/*
 * Decompiled with CFR 0.152.
 */
package com.whylogs.core.metrics;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.whylogs.core.DatasetProfile;
import com.whylogs.core.message.ScoreMatrixMessage;
import com.whylogs.core.statistics.NumberTracker;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import lombok.NonNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ClassificationMetrics {
    private static final Logger log = LoggerFactory.getLogger(ClassificationMetrics.class);
    private List<String> labels;
    private NumberTracker[][] values;

    public ClassificationMetrics() {
        this(Lists.newArrayList(), ClassificationMetrics.newMatrix(0));
    }

    public static ClassificationMetrics of() {
        return new ClassificationMetrics();
    }

    public List<String> getLabels() {
        return Collections.unmodifiableList(this.labels);
    }

    public long[][] getConfusionMatrix() {
        int len = this.labels.size();
        long[][] res = new long[len][len];
        for (int i = 0; i < len; ++i) {
            for (int j = 0; j < len; ++j) {
                res[i][j] = this.values[i][j].getDoubles().getCount();
            }
        }
        return res;
    }

    private static NumberTracker[][] newMatrix(int len) {
        NumberTracker[][] res = new NumberTracker[len][len];
        if (len == 0) {
            return res;
        }
        for (int i = 0; i < len; ++i) {
            for (int j = 0; j < len; ++j) {
                res[i][j] = new NumberTracker();
            }
        }
        return res;
    }

    public <T> void update(DatasetProfile datasetProfile, T prediction, T target, double score) {
        String predictionText = ClassificationMetrics.textValue(prediction);
        String targetText = ClassificationMetrics.textValue(target);
        datasetProfile.track("whylogs.metrics.predictions", predictionText);
        datasetProfile.track("whylogs.metrics.targets", targetText);
        int x = this.labels.indexOf(predictionText);
        int y = this.labels.indexOf(targetText);
        if (x >= 0 && y >= 0) {
            this.values[x][y].track(score);
        } else {
            ArrayList newLabels = Lists.newArrayList(this.labels);
            if (x < 0) {
                newLabels.add(predictionText);
            }
            if (y < 0 && !predictionText.equals(targetText)) {
                newLabels.add(targetText);
            }
            Collections.sort(newLabels);
            int newDim = newLabels.size();
            NumberTracker[][] newValues = ClassificationMetrics.newMatrix(newDim);
            this.addMatrix(this.labels, this.values, newLabels, newValues);
            int i = newLabels.indexOf(predictionText);
            int j = newLabels.indexOf(targetText);
            newValues[i][j].track(score);
            this.labels = newLabels;
            this.values = newValues;
        }
    }

    private static String textValue(Object value) {
        if (value == null) {
            return null;
        }
        if (value instanceof Boolean) {
            Boolean boolVal = (Boolean)value;
            return boolVal != false ? "1" : "0";
        }
        return value.toString();
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("Labels: ");
        this.labels.forEach(it -> {
            builder.append((String)it);
            builder.append(", ");
        });
        builder.append('\n');
        int len = this.labels.size();
        for (int i = 0; i < len; ++i) {
            builder.append('[');
            for (int j = 0; j < len; ++j) {
                builder.append(this.values[i][j]);
                if (j + 1 >= len) continue;
                builder.append(", ");
            }
            builder.append("]\n");
        }
        return builder.toString();
    }

    public ClassificationMetrics merge(ClassificationMetrics other) {
        if (other == null) {
            return this.copy();
        }
        HashSet allLabels = Sets.newHashSet(this.labels);
        allLabels.addAll(other.labels);
        ArrayList newLabels = Lists.newArrayList((Iterable)allLabels);
        Collections.sort(newLabels);
        NumberTracker[][] newValues = ClassificationMetrics.newMatrix(newLabels.size());
        this.addMatrix(this.labels, this.values, newLabels, newValues);
        this.addMatrix(other.labels, other.values, newLabels, newValues);
        return new ClassificationMetrics(newLabels, newValues);
    }

    private void addMatrix(List<String> oldLabels, NumberTracker[][] oldValues, List<String> newLabels, NumberTracker[][] newValues) {
        for (int i = 0; i < oldLabels.size(); ++i) {
            String iLabel = oldLabels.get(i);
            int newI = newLabels.indexOf(iLabel);
            for (int j = 0; j < oldLabels.size(); ++j) {
                String jLabel = oldLabels.get(j);
                int newJ = newLabels.indexOf(jLabel);
                newValues[newI][newJ] = newValues[newI][newJ].merge(oldValues[i][j]);
            }
        }
    }

    @NonNull
    public ClassificationMetrics copy() {
        int len = this.labels.size();
        NumberTracker[][] copyValues = ClassificationMetrics.newMatrix(len);
        for (int i = 0; i < len; ++i) {
            for (int j = 0; j < len; ++j) {
                copyValues[i][j] = copyValues[i][j].merge(this.values[i][j]);
            }
        }
        return new ClassificationMetrics(Lists.newArrayList(this.labels), copyValues);
    }

    @NonNull
    public ScoreMatrixMessage.Builder toProtobuf() {
        ScoreMatrixMessage.Builder builder = ScoreMatrixMessage.newBuilder();
        this.labels.stream().map(Object::toString).forEach(builder::addLabels);
        int len = this.labels.size();
        for (int i = 0; i < len; ++i) {
            for (int j = 0; j < len; ++j) {
                builder.addScores(this.values[i][j].toProtobuf());
            }
        }
        return builder;
    }

    public static ClassificationMetrics fromProtobuf(ScoreMatrixMessage msg) {
        ArrayList labels = Lists.newArrayList();
        for (int i = 0; i < msg.getLabelsCount(); ++i) {
            labels.add(msg.getLabels(i));
        }
        int n = labels.size();
        NumberTracker[][] values = ClassificationMetrics.newMatrix(n);
        for (int i = 0; i < msg.getScoresCount(); ++i) {
            int row = i / n;
            int col = i % n;
            values[row][col] = NumberTracker.fromProtobuf(msg.getScores(i));
        }
        return new ClassificationMetrics(labels, values);
    }

    private ClassificationMetrics(List<String> labels, NumberTracker[][] values) {
        this.labels = labels;
        this.values = values;
    }
}

