/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.classify;

import com.aliasi.classify.PrecisionRecallEvaluation;
import com.aliasi.stats.Statistics;
import com.aliasi.util.Math;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class ConfusionMatrix {
    private final String[] mCategories;
    private final int[][] mMatrix;
    private final Map<String, Integer> mCategoryToIndex = new HashMap<String, Integer>();

    public ConfusionMatrix(String[] categories) {
        int i;
        this.mCategories = (String[])categories.clone();
        int len = categories.length;
        this.mMatrix = new int[len][len];
        for (i = 0; i < len; ++i) {
            for (int j = 0; j < len; ++j) {
                this.mMatrix[i][j] = 0;
            }
        }
        for (i = 0; i < len; ++i) {
            this.mCategoryToIndex.put(categories[i], i);
        }
    }

    public ConfusionMatrix(String[] categories, int[][] matrix) {
        this.mCategories = categories;
        this.mMatrix = matrix;
        if (categories.length != matrix.length) {
            String msg = "Categories and matrix must be of same length. Found categories length=" + categories.length + " and matrix length=" + matrix.length;
            throw new IllegalArgumentException(msg);
        }
        for (int j = 0; j < matrix.length; ++j) {
            if (categories.length == matrix[j].length) continue;
            String msg = "Categories and all matrix rows must be of same length. Found categories length=" + categories.length + " Found row " + j + " length=" + matrix[j].length;
            throw new IllegalArgumentException(msg);
        }
        int len = matrix.length;
        for (int i = 0; i < len; ++i) {
            for (int j = 0; j < len; ++j) {
                if (matrix[i][j] >= 0) continue;
                String msg = "Matrix entries must be non-negative. matrix[" + i + "][" + j + "]=" + matrix[i][j];
                throw new IllegalArgumentException(msg);
            }
        }
    }

    public String[] categories() {
        return (String[])this.mCategories.clone();
    }

    public int numCategories() {
        return this.categories().length;
    }

    public int getIndex(String category) {
        Integer index = this.mCategoryToIndex.get(category);
        if (index == null) {
            return -1;
        }
        return index;
    }

    public int[][] matrix() {
        return (int[][])this.mMatrix.clone();
    }

    public void increment(int referenceCategoryIndex, int responseCategoryIndex) {
        this.checkIndex("reference", referenceCategoryIndex);
        this.checkIndex("response", responseCategoryIndex);
        int[] nArray = this.mMatrix[referenceCategoryIndex];
        int n = responseCategoryIndex;
        nArray[n] = nArray[n] + 1;
    }

    public void incrementByN(int referenceCategoryIndex, int responseCategoryIndex, int num) {
        this.checkIndex("reference", referenceCategoryIndex);
        this.checkIndex("response", responseCategoryIndex);
        if (this.mMatrix[referenceCategoryIndex][responseCategoryIndex] + num < 0) {
            String msg = "Cannot decrement to less than 0 value. referenceCategoryIndex=" + referenceCategoryIndex + " responseCategoryIndex=" + responseCategoryIndex + " matrix[referenceCategoryIndex][responseCategoryIndex]=" + this.mMatrix[referenceCategoryIndex][referenceCategoryIndex] + " increment=" + num;
            throw new IllegalArgumentException(msg);
        }
        int[] nArray = this.mMatrix[referenceCategoryIndex];
        int n = responseCategoryIndex;
        nArray[n] = nArray[n] + num;
    }

    public void increment(String referenceCategory, String responseCategory) {
        this.increment(this.getIndex(referenceCategory), this.getIndex(responseCategory));
    }

    public int count(int referenceCategoryIndex, int responseCategoryIndex) {
        this.checkIndex("reference", referenceCategoryIndex);
        this.checkIndex("response", responseCategoryIndex);
        return this.mMatrix[referenceCategoryIndex][responseCategoryIndex];
    }

    public int totalCount() {
        int total = 0;
        int len = this.numCategories();
        for (int i = 0; i < len; ++i) {
            for (int j = 0; j < len; ++j) {
                total += this.mMatrix[i][j];
            }
        }
        return total;
    }

    public int totalCorrect() {
        int total = 0;
        int len = this.numCategories();
        for (int i = 0; i < len; ++i) {
            total += this.mMatrix[i][i];
        }
        return total;
    }

    public double totalAccuracy() {
        return (double)this.totalCorrect() / (double)this.totalCount();
    }

    public double confidence95() {
        return this.confidence(1.96);
    }

    public double confidence99() {
        return this.confidence(2.58);
    }

    public double confidence(double z) {
        double p = this.totalAccuracy();
        double n = this.totalCount();
        return z * java.lang.Math.sqrt(p * (1.0 - p) / n);
    }

    public double referenceEntropy() {
        double sum = 0.0;
        for (int i = 0; i < this.numCategories(); ++i) {
            double prob = this.oneVsAll(i).referenceLikelihood();
            sum += prob * Math.log2(prob);
        }
        return -sum;
    }

    public double responseEntropy() {
        double sum = 0.0;
        for (int i = 0; i < this.numCategories(); ++i) {
            double prob = this.oneVsAll(i).responseLikelihood();
            sum += prob * Math.log2(prob);
        }
        return -sum;
    }

    public double crossEntropy() {
        double sum = 0.0;
        for (int i = 0; i < this.numCategories(); ++i) {
            PrecisionRecallEvaluation eval = this.oneVsAll(i);
            double referenceProb = eval.referenceLikelihood();
            double responseProb = eval.responseLikelihood();
            sum += referenceProb * Math.log2(responseProb);
        }
        return -sum;
    }

    public double jointEntropy() {
        double totalCount = this.totalCount();
        double entropySum = 0.0;
        for (int i = 0; i < this.numCategories(); ++i) {
            for (int j = 0; j < this.numCategories(); ++j) {
                double prob = (double)this.count(i, j) / totalCount;
                if (prob <= 0.0) continue;
                entropySum += prob * Math.log2(prob);
            }
        }
        return -entropySum;
    }

    public double conditionalEntropy(int refCategoryIndex) {
        double entropySum = 0.0;
        long refCount = this.oneVsAll(refCategoryIndex).positiveReference();
        for (int j = 0; j < this.numCategories(); ++j) {
            double conditionalProb = (double)this.count(refCategoryIndex, j) / (double)refCount;
            if (conditionalProb <= 0.0) continue;
            entropySum += conditionalProb * Math.log2(conditionalProb);
        }
        return -entropySum;
    }

    public double conditionalEntropy() {
        double entropySum = 0.0;
        for (int i = 0; i < this.numCategories(); ++i) {
            double refProbI = this.oneVsAll(i).referenceLikelihood();
            entropySum += refProbI * this.conditionalEntropy(i);
        }
        return entropySum;
    }

    public double kappa() {
        return this.kappa(this.randomAccuracy());
    }

    public double kappaUnbiased() {
        return this.kappa(this.randomAccuracyUnbiased());
    }

    public double kappaNoPrevalence() {
        return 2.0 * this.totalAccuracy() - 1.0;
    }

    private double kappa(double PE) {
        double PA = this.totalAccuracy();
        return (PA - PE) / (1.0 - PE);
    }

    public double randomAccuracy() {
        double randomAccuracy = 0.0;
        for (int i = 0; i < this.numCategories(); ++i) {
            PrecisionRecallEvaluation eval = this.oneVsAll(i);
            randomAccuracy += eval.referenceLikelihood() * eval.responseLikelihood();
        }
        return randomAccuracy;
    }

    public double randomAccuracyUnbiased() {
        double randomAccuracy = 0.0;
        for (int i = 0; i < this.numCategories(); ++i) {
            PrecisionRecallEvaluation eval = this.oneVsAll(i);
            double avgLikelihood = (eval.referenceLikelihood() + eval.responseLikelihood()) / 2.0;
            randomAccuracy += avgLikelihood * avgLikelihood;
        }
        return randomAccuracy;
    }

    public int chiSquaredDegreesOfFreedom() {
        int sqrt = this.numCategories() - 1;
        return sqrt * sqrt;
    }

    public double chiSquared() {
        int numCategories = this.numCategories();
        double[][] contingencyMatrix = new double[numCategories][numCategories];
        for (int i = 0; i < numCategories; ++i) {
            for (int j = 0; j < numCategories; ++j) {
                contingencyMatrix[i][j] = this.count(i, j);
            }
        }
        return Statistics.chiSquaredIndependence(contingencyMatrix);
    }

    public double phiSquared() {
        return this.chiSquared() / (double)this.totalCount();
    }

    public double cramersV() {
        double LMinusOne = this.numCategories() - 1;
        return java.lang.Math.sqrt(this.phiSquared() / LMinusOne);
    }

    public PrecisionRecallEvaluation oneVsAll(int categoryIndex) {
        PrecisionRecallEvaluation eval = new PrecisionRecallEvaluation();
        for (int i = 0; i < this.numCategories(); ++i) {
            for (int j = 0; j < this.numCategories(); ++j) {
                eval.addCase(i == categoryIndex, j == categoryIndex, this.mMatrix[i][j]);
            }
        }
        return eval;
    }

    public PrecisionRecallEvaluation microAverage() {
        long tp = 0L;
        long fp = 0L;
        long fn = 0L;
        long tn = 0L;
        for (int i = 0; i < this.numCategories(); ++i) {
            PrecisionRecallEvaluation eval = this.oneVsAll(i);
            tp += eval.truePositive();
            fp += eval.falsePositive();
            tn += eval.trueNegative();
            fn += eval.falseNegative();
        }
        return new PrecisionRecallEvaluation(tp, fn, fp, tn);
    }

    public double macroAvgPrecision() {
        double sum = 0.0;
        for (int i = 0; i < this.numCategories(); ++i) {
            sum += this.oneVsAll(i).precision();
        }
        return sum / (double)this.numCategories();
    }

    public double macroAvgRecall() {
        double sum = 0.0;
        for (int i = 0; i < this.numCategories(); ++i) {
            sum += this.oneVsAll(i).recall();
        }
        return sum / (double)this.numCategories();
    }

    public double macroAvgFMeasure() {
        double sum = 0.0;
        for (int i = 0; i < this.numCategories(); ++i) {
            sum += this.oneVsAll(i).fMeasure();
        }
        return sum / (double)this.numCategories();
    }

    public double lambdaA() {
        double maxReferenceCount = 0.0;
        for (int j = 0; j < this.numCategories(); ++j) {
            double referenceCount = this.oneVsAll(j).positiveReference();
            if (!(referenceCount > maxReferenceCount)) continue;
            maxReferenceCount = referenceCount;
        }
        double maxCountSum = 0.0;
        for (int j = 0; j < this.numCategories(); ++j) {
            int maxCount = 0;
            for (int i = 0; i < this.numCategories(); ++i) {
                int count = this.count(i, j);
                if (count <= maxCount) continue;
                maxCount = count;
            }
            maxCountSum += (double)maxCount;
        }
        double totalCount = this.totalCount();
        return (maxCountSum - maxReferenceCount) / (totalCount - maxReferenceCount);
    }

    public double lambdaB() {
        double maxResponseCount = 0.0;
        for (int i = 0; i < this.numCategories(); ++i) {
            double responseCount = this.oneVsAll(i).positiveResponse();
            if (!(responseCount > maxResponseCount)) continue;
            maxResponseCount = responseCount;
        }
        double maxCountSum = 0.0;
        for (int i = 0; i < this.numCategories(); ++i) {
            int maxCount = 0;
            for (int j = 0; j < this.numCategories(); ++j) {
                int count = this.count(i, j);
                if (count <= maxCount) continue;
                maxCount = count;
            }
            maxCountSum += (double)maxCount;
        }
        double totalCount = this.totalCount();
        return (maxCountSum - maxResponseCount) / (totalCount - maxResponseCount);
    }

    public double mutualInformation() {
        double totalCount = this.totalCount();
        double sum = 0.0;
        for (int i = 0; i < this.numCategories(); ++i) {
            double pI = this.oneVsAll(i).referenceLikelihood();
            if (pI <= 0.0) continue;
            for (int j = 0; j < this.numCategories(); ++j) {
                double pIJ;
                double pJ = this.oneVsAll(j).responseLikelihood();
                if (pJ <= 0.0 || (pIJ = (double)this.count(i, j) / totalCount) <= 0.0) continue;
                sum += pIJ * Math.log2(pIJ / (pI * pJ));
            }
        }
        return sum;
    }

    public double klDivergence() {
        double sum = 0.0;
        for (int k = 0; k < this.numCategories(); ++k) {
            PrecisionRecallEvaluation eval = this.oneVsAll(k);
            double refProb = eval.referenceLikelihood();
            double responseProb = eval.responseLikelihood();
            sum += refProb * Math.log2(refProb / responseProb);
        }
        return sum;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("GLOBAL CONFUSION MATRIX STATISTICS\n");
        this.toStringGlobal(sb);
        for (int i = 0; i < this.numCategories(); ++i) {
            sb.append("CATEGORY " + i + "=" + this.categories()[i] + " VS. ALL\n");
            sb.append("  Conditional Entropy=" + this.conditionalEntropy(i));
            sb.append('\n');
            sb.append(this.oneVsAll(i).toString());
            sb.append('\n');
        }
        return sb.toString();
    }

    void toStringGlobal(StringBuilder sb) {
        String[] categories = this.categories();
        sb.append("Categories=" + Arrays.asList(categories));
        sb.append('\n');
        sb.append("Total Count=" + this.totalCount());
        sb.append('\n');
        sb.append("Total Correct=" + this.totalCorrect());
        sb.append('\n');
        sb.append("Total Accuracy=" + this.totalAccuracy());
        sb.append('\n');
        sb.append("95% Confidence Interval=" + this.totalAccuracy() + " +/- " + this.confidence95());
        sb.append('\n');
        sb.append("Confusion Matrix\n");
        sb.append("reference \\ response\n");
        sb.append(this.matrixToCSV());
        sb.append('\n');
        sb.append("Macro-averaged Precision=" + this.macroAvgPrecision());
        sb.append('\n');
        sb.append("Macro-averaged Recall=" + this.macroAvgRecall());
        sb.append('\n');
        sb.append("Macro-averaged F=" + this.macroAvgFMeasure());
        sb.append('\n');
        sb.append("Micro-averaged Results\n");
        sb.append("         the following symmetries are expected:\n");
        sb.append("           TP=TN, FN=FP\n");
        sb.append("           PosRef=PosResp=NegRef=NegResp\n");
        sb.append("           Acc=Prec=Rec=F\n");
        sb.append(this.microAverage().toString());
        sb.append('\n');
        sb.append("Random Accuracy=" + this.randomAccuracy());
        sb.append('\n');
        sb.append("Random Accuracy Unbiased=" + this.randomAccuracyUnbiased());
        sb.append('\n');
        sb.append("kappa=" + this.kappa());
        sb.append('\n');
        sb.append("kappa Unbiased=" + this.kappaUnbiased());
        sb.append('\n');
        sb.append("kappa No Prevalence =" + this.kappaNoPrevalence());
        sb.append('\n');
        sb.append("Reference Entropy=" + this.referenceEntropy());
        sb.append('\n');
        sb.append("Response Entropy=" + this.responseEntropy());
        sb.append('\n');
        sb.append("Cross Entropy=" + this.crossEntropy());
        sb.append('\n');
        sb.append("Joint Entropy=" + this.jointEntropy());
        sb.append('\n');
        sb.append("Conditional Entropy=" + this.conditionalEntropy());
        sb.append('\n');
        sb.append("Mutual Information=" + this.mutualInformation());
        sb.append('\n');
        sb.append("Kullback-Liebler Divergence=" + this.klDivergence());
        sb.append('\n');
        sb.append("chi Squared=" + this.chiSquared());
        sb.append('\n');
        sb.append("chi-Squared Degrees of Freedom=" + this.chiSquaredDegreesOfFreedom());
        sb.append('\n');
        sb.append("phi Squared=" + this.phiSquared());
        sb.append('\n');
        sb.append("Cramer's V=" + this.cramersV());
        sb.append('\n');
        sb.append("lambda A=" + this.lambdaA());
        sb.append('\n');
        sb.append("lambda B=" + this.lambdaB());
        sb.append('\n');
    }

    String matrixToCSV() {
        int i;
        StringBuilder sb = new StringBuilder();
        sb.append("  ");
        for (i = 0; i < this.numCategories(); ++i) {
            sb.append(',');
            sb.append(this.categories()[i]);
        }
        for (i = 0; i < this.numCategories(); ++i) {
            sb.append("\n  ");
            sb.append(this.categories()[i]);
            for (int j = 0; j < this.numCategories(); ++j) {
                sb.append(',');
                sb.append(this.count(i, j));
            }
        }
        return sb.toString();
    }

    String matrixToHTML() {
        int i;
        StringBuilder sb = new StringBuilder();
        sb.append("<html>\n");
        sb.append("<table border='1' cellpadding='5'>");
        sb.append('\n');
        sb.append("<tr>\n  <td colspan='2' rowspan='2'>&nbsp;</td>");
        sb.append("\n  <td colspan='" + this.numCategories() + "' align='center' bgcolor='darkgray'><b>Response</b></td></tr>");
        sb.append("<tr>");
        for (i = 0; i < this.numCategories(); ++i) {
            sb.append("\n  <td align='right' bgcolor='lightgray'><i>" + this.categories()[i] + "</i></td>");
        }
        sb.append("</tr>\n");
        for (i = 0; i < this.numCategories(); ++i) {
            sb.append("<tr>");
            if (i == 0) {
                sb.append("\n  <td rowspan='" + this.numCategories() + "' bgcolor='darkgray'><b>Ref-<br>erence</b></td>");
            }
            sb.append("\n  <td align='right' bgcolor='lightgray'><i>" + this.categories()[i] + "</i></td>");
            for (int j = 0; j < this.numCategories(); ++j) {
                if (i == j) {
                    sb.append("\n  <td align='right' bgcolor='lightgreen'>");
                } else if (this.count(i, j) == 0) {
                    sb.append("\n  <td align='right' bgcolor='yellow'>");
                } else {
                    sb.append("\n  <td align='right' bgcolor='red'>");
                }
                sb.append(this.count(i, j));
                sb.append("</td>");
            }
            sb.append("</tr>\n");
        }
        sb.append("</table>\n");
        sb.append("</html>\n");
        return sb.toString();
    }

    private void checkIndex(String argMsg, int index) {
        if (index < 0) {
            String msg = "Index for " + argMsg + " must be > 0. Found index=" + index;
            throw new IllegalArgumentException(msg);
        }
        if (index >= this.numCategories()) {
            String msg = "Index for " + argMsg + " must be < numCategories()=" + this.numCategories();
            throw new IllegalArgumentException(msg);
        }
    }
}

