package org.neo4j.gds.ml.core.tensor;

import java.util.Arrays;
import java.util.function.DoubleUnaryOperator;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/ml/core/tensor/Tensor.class */
public abstract class Tensor<SELF extends Tensor<SELF>> {
    protected final double[] data;
    protected final int[] dimensions;

    public Tensor(double[] dArr, int[] iArr) {
        this.data = dArr;
        this.dimensions = iArr;
    }

    public String toString() {
        return shortDescription() + ": " + Arrays.toString(this.data);
    }

    protected abstract String shortDescription();

    public abstract SELF createWithSameDimensions();

    public abstract SELF copy();

    public abstract SELF add(SELF self);

    public int dimension(int i) {
        return this.dimensions[i];
    }

    public int[] dimensions() {
        return this.dimensions;
    }

    public double[] data() {
        return this.data;
    }

    public double dataAt(int i) {
        return this.data[i];
    }

    public void setDataAt(int i, double d) {
        this.data[i] = d;
    }

    public void addDataAt(int i, double d) {
        double[] dArr = this.data;
        dArr[i] = dArr[i] + d;
    }

    public SELF map(DoubleUnaryOperator doubleUnaryOperator) {
        SELF createWithSameDimensions = createWithSameDimensions();
        Arrays.setAll(createWithSameDimensions.data, i -> {
            return doubleUnaryOperator.applyAsDouble(this.data[i]);
        });
        return createWithSameDimensions;
    }

    public Tensor<SELF> mapInPlace(DoubleUnaryOperator doubleUnaryOperator) {
        Arrays.setAll(this.data, i -> {
            return doubleUnaryOperator.applyAsDouble(this.data[i]);
        });
        return this;
    }

    public void addInPlace(Tensor<?> tensor) {
        int i = Dimensions.totalSize(this.dimensions);
        for (int i2 = 0; i2 < i; i2++) {
            double[] dArr = this.data;
            int i3 = i2;
            dArr[i3] = dArr[i3] + tensor.data[i2];
        }
    }

    public Tensor<SELF> scalarMultiplyMutate(double d) {
        int i = totalSize();
        for (int i2 = 0; i2 < i; i2++) {
            double[] dArr = this.data;
            int i3 = i2;
            dArr[i3] = dArr[i3] * d;
        }
        return this;
    }

    public SELF scalarMultiply(double d) {
        SELF copy = copy();
        copy.scalarMultiplyMutate(d);
        return copy;
    }

    public int totalSize() {
        return Dimensions.totalSize(this.dimensions);
    }

    public SELF elementwiseProduct(Tensor<?> tensor) {
        SELF createWithSameDimensions = createWithSameDimensions();
        for (int i = 0; i < this.data.length; i++) {
            createWithSameDimensions.data[i] = this.data[i] * tensor.data[i];
        }
        return createWithSameDimensions;
    }

    public Tensor<SELF> elementwiseProductMutate(Tensor<?> tensor) {
        for (int i = 0; i < this.data.length; i++) {
            this.data[i] = this.data[i] * tensor.data[i];
        }
        return this;
    }

    public double aggregateSum() {
        double d = 0.0d;
        for (double d2 : this.data) {
            d += d2;
        }
        return d;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        return equals((Tensor) obj, 1.0E-32d);
    }

    public int hashCode() {
        return (31 * Arrays.hashCode(this.data)) + Arrays.hashCode(this.dimensions);
    }

    public static long sizeInBytes(int[] iArr) {
        return Estimate.sizeOfDoubleArray(Dimensions.totalSize(iArr));
    }

    public boolean equals(Tensor<?> tensor, double d) {
        if (!Arrays.equals(this.dimensions, tensor.dimensions)) {
            return false;
        }
        for (int i = 0; i < tensor.data.length; i++) {
            if (Math.abs(this.data[i] - tensor.data[i]) > d) {
                return false;
            }
        }
        return true;
    }
}
