/*
 * Decompiled with CFR 0.152.
 */
package com.github.psambit9791.jdsp.transform;

import com.github.psambit9791.jdsp.misc.Random;
import com.github.psambit9791.jdsp.misc.UtilMethods;
import java.util.Arrays;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.MathArrays;

public class ICA {
    private double[][] signal;
    public double[][] zm_signal;
    private double[][] output;
    private double alpha = 1.0;
    public double[] gx;
    public double g_x;
    public double[][] w_init;
    private int max_iter = 200;
    private double tol = 1.0E-4;
    private String whiten = "unit-variance";
    private String func = "logcosh";
    private long seed = 42L;
    private double[] mean_;
    private int components;
    private int n_iter = -1;
    public double[][] mixingMatrix = null;
    public double[][] unmixingMatrix = null;
    public double[][] whiteningMatrix = null;
    private double[][] componentMatrix = null;
    private double[][] sources = null;

    private void logcosh_(double[] x) {
        this.gx = new double[x.length];
        if (this.alpha < 1.0 || this.alpha > 2.0) {
            throw new IllegalArgumentException("alpha should be between 1.0 and 2.0");
        }
        double temp = 0.0;
        for (int j = 0; j < this.gx.length; ++j) {
            this.gx[j] = FastMath.tanh((double)(x[j] * this.alpha));
            temp += this.alpha * (1.0 - Math.pow(this.gx[j], 2.0));
        }
        this.g_x = temp / (double)this.gx.length;
    }

    private void exp_(double[] x) {
        this.gx = new double[x.length];
        double temp = 0.0;
        for (int j = 0; j < this.gx.length; ++j) {
            double exp = FastMath.exp((double)((0.0 - Math.pow(x[j], 2.0)) / 2.0));
            this.gx[j] = x[j] * exp;
            temp += (1.0 - Math.pow(x[j], 2.0)) * exp;
        }
        this.g_x = temp / (double)this.gx.length;
    }

    private void cube_(double[] x) {
        this.gx = new double[x.length];
        double temp = 0.0;
        for (int j = 0; j < this.gx.length; ++j) {
            this.gx[j] = Math.pow(x[j], 3.0);
            temp += 3.0 * Math.pow(x[j], 2.0);
        }
        this.g_x = temp / (double)this.gx.length;
    }

    public ICA(double[][] signal, String func, String whiten, double[][] w_init, int max_iter, double tol, double alpha) {
        this.signal = signal;
        this.components = this.signal[0].length;
        if (w_init.length != w_init[0].length || w_init.length != this.components) {
            throw new IllegalArgumentException("w_init should be a square matrix and the shape should be same as the number of components in signal");
        }
        if (!(func.equals("logcosh") || func.equals("exp") || func.equals("cube"))) {
            throw new IllegalArgumentException("func should be one of logcosh, exp or cube");
        }
        if (func.equals("logcosh") && (alpha > 2.0 || alpha < 1.0)) {
            throw new IllegalArgumentException("alpha should be between 1 and 2");
        }
        if (!(whiten.equals("unit-variance") || whiten.equals("arbitrary-variance") || whiten.isEmpty())) {
            throw new IllegalArgumentException("whiten must be one of \"unit-variance\", \"arbitrary-variance\" or an empty string. ");
        }
        this.func = func;
        this.whiten = whiten;
        this.w_init = w_init;
        this.max_iter = max_iter;
        this.tol = tol;
        this.alpha = alpha;
    }

    public ICA(double[][] signal, String func, String whiten, int max_iter, double tol, double alpha, long random_state) {
        this.signal = signal;
        this.components = this.signal[0].length;
        if (!(func.equals("logcosh") || func.equals("exp") || func.equals("cube"))) {
            throw new IllegalArgumentException("func should be one of logcosh, exp or cube");
        }
        if (func.equals("logcosh") && (alpha > 2.0 || alpha < 1.0)) {
            throw new IllegalArgumentException("alpha should be between 1 and 2");
        }
        if (!(whiten.equals("unit-variance") || whiten.equals("arbitrary-variance") || whiten.isEmpty())) {
            throw new IllegalArgumentException("whiten must be one of \"unit-variance\", \"arbitrary-variance\" or an empty string. ");
        }
        this.func = func;
        this.whiten = whiten;
        this.seed = random_state;
        this.max_iter = max_iter;
        this.tol = tol;
        Random r1 = new Random(this.seed);
        this.w_init = r1.randomNormal2D(new int[]{this.components, this.components});
    }

    public ICA(double[][] signal, String func, int max_iter, double alpha, long random_state) {
        this.signal = signal;
        this.components = this.signal[0].length;
        if (!(func.equals("logcosh") || func.equals("exp") || func.equals("cube"))) {
            throw new IllegalArgumentException("func should be one of logcosh, exp or cube");
        }
        if (func.equals("logcosh") && (alpha > 2.0 || alpha < 1.0)) {
            throw new IllegalArgumentException("alpha should be between 1 and 2");
        }
        this.func = func;
        this.seed = random_state;
        this.max_iter = max_iter;
        Random r1 = new Random(this.seed);
        this.w_init = r1.randomNormal2D(new int[]{this.components, this.components});
    }

