package org.neo4j.graphalgo.similarity;

import com.carrotsearch.hppc.LongDoubleHashMap;
import com.carrotsearch.hppc.LongDoubleMap;
import com.carrotsearch.hppc.LongHashSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.HdrHistogram.DoubleHistogram;
import org.neo4j.graphalgo.core.ProcedureConfiguration;
import org.neo4j.graphalgo.core.heavyweight.HeavyCypherGraphFactory;
import org.neo4j.graphalgo.core.utils.TerminationFlag;
import org.neo4j.graphalgo.similarity.recorder.NonRecordingSimilarityRecorder;
import org.neo4j.graphalgo.similarity.recorder.RecordingSimilarityRecorder;
import org.neo4j.graphalgo.similarity.recorder.SimilarityRecorder;
import org.neo4j.graphdb.Result;
import org.neo4j.kernel.api.KernelTransaction;
import org.neo4j.kernel.internal.GraphDatabaseAPI;
import org.neo4j.logging.Log;
import org.neo4j.procedure.Context;

/* loaded from: input_file:org/neo4j/graphalgo/similarity/SimilarityProc.class */
public class SimilarityProc {

    @Context
    public GraphDatabaseAPI api;

    @Context
    public Log log;

    @Context
    public KernelTransaction transaction;

    /* JADX INFO: Access modifiers changed from: package-private */
    public static TopKConsumer<SimilarityResult>[] initializeTopKConsumers(int i, int i2) {
        Comparator<SimilarityResult> comparator = i2 > 0 ? SimilarityResult.DESCENDING : SimilarityResult.ASCENDING;
        int abs = Math.abs(i2);
        TopKConsumer<SimilarityResult>[] topKConsumerArr = new TopKConsumer[i];
        for (int i3 = 0; i3 < topKConsumerArr.length; i3++) {
            topKConsumerArr[i3] = new TopKConsumer<>(abs, comparator);
        }
        return topKConsumerArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Stream<SimilarityResult> topN(Stream<SimilarityResult> stream, int i) {
        if (i == 0) {
            return stream;
        }
        Comparator<SimilarityResult> comparator = i > 0 ? SimilarityResult.DESCENDING : SimilarityResult.ASCENDING;
        int abs = Math.abs(i);
        return abs > 10000 ? stream.sorted(comparator).limit(abs) : TopKConsumer.topK(stream, abs, comparator);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static SimilarityRecorder<WeightedInput> similarityRecorder(SimilarityComputer<WeightedInput> similarityComputer, ProcedureConfiguration procedureConfiguration) {
        return ((Boolean) procedureConfiguration.get("showComputations", false)).booleanValue() ? new RecordingSimilarityRecorder(similarityComputer) : new NonRecordingSimilarityRecorder(similarityComputer);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static SimilarityRecorder<CategoricalInput> categoricalSimilarityRecorder(SimilarityComputer<CategoricalInput> similarityComputer, ProcedureConfiguration procedureConfiguration) {
        return ((Boolean) procedureConfiguration.get("showComputations", false)).booleanValue() ? new RecordingSimilarityRecorder(similarityComputer) : new NonRecordingSimilarityRecorder(similarityComputer);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Long getDegreeCutoff(ProcedureConfiguration procedureConfiguration) {
        return (Long) procedureConfiguration.get("degreeCutoff", 0L);
    }

    Long getWriteBatchSize(ProcedureConfiguration procedureConfiguration) {
        return (Long) procedureConfiguration.get("writeBatchSize", 10000L);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Stream<SimilaritySummaryResult> writeAndAggregateResults(Stream<SimilarityResult> stream, int i, int i2, int i3, ProcedureConfiguration procedureConfiguration, boolean z, String str, String str2, Computations computations) {
        long longValue = getWriteBatchSize(procedureConfiguration).longValue();
        AtomicLong atomicLong = new AtomicLong();
        DoubleHistogram doubleHistogram = new DoubleHistogram(5);
        Consumer<? super SimilarityResult> consumer = similarityResult -> {
            similarityResult.record(doubleHistogram);
            atomicLong.getAndIncrement();
        };
        if (z) {
            new SimilarityExporter(this.api, str, str2).export(stream.peek(consumer), longValue);
        } else {
            stream.forEach(consumer);
        }
        return Stream.of(SimilaritySummaryResult.from(i, i2, i3, atomicLong, computations.count(), str, str2, z, doubleHistogram));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Stream<SimilaritySummaryResult> emptyStream(String str, String str2) {
        return Stream.of(SimilaritySummaryResult.from(0L, 0L, 0L, new AtomicLong(0L), -1L, str, str2, false, new DoubleHistogram(5)));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Double getSimilarityCutoff(ProcedureConfiguration procedureConfiguration) {
        return (Double) procedureConfiguration.get("similarityCutoff", Double.valueOf(-1.0d));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public <T> Stream<SimilarityResult> similarityStream(T[] tArr, int[] iArr, int[] iArr2, SimilarityComputer<T> similarityComputer, ProcedureConfiguration procedureConfiguration, Supplier<RleDecoder> supplier, double d, int i) {
        SimilarityStreamGenerator similarityStreamGenerator = new SimilarityStreamGenerator(TerminationFlag.wrap(this.transaction), procedureConfiguration, supplier, similarityComputer);
        return (iArr.length == 0 && iArr2.length == 0) ? similarityStreamGenerator.stream(tArr, d, i) : similarityStreamGenerator.stream(tArr, iArr, iArr2, d, i);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public CategoricalInput[] prepareCategories(List<Map<String, Object>> list, long j) {
        CategoricalInput[] categoricalInputArr = new CategoricalInput[list.size()];
        int i = 0;
        for (Map<String, Object> map : list) {
            List<Number> extractValues = SimilarityInput.extractValues(map.get("categories"));
            int size = extractValues.size();
            if (size > j) {
                long[] jArr = new long[size];
                int i2 = 0;
                Iterator<Number> it = extractValues.iterator();
                while (it.hasNext()) {
                    int i3 = i2;
                    i2++;
                    jArr[i3] = it.next().longValue();
                }
                Arrays.sort(jArr);
                int i4 = i;
                i++;
                categoricalInputArr[i4] = new CategoricalInput(((Long) map.get("item")).longValue(), jArr);
            }
        }
        if (i != categoricalInputArr.length) {
            categoricalInputArr = (CategoricalInput[]) Arrays.copyOf(categoricalInputArr, i);
        }
        Arrays.sort(categoricalInputArr);
        return categoricalInputArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public WeightedInput[] prepareWeights(Object obj, ProcedureConfiguration procedureConfiguration, Double d) throws Exception {
        return HeavyCypherGraphFactory.TYPE.equals(procedureConfiguration.getGraphName("dense")) ? prepareSparseWeights(this.api, (String) obj, d, procedureConfiguration) : WeightedInput.prepareDenseWeights((List) obj, getDegreeCutoff(procedureConfiguration).longValue(), d);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Double readSkipValue(ProcedureConfiguration procedureConfiguration) {
        return (Double) procedureConfiguration.get("skipValue", Double.valueOf(Double.NaN));
    }

    private WeightedInput[] prepareSparseWeights(GraphDatabaseAPI graphDatabaseAPI, String str, Double d, ProcedureConfiguration procedureConfiguration) throws Exception {
        Map<String, Object> params = procedureConfiguration.getParams();
        Long degreeCutoff = getDegreeCutoff(procedureConfiguration);
        int intValue = ((Long) procedureConfiguration.get("sparseVectorRepeatCutoff", 3L)).intValue();
        Result execute = graphDatabaseAPI.execute(str, params);
        HashMap hashMap = new HashMap();
        LongHashSet longHashSet = new LongHashSet();
        execute.accept(resultRow -> {
            long longValue = resultRow.getNumber("item").longValue();
            long longValue2 = resultRow.getNumber("category").longValue();
            longHashSet.add(longValue2);
            double doubleValue = resultRow.getNumber("weight").doubleValue();
            hashMap.compute(Long.valueOf(longValue), (l, longDoubleMap) -> {
                if (longDoubleMap == null) {
                    longDoubleMap = new LongDoubleHashMap();
                }
                longDoubleMap.put(longValue2, doubleValue);
                return longDoubleMap;
            });
            return true;
        });
        WeightedInput[] weightedInputArr = new WeightedInput[hashMap.size()];
        int i = 0;
        long[] array = longHashSet.toArray();
        for (Map.Entry entry : hashMap.entrySet()) {
            Long l = (Long) entry.getKey();
            LongDoubleMap longDoubleMap = (LongDoubleMap) entry.getValue();
            if (longDoubleMap.size() > degreeCutoff.longValue()) {
                ArrayList arrayList = new ArrayList(longHashSet.size());
                for (long j : array) {
                    arrayList.add(Double.valueOf(longDoubleMap.getOrDefault(j, d.doubleValue())));
                }
                int i2 = i;
                i++;
                weightedInputArr[i2] = WeightedInput.sparse(l.longValue(), Weights.buildRleWeights(arrayList, intValue), arrayList.size(), longDoubleMap.size());
            }
        }
        if (i != weightedInputArr.length) {
            weightedInputArr = (WeightedInput[]) Arrays.copyOf(weightedInputArr, i);
        }
        Arrays.sort(weightedInputArr);
        return weightedInputArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int getTopK(ProcedureConfiguration procedureConfiguration) {
        return procedureConfiguration.getInt("topK", 0);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int getTopN(ProcedureConfiguration procedureConfiguration) {
        return procedureConfiguration.getInt("top", 0);
    }

    private Supplier<RleDecoder> createDecoderFactory(String str, int i) {
        return HeavyCypherGraphFactory.TYPE.equals(str) ? () -> {
            return new RleDecoder(i);
        } : () -> {
            return null;
        };
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Supplier<RleDecoder> createDecoderFactory(ProcedureConfiguration procedureConfiguration, WeightedInput weightedInput) {
        return createDecoderFactory(procedureConfiguration.getGraphName("dense"), weightedInput.initialSize);
    }
}
