package org.neo4j.gds.ml.core.functions;

import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/Relu.class */
public class Relu<T extends Tensor<T>> extends SingleParentVariable<T, T> {
    private static final double ALPHA = 0.01d;
    private final double alpha;

    public Relu(Variable<T> variable, double d) {
        super(variable, variable.dimensions());
        this.alpha = d;
    }

    public Relu(Variable<T> variable) {
        this(variable, ALPHA);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.neo4j.gds.ml.core.Variable
    public T apply(ComputationContext computationContext) {
        return (T) computationContext.data(this.parent).map(d -> {
            return d > 0.0d ? d : this.alpha * d;
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.neo4j.gds.ml.core.functions.SingleParentVariable
    public T gradientForParent(ComputationContext computationContext) {
        T t = (T) computationContext.data(this.parent).map(d -> {
            if (d > 0.0d) {
                return 1.0d;
            }
            return this.alpha;
        });
        t.elementwiseProductMutate(computationContext.gradient(this));
        return t;
    }
}
