/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.reduction.single.confusion;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.ml.WekaUtil;
import ai.libs.jaicore.ml.classification.multiclass.reduction.MCTreeNodeReD;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Instances;

public class ConfusionBasedAlgorithm {
    private Logger logger = LoggerFactory.getLogger(ConfusionBasedAlgorithm.class);

    /*
     * WARNING - void declaration
     */
    public MCTreeNodeReD buildClassifier(Instances data, Collection<String> pClassifierNames) throws Exception {
        Object newBestZ2;
        void var8_13;
        if (this.logger.isInfoEnabled()) {
            this.logger.info("START: {}", (Object)data.relationName());
        }
        int seed = 0;
        HashMap<String, double[][]> confusionMatrices = new HashMap<String, double[][]>();
        int numClasses = data.numClasses();
        this.logger.info("Computing confusion matrices ...");
        for (int i2 = 0; i2 < 10; ++i2) {
            List split = WekaUtil.getStratifiedSplit((Instances)data, (long)seed, (double[])new double[]{0.7f});
            for (String classifier : pClassifierNames) {
                try {
                    Classifier c = AbstractClassifier.forName((String)classifier, null);
                    c.buildClassifier((Instances)split.get(0));
                    Evaluation eval = new Evaluation((Instances)split.get(0));
                    eval.evaluateModel(c, (Instances)split.get(1), new Object[0]);
                    if (!confusionMatrices.containsKey(classifier)) {
                        confusionMatrices.put(classifier, new double[numClasses][numClasses]);
                    }
                    double[][] currentCM = (double[][])confusionMatrices.get(classifier);
                    Object addedCM = eval.confusionMatrix();
                    for (int j = 0; j < numClasses; ++j) {
                        for (int k = 0; k < numClasses; ++k) {
                            double[] dArray = currentCM[j];
                            int n = k;
                            dArray[n] = dArray[n] + addedCM[j][k];
                        }
                    }
                }
                catch (Exception e) {
                    this.logger.error("Unexpected exception has been thrown", (Throwable)e);
                }
            }
        }
        this.logger.info("done");
        HashMap zeroConflictSets = new HashMap();
        for (Map.Entry entry : confusionMatrices.entrySet()) {
            zeroConflictSets.put(entry.getKey(), this.getZeroConflictSets((double[][])entry.getValue()));
        }
        Collection classifierPairs = SetUtil.cartesianProduct(confusionMatrices.keySet(), (int)2);
        Object var8_12 = null;
        String bestRight = null;
        String bestInner = null;
        Object bestLeftClasses = null;
        Object bestRightClasses = null;
        for (List classifierPair : classifierPairs) {
            String c1 = (String)classifierPair.get(0);
            String c2 = (String)classifierPair.get(1);
            Collection z1 = (Collection)zeroConflictSets.get(c1);
            Collection z2 = (Collection)zeroConflictSets.get(c2);
            int sizeOfBestCombo = 0;
            for (Object zeroSet1 : z1) {
                for (Collection zeroSet2 : z2) {
                    Collection coveredClassesOfThisPair = SetUtil.union((Collection[])new Collection[]{zeroSet1, zeroSet2});
                    if (coveredClassesOfThisPair.size() <= sizeOfBestCombo) continue;
                    String string = c1;
                    bestRight = c2;
                    sizeOfBestCombo = coveredClassesOfThisPair.size();
                    bestLeftClasses = zeroSet1;
                    bestRightClasses = zeroSet2;
                }
            }
        }
        double[][] cm1 = (double[][])confusionMatrices.get(var8_13);
        double[][] cm2 = (double[][])confusionMatrices.get(bestRight);
        for (int cId = 0; cId < numClasses; ++cId) {
            if (bestLeftClasses.contains(cId) || bestRightClasses.contains(cId)) continue;
            ArrayList<Integer> newBestZ1 = new ArrayList<Integer>((Collection<Integer>)bestLeftClasses);
            newBestZ1.add(cId);
            int p1 = this.getPenaltyOfCluster(newBestZ1, cm1);
            newBestZ2 = new ArrayList(bestRightClasses);
            newBestZ2.add(cId);
            int p2 = this.getPenaltyOfCluster((Collection<Integer>)newBestZ2, cm2);
            if (p1 < p2) {
                bestLeftClasses = newBestZ1;
                continue;
            }
            bestRightClasses = newBestZ2;
        }
        int p1 = this.getPenaltyOfCluster((Collection<Integer>)bestLeftClasses, cm1);
        int p2 = this.getPenaltyOfCluster((Collection<Integer>)bestRightClasses, cm2);
        HashMap<String, String> classMap = new HashMap<String, String>();
        newBestZ2 = bestLeftClasses.iterator();
        while (newBestZ2.hasNext()) {
            int i1 = (Integer)newBestZ2.next();
            classMap.put(data.classAttribute().value(i1), "l");
        }
        newBestZ2 = bestRightClasses.iterator();
        while (newBestZ2.hasNext()) {
            int i2 = (Integer)newBestZ2.next();
            classMap.put(data.classAttribute().value(i2), "r");
        }
        Instances newData = WekaUtil.getRefactoredInstances((Instances)data, classMap);
        List binaryInnerSplit = WekaUtil.getStratifiedSplit((Instances)newData, (long)seed, (double[])new double[]{0.7f});
        int leastSeenMistakes = Integer.MAX_VALUE;
        for (String classifier : pClassifierNames) {
            try {
                Classifier c = AbstractClassifier.forName((String)classifier, null);
                c.buildClassifier((Instances)binaryInnerSplit.get(0));
                Evaluation eval = new Evaluation(newData);
                eval.evaluateModel(c, (Instances)binaryInnerSplit.get(1), new Object[0]);
                int mistakes = (int)eval.incorrect();
                int overallMistakes = p1 + p2 + mistakes;
                if (overallMistakes >= leastSeenMistakes) continue;
                leastSeenMistakes = overallMistakes;
                this.logger.info("New best system: {}/{}/{} with {}", new Object[]{var8_13, bestRight, classifier, leastSeenMistakes});
                bestInner = classifier;
            }
            catch (Exception e) {
                this.logger.error("Exception has been thrown unexpectedly.", (Throwable)e);
            }
        }
        if (bestInner == null) {
            throw new IllegalStateException("No best inner has been chosen!");
        }
        MCTreeNodeReD tree = new MCTreeNodeReD(bestInner, (Collection)bestLeftClasses.stream().map(i -> data.classAttribute().value(i.intValue())).collect(Collectors.toList()), (String)var8_13, (Collection)bestRightClasses.stream().map(i -> data.classAttribute().value(i.intValue())).collect(Collectors.toList()), bestRight);
        tree.buildClassifier(data);
        return tree;
    }

