/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.classification.multiclass.reduction.reducer;

import ai.libs.jaicore.logging.LoggerUtil;
import ai.libs.jaicore.ml.WekaUtil;
import ai.libs.jaicore.ml.classification.multiclass.reduction.EMCNodeType;
import ai.libs.jaicore.ml.classification.multiclass.reduction.MCTreeNode;
import ai.libs.jaicore.ml.classification.multiclass.reduction.MCTreeNodeLeaf;
import ai.libs.jaicore.ml.classification.multiclass.reduction.reducer.Decision;
import ai.libs.jaicore.ml.classification.multiclass.reduction.reducer.ReductionGraphGenerator;
import ai.libs.jaicore.ml.classification.multiclass.reduction.reducer.RestProblem;
import ai.libs.jaicore.search.algorithms.standard.bestfirst.BestFirstEpsilon;
import ai.libs.jaicore.search.core.interfaces.GraphGenerator;
import ai.libs.jaicore.search.model.other.EvaluatedSearchGraphPath;
import ai.libs.jaicore.search.probleminputs.GraphSearchWithSubpathEvaluationsInput;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.rules.OneR;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;

public class ReductionOptimizer
implements Classifier {
    private final long seed;
    private MCTreeNode root;
    private transient Logger logger = LoggerFactory.getLogger(ReductionOptimizer.class);

    public ReductionOptimizer(long seed) {
        this.seed = seed;
    }

    public void buildClassifier(Instances data) throws Exception {
        Optional bestSolution;
        EvaluatedSearchGraphPath solution;
        List<Instances> dataSplit = WekaUtil.getStratifiedSplit(data, this.seed, (double)0.6f);
        Instances train = dataSplit.get(0);
        BestFirstEpsilon search = new BestFirstEpsilon(new GraphSearchWithSubpathEvaluationsInput((GraphGenerator)new ReductionGraphGenerator(new Random(this.seed), train), n -> (double)this.getLossForClassifier(this.getTreeFromSolution(n.externalPath(), data, false), data) * 1.0), n -> (double)n.path().size() * -1.0, 0.1, false);
        int i = 0;
        ArrayList<EvaluatedSearchGraphPath> solutions = new ArrayList<EvaluatedSearchGraphPath>();
        while ((solution = (EvaluatedSearchGraphPath)search.nextSolutionCandidate()) != null) {
            solutions.add(solution);
            if (i++ <= 100) continue;
        }
        if (!(bestSolution = solutions.stream().min((s1, s2) -> ((Double)s1.getScore()).compareTo((Double)s2.getScore()))).isPresent()) {
            this.logger.error("No solution found");
            return;
        }
        this.root = this.getTreeFromSolution(((EvaluatedSearchGraphPath)bestSolution.get()).getNodes(), data, true);
        this.root.buildClassifier(data);
    }

    public double classifyInstance(Instance instance) throws Exception {
        return this.root.classifyInstance(instance);
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        return this.root.distributionForInstance(instance);
    }

    public Capabilities getCapabilities() {
        return null;
    }

    private void completeTree(MCTreeNode tree) {
        if (!tree.isCompletelyConfigured()) {
            for (MCTreeNode node : tree) {
                if (!node.getChildren().isEmpty() || node.getContainedClasses().size() == 1) continue;
                node.setNodeType(EMCNodeType.DIRECT);
                node.setBaseClassifier((Classifier)new OneR());
                Iterator iterator = node.getContainedClasses().iterator();
                while (iterator.hasNext()) {
                    int openClass = (Integer)iterator.next();
                    try {
                        node.addChild(new MCTreeNodeLeaf(openClass));
                    }
                    catch (Exception e) {
                        this.logger.error(LoggerUtil.getExceptionInfo((Throwable)e));
                    }
                }
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private int getLossForClassifier(MCTreeNode tree, Instances data) {
        this.completeTree(tree);
        ReductionOptimizer reductionOptimizer = this;
        synchronized (reductionOptimizer) {
            try {
                DescriptiveStatistics stats = new DescriptiveStatistics();
                for (int i = 0; i < 2; ++i) {
                    List<Instances> split = WekaUtil.getStratifiedSplit(data, this.seed + (long)i, (double)0.6f);
                    tree.buildClassifier(split.get(0));
                    Evaluation eval = new Evaluation(data);
                    eval.evaluateModel((Classifier)tree, split.get(1), new Object[0]);
                    stats.addValue(eval.pctIncorrect());
                }
                return (int)Math.round(stats.getMean() * 100.0);
            }
            catch (Exception e) {
                this.logger.error(LoggerUtil.getExceptionInfo((Throwable)e));
                return Integer.MAX_VALUE;
            }
        }
    }

    private MCTreeNode getTreeFromSolution(List<RestProblem> solution, Instances data, boolean mustBeComplete) {
        List decisions = solution.stream().filter(n -> n.getEdgeToParent() != null).map(RestProblem::getEdgeToParent).collect(Collectors.toList());
        LinkedList<MCTreeNode> open = new LinkedList<MCTreeNode>();
        Attribute classAttribute = data.classAttribute();
        MCTreeNode localRoot = new MCTreeNode(IntStream.range(0, classAttribute.numValues()).mapToObj(i -> i).collect(Collectors.toList()));
        open.addFirst(localRoot);
        for (Decision decision : decisions) {
            boolean isCutOff;
            MCTreeNode nodeToRefine = (MCTreeNode)open.removeFirst();
            if (nodeToRefine == null) {
                throw new IllegalStateException("No node to apply the decision to! Apparently, there are more decisions for nodes than there are inner nodes.");
            }
            nodeToRefine.setNodeType(decision.getClassificationType());
            nodeToRefine.setBaseClassifier(decision.getBaseClassifier());
            boolean bl = isCutOff = decision.getLft() == null || decision.getRgt() == null;
            if (isCutOff) {
                for (Integer c : nodeToRefine.getContainedClasses()) {
                    try {
                        nodeToRefine.addChild(new MCTreeNodeLeaf(c));
                    }
                    catch (Exception e) {
                        this.logger.error(LoggerUtil.getExceptionInfo((Throwable)e));
                    }
                }
                continue;
            }
            boolean addedLeftChild = false;
            ArrayList<String> classesLft = new ArrayList<String>(decision.getLft());
            if (classesLft.size() == 1) {
                try {
                    nodeToRefine.addChild(new MCTreeNodeLeaf(classAttribute.indexOfValue((String)classesLft.get(0))));
                }
                catch (Exception e) {
                    this.logger.error(LoggerUtil.getExceptionInfo((Throwable)e));
                }
            } else {
                MCTreeNode lft = new MCTreeNode(classesLft.stream().map(arg_0 -> ((Attribute)classAttribute).indexOfValue(arg_0)).collect(Collectors.toList()));
                nodeToRefine.addChild(lft);
                addedLeftChild = true;
                open.push(lft);
            }
            ArrayList<String> classesRgt = new ArrayList<String>(decision.getRgt());
            if (classesRgt.size() == 1) {
                try {
                    nodeToRefine.addChild(new MCTreeNodeLeaf(data.classAttribute().indexOfValue((String)classesRgt.get(0))));
                }
                catch (Exception e) {
                    this.logger.error(LoggerUtil.getExceptionInfo((Throwable)e));
                }
                continue;
            }
            MCTreeNode rgt = new MCTreeNode(classesRgt.stream().map(arg_0 -> ((Attribute)classAttribute).indexOfValue(arg_0)).collect(Collectors.toList()));
            nodeToRefine.addChild(rgt);
            if (addedLeftChild) {
                MCTreeNode lft = (MCTreeNode)open.pop();
                open.push(rgt);
                open.push(lft);
                continue;
            }
            open.push(rgt);
        }
        if (mustBeComplete && !open.isEmpty()) {
            throw new IllegalStateException("Not all nodes have been equipped with decisions!");
        }
        return localRoot;
    }
}

