package org.neo4j.gds.ml.core.functions;

import java.util.List;
import org.neo4j.gds.ml.core.AbstractVariable;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.Vector;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/CrossEntropyLoss.class */
public class CrossEntropyLoss extends AbstractVariable<Scalar> {
    private static final double PREDICTED_PROBABILITY_THRESHOLD = 1.0E-50d;
    private final Variable<Matrix> predictions;
    private final Variable<Vector> targets;
    protected final double[] classWeights;

    public CrossEntropyLoss(Variable<Matrix> variable, Variable<Vector> variable2, double[] dArr) {
        super(List.of(variable, variable2), Dimensions.scalar());
        this.predictions = variable;
        this.targets = variable2;
        this.classWeights = dArr;
    }

    public static long sizeInBytes() {
        return Scalar.sizeInBytes();
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Scalar apply(ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.data(this.predictions);
        Vector vector = (Vector) computationContext.data(this.targets);
        double d = 0.0d;
        for (int i = 0; i < vector.totalSize(); i++) {
            int dataAt = (int) vector.dataAt(i);
            double dataAt2 = matrix.dataAt(i, dataAt);
            if (dataAt2 > 0.0d) {
                d += computeIndividualLoss(dataAt2, dataAt);
            }
        }
        return new Scalar((-d) / matrix.rows());
    }

    double computeIndividualLoss(double d, int i) {
        return this.classWeights[i] * Math.log(d);
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Tensor<?> gradient(Variable<?> variable, ComputationContext computationContext) {
        if (variable != this.predictions) {
            throw new IllegalStateException("The gradient should not be necessary for the targets. But got: " + this.targets.render());
        }
        Matrix matrix = (Matrix) computationContext.data(this.predictions);
        Matrix createWithSameDimensions = matrix.createWithSameDimensions();
        Vector vector = (Vector) computationContext.data(this.targets);
        double value = ((Scalar) computationContext.gradient(this)).value();
        for (int i = 0; i < createWithSameDimensions.rows(); i++) {
            int dataAt = (int) vector.dataAt(i);
            double dataAt2 = matrix.dataAt((i * matrix.cols()) + dataAt);
            if (dataAt2 > PREDICTED_PROBABILITY_THRESHOLD) {
                createWithSameDimensions.setDataAt(i, dataAt, value * computeErrorPerExample(createWithSameDimensions.rows(), dataAt2, dataAt));
            }
        }
        return createWithSameDimensions;
    }

    double computeErrorPerExample(int i, double d, int i2) {
        return (-this.classWeights[i2]) / (d * i);
    }
}
