package apoc.ml;

import apoc.ApocConfig;
import apoc.Extended;
import apoc.ExtendedApocConfig;
import apoc.ml.OpenAIRequestHandler;
import apoc.result.MapResult;
import apoc.util.ExtendedMapUtils;
import apoc.util.ExtendedUtil;
import apoc.util.JsonUtil;
import apoc.util.Util;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.unboundid.ldap.sdk.Version;
import java.net.MalformedURLException;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.StringUtils;
import org.jsoup.helper.HttpConnection;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@Extended
/* loaded from: input_file:apoc/ml/OpenAI.class */
public class OpenAI {
    public static final String API_TYPE_CONF_KEY = "apiType";
    public static final String APIKEY_CONF_KEY = "apiKey";
    public static final String JSON_PATH_CONF_KEY = "jsonPath";
    public static final String PATH_CONF_KEY = "path";
    public static final String GPT_4O_MODEL = "gpt-4o";
    public static final String FAIL_ON_ERROR_CONF = "failOnError";
    public static final String ENABLE_BACK_OFF_RETRIES_CONF_KEY = "enableBackOffRetries";
    public static final String ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY = "exponentialBackoff";
    public static final String BACK_OFF_RETRIES_CONF_KEY = "backOffRetries";

    @Context
    public ApocConfig apocConfig;

    @Context
    public URLAccessChecker urlAccessChecker;
    public static final String APOC_ML_OPENAI_URL = "apoc.ml.openai.url";

    /* loaded from: input_file:apoc/ml/OpenAI$EmbeddingResult.class */
    public static class EmbeddingResult {
        public final long index;
        public final String text;
        public final List<Double> embedding;

