/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.EVD;
import no.uib.cipr.matrix.Matrices;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.UpperSPDDenseMatrix;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.functions.supportVector.Kernel;
import weka.classifiers.functions.supportVector.RBFKernel;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionMetadata;
import weka.core.RevisionUtils;
import weka.core.SerializedObject;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.Nystroem;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.attribute.Standardize;
import weka.filters.unsupervised.instance.RemoveRange;

public class XNV
extends RandomizableClassifier
implements TechnicalInformationHandler {
    static final long serialVersionUID = -1585383626378691736L;
    protected Nystroem m_N1;
    protected Nystroem m_N2;
    protected Matrix m_wCCA;
    protected Matrix m_B1;
    protected Standardize m_Standardize;
    protected NominalToBinary m_NominalToBinary;
    protected ReplaceMissingValues m_Missing;
    protected double m_x1 = 1.0;
    protected double m_x0 = 0.0;
    protected int m_M = 100;
    protected Kernel m_Kernel = new RBFKernel();
    protected double m_Gamma = 0.01;
    protected int m_numLabeled;
    protected boolean m_doNotStandardize;

    public String globalInfo() {
        return "Implements the XNV method for semi-supervised learning using a kernel function (default: RBFKernel). Standardizes all attributes, including the target, by default. Applies (unsupervised) NominalToBinary and ReplaceMissingValues before anything else is done.\n\nFor more information on the algorithm, see\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Brian McWilliams and David Balduzzi and Joachim M. Buhmann");
        result.setValue(TechnicalInformation.Field.TITLE, "Correlated random features for fast semi-supervised learning");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "Proc 27th Annual Conference on Neural Information Processing Systems");
        result.setValue(TechnicalInformation.Field.PAGES, "440--448");
        result.setValue(TechnicalInformation.Field.YEAR, "2013");
        result.setValue(TechnicalInformation.Field.URL, "http://papers.nips.cc/paper/5000-correlated-random-features-for-fast-semi-supervised-learning.pdf");
        return result;
    }

    @OptionMetadata(displayName="Regularization parameter gamma", description="The regularization parameter gamma.", displayOrder=1, commandLineParamName="G", commandLineParamSynopsis="-G")
    public double getGamma() {
        return this.m_Gamma;
    }

    public void setGamma(double v) {
        this.m_Gamma = v;
    }

    @OptionMetadata(displayName="Sample size for Nystroem method", description="The sample size for the Nystroem method.", displayOrder=2, commandLineParamName="M", commandLineParamSynopsis="-M")
    public int getM() {
        return this.m_M;
    }

    public void setM(int v) {
        this.m_M = v;
    }

    @OptionMetadata(displayName="Kernel function", description="The kernel function to use.", displayOrder=3, commandLineParamName="K", commandLineParamSynopsis="-K <kernel specification>")
    public void setKernel(Kernel kernel) {
        this.m_Kernel = kernel;
    }

    public Kernel getKernel() {
        return this.m_Kernel;
    }

    @OptionMetadata(displayName="Do not apply standardization", description="If true, standardization will not be performed.", displayOrder=4, commandLineParamName="S", commandLineParamSynopsis="-S")
    public boolean getDoNotStandardize() {
        return this.m_doNotStandardize;
    }

    public void setDoNotStandardize(boolean v) {
        this.m_doNotStandardize = v;
    }

    public static DenseMatrix getMatrix(Instances data, boolean center, boolean transpose) {
        int numColumns;
        double[] means = new double[data.numAttributes()];
        if (center) {
            for (int j = 0; j < data.numAttributes(); ++j) {
                if (j == data.classIndex()) continue;
                means[j] = data.meanOrMode(j);
            }
        }
        int n = transpose ? data.numInstances() : (numColumns = data.numAttributes() - (data.classIndex() >= 0 ? 1 : 0));
        int numRows = transpose ? data.numAttributes() - (data.classIndex() >= 0 ? 1 : 0) : data.numInstances();
        DenseMatrix X = new DenseMatrix(numRows, numColumns);
        for (int i = 0; i < data.numInstances(); ++i) {
            Instance inst = data.instance(i);
            int index = 0;
            for (int j = 0; j < data.numAttributes(); ++j) {
                if (j == data.classIndex()) continue;
                double value = inst.value(j) - means[j];
                if (transpose) {
                    X.set(index++, i, value);
                    continue;
                }
                X.set(i, index++, value);
            }
        }
        return X;
    }

    public static Matrix inverse(Matrix M) throws Exception {
        if (M.numRows() != M.numColumns()) {
            throw new IllegalArgumentException("Matrix is not square: cannot invert it.");
        }
        DenseMatrix Minv = new DenseMatrix(M.numRows(), M.numRows());
        DenseMatrix I = Matrices.identity((int)M.numRows());
        M.solve((Matrix)I, (Matrix)Minv);
        return new DenseMatrix((Matrix)Minv);
    }

    public static EVD CCA(Matrix X1, Matrix X2) throws Exception {
        int M = X1.numRows();
        int N = X1.numColumns();
        UpperSPDDenseMatrix CX1X1 = (UpperSPDDenseMatrix)new UpperSPDDenseMatrix(M).rank1(X1);
        CX1X1.scale(1.0 / ((double)N - 1.0));
        for (int i = 0; i < M; ++i) {
            CX1X1.set(i, i, CX1X1.get(i, i) + 1.0E-8);
        }
        UpperSPDDenseMatrix CX2X2 = (UpperSPDDenseMatrix)new UpperSPDDenseMatrix(M).rank1(X2);
        CX2X2.scale(1.0 / ((double)N - 1.0));
        for (int i = 0; i < M; ++i) {
            CX2X2.set(i, i, CX2X2.get(i, i) + 1.0E-8);
        }
        Matrix CX1X2 = X1.transBmult(X2, (Matrix)new DenseMatrix(M, M));
        CX1X2.scale(1.0 / ((double)N - 1.0));
        Matrix CX2X1 = CX1X2.transpose((Matrix)new DenseMatrix(M, M));
        Matrix CX1X1invMultCX1X2 = XNV.inverse((Matrix)CX1X1).mult(CX1X2, (Matrix)new DenseMatrix(M, M));
        Matrix CX2X2invMultCX2X1 = XNV.inverse((Matrix)CX2X2).mult(CX2X1, (Matrix)new DenseMatrix(M, M));
        Matrix CX1X1invMultCX1X2MultCX2X2invMultCX2X1 = CX1X1invMultCX1X2.mult(CX2X2invMultCX2X1, (Matrix)new DenseMatrix(M, M));
        EVD evd = EVD.factorize((Matrix)CX1X1invMultCX1X2MultCX2X2invMultCX2X1);
        return evd;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        this.m_Missing = new ReplaceMissingValues();
        this.m_Missing.setInputFormat(data);
        data = Filter.useFilter(data, this.m_Missing);
        this.m_NominalToBinary = new NominalToBinary();
        this.m_NominalToBinary.setInputFormat(data);
        data = Filter.useFilter(data, this.m_NominalToBinary);
        data = new Instances(data);
        data.randomize(new Random(this.getSeed()));
        if (!this.getDoNotStandardize()) {
            int index1;
            int index0;
            for (index0 = 0; index0 < data.numInstances() && data.instance(index0).classIsMissing(); ++index0) {
            }
            if (index0 >= data.numInstances()) {
                throw new Exception("Need at least two instances with different target values.");
            }
            double y0 = data.instance(index0).classValue();
            for (index1 = index0 + 1; index1 < data.numInstances() && (data.instance(index1).classIsMissing() || data.instance(index1).classValue() == y0); ++index1) {
            }
            if (index1 >= data.numInstances()) {
                throw new Exception("Need at least two instances with different target values.");
            }
            double y1 = data.instance(index1).classValue();
            this.m_Standardize = new Standardize();
            this.m_Standardize.setIgnoreClass(true);
            this.m_Standardize.setInputFormat(data);
            data = Filter.useFilter(data, this.m_Standardize);
            double z0 = data.instance(index0).classValue();
            double z1 = data.instance(index1).classValue();
            this.m_x1 = (y0 - y1) / (z0 - z1);
            this.m_x0 = y0 - this.m_x1 * z0;
        }
        int M = Math.min(this.m_M, data.numInstances() / 2);
        this.m_N1 = new Nystroem();
        RemoveRange rr1 = new RemoveRange();
        rr1.setInvertSelection(true);
        rr1.setInstancesIndices("first-" + M);
        this.m_N1.setFilter(rr1);
        this.m_N1.setKernel((Kernel)new SerializedObject(this.m_Kernel).getObject());
        this.m_N1.setInputFormat(data);
        Instances N1data = Filter.useFilter(data, this.m_N1);
        DenseMatrix X1 = XNV.getMatrix(N1data, true, true);
        this.m_N2 = new Nystroem();
        RemoveRange rr2 = new RemoveRange();
        rr2.setInvertSelection(true);
        rr2.setInstancesIndices(M + 1 + "-" + 2 * M);
        this.m_N2.setFilter(rr2);
        this.m_N2.setKernel((Kernel)new SerializedObject(this.m_Kernel).getObject());
        this.m_N2.setInputFormat(data);
        Instances N2data = Filter.useFilter(data, this.m_N2);
        DenseMatrix X2 = XNV.getMatrix(N2data, true, true);
        EVD evd = XNV.CCA((Matrix)X1, (Matrix)X2);
        X2 = null;
        X1 = null;
        N2data = null;
        double[] e1 = evd.getRealEigenvalues();
        this.m_B1 = evd.getRightEigenvectors();
        ArrayList<Integer> toKeep = new ArrayList<Integer>(e1.length);
        for (int i = 0; i < e1.length; ++i) {
            if (Double.isNaN(e1[i])) {
                throw new IllegalStateException("XNV: Eigenvalue is NaN, aborting. Consider modifying parameters.");
            }
            if (!(e1[i] > 0.0)) continue;
            toKeep.add(i);
        }
        double[] e1New = new double[toKeep.size()];
        DenseMatrix m_B1New = new DenseMatrix(this.m_B1.numRows(), e1New.length);
        int currentColumn = 0;
        Iterator iterator = toKeep.iterator();
        while (iterator.hasNext()) {
            int index = (Integer)iterator.next();
            e1New[currentColumn] = Math.sqrt(e1[index]);
            for (int j = 0; j < this.m_B1.numRows(); ++j) {
                m_B1New.set(j, currentColumn, this.m_B1.get(j, index));
            }
            ++currentColumn;
        }
        e1 = e1New;
        this.m_B1 = m_B1New;
        M = toKeep.size();
        Instances labeledN1 = new Instances(N1data, N1data.numInstances());
        for (Instance inst : N1data) {
            if (inst.classIsMissing()) continue;
            labeledN1.add(inst);
        }
        this.m_numLabeled = labeledN1.numInstances();
        DenseMatrix labels = new DenseMatrix(labeledN1.numInstances(), 1);
        for (int i = 0; i < labeledN1.numInstances(); ++i) {
            labels.set(i, 0, labeledN1.instance(i).classValue());
        }
        DenseMatrix Z1 = XNV.getMatrix(labeledN1, false, false);
        Matrix Z = Z1.mult(this.m_B1, (Matrix)new DenseMatrix(labeledN1.numInstances(), M));
        DenseMatrix CCA_reg = new DenseMatrix(M, M);
        DenseMatrix reg = new DenseMatrix(M, M);
        for (int i = 0; i < e1.length; ++i) {
            CCA_reg.set(i, i, (1.0 - e1[i]) / e1[i]);
            reg.set(i, i, this.m_Gamma);
        }
        Matrix inv = XNV.inverse(Z.transAmult(Z, (Matrix)new DenseMatrix(M, M)).add((Matrix)CCA_reg).add((Matrix)reg));
        this.m_wCCA = inv.transBmult(Z, (Matrix)new DenseMatrix(Z.numColumns(), Z.numRows())).mult((Matrix)labels, (Matrix)new DenseMatrix(M, 1));
    }

    @Override
    public double[] distributionForInstance(Instance inst) throws Exception {
        this.m_Missing.input(inst);
        this.m_Missing.batchFinished();
        inst = this.m_Missing.output();
        this.m_NominalToBinary.input(inst);
        this.m_NominalToBinary.batchFinished();
        inst = this.m_NominalToBinary.output();
        if (!this.getDoNotStandardize()) {
            this.m_Standardize.input(inst);
            inst = this.m_Standardize.output();
        }
        this.m_N1.input(inst);
        inst = this.m_N1.output();
        DenseMatrix result = new DenseMatrix(1, inst.numAttributes() - 1);
        int index = 0;
        for (int i = 0; i < inst.numAttributes(); ++i) {
            if (i == inst.classIndex()) continue;
            result.set(0, index++, inst.value(i));
        }
        result = result.mult(this.m_B1, (Matrix)new DenseMatrix(1, this.m_B1.numColumns()));
        result = result.mult(this.m_wCCA, (Matrix)new DenseMatrix(1, 1));
        double[] pred = new double[]{this.getDoNotStandardize() ? result.get(0, 0) : result.get(0, 0) * this.m_x1 + this.m_x0};
        return pred;
    }

    @Override
    public boolean implementsMoreEfficientBatchPrediction() {
        return true;
    }

    @Override
    public double[][] distributionsForInstances(Instances insts) throws Exception {
        this.m_Missing = new ReplaceMissingValues();
        this.m_Missing.setInputFormat(insts);
        insts = Filter.useFilter(insts, this.m_Missing);
        this.m_NominalToBinary = new NominalToBinary();
        this.m_NominalToBinary.setInputFormat(insts);
        insts = Filter.useFilter(insts, this.m_NominalToBinary);
        if (!this.getDoNotStandardize()) {
            insts = Filter.useFilter(insts, this.m_Standardize);
        }
        DenseMatrix result = XNV.getMatrix(Filter.useFilter(insts, this.m_N1), false, false);
        result = result.mult(this.m_B1, (Matrix)new DenseMatrix(result.numRows(), this.m_B1.numColumns()));
        result = result.mult(this.m_wCCA, (Matrix)new DenseMatrix(insts.numInstances(), 1));
        double[][] preds = new double[insts.numInstances()][1];
        for (int i = 0; i < insts.numInstances(); ++i) {
            preds[i][0] = this.getDoNotStandardize() ? result.get(i, 0) : result.get(i, 0) * this.m_x1 + this.m_x0;
        }
        return preds;
    }

    public String toString() {
        if (this.m_wCCA == null) {
            return "XNV: No model built yet.";
        }
        return "XNV weight vector (beta) based on " + this.m_numLabeled + " instances:\n\n" + this.m_wCCA;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 12037 $");
    }

    public static void main(String[] argv) throws Exception {
        XNV.runClassifier(new XNV(), argv);
    }
}

