/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.reduction.single.confusion;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.ml.WekaUtil;
import ai.libs.jaicore.ml.classification.multiclass.reduction.MCTreeNodeReD;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Instances;

public class ConfusionBasedGreedyOptimizingAlgorithm {
    public MCTreeNodeReD buildClassifier(Instances data, Collection<String> pClassifierNames) throws Exception {
        System.out.println("START: " + data.relationName());
        int seed = 0;
        List split = WekaUtil.getStratifiedSplit((Instances)data, (long)seed, (double[])new double[]{0.7f});
        int numClasses = data.numClasses();
        System.out.println("Computing confusion matrices ...");
        HashMap<String, double[][]> confusionMatrices = new HashMap<String, double[][]>();
        for (String string : pClassifierNames) {
            System.out.println("\t" + string + " ...");
            try {
                Classifier c = AbstractClassifier.forName((String)string, null);
                c.buildClassifier((Instances)split.get(0));
                Evaluation eval = new Evaluation((Instances)split.get(0));
                eval.evaluateModel(c, (Instances)split.get(1), new Object[0]);
                confusionMatrices.put(string, eval.confusionMatrix());
            }
            catch (Throwable e) {
                System.err.println(e.getClass().getName() + ": " + e.getMessage());
            }
        }
        System.out.println("done");
        HashMap<String, Collection<Collection<Integer>>> zeroConflictSets = new HashMap<String, Collection<Collection<Integer>>>();
        for (String classifier : confusionMatrices.keySet()) {
            zeroConflictSets.put(classifier, this.getZeroConflictSets((double[][])confusionMatrices.get(classifier)));
        }
        Collection collection = SetUtil.cartesianProduct(confusionMatrices.keySet(), (int)2);
        int leastSeenMistakes = Integer.MAX_VALUE;
        String bestLeft = null;
        String bestRight = null;
        String bestInner = null;
        ArrayList<Integer> bestLeftClasses = null;
        Object bestRightClasses = null;
        int numPair = 0;
        for (List classifierPair : collection) {
            Object newBestZ2;
            String c1 = (String)classifierPair.get(0);
            String c2 = (String)classifierPair.get(1);
            System.out.println("\tConsidering " + c1 + "/" + c2 + "(" + ++numPair + "/" + collection.size() + ")");
            double[][] cm1 = (double[][])confusionMatrices.get(c1);
            double[][] cm2 = (double[][])confusionMatrices.get(c2);
            Collection z1 = (Collection)zeroConflictSets.get(c1);
            Collection z2 = (Collection)zeroConflictSets.get(c2);
            int sizeOfBestCombo = 0;
            ArrayList<Integer> bestZ1 = null;
            Object bestZ2 = null;
            for (Collection zeroSet1 : z1) {
                for (Collection zeroSet2 : z2) {
                    Collection coveredClassesOfThisPair = SetUtil.union((Collection[])new Collection[]{zeroSet1, zeroSet2});
                    if (coveredClassesOfThisPair.size() <= sizeOfBestCombo) continue;
                    sizeOfBestCombo = coveredClassesOfThisPair.size();
                    bestZ1 = zeroSet1;
                    bestZ2 = zeroSet2;
                }
            }
            for (int cId = 0; cId < numClasses; ++cId) {
                if (bestZ1.contains(cId) || bestZ2.contains(cId)) continue;
                ArrayList<Integer> newBestZ1 = new ArrayList<Integer>(bestZ1);
                newBestZ1.add(cId);
                int p1 = this.getPenaltyOfCluster(newBestZ1, cm1);
                newBestZ2 = new ArrayList(bestZ2);
                newBestZ2.add(cId);
                int p2 = this.getPenaltyOfCluster((Collection<Integer>)newBestZ2, cm2);
                if (p1 < p2) {
                    bestZ1 = newBestZ1;
                    continue;
                }
                bestZ2 = newBestZ2;
            }
            int p1 = this.getPenaltyOfCluster((Collection<Integer>)bestZ1, cm1);
            int p2 = this.getPenaltyOfCluster((Collection<Integer>)bestZ2, cm2);
            HashMap<String, String> classMap = new HashMap<String, String>();
            newBestZ2 = bestZ1.iterator();
            while (newBestZ2.hasNext()) {
                int i1 = (Integer)newBestZ2.next();
                classMap.put(data.classAttribute().value(i1), "l");
            }
            newBestZ2 = bestZ2.iterator();
            while (newBestZ2.hasNext()) {
                int i2 = (Integer)newBestZ2.next();
                classMap.put(data.classAttribute().value(i2), "r");
            }
            Instances newData = WekaUtil.getRefactoredInstances((Instances)data, classMap);
            List binaryInnerSplit = WekaUtil.getStratifiedSplit((Instances)newData, (long)seed, (double[])new double[]{0.7f});
            for (String classifier : pClassifierNames) {
                try {
                    System.out.println("\t\tConsidering " + c1 + "/" + c2 + "/" + classifier);
                    Classifier c = AbstractClassifier.forName((String)classifier, null);
                    c.buildClassifier((Instances)binaryInnerSplit.get(0));
                    Evaluation eval = new Evaluation(newData);
                    eval.evaluateModel(c, (Instances)binaryInnerSplit.get(1), new Object[0]);
                    int mistakes = (int)eval.incorrect();
                    int overallMistakes = p1 + p2 + mistakes;
                    if (overallMistakes >= leastSeenMistakes) continue;
                    leastSeenMistakes = overallMistakes;
                    System.out.println("New best system: " + c1 + "/" + c2 + "/" + classifier + " with " + leastSeenMistakes);
                    bestLeftClasses = bestZ1;
                    bestRightClasses = bestZ2;
                    bestLeft = c1;
                    bestRight = c2;
                    bestInner = classifier;
                }
                catch (Exception e) {
                    System.err.println(e.getClass() + ": " + e.getMessage());
                }
            }
        }
        MCTreeNodeReD tree = new MCTreeNodeReD(bestInner, (Collection)bestLeftClasses.stream().map(i -> data.classAttribute().value(i.intValue())).collect(Collectors.toList()), bestLeft, (Collection)bestRightClasses.stream().map(i -> data.classAttribute().value(i.intValue())).collect(Collectors.toList()), bestRight);
        tree.buildClassifier(data);
        return tree;
    }

