package org.neo4j.gds.ml.core.functions;

import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/L2NormSquared.class */
public class L2NormSquared<T extends Tensor<T>> extends SingleParentVariable<T, Scalar> {
    public L2NormSquared(Variable<T> variable) {
        super(variable, Dimensions.scalar());
    }

    public static long sizeInBytesOfApply() {
        return Scalar.sizeInBytes();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.neo4j.gds.ml.core.Variable
    public Scalar apply(ComputationContext computationContext) {
        Tensor data = computationContext.data(this.parent);
        int i = data.totalSize();
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            double dataAt = data.dataAt(i2);
            d += dataAt * dataAt;
        }
        return new Scalar(d);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.neo4j.gds.ml.core.functions.SingleParentVariable
    public T gradientForParent(ComputationContext computationContext) {
        return (T) computationContext.data(this.parent).scalarMultiply(2.0d * ((Scalar) computationContext.gradient(this)).value());
    }
}
