package org.neo4j.gds.ml.metrics.classification;

import com.neo4j.gds.shaded.org.eclipse.collections.api.block.function.primitive.LongIntToObjectFunction;
import com.neo4j.gds.shaded.org.intellij.lang.annotations.RegExp;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.neo4j.gds.collections.LongMultiSet;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;
import org.neo4j.gds.mem.MemoryRange;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/metrics/classification/ClassificationMetricSpecification.class */
public final class ClassificationMetricSpecification {
    private final String stringRepresentation;
    private final BiFunction<LocalIdMap, LongMultiSet, Stream<Metric>> metricFactory;

    /* loaded from: input_file:org/neo4j/gds/ml/metrics/classification/ClassificationMetricSpecification$Parser.class */
    public static final class Parser {

        @RegExp
        private static final String NUMBER_OR_STAR = "(-?[\\d]+|\\*)";

        @RegExp
        private static final String CLASS_NAME_PATTERN = "(.+)";
        private static final List<String> MODEL_SPECIFIC_METRICS = List.of(OutOfBagError.OUT_OF_BAG_ERROR.name());
        private static final Map<String, LongIntToObjectFunction<ClassificationMetric>> SINGLE_CLASS_METRIC_FACTORIES = Map.of(F1Score.NAME, F1Score::new, Precision.NAME, Precision::new, Recall.NAME, Recall::new, "ACCURACY", Accuracy::new);
        private static final Map<String, BiFunction<LocalIdMap, LongMultiSet, ClassificationMetric>> ALL_CLASS_METRIC_FACTORIES = Map.of(F1Weighted.NAME, F1Weighted::new, F1Macro.NAME, (localIdMap, longMultiSet) -> {
            return new F1Macro(localIdMap);
        }, "ACCURACY", (localIdMap2, longMultiSet2) -> {
            return new GlobalAccuracy();
        });
        private static final Pattern SINGLE_CLASS_METRIC_PATTERN = Pattern.compile("(.+)\\(\\s*CLASS\\s*=\\s*(-?[\\d]+|\\*)\\s*\\)");

        private Parser() {
        }

        public static Iterable<String> singleClassMetrics() {
            return SINGLE_CLASS_METRIC_FACTORIES.keySet();
        }

        public static Iterable<String> allClassMetrics() {
            return ALL_CLASS_METRIC_FACTORIES.keySet();
        }

        /* JADX WARN: Multi-variable type inference failed */
        public static List<ClassificationMetricSpecification> parse(List<?> list) {
            if (list.isEmpty()) {
                throw new IllegalArgumentException(StringFormatting.formatWithLocale("No metrics specified, we require at least one", new Object[0]));
            }
            if (list.get(0) instanceof ClassificationMetricSpecification) {
                return list;
            }
            String upperCase = ((String) list.get(0)).toUpperCase(Locale.ENGLISH);
            ArrayList arrayList = new ArrayList();
            if (upperCase.contains("*")) {
                arrayList.add(StringFormatting.formatWithLocale("The primary (first) metric provided must be one of %s.", String.join(", ", validPrimaryMetricExpressions())));
            }
            List list2 = (List) list.stream().filter(Parser::invalidSpecification).collect(Collectors.toList());
            if (!list2.isEmpty()) {
                arrayList.add(errorMessage(list2));
            }
            if (arrayList.isEmpty()) {
                return (List) list.stream().map(Parser::parse).distinct().collect(Collectors.toList());
            }
            throw new IllegalArgumentException(String.join(" ", arrayList));
        }

        public static ClassificationMetricSpecification parse(Object obj) {
            if (obj instanceof ClassificationMetricSpecification) {
                return (ClassificationMetricSpecification) obj;
            }
            if (!(obj instanceof String)) {
                throw new IllegalArgumentException(StringFormatting.formatWithLocale("Expected MetricSpecification or String. Got %s.", obj.getClass().getSimpleName()));
            }
            String str = (String) obj;
            String upperCaseWithLocale = StringFormatting.toUpperCaseWithLocale(str);
            if (upperCaseWithLocale.equals(OutOfBagError.OUT_OF_BAG_ERROR.name())) {
                return ClassificationMetricSpecification.createSpecification((localIdMap, longMultiSet) -> {
                    return Stream.of(OutOfBagError.OUT_OF_BAG_ERROR);
                }, upperCaseWithLocale);
            }
            Matcher matcher = SINGLE_CLASS_METRIC_PATTERN.matcher(upperCaseWithLocale);
            if (!matcher.matches()) {
                BiFunction<LocalIdMap, LongMultiSet, ClassificationMetric> biFunction = ALL_CLASS_METRIC_FACTORIES.get(upperCaseWithLocale);
                if (biFunction == null) {
                    throw new IllegalArgumentException(errorMessage(List.of(str)));
                }
                return ClassificationMetricSpecification.createSpecification((localIdMap2, longMultiSet2) -> {
                    return Stream.of((Metric) biFunction.apply(localIdMap2, longMultiSet2));
                }, upperCaseWithLocale);
            }
            String group = matcher.group(1);
            String group2 = matcher.group(2);
            LongIntToObjectFunction<ClassificationMetric> longIntToObjectFunction = SINGLE_CLASS_METRIC_FACTORIES.get(group);
            if (longIntToObjectFunction == null) {
                throw new IllegalArgumentException(errorMessage(List.of(str)));
            }
            Function function = group2.equals("*") ? localIdMap3 -> {
                return localIdMap3.getMappings().map(longIntCursor -> {
                    return (Metric) longIntToObjectFunction.value(longIntCursor.key, longIntCursor.value);
                });
            } : localIdMap4 -> {
                return Stream.of((Metric) longIntToObjectFunction.value(Long.parseLong(group2), localIdMap4.toMapped(Long.parseLong(group2))));
            };
            return ClassificationMetricSpecification.createSpecification((localIdMap5, longMultiSet3) -> {
                return (Stream) function.apply(localIdMap5);
            }, StringFormatting.formatWithLocale("%s(class=%s)", group, group2));
        }

