package org.neo4j.gds.ml.core.functions;

import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/ConstantScale.class */
public class ConstantScale<T extends Tensor<T>> extends SingleParentVariable<T, T> {
    private final double constant;

    public ConstantScale(Variable<T> variable, double d) {
        super(variable, variable.dimensions());
        this.constant = d;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.neo4j.gds.ml.core.Variable
    public T apply(ComputationContext computationContext) {
        return (T) computationContext.data(this.parent).scalarMultiply(this.constant);
    }

    @Override // org.neo4j.gds.ml.core.functions.SingleParentVariable
    protected T gradientForParent(ComputationContext computationContext) {
        return (T) computationContext.gradient(this).scalarMultiply(this.constant);
    }

    @Override // org.neo4j.gds.ml.core.AbstractVariable
    public String toString() {
        return StringFormatting.formatWithLocale("%s: scale by %s, requireGradient: %b", getClass().getSimpleName(), Double.toString(this.constant), Boolean.valueOf(requireGradient()));
    }
}
