/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.weka.classification.learner.reduction.splitter;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.logging.LoggerUtil;
import ai.libs.jaicore.ml.weka.WekaUtil;
import ai.libs.jaicore.ml.weka.classification.learner.reduction.splitter.ISplitter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Random;
import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

public class RPNDSplitter
implements ISplitter {
    private static final Logger logger = LoggerFactory.getLogger(RPNDSplitter.class);
    private final Random rand;
    private final Classifier rpndClassifier;

    public RPNDSplitter(Random rand, Classifier rpndClassifier) {
        this.rand = rand;
        this.rpndClassifier = rpndClassifier;
    }

    @Override
    public Collection<Collection<String>> split(Instances data) throws SplitFailedException, InterruptedException {
        Collection<String> classes = WekaUtil.getClassesActuallyContainedInDataset(data);
        if (classes.size() == 1) {
            ArrayList<Collection<String>> split = new ArrayList<Collection<String>>();
            split.add(classes);
            return split;
        }
        ArrayList<String> copy = new ArrayList<String>(classes);
        Collections.shuffle(copy, this.rand);
        String c1 = (String)copy.get(0);
        String c2 = (String)copy.get(1);
        HashSet<String> s1 = new HashSet<String>();
        s1.add(c1);
        HashSet<String> s2 = new HashSet<String>();
        s2.add(c2);
        return this.split(copy, s1, s2, data);
    }

    public Collection<Collection<String>> split(Collection<String> classes, Collection<String> s1, Collection<String> s2, Instances data) throws SplitFailedException, InterruptedException {
        logger.info("Start creation of RPND split with basis {}/{} for classes {}", new Object[]{s1, s2, classes});
        Instances reducedData = WekaUtil.mergeClassesOfInstances(data, s1, s2);
        logger.debug("Building classifier for separating the two class sets {} and {}", s1, s2);
        try {
            this.rpndClassifier.buildClassifier(reducedData);
        }
        catch (Exception e1) {
            throw new SplitFailedException((Throwable)e1);
        }
        logger.info("Now classifying the items of the other classes");
        ArrayList remainingClasses = new ArrayList(SetUtil.difference((Collection)SetUtil.difference(classes, s1), s2));
        for (int i = 0; i < remainingClasses.size(); ++i) {
            String className = (String)remainingClasses.get(i);
            Instances testData = WekaUtil.getInstancesOfClass(data, className);
            logger.debug("Classify {} instances of class {}", (Object)testData.size(), (Object)className);
            int o1 = 0;
            int o2 = 0;
            for (Instance inst : testData) {
                if (Thread.interrupted()) {
                    throw new InterruptedException();
                }
                try {
                    double prediction = this.rpndClassifier.classifyInstance(WekaUtil.getRefactoredInstance(inst));
                    if (prediction == 0.0) {
                        ++o1;
                        continue;
                    }
                    ++o2;
                }
                catch (Exception e) {
                    logger.error(LoggerUtil.getExceptionInfo((Throwable)e));
                }
            }
            if (o1 > o2) {
                s1.add(className);
                continue;
            }
            s2.add(className);
        }
        ArrayList<Collection<String>> split = new ArrayList<Collection<String>>();
        split.add(s1);
        split.add(s2);
        return split;
    }
}

