package org.neo4j.graphalgo.impl.nn;

import com.carrotsearch.hppc.LongContainer;
import com.carrotsearch.hppc.LongHashSet;
import com.carrotsearch.hppc.cursors.LongCursor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import org.neo4j.graphalgo.Algorithm;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.api.RelationshipIterator;
import org.neo4j.graphalgo.core.ProcedureConfiguration;
import org.neo4j.graphalgo.core.huge.HugeGraph;
import org.neo4j.graphalgo.core.huge.loader.IdMapBuilder;
import org.neo4j.graphalgo.core.huge.loader.IdsAndProperties;
import org.neo4j.graphalgo.core.huge.loader.NodeImporter;
import org.neo4j.graphalgo.core.huge.loader.NodesBatchBuffer;
import org.neo4j.graphalgo.core.utils.ParallelUtil;
import org.neo4j.graphalgo.core.utils.Pools;
import org.neo4j.graphalgo.core.utils.paged.AllocationTracker;
import org.neo4j.graphalgo.core.utils.paged.HugeLongArrayBuilder;
import org.neo4j.graphalgo.impl.nn.HugeRelationshipsBuilder;
import org.neo4j.graphalgo.impl.results.SimilarityResult;
import org.neo4j.graphalgo.impl.similarity.AnnTopKConsumer;
import org.neo4j.graphalgo.impl.similarity.RleDecoder;
import org.neo4j.graphalgo.impl.similarity.SimilarityComputer;
import org.neo4j.graphalgo.impl.similarity.SimilarityInput;
import org.neo4j.graphdb.Direction;
import org.neo4j.logging.Log;
import org.roaringbitmap.RoaringBitmap;

/* loaded from: input_file:org/neo4j/graphalgo/impl/nn/ApproxNearestNeighbors.class */
public class ApproxNearestNeighbors<T extends SimilarityInput> extends Algorithm<ApproxNearestNeighbors<T>> {
    private T[] inputs;
    private final int topK;
    private final int iterations;
    private final AnnTopKConsumer[] topKConsumers;
    private final double similarityCutoff;
    private final Log log;
    private final Supplier<RleDecoder> rleDecoderFactory;
    private final Random random;
    private final int concurrency;
    private final SimilarityComputer<T> similarityComputer;
    private final double precision;
    private final double p;
    private final boolean sampling;
    private final RoaringBitmap[] visitedRelationships;
    private volatile AtomicLong nodeQueue = new AtomicLong();
    private final AtomicInteger actualIterations = new AtomicInteger();
    private final ExecutorService executor = Pools.DEFAULT;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/graphalgo/impl/nn/ApproxNearestNeighbors$ComputeTask.class */
    public class ComputeTask implements NeighborhoodTask {
        private final RleDecoder rleDecoder;
        private final AnnTopKConsumer[] localTopKConsumers;
        private final Graph oldGraph;
        private final Graph newGraph;
        private final double sampleRate;

        ComputeTask(Supplier<RleDecoder> supplier, int i, HugeGraph hugeGraph, HugeGraph hugeGraph2, double d) {
            this.rleDecoder = supplier.get();
            this.localTopKConsumers = AnnTopKConsumer.initializeTopKConsumers(i, ApproxNearestNeighbors.this.topK);
            this.oldGraph = hugeGraph.concurrentCopy();
            this.newGraph = hugeGraph2.concurrentCopy();
            this.sampleRate = d;
        }

        @Override // java.lang.Runnable
        public void run() {
            SimilarityResult similarity;
            while (true) {
                long andIncrement = ApproxNearestNeighbors.this.nodeQueue.getAndIncrement();
                if (andIncrement >= ApproxNearestNeighbors.this.inputs.length || !ApproxNearestNeighbors.this.running()) {
                    return;
                }
                LongHashSet neighbors = getNeighbors(andIncrement, this.oldGraph);
                long[] array = getNeighbors(andIncrement, this.newGraph).toArray();
                for (int i = 0; i < array.length; i++) {
                    int intExact = Math.toIntExact(array[i]);
                    SimilarityInput similarityInput = ApproxNearestNeighbors.this.inputs[intExact];
                    for (int i2 = i + 1; i2 < array.length; i2++) {
                        int intExact2 = Math.toIntExact(array[i2]);
                        SimilarityResult similarity2 = ApproxNearestNeighbors.this.similarityComputer.similarity(this.rleDecoder, similarityInput, ApproxNearestNeighbors.this.inputs[intExact2], ApproxNearestNeighbors.this.similarityCutoff);
                        if (similarity2 != null) {
                            this.localTopKConsumers[intExact].applyAsInt(similarity2);
                            this.localTopKConsumers[intExact2].applyAsInt(similarity2.reverse());
                        }
                    }
                    Iterator<LongCursor> it = neighbors.iterator();
                    while (it.hasNext()) {
                        int intExact3 = Math.toIntExact(it.next().value);
                        SimilarityInput similarityInput2 = ApproxNearestNeighbors.this.inputs[intExact3];
                        if (intExact != intExact3 && (similarity = ApproxNearestNeighbors.this.similarityComputer.similarity(this.rleDecoder, similarityInput, similarityInput2, ApproxNearestNeighbors.this.similarityCutoff)) != null) {
                            this.localTopKConsumers[intExact].applyAsInt(similarity);
                            this.localTopKConsumers[intExact3].applyAsInt(similarity.reverse());
                        }
                    }
                }
            }
        }

