/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.utils;

import htsjdk.samtools.util.Histogram;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.util.FastMath;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.utils.MathUtils;

public class MannWhitneyU {
    protected static Logger logger = LogManager.getLogger(MannWhitneyU.class);
    private static final double NORMAL_MEAN = 0.0;
    private static final double NORMAL_SD = 1.0;
    private static final NormalDistribution NORMAL = new NormalDistribution(0.0, 1.0);
    private static Map<Key, Set<List<Integer>>> PERMUTATIONS = new ConcurrentHashMap<Key, Set<List<Integer>>>();
    private int minimumNormalN = 10;

    public void setMinimumSeriesLengthForNormalApproximation(int n) {
        this.minimumNormalN = n;
    }

    public RankedData calculateRank(double[] series1, double[] series2) {
        int count;
        Arrays.sort(series1);
        Arrays.sort(series2);
        Rank[] ranks = new Rank[series1.length + series2.length];
        int i = 0;
        int j = 0;
        int r = 0;
        while (r < ranks.length) {
            if (i >= series1.length) {
                ranks[r++] = new Rank(series2[j++], r, 2);
                continue;
            }
            if (j >= series2.length) {
                ranks[r++] = new Rank(series1[i++], r, 1);
                continue;
            }
            if (series1[i] <= series2[j]) {
                ranks[r++] = new Rank(series1[i++], r, 1);
                continue;
            }
            ranks[r++] = new Rank(series2[j++], r, 2);
        }
        ArrayList<Integer> numOfTies = new ArrayList<Integer>();
        for (int i2 = 0; i2 < ranks.length; i2 += count) {
            int j2;
            float rank = ranks[i2].rank;
            count = 1;
            for (j2 = i2 + 1; j2 < ranks.length && ranks[j2].value == ranks[i2].value; ++j2) {
                rank += ranks[j2].rank;
                ++count;
            }
            if (count <= true) continue;
            rank /= (float)count;
            for (j2 = i2; j2 < i2 + count; ++j2) {
                ranks[j2].rank = rank;
            }
            numOfTies.add(count);
        }
        return new RankedData(ranks, numOfTies);
    }

    public TestStatistic calculateU1andU2(double[] series1, double[] series2) {
        RankedData ranked = this.calculateRank(series1, series2);
        Rank[] ranks = ranked.getRank();
        ArrayList<Integer> numOfTies = ranked.getNumOfTies();
        int lengthOfRanks = ranks.length;
        double numOfTiesForSigma = this.transformTies(lengthOfRanks, numOfTies);
        float r1 = 0.0f;
        float r2 = 0.0f;
        for (Rank rank : ranks) {
            if (rank.series == 1) {
                r1 += rank.rank;
                continue;
            }
            r2 += rank.rank;
        }
        double n1 = series1.length;
        double n2 = series2.length;
        double u1 = (double)r1 - n1 * (n1 + 1.0) / 2.0;
        double u2 = (double)r2 - n2 * (n2 + 1.0) / 2.0;
        TestStatistic result = new TestStatistic(u1, u2, numOfTiesForSigma);
        return result;
    }

    public double transformTies(int numOfRanks, ArrayList<Integer> numOfTies) {
        ArrayList<Double> transformedTies = new ArrayList<Double>();
        for (int count : numOfTies) {
            if (count == numOfRanks) continue;
            transformedTies.add(Math.pow(count, 3.0) - (double)count);
        }
        double numOfTiesForSigma = 0.0;
        Iterator iterator = transformedTies.iterator();
        while (iterator.hasNext()) {
            double count = (Double)iterator.next();
            numOfTiesForSigma += count;
        }
        return numOfTiesForSigma;
    }

    public TestStatistic calculateOneSidedU(double[] series1, double[] series2, TestType whichSeriesDominates) {
        TestStatistic stat = this.calculateU1andU2(series1, series2);
        TestStatistic result = whichSeriesDominates == TestType.FIRST_DOMINATES ? new TestStatistic(stat.getU1(), stat.getTies()) : new TestStatistic(stat.getU2(), stat.getTies());
        return result;
    }

    public TestStatistic calculateTwoSidedU(double[] series1, double[] series2) {
        TestStatistic u1AndU2 = this.calculateU1andU2(series1, series2);
        double u = Math.min(u1AndU2.getU1(), u1AndU2.getU2());
        TestStatistic result = new TestStatistic(u, u1AndU2.getTies());
        return result;
    }

