/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.classification.multilabel.learner.homer;

import ai.libs.jaicore.basic.ArrayUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;
import meka.classifiers.multilabel.AbstractMultiLabelClassifier;
import meka.classifiers.multilabel.BR;
import meka.classifiers.multilabel.MultiLabelClassifier;
import meka.core.F;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Add;

public class HOMERNode
extends AbstractMultiLabelClassifier {
    private static final long serialVersionUID = -2634579245812714183L;
    private static final Logger LOGGER = LoggerFactory.getLogger(HOMERNode.class);
    private static final boolean HIERARCHICAL_STRING = false;
    private static final double THRESHOLD = 0.5;
    private List<HOMERNode> children;
    private MultiLabelClassifier baselearner;
    private String baselearnerName;
    private boolean doThreshold = false;

    public HOMERNode(HOMERNode ... nodes) {
        this(Arrays.asList(nodes));
    }

    public HOMERNode(List<HOMERNode> nodes) {
        this.children = nodes;
        Collections.sort(this.children, (o1, o2) -> {
            LinkedList<Integer> o1Labels = new LinkedList<Integer>(o1.getLabels());
            LinkedList<Integer> o2Labels = new LinkedList<Integer>(o2.getLabels());
            Collections.sort(o1Labels);
            Collections.sort(o2Labels);
            return ((Integer)o1Labels.get(0)).compareTo((Integer)o2Labels.get(0));
        });
        this.baselearner = new BR();
    }

    public void setThreshold(boolean doThreshold) {
        this.doThreshold = doThreshold;
    }

    public void setBaselearner(MultiLabelClassifier baselearner) {
        this.baselearner = baselearner;
    }

    public String getBaselearnerName() {
        return this.baselearnerName;
    }

    public void setBaselearnerName(String baselearnerName) {
        this.baselearnerName = baselearnerName;
    }

    public List<HOMERNode> getChildren() {
        return this.children;
    }

    public Collection<Integer> getLabels() {
        HashSet<Integer> labels = new HashSet<Integer>();
        this.children.stream().map(HOMERNode::getLabels).forEach(labels::addAll);
        return labels;
    }

    public void buildClassifier(Instances trainingSet) throws Exception {
        int i;
        LOGGER.debug("Build node with {} as a base learner", (Object)this.baselearnerName);
        Instances currentDataset = this.prepareInstances(trainingSet);
        ArrayList<Integer> removeInstances = new ArrayList<Integer>();
        for (i = 0; i < trainingSet.size(); ++i) {
            boolean addedLabel = false;
            for (int j = 0; j < this.children.size(); ++j) {
                int currentI = i;
                if (this.children.get(j).getLabels().stream().mapToDouble(x -> trainingSet.get(currentI).value(x.intValue())).sum() > 0.0) {
                    addedLabel = true;
                    currentDataset.get(i).setValue(j, 1.0);
                    continue;
                }
                currentDataset.get(i).setValue(j, 0.0);
            }
            if (addedLabel) continue;
            removeInstances.add(i);
        }
        for (i = removeInstances.size() - 1; i >= 0; --i) {
            currentDataset.remove(((Integer)removeInstances.get(i)).intValue());
        }
        this.baselearner.buildClassifier(currentDataset);
        for (HOMERNode child : this.children) {
            if (child.getLabels().size() <= 1) continue;
            child.buildClassifier(trainingSet);
        }
    }

    public double[] distributionForInstance(Instance testInstance) throws Exception {
        int length;
        Instances copy = new Instances(testInstance.dataset(), 0);
        copy.add(testInstance.copy(testInstance.toDoubleArray()));
        Instances prepared = this.prepareInstances(copy);
        int[] tDist = new int[]{};
        double[] dist = new double[]{};
        if (this.doThreshold) {
            tDist = ArrayUtil.thresholdDoubleToBinaryArray((double[])this.baselearner.distributionForInstance(prepared.get(0)), (double)0.5);
            length = tDist.length;
        } else {
            dist = this.baselearner.distributionForInstance(prepared.get(0));
            length = dist.length;
        }
        double[] returnDist = new double[testInstance.classIndex()];
        for (int i = 0; i < length; ++i) {
            if (this.doThreshold && tDist[i] == 1) {
                if (this.children.get(i).getLabels().size() == 1) {
                    returnDist[this.children.get((int)i).getLabels().iterator().next().intValue()] = 1.0;
                    continue;
                }
                ArrayUtil.add((double[])returnDist, (double[])this.children.get(i).distributionForInstance(testInstance));
                continue;
            }
            if (this.doThreshold) continue;
            if (this.children.get(i).getLabels().size() == 1) {
                returnDist[this.children.get((int)i).getLabels().iterator().next().intValue()] = dist[i];
                continue;
            }
            double[] childDist = this.children.get(i).distributionForInstance(testInstance);
            for (Integer childLabel : this.children.get(i).getLabels()) {
                returnDist[childLabel.intValue()] = childDist[childLabel] * dist[i];
            }
        }
        return returnDist;
    }

    public Instances prepareInstances(Instances dataset) throws Exception {
        Instances currentDataset = F.keepLabels((Instances)dataset, (int)dataset.classIndex(), (int[])new int[0]);
        for (int i = this.children.size() - 1; i >= 0; --i) {
            Collection<Integer> labels = this.children.get(i).getLabels();
            Add add = new Add();
            add.setAttributeName(labels.stream().map(x -> dataset.attribute(x.intValue()).name()).collect(Collectors.joining("&")));
            add.setAttributeIndex("first");
            add.setNominalLabels("0,1");
            add.setInputFormat(currentDataset);
            currentDataset = Filter.useFilter((Instances)currentDataset, (Filter)add);
        }
        currentDataset.setClassIndex(this.children.size());
        return currentDataset;
    }

    public boolean isLeaf() {
        return false;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        String actualBaselearnerName = this.baselearner.getOptions()[1];
        sb.append(actualBaselearnerName.substring(actualBaselearnerName.lastIndexOf(46) + 1, actualBaselearnerName.length()));
        sb.append("(");
        sb.append(this.children.stream().map(HOMERNode::toString).collect(Collectors.joining(",")));
        sb.append(")");
        return sb.toString();
    }
}

