/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.classify;

import com.aliasi.classify.ConditionalClassifierEvaluator;
import com.aliasi.classify.JointClassification;
import com.aliasi.classify.JointClassifier;
import com.aliasi.util.Math;

public class JointClassifierEvaluator<E>
extends ConditionalClassifierEvaluator<E> {
    public JointClassifierEvaluator(JointClassifier<E> classifier, String[] categories, boolean storeInputs) {
        super(classifier, categories, storeInputs);
    }

    @Override
    public void setClassifier(JointClassifier<E> classifier) {
        this.setClassifier(classifier, JointClassifierEvaluator.class);
    }

    @Override
    public JointClassifier<E> classifier() {
        JointClassifier result = (JointClassifier)super.classifier();
        return result;
    }

    public double averageLog2JointProbability(String refCategory, String responseCategory) {
        this.validateCategory(refCategory);
        this.validateCategory(responseCategory);
        double sum = 0.0;
        int count = 0;
        block0: for (int i = 0; i < this.mReferenceCategories.size(); ++i) {
            if (!((String)this.mReferenceCategories.get(i)).equals(refCategory)) continue;
            JointClassification c = (JointClassification)this.mClassifications.get(i);
            for (int rank = 0; rank < c.size(); ++rank) {
                if (!c.category(rank).equals(responseCategory)) continue;
                sum += c.jointLog2Probability(rank);
                ++count;
                continue block0;
            }
        }
        return sum / (double)count;
    }

    public double averageLog2JointProbabilityReference() {
        double sum = 0.0;
        block0: for (int i = 0; i < this.mReferenceCategories.size(); ++i) {
            String refCategory = ((String)this.mReferenceCategories.get(i)).toString();
            JointClassification c = (JointClassification)this.mClassifications.get(i);
            for (int rank = 0; rank < c.size(); ++rank) {
                if (!c.category(rank).equals(refCategory)) continue;
                sum += c.jointLog2Probability(rank);
                continue block0;
            }
        }
        return sum / (double)this.mReferenceCategories.size();
    }

    public double corpusLog2JointProbability() {
        double total = 0.0;
        for (int i = 0; i < this.mClassifications.size(); ++i) {
            JointClassification c = (JointClassification)this.mClassifications.get(i);
            double maxJointLog2P = Double.NEGATIVE_INFINITY;
            for (int rank = 0; rank < c.size(); ++rank) {
                double jointLog2P = c.jointLog2Probability(rank);
                if (!(jointLog2P > maxJointLog2P)) continue;
                maxJointLog2P = jointLog2P;
            }
            double sum = 0.0;
            for (int rank = 0; rank < c.size(); ++rank) {
                sum += java.lang.Math.pow(2.0, c.jointLog2Probability(rank) - maxJointLog2P);
            }
            total += maxJointLog2P + Math.log2(sum);
        }
        return total;
    }

    @Override
    void baseToString(StringBuilder sb) {
        super.baseToString(sb);
        sb.append("Average Log2 Joint Probability Reference=" + this.averageLog2JointProbabilityReference() + "\n");
    }

    @Override
    void oneVsAllToString(StringBuilder sb, String category, int i) {
        super.oneVsAllToString(sb, category, i);
        sb.append("Average Joint Probability Histogram=\n");
        this.appendCategoryLine(sb);
        for (int j = 0; j < this.numCategories(); ++j) {
            if (j > 0) {
                sb.append(',');
            }
            sb.append(this.averageLog2JointProbability(category, this.categories()[j]));
        }
        sb.append("\n");
    }
}

