package org.neo4j.gds.ml.core.functions;

import java.util.List;
import org.neo4j.gds.ml.core.AbstractVariable;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/MeanSquareError.class */
public class MeanSquareError extends AbstractVariable<Scalar> {
    private final Variable<?> predictions;
    private final Variable<?> targets;

    public MeanSquareError(Variable<?> variable, Variable<?> variable2) {
        super(List.of(variable, variable2), Dimensions.scalar());
        this.predictions = variable;
        this.targets = variable2;
        validateDimensions(variable, variable2);
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Scalar apply(ComputationContext computationContext) {
        Tensor data = computationContext.data(this.predictions);
        Tensor data2 = computationContext.data(this.targets);
        double d = 0.0d;
        int i = data.totalSize();
        for (int i2 = 0; i2 < i; i2++) {
            double dataAt = data.dataAt(i2) - data2.dataAt(i2);
            d += dataAt * dataAt;
        }
        return !Double.isFinite(d) ? new Scalar(Double.MAX_VALUE) : new Scalar(d / i);
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Tensor<?> gradient(Variable<?> variable, ComputationContext computationContext) {
        Tensor data = computationContext.data(variable);
        Tensor data2 = variable == this.predictions ? computationContext.data(this.targets) : computationContext.data(this.predictions);
        int i = data.totalSize();
        Tensor<?> createWithSameDimensions = data.createWithSameDimensions();
        double dataAt = (2.0d * ((Scalar) computationContext.gradient(this)).dataAt(0)) / i;
        for (int i2 = 0; i2 < i; i2++) {
            createWithSameDimensions.setDataAt(i2, dataAt * (data.dataAt(i2) - data2.dataAt(i2)));
        }
        return createWithSameDimensions;
    }

    private void validateDimensions(Variable<?> variable, Variable<?> variable2) {
        if (Dimensions.totalSize(variable.dimensions()) != Dimensions.totalSize(variable2.dimensions())) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale("Targets and predictions must be of equal size. Got predictions: %s, targets: %s", Dimensions.render(variable.dimensions()), Dimensions.render(variable2.dimensions())));
        }
    }
}
