/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.eval;

import deepnetts.data.MLDataItem;
import deepnetts.eval.ClassificationMetrics;
import deepnetts.eval.ConfusionMatrix;
import deepnetts.net.NeuralNetwork;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.visrec.ml.data.DataSet;
import javax.visrec.ml.eval.EvaluationMetrics;
import javax.visrec.ml.eval.Evaluator;

public class ClassifierEvaluator
implements Evaluator<NeuralNetwork, DataSet<? extends MLDataItem>> {
    private static final String LABEL_POSITIVE = "positive";
    private static final String LABEL_NEGATIVE = "negative";
    private static final String LABEL_NONE = "none";
    private List<String> classLabels = new ArrayList<String>();
    private ConfusionMatrix confusionMatrix;
    private HashMap<String, EvaluationMetrics> performanceByClass;
    private float threshold = 0.5f;

    private void init() {
        this.performanceByClass = new HashMap();
        if (this.classLabels.size() == 2) {
            this.confusionMatrix = new ConfusionMatrix(new String[]{LABEL_NEGATIVE, LABEL_POSITIVE});
        } else {
            this.confusionMatrix = new ConfusionMatrix(this.classLabels.toArray(new String[this.classLabels.size()]));
            this.classLabels.forEach(label -> this.performanceByClass.put((String)label, new EvaluationMetrics()));
        }
    }

    public EvaluationMetrics evaluate(NeuralNetwork neuralNet, DataSet<? extends MLDataItem> testSet) {
        this.classLabels.clear();
        this.classLabels.add(0, LABEL_NONE);
        for (String label : testSet.getTargetNames()) {
            this.classLabels.add(label);
        }
        this.init();
        for (MLDataItem item : testSet) {
            neuralNet.setInput(item.getInput());
            float[] predictedOut = neuralNet.getOutput();
            this.processResult(item.getTargetOutput().getValues(), predictedOut);
        }
        if (this.classLabels.size() == 2) {
            return this.createBinaryPerformanceMeasures();
        }
        this.createMultiClassPerformanceMeasures();
        return this.getTotalAverage();
    }

    private EvaluationMetrics createBinaryPerformanceMeasures() {
        EvaluationMetrics pm = new EvaluationMetrics();
        int tp = this.confusionMatrix.getTruePositive();
        int tn = this.confusionMatrix.getTrueNegative();
        int fp = this.confusionMatrix.getFalsePositive();
        int fn = this.confusionMatrix.getFalseNegative();
        ClassificationMetrics cm = new ClassificationMetrics(tn, fp, fn, tp);
        pm.set("TotalClasses", (float)this.classLabels.size());
        pm.set("TotalItems", (float)cm.getTotal());
        pm.set("TruePositive", (float)tp);
        pm.set("TrueNegative", (float)tn);
        pm.set("FalsePositive", (float)fp);
        pm.set("FalseNegative", (float)fn);
        pm.set("TotalCorrect", (float)(tp + tn));
        pm.set("TotalIncorrect", (float)(fp + fn));
        pm.set("Accuracy", cm.getAccuracy());
        pm.set("Precision", cm.getPrecision());
        pm.set("Recall", cm.getRecall());
        pm.set("F1Score", cm.getF1Score());
        return pm;
    }

    private Map<String, EvaluationMetrics> createMultiClassPerformanceMeasures() {
        this.performanceByClass = new HashMap();
        for (int clsIdx = 1; clsIdx < this.classLabels.size(); ++clsIdx) {
            EvaluationMetrics pm = new EvaluationMetrics();
            int tp = this.confusionMatrix.getTruePositive(clsIdx);
            int tn = this.confusionMatrix.getTrueNegative(clsIdx);
            int fp = this.confusionMatrix.getFalsePositive(clsIdx);
            int fn = this.confusionMatrix.getFalseNegative(clsIdx);
            pm.set("TruePositive", (float)tp);
            pm.set("TrueNegative", (float)tn);
            pm.set("FalsePositive", (float)fp);
            pm.set("FalseNegative", (float)fn);
            ClassificationMetrics cm = new ClassificationMetrics(tn, fp, fn, tp);
            pm.set("Accuracy", cm.getAccuracy());
            pm.set("Precision", cm.getPrecision());
            pm.set("Recall", cm.getRecall());
            pm.set("F1Score", cm.getF1Score());
            this.performanceByClass.put(this.classLabels.get(clsIdx), pm);
        }
        return this.performanceByClass;
    }

    private void processResult(float[] actual, float[] predicted) {
        if (this.classLabels.size() == 1) {
            if (actual[0] == 1.0f && predicted[0] >= this.threshold) {
                this.confusionMatrix.inc(1, 1);
            } else if (actual[0] == 0.0f && predicted[0] < this.threshold) {
                this.confusionMatrix.inc(0, 0);
            } else if (actual[0] == 0.0f && predicted[0] >= this.threshold) {
                this.confusionMatrix.inc(0, 1);
            } else if (actual[0] == 1.0f && predicted[0] < this.threshold) {
                this.confusionMatrix.inc(1, 0);
            }
        } else {
            int actualIdx = this.indexOfMax(actual);
            int predictedIdx = this.indexOfMax(predicted);
            this.confusionMatrix.inc(actualIdx, predictedIdx);
        }
    }

    private int indexOfMax(float[] array) {
        int maxIdx = -1;
        for (int i = 0; i < array.length; ++i) {
            if (!(array[i] >= this.threshold)) continue;
            if (maxIdx == -1) {
                maxIdx = i;
                continue;
            }
            if (!(array[i] > array[maxIdx])) continue;
            maxIdx = i;
        }
        if (maxIdx == -1) {
            return 0;
        }
        return maxIdx + 1;
    }

    public float getThreshold() {
        return this.threshold;
    }

    public void setThreshold(float threshold) {
        this.threshold = threshold;
    }

    public EvaluationMetrics getTotalAverage() {
        float accuracy = 0.0f;
        float precision = 0.0f;
        float recall = 0.0f;
        float f1score = 0.0f;
        for (EvaluationMetrics pm : this.performanceByClass.values()) {
            accuracy += pm.get("Accuracy");
            recall += pm.get("Recall");
            precision += pm.get("Precision");
            f1score += pm.get("F1Score");
        }
        int count = this.performanceByClass.values().size();
        EvaluationMetrics total = new EvaluationMetrics();
        total.set("Accuracy", accuracy / (float)count);
        total.set("Precision", precision / (float)count);
        total.set("Recall", recall / (float)count);
        total.set("F1Score", f1score / (float)count);
        return total;
    }

    public static EvaluationMetrics averagePerformance(List<EvaluationMetrics> measures) {
        float accuracy = 0.0f;
        float precision = 0.0f;
        float recall = 0.0f;
        float f1score = 0.0f;
        for (EvaluationMetrics pm : measures) {
            accuracy += pm.get("Accuracy");
            recall += pm.get("Recall");
            precision += pm.get("Precision");
            f1score += pm.get("F1Score");
        }
        int count = measures.size();
        EvaluationMetrics total = new EvaluationMetrics();
        total.set("Accuracy", accuracy / (float)count);
        total.set("Precision", precision / (float)count);
        total.set("Recall", recall / (float)count);
        total.set("F1Score", f1score / (float)count);
        return total;
    }

    public Map<String, EvaluationMetrics> getPerformanceByClass() {
        return this.performanceByClass;
    }

    public ConfusionMatrix getConfusionMatrix() {
        return this.confusionMatrix;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(System.lineSeparator()).append("------------------------------------------------------------------------").append(System.lineSeparator()).append("CLASSIFIER EVALUATION RESULTS ").append(System.lineSeparator()).append("------------------------------------------------------------------------").append(System.lineSeparator());
        sb.append("Total classes: ").append(this.classLabels.size()).append(System.lineSeparator());
        sb.append("Results by labels").append(System.lineSeparator());
        for (String label : this.performanceByClass.keySet()) {
            EvaluationMetrics result = this.performanceByClass.get(label);
            sb.append(label).append(": ");
            sb.append(result).append(System.lineSeparator());
        }
        return sb.toString();
    }
}

