/*
 * Decompiled with CFR 0.152.
 */
package com.github.psambit9791.jdsp.transform;

import com.github.psambit9791.jdsp.misc.UtilMethods;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.commons.math3.util.MathArrays;

public class PCA {
    private double[][] signal;
    private double[][] zm_signal;
    private double[][] output;
    private int n_components;
    private int n_samples;
    private double[][] U;
    private double[][] S;
    private double[][] V;
    public double[] explained_variance_;
    public double[] explained_variance_ratio_;
    public double[] singular_values_;
    private double[] mean_;

    public PCA(double[][] signal, int n_components) throws ExceptionInInitializerError, IllegalArgumentException {
        if (signal.length < signal[0].length) {
            throw new ExceptionInInitializerError("Signal length must be more than number of channels");
        }
        if (n_components > signal[0].length || n_components <= 0) {
            throw new ExceptionInInitializerError("n_components must be greater than 0 and less than total channels in signal");
        }
        this.signal = signal;
        this.n_samples = signal.length;
        this.n_components = n_components;
        this.output = new double[n_components][signal.length];
    }

    public double[][][] getUSV() throws ExceptionInInitializerError {
        if (this.singular_values_ == null) {
            throw new ExceptionInInitializerError("Execute fit() before calling this function");
        }
        double[][][] usv = new double[][][]{this.U, this.S, this.V};
        return usv;
    }

    public void fit() {
        double[][] sigT = UtilMethods.transpose(this.signal);
        this.zm_signal = new double[this.signal.length][this.signal[0].length];
        this.zm_signal = UtilMethods.transpose(this.zm_signal);
        this.mean_ = new double[sigT.length];
        for (int i = 0; i < sigT.length; ++i) {
            this.mean_[i] = StatUtils.mean((double[])sigT[i]);
            this.zm_signal[i] = UtilMethods.zeroCenter(sigT[i]);
        }
        this.zm_signal = UtilMethods.transpose(this.zm_signal);
        RealMatrix m = MatrixUtils.createRealMatrix((double[][])this.zm_signal);
        SingularValueDecomposition svdM = new SingularValueDecomposition(m);
        double[][] U = svdM.getU().getData();
        double[][] S = svdM.getS().getData();
        double[][] V = svdM.getVT().getData();
        double[][][] temp2 = this.svdFlip(U, V);
        U = temp2[0];
        V = temp2[1];
        this.singular_values_ = svdM.getSingularValues();
        this.explained_variance_ = MathArrays.ebeMultiply((double[])this.singular_values_, (double[])this.singular_values_);
        for (int i = 0; i < this.explained_variance_.length; ++i) {
            this.explained_variance_[i] = this.explained_variance_[i] / (double)(this.n_samples - 1);
        }
        double total_var = StatUtils.sum((double[])this.explained_variance_);
        this.explained_variance_ratio_ = new double[S.length];
        for (int i = 0; i < this.explained_variance_.length; ++i) {
            this.explained_variance_ratio_[i] = this.explained_variance_[i] / total_var;
        }
        this.singular_values_ = UtilMethods.splitByIndex(this.singular_values_, 0, this.n_components);
        this.explained_variance_ = UtilMethods.splitByIndex(this.explained_variance_, 0, this.n_components);
        this.explained_variance_ratio_ = UtilMethods.splitByIndex(this.explained_variance_ratio_, 0, this.n_components);
        this.U = U;
        this.S = S;
        this.V = V;
    }

    public double[][] transform() throws ExceptionInInitializerError {
        if (this.singular_values_ == null) {
            throw new ExceptionInInitializerError("Execute fit() before calling this function");
        }
        double[][] components = new double[this.n_components][this.n_samples];
        for (int i = 0; i < this.n_components; ++i) {
            components[i] = this.V[i];
        }
        double[][] components_T = UtilMethods.transpose(components);
        this.output = UtilMethods.matrixMultiply(this.zm_signal, components_T);
        return this.output;
    }

    public double[][] transform(double[][] x) throws ExceptionInInitializerError, ArithmeticException {
        if (this.singular_values_ == null) {
            throw new ExceptionInInitializerError("Execute fit() before calling this function");
        }
        if (x[0].length != this.signal[0].length) {
            throw new ArithmeticException("Number of channels has to be same as original signal");
        }
        double[][] xT = UtilMethods.transpose(x);
        double[][] zm_x = new double[x.length][x[0].length];
        zm_x = UtilMethods.transpose(zm_x);
        for (int i = 0; i < xT.length; ++i) {
            zm_x[i] = UtilMethods.scalarArithmetic(xT[i], this.mean_[i], "sub");
        }
        zm_x = UtilMethods.transpose(zm_x);
        double[][] components = new double[this.n_components][this.n_samples];
        for (int i = 0; i < this.n_components; ++i) {
            components[i] = this.V[i];
        }
        double[][] components_T = UtilMethods.transpose(components);
        double[][] out = UtilMethods.matrixMultiply(zm_x, components_T);
        return out;
    }

    private double[][][] svdFlip(double[][] U, double[][] V) {
        int i;
        double[][] U_new = UtilMethods.absoluteArray(U);
        int[] max_abs_cols = new int[U[0].length];
        double[] signs = new double[U[0].length];
        for (int j = 0; j < U_new[0].length; ++j) {
            double[] column_vals = new double[U_new.length];
            for (int i2 = 0; i2 < U_new.length; ++i2) {
                column_vals[i2] = U_new[i2][j];
            }
            max_abs_cols[j] = UtilMethods.argmax(column_vals, false);
        }
        for (i = 0; i < max_abs_cols.length; ++i) {
            signs[i] = Math.signum(U[max_abs_cols[i]][i]);
        }
        for (i = 0; i < U.length; ++i) {
            U[i] = MathArrays.ebeMultiply((double[])U[i], (double[])signs);
        }
        V = UtilMethods.transpose(V);
        for (i = 0; i < V.length; ++i) {
            V[i] = MathArrays.ebeMultiply((double[])V[i], (double[])signs);
        }
        V = UtilMethods.transpose(V);
        double[][][] out = new double[][][]{U, V};
        return out;
    }
}

