/*
 * Decompiled with CFR 0.152.
 */
package mulan.classifier.meta;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import mulan.classifier.MultiLabelLearner;
import mulan.classifier.MultiLabelOutput;
import mulan.classifier.meta.MultiLabelMetaLearner;
import mulan.classifier.transformation.BinaryRelevance;
import mulan.classifier.transformation.LabelPowerset;
import mulan.core.ArgumentNullException;
import mulan.data.ConditionalDependenceIdentifier;
import mulan.data.GreedyLabelClustering;
import mulan.data.LabelClustering;
import mulan.data.MultiLabelInstances;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.meta.FilteredClassifier;
import weka.classifiers.trees.J48;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

public class SubsetLearner
extends MultiLabelMetaLearner {
    private ArrayList<MultiLabelLearner> multiLabelLearners;
    private ArrayList<FilteredClassifier> singleLabelLearners;
    private int[][] splitOrder;
    private int[][] absoluteIndicesToRemove;
    private Remove[] remove;
    protected Classifier baseSingleLabelClassifier;
    private boolean useCache = false;
    private LabelClustering clusterer = null;
    private static HashMap<String, MultiLabelLearner> existingMultiLabelModels = new HashMap();
    private static HashMap<String, FilteredClassifier> existingSingleLabelModels = new HashMap();
    private static HashMap<String, Remove> existingRemove = new HashMap();

    public SubsetLearner() {
        this(new GreedyLabelClustering(new BinaryRelevance((Classifier)new J48()), (Classifier)new J48(), new ConditionalDependenceIdentifier((Classifier)new J48())), (MultiLabelLearner)new BinaryRelevance((Classifier)new J48()), (Classifier)new J48());
    }

    public SubsetLearner(int[][] labelsSubsets, Classifier singleLabelClassifier) {
        super(new LabelPowerset(singleLabelClassifier));
        if (singleLabelClassifier == null) {
            throw new ArgumentNullException("singleLabelClassifier");
        }
        if (labelsSubsets == null) {
            throw new ArgumentNullException("labelsSubsets");
        }
        this.baseSingleLabelClassifier = singleLabelClassifier;
        this.splitOrder = labelsSubsets;
        this.absoluteIndicesToRemove = new int[this.splitOrder.length][];
    }

    public SubsetLearner(int[][] labelsSubsets, MultiLabelLearner multiLabelLearner, Classifier singleLabelClassifier) {
        super(multiLabelLearner);
        if (singleLabelClassifier == null) {
            throw new ArgumentNullException("singleLabelClassifier");
        }
        if (labelsSubsets == null) {
            throw new ArgumentNullException("labelsSubsets");
        }
        this.baseSingleLabelClassifier = singleLabelClassifier;
        this.splitOrder = labelsSubsets;
        this.absoluteIndicesToRemove = new int[this.splitOrder.length][];
    }

    public SubsetLearner(LabelClustering clusteringMethod, MultiLabelLearner multiLabelLearner, Classifier singleLabelClassifier) {
        super(multiLabelLearner);
        if (clusteringMethod == null) {
            throw new ArgumentNullException("clusteringMethod");
        }
        if (singleLabelClassifier == null) {
            throw new ArgumentNullException("singleLabelClassifier");
        }
        this.baseSingleLabelClassifier = singleLabelClassifier;
        this.clusterer = clusteringMethod;
    }

    public void resetSubsets(int[][] labelsSubsets) {
        this.splitOrder = labelsSubsets;
        this.absoluteIndicesToRemove = new int[this.splitOrder.length][];
    }

    @Override
    protected void buildInternal(MultiLabelInstances trainingSet) throws Exception {
        if (this.clusterer != null) {
            this.splitOrder = this.clusterer.determineClusters(trainingSet);
            this.absoluteIndicesToRemove = new int[this.splitOrder.length][];
        }
        this.remove = new Remove[this.splitOrder.length];
        this.prepareIndicesToRemove();
        this.multiLabelLearners = new ArrayList();
        this.singleLabelLearners = new ArrayList();
        int countSingle = 0;
        int countMulti = 0;
        for (int totalSplitNo = 0; totalSplitNo < this.splitOrder.length; ++totalSplitNo) {
            Arrays.sort(this.splitOrder[totalSplitNo]);
            int foldHash = trainingSet.getDataSet().toString().hashCode();
            String modelKey = this.createKey(this.splitOrder[totalSplitNo], foldHash);
            if (this.splitOrder[totalSplitNo].length > 1) {
                this.buildMultiLabelModel(trainingSet, countMulti, totalSplitNo, modelKey);
                ++countMulti;
                continue;
            }
            this.buildSingleLabelModel(trainingSet, countSingle, totalSplitNo, modelKey);
            ++countSingle;
        }
    }

    private void prepareIndicesToRemove() {
        int i;
        int numofSplits = this.splitOrder.length;
        for (int r = 0; r < this.splitOrder.length; ++r) {
            this.absoluteIndicesToRemove[r] = new int[this.numLabels - this.splitOrder[r].length];
        }
        boolean[][] Selected = new boolean[this.splitOrder.length][this.numLabels];
        for (i = 0; i < numofSplits; ++i) {
            for (int j = 0; j < this.splitOrder[i].length; ++j) {
                Selected[i][this.splitOrder[i][j]] = true;
            }
        }
        for (i = 0; i < numofSplits; ++i) {
            int k = 0;
            for (int j = 0; j < this.numLabels; ++j) {
                if (Selected[i][j]) continue;
                this.absoluteIndicesToRemove[i][k] = this.labelIndices[j];
                ++k;
            }
        }
    }

    private void buildMultiLabelModel(MultiLabelInstances trainingSet, int countMulti, int totalSplitNo, String modelKey) throws Exception {
        if (this.useCache && existingMultiLabelModels.containsKey(modelKey)) {
            MultiLabelLearner model = existingMultiLabelModels.get(modelKey);
            this.resetRandomSeed(model);
            this.multiLabelLearners.add(model.makeCopy());
            this.remove[totalSplitNo] = existingRemove.get(modelKey);
        } else {
            Instances trainSubset = trainingSet.getDataSet();
            this.remove[totalSplitNo] = new Remove();
            this.remove[totalSplitNo].setAttributeIndicesArray(this.absoluteIndicesToRemove[totalSplitNo]);
            this.remove[totalSplitNo].setInputFormat(trainSubset);
            this.remove[totalSplitNo].setInvertSelection(false);
            trainSubset = Filter.useFilter((Instances)trainSubset, (Filter)this.remove[totalSplitNo]);
            this.multiLabelLearners.add(this.baseLearner.makeCopy());
            this.multiLabelLearners.get(countMulti).build(trainingSet.reintegrateModifiedDataSet(trainSubset));
            if (this.useCache) {
                existingMultiLabelModels.put(modelKey, this.multiLabelLearners.get(countMulti));
                existingRemove.put(modelKey, this.remove[totalSplitNo]);
            }
        }
    }

    private void buildSingleLabelModel(MultiLabelInstances trainingSet, int countSingle, int totalSplitNo, String modelKey) throws Exception {
        if (this.useCache && existingSingleLabelModels.containsKey(modelKey)) {
            FilteredClassifier model = existingSingleLabelModels.get(modelKey);
            Classifier classifier = model.getClassifier();
            this.resetRandomSeed(classifier);
            this.singleLabelLearners.add(model);
            this.remove[totalSplitNo] = existingRemove.get(modelKey);
        } else {
            this.singleLabelLearners.add(new FilteredClassifier());
            this.singleLabelLearners.get(countSingle).setClassifier(AbstractClassifier.makeCopy((Classifier)this.baseSingleLabelClassifier));
            Instances trainSubset = trainingSet.getDataSet();
            this.remove[totalSplitNo] = new Remove();
            this.remove[totalSplitNo].setAttributeIndicesArray(this.absoluteIndicesToRemove[totalSplitNo]);
            this.remove[totalSplitNo].setInputFormat(trainSubset);
            this.remove[totalSplitNo].setInvertSelection(false);
            this.singleLabelLearners.get(countSingle).setFilter((Filter)this.remove[totalSplitNo]);
            trainSubset.setClassIndex(this.labelIndices[this.splitOrder[totalSplitNo][0]]);
            this.singleLabelLearners.get(countSingle).buildClassifier(trainSubset);
            if (this.useCache) {
                existingSingleLabelModels.put(modelKey, this.singleLabelLearners.get(countSingle));
                existingRemove.put(modelKey, this.remove[totalSplitNo]);
            }
        }
    }

    private String createKey(int[] set, int fold) {
        StringBuilder sb = new StringBuilder("_");
        for (int i : set) {
            sb.append(i);
            sb.append("_");
        }
        sb.append(fold);
        return sb.toString();
    }

    public void resetRandomSeed(Object model) {
        Class<?> aClass = model.getClass();
        Method method = null;
        try {
            method = aClass.getMethod("setSeed", Integer.TYPE);
        }
        catch (NoSuchMethodException e) {
            try {
                method = aClass.getMethod("setRandomSeed", Integer.TYPE);
            }
            catch (NoSuchMethodException e2) {
                this.debug("NoSuchMethodExceptions: " + e.getMessage() + " and " + e2.getMessage());
            }
        }
        try {
            if (method != null) {
                method.invoke(model, 1);
            }
        }
        catch (IllegalAccessException e) {
            this.debug("IllegalAccessException: " + e.getMessage());
        }
        catch (InvocationTargetException e) {
            this.debug("InvocationTargetException: " + e.getMessage());
        }
    }

    public void setSeed() {
        for (MultiLabelLearner multiLabelLearner : this.multiLabelLearners) {
            this.resetRandomSeed(multiLabelLearner);
        }
        for (FilteredClassifier filteredClassifier : this.singleLabelLearners) {
            this.resetRandomSeed(filteredClassifier);
        }
    }

    @Override
    public MultiLabelOutput makePredictionInternal(Instance instance) throws Exception {
        int i;
        MultiLabelOutput[] MLO = new MultiLabelOutput[this.splitOrder.length];
        int singleSplitNo = 0;
        int multiSplitNo = 0;
        boolean[][] BooleanSubsets = new boolean[this.splitOrder.length][];
        double[][] ConfidenceSubsets = new double[this.splitOrder.length][];
        for (int r = 0; r < this.splitOrder.length; ++r) {
            BooleanSubsets[r] = new boolean[this.splitOrder[r].length];
            ConfidenceSubsets[r] = new double[this.splitOrder[r].length];
        }
        boolean[] BipartitionOut = new boolean[this.numLabels];
        double[] ConfidenceOut = new double[this.numLabels];
        for (i = 0; i < this.splitOrder.length; ++i) {
            if (this.splitOrder[i].length == 1) {
                double[] distribution;
                try {
                    distribution = this.singleLabelLearners.get(singleSplitNo).distributionForInstance(instance);
                }
                catch (Exception e) {
                    System.out.println(e);
                    return null;
                }
                int maxIndex = distribution[0] > distribution[1] ? 0 : 1;
                Attribute classAttribute = this.singleLabelLearners.get(singleSplitNo).getFilter().getOutputFormat().classAttribute();
                BooleanSubsets[i][0] = classAttribute.value(maxIndex).equals("1");
                ConfidenceSubsets[i][0] = distribution[classAttribute.indexOfValue("1")];
                ++singleSplitNo;
                continue;
            }
            this.remove[i].input(instance);
            this.remove[i].batchFinished();
            Instance newInstance = this.remove[i].output();
            MLO[multiSplitNo] = this.multiLabelLearners.get(multiSplitNo).makePrediction(newInstance);
            BooleanSubsets[i] = MLO[multiSplitNo].getBipartition();
            ConfidenceSubsets[i] = MLO[multiSplitNo].getConfidences();
            ++multiSplitNo;
        }
        for (i = 0; i < this.splitOrder.length; ++i) {
            for (int j = 0; j < this.splitOrder[i].length; ++j) {
                BipartitionOut[this.splitOrder[i][j]] = BooleanSubsets[i][j];
                ConfidenceOut[this.splitOrder[i][j]] = ConfidenceSubsets[i][j];
            }
        }
        return new MultiLabelOutput(BipartitionOut, ConfidenceOut);
    }

    public void setUseCache(boolean useCache) {
        this.useCache = useCache;
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Lena Tenenboim, Lior Rokach, and Bracha Shapira");
        result.setValue(TechnicalInformation.Field.TITLE, "Multi-label Classification by Analyzing Labels Dependencies");
        result.setValue(TechnicalInformation.Field.VOLUME, "Proc. ECML/PKDD 2009 Workshop on Learning from Multi-Label Data (MLD'09)");
        result.setValue(TechnicalInformation.Field.YEAR, "2009");
        result.setValue(TechnicalInformation.Field.PAGES, "117--132");
        result.setValue(TechnicalInformation.Field.ADDRESS, "Bled, Slovenia");
        TechnicalInformation result2 = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result2.setValue(TechnicalInformation.Field.AUTHOR, "Lena Tenenboim-Chekina, Lior Rokach, and Bracha Shapira");
        result2.setValue(TechnicalInformation.Field.TITLE, "Identification of Label Dependencies for Multi-label Classification");
        result2.setValue(TechnicalInformation.Field.VOLUME, "Proc. ICML 2010 Workshop on Learning from Multi-Label Data (MLD'10");
        result2.setValue(TechnicalInformation.Field.YEAR, "2010");
        result2.setValue(TechnicalInformation.Field.PAGES, "53--60");
        result2.setValue(TechnicalInformation.Field.ADDRESS, "Haifa, Israel");
        result.add(result2);
        return result;
    }

    public String getModel() {
        String out = "";
        for (int i = 0; i < this.multiLabelLearners.size(); ++i) {
            out = out + ((LabelPowerset)this.multiLabelLearners.get(i)).getBaseClassifier().toString();
        }
        return out;
    }

    @Override
    public String globalInfo() {
        StringBuilder sb = new StringBuilder();
        sb.append("A class for learning a classifier according to disjoint ");
        sb.append("label subsets: a multi-label learner (the Label Powerset ");
        sb.append("by default) is applied to subsets with multiple labels and");
        sb.append(" a single-label learner is applied to single label ");
        sb.append(" subsets. The final classification prediction is ");
        sb.append(" determined by combining labels predicted by all the ");
        sb.append("learned models. Note: the class is not multi-thread safe. ");
        sb.append("<br> <br> There is a mechanism for caching and reusing ");
        sb.append("learned classification models. The caching mechanism is ");
        sb.append("controlled by {@link #useCache} parameter.\n\nFor more ");
        sb.append("information, see\n\n");
        sb.append(this.getTechnicalInformation().toString());
        return sb.toString();
    }
}

