package com.neo4j.gds.config;

import com.neo4j.gds.config.proto.CommonConfigProto;
import com.neo4j.gds.core.model.proto.TrainConfigsProto;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.neo4j.gds.ml.metrics.classification.ClassificationMetricSpecification;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfigImpl;

/* loaded from: input_file:com/neo4j/gds/config/NodeClassificationPipelineTrainConfigSerializer.class */
public final class NodeClassificationPipelineTrainConfigSerializer {
    private NodeClassificationPipelineTrainConfigSerializer() {
    }

    public static TrainConfigsProto.NodeClassificationPipelineTrainConfigProto serialize(NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig) {
        TrainConfigsProto.NodeClassificationPipelineTrainConfigProto.Builder newBuilder = TrainConfigsProto.NodeClassificationPipelineTrainConfigProto.newBuilder();
        newBuilder.setModelConfig(CommonConfigSerializers.serializeModelConfig(nodeClassificationPipelineTrainConfig)).setAlgoBaseConfig(CommonConfigSerializers.serializeAlgoBaseConfig(nodeClassificationPipelineTrainConfig)).addAllTargetNodeLabels(nodeClassificationPipelineTrainConfig.targetNodeLabels()).setTargetProperty(nodeClassificationPipelineTrainConfig.targetProperty()).setPipeline(nodeClassificationPipelineTrainConfig.pipeline()).setGraphName(nodeClassificationPipelineTrainConfig.graphName());
        Objects.requireNonNull(newBuilder);
        CommonConfigSerializers.serializeRandomSeedConfig(nodeClassificationPipelineTrainConfig, newBuilder::setRandomSeed);
        Stream<R> map = nodeClassificationPipelineTrainConfig.metrics().stream().map((v0) -> {
            return v0.toString();
        });
        Objects.requireNonNull(newBuilder);
        map.forEach(newBuilder::addMetrics);
        return newBuilder.build();
    }

    public static NodeClassificationPipelineTrainConfig deserialize(TrainConfigsProto.NodeClassificationPipelineTrainConfigProto nodeClassificationPipelineTrainConfigProto) {
        NodeClassificationPipelineTrainConfigImpl.Builder builder = NodeClassificationPipelineTrainConfigImpl.builder();
        builder.modelName(nodeClassificationPipelineTrainConfigProto.getModelConfig().getModelName()).modelUser(nodeClassificationPipelineTrainConfigProto.getModelConfig().getUsername()).targetProperty(nodeClassificationPipelineTrainConfigProto.getTargetProperty()).graphName(nodeClassificationPipelineTrainConfigProto.getGraphName()).pipeline(nodeClassificationPipelineTrainConfigProto.getPipeline()).relationshipTypes(nodeClassificationPipelineTrainConfigProto.getAlgoBaseConfig().getRelationshipTypesList()).targetNodeLabels(nodeClassificationPipelineTrainConfigProto.getTargetNodeLabelsList().isEmpty() ? nodeClassificationPipelineTrainConfigProto.getAlgoBaseConfig().getNodeLabelsList() : nodeClassificationPipelineTrainConfigProto.getTargetNodeLabelsList()).sudo(nodeClassificationPipelineTrainConfigProto.getAlgoBaseConfig().getSudo()).concurrency(Integer.valueOf(nodeClassificationPipelineTrainConfigProto.getAlgoBaseConfig().getConcurrency()));
        CommonConfigProto.OptionalRandomSeedProto randomSeed = nodeClassificationPipelineTrainConfigProto.getRandomSeed();
        if (randomSeed.hasValue()) {
            builder.randomSeed(Long.valueOf(randomSeed.getValue()));
        }
        builder.metrics((List) nodeClassificationPipelineTrainConfigProto.getMetricsList().stream().map((v0) -> {
            return ClassificationMetricSpecification.Parser.parse(v0);
        }).collect(Collectors.toList()));
        return builder.build();
    }
}
