package org.neo4j.gds.embeddings.graphsage;

import java.util.Map;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.LabelwiseFeatureProjection;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/MultiLabelFeatureFunction.class */
public class MultiLabelFeatureFunction implements FeatureFunction {
    private final Map<NodeLabel, Weights<Matrix>> weightsByLabel;
    private final int projectedFeatureDimension;

    /* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/MultiLabelFeatureFunction$SingleNodeLabelConsumer.class */
    private static class SingleNodeLabelConsumer implements IdMap.NodeLabelConsumer {
        NodeLabel nodeLabel;

        private SingleNodeLabelConsumer() {
        }

        @Override // org.neo4j.gds.api.IdMap.NodeLabelConsumer
        public boolean accept(NodeLabel nodeLabel) {
            this.nodeLabel = nodeLabel;
            return false;
        }
    }

    public MultiLabelFeatureFunction(Map<NodeLabel, Weights<Matrix>> map, int i) {
        this.weightsByLabel = map;
        this.projectedFeatureDimension = i;
    }

    public Map<NodeLabel, Weights<Matrix>> weightsByLabel() {
        return this.weightsByLabel;
    }

    @Override // org.neo4j.gds.embeddings.graphsage.FeatureFunction
    public Variable<Matrix> apply(Graph graph, long[] jArr, HugeObjectArray<double[]> hugeObjectArray) {
        NodeLabel[] nodeLabelArr = new NodeLabel[jArr.length];
        SingleNodeLabelConsumer singleNodeLabelConsumer = new SingleNodeLabelConsumer();
        for (int i = 0; i < jArr.length; i++) {
            graph.forEachNodeLabel(jArr[i], singleNodeLabelConsumer);
            nodeLabelArr[i] = singleNodeLabelConsumer.nodeLabel;
        }
        return new LabelwiseFeatureProjection(jArr, hugeObjectArray, this.weightsByLabel, this.projectedFeatureDimension, nodeLabelArr);
    }

    public int projectedFeatureDimension() {
        return this.projectedFeatureDimension;
    }
}
