/*
 * Decompiled with CFR 0.152.
 */
package hex;

import java.util.Arrays;
import water.Iced;
import water.Key;
import water.MRTask;
import water.Scope;
import water.fvec.CategoricalWrappedVec;
import water.fvec.Chunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

public class ConfusionMatrix
extends Iced {
    private TwoDimTable _table;
    public final double[][] _cm;
    public final String[] _domain;
    public static final int MAX_CM_CLASSES = 1000;

    public ConfusionMatrix(double[][] value, String[] domain) {
        this._cm = value;
        this._domain = domain;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static ConfusionMatrix buildCM(Vec actuals, Vec predictions) {
        if (!actuals.isCategorical()) {
            throw new IllegalArgumentException("actuals must be categorical.");
        }
        if (!predictions.isCategorical()) {
            throw new IllegalArgumentException("predictions must be categorical.");
        }
        Scope.enter();
        try {
            CategoricalWrappedVec adapted = predictions.adaptTo(actuals.domain());
            int len = actuals.domain().length;
            CMBuilder cm = (CMBuilder)new CMBuilder(len).doAll(actuals, adapted);
            ConfusionMatrix confusionMatrix = new ConfusionMatrix(cm._arr, actuals.domain());
            return confusionMatrix;
        }
        finally {
            Scope.exit(new Key[0]);
        }
    }

    public void add(int i, int j) {
        double[] dArray = this._cm[i];
        int n = j;
        dArray[n] = dArray[n] + 1.0;
    }

    public final int size() {
        return this._domain.length;
    }

    boolean tooLarge() {
        return this.size() > 1000;
    }

    public final double mean_per_class_error() {
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("mean per class error cannot be computed: too many classes");
        }
        double err = 0.0;
        for (int d = 0; d < this._cm.length; ++d) {
            err += this.class_error(d);
        }
        return err / (double)this._cm.length;
    }

    public final double mean_per_class_accuracy() {
        return 1.0 - this.mean_per_class_error();
    }

    public final double class_error(int c) {
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("class errors cannot be computed: too many classes");
        }
        double s = ArrayUtils.sum(this._cm[c]);
        if (s == 0.0) {
            return 0.0;
        }
        return (s - this._cm[c][c]) / s;
    }

    public double total_rows() {
        double n = 0.0;
        for (double[] a_arr : this._cm) {
            n += ArrayUtils.sum(a_arr);
        }
        return n;
    }

    public void add(ConfusionMatrix other) {
        if (this._cm != null && other._cm != null) {
            ArrayUtils.add(this._cm, other._cm);
        }
    }

    public double err() {
        double n;
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("error cannot be computed: too many classes");
        }
        double err = n = this.total_rows();
        for (int d = 0; d < this._cm.length; ++d) {
            err -= this._cm[d][d];
        }
        return err / n;
    }

    public double err_count() {
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("error count cannot be computed: too many classes");
        }
        double err = this.total_rows();
        for (int d = 0; d < this._cm.length; ++d) {
            err -= this._cm[d][d];
        }
        assert (err >= 0.0);
        return err;
    }

    public double accuracy() {
        return 1.0 - this.err();
    }

    public double specificity() {
        if (!this.isBinary()) {
            throw new UnsupportedOperationException("specificity is only implemented for 2 class problems.");
        }
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("specificity cannot be computed: too many classes");
        }
        double tn = this._cm[0][0];
        double fp = this._cm[0][1];
        return tn / (tn + fp);
    }

    public double recall() {
        if (!this.isBinary()) {
            throw new UnsupportedOperationException("recall is only implemented for 2 class problems.");
        }
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("recall cannot be computed: too many classes");
        }
        double tp = this._cm[1][1];
        double fn = this._cm[1][0];
        return tp / (tp + fn);
    }

    public double precision() {
        if (!this.isBinary()) {
            throw new UnsupportedOperationException("precision is only implemented for 2 class problems.");
        }
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("precision cannot be computed: too many classes");
        }
        double tp = this._cm[1][1];
        double fp = this._cm[0][1];
        return tp / (tp + fp);
    }

    public double mcc() {
        if (!this.isBinary()) {
            throw new UnsupportedOperationException("mcc is only implemented for 2 class problems.");
        }
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("mcc cannot be computed: too many classes");
        }
        double tn = this._cm[0][0];
        double fp = this._cm[0][1];
        double tp = this._cm[1][1];
        double fn = this._cm[1][0];
        return (tp * tn - fp * fn) / Math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn));
    }

    public double max_per_class_error() {
        int n = this.nclasses();
        if (n == 0) {
            throw new UnsupportedOperationException("max per class error is only defined for classification problems");
        }
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("max per class error cannot be computed: too many classes");
        }
        double res = this.class_error(0);
        for (int i = 1; i < n; ++i) {
            res = Math.max(res, this.class_error(i));
        }
        return res;
    }

    public final int nclasses() {
        return this._domain == null ? 0 : this._domain.length;
    }

    public final boolean isBinary() {
        return this.nclasses() == 2;
    }

    public double f1() {
        double precision = this.precision();
        double recall = this.recall();
        return 2.0 * (precision * recall) / (precision + recall);
    }

    public double f2() {
        double precision = this.precision();
        double recall = this.recall();
        return 5.0 * (precision * recall) / (4.0 * precision + recall);
    }

    public double f0point5() {
        double precision = this.precision();
        double recall = this.recall();
        return 1.25 * (precision * recall) / (0.25 * precision + recall);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (double[] r : this._cm) {
            sb.append(Arrays.toString(r)).append('\n');
        }
        return sb.toString();
    }

    private static String[] createConfusionMatrixHeader(double[] xs, String[] ds) {
        String[] ss = new String[xs.length];
        for (int i = 0; i < xs.length; ++i) {
            if (!(xs[i] >= 0.0) && (ds[i] == null || ds[i].length() <= 0 || Double.toString(i).equals(ds[i]))) continue;
            ss[i] = ds[i];
        }
        if (ds.length == xs.length - 1 && xs[xs.length - 1] > 0.0) {
            ss[xs.length - 1] = "NA";
        }
        return ss;
    }

    public String toASCII() {
        return this.table() == null ? "" : this._table.toString();
    }

    public TwoDimTable table() {
        return this._table == null ? (this._table = this.toTable()) : this._table;
    }

    private TwoDimTable toTable() {
        if (this.tooLarge()) {
            return null;
        }
        if (this._cm == null || this._domain == null) {
            return null;
        }
        for (double[] cm : this._cm) {
            assert (this._cm.length == cm.length);
        }
        double[] acts = new double[this._cm.length];
        double[] preds = new double[this._cm[0].length];
        boolean isInt = true;
        for (int a = 0; a < this._cm.length; ++a) {
            double sum = 0.0;
            for (int p = 0; p < this._cm[a].length; ++p) {
                sum += this._cm[a][p];
                int n = p;
                preds[n] = preds[n] + this._cm[a][p];
                isInt &= this._cm[a][p] == (double)((long)this._cm[a][p]);
            }
            acts[a] = sum;
        }
        String[] adomain = ConfusionMatrix.createConfusionMatrixHeader(acts, this._domain);
        String[] pdomain = ConfusionMatrix.createConfusionMatrixHeader(preds, this._domain);
        assert (adomain.length == pdomain.length) : "The confusion matrix should have the same length for both directions.";
        String[] rowHeader = Arrays.copyOf(adomain, adomain.length + 1);
        rowHeader[adomain.length] = "Totals";
        String[] colHeader = Arrays.copyOf(pdomain, pdomain.length + 2);
        colHeader[colHeader.length - 2] = "Error";
        colHeader[colHeader.length - 1] = "Rate";
        String[] colType = new String[colHeader.length];
        String[] colFormat = new String[colHeader.length];
        for (int i = 0; i < colFormat.length - 1; ++i) {
            colType[i] = isInt ? "long" : "double";
            colFormat[i] = isInt ? "%d" : "%.2f";
        }
        colType[colFormat.length - 2] = "double";
        colFormat[colFormat.length - 2] = "%.4f";
        colType[colFormat.length - 1] = "string";
        double terr = 0.0;
        int width = 0;
        for (int a = 0; a < this._cm.length; ++a) {
            if (adomain[a] == null) continue;
            double correct = 0.0;
            for (int p = 0; p < pdomain.length; ++p) {
                boolean onDiag;
                if (pdomain[p] == null || !(onDiag = adomain[a].equals(pdomain[p]))) continue;
                correct = this._cm[a][p];
            }
            double err = acts[a] - correct;
            terr += err;
            width = isInt ? Math.max(width, String.format("%,d / %,d", (long)err, (long)acts[a]).length()) : Math.max(width, String.format("%.4f / %.4f", err, acts[a]).length());
        }
        double nrows = 0.0;
        for (double n : acts) {
            nrows += n;
        }
        width = isInt ? Math.max(width, String.format("%,d / %,d", (long)terr, (long)nrows).length()) : Math.max(width, String.format("%.4f / %.4f", terr, nrows).length());
        colFormat[colFormat.length - 1] = "= %" + width + "s";
        TwoDimTable table = new TwoDimTable("Confusion Matrix", "vertical: actual; across: predicted", rowHeader, colHeader, colType, colFormat, null);
        for (int a = 0; a < this._cm.length; ++a) {
            if (adomain[a] == null) continue;
            double correct = 0.0;
            for (int p = 0; p < pdomain.length; ++p) {
                if (pdomain[p] == null) continue;
                boolean onDiag = adomain[a].equals(pdomain[p]);
                if (onDiag) {
                    correct = this._cm[a][p];
                }
                if (isInt) {
                    table.set(a, p, (long)this._cm[a][p]);
                    continue;
                }
                table.set(a, p, this._cm[a][p]);
            }
            double err = acts[a] - correct;
            table.set(a, pdomain.length, err / acts[a]);
            table.set(a, pdomain.length + 1, isInt ? String.format("%,d / %,d", (long)err, (long)acts[a]) : String.format("%.4f / %.4f", err, acts[a]));
        }
        for (int p = 0; p < pdomain.length; ++p) {
            if (pdomain[p] == null) continue;
            if (isInt) {
                table.set(adomain.length, p, (long)preds[p]);
                continue;
            }
            table.set(adomain.length, p, preds[p]);
        }
        table.set(adomain.length, pdomain.length, (double)((float)terr) / nrows);
        table.set(adomain.length, pdomain.length + 1, isInt ? String.format("%,d / %,d", (long)terr, (long)nrows) : String.format("%.2f / %.2f", terr, nrows));
        return table;
    }

    private static class CMBuilder
    extends MRTask<CMBuilder> {
        final int _len;
        double[][] _arr;

        CMBuilder(int len) {
            this._len = len;
        }

        @Override
        public void map(Chunk ca, Chunk cp) {
            if (this._len > 1000) {
                return;
            }
            this._arr = new double[this._len][this._len];
            for (int i = 0; i < ca._len; ++i) {
                if (ca.isNA(i)) continue;
                double[] dArray = this._arr[(int)ca.at8(i)];
                int n = (int)cp.at8(i);
                dArray[n] = dArray[n] + 1.0;
            }
        }

        @Override
        public void reduce(CMBuilder cm) {
            ArrayUtils.add(this._arr, cm._arr);
        }
    }
}

