/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.estimators.MultivariateGaussianEstimator;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.RemoveUseless;

public class QDA
extends AbstractClassifier
implements WeightedInstancesHandler {
    static final long serialVersionUID = -9113383498193689291L;
    protected Instances m_Data;
    protected MultivariateGaussianEstimator[] m_Estimators;
    protected double[] m_LogPriors;
    protected double m_Ridge = 1.0E-6;
    protected RemoveUseless m_RemoveUseless;

    public String globalInfo() {
        return "Generates a QDA model. The covariance matrices are estimated using maximum likelihood from the per-class data.";
    }

    public String ridgeTipText() {
        return "The value of the ridge parameter.";
    }

    public double getRidge() {
        return this.m_Ridge;
    }

    public void setRidge(double newRidge) {
        this.m_Ridge = newRidge;
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> newVector = new Vector<Option>(7);
        newVector.addElement(new Option("\tThe ridge parameter.\n\t(default is 1e-6)", "R", 0, "-R"));
        newVector.addAll(Collections.list(super.listOptions()));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String ridgeString = Utils.getOption('R', options);
        if (ridgeString.length() != 0) {
            this.setRidge(Double.parseDouble(ridgeString));
        } else {
            this.setRidge(1.0E-6);
        }
        super.setOptions(options);
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        Vector<String> options = new Vector<String>();
        options.add("-R");
        options.add("" + this.getRidge());
        Collections.addAll(options, super.getOptions());
        return options.toArray(new String[0]);
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.setMinimumNumberInstances(0);
        return result;
    }

    @Override
    public void buildClassifier(Instances insts) throws Exception {
        int i;
        this.getCapabilities().testWithFail(insts);
        this.m_RemoveUseless = new RemoveUseless();
        this.m_RemoveUseless.setInputFormat(insts);
        insts = Filter.useFilter(insts, this.m_RemoveUseless);
        insts.deleteWithMissingClass();
        int[] counts = new int[insts.numClasses()];
        double[] sumOfWeightsPerClass = new double[insts.numClasses()];
        for (int i2 = 0; i2 < insts.numInstances(); ++i2) {
            int classIndex;
            Instance inst = insts.instance(i2);
            int n = classIndex = (int)inst.classValue();
            counts[n] = counts[n] + 1;
            int n2 = classIndex;
            sumOfWeightsPerClass[n2] = sumOfWeightsPerClass[n2] + inst.weight();
        }
        double[][][] data = new double[insts.numClasses()][][];
        double[][] weights = new double[insts.numClasses()][];
        for (int i3 = 0; i3 < insts.numClasses(); ++i3) {
            data[i3] = new double[counts[i3]][insts.numAttributes() - 1];
            weights[i3] = new double[counts[i3]];
        }
        int[] currentCount = new int[insts.numClasses()];
        for (i = 0; i < insts.numInstances(); ++i) {
            Instance inst = insts.instance(i);
            int classIndex = (int)inst.classValue();
            weights[classIndex][currentCount[classIndex]] = inst.weight();
            int index = 0;
            int n = classIndex;
            int n3 = currentCount[n];
            currentCount[n] = n3 + 1;
            double[] row = data[classIndex][n3];
            for (int j = 0; j < inst.numAttributes(); ++j) {
                if (j == insts.classIndex()) continue;
                row[index++] = inst.value(j);
            }
        }
        this.m_Estimators = new MultivariateGaussianEstimator[insts.numClasses()];
        for (i = 0; i < insts.numClasses(); ++i) {
            if (!(sumOfWeightsPerClass[i] > 0.0)) continue;
            this.m_Estimators[i] = new MultivariateGaussianEstimator();
            this.m_Estimators[i].setRidge(this.getRidge());
            this.m_Estimators[i].estimate(data[i], weights[i]);
        }
        this.m_LogPriors = new double[insts.numClasses()];
        double sumOfWeights = Utils.sum(sumOfWeightsPerClass);
        for (int i4 = 0; i4 < insts.numClasses(); ++i4) {
            if (!(sumOfWeightsPerClass[i4] > 0.0)) continue;
            this.m_LogPriors[i4] = Math.log(sumOfWeightsPerClass[i4]) - Math.log(sumOfWeights);
        }
        this.m_Data = new Instances(insts, 0);
    }

    @Override
    public double[] distributionForInstance(Instance inst) throws Exception {
        this.m_RemoveUseless.input(inst);
        inst = this.m_RemoveUseless.output();
        double[] values = new double[inst.numAttributes() - 1];
        int index = 0;
        for (int i = 0; i < this.m_Data.numAttributes(); ++i) {
            if (i == this.m_Data.classIndex()) continue;
            values[index++] = inst.value(i);
        }
        double[] posteriorProbs = new double[this.m_Data.numClasses()];
        for (int i = 0; i < this.m_Data.numClasses(); ++i) {
            posteriorProbs[i] = this.m_Estimators[i] != null ? this.m_Estimators[i].logDensity(values) + this.m_LogPriors[i] : -1.7976931348623157E308;
        }
        posteriorProbs = Utils.logs2probs(posteriorProbs);
        return posteriorProbs;
    }

    public String toString() {
        if (this.m_LogPriors == null) {
            return "No model has been built yet.";
        }
        StringBuffer result = new StringBuffer();
        result.append("QDA model (multivariate Gaussian for each class)\n\n");
        for (int i = 0; i < this.m_Data.numClasses(); ++i) {
            if (this.m_Estimators[i] == null) continue;
            result.append("Estimates for class " + this.m_Data.classAttribute().value(i) + "\n\n");
            result.append("Natural logarithm of class prior probability: " + Utils.doubleToString(this.m_LogPriors[i], this.getNumDecimalPlaces()) + "\n");
            result.append("Class prior probability: " + Utils.doubleToString(Math.exp(this.m_LogPriors[i]), this.getNumDecimalPlaces()) + "\n\n");
            result.append("Multivariate Gaussian estimator:\n\n" + this.m_Estimators[i] + "\n");
        }
        return result.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 10382 $");
    }

    public static void main(String[] argv) {
        QDA.runClassifier(new QDA(), argv);
    }
}

