/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Arrays;
import smile.classification.Classifier;
import smile.classification.ClassifierTrainer;
import smile.math.Math;
import smile.math.matrix.EigenValueDecomposition;

public class RDA
implements Classifier<double[]> {
    private int p;
    private int k;
    private final double[] ct;
    private double[] priori;
    private double[][] mu;
    private double[][][] scaling;
    private double[][] ev;

    public RDA(double[][] x, int[] y, double alpha) {
        this(x, y, null, alpha);
    }

    public RDA(double[][] x, int[] y, double[] priori, double alpha) {
        this(x, y, priori, alpha, 1.0E-4);
    }

    public RDA(double[][] x, int[] y, double[] priori, double alpha, double tol) {
        int j;
        int c;
        int i;
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (alpha < 0.0 || alpha > 1.0) {
            throw new IllegalArgumentException("Invalid regularization factor: " + alpha);
        }
        if (priori != null) {
            if (priori.length < 2) {
                throw new IllegalArgumentException("Invalid number of priori probabilities: " + priori.length);
            }
            double sum = 0.0;
            for (double pr : priori) {
                if (pr <= 0.0 || pr >= 1.0) {
                    throw new IllegalArgumentException("Invlaid priori probability: " + pr);
                }
                sum += pr;
            }
            if (Math.abs(sum - 1.0) > 1.0E-10) {
                throw new IllegalArgumentException("The sum of priori probabilities is not one: " + sum);
            }
        }
        int[] labels = Math.unique(y);
        Arrays.sort(labels);
        for (int i2 = 0; i2 < labels.length; ++i2) {
            if (labels[i2] < 0) {
                throw new IllegalArgumentException("Negative class label: " + labels[i2]);
            }
            if (i2 <= 0 || labels[i2] - labels[i2 - 1] <= 1) continue;
            throw new IllegalArgumentException("Missing class: " + labels[i2] + 1);
        }
        this.k = labels.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (priori != null && this.k != priori.length) {
            throw new IllegalArgumentException("The number of classes and the number of priori probabilities don't match.");
        }
        if (tol < 0.0) {
            throw new IllegalArgumentException("Invalid tol: " + tol);
        }
        int n = x.length;
        if (n <= this.k) {
            throw new IllegalArgumentException(String.format("Sample size is too small: %d <= %d", n, this.k));
        }
        this.p = x[0].length;
        int[] ni = new int[this.k];
        double[] mean = Math.colMean(x);
        double[][] C = new double[this.p][this.p];
        this.mu = new double[this.k][this.p];
        double[][][] cov = new double[this.k][this.p][this.p];
        for (i = 0; i < n; ++i) {
            int n2 = c = y[i];
            ni[n2] = ni[n2] + 1;
            for (j = 0; j < this.p; ++j) {
                double[] dArray = this.mu[c];
                int n3 = j;
                dArray[n3] = dArray[n3] + x[i][j];
            }
        }
        for (i = 0; i < this.k; ++i) {
            if (ni[i] <= 1) {
                throw new IllegalArgumentException(String.format("Class %d has only one sample.", i));
            }
            int j2 = 0;
            while (j2 < this.p) {
                double[] dArray = this.mu[i];
                int n4 = j2++;
                dArray[n4] = dArray[n4] / (double)ni[i];
            }
        }
        if (priori == null) {
            priori = new double[this.k];
            for (i = 0; i < this.k; ++i) {
                priori[i] = (double)ni[i] / (double)n;
            }
        }
        this.priori = priori;
        for (i = 0; i < n; ++i) {
            c = y[i];
            for (j = 0; j < this.p; ++j) {
                for (int l = 0; l <= j; ++l) {
                    double[] dArray = cov[c][j];
                    int n5 = l;
                    dArray[n5] = dArray[n5] + (x[i][j] - this.mu[c][j]) * (x[i][l] - this.mu[c][l]);
                    double[] dArray2 = C[j];
                    int n6 = l;
                    dArray2[n6] = dArray2[n6] + (x[i][j] - mean[j]) * (x[i][l] - mean[l]);
                }
            }
        }
        tol *= tol;
        for (int j3 = 0; j3 < this.p; ++j3) {
            for (int l = 0; l <= j3; ++l) {
                double[] dArray = C[j3];
                int n7 = l;
                dArray[n7] = dArray[n7] / (double)(n - this.k);
                C[l][j3] = C[j3][l];
            }
            if (!(C[j3][j3] < tol)) continue;
            throw new IllegalArgumentException(String.format("Covariance matrix (variable %d) is close to singular.", j3));
        }
        this.ev = new double[this.k][];
        for (i = 0; i < this.k; ++i) {
            for (int j4 = 0; j4 < this.p; ++j4) {
                for (int l = 0; l <= j4; ++l) {
                    double[] dArray = cov[i][j4];
                    int n8 = l;
                    dArray[n8] = dArray[n8] / (double)(ni[i] - 1);
                    cov[i][j4][l] = alpha * cov[i][j4][l] + (1.0 - alpha) * C[j4][l];
                    cov[i][l][j4] = cov[i][j4][l];
                }
                if (!(cov[i][j4][j4] < tol)) continue;
                throw new IllegalArgumentException(String.format("Class %d covariance matrix (variable %d) is close to singular.", i, j4));
            }
            EigenValueDecomposition eigen = EigenValueDecomposition.decompose(cov[i], true);
            for (double s : eigen.getEigenValues()) {
                if (!(s < tol)) continue;
                throw new IllegalArgumentException(String.format("Class %d covariance matrix is close to singular.", i));
            }
            this.ev[i] = eigen.getEigenValues();
            cov[i] = eigen.getEigenVectors();
        }
        this.scaling = cov;
        this.ct = new double[this.k];
        for (i = 0; i < this.k; ++i) {
            double logev = 0.0;
            for (int j5 = 0; j5 < this.p; ++j5) {
                logev += Math.log(this.ev[i][j5]);
            }
            this.ct[i] = Math.log(priori[i]) - 0.5 * logev;
        }
    }

    public double[] getPriori() {
        return this.priori;
    }

    @Override
    public int predict(double[] x) {
        return this.predict(x, (double[])null);
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        if (posteriori != null && posteriori.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
        }
        int y = 0;
        double max = Double.NEGATIVE_INFINITY;
        double[] d = new double[this.p];
        double[] ux = new double[this.p];
        for (int i = 0; i < this.k; ++i) {
            for (int j = 0; j < this.p; ++j) {
                d[j] = x[j] - this.mu[i][j];
            }
            Math.atx(this.scaling[i], d, ux);
            double f = 0.0;
            for (int j = 0; j < this.p; ++j) {
                f += ux[j] * ux[j] / this.ev[i][j];
            }
            if (max < (f = this.ct[i] - 0.5 * f)) {
                max = f;
                y = i;
            }
            if (posteriori == null) continue;
            posteriori[i] = f;
        }
        if (posteriori != null) {
            int i;
            double sum = 0.0;
            for (i = 0; i < this.k; ++i) {
                posteriori[i] = Math.exp(posteriori[i] - max);
                sum += posteriori[i];
            }
            i = 0;
            while (i < this.k) {
                int n = i++;
                posteriori[n] = posteriori[n] / sum;
            }
        }
        return y;
    }

    public static class Trainer
    extends ClassifierTrainer<double[]> {
        private double alpha;
        private double[] priori;
        private double tol = 1.0E-4;

        public Trainer(double alpha) {
            if (alpha < 0.0 || alpha > 1.0) {
                throw new IllegalArgumentException("Invalid regularization factor: " + alpha);
            }
            this.alpha = alpha;
        }

        public Trainer setPriori(double[] priori) {
            this.priori = priori;
            return this;
        }

        public Trainer setTolerance(double tol) {
            if (tol < 0.0) {
                throw new IllegalArgumentException("Invalid tol: " + tol);
            }
            this.tol = tol;
            return this;
        }

        public RDA train(double[][] x, int[] y) {
            return new RDA(x, y, this.priori, this.alpha, this.tol);
        }
    }
}

