package org.neo4j.graphalgo.impl.louvain;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.IntConsumer;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.api.NodeIterator;
import org.neo4j.graphalgo.core.sources.ShuffledNodeIterator;
import org.neo4j.graphalgo.core.utils.ParallelUtil;
import org.neo4j.graphalgo.core.utils.Pointer;
import org.neo4j.graphalgo.core.utils.ProgressLogger;
import org.neo4j.graphalgo.core.utils.TerminationFlag;
import org.neo4j.graphalgo.core.utils.paged.AllocationTracker;
import org.neo4j.graphalgo.core.utils.traverse.SimpleBitSet;
import org.neo4j.graphalgo.impl.Algorithm;
import org.neo4j.graphalgo.impl.louvain.LouvainAlgorithm;
import org.neo4j.graphdb.Direction;

/* loaded from: input_file:org/neo4j/graphalgo/impl/louvain/Louvain.class */
public class Louvain extends Algorithm<Louvain> implements LouvainAlgorithm {
    private static final Direction D = Direction.OUTGOING;
    private static final int NONE = -1;
    private static final double MINIMUM_MODULARITY = Double.NEGATIVE_INFINITY;
    private Graph graph;
    private ExecutorService pool;
    private NodeIterator nodeIterator;
    private final int nodeCount;
    private final int maxIterations;
    private final int concurrency;
    private final AllocationTracker tracker;
    private double m;
    private double m2;
    private int[] communities;
    private double[] ki;
    private int iterations;
    private double q = MINIMUM_MODULARITY;
    private AtomicInteger counter = new AtomicInteger(0);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/graphalgo/impl/louvain/Louvain$Task.class */
    public class Task implements Runnable {
        private final double[] sTot;
        private final double[] sIn;
        private final int[] localCommunities;
        private double bestGain;
        private double bestWeight;
        private int bestCommunity;
        private double q = Louvain.MINIMUM_MODULARITY;

        public Task() {
            this.sTot = new double[Louvain.this.nodeCount];
            this.sIn = new double[Louvain.this.nodeCount];
            this.localCommunities = new int[Louvain.this.nodeCount];
            System.arraycopy(Louvain.this.ki, 0, this.sTot, 0, Louvain.this.nodeCount);
            System.arraycopy(Louvain.this.communities, 0, this.localCommunities, 0, Louvain.this.nodeCount);
            Arrays.fill(this.sIn, 0.0d);
        }

        public void sync(Task task) {
            System.arraycopy(task.localCommunities, 0, this.localCommunities, 0, Louvain.this.nodeCount);
            System.arraycopy(task.sTot, 0, this.sTot, 0, Louvain.this.nodeCount);
            System.arraycopy(task.sIn, 0, this.sIn, 0, Louvain.this.nodeCount);
        }

        @Override // java.lang.Runnable
        public void run() {
            ProgressLogger progressLogger = Louvain.this.getProgressLogger();
            Pointer.BoolPointer wrap = Pointer.wrap(false);
            int i = Louvain.this.nodeCount * Louvain.this.concurrency;
            Louvain.this.nodeIterator.forEachNode(i2 -> {
                wrap.v |= move(i2);
                progressLogger.logProgress(Louvain.this.counter.getAndIncrement(), i, () -> {
                    return String.format("Iteration %d", Integer.valueOf(Louvain.this.iterations));
                });
                return true;
            });
            if (wrap.v) {
                this.q = modularity();
            }
        }

        public double getModularity() {
            return this.q;
        }

        private boolean move(int i) {
            int i2 = this.localCommunities[i];
            this.bestCommunity = i2;
            double[] dArr = this.sTot;
            dArr[i2] = dArr[i2] - Louvain.this.ki[i];
            double[] dArr2 = this.sIn;
            dArr2[i2] = dArr2[i2] - (2.0d * weightIntoCom(i, i2));
            this.localCommunities[i] = -1;
            this.bestGain = 0.0d;
            this.bestWeight = 0.0d;
            forEachConnectedCommunity(i, i3 -> {
                double weightIntoCom = weightIntoCom(i, i3);
                double d = (2.0d * weightIntoCom) - ((this.sTot[i3] * Louvain.this.ki[i]) / Louvain.this.m);
                if (d > this.bestGain) {
                    this.bestGain = d;
                    this.bestCommunity = i3;
                    this.bestWeight = weightIntoCom;
                }
            });
            double[] dArr3 = this.sTot;
            int i4 = this.bestCommunity;
            dArr3[i4] = dArr3[i4] + Louvain.this.ki[i];
            double[] dArr4 = this.sIn;
            int i5 = this.bestCommunity;
            dArr4[i5] = dArr4[i5] + (2.0d * this.bestWeight);
            this.localCommunities[i] = this.bestCommunity;
            return this.bestCommunity != i2;
        }

        private void forEachConnectedCommunity(int i, IntConsumer intConsumer) {
            SimpleBitSet simpleBitSet = new SimpleBitSet(Louvain.this.nodeCount);
            Louvain.this.graph.forEachRelationship(i, Louvain.D, (i2, i3, j) -> {
                int i2 = this.localCommunities[i3];
                if (simpleBitSet.contains(i2)) {
                    return true;
                }
                simpleBitSet.put(i2);
                intConsumer.accept(i2);
                return true;
            });
        }

        private double modularity() {
            double d = 0.0d;
            SimpleBitSet simpleBitSet = new SimpleBitSet(Louvain.this.nodeCount);
            for (int i = 0; i < Louvain.this.nodeCount; i++) {
                int i2 = this.localCommunities[i];
                if (!simpleBitSet.contains(i2)) {
                    simpleBitSet.put(i2);
                    d += (this.sIn[i2] / Louvain.this.m2) - Math.pow(this.sTot[i2] / Louvain.this.m2, 2.0d);
                }
            }
            return d;
        }

