package org.neo4j.gds.procedures.pipelines;

import java.util.stream.Stream;
import org.neo4j.gds.Orientation;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.applications.algorithms.machinery.MutateStep;
import org.neo4j.gds.applications.algorithms.metadata.RelationshipsWritten;
import org.neo4j.gds.core.Aggregation;
import org.neo4j.gds.core.concurrency.DefaultPool;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.loading.SingleTypeRelationships;
import org.neo4j.gds.core.loading.construction.GraphFactory;
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;
import org.neo4j.gds.logging.Log;
import org.neo4j.gds.ml.linkmodels.LinkPredictionResult;
import org.neo4j.gds.ml.linkmodels.PredictedLink;
import org.neo4j.gds.termination.TerminationFlag;

/* loaded from: input_file:org/neo4j/gds/procedures/pipelines/LinkPredictionPipelineMutateStep.class */
class LinkPredictionPipelineMutateStep implements MutateStep<LinkPredictionResult, LinkPredictionMutateMetadata> {
    private final Log log;
    private final LinkPredictionPredictPipelineMutateConfig configuration;
    private final TerminationFlag terminationFlag;
    private final TrainedLPPipelineModel trainedLPPipelineModel;
    private final boolean shouldProduceHistogram;

    /* JADX INFO: Access modifiers changed from: package-private */
    public LinkPredictionPipelineMutateStep(Log log, LinkPredictionPredictPipelineMutateConfig linkPredictionPredictPipelineMutateConfig, TerminationFlag terminationFlag, TrainedLPPipelineModel trainedLPPipelineModel, boolean z) {
        this.log = log;
        this.configuration = linkPredictionPredictPipelineMutateConfig;
        this.terminationFlag = terminationFlag;
        this.trainedLPPipelineModel = trainedLPPipelineModel;
        this.shouldProduceHistogram = z;
    }

    @Override // org.neo4j.gds.applications.algorithms.machinery.MutateStep
    public LinkPredictionMutateMetadata execute(Graph graph, GraphStore graphStore, LinkPredictionResult linkPredictionResult) {
        Graph graph2 = graphStore.getGraph(LPGraphStoreFilterFactory.generate(this.log, this.trainedLPPipelineModel.get(this.configuration.modelName(), this.configuration.username()).trainConfig(), this.configuration, graphStore).predictNodeLabels());
        RelationshipsBuilder build = GraphFactory.initRelationshipsBuilder().aggregation(Aggregation.SINGLE).nodes(graph2).relationshipType(RelationshipType.of(this.configuration.mutateRelationshipType())).orientation(Orientation.UNDIRECTED).addPropertyConfig(GraphFactory.PropertyConfig.of(this.configuration.mutateProperty())).concurrency(this.configuration.concurrency()).executorService(DefaultPool.INSTANCE).build();
        Stream<PredictedLink> stream = linkPredictionResult.stream();
        GdsHistogram readyHistogram = readyHistogram(this.shouldProduceHistogram);
        ParallelUtil.parallelStreamConsume(stream, this.configuration.concurrency(), this.terminationFlag, stream2 -> {
            stream2.forEach(predictedLink -> {
                build.addFromInternal(graph2.toRootNodeId(predictedLink.sourceId()), graph2.toRootNodeId(predictedLink.targetId()), predictedLink.probability());
                readyHistogram.onPredictedLink(predictedLink.probability());
            });
        });
        SingleTypeRelationships build2 = build.build();
        graphStore.addRelationshipType(build2);
        return new LinkPredictionMutateMetadata(new RelationshipsWritten(build2.topology().elementCount()), readyHistogram.finalise());
    }

    private GdsHistogram readyHistogram(boolean z) {
        return z ? new HdrBackedGdsHistogram() : GdsHistogram.DISABLED;
    }
}
