package org.neo4j.gds.procedures.pipelines;

import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.schema.Direction;
import org.neo4j.gds.config.ElementTypeValidator;
import org.neo4j.gds.logging.Log;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.gds.utils.StringJoining;

/* loaded from: input_file:org/neo4j/gds/procedures/pipelines/LPGraphStoreFilterFactory.class */
public final class LPGraphStoreFilterFactory {
    private LPGraphStoreFilterFactory() {
    }

    public static LPGraphStoreFilter generate(Log log, LinkPredictionTrainConfig linkPredictionTrainConfig, LinkPredictionPredictPipelineBaseConfig linkPredictionPredictPipelineBaseConfig, GraphStore graphStore) {
        Collection<NodeLabel> collection = (Collection) linkPredictionPredictPipelineBaseConfig.sourceNodeLabel().map(str -> {
            return ElementTypeValidator.resolve(graphStore, List.of(str));
        }).orElse(ElementTypeValidator.resolveAndValidate(graphStore, List.of(linkPredictionTrainConfig.sourceNodeLabel()), "`sourceNodeLabel` from the model's train config"));
        Collection<NodeLabel> collection2 = (Collection) linkPredictionPredictPipelineBaseConfig.targetNodeLabel().map(str2 -> {
            return ElementTypeValidator.resolve(graphStore, List.of(str2));
        }).orElse(ElementTypeValidator.resolveAndValidate(graphStore, List.of(linkPredictionTrainConfig.targetNodeLabel()), "`targetNodeLabel` from the model's train config"));
        Collection<RelationshipType> resolveAndValidateTypes = !linkPredictionPredictPipelineBaseConfig.relationshipTypes().isEmpty() ? ElementTypeValidator.resolveAndValidateTypes(graphStore, linkPredictionPredictPipelineBaseConfig.relationshipTypes(), "`relationshipTypes` from the model's predict config") : ElementTypeValidator.resolveAndValidateTypes(graphStore, List.of(linkPredictionTrainConfig.targetRelationshipType()), "`targetRelationshipType` from the model's train config");
        validateGraphFilter(graphStore, resolveAndValidateTypes);
        LPGraphStoreFilter build = ImmutableLPGraphStoreFilter.builder().sourceNodeLabels(collection).targetNodeLabels(collection2).nodePropertyStepsBaseLabels((Set) Stream.of((Object[]) new Collection[]{collection2, collection}).flatMap((v0) -> {
            return v0.stream();
        }).collect(Collectors.toSet())).predictRelationshipTypes(resolveAndValidateTypes).build();
        log.info(StringFormatting.formatWithLocale("The graph filters used for filtering in prediction is %s", build));
        return build;
    }

    static void validateGraphFilter(GraphStore graphStore, Collection<RelationshipType> collection) {
        List list = (List) graphStore.schema().filterRelationshipTypes(new HashSet(collection)).relationshipSchema().directions().entrySet().stream().filter(entry -> {
            return entry.getValue() != Direction.UNDIRECTED;
        }).map((v0) -> {
            return v0.getKey();
        }).map((v0) -> {
            return v0.name();
        }).collect(Collectors.toList());
        if (!list.isEmpty()) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale("Procedure requires all relationships of %s to be UNDIRECTED, but found %s to be directed.", StringJoining.join((Stream<String>) collection.stream().map((v0) -> {
                return v0.name();
            })), StringJoining.join(list)));
        }
    }
}
