/*
 * Decompiled with CFR 0.152.
 */
package mulan.classifier.meta.thresholding;

import java.util.Arrays;
import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.InvalidDataException;
import mulan.classifier.MultiLabelLearner;
import mulan.classifier.MultiLabelOutput;
import mulan.classifier.meta.MultiLabelMetaLearner;
import mulan.classifier.transformation.BinaryRelevance;
import mulan.core.MulanRuntimeException;
import mulan.data.LabelsMetaData;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.measure.BipartitionMeasureBase;
import mulan.evaluation.measure.HammingLoss;
import weka.classifiers.Classifier;
import weka.classifiers.trees.J48;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.Utils;

public class OneThreshold
extends MultiLabelMetaLearner {
    private double threshold;
    private BipartitionMeasureBase measure;
    private int folds = 0;
    private MultiLabelLearner foldLearner;

    public OneThreshold() {
        this(new BinaryRelevance((Classifier)new J48()), new HammingLoss(), 3);
    }

    public OneThreshold(MultiLabelLearner baseLearner, BipartitionMeasureBase aMeasure, int someFolds) {
        super(baseLearner);
        if (someFolds < 2) {
            throw new IllegalArgumentException("folds should be more than 1");
        }
        this.measure = aMeasure;
        this.folds = someFolds;
        try {
            this.foldLearner = baseLearner.makeCopy();
        }
        catch (Exception ex) {
            Logger.getLogger(OneThreshold.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    public OneThreshold(MultiLabelLearner baseLearner, BipartitionMeasureBase aMeasure) {
        super(baseLearner);
        this.measure = aMeasure;
    }

    private double computeThreshold(MultiLabelLearner learner, MultiLabelInstances data, BipartitionMeasureBase measure, double min, double step, double max) throws Exception {
        int numOfThresholds = (int)Math.rint((max - min) / step + 1.0);
        double[] performance = new double[numOfThresholds];
        BipartitionMeasureBase[] measureForThreshold = new BipartitionMeasureBase[numOfThresholds];
        for (int i = 0; i < numOfThresholds; ++i) {
            measureForThreshold[i] = (BipartitionMeasureBase)measure.makeCopy();
            measureForThreshold[i].reset();
        }
        boolean[] thresholdHasProblem = new boolean[numOfThresholds];
        Arrays.fill(thresholdHasProblem, false);
        for (int j = 0; j < data.getNumInstances(); ++j) {
            Instance instance = data.getDataSet().instance(j);
            if (data.hasMissingLabels(instance)) continue;
            MultiLabelOutput mlo = learner.makePrediction(instance);
            boolean[] trueLabels = new boolean[this.numLabels];
            for (int counter = 0; counter < this.numLabels; ++counter) {
                int classIdx = this.labelIndices[counter];
                String classValue = instance.attribute(classIdx).value((int)instance.value(classIdx));
                trueLabels[counter] = classValue.equals("1");
            }
            double[] confidences = mlo.getConfidences();
            int counter = 0;
            double currentThreshold = min;
            while (currentThreshold <= max) {
                boolean[] bipartition = new boolean[this.numLabels];
                for (int k = 0; k < this.numLabels; ++k) {
                    if (!(confidences[k] >= currentThreshold)) continue;
                    bipartition[k] = true;
                }
                try {
                    MultiLabelOutput temp = new MultiLabelOutput(bipartition);
                    measureForThreshold[counter].update(temp, trueLabels);
                }
                catch (MulanRuntimeException e) {
                    thresholdHasProblem[counter] = true;
                }
                currentThreshold += step;
                ++counter;
            }
        }
        for (int i = 0; i < numOfThresholds; ++i) {
            performance[i] = !thresholdHasProblem[i] ? Math.abs(measure.getIdealValue() - measureForThreshold[i].getValue()) : Double.MAX_VALUE;
        }
        return min + (double)Utils.minIndex((double[])performance) * step;
    }

    private double computeThreshold(MultiLabelLearner learner, MultiLabelInstances data, BipartitionMeasureBase measure) throws Exception {
        double stage1 = this.computeThreshold(learner, data, measure, 0.0, 0.1, 1.0);
        this.debug("1st stage threshold = " + stage1);
        double stage2 = this.computeThreshold(learner, data, measure, stage1 - 0.05, 0.01, stage1 + 0.05);
        this.debug("2nd stage threshold = " + stage2);
        return stage2;
    }

    @Override
    protected void buildInternal(MultiLabelInstances trainingData) throws Exception {
        this.baseLearner.build(trainingData);
        if (this.folds == 0) {
            this.threshold = this.computeThreshold(this.baseLearner, trainingData, this.measure);
        } else {
            LabelsMetaData labelsMetaData = trainingData.getLabelsMetaData();
            double[] thresholds = new double[this.folds];
            for (int f = 0; f < this.folds; ++f) {
                Instances train = trainingData.getDataSet().trainCV(this.folds, f);
                MultiLabelInstances trainMulti = new MultiLabelInstances(train, labelsMetaData);
                Instances test = trainingData.getDataSet().testCV(this.folds, f);
                MultiLabelInstances testMulti = new MultiLabelInstances(test, labelsMetaData);
                MultiLabelLearner tempLearner = this.foldLearner.makeCopy();
                tempLearner.build(trainMulti);
                thresholds[f] = this.computeThreshold(tempLearner, testMulti, this.measure);
            }
            this.threshold = Utils.mean((double[])thresholds);
        }
    }

    @Override
    protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception, InvalidDataException {
        MultiLabelOutput mlo = this.baseLearner.makePrediction(instance);
        double[] confidences = mlo.getConfidences();
        boolean[] predictedLabels = new boolean[this.numLabels];
        for (int i = 0; i < this.numLabels; ++i) {
            predictedLabels[i] = confidences[i] >= this.threshold;
        }
        MultiLabelOutput newOutput = new MultiLabelOutput(predictedLabels, mlo.getConfidences());
        return newOutput;
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation info = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        info.setValue(TechnicalInformation.Field.AUTHOR, "Read, Jesse and Pfahringer, Bernhard and Holmes, Geoff");
        info.setValue(TechnicalInformation.Field.YEAR, "2008");
        info.setValue(TechnicalInformation.Field.TITLE, "Multi-label Classification Using Ensembles of Pruned Sets");
        info.setValue(TechnicalInformation.Field.BOOKTITLE, "Data Mining, 2008. ICDM '08. Eighth IEEE International Conference on");
        info.setValue(TechnicalInformation.Field.PAGES, "995-1000");
        info.setValue(TechnicalInformation.Field.LOCATION, "Pisa, Italy");
        return info;
    }

    public double getThreshold() {
        return this.threshold;
    }

    @Override
    public String globalInfo() {
        return "Class that estimates a single threshold for all labels and examples. For more information, see\n\n" + this.getTechnicalInformation().toString();
    }
}

