package apoc.ml.aws;

import apoc.Description;
import apoc.Extended;
import apoc.result.MapResult;
import apoc.util.JsonUtil;
import apoc.util.Util;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import org.jsoup.helper.HttpConnection;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@Extended
/* loaded from: input_file:apoc/ml/aws/SageMaker.class */
public class SageMaker {

    @Context
    public URLAccessChecker urlAccessChecker;

    /* loaded from: input_file:apoc/ml/aws/SageMaker$EmbeddingResult.class */
    public static final class EmbeddingResult extends Record {
        private final long index;
        private final String text;
        private final List<Double> embedding;

        public EmbeddingResult(long j, String str, List<Double> list) {
            this.index = j;
            this.text = str;
            this.embedding = list;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, EmbeddingResult.class), EmbeddingResult.class, "index;text;embedding", "FIELD:Lapoc/ml/aws/SageMaker$EmbeddingResult;->index:J", "FIELD:Lapoc/ml/aws/SageMaker$EmbeddingResult;->text:Ljava/lang/String;", "FIELD:Lapoc/ml/aws/SageMaker$EmbeddingResult;->embedding:Ljava/util/List;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, EmbeddingResult.class), EmbeddingResult.class, "index;text;embedding", "FIELD:Lapoc/ml/aws/SageMaker$EmbeddingResult;->index:J", "FIELD:Lapoc/ml/aws/SageMaker$EmbeddingResult;->text:Ljava/lang/String;", "FIELD:Lapoc/ml/aws/SageMaker$EmbeddingResult;->embedding:Ljava/util/List;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, EmbeddingResult.class, Object.class), EmbeddingResult.class, "index;text;embedding", "FIELD:Lapoc/ml/aws/SageMaker$EmbeddingResult;->index:J", "FIELD:Lapoc/ml/aws/SageMaker$EmbeddingResult;->text:Ljava/lang/String;", "FIELD:Lapoc/ml/aws/SageMaker$EmbeddingResult;->embedding:Ljava/util/List;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public long index() {
            return this.index;
        }

        public String text() {
            return this.text;
        }

        public List<Double> embedding() {
            return this.embedding;
        }
    }

    @Procedure("apoc.ml.sagemaker.custom")
    @Description("apoc.ml.sagemaker.chat(body, $conf) - To create a customizable SageMaker call")
    public Stream<MapResult> custom(@Name("body") Object obj, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map) {
        return executeRequestReturningMap(obj, new SageMakerConfig(map)).map(MapResult::new);
    }

    @Procedure("apoc.ml.sagemaker.chat")
    @Description("apoc.ml.sagemaker.chat(messages, $conf) - Prompts the chat completion API")
    public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, String>> list, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map) {
        HashMap hashMap = new HashMap(map);
        hashMap.putIfAbsent(SageMakerConfig.ENDPOINT_NAME_KEY, "Endpoint-Distilbart-xsum-1-1-1");
        hashMap.putIfAbsent("headers", Util.map(HttpConnection.CONTENT_TYPE, "application/x-text"));
        SageMakerConfig sageMakerConfig = new SageMakerConfig(hashMap);
        return list.stream().flatMap(map2 -> {
            return executeRequestReturningMap(map2.containsKey("content") ? map2.get("content") : map2, sageMakerConfig).map(MapResult::new);
        });
    }

    @Procedure("apoc.ml.sagemaker.completion")
    @Description("apoc.ml.sagemaker.completion(prompt, $conf) - Prompts the completion API")
    public Stream<MapResult> completion(@Name("prompt") String str, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map) {
        HashMap hashMap = new HashMap(map);
        hashMap.putIfAbsent(SageMakerConfig.ENDPOINT_NAME_KEY, "Endpoint-GPT-2-1");
        hashMap.putIfAbsent("headers", Map.of(HttpConnection.CONTENT_TYPE, "application/x-text"));
        return executeRequestReturningMap(str, new SageMakerConfig(hashMap)).map(MapResult::new);
    }

    @Procedure("apoc.ml.sagemaker.embedding")
    @Description("apoc.ml.sagemaker.embedding([texts], $configuration) - Returns the embeddings for a given text")
    public Stream<EmbeddingResult> embedding(@Name("texts") List<String> list, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map) {
        HashMap hashMap = new HashMap(map);
        hashMap.putIfAbsent(SageMakerConfig.ENDPOINT_NAME_KEY, "Endpoint-Jina-Embeddings-v2-Base-en-1");
        hashMap.putIfAbsent("jsonPath", "data[*]");
        SageMakerConfig sageMakerConfig = new SageMakerConfig(hashMap);
        Map of = Map.of("data", list.stream().map(str -> {
            return Map.of("text", str);
        }).toList());
        AtomicInteger atomicInteger = new AtomicInteger();
        return executeRequestCommon(of, sageMakerConfig).flatMap(obj -> {
            return ((List) obj).stream();
        }).map(map2 -> {
            int andIncrement = atomicInteger.getAndIncrement();
            return new EmbeddingResult(andIncrement, (String) list.get(andIncrement), (List) map2.get("embedding"));
        });
    }

    private Stream<Map<String, Object>> executeRequestReturningMap(Object obj, AWSConfig aWSConfig) {
        return executeRequestCommon(obj, aWSConfig).map(obj2 -> {
            return (Map) obj2;
        });
    }

    private Stream<Object> executeRequestCommon(Object obj, AWSConfig aWSConfig) {
        try {
            String writeValueAsString = obj instanceof String ? (String) obj : JsonUtil.OBJECT_MAPPER.writeValueAsString(obj);
            Map<String, Object> headers = aWSConfig.getHeaders();
            headers.putIfAbsent(HttpConnection.CONTENT_TYPE, "application/json");
            headers.putIfAbsent("accept", "*/*");
            if (!headers.containsKey(AwsSignatureV4Generator.AUTHORIZATION_KEY)) {
                AwsSignatureV4Generator.calculateAuthorizationHeaders(aWSConfig, writeValueAsString, headers, "sagemaker");
            }
            return JsonUtil.loadJson(aWSConfig.getEndpoint(), aWSConfig.getHeaders(), writeValueAsString, aWSConfig.getJsonPath(), true, List.of(), this.urlAccessChecker);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}
