/*
 * Decompiled with CFR 0.152.
 */
package hex.optimization;

import java.util.Arrays;
import water.H2O;
import water.MemoryManager;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MathUtils;

public class ADMM {
    public static double shrinkage(double x, double kappa) {
        double sign = x < 0.0 ? -1.0 : 1.0;
        double sx = x * sign;
        if (sx <= kappa) {
            return 0.0;
        }
        return sign * (sx - kappa);
    }

    public static void subgrad(double lambda, double[] beta, double[] grad) {
        if (beta == null) {
            return;
        }
        for (int i = 0; i < grad.length - 1; ++i) {
            grad[i] = beta[i] < 0.0 ? ADMM.shrinkage(grad[i] - lambda, lambda * 1.0E-4) : (beta[i] > 0.0 ? ADMM.shrinkage(grad[i] + lambda, lambda * 1.0E-4) : ADMM.shrinkage(grad[i], lambda));
        }
    }

    public static class L1Solver {
        final double RELTOL;
        final double ABSTOL;
        double gerr;
        int iter;
        final double _eps;
        final int max_iter;
        MathUtils.Norm _gradientNorm = MathUtils.Norm.L_Infinite;
        public static double DEFAULT_RELTOL = 0.01;
        public static double DEFAULT_ABSTOL = 1.0E-4;

        public L1Solver setGradientNorm(MathUtils.Norm n) {
            this._gradientNorm = n;
            return this;
        }

        public L1Solver(double eps, int max_iter) {
            this(eps, max_iter, DEFAULT_RELTOL, DEFAULT_ABSTOL);
        }

        public L1Solver(double eps, int max_iter, double reltol, double abstol) {
            this._eps = eps;
            this.max_iter = max_iter;
            this.RELTOL = reltol;
            this.ABSTOL = abstol;
        }

        public boolean solve(ProximalSolver solver, double[] res, double lambda) {
            return this.solve(solver, res, lambda, true, null, null);
        }

        private double computeErr(double[] z, double[] grad, double lambda, double[] lb, double[] ub) {
            int j;
            grad = (double[])grad.clone();
            this.gerr = 0.0;
            if (lb != null) {
                for (j = 0; j < z.length; ++j) {
                    if (z[j] != lb[j] || !(grad[j] > 0.0)) continue;
                    grad[j] = z[j] >= 0.0 ? -lambda : lambda;
                }
            }
            if (ub != null) {
                for (j = 0; j < z.length; ++j) {
                    if (z[j] != ub[j] || !(grad[j] < 0.0)) continue;
                    grad[j] = z[j] >= 0.0 ? -lambda : lambda;
                }
            }
            ADMM.subgrad(lambda, z, grad);
            switch (this._gradientNorm) {
                case L_Infinite: {
                    this.gerr = ArrayUtils.linfnorm((double[])grad, (boolean)false);
                    break;
                }
                case L2_2: {
                    this.gerr = ArrayUtils.l2norm2((double[])grad, (boolean)false);
                    break;
                }
                case L2: {
                    this.gerr = Math.sqrt(ArrayUtils.l2norm2((double[])grad, (boolean)false));
                    break;
                }
                case L1: {
                    this.gerr = ArrayUtils.l1norm((double[])grad, (boolean)false);
                    break;
                }
                default: {
                    throw H2O.unimpl();
                }
            }
            return this.gerr;
        }

