package org.neo4j.gds.ml.core.functions;

import com.neo4j.gds.shaded.org.jetbrains.annotations.TestOnly;
import java.util.List;
import org.neo4j.gds.ml.core.AbstractVariable;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.Vector;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/Constant.class */
public class Constant<T extends Tensor<T>> extends AbstractVariable<T> {
    private final T data;

    public Constant(T t) {
        super(List.of(), t.dimensions());
        this.data = t;
    }

    public static Constant<Scalar> scalar(double d) {
        return new Constant<>(new Scalar(d));
    }

    public static Constant<Vector> vector(double[] dArr) {
        return new Constant<>(new Vector(dArr));
    }

    public static Constant<Matrix> matrix(double[] dArr, int i, int i2) {
        return new Constant<>(new Matrix(dArr, i, i2));
    }

    @TestOnly
    public T data() {
        return this.data;
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public T apply(ComputationContext computationContext) {
        return this.data;
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public T gradient(Variable<?> variable, ComputationContext computationContext) {
        throw new AbstractVariable.NotAFunctionException();
    }

    public static long sizeInBytes(int[] iArr) {
        return Tensor.sizeInBytes(iArr);
    }

    @Override // org.neo4j.gds.ml.core.AbstractVariable
    public String toString() {
        return StringFormatting.formatWithLocale("%s: %s, requireGradient: %b", getClass().getSimpleName(), this.data.toString(), Boolean.valueOf(requireGradient()));
    }
}
