/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.concurrent.Callable;
import smile.classification.Classifier;
import smile.classification.ClassifierTrainer;
import smile.math.DifferentiableMultivariateFunction;
import smile.math.Math;
import smile.util.MulticoreExecutor;

public class Maxent
implements Classifier<int[]> {
    private int p;
    private int k;
    private double L;
    private double[] w;
    private double[][] W;

    public Maxent(int p, int[][] x, int[] y) {
        this(p, x, y, 0.1);
    }

    public Maxent(int p, int[][] x, int[] y, double lambda) {
        this(p, x, y, lambda, 1.0E-5, 500);
    }

    public Maxent(int p, int[][] x, int[] y, double lambda, double tol, int maxIter) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (p < 0) {
            throw new IllegalArgumentException("Invalid dimension: " + p);
        }
        if (lambda < 0.0) {
            throw new IllegalArgumentException("Invalid regularization factor: " + lambda);
        }
        if (tol <= 0.0) {
            throw new IllegalArgumentException("Invalid tolerance: " + tol);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        this.p = p;
        int[] labels = Math.unique(y);
        Arrays.sort(labels);
        for (int i = 0; i < labels.length; ++i) {
            if (labels[i] < 0) {
                throw new IllegalArgumentException("Negative class label: " + labels[i]);
            }
            if (i <= 0 || labels[i] - labels[i - 1] <= 1) continue;
            throw new IllegalArgumentException("Missing class: " + labels[i] + 1);
        }
        this.k = labels.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (this.k == 2) {
            BinaryObjectiveFunction func = new BinaryObjectiveFunction(x, y, lambda);
            this.w = new double[p + 1];
            this.L = 0.0;
            try {
                this.L = -Math.min(func, 5, this.w, tol, maxIter);
            }
            catch (Exception ex) {
                System.err.println(ex);
            }
        } else {
            MultiClassObjectiveFunction func = new MultiClassObjectiveFunction(x, y, this.k, p, lambda);
            this.w = new double[this.k * (p + 1)];
            this.L = 0.0;
            try {
                this.L = -Math.min(func, 5, this.w, tol, maxIter);
            }
            catch (Exception ex) {
                System.err.println(ex);
            }
            this.W = new double[this.k][p + 1];
            int m = 0;
            for (int i = 0; i < this.k; ++i) {
                int j = 0;
                while (j <= p) {
                    this.W[i][j] = this.w[m];
                    ++j;
                    ++m;
                }
            }
            this.w = null;
        }
    }

    public int getDimension() {
        return this.p;
    }

    private static double log1pe(double x) {
        double y = 0.0;
        y = x > 15.0 ? x : (y += Math.log1p(Math.exp(x)));
        return y;
    }

    private static double log(double x) {
        double y = 0.0;
        y = x < 1.0E-300 ? -690.7755 : Math.log(x);
        return y;
    }

    private static void softmax(double[] prob) {
        int i;
        double max = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < prob.length; ++i2) {
            if (!(prob[i2] > max)) continue;
            max = prob[i2];
        }
        double Z = 0.0;
        for (i = 0; i < prob.length; ++i) {
            double p;
            prob[i] = p = Math.exp(prob[i] - max);
            Z += p;
        }
        i = 0;
        while (i < prob.length) {
            int n = i++;
            prob[n] = prob[n] / Z;
        }
    }

    private static double dot(int[] x, double[] w) {
        double dot = w[w.length - 1];
        for (int i : x) {
            dot += w[i];
        }
        return dot;
    }

    private static double dot(int[] x, double[] w, int j, int p) {
        int pos = j * (p + 1);
        double dot = w[pos + p];
        for (int i : x) {
            dot += w[pos + i];
        }
        return dot;
    }

    public double loglikelihood() {
        return this.L;
    }

    @Override
    public int predict(int[] x) {
        return this.predict(x, (double[])null);
    }

    @Override
    public int predict(int[] x, double[] posteriori) {
        if (posteriori != null && posteriori.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
        }
        if (this.w != null) {
            double f = 1.0 / (1.0 + Math.exp(-Maxent.dot(x, this.w)));
            if (posteriori != null) {
                posteriori[0] = f;
                posteriori[1] = 1.0 - f;
            }
            if (f < 0.5) {
                return 0;
            }
            return 1;
        }
        int label = -1;
        double max = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.k; ++i) {
            double prob = Maxent.dot(x, this.W[i]);
            if (prob > max) {
                max = prob;
                label = i;
            }
            if (posteriori == null) continue;
            posteriori[i] = prob;
        }
        if (posteriori != null) {
            int i;
            double Z = 0.0;
            for (i = 0; i < this.k; ++i) {
                posteriori[i] = Math.exp(posteriori[i] - max);
                Z += posteriori[i];
            }
            i = 0;
            while (i < this.k) {
                int n = i++;
                posteriori[n] = posteriori[n] / Z;
            }
        }
        return label;
    }

    static class MultiClassObjectiveFunction
    implements DifferentiableMultivariateFunction {
        int[][] x;
        int[] y;
        int k;
        int p;
        double lambda;

        MultiClassObjectiveFunction(int[][] x, int[] y, int k, int p, double lambda) {
            this.x = x;
            this.y = y;
            this.k = k;
            this.p = p;
            this.lambda = lambda;
        }

        @Override
        public double f(double[] w) {
            double f = 0.0;
            double[] prob = new double[this.k];
            int n = this.x.length;
            int m = MulticoreExecutor.getThreadPoolSize();
            if (n < 1000 || m < 2) {
                for (int i = 0; i < n; ++i) {
                    for (int j = 0; j < this.k; ++j) {
                        prob[j] = Maxent.dot(this.x[i], w, j, this.p);
                    }
                    Maxent.softmax(prob);
                    f -= Maxent.log(prob[this.y[i]]);
                }
            } else {
                ArrayList<FTask> tasks = new ArrayList<FTask>(m + 1);
                int step = n / m;
                if (step < 100) {
                    step = 100;
                }
                int start2 = 0;
                int end = step;
                for (int i = 0; i < m - 1; ++i) {
                    tasks.add(new FTask(w, start2, end));
                    start2 += step;
                    end += step;
                }
                tasks.add(new FTask(w, start2, n));
                try {
                    Iterator i = MulticoreExecutor.run(tasks).iterator();
                    while (i.hasNext()) {
                        double fi = (Double)i.next();
                        f += fi;
                    }
                }
                catch (Exception ex) {
                    for (int i = 0; i < n; ++i) {
                        for (int j = 0; j < this.k; ++j) {
                            prob[j] = Maxent.dot(this.x[i], w, j, this.p);
                        }
                        Maxent.softmax(prob);
                        f -= Maxent.log(prob[this.y[i]]);
                    }
                }
            }
            if (this.lambda != 0.0) {
                double wnorm = 0.0;
                for (int i = 0; i < this.k; ++i) {
                    for (int j = 0; j < this.p; ++j) {
                        wnorm += Math.sqr(w[i * (this.p + 1) + j]);
                    }
                }
                f += 0.5 * this.lambda * wnorm;
            }
            return f;
        }

        @Override
        public double f(double[] w, double[] g) {
            int j;
            double f = 0.0;
            double[] prob = new double[this.k];
            Arrays.fill(g, 0.0);
            int n = this.x.length;
            int m = MulticoreExecutor.getThreadPoolSize();
            if (n < 1000 || m < 2) {
                for (int i = 0; i < n; ++i) {
                    for (int j2 = 0; j2 < this.k; ++j2) {
                        prob[j2] = Maxent.dot(this.x[i], w, j2, this.p);
                    }
                    Maxent.softmax(prob);
                    f -= Maxent.log(prob[this.y[i]]);
                    double yi = 0.0;
                    for (j = 0; j < this.k; ++j) {
                        yi = (this.y[i] == j ? 1.0 : 0.0) - prob[j];
                        int pos = j * (this.p + 1);
                        for (int l : this.x[i]) {
                            int n2 = pos + l;
                            g[n2] = g[n2] - yi;
                        }
                        int n3 = pos + this.p;
                        g[n3] = g[n3] - yi;
                    }
                }
            } else {
                ArrayList<GTask> tasks = new ArrayList<GTask>(m + 1);
                int step = n / m;
                if (step < 100) {
                    step = 100;
                }
                int start2 = 0;
                int end = step;
                for (int i = 0; i < m - 1; ++i) {
                    tasks.add(new GTask(w, start2, end));
                    start2 += step;
                    end += step;
                }
                tasks.add(new GTask(w, start2, n));
                try {
                    for (double[] gi : MulticoreExecutor.run(tasks)) {
                        f += gi[w.length];
                        for (int i = 0; i < w.length; ++i) {
                            int n4 = i;
                            g[n4] = g[n4] + gi[i];
                        }
                    }
                }
                catch (Exception ex) {
                    for (int i = 0; i < n; ++i) {
                        for (int j3 = 0; j3 < this.k; ++j3) {
                            prob[j3] = Maxent.dot(this.x[i], w, j3, this.p);
                        }
                        Maxent.softmax(prob);
                        f -= Maxent.log(prob[this.y[i]]);
                        double yi = 0.0;
                        for (int j4 = 0; j4 < this.k; ++j4) {
                            yi = (this.y[i] == j4 ? 1.0 : 0.0) - prob[j4];
                            int pos = j4 * (this.p + 1);
                            for (int l : this.x[i]) {
                                int n5 = pos + l;
                                g[n5] = g[n5] - yi;
                            }
                            int n6 = pos + this.p;
                            g[n6] = g[n6] - yi;
                        }
                    }
                }
            }
            if (this.lambda != 0.0) {
                double wnorm = 0.0;
                for (int i = 0; i < this.k; ++i) {
                    for (j = 0; j < this.p; ++j) {
                        int pos = i * (this.p + 1) + j;
                        wnorm += w[pos] * w[pos];
                        int n7 = pos;
                        g[n7] = g[n7] + this.lambda * w[pos];
                    }
                }
                f += 0.5 * this.lambda * wnorm;
            }
            return f;
        }

        class GTask
        implements Callable<double[]> {
            double[] w;
            int start;
            int end;

            GTask(double[] w, int start2, int end) {
                this.w = w;
                this.start = start2;
                this.end = end;
            }

            @Override
            public double[] call() {
                double f = 0.0;
                double[] prob = new double[MultiClassObjectiveFunction.this.k];
                double[] g = new double[this.w.length + 1];
                for (int i = this.start; i < this.end; ++i) {
                    for (int j = 0; j < MultiClassObjectiveFunction.this.k; ++j) {
                        prob[j] = Maxent.dot(MultiClassObjectiveFunction.this.x[i], this.w, j, MultiClassObjectiveFunction.this.p);
                    }
                    Maxent.softmax(prob);
                    f -= Maxent.log(prob[MultiClassObjectiveFunction.this.y[i]]);
                    double yi = 0.0;
                    for (int j = 0; j < MultiClassObjectiveFunction.this.k; ++j) {
                        yi = (MultiClassObjectiveFunction.this.y[i] == j ? 1.0 : 0.0) - prob[j];
                        int pos = j * (MultiClassObjectiveFunction.this.p + 1);
                        for (int l : MultiClassObjectiveFunction.this.x[i]) {
                            int n = pos + l;
                            g[n] = g[n] - yi;
                        }
                        int n = pos + MultiClassObjectiveFunction.this.p;
                        g[n] = g[n] - yi;
                    }
                }
                g[this.w.length] = f;
                return g;
            }
        }

        class FTask
        implements Callable<Double> {
            double[] w;
            int start;
            int end;

            FTask(double[] w, int start2, int end) {
                this.w = w;
                this.start = start2;
                this.end = end;
            }

            @Override
            public Double call() {
                double f = 0.0;
                double[] prob = new double[MultiClassObjectiveFunction.this.k];
                for (int i = this.start; i < this.end; ++i) {
                    for (int j = 0; j < MultiClassObjectiveFunction.this.k; ++j) {
                        prob[j] = Maxent.dot(MultiClassObjectiveFunction.this.x[i], this.w, j, MultiClassObjectiveFunction.this.p);
                    }
                    Maxent.softmax(prob);
                    f -= Maxent.log(prob[MultiClassObjectiveFunction.this.y[i]]);
                }
                return f;
            }
        }
    }

    static class BinaryObjectiveFunction
    implements DifferentiableMultivariateFunction {
        int[][] x;
        int[] y;
        double lambda;

        BinaryObjectiveFunction(int[][] x, int[] y, double lambda) {
            this.x = x;
            this.y = y;
            this.lambda = lambda;
        }

        @Override
        public double f(double[] w) {
            double f = 0.0;
            int p = w.length - 1;
            int n = this.x.length;
            int m = MulticoreExecutor.getThreadPoolSize();
            if (n < 1000 || m < 2) {
                for (int i = 0; i < n; ++i) {
                    double wx = Maxent.dot(this.x[i], w);
                    f += Maxent.log1pe(wx) - (double)this.y[i] * wx;
                }
            } else {
                ArrayList<FTask> tasks = new ArrayList<FTask>(m + 1);
                int step = n / m;
                if (step < 100) {
                    step = 100;
                }
                int start2 = 0;
                int end = step;
                for (int i = 0; i < m - 1; ++i) {
                    tasks.add(new FTask(w, start2, end));
                    start2 += step;
                    end += step;
                }
                tasks.add(new FTask(w, start2, n));
                try {
                    Iterator i = MulticoreExecutor.run(tasks).iterator();
                    while (i.hasNext()) {
                        double fi = (Double)i.next();
                        f += fi;
                    }
                }
                catch (Exception ex) {
                    for (int i = 0; i < n; ++i) {
                        double wx = Maxent.dot(this.x[i], w);
                        f += Maxent.log1pe(wx) - (double)this.y[i] * wx;
                    }
                }
            }
            if (this.lambda != 0.0) {
                double wnorm = 0.0;
                for (int i = 0; i < p; ++i) {
                    wnorm += w[i] * w[i];
                }
                f += 0.5 * this.lambda * wnorm;
            }
            return f;
        }

        @Override
        public double f(double[] w, double[] g) {
            double f = 0.0;
            int p = w.length - 1;
            Arrays.fill(g, 0.0);
            int n = this.x.length;
            int m = MulticoreExecutor.getThreadPoolSize();
            if (n < 1000 || m < 2) {
                for (int i = 0; i < n; ++i) {
                    double wx = Maxent.dot(this.x[i], w);
                    f += Maxent.log1pe(wx) - (double)this.y[i] * wx;
                    double yi = (double)this.y[i] - Math.logistic(wx);
                    int[] nArray = this.x[i];
                    int n2 = nArray.length;
                    for (int j = 0; j < n2; ++j) {
                        int j2;
                        int n3 = j2 = nArray[j];
                        g[n3] = g[n3] - yi * (double)j2;
                    }
                    int n4 = p;
                    g[n4] = g[n4] - yi;
                }
            } else {
                ArrayList<GTask> tasks = new ArrayList<GTask>(m + 1);
                int step = n / m;
                if (step < 100) {
                    step = 100;
                }
                int start2 = 0;
                int end = step;
                for (int i = 0; i < m - 1; ++i) {
                    tasks.add(new GTask(w, start2, end));
                    start2 += step;
                    end += step;
                }
                tasks.add(new GTask(w, start2, n));
                try {
                    for (double[] gi : MulticoreExecutor.run(tasks)) {
                        f += gi[w.length];
                        for (int i = 0; i < w.length; ++i) {
                            int n5 = i;
                            g[n5] = g[n5] + gi[i];
                        }
                    }
                }
                catch (Exception ex) {
                    for (int i = 0; i < n; ++i) {
                        double wx = Maxent.dot(this.x[i], w);
                        f += Maxent.log1pe(wx) - (double)this.y[i] * wx;
                        double yi = (double)this.y[i] - Math.logistic(wx);
                        int[] nArray = this.x[i];
                        int n6 = nArray.length;
                        for (int j = 0; j < n6; ++j) {
                            int j3;
                            int n7 = j3 = nArray[j];
                            g[n7] = g[n7] - yi * (double)j3;
                        }
                        int n8 = p;
                        g[n8] = g[n8] - yi;
                    }
                }
            }
            if (this.lambda != 0.0) {
                double wnorm = 0.0;
                for (int i = 0; i < p; ++i) {
                    wnorm += w[i] * w[i];
                }
                f += 0.5 * this.lambda * wnorm;
                for (int j = 0; j < p; ++j) {
                    int n9 = j;
                    g[n9] = g[n9] + this.lambda * w[j];
                }
            }
            return f;
        }

        class GTask
        implements Callable<double[]> {
            double[] w;
            int start;
            int end;

            GTask(double[] w, int start2, int end) {
                this.w = w;
                this.start = start2;
                this.end = end;
            }

            @Override
            public double[] call() {
                double f = 0.0;
                int p = this.w.length - 1;
                double[] g = new double[this.w.length + 1];
                for (int i = this.start; i < this.end; ++i) {
                    double wx = Maxent.dot(BinaryObjectiveFunction.this.x[i], this.w);
                    f += Maxent.log1pe(wx) - (double)BinaryObjectiveFunction.this.y[i] * wx;
                    double yi = (double)BinaryObjectiveFunction.this.y[i] - Math.logistic(wx);
                    int[] nArray = BinaryObjectiveFunction.this.x[i];
                    int n = nArray.length;
                    for (int j = 0; j < n; ++j) {
                        int j2;
                        int n2 = j2 = nArray[j];
                        g[n2] = g[n2] - yi * (double)j2;
                    }
                    int n3 = p;
                    g[n3] = g[n3] - yi;
                }
                g[this.w.length] = f;
                return g;
            }
        }

        class FTask
        implements Callable<Double> {
            double[] w;
            int start;
            int end;

            FTask(double[] w, int start2, int end) {
                this.w = w;
                this.start = start2;
                this.end = end;
            }

            @Override
            public Double call() {
                double f = 0.0;
                for (int i = this.start; i < this.end; ++i) {
                    double wx = Maxent.dot(BinaryObjectiveFunction.this.x[i], this.w);
                    f += Maxent.log1pe(wx) - (double)BinaryObjectiveFunction.this.y[i] * wx;
                }
                return f;
            }
        }
    }

    public static class Trainer
    extends ClassifierTrainer<int[]> {
        private int p;
        private double lambda = 0.0;
        private double tol = 1.0E-5;
        private int maxIter = 500;

        public Trainer(int p) {
            if (p < 0) {
                throw new IllegalArgumentException("Invalid dimension: " + p);
            }
            this.p = p;
        }

        public Trainer setRegularizationFactor(double lambda) {
            this.lambda = lambda;
            return this;
        }

        public Trainer setTolerance(double tol) {
            this.tol = tol;
            return this;
        }

        public Trainer setMaxNumIteration(int maxIter) {
            this.maxIter = maxIter;
            return this;
        }

        public Maxent train(int[][] x, int[] y) {
            return new Maxent(this.p, x, y, this.lambda, this.tol, this.maxIter);
        }
    }
}

