/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.mlplan.multiclasswithreduction;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.ml.WekaUtil;
import ai.libs.jaicore.ml.classification.multiclass.reduction.splitters.RPNDSplitter;
import ai.libs.mlplan.multiclass.wekamlplan.weka.model.MLPipeline;
import ai.libs.mlplan.multiclasswithreduction.ClassSplit;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.attributeSelection.ASEvaluation;
import weka.attributeSelection.ASSearch;
import weka.attributeSelection.InfoGainAttributeEval;
import weka.attributeSelection.Ranker;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

public class NestedDichotomyUtil {
    private static final Logger logger = LoggerFactory.getLogger(NestedDichotomyUtil.class);

    private NestedDichotomyUtil() {
    }

    public static ClassSplit<String> createGeneralRPNDBasedSplit(Collection<String> classes, Random rand, String classifierName, Instances data) throws InterruptedException {
        if (classes.size() < 2) {
            throw new IllegalArgumentException("Cannot compute split for less than two classes!");
        }
        try {
            RPNDSplitter splitter = new RPNDSplitter(rand, (Classifier)new MLPipeline((ASSearch)new Ranker(), (ASEvaluation)new InfoGainAttributeEval(), AbstractClassifier.forName((String)classifierName, null)));
            Collection splitAsCollection = null;
            splitAsCollection = splitter.split(data);
            Iterator it = splitAsCollection.iterator();
            return new ClassSplit<String>(classes, (Collection)it.next(), (Collection)it.next());
        }
        catch (InterruptedException e) {
            throw e;
        }
        catch (Exception e) {
            logger.error("Unexpected exception occurred while creating an RPND split", (Throwable)e);
            return null;
        }
    }

    public static ClassSplit<String> createGeneralRPNDBasedSplit(Collection<String> classes, Collection<String> s1, Collection<String> s2, Random rand, String classifierName, Instances data) {
        try {
            RPNDSplitter splitter = new RPNDSplitter(rand, AbstractClassifier.forName((String)classifierName, (String[])new String[0]));
            Collection splitAsCollection = null;
            splitAsCollection = splitter.split(classes, s1, s2, data);
            Iterator it = splitAsCollection.iterator();
            return new ClassSplit<String>(classes, (Collection)it.next(), (Collection)it.next());
        }
        catch (Exception e) {
            logger.error("Unexpected exception occurred while creating an RPND split", (Throwable)e);
            return null;
        }
    }

    public static ClassSplit<String> createUnaryRPNDBasedSplit(Collection<String> classes, Random rand, String classifierName, Instances data) {
        if (classes.size() == 1) {
            return new ClassSplit<String>(classes, null, null);
        }
        ArrayList<String> copy = new ArrayList<String>(classes);
        Collections.shuffle(copy, rand);
        String c1 = (String)copy.get(0);
        String c2 = (String)copy.get(1);
        HashSet<String> s1 = new HashSet<String>();
        s1.add(c1);
        HashSet<String> s2 = new HashSet<String>();
        s2.add(c2);
        Instances reducedData = WekaUtil.mergeClassesOfInstances((Instances)data, s1, s2);
        Classifier c = null;
        try {
            c = AbstractClassifier.forName((String)classifierName, (String[])new String[0]);
        }
        catch (Exception e1) {
            logger.error("Could not get object of classifier with name {}", (Object)classifierName, (Object)e1);
            return null;
        }
        try {
            c.buildClassifier(reducedData);
        }
        catch (Exception e) {
            logger.error("Could not train classifier", (Throwable)e);
        }
        ArrayList remainingClasses = new ArrayList(SetUtil.difference((Collection)SetUtil.difference(classes, s1), s2));
        int o1 = 0;
        int o2 = 0;
        for (int i = 0; i < remainingClasses.size(); ++i) {
            String className = (String)remainingClasses.get(i);
            Instances testData = WekaUtil.getInstancesOfClass((Instances)data, (String)className);
            for (Instance inst : testData) {
                try {
                    double prediction = c.classifyInstance(WekaUtil.getRefactoredInstance((Instance)inst));
                    if (prediction == 0.0) {
                        ++o1;
                        continue;
                    }
                    ++o2;
                }
                catch (Exception e) {
                    logger.error("Could not get prediction for some instance to assign it to a meta-class", (Throwable)e);
                }
            }
        }
        if (o1 > o2) {
            s1.addAll(remainingClasses);
        } else {
            s2.addAll(remainingClasses);
        }
        return new ClassSplit<String>(classes, s1, s2);
    }
}

