/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.util.Collections;
import java.util.Enumeration;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.UpperSymmDenseMatrix;
import no.uib.cipr.matrix.Vector;
import weka.classifiers.AbstractClassifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.RemoveUseless;

public class FLDA
extends AbstractClassifier {
    static final long serialVersionUID = -9212385698193681291L;
    protected Instances m_Data;
    protected Vector m_Weights;
    protected double m_Threshold;
    protected double m_Ridge = 1.0E-6;
    protected RemoveUseless m_RemoveUseless;

    public String globalInfo() {
        return "Builds Fisher's Linear Discriminant function. The threshold is selected so that the separator is half-way between centroids. The class must be binary and all other attributes must be numeric. Missing values are not permitted. Constant attributes are removed using RemoveUseless. No standardization or normalization of attributes is performed.";
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.BINARY_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.setMinimumNumberInstances(0);
        return result;
    }

    protected Vector[] getClassMeans(Instances data, int[] counts) {
        double[][] centroids = new double[2][data.numAttributes() - 1];
        for (int i = 0; i < data.numInstances(); ++i) {
            Instance inst = data.instance(i);
            int index = 0;
            for (int j = 0; j < data.numAttributes(); ++j) {
                if (j == data.classIndex()) continue;
                double[] dArray = centroids[(int)inst.classValue()];
                int n = index++;
                dArray[n] = dArray[n] + inst.value(j);
            }
            int n = (int)inst.classValue();
            counts[n] = counts[n] + 1;
        }
        DenseVector[] centroidVectors = new DenseVector[2];
        for (int i = 0; i < 2; ++i) {
            centroidVectors[i] = new DenseVector(centroids[i]);
            centroidVectors[i].scale(1.0 / (double)counts[i]);
        }
        if (this.m_Debug) {
            System.out.println("Count for class 0: " + counts[0]);
            System.out.println("Centroid 0:" + centroidVectors[0]);
            System.out.println("Count for class 11: " + counts[1]);
            System.out.println("Centroid 1:" + centroidVectors[1]);
        }
        return centroidVectors;
    }

    protected Matrix[] getCenteredData(Instances data, int[] counts, Vector[] centroids) {
        Matrix[] centeredData = new Matrix[2];
        for (int i = 0; i < 2; ++i) {
            centeredData[i] = new DenseMatrix(data.numAttributes() - 1, counts[i]);
        }
        int[] indexC = new int[2];
        for (int i = 0; i < data.numInstances(); ++i) {
            Instance inst = data.instance(i);
            int classIndex = (int)inst.classValue();
            int index = 0;
            for (int j = 0; j < data.numAttributes(); ++j) {
                if (j == data.classIndex()) continue;
                centeredData[classIndex].set(index, indexC[classIndex], inst.value(j) - centroids[classIndex].get(index));
                ++index;
            }
            int n = classIndex;
            indexC[n] = indexC[n] + 1;
        }
        if (this.m_Debug) {
            System.out.println("Centered data for class 0:\n" + centeredData[0]);
            System.out.println("Centered data for class 1:\n" + centeredData[1]);
        }
        return centeredData;
    }

    @Override
    public void buildClassifier(Instances insts) throws Exception {
        this.getCapabilities().testWithFail(insts);
        this.m_RemoveUseless = new RemoveUseless();
        this.m_RemoveUseless.setInputFormat(insts);
        insts = Filter.useFilter(insts, this.m_RemoveUseless);
        insts.deleteWithMissingClass();
        int[] classCounts = new int[2];
        Vector[] centroids = this.getClassMeans(insts, classCounts);
        Vector diff = centroids[0].copy().add(-1.0, centroids[1]);
        Matrix[] data = this.getCenteredData(insts, classCounts, centroids);
        Matrix scatter = new UpperSymmDenseMatrix(data[0].numRows()).rank1(data[0]).add(new UpperSymmDenseMatrix(data[1].numRows()).rank1(data[1]));
        for (int i = 0; i < scatter.numColumns(); ++i) {
            scatter.add(i, i, this.m_Ridge);
        }
        if (this.m_Debug) {
            System.out.println("Scatter:\n" + scatter);
        }
        this.m_Weights = scatter.solve(diff, (Vector)new DenseVector(scatter.numColumns()));
        this.m_Weights.scale(1.0 / this.m_Weights.norm(Vector.Norm.Two));
        this.m_Threshold = 0.5 * this.m_Weights.dot(centroids[0].copy().add(centroids[1]));
        this.m_Data = new Instances(insts, 0);
    }

    @Override
    public double[] distributionForInstance(Instance inst) throws Exception {
        this.m_RemoveUseless.input(inst);
        inst = this.m_RemoveUseless.output();
        DenseVector instM = new DenseVector(inst.numAttributes() - 1);
        int index = 0;
        for (int i = 0; i < inst.numAttributes(); ++i) {
            if (i == this.m_Data.classIndex()) continue;
            instM.set(index++, inst.value(i));
        }
        double[] dist = new double[2];
        dist[1] = 1.0 / (1.0 + Math.exp(instM.dot(this.m_Weights) - this.m_Threshold));
        dist[0] = 1.0 - dist[1];
        return dist;
    }

    public String toString() {
        if (this.m_Weights == null) {
            return "No model has been built yet.";
        }
        StringBuffer result = new StringBuffer();
        result.append("Fisher's Linear Discriminant Analysis\n\n");
        result.append("Threshold: " + this.m_Threshold + "\n\n");
        result.append("Weights:\n\n");
        int index = 0;
        for (int i = 0; i < this.m_Data.numAttributes(); ++i) {
            if (i == this.m_Data.classIndex()) continue;
            result.append(this.m_Data.attribute(i).name() + ": \t");
            double weight = this.m_Weights.get(index++);
            if (weight >= 0.0) {
                result.append(" ");
            }
            result.append(weight + "\n");
        }
        return result.toString();
    }

    public String ridgeTipText() {
        return "The value of the ridge parameter.";
    }

    public double getRidge() {
        return this.m_Ridge;
    }

    public void setRidge(double newRidge) {
        this.m_Ridge = newRidge;
    }

    @Override
    public Enumeration<Option> listOptions() {
        java.util.Vector<Option> newVector = new java.util.Vector<Option>(7);
        newVector.addElement(new Option("\tThe ridge parameter.\n\t(default is 1e-6)", "R", 0, "-R"));
        newVector.addAll(Collections.list(super.listOptions()));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String ridgeString = Utils.getOption('R', options);
        if (ridgeString.length() != 0) {
            this.setRidge(Double.parseDouble(ridgeString));
        } else {
            this.setRidge(1.0E-6);
        }
        super.setOptions(options);
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        java.util.Vector<String> options = new java.util.Vector<String>();
        options.add("-R");
        options.add("" + this.getRidge());
        Collections.addAll(options, super.getOptions());
        return options.toArray(new String[0]);
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 10382 $");
    }

    public static void main(String[] argv) {
        FLDA.runClassifier(new FLDA(), argv);
    }
}