    public double calculateZ(double u, int n1, int n2, double nties, TestType whichSide) {
        double correction;
        double m = (double)(n1 * n2) / 2.0;
        if (whichSide == TestType.TWO_SIDED) {
            correction = u - m >= 0.0 ? 0.5 : -0.5;
        } else {
            double d = correction = whichSide == TestType.FIRST_DOMINATES ? -0.5 : 0.5;
        }
        if (nties == 0.0) {
            correction = 0.0;
        }
        double sigma = Math.sqrt((double)(n1 * n2) / 12.0 * ((double)(n1 + n2 + 1) - nties / (double)((n1 + n2) * (n1 + n2 - 1))));
        return (u - m - correction) / sigma;
    }

    public double median(double[] data) {
        int len = data.length;
        int mid = len / 2;
        if (data.length % 2 == 0) {
            return (data[mid] + data[mid - 1]) / 2.0;
        }
        return data[mid];
    }

    public Result test(double[] series1, double[] series2, TestType whichSide) {
        double p;
        double z;
        double nties;
        double u;
        TestStatistic result;
        int n1 = series1.length;
        int n2 = series2.length;
        if (n1 == 0 || n2 == 0) {
            return new Result(Double.NaN, Double.NaN, Double.NaN, Double.NaN);
        }
        if (whichSide == TestType.TWO_SIDED) {
            result = this.calculateTwoSidedU(series1, series2);
            u = result.getTrueU();
            nties = result.getTies();
        } else {
            result = this.calculateOneSidedU(series1, series2, whichSide);
            u = result.getTrueU();
            nties = result.getTies();
        }
        if (n1 >= this.minimumNormalN || n2 >= this.minimumNormalN) {
            z = this.calculateZ(u, n1, n2, nties, whichSide);
            p = 2.0 * NORMAL.cumulativeProbability(0.0 + z * 1.0);
            if (whichSide != TestType.TWO_SIDED) {
                p /= 2.0;
            }
        } else {
            if (whichSide != TestType.FIRST_DOMINATES) {
                logger.warn("An exact two-sided MannWhitneyU test was called. Only the one-sided exact test is implemented, use the approximation instead by setting minimumNormalN to 0.");
            }
            p = this.permutationTest(series1, series2, u);
            z = NORMAL.inverseCumulativeProbability(p);
        }
        return new Result(u, z, p, Math.abs(this.median(series1) - this.median(series2)));
    }

    private void swap(Integer[] arr, int i, int j) {
        int temp = arr[i];
        arr[i] = arr[j];
        arr[j] = temp;
    }

    private void calculatePermutations(Integer[] temp, Set<List<Integer>> allPermutations) {
        allPermutations.add(new ArrayList<Integer>(Arrays.asList(temp)));
        while (true) {
            int k = -1;
            for (int i = temp.length - 2; i >= 0; --i) {
                if (temp[i] >= temp[i + 1]) continue;
                k = i;
                break;
            }
            if (k == -1) break;
            int l = -1;
            for (int i = temp.length - 1; i >= k + 1; --i) {
                if (temp[k] >= temp[i]) continue;
                l = i;
                break;
            }
            this.swap(temp, k, l);
            int end = temp.length - 1;
            for (int begin = k + 1; begin < end; --end, ++begin) {
                this.swap(temp, begin, end);
            }
            allPermutations.add(new ArrayList<Integer>(Arrays.asList(temp)));
        }
    }

    Set<List<Integer>> getPermutations(Integer[] listToPermute, int numOfPermutations) {
        Key key = new Key(listToPermute);
        Set<List<Integer>> permutations = PERMUTATIONS.get(key);
        if (permutations == null) {
            permutations = new HashSet<List<Integer>>(numOfPermutations);
            this.calculatePermutations(listToPermute, permutations);
            PERMUTATIONS.put(key, permutations);
        }
        return permutations;
    }

