/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform;

import java.util.Comparator;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.datatransform.DataTransformBase;
import jsat.distributions.Distribution;
import jsat.distributions.discrete.UniformDiscrete;
import jsat.linear.DenseVector;
import jsat.linear.EigenValueDecomposition;
import jsat.linear.Matrix;
import jsat.linear.MatrixStatistics;
import jsat.linear.SingularValueDecomposition;
import jsat.linear.SubMatrix;
import jsat.linear.Vec;

public class WhitenedPCA
extends DataTransformBase {
    private static final long serialVersionUID = 6134243673037330608L;
    protected double regularization;
    protected int dimensions;
    protected Matrix transform;

    public WhitenedPCA() {
        this(50);
    }

    public WhitenedPCA(int dims) {
        this(1.0E-4, dims);
    }

    public WhitenedPCA(double regularization, int dims) {
        this.setRegularization(regularization);
        this.setDimensions(dims);
    }

    public WhitenedPCA(DataSet dataSet, double regularization, int dims) {
        this(regularization, dims);
        this.fit(dataSet);
    }

    @Override
    public void fit(DataSet dataSet) {
        this.setUpTransform(this.getSVD(dataSet));
    }

    public WhitenedPCA(DataSet dataSet, double regularization) {
        this.setRegularization(regularization);
        SingularValueDecomposition svd = this.getSVD(dataSet);
        this.setDimensions(svd.getRank());
        this.setUpTransform(svd);
    }

    public WhitenedPCA(DataSet dataSet) {
        SingularValueDecomposition svd = this.getSVD(dataSet);
        this.setRegularization(svd);
        this.setDimensions(svd.getRank());
        this.setUpTransform(svd);
    }

    public WhitenedPCA(DataSet dataSet, int dims) {
        SingularValueDecomposition svd = this.getSVD(dataSet);
        this.setRegularization(svd);
        this.setDimensions(dims);
        this.setUpTransform(svd);
    }

    private WhitenedPCA(WhitenedPCA other) {
        this.regularization = other.regularization;
        this.dimensions = other.dimensions;
        this.transform = other.transform.clone();
    }

    private SingularValueDecomposition getSVD(DataSet dataSet) {
        Matrix cov = MatrixStatistics.covarianceMatrix(MatrixStatistics.meanVector(dataSet), dataSet);
        for (int i = 0; i < cov.rows(); ++i) {
            for (int j = 0; j < i; ++j) {
                cov.set(j, i, cov.get(i, j));
            }
        }
        EigenValueDecomposition evd = new EigenValueDecomposition(cov);
        evd.sortByEigenValue(new Comparator<Double>(){

            @Override
            public int compare(Double o1, Double o2) {
                return -Double.compare(o1, o2);
            }
        });
        return new SingularValueDecomposition(evd.getVRaw(), evd.getVRaw(), evd.getRealEigenvalues());
    }

    protected void setUpTransform(SingularValueDecomposition svd) {
        DenseVector diag = new DenseVector(this.dimensions);
        double[] s = svd.getSingularValues();
        for (int i = 0; i < this.dimensions; ++i) {
            ((Vec)diag).set(i, 1.0 / Math.sqrt(s[i] + this.regularization));
        }
        this.transform = new SubMatrix(svd.getU().transpose(), 0, 0, this.dimensions, s.length).clone();
        Matrix.diagMult(diag, this.transform);
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        Vec newVec = this.transform.multiply(dp.getNumericalValues());
        DataPoint newDp = new DataPoint(newVec, dp.getCategoricalValues(), dp.getCategoricalData(), dp.getWeight());
        return newDp;
    }

    public void setRegularization(double regularization) {
        if (regularization < 0.0 || Double.isNaN(regularization) || Double.isInfinite(regularization)) {
            throw new ArithmeticException("Regularization must be non negative value, not " + regularization);
        }
        this.regularization = regularization;
    }

    public double getRegularization() {
        return this.regularization;
    }

    @Override
    public DataTransform clone() {
        return new WhitenedPCA(this);
    }

    private void setRegularization(SingularValueDecomposition svd) {
        if (svd.isFullRank()) {
            this.setRegularization(1.0E-10);
        } else {
            this.setRegularization(Math.max(Math.log(1.0 + svd.getSingularValues()[svd.getRank()]) * 0.25, 1.0E-4));
        }
    }

    public void setDimensions(int dimensions) {
        if (dimensions < 1) {
            throw new IllegalArgumentException("Number of dimensions must be positive, not " + dimensions);
        }
        this.dimensions = dimensions;
    }

    public int getDimensions() {
        return this.dimensions;
    }

    public static Distribution guessDimensions(DataSet d) {
        if (d.getNumNumericalVars() < 100) {
            return new UniformDiscrete(1, d.getNumNumericalVars());
        }
        return new UniformDiscrete(20, 100);
    }
}

