/*
 * Decompiled with CFR 0.152.
 */
package mulan.classifier.neural;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.Set;
import mulan.classifier.InvalidDataException;
import mulan.classifier.MultiLabelLearnerBase;
import mulan.classifier.MultiLabelOutput;
import mulan.classifier.neural.DataPair;
import mulan.classifier.neural.MMPMaxUpdateRule;
import mulan.classifier.neural.MMPRandomizedUpdateRule;
import mulan.classifier.neural.MMPUniformUpdateRule;
import mulan.classifier.neural.MMPUpdateRuleType;
import mulan.classifier.neural.ModelUpdateRule;
import mulan.classifier.neural.NormalizationFilter;
import mulan.classifier.neural.model.ActivationFunction;
import mulan.classifier.neural.model.ActivationLinear;
import mulan.classifier.neural.model.Neuron;
import mulan.core.ArgumentNullException;
import mulan.core.WekaException;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.loss.RankingLoss;
import mulan.evaluation.loss.RankingLossFunction;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;

public class MMPLearner
extends MultiLabelLearnerBase {
    private static final long serialVersionUID = 2221778416856852684L;
    private static final double PERCEP_BIAS = 1.0;
    private List<Neuron> perceptrons;
    private NormalizationFilter normalizer;
    private int epochs = 1;
    private boolean convertNomToBin = true;
    private NominalToBinary nomToBinFilter;
    private final RankingLossFunction lossFunction;
    private final MMPUpdateRuleType mmpUpdateRule;
    private boolean isInitialized = false;
    private final Long randomnessSeed;

    public MMPLearner() {
        this(new RankingLoss(), MMPUpdateRuleType.UniformUpdate);
    }

    public MMPLearner(RankingLossFunction lossMeasure, MMPUpdateRuleType modelUpdateRule) {
        if (lossMeasure == null) {
            throw new ArgumentNullException("lossMeasure");
        }
        if (modelUpdateRule == null) {
            throw new ArgumentNullException("modelUpdateRule");
        }
        this.mmpUpdateRule = modelUpdateRule;
        this.lossFunction = lossMeasure;
        this.randomnessSeed = null;
    }

    public MMPLearner(RankingLossFunction lossMeasure, MMPUpdateRuleType modelUpdateRule, long randomnessSeed) {
        if (lossMeasure == null) {
            throw new ArgumentNullException("lossMeasure");
        }
        if (modelUpdateRule == null) {
            throw new ArgumentNullException("modelUpdateRule");
        }
        this.mmpUpdateRule = modelUpdateRule;
        this.lossFunction = lossMeasure;
        this.randomnessSeed = randomnessSeed;
    }

    public void setTrainingEpochs(int epochs) {
        if (epochs <= 0) {
            throw new IllegalArgumentException("The number of training epochs must be greater than zero. Entered value is : " + epochs);
        }
        this.epochs = epochs;
    }

    public int getTrainingEpochs() {
        return this.epochs;
    }

    public void setConvertNominalToBinary(boolean convert) {
        this.convertNomToBin = convert;
    }

    public boolean getConvertNominalToBinary() {
        return this.convertNomToBin;
    }

    @Override
    public boolean isUpdatable() {
        return true;
    }

    @Override
    protected void buildInternal(MultiLabelInstances trainingSet) throws Exception {
        trainingSet = trainingSet.clone();
        List<DataPair> trainData = this.prepareData(trainingSet);
        int numFeatures = trainData.get(0).getInput().length;
        if (!this.isInitialized) {
            this.perceptrons = this.initializeModel(numFeatures, this.numLabels);
            this.isInitialized = true;
        }
        ModelUpdateRule modelUpdateRule = this.getModelUpdateRule(this.lossFunction);
        for (int iter = 0; iter < this.epochs; ++iter) {
            for (DataPair dataItem : trainData) {
                modelUpdateRule.process(dataItem, null);
            }
        }
    }

    @Override
    public MultiLabelOutput makePredictionInternal(Instance instance) throws InvalidDataException {
        double[] input = this.getFeatureVector(instance);
        double[] labelConfidences = new double[this.numLabels];
        for (int index = 0; index < this.numLabels; ++index) {
            Neuron perceptron = this.perceptrons.get(index);
            labelConfidences[index] = perceptron.processInput(input);
        }
        MultiLabelOutput mlOut = new MultiLabelOutput(MultiLabelOutput.ranksFromValues(labelConfidences));
        return mlOut;
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInfo = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        technicalInfo.setValue(TechnicalInformation.Field.AUTHOR, "Koby Crammer, Yoram Singer");
        technicalInfo.setValue(TechnicalInformation.Field.YEAR, "2003");
        technicalInfo.setValue(TechnicalInformation.Field.TITLE, "A Family of Additive Online Algorithms for Category Ranking.");
        technicalInfo.setValue(TechnicalInformation.Field.JOURNAL, "Journal of Machine Learning Research");
        technicalInfo.setValue(TechnicalInformation.Field.VOLUME, "3(6)");
        technicalInfo.setValue(TechnicalInformation.Field.PAGES, "1025-1058");
        return technicalInfo;
    }

    private List<Neuron> initializeModel(int numFeatures, int numLabels) {
        Random random = this.randomnessSeed == null ? null : new Random(this.randomnessSeed);
        ArrayList<Neuron> tempPerceptrons = new ArrayList<Neuron>(numLabels);
        for (int i = 0; i < numLabels; ++i) {
            tempPerceptrons.add(new Neuron((ActivationFunction)new ActivationLinear(), numFeatures, 1.0, random));
        }
        return tempPerceptrons;
    }

    private ModelUpdateRule getModelUpdateRule(RankingLossFunction lossMeasure) {
        switch (this.mmpUpdateRule) {
            case UniformUpdate: {
                return new MMPUniformUpdateRule(this.perceptrons, lossMeasure);
            }
            case MaxUpdate: {
                return new MMPMaxUpdateRule(this.perceptrons, lossMeasure);
            }
            case RandomizedUpdate: {
                return new MMPRandomizedUpdateRule(this.perceptrons, lossMeasure);
            }
        }
        throw new IllegalArgumentException(String.format("The specified model update rule '%s' is not supported.", new Object[]{this.mmpUpdateRule}));
    }

    private List<DataPair> prepareData(MultiLabelInstances mlData) {
        Set<Attribute> featureAttr = mlData.getFeatureAttributes();
        String nominalAttrRange = this.ensureAttributesFormat(featureAttr);
        Instances dataSet = mlData.getDataSet();
        if (this.convertNomToBin && nominalAttrRange.length() > 0) {
            if (!this.isInitialized) {
                this.nomToBinFilter = new NominalToBinary();
                try {
                    this.nomToBinFilter = new NominalToBinary();
                    this.nomToBinFilter.setAttributeIndices(nominalAttrRange.toString());
                    this.nomToBinFilter.setInputFormat(dataSet);
                }
                catch (Exception exception) {
                    this.nomToBinFilter = null;
                    if (this.getDebug()) {
                        this.debug("Failed to create NominalToBinary filter for the input instances data. Error message: " + exception.getMessage());
                    }
                    throw new WekaException("Failed to create NominalToBinary filter for the input instances data.", exception);
                }
            }
            try {
                dataSet = Filter.useFilter((Instances)dataSet, (Filter)this.nomToBinFilter);
                mlData = mlData.reintegrateModifiedDataSet(dataSet);
                this.labelIndices = mlData.getLabelIndices();
            }
            catch (Exception exception) {
                if (this.getDebug()) {
                    this.debug("Failed to apply NominalToBinary filter to the input instances data. Error message: " + exception.getMessage());
                }
                throw new WekaException("Failed to apply NominalToBinary filter to the input instances data.", exception);
            }
        }
        return DataPair.createDataPairs(mlData, false);
    }

    private String ensureAttributesFormat(Set<Attribute> attributes) {
        StringBuilder nominalAttrRange = new StringBuilder();
        String rangeDelimiter = ",";
        for (Attribute attribute : attributes) {
            if (attribute.isNumeric() || !attribute.isNominal()) continue;
            nominalAttrRange.append(attribute.index() + 1).append(rangeDelimiter);
        }
        if (nominalAttrRange.length() > 0) {
            nominalAttrRange.deleteCharAt(nominalAttrRange.lastIndexOf(rangeDelimiter));
        }
        return nominalAttrRange.toString();
    }

    private double[] getFeatureVector(Instance inputInstance) {
        int modelInputDim;
        int numAttributes;
        if (this.convertNomToBin && this.nomToBinFilter != null) {
            try {
                this.nomToBinFilter.input(inputInstance);
                inputInstance = this.nomToBinFilter.output();
                inputInstance.setDataset(null);
            }
            catch (Exception ex) {
                throw new InvalidDataException("The input instance for prediction is invalid. Instance is not consistent with the data the model was built for.");
            }
        }
        if ((numAttributes = inputInstance.numAttributes()) < (modelInputDim = this.perceptrons.get(0).getWeights().length - 1)) {
            throw new InvalidDataException("Input instance do not have enough attributes to be processed by the model. Instance is not consistent with the data the model was built for.");
        }
        ArrayList<Integer> labelIndices = new ArrayList<Integer>();
        boolean labelsAreThere = false;
        if (numAttributes > modelInputDim) {
            for (int index : this.labelIndices) {
                labelIndices.add(index);
            }
            labelsAreThere = true;
        }
        double[] inputPattern = new double[modelInputDim];
        int indexCounter = 0;
        for (int attrIndex = 0; attrIndex < numAttributes; ++attrIndex) {
            if (labelsAreThere && labelIndices.contains(attrIndex)) continue;
            inputPattern[indexCounter] = inputInstance.value(attrIndex);
            ++indexCounter;
        }
        return inputPattern;
    }

    @Override
    public String globalInfo() {
        return "Implementation of Multiclass Multilabel Perceptrons learner. For more information, see\n\n" + this.getTechnicalInformation().toString();
    }
}

