package org.neo4j.gds.ml.core.functions;

import java.util.List;
import org.neo4j.gds.ml.core.AbstractVariable;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/EWiseAddMatrixScalar.class */
public class EWiseAddMatrixScalar extends AbstractVariable<Matrix> {
    private final Variable<Matrix> matrixVariable;
    private final Variable<Scalar> scalarVariable;

    public EWiseAddMatrixScalar(Variable<Matrix> variable, Variable<Scalar> variable2) {
        super(List.of(variable, variable2), variable.dimensions());
        this.matrixVariable = variable;
        this.scalarVariable = variable2;
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Matrix apply(ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.data(this.matrixVariable);
        double value = ((Scalar) computationContext.data(this.scalarVariable)).value();
        return matrix.map(d -> {
            return d + value;
        });
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Tensor<?> gradient(Variable<?> variable, ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.gradient(this);
        return variable == this.matrixVariable ? matrix : new Scalar(matrix.aggregateSum());
    }
}
