/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.estimators.MultivariateGaussianEstimator;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.RemoveUseless;

public class LDA
extends AbstractClassifier
implements WeightedInstancesHandler {
    static final long serialVersionUID = -8213283598193689271L;
    protected Instances m_Data;
    protected MultivariateGaussianEstimator m_Estimator;
    protected double[][] m_Means;
    protected double[] m_GlobalMean;
    protected double[] m_LogPriors;
    protected double m_Ridge = 1.0E-6;
    protected RemoveUseless m_RemoveUseless;

    public String globalInfo() {
        return "Generates an LDA model. The covariance matrix is estimated using maximum likelihood from the pooled data.";
    }

    public String ridgeTipText() {
        return "The value of the ridge parameter.";
    }

    public double getRidge() {
        return this.m_Ridge;
    }

    public void setRidge(double newRidge) {
        this.m_Ridge = newRidge;
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> newVector = new Vector<Option>(7);
        newVector.addElement(new Option("\tThe ridge parameter.\n\t(default is 1e-6)", "R", 0, "-R"));
        newVector.addAll(Collections.list(super.listOptions()));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String ridgeString = Utils.getOption('R', options);
        if (ridgeString.length() != 0) {
            this.setRidge(Double.parseDouble(ridgeString));
        } else {
            this.setRidge(1.0E-6);
        }
        super.setOptions(options);
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        Vector<String> options = new Vector<String>();
        options.add("-R");
        options.add("" + this.getRidge());
        Collections.addAll(options, super.getOptions());
        return options.toArray(new String[0]);
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.setMinimumNumberInstances(0);
        return result;
    }

    @Override
    public void buildClassifier(Instances insts) throws Exception {
        this.getCapabilities().testWithFail(insts);
        this.m_RemoveUseless = new RemoveUseless();
        this.m_RemoveUseless.setInputFormat(insts);
        insts = Filter.useFilter(insts, this.m_RemoveUseless);
        insts.deleteWithMissingClass();
        if (insts.numInstances() == 0) {
            this.m_Data = new Instances(insts, 0);
            this.m_Means = new double[insts.numClasses()][];
            return;
        }
        int[] counts = new int[insts.numClasses()];
        double[] sumOfWeightsPerClass = new double[insts.numClasses()];
        for (int i = 0; i < insts.numInstances(); ++i) {
            int classIndex;
            Instance inst = insts.instance(i);
            int n = classIndex = (int)inst.classValue();
            counts[n] = counts[n] + 1;
            int n2 = classIndex;
            sumOfWeightsPerClass[n2] = sumOfWeightsPerClass[n2] + inst.weight();
        }
        double[][][] data = new double[insts.numClasses()][][];
        double[][] weights = new double[insts.numClasses()][];
        for (int i = 0; i < insts.numClasses(); ++i) {
            data[i] = new double[counts[i]][insts.numAttributes() - 1];
            weights[i] = new double[counts[i]];
        }
        int[] currentCount = new int[insts.numClasses()];
        for (int i = 0; i < insts.numInstances(); ++i) {
            Instance inst = insts.instance(i);
            int classIndex = (int)inst.classValue();
            weights[classIndex][currentCount[classIndex]] = inst.weight();
            int index = 0;
            int n = classIndex;
            int n3 = currentCount[n];
            currentCount[n] = n3 + 1;
            double[] row = data[classIndex][n3];
            for (int j = 0; j < inst.numAttributes(); ++j) {
                if (j == insts.classIndex()) continue;
                row[index++] = inst.value(j);
            }
        }
        this.m_Estimator = new MultivariateGaussianEstimator();
        this.m_Estimator.setRidge(this.getRidge());
        this.m_Means = this.m_Estimator.estimatePooled(data, weights);
        this.m_GlobalMean = this.m_Estimator.getMean();
        this.m_LogPriors = new double[insts.numClasses()];
        double sumOfWeights = Utils.sum(sumOfWeightsPerClass);
        for (int i = 0; i < insts.numClasses(); ++i) {
            if (!(sumOfWeightsPerClass[i] > 0.0)) continue;
            this.m_LogPriors[i] = Math.log(sumOfWeightsPerClass[i]) - Math.log(sumOfWeights);
        }
        this.m_Data = new Instances(insts, 0);
    }

    @Override
    public double[] distributionForInstance(Instance inst) throws Exception {
        this.m_RemoveUseless.input(inst);
        inst = this.m_RemoveUseless.output();
        double[] posteriorProbs = new double[this.m_Data.numClasses()];
        double[] values = new double[inst.numAttributes() - 1];
        for (int i = 0; i < this.m_Data.numClasses(); ++i) {
            if (this.m_Means[i] != null) {
                int index = 0;
                for (int j = 0; j < this.m_Data.numAttributes(); ++j) {
                    if (j == this.m_Data.classIndex()) continue;
                    values[index] = inst.value(j) - this.m_Means[i][index] + this.m_GlobalMean[index];
                    ++index;
                }
                posteriorProbs[i] = this.m_Estimator.logDensity(values) + this.m_LogPriors[i];
                continue;
            }
            posteriorProbs[i] = -1.7976931348623157E308;
        }
        posteriorProbs = Utils.logs2probs(posteriorProbs);
        return posteriorProbs;
    }

    public String toString() {
        if (this.m_Means == null) {
            return "No model has been built yet.";
        }
        StringBuffer result = new StringBuffer();
        result.append("LDA model (multivariate Gaussian for each class)\n\n");
        result.append("Pooled estimator\n\n" + this.m_Estimator + "\n\n");
        for (int i = 0; i < this.m_Data.numClasses(); ++i) {
            if (this.m_Means[i] == null) continue;
            result.append("Estimates for class value " + this.m_Data.classAttribute().value(i) + "\n\n");
            result.append("Natural logarithm of class prior probability: " + Utils.doubleToString(this.m_LogPriors[i], this.getNumDecimalPlaces()) + "\n");
            result.append("Class prior probability: " + Utils.doubleToString(Math.exp(this.m_LogPriors[i]), this.getNumDecimalPlaces()) + "\n\n");
            int index = 0;
            result.append("Mean vector:\n\n");
            for (int j = 0; j < this.m_Data.numAttributes(); ++j) {
                if (j == this.m_Data.classIndex()) continue;
                result.append(this.m_Data.attribute(j).name() + ": " + Utils.doubleToString(this.m_Means[i][index], this.getNumDecimalPlaces()) + "\n");
                ++index;
            }
            result.append("\n");
        }
        return result.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 10382 $");
    }

    public static void main(String[] argv) {
        LDA.runClassifier(new LDA(), argv);
    }
}

