package com.neo4j.gds.embeddings.graphsage;

import com.neo4j.gds.core.model.proto.GraphSageProto;
import java.util.List;
import java.util.stream.Collectors;
import org.neo4j.gds.embeddings.graphsage.ActivationFunctionFactory;
import org.neo4j.gds.embeddings.graphsage.Aggregator;
import org.neo4j.gds.embeddings.graphsage.Layer;
import org.neo4j.gds.embeddings.graphsage.MaxPoolAggregatingLayer;
import org.neo4j.gds.embeddings.graphsage.MaxPoolingAggregator;
import org.neo4j.gds.embeddings.graphsage.MeanAggregatingLayer;
import org.neo4j.gds.embeddings.graphsage.MeanAggregator;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:com/neo4j/gds/embeddings/graphsage/LayerSerializer.class */
public final class LayerSerializer {
    private LayerSerializer() {
    }

    public static Layer[] deserialize(List<GraphSageProto.Layer> list) {
        return (Layer[]) ((List) list.stream().map(LayerSerializer::deserialize).collect(Collectors.toList())).toArray(new Layer[0]);
    }

    public static GraphSageProto.Layer serialize(Layer layer) {
        GraphSageProto.Layer.Builder newBuilder = GraphSageProto.Layer.newBuilder();
        Aggregator aggregator = layer.aggregator();
        switch (aggregator.type()) {
            case MEAN:
                newBuilder.setMean(MeanAggregatorSerializer.toSerializable((MeanAggregator) aggregator));
                break;
            case POOL:
                newBuilder.setPool(MaxPoolingAggregatorSerializer.toSerializable((MaxPoolingAggregator) aggregator));
                break;
        }
        newBuilder.setSampleSize(layer.sampleSize());
        return newBuilder.build();
    }

    public static Layer deserialize(GraphSageProto.Layer layer) {
        GraphSageProto.Layer.AggregatorCase aggregatorCase = layer.getAggregatorCase();
        switch (aggregatorCase) {
            case MEAN:
                MeanAggregator fromSerializable = MeanAggregatorSerializer.fromSerializable(layer.getMean());
                return new MeanAggregatingLayer(new Weights(fromSerializable.weightsData()), layer.getSampleSize(), ActivationFunctionFactory.activationFunctionWrapper(fromSerializable.activationFunctionType()));
            case POOL:
                MaxPoolingAggregator fromSerializable2 = MaxPoolingAggregatorSerializer.fromSerializable(layer.getPool());
                return new MaxPoolAggregatingLayer(layer.getSampleSize(), new Weights(fromSerializable2.poolWeights()), new Weights(fromSerializable2.selfWeights()), new Weights(fromSerializable2.neighborsWeights()), new Weights(fromSerializable2.bias()), ActivationFunctionFactory.activationFunctionWrapper(fromSerializable2.activationFunctionType()));
            default:
                throw new IllegalArgumentException(StringFormatting.formatWithLocale("Unknown aggregator: %s", aggregatorCase));
        }
    }
}