        private static List<String> allValidMetricExpressions() {
            return validMetricExpressions(true);
        }

        private static List<String> validPrimaryMetricExpressions() {
            return validMetricExpressions(false);
        }

        private static List<String> validMetricExpressions(boolean z) {
            LinkedList linkedList = new LinkedList(MODEL_SPECIFIC_METRICS);
            linkedList.addAll(ALL_CLASS_METRIC_FACTORIES.keySet());
            for (String str : singleClassMetrics()) {
                if (z) {
                    linkedList.add(str + "(class=*)");
                }
                linkedList.add(str + "(class=<class value>)");
            }
            return linkedList;
        }

        private static String errorMessage(List<String> list) {
            Object[] objArr = new Object[3];
            objArr[0] = list.size() == 1 ? "" : "s";
            objArr[1] = list.stream().map(str -> {
                return "`" + str + "`";
            }).collect(Collectors.joining(", "));
            objArr[2] = String.join(", ", allValidMetricExpressions());
            return StringFormatting.formatWithLocale("Invalid metric expression%s %s. Available metrics are %s (case insensitive and space allowed between brackets).", objArr);
        }

        private static boolean invalidSpecification(String str) {
            String upperCase = str.toUpperCase(Locale.ENGLISH);
            return (MODEL_SPECIFIC_METRICS.contains(upperCase) || SINGLE_CLASS_METRIC_PATTERN.matcher(upperCase).matches() || ALL_CLASS_METRIC_FACTORIES.containsKey(upperCase)) ? false : true;
        }

        private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
            String implMethodName = serializedLambda.getImplMethodName();
            boolean z = -1;
            switch (implMethodName.hashCode()) {
                case 1818100338:
                    if (implMethodName.equals("<init>")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    if (serializedLambda.getImplMethodKind() == 8 && serializedLambda.getFunctionalInterfaceClass().equals("com/neo4j/gds/shaded/org/eclipse/collections/api/block/function/primitive/LongIntToObjectFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("value") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(JI)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/neo4j/gds/ml/metrics/classification/F1Score") && serializedLambda.getImplMethodSignature().equals("(JI)V")) {
                        return F1Score::new;
                    }
                    if (serializedLambda.getImplMethodKind() == 8 && serializedLambda.getFunctionalInterfaceClass().equals("com/neo4j/gds/shaded/org/eclipse/collections/api/block/function/primitive/LongIntToObjectFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("value") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(JI)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/neo4j/gds/ml/metrics/classification/Precision") && serializedLambda.getImplMethodSignature().equals("(JI)V")) {
                        return Precision::new;
                    }
                    if (serializedLambda.getImplMethodKind() == 8 && serializedLambda.getFunctionalInterfaceClass().equals("com/neo4j/gds/shaded/org/eclipse/collections/api/block/function/primitive/LongIntToObjectFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("value") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(JI)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/neo4j/gds/ml/metrics/classification/Recall") && serializedLambda.getImplMethodSignature().equals("(JI)V")) {
                        return Recall::new;
                    }
                    if (serializedLambda.getImplMethodKind() == 8 && serializedLambda.getFunctionalInterfaceClass().equals("com/neo4j/gds/shaded/org/eclipse/collections/api/block/function/primitive/LongIntToObjectFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("value") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(JI)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/neo4j/gds/ml/metrics/classification/Accuracy") && serializedLambda.getImplMethodSignature().equals("(JI)V")) {
                        return Accuracy::new;
                    }
                    break;
            }
            throw new IllegalArgumentException("Invalid lambda deserialization");
        }
    }

    private ClassificationMetricSpecification(String str, BiFunction<LocalIdMap, LongMultiSet, Stream<Metric>> biFunction) {
        this.stringRepresentation = str;
        this.metricFactory = biFunction;
    }

    private static ClassificationMetricSpecification createSpecification(BiFunction<LocalIdMap, LongMultiSet, Stream<Metric>> biFunction, String str) {
        return new ClassificationMetricSpecification(str, biFunction);
    }

    public Stream<Metric> createMetrics(LocalIdMap localIdMap, LongMultiSet longMultiSet) {
        return this.metricFactory.apply(localIdMap, longMultiSet);
    }

    public String toString() {
        return this.stringRepresentation;
    }

    public boolean equals(Object obj) {
        if (obj instanceof ClassificationMetricSpecification) {
            return toString().equals(obj.toString());
        }
        return false;
    }

    public int hashCode() {
        return toString().hashCode();
    }

    public static MemoryEstimation memoryEstimation(int i) {
        return MemoryEstimations.builder().rangePerNode("metrics", j -> {
            return MemoryRange.of(24L, i * 24);
        }).build();
    }

    public static List<String> specificationsToString(List<ClassificationMetricSpecification> list) {
        return (List) list.stream().map((v0) -> {
            return v0.toString();
        }).collect(Collectors.toList());
    }
}