        private LongHashSet getNeighbors(long j, Graph graph) {
            long[] array = ApproxNearestNeighbors.this.findNeighbors(j, graph, Direction.INCOMING).toArray();
            long[] sampleNeighbors = ApproxNearestNeighbors.this.sampling ? ANNUtils.sampleNeighbors(array, this.sampleRate, ApproxNearestNeighbors.this.random) : array;
            LongHashSet findNeighbors = ApproxNearestNeighbors.this.findNeighbors(j, graph, Direction.OUTGOING);
            LongHashSet longHashSet = new LongHashSet();
            longHashSet.addAll(sampleNeighbors);
            longHashSet.addAll((LongContainer) findNeighbors);
            return longHashSet;
        }

        @Override // org.neo4j.graphalgo.impl.nn.ApproxNearestNeighbors.NeighborhoodTask
        public int mergeInto(AnnTopKConsumer[] annTopKConsumerArr) {
            int i = 0;
            for (int i2 = 0; i2 < annTopKConsumerArr.length; i2++) {
                i += annTopKConsumerArr[i2].apply(this.localTopKConsumers[i2]);
            }
            return i;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/graphalgo/impl/nn/ApproxNearestNeighbors$InitTask.class */
    public class InitTask implements Runnable {
        private final RleDecoder rleDecoder;

        InitTask(Supplier<RleDecoder> supplier) {
            this.rleDecoder = supplier.get();
        }

        @Override // java.lang.Runnable
        public void run() {
            while (true) {
                long andIncrement = ApproxNearestNeighbors.this.nodeQueue.getAndIncrement();
                if (andIncrement >= ApproxNearestNeighbors.this.inputs.length || !ApproxNearestNeighbors.this.running()) {
                    return;
                }
                int intExact = Math.toIntExact(andIncrement);
                AnnTopKConsumer annTopKConsumer = ApproxNearestNeighbors.this.topKConsumers[intExact];
                SimilarityInput similarityInput = ApproxNearestNeighbors.this.inputs[intExact];
                Iterator<Integer> it = ANNUtils.selectRandomNeighbors(Math.abs(ApproxNearestNeighbors.this.topK), ApproxNearestNeighbors.this.inputs.length, intExact, ApproxNearestNeighbors.this.random).iterator();
                while (it.hasNext()) {
                    SimilarityResult similarity = ApproxNearestNeighbors.this.similarityComputer.similarity(this.rleDecoder, similarityInput, ApproxNearestNeighbors.this.inputs[it.next().intValue()], ApproxNearestNeighbors.this.similarityCutoff);
                    if (similarity != null) {
                        annTopKConsumer.applyAsInt(similarity);
                    }
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/graphalgo/impl/nn/ApproxNearestNeighbors$NeighborhoodTask.class */
    public interface NeighborhoodTask extends Runnable {
        int mergeInto(AnnTopKConsumer[] annTopKConsumerArr);
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/graphalgo/impl/nn/ApproxNearestNeighbors$SetupTask.class */
    public class SetupTask implements Runnable {
        private final NewOldGraph graph;
        private final HugeRelationshipsBuilder.HugeRelationshipsBuilderWithBuffer oldRelationshipBuilder;
        private final HugeRelationshipsBuilder.HugeRelationshipsBuilderWithBuffer newRelationshipBuilder;
        private final double sampleSize;
        private final RoaringBitmap[] visitedRelationships;
        private final long startNodeId;
        private final long nodeCount;

        SetupTask(NewOldGraph newOldGraph, RoaringBitmap[] roaringBitmapArr, HugeRelationshipsBuilder hugeRelationshipsBuilder, HugeRelationshipsBuilder hugeRelationshipsBuilder2, double d, long j, long j2) {
            this.graph = newOldGraph;
            this.visitedRelationships = roaringBitmapArr;
            this.oldRelationshipBuilder = hugeRelationshipsBuilder.withBuffer();
            this.newRelationshipBuilder = hugeRelationshipsBuilder2.withBuffer();
            this.sampleSize = d;
            this.startNodeId = j;
            this.nodeCount = j2;
        }

        @Override // java.lang.Runnable
        public void run() {
            long j = this.startNodeId + this.nodeCount;
            long j2 = this.startNodeId;
            while (true) {
                long j3 = j2;
                if (j3 >= j) {
                    this.oldRelationshipBuilder.flushAll();
                    this.newRelationshipBuilder.flushAll();
                    return;
                }
                if (!ApproxNearestNeighbors.this.running()) {
                    return;
                }
                int intExact = Math.toIntExact(j3);
                Iterator<LongCursor> it = this.graph.findOldNeighbors(j3).iterator();
                while (it.hasNext()) {
                    this.oldRelationshipBuilder.addRelationship(j3, it.next().value);
                }
                long[] array = this.graph.findNewNeighbors(j3).toArray();
                long[] sampleNeighbors = ApproxNearestNeighbors.this.sampling ? ANNUtils.sampleNeighbors(array, this.sampleSize, ApproxNearestNeighbors.this.random) : array;
                for (long j4 : sampleNeighbors) {
                    this.newRelationshipBuilder.addRelationship(j3, j4);
                }
                for (long j5 : sampleNeighbors) {
                    this.visitedRelationships[intExact].add(Math.toIntExact(Long.valueOf(j5).longValue()));
                }
                j2 = j3 + 1;
            }
        }
    }

    public ApproxNearestNeighbors(ProcedureConfiguration procedureConfiguration, T[] tArr, double d, Supplier<RleDecoder> supplier, SimilarityComputer<T> similarityComputer, int i, Log log) {
        this.inputs = tArr;
        this.topK = i;
        this.iterations = procedureConfiguration.getNumber("iterations", 10).intValue();
        this.precision = procedureConfiguration.getNumber("precision", Double.valueOf(0.001d)).doubleValue();
        this.p = procedureConfiguration.getNumber("p", Double.valueOf(0.5d)).doubleValue();
        this.random = new Random(procedureConfiguration.getNumber("randomSeed", 1).longValue());
        this.sampling = procedureConfiguration.getBool("sampling", true).booleanValue();
        this.topKConsumers = AnnTopKConsumer.initializeTopKConsumers(tArr.length, i);
        this.visitedRelationships = ANNUtils.initializeRoaringBitmaps(tArr.length);
        this.similarityCutoff = d;
        this.log = log;
        this.concurrency = procedureConfiguration.getConcurrency();
        this.rleDecoderFactory = supplier;
        this.similarityComputer = similarityComputer;
    }

    public void compute() {
        double min = Math.min(this.p, 1.0d) * Math.abs(this.topK);
        ParallelUtil.runWithConcurrency(this.concurrency, createInitTasks(), this.executor);
        IdsAndProperties buildNodes = buildNodes(this.inputs);
        RoaringBitmap[] initializeRoaringBitmaps = ANNUtils.initializeRoaringBitmaps(this.inputs.length);
        for (int i = 1; i <= this.iterations; i++) {
            for (int i2 = 0; i2 < this.inputs.length; i2++) {
                this.visitedRelationships[i2] = RoaringBitmap.or(this.visitedRelationships[i2], initializeRoaringBitmaps[i2]);
            }
            initializeRoaringBitmaps = ANNUtils.initializeRoaringBitmaps(this.inputs.length);
            HugeRelationshipsBuilder.HugeRelationshipsBuilderWithBuffer withBuffer = new HugeRelationshipsBuilder(buildNodes).withBuffer();
            withBuffer.addRelationshipsFrom(this.topKConsumers);
            HugeGraph hugeGraph = ANNUtils.hugeGraph(buildNodes, withBuffer.build());
            HugeRelationshipsBuilder hugeRelationshipsBuilder = new HugeRelationshipsBuilder(buildNodes);
            HugeRelationshipsBuilder hugeRelationshipsBuilder2 = new HugeRelationshipsBuilder(buildNodes);
            ParallelUtil.runWithConcurrency(1, setupTasks(min, initializeRoaringBitmaps, hugeGraph, hugeRelationshipsBuilder, hugeRelationshipsBuilder2), this.executor);
            Collection<NeighborhoodTask> computeTasks = computeTasks(min, ANNUtils.hugeGraph(buildNodes, hugeRelationshipsBuilder.build()), ANNUtils.hugeGraph(buildNodes, hugeRelationshipsBuilder2.build()));
            ParallelUtil.runWithConcurrency(this.concurrency, computeTasks, this.executor);
            int mergeConsumers = mergeConsumers(computeTasks);
            this.log.info("ANN: Changes in iteration %d: %d", new Object[]{Integer.valueOf(i), Integer.valueOf(mergeConsumers)});
            this.actualIterations.set(i);
            if (shouldTerminate(mergeConsumers)) {
                return;
            }
        }
    }

    private Collection<Runnable> setupTasks(double d, RoaringBitmap[] roaringBitmapArr, HugeGraph hugeGraph, HugeRelationshipsBuilder hugeRelationshipsBuilder, HugeRelationshipsBuilder hugeRelationshipsBuilder2) {
        int adjustedBatchSize = ParallelUtil.adjustedBatchSize(this.inputs.length, this.concurrency, 100);
        int length = (this.inputs.length / adjustedBatchSize) + 1;
        ArrayList arrayList = new ArrayList(length);
        long j = 0;
        for (int i = 0; i < length; i++) {
            long min = Math.min(adjustedBatchSize, this.inputs.length - (i * adjustedBatchSize));
            arrayList.add(new SetupTask(new NewOldGraph(hugeGraph, this.visitedRelationships), roaringBitmapArr, hugeRelationshipsBuilder, hugeRelationshipsBuilder2, d, j, min));
            j += min;
        }
        return arrayList;
    }

    private Collection<NeighborhoodTask> computeTasks(double d, HugeGraph hugeGraph, HugeGraph hugeGraph2) {
        this.nodeQueue.set(0L);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.concurrency; i++) {
            arrayList.add(new ComputeTask(this.rleDecoderFactory, this.inputs.length, hugeGraph, hugeGraph2, d));
        }
        return arrayList;
    }

    private IdsAndProperties buildNodes(T[] tArr) {
        HugeLongArrayBuilder of = HugeLongArrayBuilder.of(tArr.length, AllocationTracker.EMPTY);
        NodeImporter nodeImporter = new NodeImporter(of, null);
        long j = 0;
        NodesBatchBuffer nodesBatchBuffer = new NodesBatchBuffer(null, -1, tArr.length, false);
        for (T t : tArr) {
            if (t.getId() > j) {
                j = t.getId();
            }
            nodesBatchBuffer.add(t.getId(), -1L);
            if (nodesBatchBuffer.isFull()) {
                nodeImporter.importNodes(nodesBatchBuffer, null);
                nodesBatchBuffer.reset();
            }
        }
        nodeImporter.importNodes(nodesBatchBuffer, null);
        return new IdsAndProperties(IdMapBuilder.build(of, j, 1, AllocationTracker.EMPTY), Collections.emptyMap());
    }

    private List<Runnable> createInitTasks() {
        this.nodeQueue.set(0L);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.concurrency; i++) {
            arrayList.add(new InitTask(this.rleDecoderFactory));
        }
        return arrayList;
    }

    private int mergeConsumers(Iterable<NeighborhoodTask> iterable) {
        int i = 0;
        Iterator<NeighborhoodTask> it = iterable.iterator();
        while (it.hasNext()) {
            i += it.next().mergeInto(this.topKConsumers);
        }
        return i;
    }

    private boolean shouldTerminate(int i) {
        return i == 0 || ((double) i) < ((double) (this.inputs.length * Math.abs(this.topK))) * this.precision;
    }

    @Override // org.neo4j.graphalgo.Algorithm
    public ApproxNearestNeighbors<T> me() {
        return this;
    }

    @Override // org.neo4j.graphalgo.Algorithm
    public void release() {
        this.inputs = null;
    }

    public AnnTopKConsumer[] result() {
        return this.topKConsumers;
    }

    public int iterations() {
        return this.actualIterations.get();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public LongHashSet findNeighbors(long j, RelationshipIterator relationshipIterator, Direction direction) {
        LongHashSet longHashSet = new LongHashSet();
        relationshipIterator.forEachRelationship(j, direction, (j2, j3) -> {
            longHashSet.add(j3);
            return true;
        });
        return longHashSet;
    }
}
