package org.neo4j.gds.graphsampling.samplers.rw.cnarw;

import com.neo4j.gds.shaded.org.apache.commons.lang3.mutable.MutableInt;
import java.util.SplittableRandom;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.compress.DoubleArrayBuffer;
import org.neo4j.gds.api.compress.LongArrayBuffer;
import org.neo4j.gds.core.utils.TwoArraysSort;
import org.neo4j.gds.functions.similarity.OverlapSimilarity;
import org.neo4j.gds.graphsampling.samplers.rw.NextNodeStrategy;

/* loaded from: input_file:org/neo4j/gds/graphsampling/samplers/rw/cnarw/WeightedCommonNeighbourAwareNextNodeStrategy.class */
public class WeightedCommonNeighbourAwareNextNodeStrategy implements NextNodeStrategy {
    private final Graph inputGraph;
    private final SplittableRandom rng;
    private final LongArrayBuffer uSortedNeighsIds = new LongArrayBuffer();
    private final DoubleArrayBuffer uSortedNeighsWeights = new DoubleArrayBuffer();
    private final LongArrayBuffer vSortedNeighsIds = new LongArrayBuffer();
    private final DoubleArrayBuffer vSortedNeighsWeights = new DoubleArrayBuffer();

    /* JADX INFO: Access modifiers changed from: package-private */
    public WeightedCommonNeighbourAwareNextNodeStrategy(Graph graph, SplittableRandom splittableRandom) {
        this.inputGraph = graph;
        this.rng = splittableRandom;
    }

    @Override // org.neo4j.gds.graphsampling.samplers.rw.NextNodeStrategy
    public long getNextNode(long j) {
        long candidateNode;
        sortNeighsWithWeights(this.inputGraph, j, this.uSortedNeighsIds, this.uSortedNeighsWeights);
        double sumWeights = sumWeights(this.uSortedNeighsIds, this.uSortedNeighsWeights.buffer);
        do {
            candidateNode = getCandidateNode(this.uSortedNeighsIds, this.uSortedNeighsWeights.buffer, sumWeights);
            sortNeighsWithWeights(this.inputGraph, candidateNode, this.vSortedNeighsIds, this.vSortedNeighsWeights);
        } while (this.rng.nextDouble() > 1.0d - computeOverlapSimilarity(this.uSortedNeighsIds, this.uSortedNeighsWeights, this.vSortedNeighsIds, this.vSortedNeighsWeights));
        return candidateNode;
    }

    private double computeOverlapSimilarity(LongArrayBuffer longArrayBuffer, DoubleArrayBuffer doubleArrayBuffer, LongArrayBuffer longArrayBuffer2, DoubleArrayBuffer doubleArrayBuffer2) {
        double computeWeightedSimilarity = OverlapSimilarity.computeWeightedSimilarity(longArrayBuffer.buffer, longArrayBuffer.length, longArrayBuffer2.buffer, longArrayBuffer2.length, doubleArrayBuffer.buffer, doubleArrayBuffer2.buffer, 0.0d);
        if (Double.isNaN(computeWeightedSimilarity)) {
            return 0.0d;
        }
        return computeWeightedSimilarity;
    }

    private double sumWeights(LongArrayBuffer longArrayBuffer, double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < longArrayBuffer.length; i++) {
            d += dArr[i];
        }
        return d;
    }

    private long getCandidateNode(LongArrayBuffer longArrayBuffer, double[] dArr, double d) {
        double nextDouble = this.rng.nextDouble(0.0d, d);
        int i = 0;
        while (nextDouble > 0.0d) {
            int i2 = i;
            i++;
            nextDouble -= dArr[i2];
        }
        return longArrayBuffer.buffer[i - 1];
    }

    private static void sortNeighsWithWeights(Graph graph, long j, LongArrayBuffer longArrayBuffer, DoubleArrayBuffer doubleArrayBuffer) {
        int degree = graph.degree(j);
        longArrayBuffer.ensureCapacity(degree);
        longArrayBuffer.length = degree;
        doubleArrayBuffer.ensureCapacity(degree);
        doubleArrayBuffer.length = degree;
        MutableInt mutableInt = new MutableInt(0);
        graph.forEachRelationship(j, 0.0d, (j2, j3, d) -> {
            int andIncrement = mutableInt.getAndIncrement();
            longArrayBuffer.buffer[andIncrement] = j3;
            doubleArrayBuffer.buffer[andIncrement] = d;
            return true;
        });
        TwoArraysSort.sortDoubleArrayByLongValues(longArrayBuffer.buffer, doubleArrayBuffer.buffer, degree);
    }
}
