/*
 * Decompiled with CFR 0.152.
 */
package hex.optimization;

import java.util.Arrays;
import java.util.Random;
import water.Iced;
import water.MemoryManager;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MathUtils;

public class L_BFGS {
    public static final double c1 = 0.1;

    public static final Result solve(GradientSolver gslvr, L_BFGS_Params params, double[] coefs) {
        return L_BFGS.solve(gslvr, params, new History(20, coefs.length), coefs);
    }

    public static final Result solve(GradientSolver gslvr, L_BFGS_Params params, History hist, double[] coefs) {
        GradientInfo gOld = gslvr.getGradient(coefs);
        double[] beta = coefs;
        int iter = 0;
        double[][] lsBetas = new double[params._nBetas][];
        for (int i = 0; i < lsBetas.length; ++i) {
            lsBetas[i] = MemoryManager.malloc8d((int)beta.length);
        }
        double step = 1.0;
        block1: while (iter++ < params._maxIter && MathUtils.l2norm2((double[])gOld._gradient) > params._gradEps) {
            double[] pk = L_BFGS.getSearchDirection(iter - 1, hist, gOld._gradient);
            double t = step;
            while (t > params._minStep) {
                for (int i = 0; i < params._nBetas; ++i) {
                    L_BFGS.wadd(lsBetas[i], beta, pk, t);
                    t *= params._stepDec;
                }
                GradientInfo[] ginfos = gslvr.getGradient(lsBetas);
                t = step;
                for (int i = 0; i < ginfos.length; ++i) {
                    if (t < params._minStep) break block1;
                    if (!L_BFGS.needLineSearch(t, gOld._objVal, ginfos[i]._objVal, pk, gOld._gradient)) {
                        ArrayUtils.mult((double[])pk, (double)t);
                        if (iter > 0) {
                            hist.update(iter - 1, pk, ginfos[i]._gradient, gOld._gradient);
                        }
                        gOld = ginfos[i];
                        ArrayUtils.add((double[])beta, (double[])pk);
                        assert (Arrays.equals(beta, lsBetas[i]));
                        step = 1.0;
                        continue block1;
                    }
                    t *= params._stepDec;
                }
                step = t;
            }
            break block1;
        }
        Log.info((Object[])new Object[]{"L_BFGS done after " + iter + " iterations"});
        return new Result(iter, beta, gOld);
    }

    private static final double[] getSearchDirection(int iter, History hist, double[] gradient) {
        int i;
        double[] alpha = MemoryManager.malloc8d((int)hist._m);
        double[] q = (double[])gradient.clone();
        for (i = 1; i <= Math.min(iter, hist._m); ++i) {
            alpha[i - 1] = hist.rho(iter - i) * ArrayUtils.innerProduct((double[])hist.getS(iter - i), (double[])q);
            MathUtils.wadd((double[])q, (double[])hist.getY(iter - i), (double)(-alpha[i - 1]));
        }
        if (iter > 0) {
            double[] s = hist.getS(iter - 1);
            double[] y = hist.getY(iter - 1);
            double Hk0 = ArrayUtils.innerProduct((double[])s, (double[])y) / ArrayUtils.innerProduct((double[])y, (double[])y);
            ArrayUtils.mult((double[])q, (double)Hk0);
        }
        for (i = Math.min(iter, hist._m); i > 0; --i) {
            double beta = hist.rho(iter - i) * ArrayUtils.innerProduct((double[])hist.getY(iter - i), (double[])q);
            MathUtils.wadd((double[])q, (double[])hist.getS(iter - i), (double)(alpha[i - 1] - beta));
        }
        ArrayUtils.mult((double[])q, (double)-1.0);
        return q;
    }

    private static final double[] wadd(double[] res, double[] x, double[] y, double w) {
        for (int i = 0; i < x.length; ++i) {
            res[i] = x[i] + w * y[i];
        }
        return x;
    }

    public static double[] startCoefs(int n, long seed) {
        double[] res = MemoryManager.malloc8d((int)n);
        Random r = new Random(seed);
        for (int i = 0; i < res.length; ++i) {
            res[i] = r.nextGaussian();
        }
        return res;
    }