        private double weightIntoCom(int i, int i2) {
            Pointer.DoublePointer wrap = Pointer.wrap(0.0d);
            Louvain.this.graph.forEachRelationship(i, Louvain.D, (i3, i4, j) -> {
                if (this.localCommunities[i4] != i2) {
                    return true;
                }
                wrap.v += Louvain.this.graph.weightOf(i3, i4);
                return true;
            });
            return wrap.v;
        }
    }

    public Louvain(Graph graph, int i, ExecutorService executorService, int i2, AllocationTracker allocationTracker) {
        this.graph = graph;
        this.nodeCount = Math.toIntExact(graph.nodeCount());
        this.maxIterations = i;
        this.pool = executorService;
        this.concurrency = i2;
        this.tracker = allocationTracker;
        this.nodeIterator = new ShuffledNodeIterator(this.nodeCount);
        this.ki = new double[this.nodeCount];
        this.communities = new int[this.nodeCount];
        allocationTracker.add(12 * this.nodeCount);
    }

    private void init() {
        ProgressLogger progressLogger = getProgressLogger();
        for (int i = 0; i < this.nodeCount; i++) {
            this.graph.forEachRelationship(i, D, (i2, i3, j) -> {
                double weightOf = this.graph.weightOf(i2, i3);
                this.m += weightOf;
                double[] dArr = this.ki;
                dArr[i2] = dArr[i2] + weightOf;
                double[] dArr2 = this.ki;
                dArr2[i3] = dArr2[i3] + weightOf;
                return true;
            });
            progressLogger.logProgress(i, this.nodeCount, () -> {
                return "Init";
            });
        }
        this.m2 = 2.0d * this.m;
        Arrays.setAll(this.communities, i4 -> {
            return i4;
        });
        progressLogger.logDone(() -> {
            return "Init complete";
        });
    }

    @Override // org.neo4j.graphalgo.impl.louvain.LouvainAlgorithm
    public Louvain compute() {
        init();
        ProgressLogger progressLogger = getProgressLogger();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.concurrency; i++) {
            arrayList.add(new Task());
        }
        this.tracker.add(20 * this.nodeCount * this.concurrency);
        this.iterations = 0;
        while (this.iterations < this.maxIterations) {
            this.counter.set(0);
            ParallelUtil.runWithConcurrency(this.concurrency, arrayList, this.pool);
            Task best = best(arrayList);
            if (null == best || best.q <= this.q) {
                break;
            }
            this.q = best.q;
            sync(best, arrayList);
            progressLogger.logDone(() -> {
                return String.format("Iteration %d led to a modularity %.4f", Integer.valueOf(this.iterations), Double.valueOf(this.q));
            });
            this.iterations++;
        }
        this.tracker.remove(20 * this.nodeCount * this.concurrency);
        progressLogger.logDone(() -> {
            return String.format("Done in %d iterations with Q=%.5f)", Integer.valueOf(this.iterations), Double.valueOf(this.q));
        });
        return this;
    }

    private static Task best(Collection<Task> collection) {
        Task task = null;
        double d = Double.NEGATIVE_INFINITY;
        for (Task task2 : collection) {
            double modularity = task2.getModularity();
            if (modularity > d) {
                d = modularity;
                task = task2;
            }
        }
        return task;
    }

    private void sync(Task task, Collection<Task> collection) {
        for (Task task2 : collection) {
            if (task2 != task) {
                task2.sync(task);
            }
        }
        System.arraycopy(task.localCommunities, 0, this.communities, 0, this.nodeCount);
    }

    @Override // org.neo4j.graphalgo.impl.louvain.LouvainAlgorithm
    public int[] getCommunityIds() {
        return this.communities;
    }

    @Override // org.neo4j.graphalgo.impl.louvain.LouvainAlgorithm
    public int getIterations() {
        return this.iterations;
    }

    @Override // org.neo4j.graphalgo.impl.louvain.LouvainAlgorithm
    public long getCommunityCount() {
        SimpleBitSet simpleBitSet = new SimpleBitSet(this.nodeCount);
        for (int i = 0; i < this.nodeCount; i++) {
            simpleBitSet.put(this.communities[i]);
        }
        return simpleBitSet.size();
    }

    @Override // org.neo4j.graphalgo.impl.louvain.LouvainAlgorithm
    public Stream<LouvainAlgorithm.Result> resultStream() {
        return IntStream.range(0, this.nodeCount).mapToObj(i -> {
            return new LouvainAlgorithm.Result(this.graph.toOriginalNodeId(i), this.communities[i]);
        });
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.graphalgo.impl.Algorithm
    public Louvain me() {
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.graphalgo.impl.Algorithm
    /* renamed from: release */
    public Louvain mo104release() {
        this.graph = null;
        this.pool = null;
        this.communities = null;
        this.ki = null;
        this.tracker.remove(12 * this.nodeCount);
        return this;
    }

    @Override // org.neo4j.graphalgo.impl.louvain.LouvainAlgorithm
    public /* bridge */ /* synthetic */ LouvainAlgorithm withTerminationFlag(TerminationFlag terminationFlag) {
        return (LouvainAlgorithm) super.withTerminationFlag(terminationFlag);
    }

    @Override // org.neo4j.graphalgo.impl.louvain.LouvainAlgorithm
    public /* bridge */ /* synthetic */ LouvainAlgorithm withProgressLogger(ProgressLogger progressLogger) {
        return (LouvainAlgorithm) super.withProgressLogger(progressLogger);
    }
}
