package org.neo4j.gds.ml.core.functions;

import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Matrix;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/ReducedSoftmax.class */
public class ReducedSoftmax extends SingleParentVariable<Matrix, Matrix> {
    public ReducedSoftmax(Variable<Matrix> variable) {
        super(variable, Dimensions.matrix(variable.dimension(0), variable.dimension(1) + 1));
    }

    public static long sizeInBytes(int i, int i2) {
        return Matrix.sizeInBytes(i, i2 - 1);
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Matrix apply(ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.data(this.parent);
        int rows = matrix.rows();
        int cols = matrix.cols() + 1;
        Matrix create = Matrix.create(0.0d, rows, cols);
        boolean z = false;
        for (int i = 0; i < rows; i++) {
            double d = 0.0d;
            int i2 = 0;
            while (i2 < cols) {
                double exp = i2 == cols - 1 ? 1.0d : Math.exp(matrix.dataAt(i, i2));
                if (Double.isInfinite(exp)) {
                    z = true;
                    exp = Double.MAX_VALUE;
                }
                create.setDataAt(i, i2, exp);
                d += exp;
                if (Double.isInfinite(d)) {
                    z = true;
                    d = Double.MAX_VALUE;
                }
                i2++;
            }
            for (int i3 = 0; i3 < cols; i3++) {
                create.setDataAt(i, i3, create.dataAt(i, i3) / d);
            }
        }
        if (z) {
            rescale(create);
        }
        return create;
    }

    private static void rescale(Matrix matrix) {
        int rows = matrix.rows();
        int cols = matrix.cols();
        for (int i = 0; i < rows; i++) {
            double d = 1.0E-15d;
            for (int i2 = 0; i2 < cols; i2++) {
                d += matrix.dataAt((i * cols) + i2);
            }
            for (int i3 = 0; i3 < cols; i3++) {
                int i4 = (i * cols) + i3;
                matrix.setDataAt(i4, matrix.dataAt(i4) / d);
            }
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.core.functions.SingleParentVariable
    public Matrix gradientForParent(ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.data(this);
        Matrix matrix2 = (Matrix) computationContext.gradient(this);
        int rows = matrix.rows();
        int cols = matrix.cols();
        Matrix create = Matrix.create(0.0d, rows, cols - 1);
        for (int i = 0; i < rows; i++) {
            int i2 = 0;
            while (i2 < cols - 1) {
                double dataAt = matrix.dataAt(i, i2);
                int i3 = 0;
                while (i3 < cols) {
                    create.addDataAt(i, i2, matrix.dataAt(i, i3) * ((i2 == i3 ? 1 : 0) - dataAt) * matrix2.dataAt(i, i3));
                    i3++;
                }
                i2++;
            }
        }
        return create;
    }
}
