package org.neo4j.gds.procedures.pipelines;

import java.util.List;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.nodeClassification.NodeClassificationPredict;
import org.neo4j.gds.ml.pipeline.NodePropertyStepExecutor;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineModelInfo;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.procedures.algorithms.AlgorithmsProcedureFacade;

/* loaded from: input_file:org/neo4j/gds/procedures/pipelines/NodeClassificationPredictPipelineEstimator.class */
public class NodeClassificationPredictPipelineEstimator {
    private final ModelCatalog modelCatalog;
    private final AlgorithmsProcedureFacade algorithmsProcedureFacade;

    public NodeClassificationPredictPipelineEstimator(ModelCatalog modelCatalog, AlgorithmsProcedureFacade algorithmsProcedureFacade) {
        this.modelCatalog = modelCatalog;
        this.algorithmsProcedureFacade = algorithmsProcedureFacade;
    }

    public MemoryEstimation estimate(Model<Classifier.ClassifierData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> model, NodeClassificationPredictPipelineBaseConfig nodeClassificationPredictPipelineBaseConfig) {
        NodePropertyPredictPipeline pipeline = model.customInfo().pipeline();
        int size = model.customInfo().classes().size();
        return MemoryEstimations.maxEstimation(List.of(NodePropertyStepExecutor.estimateNodePropertySteps(this.algorithmsProcedureFacade, this.modelCatalog, nodeClassificationPredictPipelineBaseConfig.username(), pipeline.nodePropertySteps(), nodeClassificationPredictPipelineBaseConfig.targetNodeLabels().isEmpty() ? model.trainConfig().targetNodeLabels() : nodeClassificationPredictPipelineBaseConfig.targetNodeLabels(), nodeClassificationPredictPipelineBaseConfig.relationshipTypes().isEmpty() ? model.trainConfig().relationshipTypes() : nodeClassificationPredictPipelineBaseConfig.relationshipTypes()), MemoryEstimations.builder().add("Pipeline Predict", NodeClassificationPredict.memoryEstimationWithDerivedBatchSize(model.data().trainerMethod(), nodeClassificationPredictPipelineBaseConfig.includePredictedProbabilities(), 100, model.data().featureDimension(), size, false)).build()));
    }
}
