package org.neo4j.gds.ml.core.samplers;

import com.neo4j.gds.shaded.com.carrotsearch.hppc.LongHashSet;
import java.util.SplittableRandom;
import java.util.function.LongPredicate;
import java.util.stream.LongStream;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.mem.MemoryRange;

/* loaded from: input_file:org/neo4j/gds/ml/core/samplers/LongUniformSamplerWithRetries.class */
public class LongUniformSamplerWithRetries {
    private final SplittableRandom rng;
    private final LongHashSet sampledValuesCache = new LongHashSet();

    public LongUniformSamplerWithRetries(SplittableRandom splittableRandom) {
        this.rng = splittableRandom;
    }

    public static MemoryRange memoryEstimation(long j) {
        return MemoryRange.of(Estimate.sizeOfInstance(LongUniformSamplerWithRetries.class) + Estimate.sizeOfLongHashSet(j) + Estimate.sizeOfLongArray(j));
    }

    public long[] sample(long j, long j2, long j3, int i, LongPredicate longPredicate) {
        if (i >= j3) {
            return LongStream.range(j, j2).filter(j4 -> {
                return !longPredicate.test(j4);
            }).toArray();
        }
        long[] jArr = new long[i];
        int i2 = 0;
        this.sampledValuesCache.clear();
        while (i2 < i) {
            long nextLong = this.rng.nextLong(j, j2);
            if (!longPredicate.test(nextLong) && this.sampledValuesCache.add(nextLong)) {
                int i3 = i2;
                i2++;
                jArr[i3] = nextLong;
            }
        }
        return jArr;
    }
}
