package org.neo4j.gds.procedures.pipelines;

import org.neo4j.gds.api.Graph;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureExtractor;
import org.neo4j.gds.similarity.knn.NeighborFilter;
import org.neo4j.gds.similarity.knn.NeighborFilterFactory;
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;

/* loaded from: input_file:org/neo4j/gds/procedures/pipelines/LinkPredictionSimilarityComputer.class */
public class LinkPredictionSimilarityComputer implements SimilarityComputer {
    private static final int POSITIVE_CLASS_INDEX = 1;
    private final LinkFeatureExtractor linkFeatureExtractor;
    private final Classifier classifier;

    /* loaded from: input_file:org/neo4j/gds/procedures/pipelines/LinkPredictionSimilarityComputer$LinkFilter.class */
    public static final class LinkFilter implements NeighborFilter {
        private final LPNodeFilter sourceNodeFilter;
        private final LPNodeFilter targetNodeFilter;
        private final Graph graph;

        private LinkFilter(Graph graph, LPNodeFilter lPNodeFilter, LPNodeFilter lPNodeFilter2) {
            this.graph = graph;
            this.sourceNodeFilter = lPNodeFilter;
            this.targetNodeFilter = lPNodeFilter2;
        }

        @Override // org.neo4j.gds.similarity.knn.NeighborFilter
        public boolean excludeNodePair(long j, long j2) {
            if (j == j2) {
                return true;
            }
            return !((this.sourceNodeFilter.test(j) && this.targetNodeFilter.test(j2)) || (this.sourceNodeFilter.test(j2) && this.targetNodeFilter.test(j))) || this.graph.exists(j, j2);
        }

        @Override // org.neo4j.gds.similarity.knn.NeighborFilter
        public long lowerBoundOfPotentialNeighbours(long j) {
            return this.sourceNodeFilter.test(j) ? Math.max((this.targetNodeFilter.validNodeCount() - 1) - this.graph.degree(j), 0L) : Math.max((this.sourceNodeFilter.validNodeCount() - 1) - this.graph.degree(j), 0L);
        }
    }

    /* loaded from: input_file:org/neo4j/gds/procedures/pipelines/LinkPredictionSimilarityComputer$LinkFilterFactory.class */
    public static class LinkFilterFactory implements NeighborFilterFactory {
        private final Graph graph;
        private final LPNodeFilter sourceNodeFilter;
        private final LPNodeFilter targetNodeFilter;

        public LinkFilterFactory(Graph graph, LPNodeFilter lPNodeFilter, LPNodeFilter lPNodeFilter2) {
            this.graph = graph;
            this.sourceNodeFilter = lPNodeFilter;
            this.targetNodeFilter = lPNodeFilter2;
        }

        @Override // org.neo4j.gds.similarity.knn.NeighborFilterFactory
        public NeighborFilter create() {
            return new LinkFilter(this.graph.concurrentCopy(), this.sourceNodeFilter, this.targetNodeFilter);
        }
    }

    public LinkPredictionSimilarityComputer(LinkFeatureExtractor linkFeatureExtractor, Classifier classifier) {
        this.linkFeatureExtractor = linkFeatureExtractor;
        this.classifier = classifier;
    }

    @Override // org.neo4j.gds.similarity.knn.metrics.SimilarityComputer
    public double similarity(long j, long j2) {
        return this.classifier.predictProbabilities(this.linkFeatureExtractor.extractFeatures(j, j2))[1];
    }

    @Override // org.neo4j.gds.similarity.knn.metrics.SimilarityComputer
    public boolean isSymmetric() {
        return this.linkFeatureExtractor.isSymmetric();
    }
}
