package com.neo4j.gds.config;

import com.neo4j.gds.ml.model.proto.ModelTrainingConfigsProto;
import org.neo4j.gds.ml.models.mlp.MLPClassifierTrainConfig;
import org.neo4j.gds.ml.models.mlp.MLPClassifierTrainConfigImpl;

/* loaded from: input_file:com/neo4j/gds/config/MLPTrainConfigSerializer.class */
public final class MLPTrainConfigSerializer {
    private MLPTrainConfigSerializer() {
    }

    public static ModelTrainingConfigsProto.MLPTrainConfigProto serialize(MLPClassifierTrainConfig mLPClassifierTrainConfig) {
        return ModelTrainingConfigsProto.MLPTrainConfigProto.newBuilder().setPenalty(mLPClassifierTrainConfig.penalty()).setFocusWeight(mLPClassifierTrainConfig.focusWeight()).addAllClassWeights(mLPClassifierTrainConfig.classWeights()).setTrainingConfig(CommonConfigSerializers.serializeTrainingConfig(mLPClassifierTrainConfig)).addAllHiddenLayerSizes(mLPClassifierTrainConfig.hiddenLayerSizes()).build();
    }

    public static MLPClassifierTrainConfig deserialize(ModelTrainingConfigsProto.MLPTrainConfigProto mLPTrainConfigProto) {
        ModelTrainingConfigsProto.TrainingConfigProto trainingConfig = mLPTrainConfigProto.getTrainingConfig();
        return MLPClassifierTrainConfigImpl.builder().penalty(mLPTrainConfigProto.getPenalty()).focusWeight(mLPTrainConfigProto.getFocusWeight()).classWeights(mLPTrainConfigProto.getClassWeightsList()).batchSize(trainingConfig.getBatchSize()).maxEpochs(trainingConfig.getMaxEpochs()).minEpochs(trainingConfig.getMinEpochs()).patience(trainingConfig.getPatience()).tolerance(trainingConfig.getTolerance()).hiddenLayerSizes(mLPTrainConfigProto.getHiddenLayerSizesList()).build();
    }
}
