package org.neo4j.graphalgo.impl.louvain;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.DoubleAdder;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.core.utils.ParallelUtil;
import org.neo4j.graphalgo.core.utils.ProgressLogger;
import org.neo4j.graphalgo.core.utils.TerminationFlag;
import org.neo4j.graphalgo.core.utils.traverse.SimpleBitSet;
import org.neo4j.graphalgo.impl.Algorithm;
import org.neo4j.graphalgo.impl.louvain.LouvainAlgorithm;
import org.neo4j.graphdb.Direction;

/* loaded from: input_file:org/neo4j/graphalgo/impl/louvain/WeightedLouvain.class */
public class WeightedLouvain extends Algorithm<WeightedLouvain> implements LouvainAlgorithm {
    private Graph graph;
    private ExecutorService pool;
    private final int concurrency;
    private final int nodeCount;
    private final int maxIterations;
    private double m2;
    private double mq2;
    private double[] w;
    private volatile int[] nodeCommunity;
    private volatile double[] sTot;
    private int iterations;
    private final ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock();
    private final ReentrantReadWriteLock.WriteLock writeLock = this.readWriteLock.writeLock();
    private final AtomicInteger queue = new AtomicInteger();
    private final List<Task> tasks = new ArrayList();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/graphalgo/impl/louvain/WeightedLouvain$Task.class */
    public class Task implements Runnable {
        private boolean changes;

        private Task() {
        }

        @Override // java.lang.Runnable
        public void run() {
            ProgressLogger progressLogger = WeightedLouvain.this.getProgressLogger();
            this.changes = false;
            while (true) {
                int andIncrement = WeightedLouvain.this.queue.getAndIncrement();
                if (andIncrement >= WeightedLouvain.this.nodeCount || !WeightedLouvain.this.running()) {
                    return;
                }
                int bestCommunity = WeightedLouvain.this.bestCommunity(andIncrement);
                if (bestCommunity != WeightedLouvain.this.nodeCommunity[andIncrement]) {
                    WeightedLouvain.this.move(andIncrement, bestCommunity);
                    this.changes = true;
                }
                progressLogger.logProgress(andIncrement, WeightedLouvain.this.nodeCount - 1, () -> {
                    return "Round " + WeightedLouvain.this.iterations;
                });
            }
        }
    }

    public WeightedLouvain(Graph graph, ExecutorService executorService, int i, int i2) {
        this.graph = graph;
        this.pool = executorService;
        this.concurrency = i;
        this.nodeCount = Math.toIntExact(graph.nodeCount());
        this.maxIterations = i2;
        this.nodeCommunity = new int[this.nodeCount];
        this.sTot = new double[this.nodeCount];
        this.w = new double[this.nodeCount];
    }

