package com.neo4j.gds.embeddings.graphsage;

import com.neo4j.gds.core.model.proto.GraphSageProto;
import com.neo4j.gds.core.model.proto.TensorProto;
import com.neo4j.gds.ml.core.TensorSerializer;
import java.util.Map;
import java.util.stream.Collectors;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.embeddings.graphsage.FeatureFunction;
import org.neo4j.gds.embeddings.graphsage.MultiLabelFeatureFunction;
import org.neo4j.gds.embeddings.graphsage.SingleLabelFeatureFunction;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:com/neo4j/gds/embeddings/graphsage/FeatureFunctionSerializer.class */
public final class FeatureFunctionSerializer {
    private FeatureFunctionSerializer() {
    }

    public static GraphSageProto.FeatureFunction toSerializable(FeatureFunction featureFunction) {
        if (featureFunction instanceof SingleLabelFeatureFunction) {
            return GraphSageProto.FeatureFunction.newBuilder().setFunctionType(GraphSageProto.FeatureFunctionType.SINGLE).build();
        }
        if (!(featureFunction instanceof MultiLabelFeatureFunction)) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale("Unknown feature function class: %s", featureFunction));
        }
        MultiLabelFeatureFunction multiLabelFeatureFunction = (MultiLabelFeatureFunction) featureFunction;
        return GraphSageProto.FeatureFunction.newBuilder().setFunctionType(GraphSageProto.FeatureFunctionType.MULTI).putAllWeightsByLabel(unwrapWeightsByLabelMatrix(multiLabelFeatureFunction.weightsByLabel())).setProjectedFeatureDimension(multiLabelFeatureFunction.projectedFeatureDimension()).build();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static FeatureFunction fromSerializable(GraphSageProto.FeatureFunction featureFunction) {
        switch (featureFunction.getFunctionType()) {
            case SINGLE:
                return new SingleLabelFeatureFunction();
            case MULTI:
                return new MultiLabelFeatureFunction(wrapWeightsByLabelMatrix(featureFunction.getWeightsByLabelMap()), featureFunction.getProjectedFeatureDimension());
            default:
                throw new IllegalArgumentException(StringFormatting.formatWithLocale("Unknown proto feature function class: %s", featureFunction));
        }
    }

    private static Map<String, TensorProto.Matrix> unwrapWeightsByLabelMatrix(Map<NodeLabel, Weights<Matrix>> map) {
        return (Map) map.entrySet().stream().collect(Collectors.toMap(entry -> {
            return ((NodeLabel) entry.getKey()).name();
        }, entry2 -> {
            return TensorSerializer.serialize((Matrix) ((Weights) entry2.getValue()).data());
        }));
    }

    private static Map<NodeLabel, Weights<Matrix>> wrapWeightsByLabelMatrix(Map<String, TensorProto.Matrix> map) {
        return (Map) map.entrySet().stream().collect(Collectors.toMap(entry -> {
            return NodeLabel.of((String) entry.getKey());
        }, entry2 -> {
            return new Weights(TensorSerializer.deserialize((TensorProto.Matrix) entry2.getValue()));
        }));
    }
}
