/*
 * Decompiled with CFR 0.152.
 */
package tech.tablesaw.api.ml.classification;

import com.google.common.collect.Table;
import com.google.common.collect.TreeBasedTable;
import java.util.ArrayList;
import java.util.Set;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import tech.tablesaw.api.CategoryColumn;
import tech.tablesaw.api.IntColumn;
import tech.tablesaw.api.Table;
import tech.tablesaw.api.ml.classification.ConfusionMatrix;

public class StandardConfusionMatrix
implements ConfusionMatrix {
    private final com.google.common.collect.Table<Integer, Integer, Integer> table = TreeBasedTable.create();
    private SortedMap<Integer, Object> labels = new TreeMap<Integer, Object>();

    public StandardConfusionMatrix(SortedSet<Object> labels) {
        int i = 0;
        for (Object e : labels) {
            this.labels.put(i, e);
            ++i;
        }
    }

    @Override
    public void increment(Integer predicted, Integer actual) {
        Integer v = (Integer)this.table.get((Object)predicted, (Object)actual);
        if (v == null) {
            this.table.put((Object)predicted, (Object)actual, (Object)1);
        } else {
            this.table.put((Object)predicted, (Object)actual, (Object)(v + 1));
        }
    }

    @Override
    public String toString() {
        return this.toTable().print();
    }

    @Override
    public Table toTable() {
        Table t = Table.create("Confusion Matrix");
        t.addColumn(new CategoryColumn(""));
        TreeSet allValues = new TreeSet();
        allValues.addAll(this.table.columnKeySet());
        allValues.addAll(this.table.rowKeySet());
        for (Integer comparable : allValues) {
            t.addColumn(new IntColumn(String.valueOf(this.labels.get(comparable))));
            t.column(0).appendCell("Predicted " + this.labels.get(comparable));
        }
        ArrayList valuesList = new ArrayList(allValues);
        int n = 0;
        for (int r = 0; r < valuesList.size(); ++r) {
            for (int c = 0; c < valuesList.size(); ++c) {
                Integer value = (Integer)this.table.get(valuesList.get(r), valuesList.get(c));
                if (value == null) {
                    t.intColumn(c + 1).append(0);
                    continue;
                }
                t.intColumn(c + 1).append(value);
                n += value.intValue();
            }
        }
        t.column(0).setName("n = " + n);
        for (int c = 1; c <= valuesList.size(); ++c) {
            t.column(c).setName("Actual " + this.labels.get(c - 1));
        }
        return t;
    }

    @Override
    public double accuracy() {
        Set cellSet = this.table.cellSet();
        int hits = 0;
        int misses = 0;
        for (Table.Cell cell : cellSet) {
            if (cell.getRowKey().equals(cell.getColumnKey())) {
                hits += ((Integer)cell.getValue()).intValue();
                continue;
            }
            misses += ((Integer)cell.getValue()).intValue();
        }
        return (double)hits / ((double)(hits + misses) * 1.0);
    }
}

