/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.classify;

import com.aliasi.classify.PrecisionRecallEvaluation;
import com.aliasi.util.Scored;
import com.aliasi.util.ScoredObject;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public class ScoredPrecisionRecallEvaluation {
    public static final double FLOATING_POINT_EQUALS_EPSILON = 1.0E-12;
    private final List<Case> mCases = new ArrayList<Case>();
    private int mNegativeRef = 0;
    private int mPositiveRef = 0;
    static final double[][] EMPTY_DOUBLE_2D_ARRAY = new double[0][];

    public void addCase(boolean correct, double score) {
        this.mCases.add(new Case(correct, score));
        if (correct) {
            ++this.mPositiveRef;
        } else {
            ++this.mNegativeRef;
        }
    }

    public void addMisses(int count) {
        if (count < 0) {
            String msg = "Miss count must be non-negative. Found count=" + count;
            throw new IllegalArgumentException(msg);
        }
        this.mPositiveRef += count;
    }

    public void addNegativeMisses(int count) {
        if (count < 0) {
            String msg = "Miss count must be non-negative. Found count=" + count;
            throw new IllegalArgumentException(msg);
        }
        this.mNegativeRef += count;
    }

    public int numCases() {
        return this.mPositiveRef + this.mNegativeRef;
    }

    public int numPositiveRef() {
        return this.mPositiveRef;
    }

    public int numNegativeRef() {
        return this.mNegativeRef;
    }

    public double rPrecision() {
        if (this.mPositiveRef == 0) {
            return 1.0;
        }
        double[][] rps = this.prCurve(false);
        return this.mPositiveRef < rps.length ? rps[this.mPositiveRef - 1][1] : rps[rps.length - 1][1] * (double)(rps.length - 1) / (double)this.mPositiveRef;
    }

    public double[] elevenPtInterpPrecision() {
        double[] xs = new double[11];
        double[][] rps = this.prCurve(true);
        double sum = 0.0;
        for (int i = 0; i <= 10; ++i) {
            xs[i] = ScoredPrecisionRecallEvaluation.precisionAtRecall(0.1 * (double)i, rps);
        }
        return xs;
    }

    static double precisionAtRecall(double recall, double[][] rps) {
        for (int i = 0; i < rps.length; ++i) {
            if (!(rps[i][0] + 1.0E-13 >= recall)) continue;
            return rps[i][1];
        }
        return 0.0;
    }

    public double averagePrecision() {
        double recall = 0.0;
        double[][] rps = this.prCurve(false);
        double sum = 0.0;
        for (double[] rp : rps) {
            if (!(rp[0] > recall)) continue;
            sum += rp[1];
            recall = rp[0];
        }
        return sum / (double)this.mPositiveRef;
    }

    public double[][] prCurve(boolean interpolate) {
        PrecisionRecallEvaluation eval = new PrecisionRecallEvaluation();
        ArrayList<double[]> prList = new ArrayList<double[]>();
        double previousScore = -1.0;
        for (Case cse : this.sortedCases()) {
            if (eval.total() > 0L && !this.epsilonEquals(cse.score(), previousScore)) {
                double r = ScoredPrecisionRecallEvaluation.div(eval.truePositive(), this.mPositiveRef);
                double p = eval.precision();
                if (r != 0.0 || p != 0.0) {
                    prList.add(new double[]{r, p});
                }
                previousScore = cse.score();
            } else if (eval.total() == 0L) {
                previousScore = cse.score();
            }
            boolean correct = cse.mCorrect;
            eval.addCase(correct, true);
            previousScore = cse.score();
        }
        double r = ScoredPrecisionRecallEvaluation.div(eval.truePositive(), this.mPositiveRef);
        double p = eval.precision();
        if (r != 0.0 || p != 0.0) {
            prList.add(new double[]{r, p});
        }
        if (r != 1.0 || p != 0.0) {
            prList.add(new double[]{1.0, 0.0});
        }
        if (((double[])prList.get(0))[0] != 0.0 || ((double[])prList.get(0))[1] != 1.0) {
            prList.add(0, new double[]{0.0, 1.0});
        }
        return ScoredPrecisionRecallEvaluation.interpolate(prList, interpolate);
    }

    public double[][] prScoreCurve(boolean interpolate) {
        PrecisionRecallEvaluation eval = new PrecisionRecallEvaluation();
        ArrayList<double[]> prList = new ArrayList<double[]>();
        double previousScore = -1.0;
        for (Case cse : this.sortedCases()) {
            if (eval.total() > 0L && !this.epsilonEquals(cse.score(), previousScore)) {
                double r = ScoredPrecisionRecallEvaluation.div(eval.truePositive(), this.mPositiveRef);
                double p = eval.precision();
                double s = previousScore;
                if (r != 0.0 || p != 0.0) {
                    prList.add(new double[]{r, p, s});
                }
                previousScore = cse.score();
            } else if (eval.total() == 0L) {
                previousScore = cse.score();
            }
            boolean correct = cse.mCorrect;
            eval.addCase(correct, true);
        }
        double r = ScoredPrecisionRecallEvaluation.div(eval.truePositive(), this.mPositiveRef);
        double p = eval.precision();
        double s = previousScore;
        if (r != 0.0 || p != 0.0) {
            prList.add(new double[]{r, p, s});
        }
        return ScoredPrecisionRecallEvaluation.interpolate(prList, interpolate);
    }

    private boolean epsilonEquals(double val1, double val2) {
        return Math.abs(val1 - val2) < 1.0E-12;
    }

    public double[][] rocCurve(boolean interpolate) {
        PrecisionRecallEvaluation eval = new PrecisionRecallEvaluation();
        List<double[]> ssList = new ArrayList<double[]>();
        int trueNegs = this.mNegativeRef;
        double previousScore = -1.0;
        for (Case cse : this.sortedCases()) {
            if (eval.total() > 0L && !this.epsilonEquals(cse.score(), previousScore)) {
                double r = ScoredPrecisionRecallEvaluation.div(eval.truePositive(), this.mPositiveRef);
                double rr = ScoredPrecisionRecallEvaluation.div(trueNegs, this.mNegativeRef);
                ssList.add(new double[]{1.0 - rr, r});
                previousScore = cse.score();
            } else if (eval.total() == 0L) {
                previousScore = cse.score();
            }
            boolean correct = cse.mCorrect;
            eval.addCase(correct, true);
            if (correct) continue;
            --trueNegs;
        }
        double r = ScoredPrecisionRecallEvaluation.div(eval.truePositive(), this.mPositiveRef);
        double rr = ScoredPrecisionRecallEvaluation.div(trueNegs, this.mNegativeRef);
        ssList.add(new double[]{1.0 - rr, r});
        if (r != 1.0 || rr != 0.0) {
            ssList.add(new double[]{1.0, 1.0});
        }
        if (((double[])ssList.get(0))[0] != 0.0 || ((double[])ssList.get(0))[1] != 0.0) {
            ssList.add(0, new double[]{0.0, 0.0});
        }
        if (interpolate) {
            ssList = ScoredPrecisionRecallEvaluation.interpolateRoc(ssList);
        }
        return (double[][])ssList.toArray((T[])EMPTY_DOUBLE_2D_ARRAY);
    }

    static List<double[]> interpolateRoc(List<double[]> ssList) {
        ArrayList<double[]> result = new ArrayList<double[]>();
        int i = 0;
        while (i + 1 < ssList.size()) {
            if (ssList.get(i)[0] != ssList.get(i + 1)[0]) {
                result.add(ssList.get(i));
            }
            ++i;
        }
        result.add(ssList.get(ssList.size() - 1));
        return result;
    }

    public double maximumFMeasure() {
        return this.maximumFMeasure(1.0);
    }

    public double maximumFMeasure(double beta) {
        double maxF = 0.0;
        double[][] pr = this.prCurve(false);
        for (int i = 0; i < pr.length; ++i) {
            double f = PrecisionRecallEvaluation.fMeasure(beta, pr[i][0], pr[i][1]);
            maxF = Math.max(maxF, f);
        }
        return maxF;
    }

    public double precisionAt(int rank) {
        if (rank < 0) {
            String msg = "Rank must be positive. Found rank=" + rank;
            throw new IllegalArgumentException(msg);
        }
        if (rank == 0) {
            return 1.0;
        }
        int correctCount = 0;
        Iterator<Case> it = this.sortedCases().iterator();
        for (int i = 0; i < rank && i < this.mCases.size(); ++i) {
            if (!it.next().mCorrect) continue;
            ++correctCount;
        }
        return (double)correctCount / (double)rank;
    }

    public double prBreakevenPoint() {
        return this.rPrecision();
    }

    public double reciprocalRank() {
        Iterator<Case> it = this.sortedCases().iterator();
        int i = 0;
        while (it.hasNext()) {
            Case cse = it.next();
            boolean correct = cse.mCorrect;
            if (correct) {
                return 1.0 / (double)(i + 1);
            }
            ++i;
        }
        return 0.0;
    }

    public double areaUnderPrCurve(boolean interpolate) {
        return ScoredPrecisionRecallEvaluation.areaUnder(this.prCurve(interpolate));
    }

    public double areaUnderRocCurve(boolean interpolate) {
        return ScoredPrecisionRecallEvaluation.areaUnder(this.rocCurve(interpolate));
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("  Area Under PR Curve (interpolated)=" + this.areaUnderPrCurve(true));
        sb.append("\n  Area Under PR Curve (uninterpolated)=" + this.areaUnderPrCurve(false));
        sb.append("\n  Area Under ROC Curve (interpolated)=" + this.areaUnderRocCurve(true));
        sb.append("\n  Area Under ROC Curve (uninterpolated)=" + this.areaUnderRocCurve(false));
        sb.append("\n  Average Precision=" + this.averagePrecision());
        sb.append("\n  Maximum F(1) Measure=" + this.maximumFMeasure());
        sb.append("\n  BEP (Precision-Recall break even point)=" + this.prBreakevenPoint());
        sb.append("\n  Reciprocal Rank=" + this.reciprocalRank());
        int[] ranks = new int[]{5, 10, 25, 100, 500};
        for (int i = 0; i < ranks.length && this.mCases.size() < ranks[i]; ++i) {
            sb.append("\n  Precision at " + ranks[i] + "=" + this.precisionAt(ranks[i]));
        }
        return sb.toString();
    }

    public static void printPrecisionRecallCurve(double[][] prCurve, PrintWriter pw) {
        pw.printf("%8s %8s %8s\n", "PRECI.", "RECALL", "F");
        for (double[] pr : prCurve) {
            pw.printf("%8.6f %8.6f %8.6f\n", pr[1], pr[0], PrecisionRecallEvaluation.fMeasure(1.0, pr[0], pr[1]));
        }
        pw.flush();
    }

    public static void printScorePrecisionRecallCurve(double[][] prScoreCurve, PrintWriter pw) {
        pw.printf("%8s %8s %8s\n", "PRECI.", "RECALL", "SCORE");
        for (double[] pr : prScoreCurve) {
            pw.printf("%8.6f %8.6f %8.6f\n", pr[1], pr[0], pr[2]);
        }
        pw.flush();
    }

    private List<Case> sortedCases() {
        Collections.sort(this.mCases, ScoredObject.reverseComparator());
        return this.mCases;
    }

    static double div(double x, double y) {
        return x / y;
    }

    private static double[][] interpolate(List<double[]> prList, boolean interpolate) {
        double[] rp2;
        if (!interpolate) {
            return (double[][])prList.toArray((T[])EMPTY_DOUBLE_2D_ARRAY);
        }
        Collections.reverse(prList);
        LinkedList<double[]> resultList = new LinkedList<double[]>();
        double minP = 0.0;
        for (double[] rp2 : prList) {
            double p = rp2[1];
            if (p > minP) {
                minP = p;
            } else {
                rp2[1] = minP;
            }
            resultList.addFirst(rp2);
        }
        LinkedList<double[]> trimmedResultList = new LinkedList<double[]>();
        rp2 = new double[]{0.0, 1.0};
        for (double[] rp2 : resultList) {
            if (rp2[0] == rp2[0]) continue;
            trimmedResultList.add(rp2);
            rp2 = rp2;
        }
        trimmedResultList.add(rp2);
        return (double[][])trimmedResultList.toArray((T[])EMPTY_DOUBLE_2D_ARRAY);
    }

    private static double areaUnder(double[][] f) {
        double area = 0.0;
        for (int i = 1; i < f.length; ++i) {
            area += ScoredPrecisionRecallEvaluation.area(f[i - 1][0], f[i - 1][1], f[i][0], f[i][1]);
        }
        return area;
    }

    private static double area(double x1, double y1, double x2, double y2) {
        return (y1 + y2) * (x2 - x1) / 2.0;
    }

    static class Case
    implements Scored {
        private final boolean mCorrect;
        private final double mScore;

        Case(boolean correct, double score) {
            this.mCorrect = correct;
            this.mScore = score;
        }

        @Override
        public double score() {
            return this.mScore;
        }

        public String toString() {
            return this.mCorrect + " : " + this.mScore;
        }
    }
}

