package com.neo4j.gds.embeddings.graphsage;

import com.neo4j.gds.core.model.proto.GraphSageCommonProto;
import com.neo4j.gds.core.model.proto.GraphSageProto;
import com.neo4j.gds.ml.core.TensorSerializer;
import org.neo4j.gds.embeddings.graphsage.ActivationFunctionFactory;
import org.neo4j.gds.embeddings.graphsage.ActivationFunctionType;
import org.neo4j.gds.embeddings.graphsage.MaxPoolingAggregator;
import org.neo4j.gds.ml.core.functions.Weights;

/* loaded from: input_file:com/neo4j/gds/embeddings/graphsage/MaxPoolingAggregatorSerializer.class */
public final class MaxPoolingAggregatorSerializer {
    private MaxPoolingAggregatorSerializer() {
    }

    public static GraphSageProto.MaxPoolingAggregator toSerializable(MaxPoolingAggregator maxPoolingAggregator) {
        return GraphSageProto.MaxPoolingAggregator.newBuilder().setPoolWeights(TensorSerializer.serialize(maxPoolingAggregator.poolWeights())).setSelfWeights(TensorSerializer.serialize(maxPoolingAggregator.selfWeights())).setNeighborsWeights(TensorSerializer.serialize(maxPoolingAggregator.neighborsWeights())).setBias(TensorSerializer.serialize(maxPoolingAggregator.bias())).setActivationFunction(GraphSageCommonProto.ActivationFunction.valueOf(maxPoolingAggregator.activationFunctionType().name())).build();
    }

    public static MaxPoolingAggregator fromSerializable(GraphSageProto.MaxPoolingAggregator maxPoolingAggregator) {
        return new MaxPoolingAggregator(new Weights(TensorSerializer.deserialize(maxPoolingAggregator.getPoolWeights())), new Weights(TensorSerializer.deserialize(maxPoolingAggregator.getSelfWeights())), new Weights(TensorSerializer.deserialize(maxPoolingAggregator.getNeighborsWeights())), new Weights(TensorSerializer.deserialize(maxPoolingAggregator.getBias())), ActivationFunctionFactory.activationFunctionWrapper(ActivationFunctionType.of(maxPoolingAggregator.getActivationFunction().name())));
    }
}
