package org.neo4j.gds.ml.core.tensor;

import com.neo4j.gds.shaded.org.ejml.data.DMatrixRMaj;
import com.neo4j.gds.shaded.org.ejml.dense.row.mult.MatrixMatrixMult_DDRM;
import java.util.Arrays;
import java.util.function.DoubleUnaryOperator;
import org.neo4j.gds.collections.ArrayUtil;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/core/tensor/Matrix.class */
public class Matrix extends Tensor<Matrix> {
    private final int columns;
    private final int rows;

    public static long sizeInBytes(int i, int i2) {
        return Estimate.sizeOfDoubleArray(i * i2);
    }

    public Matrix(double[] dArr, int i, int i2) {
        super(dArr, Dimensions.matrix(i, i2));
        this.rows = i;
        this.columns = i2;
    }

    public static Matrix of(DMatrixRMaj dMatrixRMaj) {
        return new Matrix(dMatrixRMaj.data, dMatrixRMaj.numRows, dMatrixRMaj.numCols);
    }

    public Matrix(int i, int i2) {
        this(new double[Math.multiplyExact(i, i2)], i, i2);
    }

    public static Matrix create(double d, int i, int i2) {
        return new Matrix(ArrayUtil.fill(d, i * i2), i, i2);
    }

    public double dataAt(int i, int i2) {
        return dataAt((i * this.columns) + i2);
    }

    public void setDataAt(int i, int i2, double d) {
        setDataAt((i * this.columns) + i2, d);
    }

    public void setRow(int i, double[] dArr) {
        if (dArr.length != this.columns) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale("Input vector dimension is unequal to column count of the matrix. Got %d, but expected %d.", Integer.valueOf(this.columns), Integer.valueOf(dArr.length)));
        }
        System.arraycopy(dArr, 0, this.data, i * this.columns, this.columns);
    }

    public double[] getRow(int i) {
        return Arrays.copyOfRange(this.data, i * this.columns, (i + 1) * this.columns);
    }

    public void addDataAt(int i, int i2, double d) {
        updateDataAt(i, i2, d2 -> {
            return d2 + d;
        });
    }

    public void updateDataAt(int i, int i2, DoubleUnaryOperator doubleUnaryOperator) {
        setDataAt(i, i2, doubleUnaryOperator.applyAsDouble(dataAt(i, i2)));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.core.tensor.Tensor
    public Matrix createWithSameDimensions() {
        return new Matrix(rows(), cols());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.core.tensor.Tensor
    public Matrix copy() {
        return new Matrix((double[]) this.data.clone(), rows(), cols());
    }

    @Override // org.neo4j.gds.ml.core.tensor.Tensor
    public Matrix add(Matrix matrix) {
        if (rows() != matrix.rows() || cols() != matrix.cols()) {
            throw new ArithmeticException(StringFormatting.formatWithLocale("Matrix dimensions must match! Got dimensions (%d, %d) + (%d, %d)", Integer.valueOf(rows()), Integer.valueOf(cols()), Integer.valueOf(matrix.rows()), Integer.valueOf(matrix.cols())));
        }
        Matrix createWithSameDimensions = createWithSameDimensions();
        double[] dArr = this.data;
        for (int i = 0; i < dArr.length; i++) {
            createWithSameDimensions.data[i] = dArr[i] + matrix.data[i];
        }
        return createWithSameDimensions;
    }

    public Matrix multiply(Matrix matrix) {
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(this.rows, matrix.cols());
        MatrixMatrixMult_DDRM.mult_reorder(toEjml(), matrix.toEjml(), dMatrixRMaj);
        return of(dMatrixRMaj);
    }

    public Matrix multiplyTransB(Matrix matrix) {
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(this.rows, matrix.rows);
        MatrixMatrixMult_DDRM.multTransB(toEjml(), matrix.toEjml(), dMatrixRMaj);
        return of(dMatrixRMaj);
    }

    public Matrix multiplyTransA(Matrix matrix) {
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(cols(), matrix.cols());
        MatrixMatrixMult_DDRM.multTransA_reorder(toEjml(), matrix.toEjml(), dMatrixRMaj);
        return of(dMatrixRMaj);
    }

    public Matrix sumBroadcastColumnWise(Vector vector) {
        Matrix createWithSameDimensions = createWithSameDimensions();
        for (int i = 0; i < this.rows; i++) {
            for (int i2 = 0; i2 < this.columns; i2++) {
                int i3 = (i * this.columns) + i2;
                createWithSameDimensions.data[i3] = this.data[i3] + vector.data[i2];
            }
        }
        return createWithSameDimensions;
    }

    public Vector sumPerColumn() {
        double[] dArr = new double[this.columns];
        for (int i = 0; i < this.columns; i++) {
            for (int i2 = 0; i2 < this.rows; i2++) {
                int i3 = i;
                dArr[i3] = dArr[i3] + this.data[(i2 * this.columns) + i];
            }
        }
        return new Vector(dArr);
    }

    public void setRow(int i, Matrix matrix, int i2) {
        if (matrix.columns != this.columns) {
            throw new ArithmeticException(StringFormatting.formatWithLocale("Input matrix must have the same number of columns. Expected %s, but got %s.", Integer.valueOf(this.columns), Integer.valueOf(matrix.columns)));
        }
        System.arraycopy(matrix.data, i2 * matrix.columns, this.data, i * this.columns, matrix.columns);
    }

    @Override // org.neo4j.gds.ml.core.tensor.Tensor
    public String shortDescription() {
        return StringFormatting.formatWithLocale("Matrix(%d, %d)", Integer.valueOf(rows()), Integer.valueOf(cols()));
    }

    public int rows() {
        return this.rows;
    }

    public int cols() {
        return this.columns;
    }

    public boolean isVector() {
        return Dimensions.isVector(this.dimensions);
    }

    public DMatrixRMaj toEjml() {
        return DMatrixRMaj.wrap(dimension(0), dimension(1), data());
    }
}
