package org.neo4j.gds.graphsampling.samplers.rw.rwr;

import com.neo4j.gds.shaded.com.carrotsearch.hppc.LongContainer;
import com.neo4j.gds.shaded.com.carrotsearch.hppc.LongHashSet;
import com.neo4j.gds.shaded.com.carrotsearch.hppc.LongSet;
import java.util.Collection;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.function.Supplier;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.collections.haa.HugeAtomicDoubleArray;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.graphsampling.RandomWalkBasedNodesSampler;
import org.neo4j.gds.graphsampling.config.RandomWalkWithRestartsConfig;
import org.neo4j.gds.graphsampling.samplers.SeenNodes;
import org.neo4j.gds.graphsampling.samplers.rw.InitialStartQualities;
import org.neo4j.gds.graphsampling.samplers.rw.RandomWalkCompanion;
import org.neo4j.gds.graphsampling.samplers.rw.WalkQualities;
import org.neo4j.gds.graphsampling.samplers.rw.Walker;
import org.neo4j.gds.graphsampling.samplers.rw.WalkerProducer;

/* loaded from: input_file:org/neo4j/gds/graphsampling/samplers/rw/rwr/RandomWalkWithRestarts.class */
public class RandomWalkWithRestarts extends RandomWalkBasedNodesSampler {
    public static final double QUALITY_MOMENTUM = 0.9d;
    private static final double QUALITY_THRESHOLD_BASE = 0.05d;
    public static final int MAX_WALKS_PER_START = 100;
    protected static final long INVALID_NODE_ID = -1;
    private final RandomWalkWithRestartsConfig config;
    private final Concurrency concurrency;
    private LongHashSet startNodesUsed;
    private final WalkerProducer walkerProducer = WalkerProducer.RWRWalkerProducer();
    static final /* synthetic */ boolean $assertionsDisabled;

    public RandomWalkWithRestarts(RandomWalkWithRestartsConfig randomWalkWithRestartsConfig) {
        this.config = randomWalkWithRestartsConfig;
        this.concurrency = randomWalkWithRestartsConfig.concurrency();
    }

    @Override // org.neo4j.gds.graphsampling.NodesSampler
    public HugeAtomicBitSet compute(Graph graph, ProgressTracker progressTracker) {
        if (!$assertionsDisabled && graph.hasRelationshipProperty() != this.config.hasRelationshipWeightProperty()) {
            throw new AssertionError();
        }
        progressTracker.beginSubTask("Sample nodes");
        SeenNodes create = SeenNodes.create(graph, progressTracker, this.terminationFlag, this.config.nodeLabelStratification(), this.concurrency, this.config.samplingRatio());
        progressTracker.beginSubTask("Do random walks");
        progressTracker.setSteps(create.totalExpectedNodes());
        this.startNodesUsed = new LongHashSet();
        SplittableRandom splittableRandom = new SplittableRandom(this.config.randomSeed().orElseGet(() -> {
            return Long.valueOf(new SplittableRandom().nextLong());
        }).longValue());
        InitialStartQualities init = InitialStartQualities.init(graph, splittableRandom, this.config.startNodes());
        Optional<HugeAtomicDoubleArray> initializeTotalWeights = RandomWalkCompanion.initializeTotalWeights(this.config, graph.nodeCount());
        Collection<Runnable> tasks = ParallelUtil.tasks(this.concurrency, (Supplier<? extends Runnable>) () -> {
            return this.walkerProducer.getWalker(create, initializeTotalWeights, QUALITY_THRESHOLD_BASE / this.concurrency.squared(), new WalkQualities(init), splittableRandom.split(), graph.concurrentCopy(), this.config, progressTracker, this.terminationFlag);
        });
        RunWithConcurrency.builder().concurrency(this.concurrency).tasks(tasks).run();
        tasks.forEach(runnable -> {
            this.startNodesUsed.addAll((LongContainer) ((Walker) runnable).startNodesUsed());
        });
        progressTracker.endSubTask("Do random walks");
        progressTracker.endSubTask("Sample nodes");
        return create.sampledNodes();
    }

    @Override // org.neo4j.gds.graphsampling.NodesSampler
    public Task progressTask(GraphStore graphStore) {
        return this.config.nodeLabelStratification() ? Tasks.task("Sample nodes", Tasks.leaf("Count node labels", graphStore.nodeCount()), Tasks.leaf("Do random walks", 10 * Math.round(graphStore.nodeCount() * this.config.samplingRatio()))) : Tasks.task("Sample nodes", Tasks.leaf("Do random walks", 10 * Math.round(graphStore.nodeCount() * this.config.samplingRatio())), new Task[0]);
    }

    @Override // org.neo4j.gds.graphsampling.NodesSampler
    public String progressTaskName() {
        return "Random walk with restarts sampling";
    }

    public LongSet startNodesUsed() {
        return this.startNodesUsed;
    }

    @Override // org.neo4j.gds.graphsampling.RandomWalkBasedNodesSampler
    public long startNodesCount() {
        return this.startNodesUsed.size();
    }

    static {
        $assertionsDisabled = !RandomWalkWithRestarts.class.desiredAssertionStatus();
    }
}
