package org.neo4j.gds.ml.core.samplers;

import com.neo4j.gds.shaded.com.carrotsearch.hppc.LongArrayList;
import java.util.Arrays;
import java.util.SplittableRandom;
import java.util.function.LongPredicate;
import java.util.stream.LongStream;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.mem.MemoryRange;

/* loaded from: input_file:org/neo4j/gds/ml/core/samplers/LongUniformSamplerByExclusion.class */
public class LongUniformSamplerByExclusion {
    private final LongUniformSamplerWithRetries samplerWithRetries;
    static final /* synthetic */ boolean $assertionsDisabled;

    public LongUniformSamplerByExclusion(SplittableRandom splittableRandom) {
        this.samplerWithRetries = new LongUniformSamplerWithRetries(splittableRandom);
    }

    public static MemoryRange memoryEstimation(long j, long j2) {
        return LongUniformSamplerWithRetries.memoryEstimation(Math.min(j, j2 - j)).union(LongUniformSamplerWithRetries.memoryEstimation(0L)).add(MemoryRange.of(Estimate.sizeOfInstance(LongUniformSamplerByExclusion.class) + Estimate.sizeOfLongArray(j) + Estimate.sizeOfLongArrayList(0L), Estimate.sizeOfInstance(LongUniformSamplerByExclusion.class) + Estimate.sizeOfLongArray(j) + Estimate.sizeOfLongArrayList(j2)));
    }

    public long[] sample(long j, long j2, long j3, int i, LongPredicate longPredicate) {
        if (i >= j3) {
            return LongStream.range(j, j2).filter(j4 -> {
                return !longPredicate.test(j4);
            }).toArray();
        }
        LongArrayList longArrayList = new LongArrayList((int) j3);
        long j5 = j;
        while (true) {
            long j6 = j5;
            if (j6 >= j2) {
                break;
            }
            if (!longPredicate.test(j6)) {
                longArrayList.add(j6);
            }
            j5 = j6 + 1;
        }
        if (!$assertionsDisabled && longArrayList.size() < i) {
            throw new AssertionError();
        }
        long[] sample = this.samplerWithRetries.sample(0L, longArrayList.size(), longArrayList.size(), longArrayList.size() - i, j7 -> {
            return false;
        });
        Arrays.sort(sample);
        long[] jArr = new long[i];
        int i2 = 0;
        int i3 = 0;
        for (long j8 : sample) {
            int i4 = ((int) j8) - i3;
            System.arraycopy(longArrayList.buffer, i3, jArr, i2, i4);
            i3 = ((int) j8) + 1;
            i2 += i4;
        }
        System.arraycopy(longArrayList.buffer, i3, jArr, i2, i - i2);
        return jArr;
    }

    static {
        $assertionsDisabled = !LongUniformSamplerByExclusion.class.desiredAssertionStatus();
    }
}
