package org.neo4j.gds.embeddings.graphsage.algo;

import java.util.concurrent.ExecutorService;
import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.compat.GdsVersionInfoProvider;
import org.neo4j.gds.core.concurrency.DefaultPool;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer;
import org.neo4j.gds.embeddings.graphsage.TrainConfigTransformer;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.ml.core.EmbeddingUtils;
import org.neo4j.gds.termination.TerminationFlag;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainAlgorithmFactory.class */
public final class GraphSageTrainAlgorithmFactory extends GraphAlgorithmFactory<GraphSageTrain, GraphSageTrainConfig> {
    @Override // org.neo4j.gds.AlgorithmFactory
    public String taskName() {
        return GraphSageTrain.class.getSimpleName();
    }

    @Override // org.neo4j.gds.AlgorithmFactory
    public GraphSageTrain build(Graph graph, GraphSageTrainConfig graphSageTrainConfig, ProgressTracker progressTracker) {
        ExecutorService executorService = DefaultPool.INSTANCE;
        String gdsVersion = GdsVersionInfoProvider.GDS_VERSION_INFO.gdsVersion();
        if (graphSageTrainConfig.hasRelationshipWeightProperty()) {
            EmbeddingUtils.validateRelationshipWeightPropertyValue(graph, graphSageTrainConfig.concurrency(), executorService);
        }
        GraphSageTrainParameters parameters = TrainConfigTransformer.toParameters(graphSageTrainConfig);
        return graphSageTrainConfig.isMultiLabel() ? new MultiLabelGraphSageTrain(graph, parameters, graphSageTrainConfig.projectedFeatureDimension().get().intValue(), executorService, progressTracker, TerminationFlag.RUNNING_TRUE, gdsVersion, graphSageTrainConfig) : new SingleLabelGraphSageTrain(graph, parameters, executorService, progressTracker, TerminationFlag.RUNNING_TRUE, gdsVersion, graphSageTrainConfig);
    }

    public MemoryEstimation memoryEstimation(GraphSageTrainMemoryEstimateParameters graphSageTrainMemoryEstimateParameters) {
        return new GraphSageTrainEstimateDefinition(graphSageTrainMemoryEstimateParameters).memoryEstimation();
    }

    @Override // org.neo4j.gds.AlgorithmFactory
    public MemoryEstimation memoryEstimation(GraphSageTrainConfig graphSageTrainConfig) {
        return memoryEstimation(TrainConfigTransformer.toMemoryEstimateParameters(graphSageTrainConfig));
    }

    public Task progressTask(long j, GraphSageTrainParameters graphSageTrainParameters) {
        return Tasks.task(taskName(), GraphSageModelTrainer.progressTasks(graphSageTrainParameters.numberOfBatches(j), graphSageTrainParameters.batchesPerIteration(j), graphSageTrainParameters.maxIterations(), graphSageTrainParameters.epochs()));
    }

    @Override // org.neo4j.gds.AlgorithmFactory
    public Task progressTask(Graph graph, GraphSageTrainConfig graphSageTrainConfig) {
        return progressTask(graph.nodeCount(), TrainConfigTransformer.toParameters(graphSageTrainConfig));
    }
}
