package org.neo4j.gds.ml.core.functions;

import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.subgraph.BatchNeighbors;
import org.neo4j.gds.ml.core.tensor.Matrix;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/MultiMean.class */
public class MultiMean extends SingleParentVariable<Matrix, Matrix> {
    private final BatchNeighbors subGraph;
    static final /* synthetic */ boolean $assertionsDisabled;

    public MultiMean(Variable<Matrix> variable, BatchNeighbors batchNeighbors) {
        super(variable, Dimensions.matrix(batchNeighbors.batchSize(), variable.dimension(1)));
        this.subGraph = batchNeighbors;
        if (!$assertionsDisabled && variable.dimension(0) < batchNeighbors.nodeCount()) {
            throw new AssertionError("Expecting a row for each node in the subgraph");
        }
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Matrix apply(ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.data(this.parent);
        int[] batchIds = this.subGraph.batchIds();
        int length = batchIds.length;
        int cols = matrix.cols();
        Matrix create = Matrix.create(0.0d, length, cols);
        for (int i = 0; i < length; i++) {
            int i2 = batchIds[i];
            int[] neighbors = this.subGraph.neighbors(i2);
            int length2 = neighbors.length + 1;
            for (int i3 = 0; i3 < cols; i3++) {
                create.addDataAt(i, i3, matrix.dataAt(i2, i3) / length2);
            }
            for (int i4 : neighbors) {
                double relationshipWeight = this.subGraph.relationshipWeight(i2, i4);
                for (int i5 = 0; i5 < cols; i5++) {
                    create.addDataAt(i, i5, (matrix.dataAt(i4, i5) * relationshipWeight) / length2);
                }
            }
        }
        return create;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.core.functions.SingleParentVariable
    public Matrix gradientForParent(ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.gradient(this);
        Matrix createWithSameDimensions = ((Matrix) computationContext.data(this.parent)).createWithSameDimensions();
        int cols = createWithSameDimensions.cols();
        int[] batchIds = this.subGraph.batchIds();
        for (int i = 0; i < batchIds.length; i++) {
            int i2 = batchIds[i];
            int[] neighbors = this.subGraph.neighbors(i2);
            int length = neighbors.length + 1;
            for (int i3 = 0; i3 < cols; i3++) {
                createWithSameDimensions.addDataAt(i2, i3, matrix.dataAt(i, i3) / length);
            }
            for (int i4 : neighbors) {
                double relationshipWeight = this.subGraph.relationshipWeight(i2, i4);
                for (int i5 = 0; i5 < cols; i5++) {
                    createWithSameDimensions.addDataAt((i4 * cols) + i5, (matrix.dataAt(i, i5) * relationshipWeight) / length);
                }
            }
        }
        return createWithSameDimensions;
    }

    static {
        $assertionsDisabled = !MultiMean.class.desiredAssertionStatus();
    }
}
