package org.neo4j.gds.ml.core.functions;

import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Matrix;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/Softmax.class */
public class Softmax extends SingleParentVariable<Matrix, Matrix> {
    public Softmax(Variable<Matrix> variable) {
        super(variable, variable.dimensions());
    }

    public static long sizeInBytes(int i, int i2) {
        return Matrix.sizeInBytes(i, i2);
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Matrix apply(ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.data(this.parent);
        int rows = matrix.rows();
        int cols = matrix.cols();
        Matrix createWithSameDimensions = matrix.createWithSameDimensions();
        boolean z = false;
        for (int i = 0; i < rows; i++) {
            double d = 1.0E-15d;
            for (int i2 = 0; i2 < cols; i2++) {
                int i3 = (i * cols) + i2;
                double exp = Math.exp(matrix.dataAt(i3));
                if (Double.isInfinite(exp)) {
                    z = true;
                    exp = Double.MAX_VALUE;
                }
                createWithSameDimensions.setDataAt(i3, exp);
                d += exp;
                if (Double.isInfinite(d)) {
                    z = true;
                    d = Double.MAX_VALUE;
                }
            }
            for (int i4 = 0; i4 < cols; i4++) {
                int i5 = (i * cols) + i4;
                createWithSameDimensions.setDataAt(i5, createWithSameDimensions.dataAt(i5) / d);
            }
        }
        if (z) {
            rescale(createWithSameDimensions);
        }
        return createWithSameDimensions;
    }

    private static void rescale(Matrix matrix) {
        int rows = matrix.rows();
        int cols = matrix.cols();
        for (int i = 0; i < rows; i++) {
            double d = 1.0E-15d;
            for (int i2 = 0; i2 < cols; i2++) {
                d += matrix.dataAt((i * cols) + i2);
            }
            for (int i3 = 0; i3 < cols; i3++) {
                int i4 = (i * cols) + i3;
                matrix.setDataAt(i4, matrix.dataAt(i4) / d);
            }
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.core.functions.SingleParentVariable
    public Matrix gradientForParent(ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.data(this);
        Matrix matrix2 = (Matrix) computationContext.gradient(this);
        int rows = matrix.rows();
        int cols = matrix.cols();
        Matrix create = Matrix.create(0.0d, rows, cols);
        for (int i = 0; i < rows; i++) {
            int i2 = 0;
            while (i2 < cols) {
                int i3 = (i * cols) + i2;
                double dataAt = matrix.dataAt(i3);
                int i4 = 0;
                while (i4 < cols) {
                    create.addDataAt(i3, matrix.dataAt((i * cols) + i4) * ((i2 == i4 ? 1 : 0) - dataAt) * matrix2.dataAt((i * cols) + i4));
                    i4++;
                }
                i2++;
            }
        }
        return create;
    }
}
