package org.neo4j.gds.ml.core.functions;

import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Matrix;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/NormalizeRows.class */
public class NormalizeRows extends SingleParentVariable<Matrix, Matrix> {
    private static final double EPSILON = 1.0E-10d;

    public NormalizeRows(Variable<Matrix> variable) {
        super(variable, variable.dimensions());
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Matrix apply(ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.data(this.parent);
        int rows = matrix.rows();
        int cols = matrix.cols();
        Matrix createWithSameDimensions = matrix.createWithSameDimensions();
        for (int i = 0; i < rows; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < cols; i2++) {
                d += matrix.dataAt(i, i2) * matrix.dataAt(i, i2);
            }
            double sqrt = Math.sqrt(d) + 1.0E-10d;
            for (int i3 = 0; i3 < cols; i3++) {
                createWithSameDimensions.setDataAt(i, i3, matrix.dataAt(i, i3) / sqrt);
            }
        }
        return createWithSameDimensions;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.core.functions.SingleParentVariable
    public Matrix gradientForParent(ComputationContext computationContext) {
        double d;
        double dataAt;
        Matrix matrix = (Matrix) computationContext.data(this.parent);
        Matrix matrix2 = (Matrix) computationContext.gradient(this);
        Matrix createWithSameDimensions = matrix.createWithSameDimensions();
        int rows = matrix.rows();
        int cols = matrix.cols();
        for (int i = 0; i < rows; i++) {
            double d2 = 0.0d;
            for (int i2 = 0; i2 < cols; i2++) {
                double dataAt2 = matrix.dataAt(i, i2);
                d2 += dataAt2 * dataAt2;
            }
            double sqrt = Math.sqrt(d2) * d2;
            if (Double.compare(sqrt, 0.0d) != 0) {
                for (int i3 = 0; i3 < cols; i3++) {
                    double dataAt3 = matrix.dataAt(i, i3);
                    for (int i4 = 0; i4 < cols; i4++) {
                        if (i3 == i4) {
                            d = matrix2.dataAt(i, i3);
                            dataAt = d2 - (dataAt3 * dataAt3);
                        } else {
                            d = -matrix2.dataAt(i, i4);
                            dataAt = dataAt3 * matrix.dataAt(i, i4);
                        }
                        createWithSameDimensions.addDataAt(i, i3, d * dataAt);
                    }
                    createWithSameDimensions.setDataAt(i, i3, createWithSameDimensions.dataAt(i, i3) / sqrt);
                }
            }
        }
        return createWithSameDimensions;
    }
}
