package org.neo4j.gds.ml.models.automl;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.ml.models.automl.ParameterParser;
import org.neo4j.gds.ml.models.automl.hyperparameter.ConcreteParameter;
import org.neo4j.gds.ml.models.automl.hyperparameter.DoubleRangeParameter;
import org.neo4j.gds.ml.models.automl.hyperparameter.IntegerRangeParameter;
import org.neo4j.gds.ml.models.automl.hyperparameter.NumericalRangeParameter;
import org.neo4j.gds.ml.models.linearregression.LinearRegressionTrainConfig;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
import org.neo4j.gds.ml.models.mlp.MLPClassifierTrainConfig;
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainerConfig;
import org.neo4j.gds.ml.models.randomforest.RandomForestRegressorTrainerConfig;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/models/automl/TunableTrainerConfig.class */
public final class TunableTrainerConfig {
    static final double EPSILON = 1.0E-8d;
    static final List<String> LOG_SCALE_PARAMETERS = List.of("penalty", "learningRate", "tolerance");
    static final Map<String, Class> NON_NUMERIC_PARAMETERS = Map.of("criterion", String.class, "hiddenLayerSizes", List.class, "classWeights", List.class);
    private final Map<String, ConcreteParameter<?>> concreteParameters;
    public final Map<String, DoubleRangeParameter> doubleRanges;
    public final Map<String, IntegerRangeParameter> integerRanges;
    private final TrainingMethod method;

    private TunableTrainerConfig(Map<String, ConcreteParameter<?>> map, Map<String, DoubleRangeParameter> map2, Map<String, IntegerRangeParameter> map3, TrainingMethod trainingMethod) {
        this.concreteParameters = map;
        this.doubleRanges = map2;
        this.integerRanges = map3;
        this.method = trainingMethod;
    }

    public static TunableTrainerConfig of(Map<String, Object> map, TrainingMethod trainingMethod) {
        ParameterParser.RangeParameters parseRangeParameters = ParameterParser.parseRangeParameters(map);
        TunableTrainerConfig tunableTrainerConfig = new TunableTrainerConfig(ParameterParser.parseConcreteParameters(fillDefaults(map, createTrainerConfigFromMap(Map.of(), trainingMethod).toMap())), parseRangeParameters.doubleRanges(), parseRangeParameters.integerRanges(), trainingMethod);
        tunableTrainerConfig.streamCornerCaseConfigs().forEach(trainerConfig -> {
        });
        return tunableTrainerConfig;
    }

    private static Map<String, Object> fillDefaults(Map<String, Object> map, Map<String, Object> map2) {
        return (Map) Stream.concat(map2.keySet().stream(), map.keySet().stream()).distinct().filter(str -> {
            return !str.equals("methodName");
        }).collect(Collectors.toMap(str2 -> {
            return str2;
        }, str3 -> {
            return map.getOrDefault(str3, map2.get(str3));
        }));
    }

    public TrainerConfig materialize(Map<String, Object> map) {
        HashMap hashMap = new HashMap();
        this.concreteParameters.forEach((str, concreteParameter) -> {
            hashMap.put(str, concreteParameter.value());
        });
        hashMap.putAll(map);
        return createTrainerConfigFromMap(hashMap, this.method);
    }

    public Stream<TrainerConfig> streamCornerCaseConfigs() {
        HashMap hashMap = new HashMap();
        hashMap.putAll(this.doubleRanges);
        hashMap.putAll(this.integerRanges);
        int size = hashMap.size();
        if (size > 20) {
            throw new IllegalArgumentException("Currently at most 20 hyperparameters are supported");
        }
        return IntStream.range(0, (int) Math.pow(2.0d, size)).mapToObj(i -> {
            HashMap hashMap2 = new HashMap();
            int i = 0;
            for (Map.Entry entry : hashMap.entrySet()) {
                boolean z = ((i >> i) & 1) == 0;
                NumericalRangeParameter numericalRangeParameter = (NumericalRangeParameter) entry.getValue();
                String str = (String) entry.getKey();
                hashMap2.put(str, endPoint(z, this.doubleRanges.containsKey(str), numericalRangeParameter));
                i++;
            }
            return materialize(hashMap2);
        });
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [java.lang.Number] */
    /* JADX WARN: Type inference failed for: r0v5, types: [java.lang.Number] */
    private static Number endPoint(boolean z, boolean z2, NumericalRangeParameter<?> numericalRangeParameter) {
        if (z2) {
            return Double.valueOf(z ? ((Double) numericalRangeParameter.min()).doubleValue() + 1.0E-8d : ((Double) numericalRangeParameter.max()).doubleValue() - 1.0E-8d);
        }
        return z ? numericalRangeParameter.min() : numericalRangeParameter.max();
    }

    public Map<String, Object> toMap() {
        HashMap hashMap = new HashMap();
        this.concreteParameters.forEach((str, concreteParameter) -> {
            hashMap.put(str, concreteParameter.value());
        });
        this.doubleRanges.forEach((str2, doubleRangeParameter) -> {
            hashMap.put(str2, doubleRangeParameter.toMap());
        });
        this.integerRanges.forEach((str3, integerRangeParameter) -> {
            hashMap.put(str3, integerRangeParameter.toMap());
        });
        hashMap.put("methodName", trainingMethod().toString());
        return hashMap;
    }

    public TrainingMethod trainingMethod() {
        return this.method;
    }

    public boolean isConcrete() {
        return this.doubleRanges.isEmpty() && this.integerRanges.isEmpty();
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        TunableTrainerConfig tunableTrainerConfig = (TunableTrainerConfig) obj;
        return Objects.equals(this.concreteParameters, tunableTrainerConfig.concreteParameters) && this.method == tunableTrainerConfig.method;
    }

    public int hashCode() {
        return Objects.hash(this.concreteParameters, this.method);
    }

    private static TrainerConfig createTrainerConfigFromMap(Map<String, Object> map, TrainingMethod trainingMethod) {
        switch (trainingMethod) {
            case LogisticRegression:
                return LogisticRegressionTrainConfig.of(map);
            case RandomForestClassification:
                return RandomForestClassifierTrainerConfig.of(map);
            case MLPClassification:
                return MLPClassifierTrainConfig.of(map);
            case LinearRegression:
                return LinearRegressionTrainConfig.of(map);
            case RandomForestRegression:
                return RandomForestRegressorTrainerConfig.of(map);
            default:
                throw new IllegalStateException(StringFormatting.formatWithLocale("Method %s does not have a trainerConfig Implemented", trainingMethod.name()));
        }
    }
}
