/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.similarity;

import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;
import org.neo4j.gds.core.utils.Intersections;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.UserFunction;
import org.neo4j.values.storable.FloatingPointValue;
import org.neo4j.values.storable.IntegralValue;
import org.neo4j.values.storable.Values;

public class SimilaritiesFunc {
    private static final Predicate<Number> IS_NULL = Predicate.isEqual(null);
    private static final Comparator<Number> NUMBER_COMPARATOR = new NumberComparator();
    private static final String CATEGORY_KEY = "category";
    private static final String WEIGHT_KEY = "weight";

    @UserFunction(value="gds.similarity.jaccard")
    @Description(value="RETURN gds.similarity.jaccard(vector1, vector2) - Given two collection vectors, calculate Jaccard similarity")
    public double jaccardSimilarity(@Name(value="vector1") List<Number> vector1, @Name(value="vector2") List<Number> vector2) {
        if (vector1 == null || vector2 == null) {
            return 0.0;
        }
        return this.jaccard(vector1, vector2);
    }

    @UserFunction(value="gds.similarity.cosine")
    @Description(value="RETURN gds.similarity.cosine(vector1, vector2) - Given two collection vectors, calculate cosine similarity")
    public double cosineSimilarity(@Name(value="vector1") List<Number> vector1, @Name(value="vector2") List<Number> vector2) {
        int len = this.validateLength(vector1, vector2);
        double[] left = this.toArray(vector1);
        double[] right = this.toArray(vector2);
        return Intersections.cosine((double[])left, (double[])right, (int)len);
    }

    @UserFunction(value="gds.similarity.pearson")
    @Description(value="RETURN gds.similarity.pearson(vector1, vector2) - Given two collection vectors, calculate pearson similarity")
    public double pearsonSimilarity(@Name(value="vector1") List<Number> vector1, @Name(value="vector2") List<Number> vector2) {
        int len = this.validateLength(vector1, vector2);
        double[] left = this.toArray(vector1);
        double[] right = this.toArray(vector2);
        return Intersections.pearson((double[])left, (double[])right, (int)len);
    }

    @UserFunction(value="gds.similarity.euclideanDistance")
    @Description(value="RETURN gds.similarity.euclideanDistance(vector1, vector2) - Given two collection vectors, calculate the euclidean distance (square root of the sum of the squared differences)")
    public double euclideanDistance(@Name(value="vector1") List<Number> vector1, @Name(value="vector2") List<Number> vector2) {
        int len = this.validateLength(vector1, vector2);
        double[] left = this.toArray(vector1);
        double[] right = this.toArray(vector2);
        return Math.sqrt(Intersections.sumSquareDelta((double[])left, (double[])right, (int)len));
    }

    @UserFunction(value="gds.similarity.euclidean")
    @Description(value="RETURN gds.similarity.euclidean(vector1, vector2) - Given two collection vectors, calculate similarity based on euclidean distance")
    public double euclideanSimilarity(@Name(value="vector1") List<Number> vector1, @Name(value="vector2") List<Number> vector2) {
        return 1.0 / (1.0 + this.euclideanDistance(vector1, vector2));
    }

    @UserFunction(value="gds.similarity.overlap")
    @Description(value="RETURN gds.similarity.overlap(vector1, vector2) - Given two collection vectors, calculate overlap similarity")
    public double overlapSimilarity(@Name(value="vector1") List<Number> vector1, @Name(value="vector2") List<Number> vector2) {
        vector1.removeIf(IS_NULL);
        vector2.removeIf(IS_NULL);
        if (vector1 == null || vector2 == null) {
            return 0.0;
        }
        HashSet<Number> intersectionSet = new HashSet<Number>(vector1);
        intersectionSet.retainAll(vector2);
        int intersection = intersectionSet.size();
        long denominator = Math.min(vector1.size(), vector2.size());
        return denominator == 0L ? 0.0 : (double)intersection / (double)denominator;
    }

    private double[] toArray(List<Number> input) {
        int length = input.size();
        double[] weights = new double[length];
        for (int i = 0; i < length; ++i) {
            weights[i] = SimilaritiesFunc.getDoubleValue(input.get(i));
        }
        return weights;
    }

    private int validateLength(List<Number> vector1, List<Number> vector2) {
        if (vector1.size() != vector2.size() || vector1.isEmpty()) {
            throw new RuntimeException("Vectors must be non-empty and of the same size");
        }
        return vector1.size();
    }

    private double jaccard(List<Number> vector1, List<Number> vector2) {
        vector1.removeIf(IS_NULL);
        vector2.removeIf(IS_NULL);
        vector1.sort(NUMBER_COMPARATOR);
        vector2.sort(NUMBER_COMPARATOR);
        int index1 = 0;
        int index2 = 0;
        int intersection = 0;
        double union = 0.0;
        while (index1 < vector1.size() && index2 < vector2.size()) {
            Number val2;
            Number val1 = vector1.get(index1);
            int compare = NUMBER_COMPARATOR.compare(val1, val2 = vector2.get(index2));
            if (compare == 0) {
                ++intersection;
                union += 1.0;
                ++index1;
                ++index2;
                continue;
            }
            if (compare < 0) {
                union += 1.0;
                ++index1;
                continue;
            }
            union += 1.0;
            ++index2;
        }
        return (union += (double)(vector1.size() - index1 + (vector2.size() - index2))) == 0.0 ? 1.0 : (double)intersection / union;
    }

    private static double getDoubleValue(Number value) {
        return Optional.ofNullable(value).map(Number::doubleValue).orElse(0.0);
    }

    static class NumberComparator
    implements Comparator<Number> {
        NumberComparator() {
        }

        @Override
        public int compare(Number o1, Number o2) {
            if (o1 instanceof Long && o2 instanceof Long) {
                return ((Long)o1).compareTo((Long)o2);
            }
            if (o1 instanceof Long) {
                return Values.longValue((long)o1.longValue()).compareTo((FloatingPointValue)Values.doubleValue((double)o2.doubleValue()));
            }
            if (o2 instanceof Long) {
                return Values.doubleValue((double)o1.doubleValue()).compareTo((IntegralValue)Values.longValue((long)o2.longValue()));
            }
            return Double.compare(o1.doubleValue(), o2.doubleValue());
        }
    }
}

