package com.neo4j.gds.config;

import com.neo4j.gds.config.proto.CommonConfigProto;
import com.neo4j.gds.core.model.proto.TrainConfigsProto;
import com.neo4j.gds.shaded.com.google.protobuf.ProtocolStringList;
import java.util.Objects;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfigImpl;

/* loaded from: input_file:com/neo4j/gds/config/LinkPredictionPipelineTrainConfigSerializer.class */
public final class LinkPredictionPipelineTrainConfigSerializer {
    private LinkPredictionPipelineTrainConfigSerializer() {
    }

    public static TrainConfigsProto.LinkPredictionPipelineTrainConfigProto serialize(LinkPredictionTrainConfig linkPredictionTrainConfig) {
        TrainConfigsProto.LinkPredictionPipelineTrainConfigProto.Builder newBuilder = TrainConfigsProto.LinkPredictionPipelineTrainConfigProto.newBuilder();
        newBuilder.setModelConfig(CommonConfigSerializers.serializeModelConfig(linkPredictionTrainConfig)).setAlgoBaseConfig(CommonConfigSerializers.serializeAlgoBaseConfig(linkPredictionTrainConfig)).setSourceNodeLabel(linkPredictionTrainConfig.sourceNodeLabel()).setTargetNodeLabel(linkPredictionTrainConfig.targetNodeLabel()).setTargetRelationshipType(linkPredictionTrainConfig.targetRelationshipType()).setPipeline(linkPredictionTrainConfig.pipeline()).setGraphName(linkPredictionTrainConfig.graphName()).setNegativeClassWeight(linkPredictionTrainConfig.negativeClassWeight());
        Objects.requireNonNull(newBuilder);
        CommonConfigSerializers.serializeRandomSeedConfig(linkPredictionTrainConfig, newBuilder::setRandomSeed);
        return newBuilder.build();
    }

    public static LinkPredictionTrainConfig deserialize(TrainConfigsProto.LinkPredictionPipelineTrainConfigProto linkPredictionPipelineTrainConfigProto) {
        String sourceNodeLabel;
        String targetNodeLabel;
        String targetRelationshipType;
        LinkPredictionTrainConfigImpl.Builder builder = LinkPredictionTrainConfigImpl.builder();
        if (linkPredictionPipelineTrainConfigProto.getSourceNodeLabel().isEmpty()) {
            sourceNodeLabel = "*";
            targetNodeLabel = "*";
            ProtocolStringList relationshipTypesList = linkPredictionPipelineTrainConfigProto.getAlgoBaseConfig().getRelationshipTypesList();
            targetRelationshipType = (relationshipTypesList.size() != 1 || relationshipTypesList.get(0).equals("*")) ? "INVALID_TARGET_RELATIONSHIPTYPE_IN_OLDER_MODEL" : relationshipTypesList.get(0);
        } else {
            sourceNodeLabel = linkPredictionPipelineTrainConfigProto.getSourceNodeLabel();
            targetNodeLabel = linkPredictionPipelineTrainConfigProto.getTargetNodeLabel();
            targetRelationshipType = linkPredictionPipelineTrainConfigProto.getTargetRelationshipType();
        }
        builder.modelName(linkPredictionPipelineTrainConfigProto.getModelConfig().getModelName()).modelUser(linkPredictionPipelineTrainConfigProto.getModelConfig().getUsername()).graphName(linkPredictionPipelineTrainConfigProto.getGraphName()).pipeline(linkPredictionPipelineTrainConfigProto.getPipeline()).targetRelationshipType(targetRelationshipType).sourceNodeLabel(sourceNodeLabel).targetNodeLabel(targetNodeLabel).sudo(linkPredictionPipelineTrainConfigProto.getAlgoBaseConfig().getSudo()).concurrency(Integer.valueOf(linkPredictionPipelineTrainConfigProto.getAlgoBaseConfig().getConcurrency())).negativeClassWeight(linkPredictionPipelineTrainConfigProto.getNegativeClassWeight());
        CommonConfigProto.OptionalRandomSeedProto randomSeed = linkPredictionPipelineTrainConfigProto.getRandomSeed();
        if (randomSeed.hasValue()) {
            builder.randomSeed(Long.valueOf(randomSeed.getValue()));
        }
        return builder.build();
    }
}