    public ICA(double[][] signal, String func, double[][] w_init, int max_iter, double tol, double alpha) {
        this.signal = signal;
        this.components = this.signal[0].length;
        if (!(func.equals("logcosh") || func.equals("exp") || func.equals("cube"))) {
            throw new IllegalArgumentException("func should be one of logcosh, exp or cube");
        }
        if (func.equals("logcosh") && (alpha > 2.0 || alpha < 1.0)) {
            throw new IllegalArgumentException("alpha should be between 1 and 2");
        }
        this.w_init = w_init;
        this.func = func;
        this.max_iter = max_iter;
        this.tol = tol;
    }

    public ICA(double[][] signal, String func, double alpha, long random_state) {
        this.signal = signal;
        this.components = this.signal[0].length;
        if (!(func.equals("logcosh") || func.equals("exp") || func.equals("cube"))) {
            throw new IllegalArgumentException("func should be one of logcosh, exp or cube");
        }
        if (func.equals("logcosh") && (alpha > 2.0 || alpha < 1.0)) {
            throw new IllegalArgumentException("alpha should be between 1 and 2");
        }
        this.func = func;
        this.alpha = alpha;
        this.seed = random_state;
        Random r1 = new Random(this.seed);
        this.w_init = r1.randomNormal2D(new int[]{this.components, this.components});
    }

    public ICA(double[][] signal, String func, double alpha) {
        this.signal = signal;
        this.components = this.signal[0].length;
        if (!(func.equals("logcosh") || func.equals("exp") || func.equals("cube"))) {
            throw new IllegalArgumentException("func should be one of logcosh, exp or cube");
        }
        if (func.equals("logcosh") && (alpha > 2.0 || alpha < 1.0)) {
            throw new IllegalArgumentException("alpha should be between 1 and 2");
        }
        this.func = func;
        this.alpha = alpha;
        Random r1 = new Random(this.seed);
        this.w_init = r1.randomNormal2D(new int[]{this.components, this.components});
    }

    public ICA(double[][] signal, String func) {
        this.signal = signal;
        this.components = this.signal[0].length;
        if (!(func.equals("logcosh") || func.equals("exp") || func.equals("cube"))) {
            throw new IllegalArgumentException("func should be one of logcosh, exp or cube");
        }
        this.gx = new double[this.signal.length];
        this.g_x = 0.0;
        this.func = func;
        Random r1 = new Random(this.seed);
        this.w_init = r1.randomNormal2D(new int[]{this.components, this.components});
    }

    public ICA(double[][] signal, long random_state) {
        this.signal = signal;
        this.components = this.signal[0].length;
        this.seed = random_state;
        Random r1 = new Random(this.seed);
        this.w_init = r1.randomNormal2D(new int[]{this.components, this.components});
    }

    public ICA(double[][] signal) {
        this.signal = signal;
        this.components = this.signal[0].length;
        Random r1 = new Random(this.seed);
        this.w_init = r1.randomNormal2D(new int[]{this.components, this.components});
    }

    private double[] _gs_decorrelation(double[] w, double[][] W, int j) {
        double[][] sub_W = UtilMethods.subarray(W, j, W.length);
        if (j == 0) {
            for (double[] row : sub_W = new double[w.length][w.length]) {
                Arrays.fill(row, 0.0);
            }
        } else {
            sub_W = UtilMethods.transpose(UtilMethods.matrixMultiply(UtilMethods.transpose(sub_W), sub_W));
        }
        double[] _w = new double[w.length];
        for (int i = 0; i < _w.length; ++i) {
            _w[i] = StatUtils.sum((double[])MathArrays.ebeMultiply((double[])sub_W[i], (double[])w));
        }
        w = MathArrays.ebeSubtract((double[])w, (double[])_w);
        return w;
    }

    private double[][] icaDef(double[][] X, String function, int max_iterations, double[][] w_init) {
        double[][] W;
        for (double[] doubles : W = new double[this.components][this.components]) {
            Arrays.fill(doubles, 0.0);
        }
        for (int j = 0; j < this.components; ++j) {
            int i;
            double[] w = w_init[j];
            double divisor = Math.sqrt(StatUtils.sum((double[])UtilMethods.scalarArithmetic(w, 2.0, "pow")));
            w = UtilMethods.scalarArithmetic(w, divisor, "div");
            for (i = 0; i < max_iterations; ++i) {
                double[] wX = UtilMethods.flattenMatrix(UtilMethods.matrixMultiply(new double[][]{w}, X));
                if (function.equals("logcosh")) {
                    this.logcosh_(wX);
                } else if (function.equals("cube")) {
                    this.cube_(wX);
                } else {
                    this.exp_(wX);
                }
                double[] w1 = new double[X.length];
                for (int h = 0; h < w1.length; ++h) {
                    w1[h] = StatUtils.mean((double[])MathArrays.ebeMultiply((double[])X[h], (double[])this.gx));
                }
                w1 = MathArrays.ebeSubtract((double[])w1, (double[])UtilMethods.scalarArithmetic(w, this.g_x, "mul"));
                w1 = this._gs_decorrelation(w1, W, j);
                double divisor2 = Math.sqrt(StatUtils.sum((double[])UtilMethods.scalarArithmetic(w1, 2.0, "pow")));
                w1 = UtilMethods.scalarArithmetic(w1, divisor2, "div");
                double lim = Math.abs(Math.abs(StatUtils.sum((double[])MathArrays.ebeMultiply((double[])w1, (double[])w))) - 1.0);
                w = w1;
                if (lim < this.tol) break;
            }
            this.n_iter = Math.max(this.n_iter, i + 1);
            W[j] = w;
        }
        return W;
    }

