/*
 * Decompiled with CFR 0.152.
 */
package smile.ica;

import java.io.Serializable;
import java.util.Properties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.ica.Exp;
import smile.ica.LogCosh;
import smile.math.DifferentiableFunction;
import smile.math.MathEx;
import smile.math.matrix.Matrix;
import smile.stat.distribution.GaussianDistribution;

public class ICA
implements Serializable {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(ICA.class);
    public final double[][] components;

    private ICA(double[][] components) {
        this.components = components;
    }

    public static ICA fit(double[][] data, int p) {
        return ICA.fit(data, p, new Properties());
    }

    public static ICA fit(double[][] data, int p, Properties params) {
        String contrast;
        DifferentiableFunction f = switch (contrast = params.getProperty("smile.ica.contrast", "LogCosh")) {
            case "LogCosh" -> new LogCosh();
            case "Gaussian" -> new Exp();
            default -> throw new IllegalArgumentException("Unsupported contrast function: " + contrast);
        };
        double tol = Double.parseDouble(params.getProperty("smile.ica.tolerance", "1E-4"));
        int maxIter = Integer.parseInt(params.getProperty("smile.ica.iterations", "100"));
        return ICA.fit(data, p, f, tol, maxIter);
    }

    public static ICA fit(double[][] data, int p, DifferentiableFunction contrast, double tol, int maxIter) {
        if (tol <= 0.0) {
            throw new IllegalArgumentException("Invalid tolerance: " + tol);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int n = data[0].length;
        int m = data.length;
        if (p < 1 || p > m) {
            throw new IllegalArgumentException("Invalid dimension of feature space: " + p);
        }
        GaussianDistribution g = new GaussianDistribution(0.0, 1.0);
        double[][] W = new double[p][n];
        for (int i = 0; i < p; ++i) {
            for (int j = 0; j < n; ++j) {
                W[i][j] = g.rand();
            }
            MathEx.unitize(W[i]);
        }
        Matrix X = ICA.whiten(data);
        double[] wold = new double[n];
        double[] wdif = new double[n];
        double[] gwX = new double[m];
        double[] g2w = new double[n];
        for (int i = 0; i < p; ++i) {
            double[] w = W[i];
            double diff = Double.MAX_VALUE;
            for (int iter = 0; iter < maxIter && diff > tol; ++iter) {
                int j;
                System.arraycopy(w, 0, wold, 0, n);
                double[] wX = new double[m];
                X.tv(w, wX);
                double g2 = 0.0;
                for (j = 0; j < m; ++j) {
                    gwX[j] = contrast.g(wX[j]);
                    g2 += contrast.g2(wX[j]);
                }
                for (j = 0; j < n; ++j) {
                    g2w[j] = w[j] * g2;
                }
                X.mv(gwX, w);
                for (j = 0; j < n; ++j) {
                    w[j] = (w[j] - g2w[j]) / (double)m;
                }
                for (int k = 0; k < i; ++k) {
                    double[] wk = W[k];
                    double wkw = MathEx.dot(W[k], w);
                    for (int j2 = 0; j2 < n; ++j2) {
                        int n2 = j2;
                        w[n2] = w[n2] - wkw * wk[j2];
                    }
                }
                MathEx.unitize2(w);
                for (j = 0; j < n; ++j) {
                    wdif[j] = w[j] - wold[j];
                }
                double n1 = MathEx.norm(wdif);
                for (int j3 = 0; j3 < n; ++j3) {
                    wdif[j3] = w[j3] + wold[j3];
                }
                double n2 = MathEx.norm(wdif);
                diff = Math.min(n1, n2);
            }
            if (!(diff > tol)) continue;
            logger.warn("Component {} did not converge in {} iterations.", (Object)i, (Object)maxIter);
        }
        return new ICA(W);
    }

    private static Matrix whiten(double[][] data) {
        double[] mean = MathEx.rowMeans(data);
        Matrix X = Matrix.of(data).transpose();
        int n = X.nrow();
        int m = X.ncol();
        for (int j = 0; j < m; ++j) {
            double mu = mean[j];
            for (int i = 0; i < n; ++i) {
                X.sub(i, j, mu);
            }
        }
        Matrix XtX = X.ata();
        Matrix.EVD eigen = XtX.eigen(false, true, true);
        Matrix E = eigen.Vr;
        Matrix Y = X.mm(E);
        double[] d = eigen.wr;
        for (int i = 0; i < d.length; ++i) {
            if (d[i] < 1.0E-8) {
                throw new IllegalArgumentException(String.format("Covariance matrix (column %d) is close to singular.", i));
            }
            d[i] = 1.0 / Math.sqrt(d[i]);
        }
        for (int j = 0; j < m; ++j) {
            for (int i = 0; i < n; ++i) {
                Y.mul(i, j, d[j]);
            }
        }
        return Y;
    }
}

