/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.classification.multiclass.reduction;

import ai.libs.jaicore.basic.StringUtil;
import ai.libs.jaicore.ml.WekaUtil;
import ai.libs.jaicore.ml.classification.multiclass.reduction.ConstantClassifier;
import ai.libs.jaicore.ml.classification.multiclass.reduction.MCTreeMergeNode;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.rules.ZeroR;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.WekaException;

public class MCTreeNodeReD
implements Classifier,
Serializable {
    private static final long serialVersionUID = 8873192747068561266L;
    private Classifier innerNodeClassifier;
    private final List<String> containedClasses = new ArrayList<String>();
    private List<ChildNode> children = new ArrayList<ChildNode>();
    private boolean trained = false;

    public MCTreeNodeReD(String innerNodeClassifier, Collection<String> leftChildClasses, String leftChildClassifier, Collection<String> rightChildClasses, String rightChildClassifier) throws Exception {
        this(innerNodeClassifier, leftChildClasses, AbstractClassifier.forName((String)leftChildClassifier, null), rightChildClasses, AbstractClassifier.forName((String)rightChildClassifier, null));
    }

    public MCTreeNodeReD(Classifier innerNodeClassifier, Collection<String> leftChildClasses, Classifier leftChildClassifier, Collection<String> rightChildClasses, Classifier rightChildClassifier) {
        this(innerNodeClassifier, Arrays.asList(leftChildClasses, rightChildClasses), Arrays.asList(leftChildClassifier, rightChildClassifier));
    }

    public MCTreeNodeReD(String innerNodeClassifier, Collection<String> leftChildClasses, Classifier leftChildClassifier, Collection<String> rightChildClasses, Classifier rightChildClassifier) throws Exception {
        this(AbstractClassifier.forName((String)innerNodeClassifier, (String[])new String[0]), leftChildClasses, leftChildClassifier, rightChildClasses, rightChildClassifier);
    }

    public MCTreeNodeReD(Classifier innerNodeClassifier, List<Collection<String>> childClasses, List<Classifier> childClassifier) {
        if (childClasses.size() != childClassifier.size()) {
            throw new IllegalArgumentException("Number of child classes does not equal the number of child classifiers");
        }
        this.innerNodeClassifier = innerNodeClassifier;
        for (int i = 0; i < childClasses.size(); ++i) {
            this.addChild(new ArrayList<String>(childClasses.get(i)), (Classifier)(childClasses.get(i).size() > 1 ? childClassifier.get(i) : new ConstantClassifier()));
        }
    }

    protected MCTreeNodeReD() {
    }

    public void addChild(List<String> childClasses, Classifier childClassifier) {
        assert (!this.trained) : "Cannot insert children after the tree node has been trained!";
        if (childClassifier instanceof MCTreeMergeNode) {
            this.children.addAll(((MCTreeMergeNode)childClassifier).getChildren());
        } else {
            this.children.add(new ChildNode(childClasses, childClassifier));
        }
        this.containedClasses.addAll(childClasses);
    }

    public List<ChildNode> getChildren() {
        return this.children;
    }

    public List<String> getContainedClasses() {
        return this.containedClasses;
    }

    public boolean isCompletelyConfigured() {
        if (this.innerNodeClassifier == null || this.children.isEmpty()) {
            return false;
        }
        for (ChildNode child : this.children) {
            if (!(child.childNodeClassifier instanceof MCTreeNodeReD) || ((MCTreeNodeReD)child.childNodeClassifier).isCompletelyConfigured()) continue;
            return false;
        }
        return true;
    }

    public void buildClassifier(Instances data) throws Exception {
        assert (!data.isEmpty()) : "Cannot train MCTree with empty set of instances.";
        assert (!this.children.isEmpty()) : "Cannot train MCTree without children";
        assert (!this.trained) : "Cannot retrain MCTreeNodeReD";
        assert (this.containedClasses.containsAll(WekaUtil.getClassesActuallyContainedInDataset(data))) : "The classes for which this MCTreeNodeReD has been defined (" + this.containedClasses + ") is not a superset of the given training data (" + WekaUtil.getClassesActuallyContainedInDataset(data) + ") ...";
        assert (WekaUtil.getClassesActuallyContainedInDataset(data).containsAll(this.containedClasses)) : "The classes for which this MCTreeNodeReD has been defined (" + this.containedClasses + ") is not a subset of the given training data (" + WekaUtil.getClassesActuallyContainedInDataset(data) + ") ...";
        this.containedClasses.clear();
        for (int i = 0; i < data.numClasses(); ++i) {
            this.containedClasses.add(data.classAttribute().value(i));
        }
        ArrayList<Set<String>> instancesClusters = new ArrayList<Set<String>>();
        int childNum = 0;
        for (ChildNode child : this.getChildren()) {
            ++childNum;
            assert (!child.containedClasses.isEmpty()) : "Contained classes of child must not be empty";
            Instances childData = WekaUtil.getEmptySetOfInstancesWithRefactoredClass(data, child.containedClasses);
            for (Instance i : data) {
                String className = i.classAttribute().value((int)Math.round(i.classValue()));
                if (!child.containedClasses.contains(className)) continue;
                Instance iNew = WekaUtil.getRefactoredInstance(i, child.containedClasses);
                iNew.setClassValue(className);
                iNew.setDataset(childData);
                childData.add(iNew);
            }
            assert (child.containedClasses.containsAll(WekaUtil.getClassesActuallyContainedInDataset(childData))) : "There are data for the child node that are not contained in its declaration";
            assert (WekaUtil.getClassesActuallyContainedInDataset(childData).containsAll(child.containedClasses)) : "There are classes declared in the child, but no corresponding data have been passed";
            try {
                child.childNodeClassifier.buildClassifier(childData);
            }
            catch (Throwable e) {
                throw new RuntimeException("Cannot train classifier in child #" + childNum, e);
            }
            instancesClusters.add(new HashSet(child.containedClasses));
        }
        Instances trainingData = WekaUtil.mergeClassesOfInstances(data, instancesClusters);
        try {
            this.innerNodeClassifier.buildClassifier(trainingData);
        }
        catch (WekaException e) {
            this.innerNodeClassifier = new ZeroR();
            this.innerNodeClassifier.buildClassifier(trainingData);
        }
        catch (Throwable e) {
            throw new RuntimeException("Cannot train inner classifier", e);
        }
        this.trained = true;
    }

    public double classifyInstance(Instance instance) throws Exception {
        double selection = -1.0;
        double best = 0.0;
        double[] dist = this.distributionForInstance(instance);
        for (int i = 0; i < dist.length; ++i) {
            double score = dist[i];
            if (!(score > best)) continue;
            best = score;
            selection = i;
        }
        return selection;
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        assert (this.trained) : "Cannot get distribution from untrained classifier " + this.toStringWithOffset();
        Instance refactoredInstance = WekaUtil.getRefactoredInstance(instance);
        double[] innerNodeClassifierDistribution = this.innerNodeClassifier.distributionForInstance(refactoredInstance);
        double[] classDistribution = new double[this.getContainedClasses().size()];
        for (int childIndex = 0; childIndex < this.children.size(); ++childIndex) {
            ChildNode child = this.children.get(childIndex);
            double[] childDistribution = child.childNodeClassifier.distributionForInstance(WekaUtil.getRefactoredInstance(instance, child.containedClasses));
            assert (childDistribution.length == child.containedClasses.size()) : "Mismatch of child classes (" + ChildNode.access$200(child).size() + ") and distribution in child (" + childDistribution.length + ")";
            for (int i = 0; i < childDistribution.length; ++i) {
                String classValue = (String)child.containedClasses.get(i);
                classDistribution[this.getContainedClasses().indexOf((Object)classValue)] = childDistribution[i] * innerNodeClassifierDistribution[childIndex];
            }
        }
        double sum = Arrays.stream(classDistribution).sum();
        assert (sum - 1.0E-8 <= 1.0 && sum + 1.0E-8 >= 1.0) : "Distribution does not sum up to 1; actual some of distribution entries: " + sum;
        return classDistribution;
    }

    public Capabilities getCapabilities() {
        return this.innerNodeClassifier.getCapabilities();
    }

    public int getHeight() {
        int maxHeightChildren = 0;
        for (ChildNode child : this.children) {
            if (!(child.childNodeClassifier instanceof MCTreeNodeReD)) continue;
            maxHeightChildren = Math.max(((MCTreeNodeReD)child.childNodeClassifier).getHeight(), maxHeightChildren);
        }
        return 1 + maxHeightChildren;
    }

    public int getDepthOfFirstCommonParent(List<String> classes) {
        for (ChildNode child : this.children) {
            if (!child.containedClasses.containsAll(classes)) continue;
            int depth = 1;
            if (child.childNodeClassifier instanceof MCTreeNodeReD) {
                depth += ((MCTreeNodeReD)child.childNodeClassifier).getDepthOfFirstCommonParent(classes);
            }
            return depth;
        }
        return 1;
    }

    public Classifier getClassifier() {
        return this.innerNodeClassifier;
    }

    public void setBaseClassifier(Classifier classifier) {
        assert (classifier != null) : "Cannot set null classifier!";
        this.innerNodeClassifier = classifier;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("(");
        sb.append(this.innerNodeClassifier.getClass().getSimpleName());
        sb.append(")");
        sb.append("{");
        boolean first = true;
        for (ChildNode child : this.children) {
            if (first) {
                first = false;
            } else {
                sb.append(",");
            }
            sb.append(child);
        }
        sb.append("}");
        return sb.toString();
    }

    public String toStringWithOffset() {
        return this.toStringWithOffset("");
    }

    public String toStringWithOffset(String offset) {
        StringBuilder sb = new StringBuilder();
        sb.append(offset);
        sb.append("(");
        sb.append(this.getContainedClasses());
        sb.append(":");
        sb.append(this.innerNodeClassifier.getClass().getSimpleName());
        sb.append(") {");
        boolean first = true;
        for (ChildNode child : this.children) {
            if (first) {
                first = false;
            } else {
                sb.append(",");
            }
            sb.append("\n");
            sb.append(child.toStringWithOffset(offset + "  "));
        }
        sb.append("\n");
        sb.append(offset);
        sb.append("}");
        return sb.toString();
    }

    public MCTreeNodeReD clone() {
        try {
            Classifier lcClone = WekaUtil.cloneClassifier(this.children.get(0).childNodeClassifier);
            Classifier rcClone = WekaUtil.cloneClassifier(this.children.get(1).childNodeClassifier);
            return new MCTreeNodeReD(this.innerNodeClassifier.getClass().getName(), (Collection<String>)this.children.get(0).containedClasses, lcClone, (Collection<String>)this.children.get(1).containedClasses, rcClone);
        }
        catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    private class ChildNode {
        private List<String> containedClasses;
        private Classifier childNodeClassifier;

        private ChildNode(List<String> containedClasses, Classifier childNodeClassifier) {
            this.containedClasses = containedClasses;
            this.childNodeClassifier = childNodeClassifier;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            if (this.childNodeClassifier instanceof MCTreeNodeReD) {
                sb.append(this.childNodeClassifier.toString());
            } else {
                sb.append(this.childNodeClassifier.getClass().getSimpleName() + "(");
                sb.append(StringUtil.implode(this.containedClasses, (String)","));
                sb.append(")");
            }
            return sb.toString();
        }

        public String toStringWithOffset(String offset) {
            StringBuilder sb = new StringBuilder();
            if (this.childNodeClassifier instanceof MCTreeNodeReD) {
                sb.append(((MCTreeNodeReD)this.childNodeClassifier).toStringWithOffset(offset + "\t"));
            } else {
                sb.append(offset);
                sb.append("(");
                sb.append(this.containedClasses);
                sb.append(":");
                sb.append(this.childNodeClassifier.getClass().getSimpleName());
                sb.append(")");
            }
            return sb.toString();
        }
    }
}

