/*
 * Decompiled with CFR 0.152.
 */
package com.whylogs.core.metrics;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.whylogs.core.message.ScoreMatrixMessage;
import com.whylogs.core.statistics.NumberTracker;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ScoreMatrix {
    private static final Logger log = LoggerFactory.getLogger(ScoreMatrix.class);
    private List<String> labels;
    private final String predictionField;
    private final String targetField;
    private final String scoreField;
    private NumberTracker[][] values;

    public ScoreMatrix(String predictionField, String targetField, String scoreField) {
        this(Lists.newArrayList(), predictionField, targetField, scoreField, ScoreMatrix.newMatrix(0));
    }

    public List<String> getLabels() {
        return Collections.unmodifiableList(this.labels);
    }

    public long[][] getConfusionMatrix() {
        int len = this.labels.size();
        long[][] res = new long[len][len];
        for (int i = 0; i < len; ++i) {
            for (int j = 0; j < len; ++j) {
                res[i][j] = this.values[i][j].getDoubles().getCount();
            }
        }
        return res;
    }

    private static NumberTracker[][] newMatrix(int len) {
        NumberTracker[][] res = new NumberTracker[len][len];
        if (len == 0) {
            return res;
        }
        for (int i = 0; i < len; ++i) {
            for (int j = 0; j < len; ++j) {
                res[i][j] = new NumberTracker();
            }
        }
        return res;
    }

    public void track(Map<String, ?> columns) {
        Preconditions.checkState((this.predictionField != null ? 1 : 0) != 0);
        Preconditions.checkState((this.targetField != null ? 1 : 0) != 0);
        Object prediction = columns.get(this.predictionField);
        Object target = columns.get(this.targetField);
        Object scoreRaw = columns.get(this.scoreField);
        double score = 0.0;
        if (scoreRaw instanceof Number) {
            score = ((Number)scoreRaw).doubleValue();
        } else if (scoreRaw != null) {
            try {
                score = Double.parseDouble(scoreRaw.toString());
            }
            catch (NumberFormatException e) {
                log.warn("Failed to parse score: {}", scoreRaw, (Object)e);
            }
        }
        this.update(prediction, target, score);
    }

    public <T> void update(T prediction, T target, double score) {
        String predictionText = ScoreMatrix.textValue(prediction);
        String targetText = ScoreMatrix.textValue(target);
        int x = this.labels.indexOf(predictionText);
        int y = this.labels.indexOf(targetText);
        if (x >= 0 && y >= 0) {
            this.values[x][y].track(score);
        } else {
            HashSet newLabelsSet = Sets.newHashSet(this.labels);
            if (x < 0) {
                newLabelsSet.add(predictionText);
            }
            if (y < 0) {
                newLabelsSet.add(targetText);
            }
            ArrayList newLabels = Lists.newArrayList((Iterable)newLabelsSet);
            Collections.sort(newLabels);
            int newDim = newLabelsSet.size();
            NumberTracker[][] newValues = ScoreMatrix.newMatrix(newDim);
            this.addMatrix(this.labels, this.values, newLabels, newValues);
            int i = newLabels.indexOf(predictionText);
            int j = newLabels.indexOf(targetText);
            newValues[i][j].track(score);
            this.labels = newLabels;
            this.values = newValues;
        }
    }

    private static String textValue(Object value) {
        if (value == null) {
            return null;
        }
        if (value instanceof Boolean) {
            Boolean boolVal = (Boolean)value;
            return boolVal != false ? "1" : "0";
        }
        return value.toString();
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("Labels: ");
        this.labels.forEach(it -> {
            builder.append((String)it);
            builder.append(", ");
        });
        builder.append('\n');
        int len = this.labels.size();
        for (int i = 0; i < len; ++i) {
            builder.append('[');
            for (int j = 0; j < len; ++j) {
                builder.append(this.values[i][j]);
                if (j + 1 >= len) continue;
                builder.append(", ");
            }
            builder.append("]\n");
        }
        return builder.toString();
    }

    public ScoreMatrix merge(ScoreMatrix other) {
        if (other == null) {
            return this.copy();
        }
        HashSet allLabels = Sets.newHashSet(this.labels);
        allLabels.addAll(other.labels);
        ArrayList newLabels = Lists.newArrayList((Iterable)allLabels);
        Collections.sort(newLabels);
        NumberTracker[][] newValues = ScoreMatrix.newMatrix(newLabels.size());
        this.addMatrix(this.labels, this.values, newLabels, newValues);
        this.addMatrix(other.labels, other.values, newLabels, newValues);
        return new ScoreMatrix(newLabels, this.targetField, this.predictionField, this.scoreField, newValues);
    }

    private void addMatrix(List<String> oldLabels, NumberTracker[][] oldValues, List<String> newLabels, NumberTracker[][] newValues) {
        for (int i = 0; i < oldLabels.size(); ++i) {
            String iLabel = oldLabels.get(i);
            int newI = newLabels.indexOf(iLabel);
            for (int j = 0; j < oldLabels.size(); ++j) {
                String jLabel = oldLabels.get(j);
                int newJ = newLabels.indexOf(jLabel);
                newValues[newI][newJ] = newValues[newI][newJ].merge(oldValues[i][j]);
            }
        }
    }

    @NonNull
    public ScoreMatrix copy() {
        int len = this.labels.size();
        NumberTracker[][] copyValues = ScoreMatrix.newMatrix(len);
        for (int i = 0; i < len; ++i) {
            for (int j = 0; j < len; ++j) {
                copyValues[i][j] = copyValues[i][j].merge(this.values[i][j]);
            }
        }
        return new ScoreMatrix(Lists.newArrayList(this.labels), this.predictionField, this.targetField, this.scoreField, copyValues);
    }

    @NonNull
    public ScoreMatrixMessage.Builder toProtobuf() {
        ScoreMatrixMessage.Builder builder = ScoreMatrixMessage.newBuilder();
        this.labels.stream().map(Object::toString).forEach(builder::addLabels);
        int len = this.labels.size();
        for (int i = 0; i < len; ++i) {
            for (int j = 0; j < len; ++j) {
                builder.addScores(this.values[i][j].toProtobuf());
            }
        }
        builder.setPredictionField(this.predictionField);
        builder.setTargetField(this.targetField);
        builder.setScoreField(this.scoreField);
        return builder;
    }

    public static ScoreMatrix fromProtobuf(ScoreMatrixMessage msg) {
        if (msg == null || msg.getSerializedSize() == 0) {
            return null;
        }
        ArrayList labels = Lists.newArrayList();
        for (int i = 0; i < msg.getLabelsCount(); ++i) {
            labels.add(msg.getLabels(i));
        }
        if (msg.getLabelsCount() == 0 && msg.getScoresCount() > 0) {
            log.warn("Skipping classification ScoreMatrix: has scores but no labels");
            return null;
        }
        int n = labels.size();
        NumberTracker[][] values = ScoreMatrix.newMatrix(n);
        for (int i = 0; i < msg.getScoresCount(); ++i) {
            int row = i / n;
            int col = i % n;
            values[row][col] = NumberTracker.fromProtobuf(msg.getScores(i));
        }
        return new ScoreMatrix(labels, msg.getPredictionField(), msg.getTargetField(), msg.getScoreField(), values);
    }

    private ScoreMatrix(List<String> labels, String predictionField, String targetField, String scoreField, NumberTracker[][] values) {
        this.labels = labels;
        this.predictionField = predictionField;
        this.targetField = targetField;
        this.scoreField = scoreField;
        this.values = values;
    }

    public String getPredictionField() {
        return this.predictionField;
    }

    public String getTargetField() {
        return this.targetField;
    }

    public String getScoreField() {
        return this.scoreField;
    }

    public NumberTracker[][] getValues() {
        return this.values;
    }
}

