package org.neo4j.gds.ml.core.functions;

import java.util.List;
import org.neo4j.gds.ml.core.AbstractVariable;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/SingleParentVariable.class */
public abstract class SingleParentVariable<P extends Tensor<P>, T extends Tensor<T>> extends AbstractVariable<T> {
    protected final Variable<P> parent;

    public SingleParentVariable(Variable<P> variable, int[] iArr) {
        super(List.of(variable), iArr);
        this.parent = variable;
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Tensor<?> gradient(Variable<?> variable, ComputationContext computationContext) {
        validateParent(variable);
        return gradientForParent(computationContext);
    }

    protected abstract P gradientForParent(ComputationContext computationContext);

    private void validateParent(Variable<?> variable) {
        if (variable != this.parent) {
            throw new RuntimeException("Calling gradient with a `parent` that was not expected");
        }
    }
}