    private void init() {
        this.tasks.clear();
        for (int i = 0; i < this.concurrency; i++) {
            this.tasks.add(new Task());
        }
        Arrays.setAll(this.nodeCommunity, i2 -> {
            return i2;
        });
        DoubleAdder doubleAdder = new DoubleAdder();
        ParallelUtil.iterateParallel(this.pool, this.nodeCount, this.concurrency, i3 -> {
            double[] dArr = {0.0d};
            this.graph.forEachRelationship(i3, Direction.OUTGOING, (i3, i4, j) -> {
                dArr[0] = dArr[0] + this.graph.weightOf(i3, i4);
                return true;
            });
            doubleAdder.add(dArr[0]);
            double[] dArr2 = this.w;
            double[] dArr3 = this.sTot;
            double d = dArr[0];
            dArr3[i3] = d;
            dArr2[i3] = d;
        });
        this.iterations = 0;
        this.m2 = doubleAdder.doubleValue();
        this.mq2 = Math.pow(doubleAdder.doubleValue(), 2.0d);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void move(int i, int i2) {
        this.writeLock.lock();
        int i3 = this.nodeCommunity[i];
        double[] dArr = this.sTot;
        dArr[i3] = dArr[i3] - this.w[i];
        double[] dArr2 = this.sTot;
        dArr2[i2] = dArr2[i2] + this.w[i];
        this.nodeCommunity[i] = i2;
        this.writeLock.unlock();
    }

    private double weightIntoC(int i, int i2) {
        double[] dArr = {0.0d};
        this.graph.forEachRelationship(i, Direction.OUTGOING, (i3, i4, j) -> {
            if (this.nodeCommunity[i4] != i2) {
                return true;
            }
            dArr[0] = dArr[0] + this.graph.weightOf(i3, i4);
            return true;
        });
        return dArr[0];
    }

    private double modGain(int i, int i2) {
        return (weightIntoC(i, i2) / this.m2) - ((this.w[i] * this.sTot[i2]) / this.mq2);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public int bestCommunity(int i) {
        double[] dArr = {0.0d};
        int[] iArr = {this.nodeCommunity[i]};
        this.graph.forEachRelationship(i, Direction.OUTGOING, (i2, i3, j) -> {
            int i2 = this.nodeCommunity[i3];
            double modGain = modGain(i, i2);
            if (modGain < dArr[0]) {
                return true;
            }
            dArr[0] = modGain;
            iArr[0] = i2;
            return true;
        });
        return iArr[0];
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.graphalgo.impl.Algorithm
    public WeightedLouvain me() {
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.graphalgo.impl.Algorithm
    public WeightedLouvain release() {
        this.graph = null;
        this.pool = null;
        this.sTot = null;
        this.w = null;
        return this;
    }

    @Override // org.neo4j.graphalgo.impl.louvain.LouvainAlgorithm
    public LouvainAlgorithm compute() {
        init();
        this.iterations = 0;
        while (this.iterations < this.maxIterations) {
            this.queue.set(0);
            ParallelUtil.runWithConcurrency(this.concurrency, this.tasks, getTerminationFlag(), this.pool);
            boolean z = false;
            Iterator<Task> it = this.tasks.iterator();
            while (it.hasNext()) {
                z |= it.next().changes;
            }
            if (!z) {
                return this;
            }
            this.iterations++;
        }
        return this;
    }

    @Override // org.neo4j.graphalgo.impl.louvain.LouvainAlgorithm
    public int[] getCommunityIds() {
        return this.nodeCommunity;
    }

    @Override // org.neo4j.graphalgo.impl.louvain.LouvainAlgorithm
    public int getIterations() {
        return this.iterations;
    }

    @Override // org.neo4j.graphalgo.impl.louvain.LouvainAlgorithm
    public long getCommunityCount() {
        SimpleBitSet simpleBitSet = new SimpleBitSet(this.nodeCount);
        for (int i = 0; i < this.nodeCount; i++) {
            simpleBitSet.put(this.nodeCommunity[i]);
        }
        return simpleBitSet.size();
    }

    @Override // org.neo4j.graphalgo.impl.louvain.LouvainAlgorithm
    public Stream<LouvainAlgorithm.Result> resultStream() {
        return IntStream.range(0, this.nodeCount).mapToObj(i -> {
            return new LouvainAlgorithm.Result(this.graph.toOriginalNodeId(i), this.nodeCommunity[i]);
        });
    }

    @Override // org.neo4j.graphalgo.impl.louvain.LouvainAlgorithm
    public /* bridge */ /* synthetic */ LouvainAlgorithm withTerminationFlag(TerminationFlag terminationFlag) {
        return (LouvainAlgorithm) super.withTerminationFlag(terminationFlag);
    }

    @Override // org.neo4j.graphalgo.impl.louvain.LouvainAlgorithm
    public /* bridge */ /* synthetic */ LouvainAlgorithm withProgressLogger(ProgressLogger progressLogger) {
        return (LouvainAlgorithm) super.withProgressLogger(progressLogger);
    }
}
