package org.neo4j.gds.ml.core.samplers;

import com.neo4j.gds.shaded.com.carrotsearch.hppc.LongHashSet;
import com.neo4j.gds.shaded.org.apache.commons.lang3.mutable.MutableInt;
import java.util.Arrays;
import java.util.SplittableRandom;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.neo4j.gds.api.properties.relationships.RelationshipCursor;

/* loaded from: input_file:org/neo4j/gds/ml/core/samplers/UniformSampler.class */
public class UniformSampler {
    private final SplittableRandom random;

    public UniformSampler(long j) {
        this.random = new SplittableRandom(j);
    }

    public LongStream sample(Stream<RelationshipCursor> stream, long j, int i) {
        return sample(stream.mapToLong((v0) -> {
            return v0.targetId();
        }), j, i);
    }

    public LongStream sample(LongStream longStream, long j, int i) {
        return ((double) i) / ((double) j) < 0.5d ? sampleWithIndexes(longStream, j, i) : sampleWithReservoir(longStream, j, i);
    }

    /* JADX WARN: Type inference failed for: r0v7, types: [java.util.PrimitiveIterator$OfLong] */
    public LongStream sampleWithReservoir(LongStream longStream, long j, int i) {
        if (i == 0) {
            return LongStream.empty();
        }
        if (i >= j) {
            return longStream;
        }
        long[] jArr = new long[i];
        ?? it = longStream.iterator();
        for (int i2 = 0; i2 < i; i2++) {
            jArr[i2] = it.nextLong();
        }
        double computeSkipFactor = computeSkipFactor(i);
        int computeNumberOfSkips = (int) ((i - 1) + computeNumberOfSkips(computeSkipFactor));
        double computeSkipFactor2 = computeSkipFactor * computeSkipFactor(i);
        int i3 = i;
        while (it.hasNext()) {
            long nextLong = it.nextLong();
            if (i3 == computeNumberOfSkips) {
                jArr[this.random.nextInt(i)] = nextLong;
                computeNumberOfSkips = (int) (computeNumberOfSkips + computeNumberOfSkips(computeSkipFactor2));
                computeSkipFactor2 *= computeSkipFactor(i);
            }
            i3++;
        }
        return Arrays.stream(jArr);
    }

    private double computeSkipFactor(int i) {
        return Math.exp(Math.log(this.random.nextDouble()) / i);
    }

    private long computeNumberOfSkips(double d) {
        return ((long) (Math.log(this.random.nextDouble()) / Math.log(1.0d - d))) + 1;
    }

    public LongHashSet sampleUniqueNumbersHashSet(int i, long j) {
        if (i > j) {
            throw new IllegalArgumentException("Cannot sample more unique numbers than the range allows.");
        }
        LongHashSet longHashSet = new LongHashSet();
        if (j != i) {
            while (longHashSet.size() < i) {
                longHashSet.add(this.random.nextLong(j));
            }
            return longHashSet;
        }
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (j3 >= j) {
                return longHashSet;
            }
            longHashSet.add(j3);
            j2 = j3 + 1;
        }
    }

    public LongStream sampleWithIndexes(LongStream longStream, long j, int i) {
        if (i == 0) {
            return LongStream.empty();
        }
        if (i >= j) {
            return longStream;
        }
        LongHashSet sampleUniqueNumbersHashSet = sampleUniqueNumbersHashSet(i, j);
        MutableInt mutableInt = new MutableInt(0);
        return longStream.filter(j2 -> {
            return sampleUniqueNumbersHashSet.contains(mutableInt.getAndIncrement());
        });
    }
}
