package org.neo4j.gds.procedures.pipelines;

import com.neo4j.gds.arrow.core.Constants;
import java.util.Map;
import org.neo4j.gds.api.User;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.ml.models.BaseModelData;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.model.ModelConfig;

/* loaded from: input_file:org/neo4j/gds/procedures/pipelines/NodeClassificationPredictConfigPreProcessor.class */
final class NodeClassificationPredictConfigPreProcessor {
    private final ModelCatalog modelCatalog;
    private final User user;

    /* JADX INFO: Access modifiers changed from: package-private */
    public NodeClassificationPredictConfigPreProcessor(ModelCatalog modelCatalog, User user) {
        this.modelCatalog = modelCatalog;
        this.user = user;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void enhanceInputWithPipelineParameters(Map<String, Object> map) {
        if (map.containsKey(ModelConfig.MODEL_NAME_KEY)) {
            Model model = this.modelCatalog.get(this.user.getUsername(), (String) map.get(ModelConfig.MODEL_NAME_KEY), BaseModelData.class, NodeClassificationPipelineTrainConfig.class, Model.CustomInfo.class);
            map.putIfAbsent(Constants.TARGET_NODE_LABEL_FIELD, ((NodeClassificationPipelineTrainConfig) model.trainConfig()).targetNodeLabels());
            map.putIfAbsent(AlgoBaseConfig.RELATIONSHIP_TYPES_KEY, ((NodeClassificationPipelineTrainConfig) model.trainConfig()).relationshipTypes());
        }
    }
}
