/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.evaluation;

import java.util.Arrays;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.evaluation.ClassificationScore;

public class BalancedAccuracy
implements ClassificationScore {
    private int classes;
    double[] class_correct;
    double[] total_class_weight;

    public BalancedAccuracy() {
    }

    public BalancedAccuracy(BalancedAccuracy toClone) {
        this.classes = toClone.classes;
        if (toClone.class_correct != null) {
            this.class_correct = Arrays.copyOf(toClone.class_correct, toClone.class_correct.length);
        }
        if (toClone.total_class_weight != null) {
            this.total_class_weight = Arrays.copyOf(toClone.total_class_weight, toClone.total_class_weight.length);
        }
    }

    @Override
    public double getScore() {
        double score = 0.0;
        for (int i = 0; i < this.classes; ++i) {
            if (this.total_class_weight[i] > 1.0E-15) {
                score += this.class_correct[i] / this.total_class_weight[i];
                continue;
            }
            score += 1.0;
        }
        return score /= (double)this.classes;
    }

    @Override
    public boolean lowerIsBetter() {
        return false;
    }

    @Override
    public BalancedAccuracy clone() {
        return new BalancedAccuracy(this);
    }

    @Override
    public String getName() {
        return "BalancedAccuracy";
    }

    @Override
    public int hashCode() {
        return this.getName().hashCode();
    }

    @Override
    public boolean equals(Object obj) {
        return this.getClass().isAssignableFrom(obj.getClass()) && obj.getClass().isAssignableFrom(this.getClass());
    }

    @Override
    public void prepare(CategoricalData toPredict) {
        this.classes = toPredict.getNumOfCategories();
        this.total_class_weight = new double[this.classes];
        this.class_correct = new double[this.classes];
    }

    @Override
    public void addResult(CategoricalResults prediction, int trueLabel, double weight) {
        int n = trueLabel;
        this.total_class_weight[n] = this.total_class_weight[n] + weight;
        if (prediction.mostLikely() == trueLabel) {
            int n2 = trueLabel;
            this.class_correct[n2] = this.class_correct[n2] + weight;
        }
    }

    @Override
    public void addResults(ClassificationScore other) {
        if (other instanceof BalancedAccuracy) {
            BalancedAccuracy o = (BalancedAccuracy)other;
            for (int i = 0; i < this.classes; ++i) {
                int n = i;
                this.class_correct[n] = this.class_correct[n] + o.class_correct[i];
                int n2 = i;
                this.total_class_weight[n2] = this.total_class_weight[n2] + o.total_class_weight[i];
            }
        }
    }
}

