package org.neo4j.gds.ml.core.samplers;

import com.neo4j.gds.shaded.org.apache.commons.lang3.mutable.MutableDouble;
import com.neo4j.gds.shaded.org.apache.commons.lang3.mutable.MutableLong;
import java.util.SplittableRandom;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.mem.MemoryRange;

/* loaded from: input_file:org/neo4j/gds/ml/core/samplers/RandomWalkSampler.class */
public class RandomWalkSampler {
    private static final long NO_MORE_NODES = -1;
    private static final int MAX_TRIES = 100;
    private final Graph graph;
    private final int walkLength;
    private SplittableRandom random;
    private final MutableDouble currentWeight = new MutableDouble(0.0d);
    private final MutableLong randomNeighbour = new MutableLong(-1);
    private final double normalizedReturnProbability;
    private final double normalizedSameDistanceProbability;
    private final double normalizedInOutProbability;
    private final CumulativeWeightSupplier cumulativeWeightSupplier;
    private final long randomSeed;

    @FunctionalInterface
    /* loaded from: input_file:org/neo4j/gds/ml/core/samplers/RandomWalkSampler$CumulativeWeightSupplier.class */
    public interface CumulativeWeightSupplier {
        double forNode(long j);
    }

    public static RandomWalkSampler create(Graph graph, CumulativeWeightSupplier cumulativeWeightSupplier, int i, double d, double d2, long j) {
        double max = Math.max(Math.max(1.0d / d, 1.0d), 1.0d / d2);
        return new RandomWalkSampler(graph, cumulativeWeightSupplier, i, (1.0d / d) / max, 1.0d / max, (1.0d / d2) / max, j);
    }

    public RandomWalkSampler(Graph graph, CumulativeWeightSupplier cumulativeWeightSupplier, int i, double d, double d2, double d3, long j) {
        this.randomSeed = j;
        this.cumulativeWeightSupplier = cumulativeWeightSupplier;
        this.graph = graph;
        this.walkLength = i;
        this.normalizedReturnProbability = d;
        this.normalizedSameDistanceProbability = d2;
        this.normalizedInOutProbability = d3;
        this.random = new SplittableRandom(j);
    }

    public static MemoryRange memoryEstimation(long j) {
        return MemoryRange.of(Estimate.sizeOfInstance(RandomWalkSampler.class) + Estimate.sizeOfLongArray(j), Estimate.sizeOfInstance(RandomWalkSampler.class) + (2 * Estimate.sizeOfLongArray(j)));
    }

    public long[] walk(long j) {
        long[] jArr = new long[this.walkLength];
        jArr[0] = j;
        jArr[1] = randomNeighbour(j);
        if (jArr[1] == -1) {
            return new long[]{jArr[0]};
        }
        int i = 2;
        while (true) {
            if (i >= this.walkLength) {
                break;
            }
            long walkOneStep = walkOneStep(jArr[i - 2], jArr[i - 1]);
            if (walkOneStep == -1) {
                long[] jArr2 = new long[i];
                System.arraycopy(jArr, 0, jArr2, 0, jArr2.length);
                jArr = jArr2;
                break;
            }
            jArr[i] = walkOneStep;
            i++;
        }
        return jArr;
    }

    private long walkOneStep(long j, long j2) {
        int degree = this.graph.degree(j2);
        if (degree == 0) {
            return -1L;
        }
        if (degree == 1) {
            return randomNeighbour(j2);
        }
        for (int i = 0; i < 100; i++) {
            long randomNeighbour = randomNeighbour(j2);
            double nextDouble = this.random.nextDouble();
            if (randomNeighbour != j) {
                if (nextDouble < Math.min(this.normalizedSameDistanceProbability, this.normalizedInOutProbability)) {
                    return randomNeighbour;
                }
                if (nextDouble >= Math.max(this.normalizedSameDistanceProbability, this.normalizedInOutProbability)) {
                    continue;
                } else if (isNeighbour(j, randomNeighbour)) {
                    if (nextDouble < this.normalizedSameDistanceProbability) {
                        return randomNeighbour;
                    }
                } else if (nextDouble < this.normalizedInOutProbability) {
                    return randomNeighbour;
                }
            } else if (nextDouble < this.normalizedReturnProbability) {
                return randomNeighbour;
            }
        }
        return randomNeighbour(j2);
    }

    private long randomNeighbour(long j) {
        double forNode = this.cumulativeWeightSupplier.forNode(j) * this.random.nextDouble();
        this.currentWeight.setValue(0.0d);
        this.randomNeighbour.setValue(-1L);
        this.graph.forEachRelationship(j, 1.0d, (j2, j3, d) -> {
            if (forNode > this.currentWeight.addAndGet(d)) {
                return true;
            }
            this.randomNeighbour.setValue(j3);
            return false;
        });
        return this.randomNeighbour.getValue2().longValue();
    }

    public void prepareForNewNode(long j) {
        this.random = new SplittableRandom(this.randomSeed + j);
    }

    private boolean isNeighbour(long j, long j2) {
        return this.graph.exists(j, j2);
    }
}