    public double permutationTest(double[] series1, double[] series2, double testStatU) {
        Histogram histo = new Histogram();
        int n1 = series1.length;
        int n2 = series2.length;
        RankedData rankedGroups = this.calculateRank(series1, series2);
        Rank[] ranks = rankedGroups.getRank();
        Integer[] firstPermutation = new Integer[n1 + n2];
        for (int i = 0; i < firstPermutation.length; ++i) {
            firstPermutation[i] = i < n1 ? Integer.valueOf(0) : Integer.valueOf(1);
        }
        int numOfPerms = (int)MathUtils.binomialCoefficient(n1 + n2, n2);
        Set<List<Integer>> allPermutations = this.getPermutations(firstPermutation, numOfPerms);
        double[] newSeries1 = new double[n1];
        double[] newSeries2 = new double[n2];
        for (List<Integer> currPerm : allPermutations) {
            int series1End = 0;
            int series2End = 0;
            for (int i = 0; i < currPerm.size(); ++i) {
                int grouping = currPerm.get(i);
                if (grouping == 0) {
                    newSeries1[series1End] = ranks[i].rank;
                    ++series1End;
                    continue;
                }
                newSeries2[series2End] = ranks[i].rank;
                ++series2End;
            }
            assert (series1End == n1);
            assert (series2End == n2);
            double newU = MathUtils.sum(newSeries1) - (double)(n1 * (n1 + 1)) / 2.0;
            histo.increment((Comparable)Long.valueOf(FastMath.round((double)(2.0 * newU))));
        }
        double sumOfAllSmallerBins = histo.get((Comparable)Long.valueOf(FastMath.round((double)(2.0 * testStatU)))).getValue() / 2.0;
        for (Histogram.Bin bin : histo.values()) {
            if ((Long)bin.getId() >= FastMath.round((double)(2.0 * testStatU))) continue;
            sumOfAllSmallerBins += bin.getValue();
        }
        return sumOfAllSmallerBins / histo.getCount();
    }

    public static enum TestType {
        FIRST_DOMINATES,
        SECOND_DOMINATES,
        TWO_SIDED;

    }

    private static class Key {
        final Integer[] listToPermute;

        private Key(Integer[] listToPermute) {
            this.listToPermute = listToPermute;
        }

        public boolean equals(Object o) {
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            Key that = (Key)o;
            return Arrays.deepEquals(this.listToPermute, that.listToPermute);
        }

        public int hashCode() {
            int result = 17;
            for (Integer i : this.listToPermute) {
                result = 31 * result + this.listToPermute[i];
            }
            return result;
        }
    }

    public static class RankedData {
        private final Rank[] rank;
        private final ArrayList<Integer> numOfTies;

        public RankedData(Rank[] rank, ArrayList<Integer> numOfTies) {
            this.rank = rank;
            this.numOfTies = numOfTies;
        }

        public Rank[] getRank() {
            return this.rank;
        }

        public ArrayList<Integer> getNumOfTies() {
            return this.numOfTies;
        }
    }

    public static class TestStatistic {
        private final double u1;
        private final double u2;
        private final double trueU;
        private final double numOfTiesTransformed;

        public TestStatistic(double u1, double u2, double numOfTiesTransformed) {
            this.u1 = u1;
            this.u2 = u2;
            this.numOfTiesTransformed = numOfTiesTransformed;
            this.trueU = Double.NaN;
        }

        public TestStatistic(double trueU, double numOfTiesTransformed) {
            this.trueU = trueU;
            this.numOfTiesTransformed = numOfTiesTransformed;
            this.u1 = Double.NaN;
            this.u2 = Double.NaN;
        }

        public double getU1() {
            return this.u1;
        }

        public double getU2() {
            return this.u2;
        }

        public double getTies() {
            return this.numOfTiesTransformed;
        }

        public double getTrueU() {
            return this.trueU;
        }
    }

    public static class Result {
        private final double u;
        private final double z;
        private final double p;
        private final double medianShift;

        public Result(double u, double z, double p, double medianShift) {
            this.u = u;
            this.z = z;
            this.p = p;
            this.medianShift = medianShift;
        }

        public double getU() {
            return this.u;
        }

        public double getZ() {
            return this.z;
        }

        public double getP() {
            return this.p;
        }

        public double getMedianShift() {
            return this.medianShift;
        }
    }

    private static final class Rank
    implements Comparable<Rank> {
        final double value;
        float rank;
        final int series;

        private Rank(double value, float rank, int series) {
            this.value = value;
            this.rank = rank;
            this.series = series;
        }

        @Override
        public int compareTo(Rank that) {
            return (int)(this.value - that.value);
        }

        public String toString() {
            return "Rank{value=" + this.value + ", rank=" + this.rank + ", series=" + this.series + '}';
        }
    }
}