    private static final boolean needLineSearch(double step, double objOld, double objNew, double[] pk, double[] gradOld) {
        double f_hat = 0.0;
        for (int i = 0; i < pk.length; ++i) {
            f_hat += gradOld[i] * pk[i];
        }
        return objNew > (f_hat = 0.1 * step * f_hat + objOld);
    }

    public static final class L_BFGS_Params
    extends Iced {
        public int _maxIter = 1000;
        public double _gradEps = 1.0E-5;
        public int _nBetas = 16;
        public double _stepDec = 0.8;
        public double _minStep = Math.pow(this._stepDec, this._nBetas * 2);
    }

    public static final class History {
        private final double[][] _s;
        private final double[][] _y;
        private final double[] _rho;
        final int _m;
        final int _n;

        public History(int m, int n) {
            this._m = m;
            this._n = n;
            this._s = new double[m][];
            this._y = new double[m][];
            this._rho = MemoryManager.malloc8d((int)m);
            Arrays.fill(this._rho, Double.NaN);
            for (int i = 0; i < m; ++i) {
                this._s[i] = MemoryManager.malloc8d((int)n);
                Arrays.fill(this._s[i], Double.NaN);
                this._y[i] = MemoryManager.malloc8d((int)n);
                Arrays.fill(this._y[i], Double.NaN);
            }
        }

        double[] getY(int k) {
            return this._y[k % this._m];
        }

        double[] getS(int k) {
            return this._s[k % this._m];
        }

        double rho(int i) {
            return this._rho[i % this._m];
        }

        private final void update(int iter, double[] pk, double[] gNew, double[] gOld) {
            assert (iter >= 0);
            int id = iter % this._m;
            double[] gradDiff = this._y[id];
            for (int i = 0; i < gNew.length; ++i) {
                gradDiff[i] = gNew[i] - gOld[i];
            }
            System.arraycopy(pk, 0, this._s[id], 0, pk.length);
            this._rho[id] = 1.0 / ArrayUtils.innerProduct((double[])this._s[id], (double[])this._y[id]);
        }
    }

    public static final class Result {
        public final int iter;
        public final double[] coefs;
        public final GradientInfo ginfo;

        public Result(int iter, double[] coefs, GradientInfo ginfo) {
            this.iter = iter;
            this.coefs = coefs;
            this.ginfo = ginfo;
        }

        public String toString() {
            return this.coefs.length < 50 ? "L-BFGS_res(iter = " + this.iter + ", obj = " + this.ginfo._objVal + ", " + " coefs = " + Arrays.toString(this.coefs) + ", grad = " + Arrays.toString(this.ginfo._gradient) + ")" : "L-BFGS_res(iter = " + this.iter + ", obj = " + this.ginfo._objVal + ", coefs = [" + this.coefs[0] + ", " + this.coefs[1] + ", ..., " + this.coefs[this.coefs.length - 2] + ", " + this.coefs[this.coefs.length - 1] + "]" + ", grad = [" + this.ginfo._gradient[0] + ", " + this.ginfo._gradient[1] + ", ..., " + this.ginfo._gradient[this.ginfo._gradient.length - 2] + ", " + this.ginfo._gradient[this.ginfo._gradient.length - 1] + "])" + "|grad|^2 = " + MathUtils.l2norm2((double[])this.ginfo._gradient);
        }
    }

    public static abstract class GradientSolver {
        public abstract GradientInfo[] getGradient(double[][] var1);

        public final GradientInfo getGradient(double[] betas) {
            return this.getGradient(new double[][]{betas})[0];
        }
    }

    public static class GradientInfo {
        public final double _objVal;
        public final double[] _gradient;

        public GradientInfo(double objVal, double[] grad) {
            this._objVal = objVal;
            this._gradient = grad;
        }

        public String toString() {
            return " objVal = " + this._objVal + ", " + Arrays.toString(this._gradient);
        }
    }
}