    private int getLeastConflictingClass(double[][] confusionMatrix, Collection<Integer> blackList) {
        int leastConflictingClass = -1;
        int leastKnownScore = Integer.MAX_VALUE;
        for (int i = 0; i < confusionMatrix.length; ++i) {
            if (blackList.contains(i)) continue;
            int sum = 0;
            for (int j = 0; j < confusionMatrix.length; ++j) {
                sum = (int)((double)sum + confusionMatrix[i][j]);
            }
            if (sum >= leastKnownScore) continue;
            leastKnownScore = sum;
            leastConflictingClass = i;
        }
        return leastConflictingClass;
    }

    private Collection<Collection<Integer>> getZeroConflictSets(double[][] confusionMatrix) {
        ArrayList<Integer> blackList = new ArrayList<Integer>();
        ArrayList<Collection<Integer>> partitions = new ArrayList<Collection<Integer>>();
        int leastConflictingClass = -1;
        do {
            if ((leastConflictingClass = this.getLeastConflictingClass(confusionMatrix, blackList)) < 0) continue;
            Collection<Integer> cluster = new ArrayList<Integer>();
            cluster.add(leastConflictingClass);
            do {
                if (!(cluster = this.incrementCluster(cluster, confusionMatrix, blackList)).contains(-1)) continue;
                throw new IllegalStateException("Computed illegal cluster: " + cluster);
            } while (this.getPenaltyOfCluster(cluster, confusionMatrix) == 0);
            blackList.addAll(cluster);
            partitions.add(cluster);
        } while (leastConflictingClass >= 0);
        return partitions;
    }

    private Collection<Integer> incrementCluster(Collection<Integer> cluster, double[][] confusionMatrix, Collection<Integer> blackList) {
        int leastSeenPenalty = Integer.MAX_VALUE;
        int choice = -1;
        for (int cId = 0; cId < confusionMatrix.length; ++cId) {
            if (cluster.contains(cId) || blackList.contains(cId)) continue;
            int addedPenalty = 0;
            for (int i = 0; i < confusionMatrix.length; ++i) {
                addedPenalty = (int)((double)addedPenalty + confusionMatrix[i][cId]);
                addedPenalty = (int)((double)addedPenalty + confusionMatrix[cId][i]);
            }
            if (addedPenalty >= leastSeenPenalty) continue;
            leastSeenPenalty = addedPenalty;
            choice = cId;
        }
        ArrayList<Integer> newCluster = new ArrayList<Integer>(cluster);
        if (choice < 0) {
            return newCluster;
        }
        newCluster.add(choice);
        return newCluster;
    }

    private int getPenaltyOfCluster(Collection<Integer> cluster, double[][] confusionMatrix) {
        int sum = 0;
        for (int i : cluster) {
            for (int j : cluster) {
                sum = (int)((double)sum + confusionMatrix[i][j]);
            }
        }
        return sum;
    }
}

