/*
 * Decompiled with CFR 0.152.
 */
package hex.glm;

import hex.glm.Gram;
import java.util.Arrays;
import jsr166y.CountedCompleter;
import water.H2O;
import water.Iced;
import water.Key;
import water.MemoryManager;

public abstract class LSMSolver
extends Iced {
    double _lambda;
    final double _alpha;
    public Key _jobKey;
    public String _id;
    protected boolean _converged;

    public LSMSolver(double lambda, double alpha) {
        this._lambda = lambda;
        this._alpha = alpha;
    }

    public final double[] grad(Gram gram, double[] beta, double[] xy) {
        double[] grad = gram.mul(beta);
        for (int i = 0; i < grad.length; ++i) {
            int n = i;
            grad[n] = grad[n] - xy[i];
        }
        return grad;
    }

    public static void subgrad(double alpha, double lambda, double[] beta, double[] grad) {
        if (beta == null) {
            return;
        }
        double l1pen = lambda * alpha;
        for (int i = 0; i < grad.length - 1; ++i) {
            if (beta[i] < 0.0) {
                int n = i;
                grad[n] = grad[n] - l1pen;
                continue;
            }
            if (beta[i] > 0.0) {
                int n = i;
                grad[n] = grad[n] + l1pen;
                continue;
            }
            grad[i] = LSMSolver.shrinkage(grad[i], l1pen);
        }
    }

    public abstract boolean solve(Gram var1, double[] var2, double var3, double[] var5);

    public final boolean converged() {
        return this._converged;
    }

    public abstract String name();

    protected static double shrinkage(double x, double kappa) {
        double sign = x < 0.0 ? -1.0 : 1.0;
        double sx = x * sign;
        if (sx <= kappa) {
            return 0.0;
        }
        return sign * (sx - kappa);
    }

    protected double objectiveVal(double[] xy, double yy, double[] beta, double[] xb) {
        double res = this.lsm_objectiveVal(xy, yy, beta, xb);
        double l1 = 0.0;
        double l2 = 0.0;
        for (int i = 0; i < beta.length; ++i) {
            l1 += Math.abs(beta[i]);
            l2 += beta[i] * beta[i];
        }
        return res + this._alpha * this._lambda * l1 + 0.5 * (1.0 - this._alpha) * this._lambda * l2;
    }

    protected double lsm_objectiveVal(double[] xy, double yy, double[] beta, double[] xb) {
        double res = 0.5 * yy;
        for (int i = 0; i < xb.length; ++i) {
            res += beta[i] * (0.5 * xb[i] - xy[i]);
        }
        return res;
    }

    static final double[] mul(double[][] X, double[] y, double[] z) {
        int M = X.length;
        int N = y.length;
        for (int i = 0; i < M; ++i) {
            z[i] = X[i][0] * y[0];
            for (int j = 1; j < N; ++j) {
                int n = i;
                z[n] = z[n] + X[i][j] * y[j];
            }
        }
        return z;
    }

    static final double[] mul(double[] x, double a, double[] z) {
        for (int i = 0; i < x.length; ++i) {
            z[i] = a * x[i];
        }
        return z;
    }

    static final double[] plus(double[] x, double[] y, double[] z) {
        for (int i = 0; i < x.length; ++i) {
            z[i] = x[i] + y[i];
        }
        return z;
    }

    static final double[] minus(double[] x, double[] y, double[] z) {
        for (int i = 0; i < x.length; ++i) {
            z[i] = x[i] - y[i];
        }
        return z;
    }

    static final double[] shrink(double[] x, double[] z, double kappa) {
        for (int i = 0; i < x.length - 1; ++i) {
            z[i] = LSMSolver.shrinkage(x[i], kappa);
        }
        z[x.length - 1] = x[x.length - 1];
        return z;
    }

    public static final class ProxSolver
    extends LSMSolver {
        public ProxSolver(double lambda, double alpha) {
            super(lambda, alpha);
        }

        private static final double f_hat(double[] newB, double oldObj, double[] oldB, double[] xb, double[] xy, double t) {
            double res = oldObj;
            double l2 = 0.0;
            for (int i = 0; i < newB.length; ++i) {
                double diff = newB[i] - oldB[i];
                res += (xb[i] - xy[i]) * diff;
                l2 += diff * diff;
            }
            return res + 0.25 * l2 / t;
        }

        private double penalty(double[] beta) {
            double l1 = 0.0;
            double l2 = 0.0;
            for (int i = 0; i < beta.length; ++i) {
                l1 += Math.abs(beta[i]);
                l2 += beta[i] * beta[i];
            }
            return this._lambda * (this._alpha * l1 + (1.0 - this._alpha) * l2 * 0.5);
        }

        private static double betaDiff(double[] b1, double[] b2) {
            double res = 0.0;
            for (int i = 0; i < b1.length; ++i) {
                Math.max(res, Math.abs(b1[i] - b2[i]));
            }
            return res;
        }

        @Override
        public boolean solve(Gram gram, double[] xy, double yy, double[] beta) {
            ADMMSolver admm = new ADMMSolver(this._lambda, this._alpha, 0.01);
            if (gram != null) {
                return admm.solve(gram, xy, yy, beta);
            }
            Arrays.fill(beta, 0.0);
            long t1 = System.currentTimeMillis();
            double[] xb = gram.mul(beta);
            double objval = this.objectiveVal(xy, yy, beta, xb);
            double[] newB = MemoryManager.malloc8d((int)beta.length);
            double[] newG = MemoryManager.malloc8d((int)beta.length);
            double step = 1.0;
            double l1pen = this._lambda * this._alpha;
            double l2pen = this._lambda * (1.0 - this._alpha);
            double lsmobjval = this.lsm_objectiveVal(xy, yy, beta, xb);
            boolean converged = false;
            int intercept = beta.length - 1;
            int iter = 0;
            block0: while (!converged && iter < 1000) {
                ++iter;
                for (step = 1.0; step > 1.0E-12; step *= 0.8) {
                    double l2shrink = 1.0 / (1.0 + step * l2pen);
                    double l1shrink = l1pen * step;
                    for (int i = 0; i < beta.length - 1; ++i) {
                        newB[i] = l2shrink * ProxSolver.shrinkage(beta[i] - step * (xb[i] - xy[i]), l1shrink);
                    }
                    newB[intercept] = beta[intercept] - step * (xb[intercept] - xy[intercept]);
                    gram.mul(newB, newG);
                    double newlsmobj = this.lsm_objectiveVal(xy, yy, newB, newG);
                    double fhat = ProxSolver.f_hat(newB, lsmobjval, beta, xb, xy, step);
                    if (!(newlsmobj <= fhat)) continue;
                    lsmobjval = newlsmobj;
                    converged = ProxSolver.betaDiff(beta, newB) < 1.0E-6;
                    System.arraycopy(newB, 0, beta, 0, newB.length);
                    System.arraycopy(newG, 0, xb, 0, newG.length);
                    continue block0;
                }
                converged = true;
            }
            return converged;
        }

        @Override
        public String name() {
            return "ProximalGradientSolver";
        }
    }

    public static final class ADMMSolver
    extends LSMSolver {
        public static final double DEFAULT_ALPHA = 0.5;
        public double[] _wgiven;
        public double _proximalPenalty;
        public final double _gradientEps;
        private static final double GLM1_RHO = 0.001;
        public double gerr = Double.POSITIVE_INFINITY;
        public int iterations = 0;
        public long decompTime;
        public double _addedL2;
        static final double RELTOL = 1.0E-4;

        public boolean normalize() {
            return this._lambda != 0.0;
        }

        public ADMMSolver(double lambda, double alpha, double gradEps) {
            super(lambda, alpha);
            this._gradientEps = gradEps;
        }

        public ADMMSolver(double lambda, double alpha, double gradEps, double addedL2) {
            super(lambda, alpha);
            this._addedL2 = addedL2;
            this._gradientEps = gradEps;
        }

        @Override
        public boolean solve(Gram gram, double[] xy, double yy, double[] z) {
            return this.solve(gram, xy, yy, z, Double.POSITIVE_INFINITY);
        }

        private static double l1_norm(double[] v) {
            double res = 0.0;
            for (double d : v) {
                res += Math.abs(d);
            }
            return res;
        }

        private static double l2_norm(double[] v) {
            double res = 0.0;
            for (double d : v) {
                res += d * d;
            }
            return res;
        }

        private double converged(Gram g, double[] beta, double[] xy) {
            double[] grad = this.grad(g, beta, xy);
            ADMMSolver.subgrad(this._alpha, this._lambda, beta, grad);
            double err = 0.0;
            for (double d : grad) {
                if (d > err) {
                    err = d;
                    continue;
                }
                if (!(d < -err)) continue;
                err = -d;
            }
            return err;
        }

        private double getGrad(Gram gram, double[] beta, double[] xy) {
            double[] g = this.grad(gram, beta, xy);
            double err = 0.0;
            for (double d3 : g) {
                if (d3 > err) {
                    err = d3;
                    continue;
                }
                if (!(d3 < -err)) continue;
                err = -d3;
            }
            return err;
        }

        public ParallelSolver parSolver(Gram gram, double[] xy, double[] res, double rho, int iBlock, int rBlock) {
            return new ParallelSolver(gram, xy, res, rho, iBlock, rBlock);
        }

        public boolean solve(Gram gram, double[] xy, double yy, double[] z, double rho) {
            int i;
            this.gerr = 0.0;
            double d = gram._diagAdded;
            int N = xy.length;
            Arrays.fill(z, 0.0);
            if (this._lambda > 0.0 || this._addedL2 > 0.0) {
                gram.addDiag(this._lambda * (1.0 - this._alpha) + this._addedL2);
            }
            if (this._alpha > 0.0 && this._lambda > 0.0) {
                gram.addDiag(rho);
            }
            if (this._proximalPenalty > 0.0 && this._wgiven != null) {
                gram.addDiag(this._proximalPenalty, true);
                xy = (double[])xy.clone();
                for (int i2 = 0; i2 < xy.length; ++i2) {
                    int n = i2;
                    xy[n] = xy[n] + this._proximalPenalty * this._wgiven[i2];
                }
            }
            long t1 = System.currentTimeMillis();
            Gram.Cholesky chol = gram.cholesky(null, true, this._id);
            long t2 = System.currentTimeMillis();
            for (int attempts = 0; !chol.isSPD() && attempts < 10; ++attempts) {
                this._addedL2 = this._addedL2 == 0.0 ? 1.0E-5 : (this._addedL2 *= 10.0);
                gram.addDiag(this._addedL2);
                gram.cholesky(chol);
            }
            this.decompTime = t2 - t1;
            if (!chol.isSPD()) {
                throw new NonSPDMatrixException(gram);
            }
            if (this._alpha == 0.0 || this._lambda == 0.0) {
                System.arraycopy(xy, 0, z, 0, xy.length);
                chol.solve(z);
                gram.addDiag(-gram._diagAdded + d);
                return true;
            }
            double[] u = MemoryManager.malloc8d((int)N);
            double[] xyPrime = (double[])xy.clone();
            double kappa = this._lambda * this._alpha / rho;
            int max_iter = Math.max(500, (int)(50000.0 / (double)(1 + (xy.length >> 3))));
            double orlx = 1.8;
            double reltol = 1.0E-4;
            for (i = 0; i < max_iter; ++i) {
                long tX = System.currentTimeMillis();
                for (int j = 0; j < N - 1; ++j) {
                    xyPrime[j] = xy[j] + rho * (z[j] - u[j]);
                }
                xyPrime[N - 1] = xy[N - 1];
                chol.solve(xyPrime);
                double rnorm = 0.0;
                double snorm = 0.0;
                double unorm = 0.0;
                double xnorm = 0.0;
                for (int j = 0; j < N - 1; ++j) {
                    double x = xyPrime[j];
                    double zold = z[j];
                    double x_hat = x * orlx + (1.0 - orlx) * zold;
                    z[j] = ADMMSolver.shrinkage(x_hat + u[j], kappa);
                    int n = j;
                    u[n] = u[n] + (x_hat - z[j]);
                    double r = xyPrime[j] - z[j];
                    double s = z[j] - zold;
                    rnorm += r * r;
                    snorm += s * s;
                    xnorm += x * x;
                    unorm += u[j] * u[j];
                }
                z[N - 1] = xyPrime[N - 1];
                if (rnorm < reltol * xnorm && snorm < reltol * unorm) {
                    this.gerr = 0.0;
                    double[] grad = this.grad(gram, z, xy);
                    ADMMSolver.subgrad(this._alpha, this._lambda, z, grad);
                    for (int x = 0; x < grad.length - 1; ++x) {
                        if (this.gerr < grad[x]) {
                            this.gerr = grad[x];
                            continue;
                        }
                        if (!(this.gerr < -grad[x])) continue;
                        this.gerr = -grad[x];
                    }
                    if (this.gerr < 1.0E-4 || reltol <= 1.0E-6) break;
                    while (rnorm < reltol * xnorm && snorm < reltol * unorm) {
                        reltol *= 0.1;
                    }
                }
                if (i % 20 != 0) continue;
                orlx = (1.0 + 15.0 * orlx) * 0.0625;
            }
            gram.addDiag(-gram._diagAdded + d);
            assert (gram._diagAdded == d);
            this.iterations = i;
            this._converged = this.gerr < this._gradientEps;
            return this._converged;
        }

        @Override
        public String name() {
            return "ADMM";
        }

        public final class ParallelSolver
        extends H2O.H2OCountedCompleter {
            final Gram gram;
            final double rho;
            final double kappa;
            double _bestErr = Double.POSITIVE_INFINITY;
            double _lastErr = Double.POSITIVE_INFINITY;
            final double[] xy;
            double[] _xyPrime;
            double _orlx;
            int _k;
            final double[] u;
            final double[] z;
            Gram.Cholesky chol;
            final double d;
            int _iter;
            final int N;
            final int max_iter;
            final int round;
            final int _iBlock;
            final int _rBlock;

            private ParallelSolver(Gram g, double[] xy, double[] res, double rho, int iBlock, int rBlock) {
                this._iBlock = iBlock;
                this._rBlock = rBlock;
                this.gram = g;
                this.xy = xy;
                this.z = res;
                this.N = xy.length;
                this.d = this.gram._diagAdded;
                this.rho = rho;
                this.u = MemoryManager.malloc8d((int)this.N);
                this.kappa = ADMMSolver.this._lambda * ADMMSolver.this._alpha / rho;
                this.max_iter = (int)(10000.0 * (250.0 / (double)(1 + xy.length)));
                this._k = this.round = Math.max(20, (int)((double)this.max_iter * 0.01));
            }

            public void compute2() {
                Arrays.fill(this.z, 0.0);
                if (ADMMSolver.this._lambda > 0.0 || ADMMSolver.this._addedL2 > 0.0) {
                    this.gram.addDiag(ADMMSolver.this._lambda * (1.0 - ADMMSolver.this._alpha) + ADMMSolver.this._addedL2);
                }
                if (ADMMSolver.this._alpha > 0.0 && ADMMSolver.this._lambda > 0.0) {
                    this.gram.addDiag(this.rho);
                }
                if (ADMMSolver.this._proximalPenalty > 0.0 && ADMMSolver.this._wgiven != null) {
                    this.gram.addDiag(ADMMSolver.this._proximalPenalty, true);
                    for (int i = 0; i < this.xy.length; ++i) {
                        int n = i;
                        this.xy[n] = this.xy[n] + ADMMSolver.this._proximalPenalty * ADMMSolver.this._wgiven[i];
                    }
                }
                long t1 = System.currentTimeMillis();
                this.chol = this.gram.cholesky(null, true, ADMMSolver.this._id);
                long t2 = System.currentTimeMillis();
                for (int attempts = 0; !this.chol.isSPD() && attempts < 10; ++attempts) {
                    ADMMSolver.this._addedL2 = ADMMSolver.this._addedL2 == 0.0 ? 1.0E-5 : (ADMMSolver.this._addedL2 *= 10.0);
                    this.gram.addDiag(ADMMSolver.this._addedL2);
                    this.gram.cholesky(this.chol);
                }
                ADMMSolver.this.decompTime = t2 - t1;
                if (!this.chol.isSPD()) {
                    throw new NonSPDMatrixException(this.gram);
                }
                if (ADMMSolver.this._alpha == 0.0 || ADMMSolver.this._lambda == 0.0) {
                    System.arraycopy(this.xy, 0, this.z, 0, this.xy.length);
                    this.chol.parSolver((CountedCompleter)this, this.z, this._iBlock, this._rBlock).fork();
                    return;
                }
                ADMMSolver.this.gerr = Double.POSITIVE_INFINITY;
                this._xyPrime = (double[])this.xy.clone();
                this._orlx = 1.8;
                new ADMMIteration(this).fork();
            }

            public void onCompletion(CountedCompleter caller) {
                this.gram.addDiag(-this.gram._diagAdded + this.d);
                assert (this.gram._diagAdded == this.d);
            }

            private final class ADMMIteration
            extends CountedCompleter {
                final long t1;

                public ADMMIteration(H2O.H2OCountedCompleter cmp) {
                    super((CountedCompleter)cmp);
                    this.t1 = System.currentTimeMillis();
                }

                public void compute() {
                    ++ParallelSolver.this._iter;
                    double[] xyPrime = ParallelSolver.this._xyPrime;
                    for (int j = 0; j < ParallelSolver.this.N - 1; ++j) {
                        xyPrime[j] = ParallelSolver.this.xy[j] + ParallelSolver.this.rho * (ParallelSolver.this.z[j] - ParallelSolver.this.u[j]);
                    }
                    xyPrime[ParallelSolver.this.N - 1] = ParallelSolver.this.xy[ParallelSolver.this.N - 1];
                    ParallelSolver.this.chol.parSolver(this, xyPrime, ParallelSolver.this._iBlock, ParallelSolver.this._rBlock).fork();
                }

                public void onCompletion(CountedCompleter caller) {
                    double[] xyPrime = ParallelSolver.this._xyPrime;
                    double orlx = ParallelSolver.this._orlx;
                    for (int j = 0; j < ParallelSolver.this.N - 1; ++j) {
                        double x_hat = xyPrime[j];
                        x_hat = x_hat * orlx + (1.0 - orlx) * ParallelSolver.this.z[j];
                        ParallelSolver.this.z[j] = LSMSolver.shrinkage(x_hat + ParallelSolver.this.u[j], ParallelSolver.this.kappa);
                        int n = j;
                        ParallelSolver.this.u[n] = ParallelSolver.this.u[n] + (x_hat - ParallelSolver.this.z[j]);
                    }
                    ParallelSolver.this.z[ParallelSolver.this.N - 1] = xyPrime[ParallelSolver.this.N - 1];
                    if (ParallelSolver.this._iter == ParallelSolver.this._k) {
                        double[] grad = ADMMSolver.this.grad(ParallelSolver.this.gram, ParallelSolver.this.z, ParallelSolver.this.xy);
                        LSMSolver.subgrad(ADMMSolver.this._alpha, ADMMSolver.this._lambda, ParallelSolver.this.z, grad);
                        for (int x = 0; x < grad.length - 1; ++x) {
                            if (!(ADMMSolver.this.gerr < grad[x]) && !(ADMMSolver.this.gerr < -grad[x])) continue;
                            ADMMSolver.this.gerr = grad[x];
                        }
                        if (ADMMSolver.this.gerr < 9.0E-4) {
                            return;
                        }
                        ParallelSolver.this._k += ParallelSolver.this.round;
                    }
                    if (ParallelSolver.this._iter < ParallelSolver.this.max_iter) {
                        this.getCompleter().addToPendingCount(1);
                        new ADMMIteration((H2O.H2OCountedCompleter)this.getCompleter()).fork();
                    }
                }
            }
        }

        public static class NonSPDMatrixException
        extends LSMSolverException {
            public NonSPDMatrixException() {
                super("Matrix is not SPD, can't solve without regularization\n");
            }

            public NonSPDMatrixException(Gram grm) {
                super("Matrix is not SPD, can't solve without regularization\n" + (Object)((Object)grm));
            }
        }
    }

    public static class LSMSolverException
    extends RuntimeException {
        public LSMSolverException(String msg) {
            super(msg);
        }
    }

    public static enum LSMSolverType {
        AUTO,
        ADMM,
        GenGradient;

    }
}

