package org.neo4j.gds.ml.core.functions;

import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.subgraph.BatchNeighbors;
import org.neo4j.gds.ml.core.tensor.Matrix;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/ElementWiseMax.class */
public class ElementWiseMax extends SingleParentVariable<Matrix, Matrix> {
    private static final int INVALID_NEIGHBOR = -1;
    private final BatchNeighbors batchNeighbors;
    static final /* synthetic */ boolean $assertionsDisabled;

    public ElementWiseMax(Variable<Matrix> variable, BatchNeighbors batchNeighbors) {
        super(variable, Dimensions.matrix(batchNeighbors.batchSize(), variable.dimension(1)));
        this.batchNeighbors = batchNeighbors;
        if (!$assertionsDisabled && variable.dimension(0) < batchNeighbors.nodeCount()) {
            throw new AssertionError("Expecting a row for each node in the subgraph");
        }
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Matrix apply(ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.data(this.parent);
        int batchSize = this.batchNeighbors.batchSize();
        int cols = matrix.cols();
        int[] batchIds = this.batchNeighbors.batchIds();
        Matrix create = Matrix.create(Double.NEGATIVE_INFINITY, batchSize, cols);
        for (int i = 0; i < batchSize; i++) {
            int i2 = batchIds[i];
            for (int i3 : this.batchNeighbors.neighbors(i2)) {
                double relationshipWeight = this.batchNeighbors.relationshipWeight(i2, i3);
                for (int i4 = 0; i4 < cols; i4++) {
                    double dataAt = matrix.dataAt(i3, i4) * relationshipWeight;
                    if (dataAt >= create.dataAt(i, i4)) {
                        create.setDataAt(i, i4, dataAt);
                    }
                }
            }
            if (this.batchNeighbors.degree(i2) == 0) {
                for (int i5 = 0; i5 < cols; i5++) {
                    create.setDataAt(i, i5, 0.0d);
                }
            }
        }
        return create;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.core.functions.SingleParentVariable
    public Matrix gradientForParent(ComputationContext computationContext) {
        Matrix createWithSameDimensions = ((Matrix) computationContext.data(this.parent)).createWithSameDimensions();
        int cols = createWithSameDimensions.cols();
        Matrix matrix = (Matrix) computationContext.data(this.parent);
        Matrix matrix2 = (Matrix) computationContext.gradient(this);
        Matrix matrix3 = (Matrix) computationContext.data(this);
        int[] batchIds = this.batchNeighbors.batchIds();
        for (int i = 0; i < this.batchNeighbors.batchSize(); i++) {
            int i2 = batchIds[i];
            int[] neighbors = this.batchNeighbors.neighbors(i2);
            int length = neighbors.length;
            double[] dArr = new double[length];
            for (int i3 = 0; i3 < length; i3++) {
                dArr[i3] = this.batchNeighbors.relationshipWeight(i2, neighbors[i3]);
            }
            for (int i4 = 0; i4 < cols; i4++) {
                double dataAt = matrix3.dataAt(i, i4);
                double d = Double.MAX_VALUE;
                int i5 = -1;
                double d2 = Double.NaN;
                for (int i6 = 0; i6 < length; i6++) {
                    int i7 = neighbors[i6];
                    double d3 = dArr[i6];
                    double abs = Math.abs(dataAt - (matrix.dataAt(i7, i4) * d3));
                    if (abs < d) {
                        d = abs;
                        i5 = i7;
                        d2 = d3;
                    }
                }
                if (i5 != -1) {
                    createWithSameDimensions.addDataAt(i5, i4, matrix2.dataAt(i, i4) * d2);
                } else if (!$assertionsDisabled && length != 0) {
                    throw new AssertionError();
                }
            }
        }
        return createWithSameDimensions;
    }

    static {
        $assertionsDisabled = !ElementWiseMax.class.desiredAssertionStatus();
    }
}
