package org.neo4j.gds.ml.core.functions;

import java.util.Iterator;
import java.util.List;
import org.neo4j.gds.ml.core.AbstractVariable;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/ElementSum.class */
public class ElementSum extends AbstractVariable<Scalar> {
    public ElementSum(List<Variable<?>> list) {
        super(list, Dimensions.scalar());
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Scalar apply(ComputationContext computationContext) {
        double d = 0.0d;
        Iterator<? extends Variable<?>> it = parents().iterator();
        while (it.hasNext()) {
            d += computationContext.data(it.next()).aggregateSum();
        }
        return new Scalar(d);
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Tensor<?> gradient(Variable<?> variable, ComputationContext computationContext) {
        double value = ((Scalar) computationContext.gradient(this)).value();
        return computationContext.data(variable).map(d -> {
            return value;
        });
    }
}
