/*
 * Decompiled with CFR 0.152.
 */
package mulan.classifier.meta.thresholding;

import java.util.ArrayList;
import java.util.Collections;
import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.MultiLabelLearner;
import mulan.classifier.MultiLabelOutput;
import mulan.classifier.meta.MultiLabelMetaLearner;
import mulan.classifier.transformation.BinaryRelevance;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.measure.BipartitionMeasureBase;
import mulan.evaluation.measure.HammingLoss;
import weka.classifiers.Classifier;
import weka.classifiers.trees.J48;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.Utils;

public class SCut
extends MultiLabelMetaLearner {
    BipartitionMeasureBase measure;
    int kFoldsCV;
    double[] thresholds;

    public SCut() {
        this(new BinaryRelevance((Classifier)new J48()), new HammingLoss(), 3);
    }

    public SCut(MultiLabelLearner baseLearner, BipartitionMeasureBase measure, int folds) {
        super(baseLearner);
        this.measure = measure;
        this.kFoldsCV = folds;
    }

    public SCut(MultiLabelLearner baseLearner, BipartitionMeasureBase measure) {
        super(baseLearner);
        this.measure = measure;
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Yiming Yang");
        result.setValue(TechnicalInformation.Field.TITLE, "A study of thresholding strategies for text categorization");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "Proceedings of the 24th annual international ACM SIGIR conference on Research and development in information retrieval");
        result.setValue(TechnicalInformation.Field.PAGES, "137 - 145");
        result.setValue(TechnicalInformation.Field.LOCATION, "New Orleans, Louisiana, United States");
        result.setValue(TechnicalInformation.Field.YEAR, "2001");
        return result;
    }

    private double[] computeThresholds(MultiLabelLearner learner, MultiLabelInstances data) throws Exception {
        double[][] arraysOfConfidences = new double[data.getNumInstances()][this.numLabels];
        boolean[][] trueLabels = new boolean[data.getNumInstances()][this.numLabels];
        ArrayList[] conf = new ArrayList[this.numLabels];
        for (int l = 0; l < this.numLabels; ++l) {
            conf[l] = new ArrayList();
        }
        for (int j = 0; j < data.getNumInstances(); ++j) {
            try {
                arraysOfConfidences[j] = learner.makePrediction(data.getDataSet().instance(j)).getConfidences();
            }
            catch (Exception ex) {
                Logger.getLogger(SCut.class.getName()).log(Level.SEVERE, null, ex);
            }
            for (int l = 0; l < this.numLabels; ++l) {
                int labelIndice = this.labelIndices[l];
                trueLabels[j][l] = data.getDataSet().attribute(labelIndice).value((int)data.getDataSet().instance(j).value(labelIndice)).equals("1");
                conf[l].add(arraysOfConfidences[j][l]);
            }
        }
        double[] currentThresholds = new double[this.numLabels];
        double[][] measureTable = new double[3][this.numLabels];
        for (int l = 0; l < this.numLabels; ++l) {
            Collections.sort(conf[l]);
            currentThresholds[l] = 0.5;
        }
        double counter = 0.0;
        double tempThreshold = 0.0;
        int conv = 0;
        int numOfThresholds = data.getNumInstances();
        double[] performance = new double[numOfThresholds];
        BipartitionMeasureBase[] measureForThreshold = new BipartitionMeasureBase[numOfThresholds];
        for (int i = 0; i < numOfThresholds; ++i) {
            measureForThreshold[i] = (BipartitionMeasureBase)this.measure.makeCopy();
            measureForThreshold[i].reset();
        }
        do {
            int j;
            for (j = 0; j < this.numLabels; ++j) {
                measureTable[1][j] = measureTable[0][j];
            }
            for (j = 0; j < this.numLabels; ++j) {
                double score = 0.0;
                for (int l = numOfThresholds - 1; l >= 0; --l) {
                    measureForThreshold[l].reset();
                    currentThresholds[j] = l == 0 ? (Double)conf[j].get(l) : ((Double)conf[j].get(l) + (Double)conf[j].get(l - 1)) / 2.0;
                    for (int k = 0; k < data.getNumInstances(); ++k) {
                        boolean[] predictedLabels = new boolean[this.numLabels];
                        for (int x = 0; x < this.numLabels; ++x) {
                            predictedLabels[x] = arraysOfConfidences[k][x] >= currentThresholds[x];
                        }
                        MultiLabelOutput temp = new MultiLabelOutput(predictedLabels);
                        measureForThreshold[l].update(temp, trueLabels[k]);
                    }
                    score += measureForThreshold[l].getValue();
                }
                for (int i = 0; i < numOfThresholds; ++i) {
                    performance[i] = Math.abs(this.measure.getIdealValue() - measureForThreshold[i].getValue());
                }
                int t = Utils.minIndex((double[])performance);
                tempThreshold = t == 0 ? (Double)conf[j].get(t) : ((Double)conf[j].get(t) + (Double)conf[j].get(t - 1)) / 2.0;
                measureTable[0][j] = score;
                currentThresholds[j] = tempThreshold;
                if (counter != 0.0) continue;
                measureTable[2][j] = score;
            }
            conv = 0;
            for (int l = 0; l < this.numLabels; ++l) {
                if (!(Math.abs(measureTable[0][l] - measureTable[1][l]) / measureTable[2][l] < 0.001) || counter == 0.0) continue;
                ++conv;
            }
            counter += 1.0;
        } while (conv != this.numLabels);
        return currentThresholds;
    }

    @Override
    protected void buildInternal(MultiLabelInstances trainingSet) throws Exception {
        if (this.kFoldsCV == 0) {
            this.baseLearner.build(trainingSet);
            this.thresholds = this.computeThresholds(this.baseLearner, trainingSet);
        } else {
            this.thresholds = new double[this.numLabels];
            for (int i = 0; i < this.kFoldsCV; ++i) {
                Instances train = trainingSet.getDataSet().trainCV(this.kFoldsCV, i);
                MultiLabelInstances mlTrain = new MultiLabelInstances(train, trainingSet.getLabelsMetaData());
                Instances test = trainingSet.getDataSet().testCV(this.kFoldsCV, i);
                MultiLabelInstances mlTest = new MultiLabelInstances(test, trainingSet.getLabelsMetaData());
                MultiLabelLearner learner = this.baseLearner.makeCopy();
                learner.build(mlTrain);
                double[] foldThresholds = this.computeThresholds(learner, mlTest);
                for (int j = 0; j < this.numLabels; ++j) {
                    int n = j;
                    this.thresholds[n] = this.thresholds[n] + foldThresholds[j];
                }
            }
            int j = 0;
            while (j < this.numLabels) {
                int n = j++;
                this.thresholds[n] = this.thresholds[n] / (double)this.kFoldsCV;
            }
            this.baseLearner.build(trainingSet);
        }
    }

    @Override
    public MultiLabelOutput makePredictionInternal(Instance instance) throws Exception {
        MultiLabelOutput m = this.baseLearner.makePrediction(instance);
        double[] arrayOfConfidences = new double[this.numLabels];
        boolean[] predictedLabels = new boolean[this.numLabels];
        if (m.hasConfidences()) {
            arrayOfConfidences = m.getConfidences();
            for (int i = 0; i < this.numLabels; ++i) {
                predictedLabels[i] = arrayOfConfidences[i] >= this.thresholds[i];
            }
        }
        MultiLabelOutput final_mlo = new MultiLabelOutput(predictedLabels, arrayOfConfidences);
        return final_mlo;
    }

    @Override
    public String globalInfo() {
        return "Class that implements the SCut method (Score-based local  optimization). It computes a separate threshold for each label based on improving a user defined performance measure.For more information, see\n\n" + this.getTechnicalInformation().toString();
    }
}

