/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.weka.classification.learner.reduction;

import ai.libs.jaicore.ml.weka.WekaUtil;
import ai.libs.jaicore.ml.weka.classification.learner.reduction.AMCTreeNode;
import ai.libs.jaicore.ml.weka.classification.learner.reduction.EMCNodeType;
import ai.libs.jaicore.ml.weka.classification.learner.reduction.ITreeClassifier;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.lang3.builder.HashCodeBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.meta.MultiClassClassifier;
import weka.classifiers.rules.ZeroR;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.WekaException;

public class MCTreeNode
extends AMCTreeNode<Integer>
implements ITreeClassifier,
Iterable<MCTreeNode> {
    private static final long serialVersionUID = 8873192747068561266L;
    private EMCNodeType nodeType;
    private List<MCTreeNode> children = new ArrayList<MCTreeNode>();
    private Classifier classifier;
    private String classifierID;
    private boolean trained = false;
    private transient Logger logger = LoggerFactory.getLogger(MCTreeNode.class);
    public static final AtomicInteger cacheRetrievals = new AtomicInteger();
    private static Map<String, Classifier> classifierCacheMap = new HashMap<String, Classifier>();
    private static Lock classifierCacheMapLock = new ReentrantLock();

    public MCTreeNode(List<Integer> containedClasses) {
        super(containedClasses);
    }

    public MCTreeNode(List<Integer> containedClasses, EMCNodeType nodeType, String classifierID) throws Exception {
        this(containedClasses, nodeType, AbstractClassifier.forName((String)classifierID, null));
    }

    public MCTreeNode(List<Integer> containedClasses, EMCNodeType nodeType, Classifier baseClassifier) {
        this(containedClasses);
        this.setNodeType(nodeType);
        this.setBaseClassifier(baseClassifier);
    }

    public EMCNodeType getNodeType() {
        return this.nodeType;
    }

    public void addChild(MCTreeNode newNode) {
        if (newNode.getNodeType() == EMCNodeType.MERGE) {
            for (MCTreeNode child : newNode.getChildren()) {
                this.children.add(child);
            }
        } else {
            this.children.add(newNode);
        }
    }

    public List<MCTreeNode> getChildren() {
        return this.children;
    }

    public boolean isCompletelyConfigured() {
        if (this.classifier == null) {
            return false;
        }
        if (this.children.isEmpty()) {
            return false;
        }
        for (MCTreeNode child : this.children) {
            if (child.isCompletelyConfigured()) continue;
            return false;
        }
        return true;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void buildClassifier(Instances data) throws Exception {
        assert (this.getNodeType() != EMCNodeType.MERGE) : "MERGE node detected while building classifier. This must not happen!";
        if (data.isEmpty()) {
            throw new IllegalArgumentException("Cannot train MCTree with empty set of instances.");
        }
        if (this.children.isEmpty()) {
            throw new IllegalStateException("Cannot train MCTree without children");
        }
        ArrayList<Set<String>> instancesCluster = new ArrayList<Set<String>>();
        IntStream.range(0, this.children.size()).forEach((int x) -> instancesCluster.add(new HashSet()));
        int index = 0;
        for (MCTreeNode child2 : this.children) {
            Iterator iterator = child2.getContainedClasses().iterator();
            while (iterator.hasNext()) {
                int classIndex = (Integer)iterator.next();
                ((Set)instancesCluster.get(index)).add(data.classAttribute().value(classIndex));
            }
            ++index;
        }
        String classifierKey = this.classifier.getClass().getName() + "#" + instancesCluster + "#" + data.size() + "#" + new HashCodeBuilder().append((Object)data.toString()).toHashCode();
        Instances trainingData = WekaUtil.mergeClassesOfInstances(data, instancesCluster);
        try {
            this.classifier.buildClassifier(trainingData);
        }
        catch (WekaException e) {
            this.classifier = new ZeroR();
            this.classifier.buildClassifier(trainingData);
        }
        classifierCacheMapLock.lock();
        try {
            classifierCacheMap.put(classifierKey, this.classifier);
        }
        finally {
            classifierCacheMapLock.unlock();
        }
        ((Stream)this.children.stream().parallel()).forEach(child -> {
            try {
                child.buildClassifier(data);
            }
            catch (Exception e) {
                this.logger.error("Encountered problem when training MCTreeNode.", (Throwable)e);
            }
        });
        this.trained = true;
    }

    public void distributionForInstance(Instance instance, double[] distribution) throws Exception {
        Instance iNew = WekaUtil.getRefactoredInstance(instance, IntStream.range(0, this.children.size()).mapToObj(x -> x + ".0").collect(Collectors.toList()));
        double[] localDistribution = this.classifier.distributionForInstance(iNew);
        for (MCTreeNode child : this.children) {
            child.distributionForInstance(instance, distribution);
            int indexOfChild = this.children.indexOf(child);
            Iterator iterator = child.getContainedClasses().iterator();
            while (iterator.hasNext()) {
                int classContainedInChild;
                int n = classContainedInChild = ((Integer)iterator.next()).intValue();
                distribution[n] = distribution[n] * localDistribution[indexOfChild];
            }
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (!this.trained) {
            throw new IllegalStateException("Cannot get distribution from untrained classifier " + this.toStringWithOffset());
        }
        double[] classDistribution = new double[this.getContainedClasses().size()];
        this.distributionForInstance(instance, classDistribution);
        return classDistribution;
    }

    public Capabilities getCapabilities() {
        return this.classifier.getCapabilities();
    }

    @Override
    public int getHeight() {
        return 1 + this.children.stream().map(MCTreeNode::getHeight).mapToInt(Integer.TYPE::cast).max().getAsInt();
    }

    @Override
    public int getDepthOfFirstCommonParent(List<Integer> classes) {
        for (MCTreeNode child : this.children) {
            if (!child.getContainedClasses().containsAll(classes)) continue;
            return 1 + child.getDepthOfFirstCommonParent(classes);
        }
        return 1;
    }

    public static void clearCache() {
        classifierCacheMap.clear();
    }

    public static Map<String, Classifier> getClassifierCache() {
        return classifierCacheMap;
    }

    public Classifier getClassifier() {
        return this.classifier;
    }

    public void setBaseClassifier(Classifier classifier) {
        if (classifier == null) {
            throw new IllegalArgumentException("Cannot set null classifier!");
        }
        this.classifierID = classifier.getClass().getName();
        switch (this.nodeType) {
            case ONEVSREST: {
                MultiClassClassifier oneVsRestMCC = new MultiClassClassifier();
                oneVsRestMCC.setClassifier(classifier);
                this.classifier = oneVsRestMCC;
                break;
            }
            case ALLPAIRS: {
                MultiClassClassifier allPairsMCC = new MultiClassClassifier();
                try {
                    allPairsMCC.setOptions(new String[]{"-M", "3"});
                }
                catch (Exception e) {
                    this.logger.error("Observed problem when setting options for classifier.", (Throwable)e);
                }
                allPairsMCC.setClassifier(classifier);
                this.classifier = allPairsMCC;
                break;
            }
            case DIRECT: {
                this.classifier = classifier;
                break;
            }
        }
    }

    public void setNodeType(EMCNodeType nodeType) {
        this.nodeType = nodeType;
    }

    public String toString() {
        return this.toStringWithOffset("", null);
    }

    public String toStringWithOffset() {
        return this.toStringWithOffset("", "  ");
    }

    public String toStringWithOffset(String offset, String indent) {
        StringBuilder sb = new StringBuilder();
        sb.append(offset).append("(").append(this.getContainedClasses()).append(":").append(this.classifierID).append(":").append((Object)this.nodeType).append(") {");
        boolean first = true;
        for (MCTreeNode child : this.children) {
            if (first) {
                first = false;
            } else {
                sb.append(",");
            }
            if (indent != null) {
                sb.append("\n");
            }
            sb.append(child.toStringWithOffset(offset + (indent != null ? indent : ""), indent));
        }
        if (indent != null) {
            sb.append("\n").append(offset);
        }
        sb.append("}");
        return sb.toString();
    }

    @Override
    public Iterator<MCTreeNode> iterator() {
        return new Iterator<MCTreeNode>(){
            private int currentlyTraversedChild = -1;
            private Iterator<MCTreeNode> childIterator = null;

            @Override
            public boolean hasNext() {
                if (this.currentlyTraversedChild < 0) {
                    return true;
                }
                if (MCTreeNode.this.children.isEmpty()) {
                    return false;
                }
                if (this.childIterator == null) {
                    this.childIterator = ((MCTreeNode)MCTreeNode.this.children.get(this.currentlyTraversedChild)).iterator();
                }
                if (this.childIterator.hasNext()) {
                    return true;
                }
                if (this.currentlyTraversedChild == MCTreeNode.this.children.size() - 1) {
                    return false;
                }
                ++this.currentlyTraversedChild;
                this.childIterator = ((MCTreeNode)MCTreeNode.this.children.get(this.currentlyTraversedChild)).iterator();
                return this.childIterator.hasNext();
            }

            @Override
            public MCTreeNode next() {
                if (this.currentlyTraversedChild == -1) {
                    ++this.currentlyTraversedChild;
                    return MCTreeNode.this;
                }
                return this.childIterator.next();
            }
        };
    }
}

