package org.neo4j.gds.kmeans;

import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;
import org.neo4j.gds.mem.MemoryRange;

/* loaded from: input_file:org/neo4j/gds/kmeans/KmeansTask.class */
public abstract class KmeansTask implements Runnable {
    private final ClusterManager clusterManager;
    private final Partition partition;
    final NodePropertyValues nodePropertyValues;
    private final HugeDoubleArray distanceFromCentroid;
    final HugeIntArray communities;
    final long[] communitySizes;
    final int k;
    final int dimensions;
    private long swaps;
    private double distance;
    private double squaredDistance = 0.0d;
    private TaskPhase phase;

    /* JADX INFO: Access modifiers changed from: package-private */
    public long getNumAssignedAtCluster(int i) {
        return this.communitySizes[i];
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public long getSwaps() {
        return this.swaps;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MemoryEstimation memoryEstimation(int i, int i2) {
        MemoryEstimations.Builder builder = MemoryEstimations.builder((Class<?>) KmeansTask.class);
        builder.fixed("communitySizes", Estimate.sizeOfLongArray(i)).add("communityCoordinateSums", MemoryEstimations.of("communityCoordinateSums", MemoryRange.of(i * Estimate.sizeOfFloatArray(i2), i * Estimate.sizeOfDoubleArray(i2))));
        return builder.build();
    }

    abstract void reset();

    abstract void updateAfterAssignmentToCentroid(long j, int i);

    /* JADX INFO: Access modifiers changed from: package-private */
    public KmeansTask(SamplerType samplerType, ClusterManager clusterManager, NodePropertyValues nodePropertyValues, HugeIntArray hugeIntArray, HugeDoubleArray hugeDoubleArray, int i, int i2, Partition partition) {
        this.clusterManager = clusterManager;
        this.nodePropertyValues = nodePropertyValues;
        this.communities = hugeIntArray;
        this.distanceFromCentroid = hugeDoubleArray;
        this.k = i;
        this.dimensions = i2;
        this.partition = partition;
        this.communitySizes = new long[i];
        if (samplerType == SamplerType.UNIFORM) {
            this.phase = TaskPhase.ITERATION;
        } else {
            this.phase = TaskPhase.INITIAL;
        }
        this.distance = 0.0d;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static KmeansTask createTask(SamplerType samplerType, ClusterManager clusterManager, NodePropertyValues nodePropertyValues, HugeIntArray hugeIntArray, HugeDoubleArray hugeDoubleArray, int i, int i2, Partition partition) {
        return clusterManager instanceof DoubleClusterManager ? new DoubleKmeansTask(samplerType, clusterManager, nodePropertyValues, hugeIntArray, hugeDoubleArray, i, i2, partition) : new FloatKmeansTask(samplerType, clusterManager, nodePropertyValues, hugeIntArray, hugeDoubleArray, i, i2, partition);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void switchToPhase(TaskPhase taskPhase) {
        this.phase = taskPhase;
    }

    private void assignNodeToCentroid(long j, long j2) {
        this.swaps = 0L;
        reset();
        long j3 = j;
        while (true) {
            long j4 = j3;
            if (j4 >= j2) {
                return;
            }
            int findClosestCentroid = this.clusterManager.findClosestCentroid(j4);
            long[] jArr = this.communitySizes;
            jArr[findClosestCentroid] = jArr[findClosestCentroid] + 1;
            if (findClosestCentroid != this.communities.get(j4)) {
                this.swaps++;
            }
            this.communities.set(j4, findClosestCentroid);
            updateAfterAssignmentToCentroid(j4, findClosestCentroid);
            j3 = j4 + 1;
        }
    }

    public double getDistanceFromCentroidNormalized() {
        return this.distance / this.communities.size();
    }

    public double getSquaredDistance() {
        return this.squaredDistance;
    }

    private void calculateFinalDistance(long j, long j2) {
        long j3 = j;
        while (true) {
            long j4 = j3;
            if (j4 >= j2) {
                return;
            }
            double euclidean = this.clusterManager.euclidean(j4, this.communities.get(j4));
            this.distance += euclidean;
            this.distanceFromCentroid.set(j4, euclidean);
            j3 = j4 + 1;
        }
    }

    private void distanceFromLastSampledCentroid(long j, long j2, int i) {
        this.squaredDistance = 0.0d;
        long j3 = j;
        while (true) {
            long j4 = j3;
            if (j4 >= j2) {
                return;
            }
            if (this.distanceFromCentroid.get(j4) > -1.0d) {
                double euclidean = this.clusterManager.euclidean(j4, i - 1);
                if (i == 1) {
                    this.distanceFromCentroid.set(j4, euclidean);
                    this.squaredDistance += euclidean * euclidean;
                    this.communities.set(j4, 0);
                } else if (this.distanceFromCentroid.get(j4) > euclidean) {
                    this.distanceFromCentroid.set(j4, euclidean);
                    this.squaredDistance += euclidean * euclidean;
                    this.communities.set(j4, i - 1);
                } else {
                    this.squaredDistance += this.distanceFromCentroid.get(j4) * this.distanceFromCentroid.get(j4);
                }
            }
            if (i == this.k) {
                if (this.distanceFromCentroid.get(j4) <= -1.0d) {
                    this.communities.set(j4, ((int) (-this.distanceFromCentroid.get(j4))) - 1);
                    this.distanceFromCentroid.set(j4, 0.0d);
                }
                int i2 = this.communities.get(j4);
                long[] jArr = this.communitySizes;
                jArr[i2] = jArr[i2] + 1;
                updateAfterAssignmentToCentroid(j4, i2);
            }
            j3 = j4 + 1;
        }
    }

    @Override // java.lang.Runnable
    public void run() {
        long startNode = this.partition.startNode();
        long nodeCount = startNode + this.partition.nodeCount();
        if (this.phase == TaskPhase.ITERATION) {
            assignNodeToCentroid(startNode, nodeCount);
        } else if (this.phase == TaskPhase.DISTANCE) {
            calculateFinalDistance(startNode, nodeCount);
        } else {
            distanceFromLastSampledCentroid(startNode, nodeCount, this.clusterManager.getCurrentlyAssigned());
        }
    }
}
