/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.classification.multiclass.reduction;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.ml.WekaUtil;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Instances;

public class AllPairsTable {
    private final Map<String, Integer> classCount;
    private final Map<String, Map<String, Double>> separabilities = new HashMap<String, Map<String, Double>>();
    private final int sum;

    public AllPairsTable(Instances training, Instances validation, Classifier c) throws Exception {
        Collection<String> classes = WekaUtil.getClassesActuallyContainedInDataset(training);
        for (Collection set : SetUtil.getAllPossibleSubsetsWithSize(classes, (int)2)) {
            List pair = set.stream().sorted().collect(Collectors.toList());
            String a = (String)pair.get(0);
            String b = (String)pair.get(1);
            Instances trainingData = WekaUtil.getInstancesOfClass(training, a);
            trainingData.addAll((Collection)WekaUtil.getInstancesOfClass(training, b));
            c.buildClassifier(trainingData);
            Instances validationData = WekaUtil.getInstancesOfClass(validation, a);
            validationData.addAll((Collection)WekaUtil.getInstancesOfClass(validation, b));
            Evaluation eval = new Evaluation(trainingData);
            eval.evaluateModel(c, validationData, new Object[0]);
            if (!this.separabilities.containsKey(a)) {
                this.separabilities.put(a, new HashMap());
            }
            this.separabilities.get(a).put(b, eval.pctCorrect() / 100.0);
        }
        this.classCount = WekaUtil.getNumberOfInstancesPerClass(training);
        this.sum = training.size();
    }

    public double getSeparability(String c1, String c2) {
        if (c1.equals(c2)) {
            throw new IllegalArgumentException("Cannot separate a class from itself.");
        }
        if (c1.compareTo(c2) > 0) {
            return this.getSeparability(c2, c1);
        }
        return this.separabilities.get(c1).get(c2);
    }

    public double getUpperBoundOnSeparability(Collection<String> classes) {
        double max = 0.0;
        for (Collection pair : SetUtil.getAllPossibleSubsetsWithSize(classes, (int)2)) {
            Iterator i = pair.iterator();
            String a = (String)i.next();
            String b = (String)i.next();
            double expectedContributionToError = 1.0 - this.getSeparability(a, b);
            double relativeExpectedContributionToError = expectedContributionToError * (double)(this.classCount.get(a) + this.classCount.get(b)) / (double)(1.0f * (float)this.sum);
            max = Math.max(max, relativeExpectedContributionToError);
        }
        return 1.0 - max;
    }

    public double getAverageSeparability(Collection<String> classes) {
        DescriptiveStatistics stats = new DescriptiveStatistics();
        for (Collection pair : SetUtil.getAllPossibleSubsetsWithSize(classes, (int)2)) {
            Iterator i = pair.iterator();
            String a = (String)i.next();
            String b = (String)i.next();
            stats.addValue(this.getSeparability(a, b));
        }
        return stats.getMean();
    }

    public double getMultipliedSeparability(Collection<String> classes) {
        double seperability = 1.0;
        for (Collection pair : SetUtil.getAllPossibleSubsetsWithSize(classes, (int)2)) {
            Iterator i = pair.iterator();
            String a = (String)i.next();
            String b = (String)i.next();
            seperability *= this.getSeparability(a, b);
        }
        return seperability;
    }
}

