/*
 * Decompiled with CFR 0.152.
 */
package hex;

import java.util.Arrays;
import water.Iced;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.TwoDimTable;

public class ConfusionMatrix
extends Iced {
    private static final String MAX_CM_CLASSES_KEY = "sys.ai.h2o.cm.maxClasses";
    private static final int MAX_CM_CLASSES_DEFAULT = 1000;
    private TwoDimTable _table;
    public final double[][] _cm;
    public final String[] _domain;

    public ConfusionMatrix(double[][] value, String[] domain) {
        this._cm = value;
        this._domain = domain;
    }

    public void add(int i2, int j2) {
        double[] dArray = this._cm[i2];
        int n2 = j2;
        dArray[n2] = dArray[n2] + 1.0;
    }

    public final int size() {
        return this._domain.length;
    }

    boolean tooLarge() {
        return this.size() > ConfusionMatrix.maxClasses();
    }

    static int maxClasses() {
        String maxClassesSpec = System.getProperty(MAX_CM_CLASSES_KEY);
        if (maxClassesSpec == null) {
            return 1000;
        }
        return ConfusionMatrix.parseMaxClasses(maxClassesSpec);
    }

    static int parseMaxClasses(String maxClassesSpec) {
        try {
            int maxClasses = Integer.parseInt(maxClassesSpec);
            if (maxClasses <= 0) {
                Log.warn("Using default limit of max classes in a confusion matrix (1000, user specification is invalid: " + maxClasses + ")");
                return 1000;
            }
            return maxClasses;
        }
        catch (NumberFormatException e2) {
            Log.warn("Using default limit of max classes in a confusion matrix (1000, user specification is invalid: " + maxClassesSpec + ")", e2);
            return 1000;
        }
    }

    public final double mean_per_class_error() {
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("mean per class error cannot be computed: too many classes");
        }
        double err = 0.0;
        for (int d2 = 0; d2 < this._cm.length; ++d2) {
            err += this.class_error(d2);
        }
        return err / (double)this._cm.length;
    }

    public final double mean_per_class_accuracy() {
        return 1.0 - this.mean_per_class_error();
    }

    public final double class_error(int c2) {
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("class errors cannot be computed: too many classes");
        }
        double s2 = ArrayUtils.sum(this._cm[c2]);
        if (s2 == 0.0) {
            return 0.0;
        }
        return (s2 - this._cm[c2][c2]) / s2;
    }

    public double total_rows() {
        double n2 = 0.0;
        for (double[] a_arr : this._cm) {
            n2 += ArrayUtils.sum(a_arr);
        }
        return n2;
    }

    public void add(ConfusionMatrix other) {
        if (this._cm != null && other._cm != null) {
            ArrayUtils.add(this._cm, other._cm);
        }
    }

    public double err() {
        double n2;
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("error cannot be computed: too many classes");
        }
        double err = n2 = this.total_rows();
        for (int d2 = 0; d2 < this._cm.length; ++d2) {
            err -= this._cm[d2][d2];
        }
        return err / n2;
    }

    public double err_count() {
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("error count cannot be computed: too many classes");
        }
        double err = this.total_rows();
        for (int d2 = 0; d2 < this._cm.length; ++d2) {
            err -= this._cm[d2][d2];
        }
        assert (err >= 0.0);
        return err;
    }

    public double accuracy() {
        return 1.0 - this.err();
    }

    public double specificity() {
        if (!this.isBinary()) {
            throw new UnsupportedOperationException("specificity is only implemented for 2 class problems.");
        }
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("specificity cannot be computed: too many classes");
        }
        double tn = this._cm[0][0];
        double fp = this._cm[0][1];
        return tn / (tn + fp);
    }

    public double recall() {
        if (!this.isBinary()) {
            throw new UnsupportedOperationException("recall is only implemented for 2 class problems.");
        }
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("recall cannot be computed: too many classes");
        }
        double tp = this._cm[1][1];
        double fn = this._cm[1][0];
        return tp / (tp + fn);
    }

    public double precision() {
        if (!this.isBinary()) {
            throw new UnsupportedOperationException("precision is only implemented for 2 class problems.");
        }
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("precision cannot be computed: too many classes");
        }
        double tp = this._cm[1][1];
        double fp = this._cm[0][1];
        return tp / (tp + fp);
    }

    public double mcc() {
        if (!this.isBinary()) {
            throw new UnsupportedOperationException("mcc is only implemented for 2 class problems.");
        }
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("mcc cannot be computed: too many classes");
        }
        double tn = this._cm[0][0];
        double fp = this._cm[0][1];
        double tp = this._cm[1][1];
        double fn = this._cm[1][0];
        return (tp * tn - fp * fn) / Math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn));
    }

    public double max_per_class_error() {
        int n2 = this.nclasses();
        if (n2 == 0) {
            throw new UnsupportedOperationException("max per class error is only defined for classification problems");
        }
        if (this.tooLarge()) {
            throw new UnsupportedOperationException("max per class error cannot be computed: too many classes");
        }
        double res = this.class_error(0);
        for (int i2 = 1; i2 < n2; ++i2) {
            res = Math.max(res, this.class_error(i2));
        }
        return res;
    }

    public final int nclasses() {
        return this._domain == null ? 0 : this._domain.length;
    }

    public final boolean isBinary() {
        return this.nclasses() == 2;
    }

    public double f1() {
        double precision = this.precision();
        double recall = this.recall();
        return 2.0 * (precision * recall) / (precision + recall);
    }

    public double f2() {
        double precision = this.precision();
        double recall = this.recall();
        return 5.0 * (precision * recall) / (4.0 * precision + recall);
    }

    public double f0point5() {
        double precision = this.precision();
        double recall = this.recall();
        return 1.25 * (precision * recall) / (0.25 * precision + recall);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (double[] r2 : this._cm) {
            sb.append(Arrays.toString(r2)).append('\n');
        }
        return sb.toString();
    }

    private static String[] createConfusionMatrixHeader(double[] xs, String[] ds) {
        String[] ss = new String[xs.length];
        for (int i2 = 0; i2 < xs.length; ++i2) {
            if (!(xs[i2] >= 0.0) && (ds[i2] == null || ds[i2].length() <= 0 || Double.toString(i2).equals(ds[i2]))) continue;
            ss[i2] = ds[i2];
        }
        if (ds.length == xs.length - 1 && xs[xs.length - 1] > 0.0) {
            ss[xs.length - 1] = "NA";
        }
        return ss;
    }

    public String toASCII() {
        return this.table() == null ? "" : this._table.toString();
    }

    public TwoDimTable table() {
        return this._table == null ? (this._table = this.toTable()) : this._table;
    }

    private TwoDimTable toTable() {
        if (this.tooLarge()) {
            return null;
        }
        if (this._cm == null || this._domain == null) {
            return null;
        }
        for (double[] cm : this._cm) {
            assert (this._cm.length == cm.length);
        }
        double[] acts = new double[this._cm.length];
        double[] preds = new double[this._cm[0].length];
        boolean isInt = true;
        for (int a2 = 0; a2 < this._cm.length; ++a2) {
            double sum = 0.0;
            for (int p2 = 0; p2 < this._cm[a2].length; ++p2) {
                sum += this._cm[a2][p2];
                int n2 = p2;
                preds[n2] = preds[n2] + this._cm[a2][p2];
                isInt &= this._cm[a2][p2] == (double)((long)this._cm[a2][p2]);
            }
            acts[a2] = sum;
        }
        String[] adomain = ConfusionMatrix.createConfusionMatrixHeader(acts, this._domain);
        String[] pdomain = ConfusionMatrix.createConfusionMatrixHeader(preds, this._domain);
        assert (adomain.length == pdomain.length) : "The confusion matrix should have the same length for both directions.";
        String[] rowHeader = Arrays.copyOf(adomain, adomain.length + 1);
        rowHeader[adomain.length] = "Totals";
        String[] colHeader = Arrays.copyOf(pdomain, pdomain.length + 2);
        colHeader[colHeader.length - 2] = "Error";
        colHeader[colHeader.length - 1] = "Rate";
        String[] colType = new String[colHeader.length];
        String[] colFormat = new String[colHeader.length];
        for (int i2 = 0; i2 < colFormat.length - 1; ++i2) {
            colType[i2] = isInt ? "long" : "double";
            colFormat[i2] = isInt ? "%d" : "%.2f";
        }
        colType[colFormat.length - 2] = "double";
        colFormat[colFormat.length - 2] = "%.4f";
        colType[colFormat.length - 1] = "string";
        double terr = 0.0;
        int width = 0;
        for (int a3 = 0; a3 < this._cm.length; ++a3) {
            if (adomain[a3] == null) continue;
            double correct = 0.0;
            for (int p3 = 0; p3 < pdomain.length; ++p3) {
                int onDiag;
                if (pdomain[p3] == null || (onDiag = (int)(adomain[a3].equals(pdomain[p3]) ? 1 : 0)) == 0) continue;
                correct = this._cm[a3][p3];
            }
            double err = acts[a3] - correct;
            terr += err;
            width = isInt ? Math.max(width, String.format("%,d / %,d", (long)err, (long)acts[a3]).length()) : Math.max(width, String.format("%.4f / %.4f", err, acts[a3]).length());
        }
        double nrows = 0.0;
        for (double n3 : acts) {
            nrows += n3;
        }
        width = isInt ? Math.max(width, String.format("%,d / %,d", (long)terr, (long)nrows).length()) : Math.max(width, String.format("%.4f / %.4f", terr, nrows).length());
        colFormat[colFormat.length - 1] = "= %" + width + "s";
        TwoDimTable table = new TwoDimTable("Confusion Matrix", "Row labels: Actual class; Column labels: Predicted class", rowHeader, colHeader, colType, colFormat, null);
        for (int a4 = 0; a4 < this._cm.length; ++a4) {
            if (adomain[a4] == null) continue;
            double correct = 0.0;
            for (int p4 = 0; p4 < pdomain.length; ++p4) {
                if (pdomain[p4] == null) continue;
                boolean onDiag = adomain[a4].equals(pdomain[p4]);
                if (onDiag) {
                    correct = this._cm[a4][p4];
                }
                if (isInt) {
                    table.set(a4, p4, (long)this._cm[a4][p4]);
                    continue;
                }
                table.set(a4, p4, this._cm[a4][p4]);
            }
            double err = acts[a4] - correct;
            table.set(a4, pdomain.length, err / acts[a4]);
            table.set(a4, pdomain.length + 1, isInt ? String.format("%,d / %,d", (long)err, (long)acts[a4]) : String.format("%.4f / %.4f", err, acts[a4]));
        }
        for (int p5 = 0; p5 < pdomain.length; ++p5) {
            if (pdomain[p5] == null) continue;
            if (isInt) {
                table.set(adomain.length, p5, (long)preds[p5]);
                continue;
            }
            table.set(adomain.length, p5, preds[p5]);
        }
        table.set(adomain.length, pdomain.length, (double)((float)terr) / nrows);
        table.set(adomain.length, pdomain.length + 1, isInt ? String.format("%,d / %,d", (long)terr, (long)nrows) : String.format("%.2f / %.2f", terr, nrows));
        return table;
    }
}

