package org.neo4j.gds.ml.core.functions;

import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Vector;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/ReducedFocalLoss.class */
public class ReducedFocalLoss extends ReducedCrossEntropyLoss {
    private final double focusWeight;

    public ReducedFocalLoss(Variable<Matrix> variable, Variable<Matrix> variable2, Weights<Vector> weights, Variable<Matrix> variable3, Variable<Vector> variable4, double d, double[] dArr) {
        super(variable, variable2, weights, variable3, variable4, dArr);
        this.focusWeight = d;
    }

    public static long sizeInBytes() {
        return Scalar.sizeInBytes();
    }

    @Override // org.neo4j.gds.ml.core.functions.ReducedCrossEntropyLoss
    double computeIndividualLoss(double d, int i) {
        return this.classWeights[i] * Math.pow(1.0d - d, this.focusWeight) * Math.log(d);
    }

    @Override // org.neo4j.gds.ml.core.functions.ReducedCrossEntropyLoss
    double computeErrorPerExample(int i, double d, double d2, double d3, int i2) {
        double d4 = 1.0d - d3;
        double pow = Math.pow(d4, this.focusWeight - 1.0d);
        return ((this.classWeights[i2] * (((this.focusWeight * pow) * Math.log(d3)) - ((pow * d4) / d3))) * (d3 * (d2 - d))) / i;
    }
}
