package org.neo4j.graphalgo.similarity;

import java.util.Map;
import java.util.stream.Stream;
import org.neo4j.graphalgo.core.ProcedureConfiguration;
import org.neo4j.graphalgo.similarity.SimilarityProc;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

/* loaded from: input_file:org/neo4j/graphalgo/similarity/CosineProc.class */
public class CosineProc extends SimilarityProc {
    @Procedure(name = "algo.similarity.cosine.stream", mode = Mode.READ)
    @Description("CALL algo.similarity.cosine.stream([{item:id, weights:[weights]}], {similarityCutoff:-1,degreeCutoff:0}) YIELD item1, item2, count1, count2, intersection, similarity - computes cosine distance")
    public Stream<SimilarityResult> cosineStream(@Name(value = "data", defaultValue = "null") Object obj, @Name(value = "config", defaultValue = "{}") Map<String, Object> map) throws Exception {
        ProcedureConfiguration create = ProcedureConfiguration.create(map);
        Double d = (Double) create.get("skipValue", Double.valueOf(Double.NaN));
        WeightedInput[] prepareWeights = prepareWeights(obj, create, d);
        return prepareWeights.length == 0 ? Stream.empty() : generateWeightedStream(create, prepareWeights, similarityCutoff(create), getTopN(create), getTopK(create), similarityComputer(d));
    }

    @Procedure(name = "algo.similarity.cosine", mode = Mode.WRITE)
    @Description("CALL algo.similarity.cosine([{item:id, weights:[weights]}], {similarityCutoff:-1,degreeCutoff:0}) YIELD p50, p75, p90, p99, p999, p100 - computes cosine similarities")
    public Stream<SimilaritySummaryResult> cosine(@Name(value = "data", defaultValue = "null") Object obj, @Name(value = "config", defaultValue = "{}") Map<String, Object> map) throws Exception {
        ProcedureConfiguration create = ProcedureConfiguration.create(map);
        Double d = (Double) create.get("skipValue", Double.valueOf(Double.NaN));
        WeightedInput[] prepareWeights = prepareWeights(obj, create, d);
        String str = (String) create.get("writeRelationshipType", "SIMILAR");
        String writeProperty = create.getWriteProperty("score");
        if (prepareWeights.length == 0) {
            return emptyStream(str, writeProperty);
        }
        double similarityCutoff = similarityCutoff(create);
        return writeAndAggregateResults(generateWeightedStream(create, prepareWeights, similarityCutoff, getTopN(create), getTopK(create), similarityComputer(d)), prepareWeights.length, create.isWriteFlag(false) && similarityCutoff > 0.0d, str, writeProperty);
    }

    private SimilarityProc.SimilarityComputer<WeightedInput> similarityComputer(Double d) {
        return d == null ? (rleDecoder, weightedInput, weightedInput2, d2) -> {
            return weightedInput.cosineSquares(rleDecoder, d2, weightedInput2);
        } : (rleDecoder2, weightedInput3, weightedInput4, d3) -> {
            return weightedInput3.cosineSquaresSkip(rleDecoder2, d3, weightedInput4, d.doubleValue());
        };
    }

    Stream<SimilarityResult> generateWeightedStream(ProcedureConfiguration procedureConfiguration, WeightedInput[] weightedInputArr, double d, int i, int i2, SimilarityProc.SimilarityComputer<WeightedInput> similarityComputer) {
        return topN(similarityStream(weightedInputArr, similarityComputer, procedureConfiguration, createDecoderFactory(procedureConfiguration, weightedInputArr[0]), d, i2), i).map((v0) -> {
            return v0.squareRooted();
        });
    }

    private double similarityCutoff(ProcedureConfiguration procedureConfiguration) {
        double doubleValue = getSimilarityCutoff(procedureConfiguration).doubleValue();
        if (doubleValue > 0.0d) {
            doubleValue *= doubleValue;
        }
        return doubleValue;
    }
}