        public EmbeddingResult(long j, String str, List<Double> list) {
            this.index = j;
            this.text = str;
            this.embedding = list;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Stream<Object> executeRequest(String str, Map<String, Object> map, String str2, String str3, String str4, Object obj, String str5, ApocConfig apocConfig, URLAccessChecker uRLAccessChecker) throws JsonProcessingException, MalformedURLException {
        String str6 = (String) map.getOrDefault("apiKey", apocConfig.getString(ExtendedApocConfig.APOC_OPENAI_KEY, str));
        boolean z = Util.toBoolean(map.get(ENABLE_BACK_OFF_RETRIES_CONF_KEY));
        Integer integer = Util.toInteger(map.getOrDefault(BACK_OFF_RETRIES_CONF_KEY, 5));
        boolean z2 = Util.toBoolean(map.get(ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY));
        if (str6 == null || str6.isBlank()) {
            throw new IllegalArgumentException("API Key must not be empty");
        }
        OpenAIRequestHandler.Type valueOf = OpenAIRequestHandler.Type.valueOf(((String) map.getOrDefault(API_TYPE_CONF_KEY, apocConfig.getString(ExtendedApocConfig.APOC_ML_OPENAI_TYPE, OpenAIRequestHandler.Type.OPENAI.name()))).toUpperCase(Locale.ENGLISH));
        HashMap hashMap = new HashMap(map);
        Stream of = Stream.of((Object[]) new String[]{"endpoint", API_TYPE_CONF_KEY, MLUtil.API_VERSION_CONF_KEY, "apiKey"});
        Objects.requireNonNull(hashMap);
        of.forEach((v1) -> {
            r1.remove(v1);
        });
        HashMap hashMap2 = new HashMap();
        handleAPIProvider(valueOf, map, str2, str3, str4, obj, hashMap, hashMap2);
        String str7 = (String) map.getOrDefault("path", str2);
        OpenAIRequestHandler openAIRequestHandler = valueOf.get();
        String str8 = (String) map.getOrDefault("jsonPath", str5);
        hashMap2.put(HttpConnection.CONTENT_TYPE, "application/json");
        openAIRequestHandler.addApiKey(hashMap2, str6);
        String writeValueAsString = JsonUtil.OBJECT_MAPPER.writeValueAsString(hashMap);
        String fullUrl = openAIRequestHandler.getFullUrl(str7, map, apocConfig);
        return (Stream) ExtendedUtil.withBackOffRetries(() -> {
            return JsonUtil.loadJson(fullUrl, hashMap2, writeValueAsString, str8, true, List.of(), uRLAccessChecker);
        }, z, integer.intValue(), z2, exc -> {
            if (!exc.getMessage().contains("429")) {
                throw new RuntimeException(exc);
            }
        });
    }

    private static void handleAPIProvider(OpenAIRequestHandler.Type type, Map<String, Object> map, String str, String str2, String str3, Object obj, HashMap<String, Object> hashMap, Map<String, Object> map2) {
        switch (type) {
            case MIXEDBREAD_CUSTOM:
                return;
            case HUGGINGFACE:
                hashMap.putIfAbsent("inputs", obj);
                map.putIfAbsent("path", Version.VERSION_QUALIFIER);
                map2.putIfAbsent("method", "POST");
                map.putIfAbsent("jsonPath", "$[0]");
                return;
            case ANTHROPIC:
                map2.putIfAbsent(MLUtil.ANTHROPIC_VERSION, map.getOrDefault(MLUtil.ANTHROPIC_VERSION, "2023-06-01"));
                if (str.equals("completions")) {
                    map.putIfAbsent("path", "complete");
                    hashMap.putIfAbsent(MLUtil.MAX_TOKENS_TO_SAMPLE, 1000);
                    hashMap.putIfAbsent("model", "claude-2.1");
                } else {
                    map.putIfAbsent("path", "messages");
                    hashMap.putIfAbsent(MLUtil.MAX_TOKENS, 1000);
                    hashMap.putIfAbsent("model", "claude-3-5-sonnet-20240620");
                }
                hashMap.remove(MLUtil.ANTHROPIC_VERSION);
                hashMap.put(str3, obj);
                return;
            default:
                hashMap.putIfAbsent("model", str2);
                hashMap.put(str3, obj);
                return;
        }
    }

    @Procedure("apoc.ml.openai.embedding")
    @Description("apoc.openai.embedding([texts], api_key, configuration) - returns the embeddings for a given text")
    public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> list, @Name("api_key") String str, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map) throws Exception {
        boolean isFailOnError = isFailOnError(map);
        if (checkNullInput(list, isFailOnError)) {
            return Stream.empty();
        }
        List<String> list2 = list.stream().filter((v0) -> {
            return StringUtils.isNotBlank(v0);
        }).toList();
        return checkEmptyInput(list2, isFailOnError) ? Stream.empty() : getEmbeddingResult(list2, str, map, this.apocConfig, this.urlAccessChecker, (map2, str2) -> {
            return new EmbeddingResult(((Long) map2.get("index")).longValue(), str2, (List) map2.get("embedding"));
        }, str3 -> {
            return new EmbeddingResult(-1L, str3, List.of());
        });
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static <T> Stream<T> getEmbeddingResult(List<String> list, String str, Map<String, Object> map, ApocConfig apocConfig, URLAccessChecker uRLAccessChecker, BiFunction<Map, String, T> biFunction, Function<String, T> function) throws JsonProcessingException, MalformedURLException {
        if (list == null) {
            throw new RuntimeException(MLUtil.ERROR_NULL_INPUT);
        }
        Map map2 = (Map) list.stream().collect(Collectors.groupingBy((v0) -> {
            return Objects.nonNull(v0);
        }));
        List list2 = (List) map2.get(true);
        return Stream.concat(executeRequest(str, map, RagConfig.EMBEDDINGS_CONF, "text-embedding-ada-002", "input", list2, "$.data", apocConfig, uRLAccessChecker).flatMap(obj -> {
            return ((List) obj).stream();
        }).map(map3 -> {
            return biFunction.apply(map3, (String) list2.get(((Long) map3.get("index")).intValue()));
        }), ((List) map2.getOrDefault(false, List.of())).stream().map(function));
    }

    @Procedure("apoc.ml.openai.completion")
    @Description("apoc.ml.openai.completion(prompt, api_key, configuration) - prompts the completion API")
    public Stream<MapResult> completion(@Name("prompt") String str, @Name("api_key") String str2, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map) throws Exception {
        return checkBlankInput(str, isFailOnError(map)) ? Stream.empty() : executeRequest(str2, map, "completions", "gpt-3.5-turbo-instruct", RagConfig.PROMPT_CONF, str, "$", this.apocConfig, this.urlAccessChecker).map(obj -> {
            return (Map) obj;
        }).map(MapResult::new);
    }

    @Procedure("apoc.ml.openai.chat")
    @Description("apoc.ml.openai.chat(messages, api_key, configuration]) - prompts the completion API")
    public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Object>> list, @Name("api_key") String str, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map) throws Exception {
        boolean isFailOnError = isFailOnError(map);
        if (checkNullInput(list, isFailOnError)) {
            return Stream.empty();
        }
        List<Map<String, Object>> list2 = list.stream().filter(ExtendedMapUtils::isNotEmpty).toList();
        if (checkEmptyInput(list2, isFailOnError)) {
            return Stream.empty();
        }
        map.putIfAbsent("model", GPT_4O_MODEL);
        return executeRequest(str, map, "chat/completions", (String) map.get("model"), "messages", list2, "$", this.apocConfig, this.urlAccessChecker).map(obj -> {
            return (Map) obj;
        }).map(MapResult::new);
    }

    private static boolean isFailOnError(Map<String, Object> map) {
        return Util.toBoolean(map.getOrDefault(FAIL_ON_ERROR_CONF, true));
    }

    static boolean checkNullInput(Object obj, boolean z) {
        return checkInput(z, () -> {
            return Boolean.valueOf(Objects.isNull(obj));
        });
    }

    static boolean checkEmptyInput(Collection<?> collection, boolean z) {
        return checkInput(z, () -> {
            return Boolean.valueOf(collection.isEmpty());
        });
    }

    static boolean checkBlankInput(String str, boolean z) {
        return checkInput(z, () -> {
            return Boolean.valueOf(StringUtils.isBlank(str));
        });
    }

    private static boolean checkInput(boolean z, Supplier<Boolean> supplier) {
        if (!supplier.get().booleanValue()) {
            return false;
        }
        if (z) {
            throw new RuntimeException(MLUtil.ERROR_NULL_INPUT);
        }
        return true;
    }
}