    public void fit() {
        double[][] S2;
        double[][] X1;
        double n_samples = this.signal.length;
        double[][] sigT = UtilMethods.transpose(this.signal);
        Object K = new double[][]{};
        this.zm_signal = UtilMethods.transpose(this.signal);
        if (!this.whiten.isEmpty()) {
            int i;
            this.mean_ = new double[sigT.length];
            for (int i2 = 0; i2 < sigT.length; ++i2) {
                this.mean_[i2] = StatUtils.mean((double[])sigT[i2]);
                this.zm_signal[i2] = UtilMethods.zeroCenter(sigT[i2]);
            }
            RealMatrix m = MatrixUtils.createRealMatrix((double[][])this.zm_signal);
            SingularValueDecomposition svdM = new SingularValueDecomposition(m);
            double[][] U = svdM.getU().getData();
            double[][] S = svdM.getS().getData();
            double[] signs = UtilMethods.sign(U[0]);
            for (i = 0; i < U.length; ++i) {
                U[i] = MathArrays.ebeMultiply((double[])U[i], (double[])signs);
            }
            for (i = 0; i < S.length; ++i) {
                for (int j = 0; j < S.length; ++j) {
                    S[i][j] = S[j][j];
                }
            }
            K = UtilMethods.ebeDivide(MatrixUtils.createRealMatrix((double[][])U), MatrixUtils.createRealMatrix((double[][])S)).getData();
            K = UtilMethods.transpose(K);
            X1 = UtilMethods.matrixMultiply(K, this.zm_signal);
            for (i = 0; i < X1.length; ++i) {
                X1[i] = UtilMethods.scalarArithmetic(X1[i], Math.sqrt(n_samples), "mul");
            }
        } else {
            X1 = sigT;
        }
        double[][] W = this.icaDef(X1, this.func, this.max_iter, this.w_init);
        if (!this.whiten.isEmpty()) {
            S2 = UtilMethods.matrixMultiply(K, this.zm_signal);
            S2 = UtilMethods.matrixMultiply(W, S2);
        } else {
            S2 = UtilMethods.matrixMultiply(W, this.zm_signal);
        }
        if (!this.whiten.isEmpty()) {
            if (this.whiten.equals("unit-variance")) {
                int i;
                double[] S2_std = new double[S2.length];
                for (i = 0; i < S2.length; ++i) {
                    S2_std[i] = 1.0 / Math.sqrt(StatUtils.variance((double[])S2[i]) * (double)(S2[i].length - 1) / (double)S2[i].length);
                }
                S2 = UtilMethods.transpose(S2);
                for (i = 0; i < S2.length; ++i) {
                    S2[i] = MathArrays.ebeMultiply((double[])S2[i], (double[])S2_std);
                }
                W = UtilMethods.transpose(W);
                for (i = 0; i < W.length; ++i) {
                    W[i] = MathArrays.ebeMultiply((double[])W[i], (double[])S2_std);
                }
                W = UtilMethods.transpose(W);
            } else {
                S2 = UtilMethods.transpose(S2);
            }
            this.whiteningMatrix = K;
            this.componentMatrix = UtilMethods.matrixMultiply(W, K);
        } else {
            S2 = UtilMethods.transpose(S2);
            this.componentMatrix = W;
        }
        this.mixingMatrix = UtilMethods.pseudoInverse(this.componentMatrix);
        this.unmixingMatrix = W;
        this.sources = S2;
    }

    public double[][] transform() throws ExceptionInInitializerError {
        if (this.unmixingMatrix == null) {
            throw new ExceptionInInitializerError("Execute fit() before calling this function");
        }
        return this.sources;
    }

    public double[][] transform(double[][] signal) throws ExceptionInInitializerError, ArithmeticException {
        if (this.unmixingMatrix == null) {
            throw new ExceptionInInitializerError("Execute fit() before calling this function");
        }
        if (signal[0].length != this.components) {
            throw new ArithmeticException("Number of components has to be same as original signal");
        }
        if (!this.whiten.isEmpty()) {
            signal = UtilMethods.transpose(signal);
            for (int i = 0; i < signal.length; ++i) {
                signal[i] = UtilMethods.scalarArithmetic(signal[i], this.mean_[i], "sub");
            }
            signal = UtilMethods.transpose(signal);
        }
        return UtilMethods.matrixMultiply(signal, UtilMethods.transpose(this.componentMatrix));
    }
}

