package org.neo4j.gds.leiden;

import com.neo4j.gds.shaded.com.carrotsearch.hppc.BitSet;
import com.neo4j.gds.shaded.org.apache.commons.lang3.mutable.MutableLong;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;

/* loaded from: input_file:org/neo4j/gds/leiden/RefinementPhase.class */
final class RefinementPhase {
    private final Graph workingGraph;
    private final HugeLongArray originalCommunities;
    private final HugeDoubleArray nodeVolumes;
    private final HugeDoubleArray communityVolumes;
    private final HugeDoubleArray communityVolumesAfterMerge;
    private final double gamma;
    private final double theta;
    private final HugeDoubleArray relationshipsBetweenCommunities;
    private final HugeLongArray encounteredCommunities;
    private final HugeDoubleArray encounteredCommunitiesWeights;
    private final long seed;
    private long communityCounter = 0;
    private final Concurrency concurrency;
    private final ExecutorService executorService;
    private final HugeDoubleArray nextCommunityProbabilities;
    private final ProgressTracker progressTracker;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/neo4j/gds/leiden/RefinementPhase$RefinementPhaseResult.class */
    static class RefinementPhaseResult {
        private final HugeLongArray communities;
        private final HugeDoubleArray communityVolumes;
        private final long maximumRefinementCommunityId;