    private int getLeastConflictingClass(double[][] confusionMatrix, Collection<Integer> blackList) {
        int leastConflictingClass = -1;
        int leastKnownScore = Integer.MAX_VALUE;
        for (int i = 0; i < confusionMatrix.length; ++i) {
            if (blackList.contains(i)) continue;
            int sum = 0;
            for (int j = 0; j < confusionMatrix.length; ++j) {
                if (i == j) continue;
                sum = (int)((double)sum + confusionMatrix[i][j]);
            }
            if (sum >= leastKnownScore) continue;
            leastKnownScore = sum;
            leastConflictingClass = i;
        }
        return leastConflictingClass;
    }

    private Collection<Collection<Integer>> getZeroConflictSets(double[][] confusionMatrix) {
        ArrayList<Integer> blackList = new ArrayList<Integer>();
        ArrayList<Collection<Integer>> partitions = new ArrayList<Collection<Integer>>();
        int leastConflictingClass = -1;
        do {
            Collection<Integer> newCluster;
            if ((leastConflictingClass = this.getLeastConflictingClass(confusionMatrix, blackList)) < 0) continue;
            Collection<Integer> cluster = new ArrayList<Integer>();
            cluster.add(leastConflictingClass);
            while ((newCluster = this.incrementCluster(cluster, confusionMatrix, blackList)).size() != cluster.size()) {
                cluster = newCluster;
                if (cluster.contains(-1)) {
                    throw new IllegalStateException("Computed illegal cluster: " + cluster);
                }
                if (this.getPenaltyOfCluster(cluster, confusionMatrix) == 0 && cluster.size() < confusionMatrix.length) continue;
            }
            blackList.addAll(cluster);
            partitions.add(cluster);
        } while (leastConflictingClass >= 0 && blackList.size() < confusionMatrix.length);
        return partitions;
    }

    private Collection<Integer> incrementCluster(Collection<Integer> cluster, double[][] confusionMatrix, Collection<Integer> blackList) {
        int leastSeenPenalty = Integer.MAX_VALUE;
        int choice = -1;
        for (int cId = 0; cId < confusionMatrix.length; ++cId) {
            if (cluster.contains(cId) || blackList.contains(cId)) continue;
            int addedPenalty = 0;
            for (int i = 0; i < confusionMatrix.length; ++i) {
                addedPenalty = (int)((double)addedPenalty + confusionMatrix[i][cId]);
                addedPenalty = (int)((double)addedPenalty + confusionMatrix[cId][i]);
            }
            if (addedPenalty >= leastSeenPenalty) continue;
            leastSeenPenalty = addedPenalty;
            choice = cId;
        }
        ArrayList<Integer> newCluster = new ArrayList<Integer>(cluster);
        if (choice < 0) {
            return newCluster;
        }
        newCluster.add(choice);
        return newCluster;
    }

    private int getPenaltyOfCluster(Collection<Integer> cluster, double[][] confusionMatrix) {
        int sum = 0;
        for (int i : cluster) {
            for (int j : cluster) {
                if (i == j) continue;
                sum = (int)((double)sum + confusionMatrix[i][j]);
            }
        }
        return sum;
    }
}

