/*
 * Decompiled with CFR 0.152.
 */
package mulan.classifier.neural;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.Set;
import mulan.classifier.InvalidDataException;
import mulan.classifier.MultiLabelLearnerBase;
import mulan.classifier.MultiLabelOutput;
import mulan.classifier.neural.BPMLLAlgorithm;
import mulan.classifier.neural.DataPair;
import mulan.classifier.neural.NormalizationFilter;
import mulan.classifier.neural.ThresholdFunction;
import mulan.classifier.neural.model.ActivationTANH;
import mulan.classifier.neural.model.BasicNeuralNet;
import mulan.classifier.neural.model.NeuralNet;
import mulan.core.WekaException;
import mulan.data.DataUtils;
import mulan.data.InvalidDataFormatException;
import mulan.data.MultiLabelInstances;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;

public class BPMLL
extends MultiLabelLearnerBase {
    private static final long serialVersionUID = 2153814250172139021L;
    private static final double NET_BIAS = 1.0;
    private static final double ERROR_SMALL_CHANGE = 1.0E-6;
    private NominalToBinary nominalToBinaryFilter;
    private int epochs = 100;
    private final Long randomnessSeed;
    private double weightsDecayCost = 1.0E-5;
    private double learningRate = 0.05;
    private int[] hiddenLayersTopology;
    private boolean normalizeAttributes = true;
    private NormalizationFilter normalizer;
    private NeuralNet model;
    private ThresholdFunction thresholdF;

    public BPMLL() {
        this.randomnessSeed = null;
    }

    public BPMLL(long randomnessSeed) {
        this.randomnessSeed = randomnessSeed;
    }

    public void setHiddenLayers(int[] hiddenLayers) {
        if (hiddenLayers != null) {
            for (int value : hiddenLayers) {
                if (value > 0) continue;
                throw new IllegalArgumentException("Invalid hidden layer topology definition. Number of neurons in hidden layer must be larger than zero.");
            }
        }
        this.hiddenLayersTopology = hiddenLayers;
    }

    public int[] getHiddenLayers() {
        return this.hiddenLayersTopology == null ? this.hiddenLayersTopology : Arrays.copyOf(this.hiddenLayersTopology, this.hiddenLayersTopology.length);
    }

    public void setLearningRate(double learningRate) {
        if (learningRate <= 0.0 || learningRate > 1.0) {
            throw new IllegalArgumentException("The learning rate must be greater than 0 and no more than 1. Entered value is : " + learningRate);
        }
        this.learningRate = learningRate;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setWeightsDecayRegularization(double weightsDecayCost) {
        if (weightsDecayCost <= 0.0 || weightsDecayCost > 1.0) {
            throw new IllegalArgumentException("The weights decay regularization cost term must be greater than 0 and no more than 1. The passed  value is : " + weightsDecayCost);
        }
        this.weightsDecayCost = weightsDecayCost;
    }

    public double getWeightsDecayRegularization() {
        return this.weightsDecayCost;
    }

    public void setTrainingEpochs(int epochs) {
        if (epochs <= 0) {
            throw new IllegalArgumentException("The number of training epochs must be greater than zero. Entered value is : " + epochs);
        }
        this.epochs = epochs;
    }

    public int getTrainingEpochs() {
        return this.epochs;
    }

    public void setNormalizeAttributes(boolean normalize) {
        this.normalizeAttributes = normalize;
    }

    public boolean getNormalizeAttributes() {
        return this.normalizeAttributes;
    }

    @Override
    protected void buildInternal(MultiLabelInstances instances) throws Exception {
        this.nominalToBinaryFilter = null;
        MultiLabelInstances trainInstances = instances.clone();
        List<DataPair> trainData = this.prepareData(trainInstances);
        int inputsDim = trainData.get(0).getInput().length;
        this.model = this.buildNeuralNetwork(inputsDim);
        BPMLLAlgorithm learnAlg = new BPMLLAlgorithm(this.model, this.weightsDecayCost);
        int numInstances = trainData.size();
        int processedInstances = 0;
        double prevError = Double.MAX_VALUE;
        double error = 0.0;
        for (int epoch = 0; epoch < this.epochs; ++epoch) {
            double errorDiff;
            Collections.shuffle(trainData, new Random(1L));
            for (int index = 0; index < numInstances; ++index) {
                DataPair trainPair = trainData.get(index);
                double result = learnAlg.learn(trainPair.getInput(), trainPair.getOutput(), this.learningRate);
                if (Double.isNaN(result)) continue;
                error += result;
                ++processedInstances;
            }
            if (this.getDebug() && epoch % 10 == 0) {
                this.debug("Training epoch : " + epoch + "  Model error : " + error / (double)processedInstances);
            }
            if (!((errorDiff = prevError - error) <= 1.0E-6 * prevError)) continue;
            if (!this.getDebug()) break;
            this.debug("Global training error does not decrease enough. Training terminated.");
            break;
        }
        this.thresholdF = this.buildThresholdFunction(trainData);
    }

    @Override
    public String globalInfo() {
        return "The implementation of Back-Propagation Multi-Label Learning (BPMLL) learner. The learned model is stored in {@link NeuralNet} neural network. The models of the learner built by {@link BPMLLAlgorithm} from given training data set.";
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInfo = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        technicalInfo.setValue(TechnicalInformation.Field.AUTHOR, "Zhang, M.L., Zhou, Z.H.");
        technicalInfo.setValue(TechnicalInformation.Field.YEAR, "2006");
        technicalInfo.setValue(TechnicalInformation.Field.TITLE, "Multi-label neural networks with applications to functional genomics and text categorization");
        technicalInfo.setValue(TechnicalInformation.Field.JOURNAL, "IEEE Transactions on Knowledge and Data Engineering");
        technicalInfo.setValue(TechnicalInformation.Field.VOLUME, "18");
        technicalInfo.setValue(TechnicalInformation.Field.PAGES, "1338-1351");
        return technicalInfo;
    }

    private ThresholdFunction buildThresholdFunction(List<DataPair> trainData) {
        int numExamples = trainData.size();
        double[][] idealLabels = new double[numExamples][this.numLabels];
        double[][] modelConfidences = new double[numExamples][this.numLabels];
        for (int example = 0; example < numExamples; ++example) {
            DataPair dataPair = trainData.get(example);
            idealLabels[example] = dataPair.getOutput();
            modelConfidences[example] = this.model.feedForward(dataPair.getInput());
        }
        return new ThresholdFunction(idealLabels, modelConfidences);
    }

    private NeuralNet buildNeuralNetwork(int inputsDim) {
        int[] networkTopology;
        if (this.hiddenLayersTopology == null) {
            int hiddenUnits = Math.round(0.2f * (float)inputsDim);
            this.hiddenLayersTopology = new int[]{hiddenUnits};
            networkTopology = new int[]{inputsDim, hiddenUnits, this.numLabels};
        } else {
            networkTopology = new int[this.hiddenLayersTopology.length + 2];
            networkTopology[0] = inputsDim;
            System.arraycopy(this.hiddenLayersTopology, 0, networkTopology, 1, this.hiddenLayersTopology.length);
            networkTopology[networkTopology.length - 1] = this.numLabels;
        }
        BasicNeuralNet aModel = new BasicNeuralNet(networkTopology, 1.0, ActivationTANH.class, this.randomnessSeed == null ? null : new Random(this.randomnessSeed));
        return aModel;
    }

    private List<DataPair> prepareData(MultiLabelInstances mlData) {
        Instances data = mlData.getDataSet();
        if ((data = this.checkAttributesFormat(data, mlData.getFeatureAttributes())) == null) {
            throw new InvalidDataException("Attributes are not in correct format. Input attributes (all but the label attributes) must be nominal or numeric.");
        }
        try {
            mlData = mlData.reintegrateModifiedDataSet(data);
            this.labelIndices = mlData.getLabelIndices();
        }
        catch (InvalidDataFormatException e) {
            throw new InvalidDataException("Failed to create a multilabel data set from modified instances.");
        }
        if (this.normalizeAttributes) {
            this.normalizer = new NormalizationFilter(mlData, true, -0.8, 0.8);
        }
        return DataPair.createDataPairs(mlData, true);
    }

    private Instances checkAttributesFormat(Instances dataSet, Set<Attribute> inputAttributes) {
        StringBuilder nominalAttrRange = new StringBuilder();
        String rangeDelimiter = ",";
        for (Attribute attribute : inputAttributes) {
            if (attribute.isNumeric()) continue;
            if (attribute.isNominal()) {
                nominalAttrRange.append(attribute.index() + 1 + rangeDelimiter);
                continue;
            }
            return null;
        }
        if (nominalAttrRange.length() > 0) {
            nominalAttrRange.deleteCharAt(nominalAttrRange.lastIndexOf(rangeDelimiter));
            try {
                this.nominalToBinaryFilter = new NominalToBinary();
                this.nominalToBinaryFilter.setAttributeIndices(nominalAttrRange.toString());
                this.nominalToBinaryFilter.setInputFormat(dataSet);
                dataSet = Filter.useFilter((Instances)dataSet, (Filter)this.nominalToBinaryFilter);
            }
            catch (Exception exception) {
                this.nominalToBinaryFilter = null;
                if (this.getDebug()) {
                    this.debug("Failed to apply NominalToBinary filter to the input instances data. Error message: " + exception.getMessage());
                }
                throw new WekaException("Failed to apply NominalToBinary filter to the input instances data.", exception);
            }
        }
        return dataSet;
    }

    @Override
    public MultiLabelOutput makePredictionInternal(Instance instance) throws InvalidDataException {
        Instance inputInstance = null;
        if (this.nominalToBinaryFilter != null) {
            try {
                this.nominalToBinaryFilter.input(instance);
                inputInstance = this.nominalToBinaryFilter.output();
                inputInstance.setDataset(null);
            }
            catch (Exception ex) {
                throw new InvalidDataException("The input instance for prediction is invalid. Instance is not consistent with the data the model was built for.");
            }
        } else {
            inputInstance = DataUtils.createInstance(instance, instance.weight(), instance.toDoubleArray());
        }
        int numAttributes = inputInstance.numAttributes();
        if (numAttributes < this.model.getNetInputSize()) {
            throw new InvalidDataException("Input instance do not have enough attributes to be processed by the model. Instance is not consistent with the data the model was built for.");
        }
        ArrayList<Integer> someLabelIndices = new ArrayList<Integer>();
        boolean labelsAreThere = false;
        if (numAttributes > this.model.getNetInputSize()) {
            for (int index : this.labelIndices) {
                someLabelIndices.add(index);
            }
            labelsAreThere = true;
        }
        if (this.normalizeAttributes) {
            this.normalizer.normalize(inputInstance);
        }
        int inputDim = this.model.getNetInputSize();
        double[] inputPattern = new double[inputDim];
        int indexCounter = 0;
        for (int attrIndex = 0; attrIndex < numAttributes; ++attrIndex) {
            if (labelsAreThere && someLabelIndices.contains(attrIndex)) continue;
            inputPattern[indexCounter] = inputInstance.value(attrIndex);
            ++indexCounter;
        }
        double[] labelConfidences = this.model.feedForward(inputPattern);
        double threshold = this.thresholdF.computeThreshold(labelConfidences);
        boolean[] labelPredictions = new boolean[this.numLabels];
        Arrays.fill(labelPredictions, false);
        for (int labelIndex = 0; labelIndex < this.numLabels; ++labelIndex) {
            if (labelConfidences[labelIndex] > threshold) {
                labelPredictions[labelIndex] = true;
            }
            labelConfidences[labelIndex] = (labelConfidences[labelIndex] + 1.0) / 2.0;
        }
        MultiLabelOutput mlo = new MultiLabelOutput(labelPredictions, labelConfidences);
        return mlo;
    }
}

