/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.classification.loss;

import ai.libs.jaicore.basic.StringUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class ConfusionMatrix {
    private static final String COL_SEP = " | ";
    private final List<Object> objectIndex;
    private int[][] matrixEntries;

    public ConfusionMatrix(List<?> expected, List<?> actual) {
        if (expected.size() != actual.size()) {
            throw new IllegalArgumentException("The proivded lists must be of the same length.");
        }
        HashSet distinctClasses = new HashSet(expected);
        distinctClasses.addAll(actual);
        this.objectIndex = new ArrayList(distinctClasses);
        this.matrixEntries = new int[this.objectIndex.size()][this.objectIndex.size()];
        for (int i = 0; i < expected.size(); ++i) {
            int[] nArray = this.matrixEntries[this.objectIndex.indexOf(expected.get(i))];
            int n = this.objectIndex.indexOf(actual.get(i));
            nArray[n] = nArray[n] + 1;
        }
    }

    public List<Object> getObjectIndex() {
        return this.objectIndex;
    }

    public int getIndexOfObject(Object object) {
        return this.objectIndex.indexOf(object);
    }

    public int[][] getConfusionMatrix() {
        return this.matrixEntries;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        int cellWidth = Math.max(this.objectIndex.stream().mapToInt(x -> x.toString().length()).max().getAsInt(), Arrays.stream(this.matrixEntries).mapToInt(x -> Arrays.stream(x).map(y -> (y + "").length()).max().getAsInt()).max().getAsInt());
        sb.append(StringUtil.spaces((int)cellWidth));
        this.objectIndex.stream().map(x -> COL_SEP + StringUtil.postpaddedString((String)x.toString(), (int)cellWidth)).forEach(sb::append);
        sb.append("\n");
        sb.append(IntStream.range(0, cellWidth + (cellWidth + COL_SEP.length()) * this.objectIndex.size()).mapToObj(x -> "-").collect(Collectors.joining())).append("\n");
        for (int i = 0; i < this.objectIndex.size(); ++i) {
            sb.append(StringUtil.postpaddedString((String)this.objectIndex.get(i).toString(), (int)cellWidth));
            for (int j = 0; j < this.objectIndex.size(); ++j) {
                sb.append(COL_SEP).append(StringUtil.prepaddedString((String)(this.matrixEntries[i][j] + ""), (int)cellWidth));
            }
            sb.append("\n");
        }
        return sb.toString();
    }
}