        RefinementPhaseResult(HugeLongArray hugeLongArray, HugeDoubleArray hugeDoubleArray, long j) {
            this.communities = hugeLongArray;
            this.communityVolumes = hugeDoubleArray;
            this.maximumRefinementCommunityId = j;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public HugeLongArray communities() {
            return this.communities;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public HugeDoubleArray communityVolumes() {
            return this.communityVolumes;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public long maximumRefinedCommunityId() {
            return this.maximumRefinementCommunityId;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static RefinementPhase create(Graph graph, HugeLongArray hugeLongArray, HugeDoubleArray hugeDoubleArray, HugeDoubleArray hugeDoubleArray2, double d, double d2, long j, Concurrency concurrency, ExecutorService executorService, ProgressTracker progressTracker) {
        HugeLongArray newArray = HugeLongArray.newArray(graph.nodeCount());
        HugeDoubleArray newArray2 = HugeDoubleArray.newArray(graph.nodeCount());
        newArray2.setAll(j2 -> {
            return -1.0d;
        });
        return new RefinementPhase(graph, hugeLongArray, hugeDoubleArray, hugeDoubleArray2, newArray, newArray2, HugeDoubleArray.newArray(graph.nodeCount()), d, d2, j, concurrency, executorService, progressTracker);
    }

    private RefinementPhase(Graph graph, HugeLongArray hugeLongArray, HugeDoubleArray hugeDoubleArray, HugeDoubleArray hugeDoubleArray2, HugeLongArray hugeLongArray2, HugeDoubleArray hugeDoubleArray3, HugeDoubleArray hugeDoubleArray4, double d, double d2, long j, Concurrency concurrency, ExecutorService executorService, ProgressTracker progressTracker) {
        this.workingGraph = graph;
        this.originalCommunities = hugeLongArray;
        this.nodeVolumes = hugeDoubleArray;
        this.communityVolumesAfterMerge = hugeDoubleArray.copyOf(hugeDoubleArray.size());
        this.communityVolumes = hugeDoubleArray2;
        this.encounteredCommunities = hugeLongArray2;
        this.encounteredCommunitiesWeights = hugeDoubleArray3;
        this.nextCommunityProbabilities = hugeDoubleArray4;
        this.gamma = d;
        this.theta = d2;
        this.seed = j;
        hugeDoubleArray3.setAll(j2 -> {
            return -1.0d;
        });
        this.relationshipsBetweenCommunities = HugeDoubleArray.newArray(graph.nodeCount());
        this.concurrency = concurrency;
        this.executorService = executorService;
        this.progressTracker = progressTracker;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MemoryEstimation memoryEstimation() {
        return MemoryEstimations.builder((Class<?>) RefinementPhase.class).perNode("encountered communities", HugeLongArray::memoryEstimation).perNode("encountered community weights", HugeDoubleArray::memoryEstimation).perNode("next community probabilities", HugeDoubleArray::memoryEstimation).perNode("merged community volumes", HugeDoubleArray::memoryEstimation).perNode("relationships between communities", HugeDoubleArray::memoryEstimation).perNode("refined communities", HugeLongArray::memoryEstimation).perNode("merge tracking bitset", Estimate::sizeOfBitset).build();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public RefinementPhaseResult run() {
        HugeLongArray newArray = HugeLongArray.newArray(this.workingGraph.nodeCount());
        newArray.setAll(j -> {
            return j;
        });
        computeRelationshipsBetweenCommunities();
        BitSet bitSet = new BitSet(this.workingGraph.nodeCount());
        bitSet.set(0L, this.workingGraph.nodeCount());
        Random random = new Random(this.seed);
        MutableLong mutableLong = new MutableLong(-1L);
        this.workingGraph.forEachNode(j2 -> {
            if (bitSet.get(j2) && isWellConnected(j2)) {
                mergeNodeSubset(j2, newArray, bitSet, random);
            }
            long j2 = newArray.get(j2);
            if (mutableLong.longValue() < j2) {
                mutableLong.setValue(j2);
            }
            this.progressTracker.logProgress();
            return true;
        });
        return new RefinementPhaseResult(newArray, this.communityVolumesAfterMerge, mutableLong.longValue());
    }

    private void computeRelationshipsBetweenCommunities() {
        RunWithConcurrency.builder().concurrency(this.concurrency).tasks(PartitionUtils.degreePartition(this.workingGraph, this.concurrency, degreePartition -> {
            return new RefinementBetweenRelationshipCounter(this.workingGraph.concurrentCopy(), this.relationshipsBetweenCommunities, this.originalCommunities, degreePartition);
        }, Optional.empty())).executor(this.executorService).run();
    }

    private void mergeNodeSubset(long j, HugeLongArray hugeLongArray, BitSet bitSet, Random random) {
        this.communityCounter = 0L;
        computeCommunityInformation(j, hugeLongArray);
        long j2 = hugeLongArray.get(j);
        double d = this.nodeVolumes.get(j);
        long j3 = 0;
        double d2 = 0.0d;
        if (this.communityCounter == 0) {
            return;
        }
        double d3 = 0.0d;
        long j4 = 0;
        double d4 = 0.0d;
        long j5 = 0;
        while (true) {
            long j6 = j5;
            if (j6 >= this.communityCounter) {
                break;
            }
            long j7 = this.encounteredCommunities.get(j6);
            double d5 = this.encounteredCommunitiesWeights.get(j7);
            d4 += d5;
            this.encounteredCommunitiesWeights.set(j7, -d5);
            double d6 = d5 - ((d * this.communityVolumesAfterMerge.get(j7)) * this.gamma);
            if (d6 > d3) {
                d3 = d6;
                j4 = j7;
            }
            double d7 = 0.0d;
            if (d6 >= 0.0d) {
                d7 = Math.exp(d6 / this.theta);
            }
            long j8 = j3;
            j3 = j8 + 1;
            this.nextCommunityProbabilities.set(j8, d7);
            d2 += d7;
            j5 = j6 + 1;
        }
        long j9 = j2;
        if (!Double.isInfinite(d2) && d2 > 0.0d) {
            j9 = selectRandomCommunity(this.nextCommunityProbabilities, d2, random, j9);
        } else if (d3 > 0.0d) {
            j9 = j4;
        }
        if (j9 != j2) {
            addToCommunity(j, hugeLongArray, bitSet, j2, d4, j9);
        }
    }

    private long selectRandomCommunity(HugeDoubleArray hugeDoubleArray, double d, Random random, long j) {
        double nextDouble = d * random.nextDouble();
        if (!$assertionsDisabled && nextDouble < 0.0d) {
            throw new AssertionError();
        }
        long j2 = j;
        long j3 = 0;
        double d2 = 0.0d;
        long j4 = 0;
        while (true) {
            long j5 = j4;
            if (j5 >= this.communityCounter) {
                break;
            }
            long j6 = this.encounteredCommunities.get(j5);
            d2 += hugeDoubleArray.get(j3);
            if (nextDouble <= d2) {
                j2 = j6;
                break;
            }
            j3++;
            j4 = j5 + 1;
        }
        return j2;
    }

    private void addToCommunity(long j, HugeLongArray hugeLongArray, BitSet bitSet, long j2, double d, long j3) {
        hugeLongArray.set(j, j3);
        if (bitSet.get(j3)) {
            bitSet.flip(j3);
        }
        double d2 = this.nodeVolumes.get(j);
        this.communityVolumesAfterMerge.addTo(j3, d2);
        this.communityVolumesAfterMerge.addTo(j2, -d2);
        this.relationshipsBetweenCommunities.addTo(j3, d - Math.abs(this.encounteredCommunitiesWeights.get(j3)));
    }

    private void computeCommunityInformation(long j, HugeLongArray hugeLongArray) {
        long j2 = this.originalCommunities.get(j);
        this.workingGraph.forEachRelationship(j, 1.0d, (j3, j4, d) -> {
            if (this.originalCommunities.get(j4) != j2) {
                return true;
            }
            long j3 = hugeLongArray.get(j4);
            if (!isWellConnected(j3)) {
                return true;
            }
            if (this.encounteredCommunitiesWeights.get(j3) >= 0.0d) {
                this.encounteredCommunitiesWeights.addTo(j3, d);
                return true;
            }
            this.encounteredCommunities.set(this.communityCounter, j3);
            this.communityCounter++;
            this.encounteredCommunitiesWeights.set(j3, d);
            return true;
        });
    }

    private boolean isWellConnected(long j) {
        double d = this.communityVolumes.get(this.originalCommunities.get(j));
        double d2 = this.communityVolumesAfterMerge.get(j);
        return this.relationshipsBetweenCommunities.get(j) >= (this.gamma * d2) * (d - d2);
    }

    static {
        $assertionsDisabled = !RefinementPhase.class.desiredAssertionStatus();
    }
}