        public boolean solve(ProximalSolver solver, double[] z, double l1pen, boolean hasIntercept, double[] lb, double[] ub) {
            int i;
            this.gerr = Double.POSITIVE_INFINITY;
            if (l1pen == 0.0 && lb == null && ub == null) {
                solver.solve(null, z);
                return true;
            }
            boolean ii = hasIntercept;
            double[] zbest = null;
            int N = z.length;
            double abstol = this.ABSTOL * Math.sqrt(N);
            double[] rho = solver.rho();
            double[] u = MemoryManager.malloc8d((int)N);
            double[] x = (double[])z.clone();
            double[] beta_given = MemoryManager.malloc8d((int)N);
            double[] kappa = MemoryManager.malloc8d((int)rho.length);
            if (l1pen > 0.0) {
                for (i = 0; i < N - 1; ++i) {
                    kappa[i] = rho[i] != 0.0 ? l1pen / rho[i] : 0.0;
                }
            }
            double orlx = 1.0;
            double reltol = this.RELTOL;
            double best_err = Double.POSITIVE_INFINITY;
            for (i = 0; i < this.max_iter; ++i) {
                solver.solve(beta_given, x);
                double rnorm = 0.0;
                double snorm = 0.0;
                double unorm = 0.0;
                double xnorm = 0.0;
                boolean allzeros = true;
                for (int j = 0; j < N - 1; ++j) {
                    double xj = x[j];
                    double zjold = z[j];
                    double x_hat = xj * orlx + (1.0 - orlx) * zjold;
                    double zj = ADMM.shrinkage(x_hat + u[j], kappa[j]);
                    if (lb != null && zj < lb[j]) {
                        zj = lb[j];
                    }
                    if (ub != null && zj > ub[j]) {
                        zj = ub[j];
                    }
                    int n = j;
                    u[n] = u[n] + (x_hat - zj);
                    beta_given[j] = zj - u[j];
                    double r = xj - zj;
                    double s = zj - zjold;
                    rnorm += r * r;
                    snorm += s * s;
                    xnorm += xj * xj;
                    unorm += rho[j] * rho[j] * u[j] * u[j];
                    z[j] = zj;
                    allzeros &= zj == 0.0;
                }
                if (hasIntercept) {
                    int idx = x.length - 1;
                    double icpt = x[idx];
                    if (lb != null && icpt < lb[idx]) {
                        icpt = lb[idx];
                    }
                    if (ub != null && icpt > ub[idx]) {
                        icpt = ub[idx];
                    }
                    double r = x[idx] - icpt;
                    double s = icpt - z[idx];
                    int n = idx;
                    u[n] = u[n] + r;
                    beta_given[idx] = icpt - u[idx];
                    rnorm += r * r;
                    snorm += s * s;
                    xnorm += icpt * icpt;
                    unorm += rho[idx] * rho[idx] * u[idx] * u[idx];
                    z[idx] = icpt;
                }
                if (!(rnorm < abstol + reltol * Math.sqrt(xnorm)) || !(snorm < abstol + reltol * Math.sqrt(unorm))) continue;
                double oldGerr = this.gerr;
                this.computeErr(z, solver.gradient(z), l1pen, lb, ub);
                if (this.gerr > this._eps) {
                    Log.debug((Object[])new Object[]{"ADMM.L1Solver: iter = " + i + " , gerr =  " + this.gerr + ", oldGerr = " + oldGerr + ", rnorm = " + rnorm + ", snorm  " + snorm});
                    if (abstol > 1.0E-12) {
                        abstol *= 0.1;
                    }
                    if (reltol > 1.0E-10) {
                        reltol *= 0.1;
                    }
                    reltol *= 0.1;
                    continue;
                }
                if (this.gerr > this._eps) {
                    Log.warn((Object[])new Object[]{"ADMM solver finished with gerr = " + this.gerr + " >  eps = " + this._eps});
                }
                this.iter = i;
                Log.info((Object[])new Object[]{"ADMM.L1Solver: converged at iteration = " + i + ", gerr = " + this.gerr + ", inner solver took " + solver.iter() + " iterations"});
                return true;
            }
            this.computeErr(z, solver.gradient(z), l1pen, lb, ub);
            if (zbest != null && best_err < this.gerr) {
                System.arraycopy(zbest, 0, z, 0, zbest.length);
                this.computeErr(z, solver.gradient(z), l1pen, lb, ub);
                assert (Math.abs(best_err - this.gerr) < 1.0E-8) : " gerr = " + this.gerr + ", best_err = " + best_err + " zbest = " + Arrays.toString(zbest) + ", z = " + Arrays.toString(z);
            }
            Log.warn((Object[])new Object[]{"ADMM solver reached maximum number of iterations (" + this.max_iter + ")"});
            if (this.gerr > this._eps) {
                Log.warn((Object[])new Object[]{"ADMM solver finished with gerr = " + this.gerr + " >  eps = " + this._eps});
            }
            this.iter = this.max_iter;
            return false;
        }

        public static double estimateRho(double x, double l1pen, double lb, double ub) {
            if (Double.isInfinite(x)) {
                return 0.0;
            }
            double rho = 0.0;
            if (l1pen != 0.0 && x != 0.0) {
                double D;
                if (x > 0.0) {
                    D = l1pen * (l1pen + 4.0 * x);
                    if (D >= 0.0) {
                        double r = (l1pen + (D = Math.sqrt(D))) / (2.0 * x);
                        if (r > 0.0) {
                            rho = r;
                        } else {
                            Log.warn((Object[])new Object[]{"negative rho estimate(1)! r = " + r});
                        }
                    }
                } else if (x < 0.0 && (D = l1pen * (l1pen - 4.0 * x)) >= 0.0) {
                    double r = -(l1pen + (D = Math.sqrt(D))) / (2.0 * x);
                    if (r > 0.0) {
                        rho = r;
                    } else {
                        Log.warn((Object[])new Object[]{"negative rho estimate(2)!  r = " + r});
                    }
                }
                rho *= 0.25;
            }
            if (!Double.isInfinite(lb) || !Double.isInfinite(ub)) {
                boolean oob = -Math.min(x - lb, ub - x) > -1.0E-4;
                rho = oob ? 10.0 : 0.1;
            }
            return rho;
        }
    }

    public static interface ProximalSolver {
        public double[] rho();

        public boolean solve(double[] var1, double[] var2);

        public boolean hasGradient();

        public double[] gradient(double[] var1);

        public int iter();
    }
}

