package org.neo4j.gds.ml.core.functions;

import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Matrix;

/* loaded from: input_file:org/neo4j/gds/ml/core/functions/Slice.class */
public class Slice extends SingleParentVariable<Matrix, Matrix> {
    private final int[] batchIds;

    public Slice(Variable<Matrix> variable, int[] iArr) {
        super(variable, Dimensions.matrix(iArr.length, variable.dimension(1)));
        this.batchIds = iArr;
    }

    @Override // org.neo4j.gds.ml.core.Variable
    public Matrix apply(ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.data(this.parent);
        int length = this.batchIds.length;
        Matrix matrix2 = new Matrix(length, matrix.cols());
        for (int i = 0; i < length; i++) {
            matrix2.setRow(i, matrix, this.batchIds[i]);
        }
        return matrix2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.core.functions.SingleParentVariable
    public Matrix gradientForParent(ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.gradient(this);
        Matrix createWithSameDimensions = ((Matrix) computationContext.data(this.parent)).createWithSameDimensions();
        int length = this.batchIds.length;
        int cols = matrix.cols();
        for (int i = 0; i < length; i++) {
            int i2 = this.batchIds[i];
            for (int i3 = 0; i3 < cols; i3++) {
                createWithSameDimensions.addDataAt(i2, i3, matrix.dataAt(i, i3));
            }
        }
        return createWithSameDimensions;
    }
}
