package com.neo4j.gds.ml.pipeline;

import com.neo4j.gds.ModelDataSerializer;
import com.neo4j.gds.ml.core.TensorSerializer;
import com.neo4j.gds.ml.model.proto.MLPDataProto;
import com.neo4j.gds.shaded.com.google.protobuf.GeneratedMessageV3;
import com.neo4j.gds.shaded.com.google.protobuf.InvalidProtocolBufferException;
import com.neo4j.gds.shaded.com.google.protobuf.Parser;
import com.neo4j.gds.shaded.org.jetbrains.annotations.NotNull;
import java.util.List;
import java.util.stream.Collectors;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.models.mlp.MLPClassifierData;

/* loaded from: input_file:com/neo4j/gds/ml/pipeline/MLPDataSerializer.class */
public class MLPDataSerializer implements ModelDataSerializer {
    @Override // com.neo4j.gds.ModelDataSerializer
    public MLPDataProto.MLPData serialize(Object obj) {
        MLPClassifierData mLPClassifierData = (MLPClassifierData) obj;
        List list = (List) mLPClassifierData.weights().stream().map((v0) -> {
            return v0.data();
        }).map(TensorSerializer::serialize).collect(Collectors.toList());
        return MLPDataProto.MLPData.newBuilder().addAllWeights(list).addAllBiases((List) mLPClassifierData.biases().stream().map((v0) -> {
            return v0.data();
        }).map(TensorSerializer::serialize).collect(Collectors.toList())).build();
    }

    @Override // com.neo4j.gds.ModelDataSerializer
    @NotNull
    public MLPClassifierData deserialize(GeneratedMessageV3 generatedMessageV3) throws InvalidProtocolBufferException {
        MLPDataProto.MLPData mLPData = (MLPDataProto.MLPData) generatedMessageV3;
        if (mLPData.getBiasesCount() == 0 || mLPData.getWeightsCount() == 0) {
            throw new InvalidProtocolBufferException("Could not parse serializedData because biases or weights count is zero");
        }
        List list = (List) mLPData.getWeightsList().stream().map(TensorSerializer::deserialize).map((v1) -> {
            return new Weights(v1);
        }).collect(Collectors.toList());
        return MLPClassifierData.builder().weights(list).biases((List) mLPData.getBiasesList().stream().map(TensorSerializer::deserialize).map((v1) -> {
            return new Weights(v1);
        }).collect(Collectors.toList())).build();
    }

    @Override // com.neo4j.gds.ModelDataSerializer
    public Parser<MLPDataProto.MLPData> parser() {
        return MLPDataProto.MLPData.parser();
    }
}
