package org.neo4j.gds.ml.core.functions;

import java.util.ArrayList;
import java.util.Map;
import java.util.stream.IntStream;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.ml.core.AbstractVariable;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.operations.DoubleMatrixOperations;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/LabelwiseFeatureProjection.class */
public class LabelwiseFeatureProjection extends AbstractVariable<Matrix> {
    private final long[] nodeIds;
    private final HugeObjectArray<double[]> features;
    private final Map<NodeLabel, Weights<Matrix>> weightsByLabel;
    private final int projectedFeatureDimension;
    private final NodeLabel[] labels;

    public LabelwiseFeatureProjection(long[] jArr, HugeObjectArray<double[]> hugeObjectArray, Map<NodeLabel, Weights<Matrix>> map, int i, NodeLabel[] nodeLabelArr) {
        super(new ArrayList(map.values()), Dimensions.matrix(jArr.length, i));
        this.nodeIds = jArr;
        this.features = hugeObjectArray;
        this.weightsByLabel = map;
        this.projectedFeatureDimension = i;
        this.labels = nodeLabelArr;
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Matrix apply(ComputationContext computationContext) {
        Matrix matrix = new Matrix(this.nodeIds.length, this.projectedFeatureDimension);
        for (int i = 0; i < this.nodeIds.length; i++) {
            long j = this.nodeIds[i];
            Weights<Matrix> weights = this.weightsByLabel.get(this.labels[i]);
            double[] dArr = this.features.get(j);
            Matrix matrix2 = new Matrix(dArr, 1, dArr.length);
            Matrix matrix3 = new Matrix(weights.dimension(0), 1);
            DoubleMatrixOperations.multTransB(weights.data(), matrix2, matrix3, i2 -> {
                return i2 < this.projectedFeatureDimension;
            });
            matrix.setRow(i, matrix3.data());
        }
        return matrix;
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Tensor<?> gradient(Variable<?> variable, ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.gradient(this);
        int dimension = variable.dimension(0);
        int dimension2 = variable.dimension(1);
        double[] dArr = new double[dimension * dimension2];
        IntStream.range(0, this.nodeIds.length).forEach(i -> {
            long j = this.nodeIds[i];
            if (this.weightsByLabel.get(this.labels[i]) == variable) {
                double[] dArr2 = this.features.get(j);
                for (int i = 0; i < dimension; i++) {
                    double dataAt = matrix.dataAt(i, i);
                    for (int i2 = 0; i2 < dimension2; i2++) {
                        int i3 = (i * dimension2) + i2;
                        dArr[i3] = dArr[i3] + (dArr2[i2] * dataAt);
                    }
                }
            }
        });
        return new Matrix(dArr, dimension, dimension2);
    }
}
