package org.neo4j.gds.ml.core.functions;

import java.util.List;
import org.neo4j.gds.ml.core.AbstractVariable;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.Vector;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/RootMeanSquareError.class */
public class RootMeanSquareError extends AbstractVariable<Scalar> {
    private final Variable<Matrix> predictionsVar;
    private final Variable<Vector> targetsVar;
    static final /* synthetic */ boolean $assertionsDisabled;

    public RootMeanSquareError(Variable<Matrix> variable, Variable<Vector> variable2) {
        super(List.of(variable, variable2), Dimensions.scalar());
        if (!$assertionsDisabled && !Dimensions.isVector(variable.dimensions())) {
            throw new AssertionError("Predictions need to be a vector");
        }
        if (!$assertionsDisabled && Dimensions.totalSize(variable.dimensions()) <= 0) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && Dimensions.totalSize(variable.dimensions()) != Dimensions.totalSize(variable2.dimensions())) {
            throw new AssertionError("Predictions and targets need to have the same total size");
        }
        this.predictionsVar = variable;
        this.targetsVar = variable2;
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Scalar apply(ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.data(this.predictionsVar);
        Vector vector = (Vector) computationContext.data(this.targetsVar);
        double d = 0.0d;
        for (int i = 0; i < vector.length(); i++) {
            double dataAt = matrix.dataAt(i) - vector.dataAt(i);
            d += dataAt * dataAt;
        }
        return !Double.isFinite(d) ? new Scalar(Double.MAX_VALUE) : new Scalar(Math.sqrt(d / vector.length()));
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Tensor<?> gradient(Variable<?> variable, ComputationContext computationContext) {
        if (variable != this.predictionsVar) {
            throw new IllegalStateException("The gradient should only be computed for the prediction parent, but got " + variable.render());
        }
        Matrix matrix = (Matrix) computationContext.data(this.predictionsVar);
        int length = ((Vector) computationContext.data(this.targetsVar)).length();
        Scalar scalar = (Scalar) computationContext.data(this);
        Tensor<?> createWithSameDimensions = computationContext.data(variable).createWithSameDimensions();
        if (Double.compare(scalar.value(), 0.0d) == 0) {
            return createWithSameDimensions;
        }
        double value = ((Scalar) computationContext.gradient(this)).value() / scalar.scalarMultiply(length).value();
        Vector vector = (Vector) computationContext.data(this.targetsVar);
        for (int i = 0; i < length; i++) {
            createWithSameDimensions.setDataAt(i, (matrix.dataAt(i) - vector.dataAt(i)) * value);
        }
        return createWithSameDimensions;
    }

    static {
        $assertionsDisabled = !RootMeanSquareError.class.desiredAssertionStatus();
    }
}
