package org.neo4j.gds.embeddings.graphsage;

import com.neo4j.gds.shaded.org.apache.commons.lang3.mutable.MutableInt;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.api.schema.GraphSchema;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainMemoryEstimateParameters;
import org.neo4j.gds.embeddings.graphsage.algo.MultiLabelFeatureExtractors;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;
import org.neo4j.gds.mem.MemoryRange;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.features.BiasFeature;
import org.neo4j.gds.ml.core.features.FeatureExtraction;
import org.neo4j.gds.ml.core.features.FeatureExtractor;
import org.neo4j.gds.ml.core.features.HugeObjectArrayFeatureConsumer;
import org.neo4j.gds.ml.core.functions.NormalizeRows;
import org.neo4j.gds.ml.core.subgraph.NeighborhoodSampler;
import org.neo4j.gds.ml.core.subgraph.SubGraph;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageHelper.class */
public final class GraphSageHelper {
    private GraphSageHelper() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Variable<Matrix> embeddingsComputationGraph(List<SubGraph> list, Layer[] layerArr, Variable<Matrix> variable) {
        Variable<Matrix> variable2 = variable;
        for (int length = layerArr.length - 1; length >= 0; length--) {
            variable2 = layerArr[(layerArr.length - length) - 1].aggregator().aggregate(variable2, list.get(length));
        }
        return new NormalizeRows(variable2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static List<SubGraph> subGraphsPerLayer(Graph graph, long[] jArr, Layer[] layerArr, long j) {
        Random random = new Random(j);
        List list = (List) Arrays.stream(layerArr).map(layer -> {
            NeighborhoodSampler neighborhoodSampler = new NeighborhoodSampler(random.nextLong());
            return j2 -> {
                return neighborhoodSampler.sample(graph, j2, layer.sampleSize());
            };
        }).collect(Collectors.toList());
        Collections.reverse(list);
        return SubGraph.buildSubGraphs(jArr, list, SubGraph.relationshipWeightFunction(graph));
    }

    public static MemoryEstimation embeddingsEstimation(GraphSageTrainMemoryEstimateParameters graphSageTrainMemoryEstimateParameters, long j, long j2, int i, boolean z) {
        AggregatorMemoryEstimator poolAggregatorMemoryEstimator;
        boolean isMultiLabel = graphSageTrainMemoryEstimateParameters.isMultiLabel();
        List<LayerConfig> layerConfigs = graphSageTrainMemoryEstimateParameters.layerConfigs();
        int size = layerConfigs.size();
        MemoryEstimations.Builder startField = MemoryEstimations.builder("computationGraph").startField("subgraphs");
        ArrayList arrayList = new ArrayList(size + 1);
        ArrayList arrayList2 = new ArrayList(size + 1);
        arrayList.add(Long.valueOf(j));
        arrayList2.add(Long.valueOf(j));
        for (int i2 = 0; i2 < size; i2++) {
            int sampleSize = layerConfigs.get(i2).sampleSize();
            Long l = (Long) arrayList.get(i2);
            Long l2 = (Long) arrayList2.get(i2);
            long min = Math.min(l.longValue(), j2);
            long min2 = Math.min(l2.longValue() * (sampleSize + 1), j2);
            arrayList.add(Long.valueOf(min));
            arrayList2.add(Long.valueOf(min2));
            startField.add(MemoryEstimations.of("subgraph " + (i2 + 1), MemoryRange.of(Estimate.sizeOfIntArray(l.longValue()) + Estimate.sizeOfObjectArray(l.longValue()) + (l.longValue() * Estimate.sizeOfIntArray(0L)) + Estimate.sizeOfLongArray(min), Estimate.sizeOfIntArray(l2.longValue()) + Estimate.sizeOfObjectArray(l2.longValue()) + (l2.longValue() * Estimate.sizeOfIntArray(sampleSize)) + Estimate.sizeOfLongArray(min2))));
        }
        Collections.reverse(arrayList);
        Collections.reverse(arrayList2);
        MemoryEstimations.Builder builder = MemoryEstimations.builder();
        for (int i3 = 0; i3 < size; i3++) {
            LayerConfig layerConfig = layerConfigs.get(i3);
            Long l3 = (Long) arrayList.get(i3);
            Long l4 = (Long) arrayList2.get(i3);
            Long l5 = (Long) arrayList.get(i3 + 1);
            Long l6 = (Long) arrayList2.get(i3 + 1);
            if (i3 == 0) {
                int estimationFeatureDimension = graphSageTrainMemoryEstimateParameters.estimationFeatureDimension();
                MemoryRange of = MemoryRange.of(Estimate.sizeOfDoubleArray(l3.longValue() * estimationFeatureDimension), Estimate.sizeOfDoubleArray(l4.longValue() * estimationFeatureDimension));
                if (isMultiLabel) {
                    of = of.add(MemoryRange.of(Estimate.sizeOfDoubleArray(estimationFeatureDimension)));
                }
                builder.fixed("firstLayer", of);
            }
            AggregatorType aggregatorType = layerConfig.aggregatorType();
            int embeddingDimension = graphSageTrainMemoryEstimateParameters.embeddingDimension();
            switch (aggregatorType) {
                case MEAN:
                    poolAggregatorMemoryEstimator = new MeanAggregatorMemoryEstimator();
                    break;
                case POOL:
                    poolAggregatorMemoryEstimator = new PoolAggregatorMemoryEstimator();
                    break;
                default:
                    throw new IncompatibleClassChangeError();
            }
            builder.fixed(StringFormatting.formatWithLocale("%s %d", aggregatorType.name(), Integer.valueOf(i3 + 1)), poolAggregatorMemoryEstimator.estimate(l5.longValue(), l6.longValue(), l3.longValue(), l4.longValue(), layerConfig.cols(), embeddingDimension));
            if (i3 == size - 1) {
                builder.fixed("normalizeRows", MemoryRange.of(Estimate.sizeOfDoubleArray(l5.longValue() * embeddingDimension), Estimate.sizeOfDoubleArray(l6.longValue() * embeddingDimension)));
            }
        }
        MemoryEstimations.Builder endField = startField.endField();
        if (isMultiLabel) {
            endField.fixed("multiLabelFeatureFunction", MemoryRange.of(Estimate.sizeOfObjectArray(((Long) arrayList.get(0)).longValue()), Estimate.sizeOfObjectArray(((Long) arrayList2.get(0)).longValue())).add(MemoryRange.of(Estimate.sizeOfObjectArray(i))));
        }
        MemoryEstimations.Builder addComponentsOf = endField.startField("forward").addComponentsOf(builder.build());
        if (z) {
            addComponentsOf = addComponentsOf.endField().startField("backward").addComponentsOf(builder.build());
        }
        return addComponentsOf.endField().build();
    }

    public static HugeObjectArray<double[]> initializeSingleLabelFeatures(Graph graph, Collection<String> collection) {
        return FeatureExtraction.extract(graph, FeatureExtraction.propertyExtractors(graph, collection), HugeObjectArray.newArray(double[].class, graph.nodeCount()));
    }

    public static MultiLabelFeatureExtractors multiLabelFeatureExtractors(Graph graph, List<String> list) {
        Map<NodeLabel, Set<String>> filteredPropertyKeysPerNodeLabel = filteredPropertyKeysPerNodeLabel(graph, list);
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        graph.forEachNode(j -> {
            NodeLabel labelOf = labelOf(graph, j);
            hashMap2.computeIfAbsent(labelOf, nodeLabel -> {
                ArrayList arrayList = new ArrayList(FeatureExtraction.propertyExtractors(graph, (Set) filteredPropertyKeysPerNodeLabel.get(nodeLabel), j));
                arrayList.add(new BiasFeature());
                return arrayList;
            });
            hashMap.computeIfAbsent(labelOf, nodeLabel2 -> {
                return Integer.valueOf(FeatureExtraction.featureCount((Collection) hashMap2.get(nodeLabel2)));
            });
            return true;
        });
        return new MultiLabelFeatureExtractors(hashMap, hashMap2);
    }

    public static HugeObjectArray<double[]> initializeMultiLabelFeatures(Graph graph, MultiLabelFeatureExtractors multiLabelFeatureExtractors) {
        HugeObjectArray<double[]> newArray = HugeObjectArray.newArray(double[].class, graph.nodeCount());
        HugeObjectArrayFeatureConsumer hugeObjectArrayFeatureConsumer = new HugeObjectArrayFeatureConsumer(newArray);
        graph.forEachNode(j -> {
            NodeLabel labelOf = labelOf(graph, j);
            List<FeatureExtractor> list = multiLabelFeatureExtractors.extractorsPerLabel().get(labelOf);
            newArray.set(j, new double[multiLabelFeatureExtractors.featureCountPerLabel().get(labelOf).intValue()]);
            FeatureExtraction.extract(j, j, list, hugeObjectArrayFeatureConsumer);
            return true;
        });
        return newArray;
    }

    public static List<LayerConfig> layerConfigs(int i, List<Integer> list, Optional<Long> optional, AggregatorType aggregatorType, ActivationFunctionType activationFunctionType, int i2) {
        Random random = new Random();
        Objects.requireNonNull(random);
        optional.ifPresent((v1) -> {
            r1.setSeed(v1);
        });
        ArrayList arrayList = new ArrayList(list.size());
        int i3 = 0;
        while (i3 < list.size()) {
            arrayList.add(LayerConfig.builder().aggregatorType(aggregatorType).activationFunction(activationFunctionType).rows(i2).cols(i3 == 0 ? i : i2).sampleSize(list.get(i3).intValue()).randomSeed(random.nextLong()).build());
            i3++;
        }
        return arrayList;
    }

    private static Map<NodeLabel, Set<String>> propertyKeysPerNodeLabel(GraphSchema graphSchema) {
        return (Map) graphSchema.nodeSchema().entries().stream().collect(Collectors.toMap((v0) -> {
            return v0.identifier();
        }, mutableNodeSchemaEntry -> {
            return mutableNodeSchemaEntry.properties().keySet();
        }));
    }

    private static Map<NodeLabel, Set<String>> filteredPropertyKeysPerNodeLabel(Graph graph, List<String> list) {
        return (Map) propertyKeysPerNodeLabel(graph.schema()).entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            Stream stream = list.stream();
            Set set = (Set) entry.getValue();
            Objects.requireNonNull(set);
            return (Set) stream.filter((v1) -> {
                return r1.contains(v1);
            }).collect(Collectors.toSet());
        }));
    }

    private static NodeLabel labelOf(IdMap idMap, long j) {
        AtomicReference atomicReference = new AtomicReference();
        MutableInt mutableInt = new MutableInt(0);
        idMap.forEachNodeLabel(j, nodeLabel -> {
            atomicReference.set(nodeLabel);
            return mutableInt.getAndIncrement() == 0;
        });
        if (mutableInt.intValue() != 1) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale("Each node must have exactly one label: nodeId=%d, labels=%s", Long.valueOf(j), idMap.nodeLabels(j)));
        }
        return (NodeLabel) atomicReference.get();
    }
}
