/*
 * Decompiled with CFR 0.152.
 */
package weka.filters.supervised.attribute;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrices;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.SymmDenseEVD;
import no.uib.cipr.matrix.UpperSymmDenseMatrix;
import no.uib.cipr.matrix.Vector;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.SimpleBatchFilter;

public class MultiClassFLDA
extends SimpleBatchFilter
implements OptionHandler,
WeightedInstancesHandler {
    static final long serialVersionUID = -291536442147283133L;
    protected Matrix m_WeightingMatrix;
    protected double m_Ridge = 1.0E-6;

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = new Capabilities(this);
        result.disableAll();
        result.setMinimumNumberInstances(0);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    @Override
    public String globalInfo() {
        return "Implements Fisher's linear discriminant analysis for dimensionality reduction. Note that this implementation adds the value of the ridge parameter to the diagonal of the pooled within-class scatter matrix.";
    }

    public String ridgeTipText() {
        return "The ridge parameter to add to the diagonal of the pooled within-class scatter matrix.";
    }

    public double getRidge() {
        return this.m_Ridge;
    }

    public void setRidge(double newRidge) {
        this.m_Ridge = newRidge;
    }

    @Override
    public Enumeration<Option> listOptions() {
        java.util.Vector<Option> newVector = new java.util.Vector<Option>(7);
        newVector.addElement(new Option("\tThe ridge parameter to add to the diagonal of the pooled within-class scatter matrix.\n\t(default is 1e-6)", "R", 0, "-R"));
        newVector.addAll(Collections.list(super.listOptions()));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String ridgeString = Utils.getOption('R', options);
        if (ridgeString.length() != 0) {
            this.setRidge(Double.parseDouble(ridgeString));
        } else {
            this.setRidge(1.0E-6);
        }
        super.setOptions(options);
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        java.util.Vector<String> options = new java.util.Vector<String>();
        options.add("-R");
        options.add("" + this.getRidge());
        Collections.addAll(options, super.getOptions());
        return options.toArray(new String[0]);
    }

    @Override
    public boolean allowAccessToFullInputFormat() {
        return true;
    }

    protected Vector computeMean(Instances data, double[] totalWeight, int aI) {
        DenseVector meanVector = new DenseVector(data.numAttributes() - 1);
        totalWeight[aI] = 0.0;
        for (Instance inst : data) {
            if (inst.classIsMissing()) continue;
            meanVector.add(inst.weight(), this.instanceToVector(inst));
            int n = aI;
            totalWeight[n] = totalWeight[n] + inst.weight();
        }
        meanVector.scale(1.0 / totalWeight[aI]);
        return meanVector;
    }

    protected Vector instanceToVector(Instance inst) {
        DenseVector v = new DenseVector(inst.numAttributes() - 1);
        int index = 0;
        for (int i = 0; i < inst.numAttributes(); ++i) {
            if (i == inst.classIndex()) continue;
            v.set(index++, inst.value(i));
        }
        return v;
    }

    @Override
    protected Instances determineOutputFormat(Instances inputFormat) throws Exception {
        Vector diff;
        int m = inputFormat.numAttributes() - 1;
        double[] totalWeight = new double[1];
        Vector globalMean = this.computeMean(inputFormat, totalWeight, 0);
        Instances[] subsets = new Instances[inputFormat.numClasses()];
        for (int j = 0; j < subsets.length; ++j) {
            subsets[j] = new Instances(inputFormat, inputFormat.numInstances());
        }
        for (Instance inst : inputFormat) {
            if (inst.classIsMissing()) continue;
            subsets[(int)inst.classValue()].add(inst);
        }
        DenseVector[] perClassMeans = new DenseVector[inputFormat.numClasses()];
        double[] perClassWeights = new double[inputFormat.numClasses()];
        for (int i = 0; i < inputFormat.numClasses(); ++i) {
            perClassMeans[i] = this.computeMean(subsets[i], perClassWeights, i);
        }
        UpperSymmDenseMatrix Cw = new UpperSymmDenseMatrix(m);
        for (Instance inst : inputFormat) {
            if (inst.classIsMissing()) continue;
            diff = this.instanceToVector(inst);
            diff = diff.add(-1.0, (Vector)perClassMeans[(int)inst.classValue()]);
            Cw = Cw.rank1(inst.weight(), diff);
        }
        for (int i = 0; i < Cw.numColumns(); ++i) {
            Cw.add(i, i, this.m_Ridge);
        }
        UpperSymmDenseMatrix Cb = new UpperSymmDenseMatrix(m);
        for (int i = 0; i < inputFormat.numClasses(); ++i) {
            diff = perClassMeans[i].copy();
            diff = diff.add(-1.0, globalMean);
            Cb = Cb.rank1(perClassWeights[i], diff);
        }
        if (this.m_Debug) {
            System.err.println("Within-class scatter matrix :\n" + Cw);
            System.err.println("Between-class scatter matrix :\n" + Cb);
        }
        SymmDenseEVD evdCw = SymmDenseEVD.factorize((Matrix)Cw);
        DenseMatrix evCw = evdCw.getEigenvectors();
        double[] evs = evdCw.getEigenvalues();
        UpperSymmDenseMatrix D = new UpperSymmDenseMatrix(evs.length);
        for (int i = 0; i < evs.length; ++i) {
            if (!(evs[i] > 0.0)) {
                throw new IllegalArgumentException("Found non-positive eigenvalue of within-class scatter matrix.");
            }
            D.set(i, i, Math.sqrt(1.0 / evs[i]));
        }
        if (this.m_Debug) {
            System.err.println("evCw : \n" + evCw);
            System.err.println("Sqrt of reciprocal of eigenvalues of Cw: \n" + D);
            System.err.println("evCw times evCwTransposed : \n" + evCw.mult(evCw.transpose((Matrix)new DenseMatrix(m, m)), (Matrix)new DenseMatrix(m, m)));
        }
        Matrix temp = evCw.mult((Matrix)D, (Matrix)new DenseMatrix(m, m));
        Matrix sqrtCwInverse = temp.mult(evCw.transpose(), (Matrix)new UpperSymmDenseMatrix(m));
        if (this.m_Debug) {
            System.err.println("sqrtCwInverse : \n");
            for (int i = 0; i < sqrtCwInverse.numRows(); ++i) {
                for (int j = 0; j < sqrtCwInverse.numColumns(); ++j) {
                    System.err.print(sqrtCwInverse.get(i, j) + "\t");
                }
                System.err.println();
            }
            System.err.println("sqrtCwInverse times sqrtCwInverse : \n" + sqrtCwInverse.mult(sqrtCwInverse, (Matrix)new DenseMatrix(m, m)));
            DenseMatrix I = Matrices.identity((int)m);
            DenseMatrix CwInverse = I.copy();
            System.err.println("CwInverse : \n" + Cw.solve((Matrix)I, (Matrix)CwInverse));
        }
        temp = sqrtCwInverse.mult((Matrix)Cb, (Matrix)new DenseMatrix(m, m));
        Matrix symmMatrix = temp.mult(sqrtCwInverse, (Matrix)new UpperSymmDenseMatrix(m));
        if (this.m_Debug) {
            System.err.println("Symmetric matrix : \n" + symmMatrix);
        }
        SymmDenseEVD evd = SymmDenseEVD.factorize((Matrix)symmMatrix);
        if (this.m_Debug) {
            System.err.println("Eigenvectors of symmetric matrix :\n" + evd.getEigenvectors());
            System.err.println("Eigenvalues of symmetric matrix :\n" + Utils.arrayToString(evd.getEigenvalues()) + "\n");
        }
        ArrayList<Integer> indices = new ArrayList<Integer>();
        for (int i = 0; i < evd.getEigenvalues().length; ++i) {
            if (!Utils.gr(evd.getEigenvalues()[i], 0.0)) continue;
            indices.add(i);
        }
        int[] cols = new int[indices.size()];
        int index = 0;
        for (int i = indices.size() - 1; i >= 0; --i) {
            cols[index++] = (Integer)indices.get(i);
        }
        int[] rows = new int[evd.getEigenvectors().numRows()];
        for (int i = 0; i < rows.length; ++i) {
            rows[i] = i;
        }
        Matrix reducedMatrix = Matrices.getSubMatrix((Matrix)evd.getEigenvectors(), (int[])rows, (int[])cols);
        if (this.m_Debug) {
            System.err.println("Eigenvectors with eigenvalues > eps :\n" + reducedMatrix);
        }
        this.m_WeightingMatrix = sqrtCwInverse.mult(reducedMatrix, (Matrix)new DenseMatrix(rows.length, cols.length)).transpose((Matrix)new DenseMatrix(cols.length, rows.length));
        if (this.m_Debug) {
            System.err.println("Weighting matrix: \n");
            for (int i = 0; i < this.m_WeightingMatrix.numRows(); ++i) {
                for (int j = 0; j < this.m_WeightingMatrix.numColumns(); ++j) {
                    System.err.print(this.m_WeightingMatrix.get(i, j) + "\t");
                }
                System.err.println();
            }
        }
        ArrayList<Attribute> atts = new ArrayList<Attribute>(cols.length + 1);
        for (int i = 0; i < cols.length; ++i) {
            atts.add(new Attribute("z" + (i + 1)));
        }
        atts.add((Attribute)inputFormat.classAttribute().copy());
        Instances d = new Instances(inputFormat.relationName(), atts, 0);
        d.setClassIndex(d.numAttributes() - 1);
        return d;
    }

    @Override
    protected Instances process(Instances instances) throws Exception {
        Instances transformed = this.getOutputFormat();
        for (Instance inst : instances) {
            Vector newInst = this.m_WeightingMatrix.mult(this.instanceToVector(inst), (Vector)new DenseVector(this.m_WeightingMatrix.numRows()));
            double[] newVals = new double[this.m_WeightingMatrix.numRows() + 1];
            for (int i = 0; i < this.m_WeightingMatrix.numRows(); ++i) {
                newVals[i] = newInst.get(i);
            }
            newVals[transformed.classIndex()] = inst.classValue();
            transformed.add(new DenseInstance(inst.weight(), newVals));
        }
        return transformed;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 12037 $");
    }

    public static void main(String[] argv) {
        MultiClassFLDA.runFilter(new MultiClassFLDA(), argv);
    }
}

