package org.neo4j.gds.ml.core.batch;

import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.IntFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.termination.TerminationFlag;

/* loaded from: input_file:org/neo4j/gds/ml/core/batch/BatchQueue.class */
public abstract class BatchQueue {
    public static final int DEFAULT_BATCH_SIZE = 100;
    final long totalSize;
    final int batchSize;
    long currentBatch = 0;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/gds/ml/core/batch/BatchQueue$ConsumerTask.class */
    public class ConsumerTask implements Runnable {
        private final Consumer<Batch> batchConsumer;

        ConsumerTask(Consumer<Batch> consumer) {
            this.batchConsumer = consumer;
        }

        @Override // java.lang.Runnable
        public void run() {
            while (true) {
                Optional<Batch> pop = BatchQueue.this.pop();
                if (!pop.isPresent()) {
                    return;
                }
                this.batchConsumer.accept(pop.get());
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public BatchQueue(long j, int i) {
        this.totalSize = j;
        this.batchSize = i;
    }

    public static int computeBatchSize(long j, int i, Concurrency concurrency) {
        return Math.toIntExact(Math.min(2147483647L, ParallelUtil.adjustedBatchSize(j, concurrency, i)));
    }

    public static BatchQueue consecutive(long j) {
        return consecutive(j, 100);
    }

    public static BatchQueue consecutive(long j, int i, Concurrency concurrency) {
        return consecutive(j, computeBatchSize(j, i, concurrency));
    }

    public static BatchQueue consecutive(long j, int i) {
        return new ConsecutiveBatchQueue(j, i);
    }

    public static BatchQueue fromArray(ReadOnlyHugeLongArray readOnlyHugeLongArray) {
        return fromArray(readOnlyHugeLongArray, 100);
    }

    public static BatchQueue fromArray(ReadOnlyHugeLongArray readOnlyHugeLongArray, int i) {
        return new ArraySourcedBatchQueue(readOnlyHugeLongArray, i);
    }

    abstract Optional<Batch> pop();

    public long totalSize() {
        return this.totalSize;
    }

    public void parallelConsume(Consumer<Batch> consumer, Concurrency concurrency, TerminationFlag terminationFlag) {
        parallelConsume(concurrency, i -> {
            return consumer;
        }, terminationFlag);
    }

    public void parallelConsume(Concurrency concurrency, List<? extends Consumer<Batch>> list, TerminationFlag terminationFlag) {
        if (!$assertionsDisabled && list.size() != concurrency.value()) {
            throw new AssertionError();
        }
        RunWithConcurrency.builder().concurrency(concurrency).tasks((Stream<? extends Runnable>) list.stream().map(consumer -> {
            return new ConsumerTask(consumer);
        })).terminationFlag(terminationFlag).run();
    }

    public void parallelConsume(Concurrency concurrency, IntFunction<? extends Consumer<Batch>> intFunction, TerminationFlag terminationFlag) {
        parallelConsume(concurrency, (List<? extends Consumer<Batch>>) IntStream.range(0, concurrency.value()).mapToObj(intFunction).collect(Collectors.toList()), terminationFlag);
    }

    static {
        $assertionsDisabled = !BatchQueue.class.desiredAssertionStatus();
    }
}
