/*
 * Decompiled with CFR 0.152.
 */
package opennlp.tools.cmdline.doccat;

import java.io.OutputStream;
import java.io.PrintStream;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Map;
import java.util.SortedSet;
import java.util.TreeSet;
import opennlp.tools.doccat.DoccatEvaluationMonitor;
import opennlp.tools.doccat.DocumentSample;
import opennlp.tools.util.Span;
import opennlp.tools.util.eval.FMeasure;
import opennlp.tools.util.eval.Mean;

public class DoccatFineGrainedReportListener
implements DoccatEvaluationMonitor {
    private final PrintStream printStream;
    private final Stats stats = new Stats();
    private static final char[] alpha = new char[]{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'};

    public DoccatFineGrainedReportListener() {
        this(System.err);
    }

    public DoccatFineGrainedReportListener(OutputStream outputStream) {
        this.printStream = new PrintStream(outputStream);
    }

    @Override
    public void missclassified(DocumentSample reference, DocumentSample prediction) {
        this.stats.add(reference, prediction);
    }

    @Override
    public void correctlyClassified(DocumentSample reference, DocumentSample prediction) {
        this.stats.add(reference, prediction);
    }

    public void writeReport() {
        this.printGeneralStatistics();
        this.printTagsErrorRank();
        this.printGeneralConfusionTable();
    }

    public long getNumberOfSentences() {
        return this.stats.getNumberOfSentences();
    }

    public double getAverageSentenceSize() {
        return this.stats.getAverageSentenceSize();
    }

    public int getMinSentenceSize() {
        return this.stats.getMinSentenceSize();
    }

    public int getMaxSentenceSize() {
        return this.stats.getMaxSentenceSize();
    }

    public int getNumberOfTags() {
        return this.stats.getNumberOfTags();
    }

    public double getAccuracy() {
        return this.stats.getAccuracy();
    }

    public double getTokenAccuracy(String token) {
        return this.stats.getTokenAccuracy(token);
    }

    public SortedSet<String> getTokensOrderedByFrequency() {
        return this.stats.getTokensOrderedByFrequency();
    }

    public int getTokenFrequency(String token) {
        return this.stats.getTokenFrequency(token);
    }

    public int getTokenErrors(String token) {
        return this.stats.getTokenErrors(token);
    }

    public SortedSet<String> getTokensOrderedByNumberOfErrors() {
        return this.stats.getTokensOrderedByNumberOfErrors();
    }

    public SortedSet<String> getTagsOrderedByErrors() {
        return this.stats.getTagsOrderedByErrors();
    }

    public int getTagFrequency(String tag) {
        return this.stats.getTagFrequency(tag);
    }

    public int getTagErrors(String tag) {
        return this.stats.getTagErrors(tag);
    }

    public double getTagPrecision(String tag) {
        return this.stats.getTagPrecision(tag);
    }

    public double getTagRecall(String tag) {
        return this.stats.getTagRecall(tag);
    }

    public double getTagFMeasure(String tag) {
        return this.stats.getTagFMeasure(tag);
    }

    public SortedSet<String> getConfusionMatrixTagset() {
        return this.stats.getConfusionMatrixTagset();
    }

    public SortedSet<String> getConfusionMatrixTagset(String token) {
        return this.stats.getConfusionMatrixTagset(token);
    }

    public double[][] getConfusionMatrix() {
        return this.stats.getConfusionMatrix();
    }

    public double[][] getConfusionMatrix(String token) {
        return this.stats.getConfusionMatrix(token);
    }

    private String matrixToString(SortedSet<String> tagset, double[][] data, boolean filter) {
        int i;
        int initialIndex = 0;
        String[] tags = tagset.toArray(new String[tagset.size()]);
        StringBuilder sb = new StringBuilder();
        int minColumnSize = Integer.MIN_VALUE;
        String[][] matrix = new String[data.length][data[0].length];
        for (int i2 = 0; i2 < data.length; ++i2) {
            int j;
            for (j = 0; j < data[i2].length - 1; ++j) {
                String string = matrix[i2][j] = data[i2][j] > 0.0 ? Integer.toString((int)data[i2][j]) : ".";
                if (minColumnSize >= matrix[i2][j].length()) continue;
                minColumnSize = matrix[i2][j].length();
            }
            matrix[i2][j] = MessageFormat.format("{0,number,#.##%}", data[i2][j]);
            if (data[i2][j] != 1.0 || !filter) continue;
            initialIndex = i2 + 1;
        }
        String headerFormat = "%" + (minColumnSize + 2) + "s ";
        String cellFormat = "%" + (minColumnSize + 2) + "s ";
        String diagFormat = " %" + (minColumnSize + 2) + "s";
        for (i = initialIndex; i < tagset.size(); ++i) {
            sb.append(String.format(headerFormat, DoccatFineGrainedReportListener.generateAlphaLabel(i - initialIndex).trim()));
        }
        sb.append("| Accuracy | <-- classified as\n");
        for (i = initialIndex; i < data.length; ++i) {
            int j;
            for (j = initialIndex; j < data[i].length - 1; ++j) {
                if (i == j) {
                    String val = "<" + matrix[i][j] + ">";
                    sb.append(String.format(diagFormat, val));
                    continue;
                }
                sb.append(String.format(cellFormat, matrix[i][j]));
            }
            sb.append(String.format("|   %-6s |   %3s = ", matrix[i][j], DoccatFineGrainedReportListener.generateAlphaLabel(i - initialIndex))).append(tags[i]);
            sb.append("\n");
        }
        return sb.toString();
    }

    private void printGeneralStatistics() {
        this.printHeader("Evaluation summary");
        this.printStream.append(String.format("%21s: %6s", "Number of documents", Long.toString(this.getNumberOfSentences()))).append("\n");
        this.printStream.append(String.format("%21s: %6s", "Min sentence size", this.getMinSentenceSize())).append("\n");
        this.printStream.append(String.format("%21s: %6s", "Max sentence size", this.getMaxSentenceSize())).append("\n");
        this.printStream.append(String.format("%21s: %6s", "Average sentence size", MessageFormat.format("{0,number,#.##}", this.getAverageSentenceSize()))).append("\n");
        this.printStream.append(String.format("%21s: %6s", "Categories count", this.getNumberOfTags())).append("\n");
        this.printStream.append(String.format("%21s: %6s", "Accuracy", MessageFormat.format("{0,number,#.##%}", this.getAccuracy()))).append("\n");
    }

    private void printTagsErrorRank() {
        this.printHeader("Detailed Accuracy By Tag");
        SortedSet<String> tags = this.getTagsOrderedByErrors();
        this.printStream.append("\n");
        int maxTagSize = 3;
        for (String t : tags) {
            if (t.length() <= maxTagSize) continue;
            maxTagSize = t.length();
        }
        int tableSize = 65 + maxTagSize;
        String headerFormat = "| %" + maxTagSize + "s | %6s | %6s | %7s | %9s | %6s | %9s |\n";
        String format = "| %" + maxTagSize + "s | %6s | %6s | %-7s | %-9s | %-6s | %-9s |\n";
        this.printLine(tableSize);
        this.printStream.append(String.format(headerFormat, "Tag", "Errors", "Count", "% Err", "Precision", "Recall", "F-Measure"));
        this.printLine(tableSize);
        for (String tag : tags) {
            int ocurrencies = this.getTagFrequency(tag);
            int errors = this.getTagErrors(tag);
            String rate = MessageFormat.format("{0,number,#.###}", (double)errors / (double)ocurrencies);
            double p = this.getTagPrecision(tag);
            double r = this.getTagRecall(tag);
            double f = this.getTagFMeasure(tag);
            this.printStream.append(String.format(format, tag, errors, ocurrencies, rate, MessageFormat.format("{0,number,#.###}", p > 0.0 ? p : 0.0), MessageFormat.format("{0,number,#.###}", r > 0.0 ? r : 0.0), MessageFormat.format("{0,number,#.###}", f > 0.0 ? f : 0.0)));
        }
        this.printLine(tableSize);
    }

    private void printGeneralConfusionTable() {
        this.printHeader("Confusion matrix");
        SortedSet<String> labels = this.getConfusionMatrixTagset();
        double[][] confusionMatrix = this.getConfusionMatrix();
        int line = 0;
        for (String label : labels) {
            if (confusionMatrix[line][confusionMatrix[0].length - 1] == 1.0) {
                this.printStream.append(label).append(" (").append(Integer.toString((int)confusionMatrix[line][line])).append(") ");
            }
            ++line;
        }
        this.printStream.append("\n\n");
        this.printStream.append(this.matrixToString(labels, confusionMatrix, true));
    }

    private void printHeader(String text) {
        this.printStream.append("\n=== ").append(text).append(" ===\n");
    }

    private void printLine(int size) {
        for (int i = 0; i < size; ++i) {
            this.printStream.append("-");
        }
        this.printStream.append("\n");
    }

    private static String generateAlphaLabel(int index) {
        char[] labelChars = new char[3];
        for (int i = 2; i >= 0; --i) {
            labelChars[i] = alpha[index % alpha.length];
            if ((index = index / alpha.length - 1) < 0) break;
        }
        return new String(labelChars);
    }

    private static class Counter {
        private int c = 0;

        private Counter() {
        }

        public void increment() {
            ++this.c;
        }

        public int value() {
            return this.c;
        }
    }

    private static class ConfusionMatrixLine {
        private Map<String, Counter> line = new HashMap<String, Counter>();
        private String ref;
        private int total = 0;
        private int correct = 0;
        private double acc = -1.0;

        public ConfusionMatrixLine(String ref) {
            this.ref = ref;
        }

        public void increment(String column) {
            ++this.total;
            if (column.equals(this.ref)) {
                ++this.correct;
            }
            if (!this.line.containsKey(column)) {
                this.line.put(column, new Counter());
            }
            this.line.get(column).increment();
        }

        public double getAccuracy() {
            if (this.acc == -1.0) {
                if (this.total == 0) {
                    this.acc = 0.0;
                }
                this.acc = (double)this.correct / (double)this.total;
            }
            return this.acc;
        }

        public int getValue(String column) {
            Counter c = this.line.get(column);
            if (c == null) {
                return 0;
            }
            return c.value();
        }
    }

    private static class CategoryComparator
    implements Comparator<String> {
        private Map<String, ConfusionMatrixLine> confusionMatrix;

        public CategoryComparator(Map<String, ConfusionMatrixLine> confusionMatrix) {
            this.confusionMatrix = confusionMatrix;
        }

        @Override
        public int compare(String o1, String o2) {
            double r2;
            if (o1.equals(o2)) {
                return 0;
            }
            ConfusionMatrixLine t1 = this.confusionMatrix.get(o1);
            ConfusionMatrixLine t2 = this.confusionMatrix.get(o2);
            if (t1 == null || t2 == null) {
                if (t1 == null) {
                    return 1;
                }
                if (t2 == null) {
                    return -1;
                }
                return 0;
            }
            double r1 = t1.getAccuracy();
            if (r1 == (r2 = t2.getAccuracy())) {
                return o1.compareTo(o2);
            }
            if (r2 > r1) {
                return 1;
            }
            return -1;
        }
    }

    private class Stats {
        private final Mean accuracy = new Mean();
        private final Mean averageSentenceLength = new Mean();
        private int minimalSentenceLength = Integer.MAX_VALUE;
        private int maximumSentenceLength = Integer.MIN_VALUE;
        private final Map<String, Mean> tokAccuracies = new HashMap<String, Mean>();
        private final Map<String, Counter> tokOcurrencies = new HashMap<String, Counter>();
        private final Map<String, Counter> tokErrors = new HashMap<String, Counter>();
        private final Map<String, Counter> tagOcurrencies = new HashMap<String, Counter>();
        private final Map<String, Counter> tagErrors = new HashMap<String, Counter>();
        private final Map<String, FMeasure> tagFMeasure = new HashMap<String, FMeasure>();
        private final Map<String, ConfusionMatrixLine> generalConfusionMatrix = new HashMap<String, ConfusionMatrixLine>();
        private final Map<String, Map<String, ConfusionMatrixLine>> tokenConfusionMatrix = new HashMap<String, Map<String, ConfusionMatrixLine>>();

        private Stats() {
        }

        public void add(DocumentSample reference, DocumentSample prediction) {
            int length = reference.getText().length;
            this.averageSentenceLength.add(length);
            if (this.minimalSentenceLength > length) {
                this.minimalSentenceLength = length;
            }
            if (this.maximumSentenceLength < length) {
                this.maximumSentenceLength = length;
            }
            String[] refs = new String[]{reference.getCategory()};
            String[] preds = new String[]{prediction.getCategory()};
            this.updateTagFMeasure(refs, preds);
            this.add("xx", reference.getCategory(), prediction.getCategory());
        }

        private void add(String tok, String ref, String pred) {
            if (!this.tokAccuracies.containsKey(tok)) {
                this.tokAccuracies.put(tok, new Mean());
                this.tokOcurrencies.put(tok, new Counter());
                this.tokErrors.put(tok, new Counter());
            }
            this.tokOcurrencies.get(tok).increment();
            if (!this.tagOcurrencies.containsKey(ref)) {
                this.tagOcurrencies.put(ref, new Counter());
                this.tagErrors.put(ref, new Counter());
            }
            this.tagOcurrencies.get(ref).increment();
            if (ref.equals(pred)) {
                this.tokAccuracies.get(tok).add(1.0);
                this.accuracy.add(1.0);
            } else {
                this.tokAccuracies.get(tok).add(0.0);
                this.tokErrors.get(tok).increment();
                this.tagErrors.get(ref).increment();
                this.accuracy.add(0.0);
            }
            if (!this.generalConfusionMatrix.containsKey(ref)) {
                this.generalConfusionMatrix.put(ref, new ConfusionMatrixLine(ref));
            }
            this.generalConfusionMatrix.get(ref).increment(pred);
            if (!this.tokenConfusionMatrix.containsKey(tok)) {
                this.tokenConfusionMatrix.put(tok, new HashMap());
            }
            if (!this.tokenConfusionMatrix.get(tok).containsKey(ref)) {
                this.tokenConfusionMatrix.get(tok).put(ref, new ConfusionMatrixLine(ref));
            }
            this.tokenConfusionMatrix.get(tok).get(ref).increment(pred);
        }

        private void updateTagFMeasure(String[] refs, String[] preds) {
            HashSet<String> tags = new HashSet<String>(Arrays.asList(refs));
            tags.addAll(Arrays.asList(preds));
            for (String tag : tags) {
                ArrayList<Span> reference = new ArrayList<Span>();
                ArrayList<Span> prediction = new ArrayList<Span>();
                for (int i = 0; i < refs.length; ++i) {
                    if (refs[i].equals(tag)) {
                        reference.add(new Span(i, i + 1));
                    }
                    if (!preds[i].equals(tag)) continue;
                    prediction.add(new Span(i, i + 1));
                }
                if (!this.tagFMeasure.containsKey(tag)) {
                    this.tagFMeasure.put(tag, new FMeasure());
                }
                this.tagFMeasure.get(tag).updateScores(reference.toArray(new Span[reference.size()]), prediction.toArray(new Span[prediction.size()]));
            }
        }

        public double getAccuracy() {
            return this.accuracy.mean();
        }

        public int getNumberOfTags() {
            return this.tagOcurrencies.keySet().size();
        }

        public long getNumberOfSentences() {
            return this.averageSentenceLength.count();
        }

        public double getAverageSentenceSize() {
            return this.averageSentenceLength.mean();
        }

        public int getMinSentenceSize() {
            return this.minimalSentenceLength;
        }

        public int getMaxSentenceSize() {
            return this.maximumSentenceLength;
        }

        public double getTokenAccuracy(String token) {
            return this.tokAccuracies.get(token).mean();
        }

        public int getTokenErrors(String token) {
            return this.tokErrors.get(token).value();
        }

        public int getTokenFrequency(String token) {
            return this.tokOcurrencies.get(token).value();
        }

        public SortedSet<String> getTokensOrderedByFrequency() {
            TreeSet<String> toks = new TreeSet<String>((o1, o2) -> {
                if (o1.equals(o2)) {
                    return 0;
                }
                int e1 = 0;
                int e2 = 0;
                if (this.tokOcurrencies.containsKey(o1)) {
                    e1 = this.tokOcurrencies.get(o1).value();
                }
                if (this.tokOcurrencies.containsKey(o2)) {
                    e2 = this.tokOcurrencies.get(o2).value();
                }
                if (e1 == e2) {
                    return o1.compareTo((String)o2);
                }
                return e2 - e1;
            });
            toks.addAll(this.tokOcurrencies.keySet());
            return Collections.unmodifiableSortedSet(toks);
        }

        public SortedSet<String> getTokensOrderedByNumberOfErrors() {
            TreeSet<String> toks = new TreeSet<String>((o1, o2) -> {
                if (o1.equals(o2)) {
                    return 0;
                }
                int e1 = 0;
                int e2 = 0;
                if (this.tokErrors.containsKey(o1)) {
                    e1 = this.tokErrors.get(o1).value();
                }
                if (this.tokErrors.containsKey(o2)) {
                    e2 = this.tokErrors.get(o2).value();
                }
                if (e1 == e2) {
                    return o1.compareTo((String)o2);
                }
                return e2 - e1;
            });
            toks.addAll(this.tokErrors.keySet());
            return toks;
        }

        public int getTagFrequency(String tag) {
            return this.tagOcurrencies.get(tag).value();
        }

        public int getTagErrors(String tag) {
            return this.tagErrors.get(tag).value();
        }

        public double getTagFMeasure(String tag) {
            return this.tagFMeasure.get(tag).getFMeasure();
        }

        public double getTagRecall(String tag) {
            return this.tagFMeasure.get(tag).getRecallScore();
        }

        public double getTagPrecision(String tag) {
            return this.tagFMeasure.get(tag).getPrecisionScore();
        }

        public SortedSet<String> getTagsOrderedByErrors() {
            TreeSet<String> tags = new TreeSet<String>((o1, o2) -> {
                if (o1.equals(o2)) {
                    return 0;
                }
                int e1 = 0;
                int e2 = 0;
                if (this.tagErrors.containsKey(o1)) {
                    e1 = this.tagErrors.get(o1).value();
                }
                if (this.tagErrors.containsKey(o2)) {
                    e2 = this.tagErrors.get(o2).value();
                }
                if (e1 == e2) {
                    return o1.compareTo((String)o2);
                }
                return e2 - e1;
            });
            tags.addAll(this.tagErrors.keySet());
            return Collections.unmodifiableSortedSet(tags);
        }

        public SortedSet<String> getConfusionMatrixTagset() {
            return this.getConfusionMatrixTagset(this.generalConfusionMatrix);
        }

        public double[][] getConfusionMatrix() {
            return this.createConfusionMatrix(this.getConfusionMatrixTagset(), this.generalConfusionMatrix);
        }

        public SortedSet<String> getConfusionMatrixTagset(String token) {
            return this.getConfusionMatrixTagset(this.tokenConfusionMatrix.get(token));
        }

        public double[][] getConfusionMatrix(String token) {
            return this.createConfusionMatrix(this.getConfusionMatrixTagset(token), this.tokenConfusionMatrix.get(token));
        }

        private double[][] createConfusionMatrix(SortedSet<String> tagset, Map<String, ConfusionMatrixLine> data) {
            int size = tagset.size();
            double[][] matrix = new double[size][size + 1];
            int line = 0;
            for (String ref : tagset) {
                int column = 0;
                for (String pred : tagset) {
                    matrix[line][column] = data.get(ref) != null ? (double)data.get(ref).getValue(pred) : 0.0;
                    ++column;
                }
                matrix[line][column] = data.get(ref) != null ? data.get(ref).getAccuracy() : 0.0;
                ++line;
            }
            return matrix;
        }

        private SortedSet<String> getConfusionMatrixTagset(Map<String, ConfusionMatrixLine> data) {
            TreeSet<String> tags = new TreeSet<String>(new CategoryComparator(data));
            tags.addAll(data.keySet());
            LinkedList col = new LinkedList();
            for (String t : tags) {
                col.addAll(data.get(t).line.keySet());
            }
            tags.addAll(col);
            return Collections.unmodifiableSortedSet(tags);
        }
    }
}

