/*
 * Decompiled with CFR 0.152.
 */
package hex.optimization;

import java.util.Arrays;
import java.util.Random;
import water.Iced;
import water.MemoryManager;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MathUtils;

public final class L_BFGS
extends Iced {
    int _maxIter = 500;
    double _gradEps = 1.0E-8;
    int _nBetas = 32;
    double _stepDec = 0.7;
    double _minStep = Math.pow(this._stepDec, this._nBetas);
    int _historySz = 20;
    History _hist;
    public static final double c1 = 0.1;

    public L_BFGS setMaxIter(int m) {
        this._maxIter = m;
        return this;
    }

    public L_BFGS setGradEps(double d) {
        this._gradEps = d;
        return this;
    }

    public L_BFGS setHistorySz(int sz) {
        this._historySz = sz;
        return this;
    }

    public L_BFGS setMinStep(double d) {
        this._minStep = d;
        int nBetas = (int)(Math.log(d) / Math.log(this._stepDec));
        this._nBetas = Math.min(48, nBetas);
        return this;
    }

    public int k() {
        return this._hist._k;
    }

    public int maxIter() {
        return this._maxIter;
    }

    public final Result solve(GradientSolver gslvr, double[] coefs) {
        return this.solve(gslvr, coefs, gslvr.getGradient(coefs), new ProgressMonitor());
    }

    public final Result solve(GradientSolver gslvr, double[] beta, GradientInfo gOld, ProgressMonitor pm) {
        if (this._hist == null) {
            this._hist = new History(this._historySz, beta.length);
        }
        double[][] lsBetas = new double[this._nBetas][];
        for (int i = 0; i < lsBetas.length; ++i) {
            lsBetas[i] = MemoryManager.malloc8d((int)beta.length);
        }
        double step = 1.0;
        int iter = 0;
        block1: while (pm.progress(gOld) && MathUtils.l2norm2((double[])gOld._gradient) > this._gradEps && iter++ < this._maxIter) {
            double[] pk = this._hist.getSearchDirection(gOld._gradient);
            double t = step;
            while (t > this._minStep) {
                for (int i = 0; i < this._nBetas; ++i) {
                    L_BFGS.wadd(lsBetas[i], beta, pk, t);
                    t *= this._stepDec;
                }
                GradientInfo[] ginfos = gslvr.getGradient(lsBetas);
                t = step;
                for (int i = 0; i < ginfos.length; ++i) {
                    if (t < this._minStep) break block1;
                    if (ginfos[i].isValid() && !L_BFGS.needLineSearch(t, gOld._objVal, ginfos[i]._objVal, pk, gOld._gradient)) {
                        ArrayUtils.mult((double[])pk, (double)t);
                        this._hist.update(pk, ginfos[i]._gradient, gOld._gradient);
                        gOld = ginfos[i];
                        ArrayUtils.add((double[])beta, (double[])pk);
                        assert (Arrays.equals(beta, lsBetas[i]));
                        step = 1.0;
                        continue block1;
                    }
                    t *= this._stepDec;
                }
                step = t;
            }
            --iter;
            break;
        }
        Log.info((Object[])new Object[]{"L_BFGS done after " + iter + " iterations, objval = " + gOld._objVal + ", gradient norm2 = " + MathUtils.l2norm2((double[])gOld._gradient) + ",  converged = " + (MathUtils.l2norm2((double[])gOld._gradient) <= this._gradEps)});
        return new Result(iter, beta, gOld);
    }

    private static final double[] wadd(double[] res, double[] x, double[] y, double w) {
        for (int i = 0; i < x.length; ++i) {
            res[i] = x[i] + w * y[i];
        }
        return x;
    }

    public static double[] startCoefs(int n, long seed) {
        double[] res = MemoryManager.malloc8d((int)n);
        Random r = new Random(seed);
        for (int i = 0; i < res.length; ++i) {
            res[i] = r.nextGaussian();
        }
        return res;
    }

    private static final boolean needLineSearch(double step, double objOld, double objNew, double[] pk, double[] gradOld) {
        double f_hat = 0.0;
        for (int i = 0; i < pk.length; ++i) {
            f_hat += gradOld[i] * pk[i];
        }
        return objNew > (f_hat = 0.1 * step * f_hat + objOld);
    }

    public static final class History
    extends Iced {
        private final double[][] _s;
        private final double[][] _y;
        private final double[] _rho;
        final int _m;
        final int _n;
        int _k;

        public History(int m, int n) {
            this._m = m;
            this._n = n;
            this._s = new double[m][];
            this._y = new double[m][];
            this._rho = MemoryManager.malloc8d((int)m);
            Arrays.fill(this._rho, Double.NaN);
            for (int i = 0; i < m; ++i) {
                this._s[i] = MemoryManager.malloc8d((int)n);
                Arrays.fill(this._s[i], Double.NaN);
                this._y[i] = MemoryManager.malloc8d((int)n);
                Arrays.fill(this._y[i], Double.NaN);
            }
        }

        double[] getY(int k) {
            return this._y[(this._k + k) % this._m];
        }

        double[] getS(int k) {
            return this._s[(this._k + k) % this._m];
        }

        double rho(int k) {
            return this._rho[(this._k + k) % this._m];
        }

        private final void update(double[] pk, double[] gNew, double[] gOld) {
            int id = this._k % this._m;
            double[] gradDiff = this._y[id];
            for (int i = 0; i < gNew.length; ++i) {
                gradDiff[i] = gNew[i] - gOld[i];
            }
            System.arraycopy(pk, 0, this._s[id], 0, pk.length);
            this._rho[id] = 1.0 / ArrayUtils.innerProduct((double[])this._s[id], (double[])this._y[id]);
            ++this._k;
        }

        private final double[] getSearchDirection(double[] gradient) {
            int i;
            double[] alpha = MemoryManager.malloc8d((int)this._m);
            double[] q = (double[])gradient.clone();
            for (i = 1; i <= Math.min(this._k, this._m); ++i) {
                alpha[i - 1] = this.rho(-i) * ArrayUtils.innerProduct((double[])this.getS(-i), (double[])q);
                MathUtils.wadd((double[])q, (double[])this.getY(-i), (double)(-alpha[i - 1]));
            }
            if (this._k > 0) {
                double[] s = this.getS(-1);
                double[] y = this.getY(-1);
                double Hk0 = ArrayUtils.innerProduct((double[])s, (double[])y) / ArrayUtils.innerProduct((double[])y, (double[])y);
                ArrayUtils.mult((double[])q, (double)Hk0);
            }
            for (i = Math.min(this._k, this._m); i > 0; --i) {
                double beta = this.rho(-i) * ArrayUtils.innerProduct((double[])this.getY(-i), (double[])q);
                MathUtils.wadd((double[])q, (double[])this.getS(-i), (double)(alpha[i - 1] - beta));
            }
            ArrayUtils.mult((double[])q, (double)-1.0);
            return q;
        }
    }

    public static final class Result {
        public final int iter;
        public final double[] coefs;
        public final GradientInfo ginfo;

        public Result(int iter, double[] coefs, GradientInfo ginfo) {
            this.iter = iter;
            this.coefs = coefs;
            this.ginfo = ginfo;
        }

        public String toString() {
            return this.coefs.length < 50 ? "L-BFGS_res(iter = " + this.iter + ", obj = " + this.ginfo._objVal + ", " + " coefs = " + Arrays.toString(this.coefs) + ", grad = " + Arrays.toString(this.ginfo._gradient) + ")" : "L-BFGS_res(iter = " + this.iter + ", obj = " + this.ginfo._objVal + ", coefs = [" + this.coefs[0] + ", " + this.coefs[1] + ", ..., " + this.coefs[this.coefs.length - 2] + ", " + this.coefs[this.coefs.length - 1] + "]" + ", grad = [" + this.ginfo._gradient[0] + ", " + this.ginfo._gradient[1] + ", ..., " + this.ginfo._gradient[this.ginfo._gradient.length - 2] + ", " + this.ginfo._gradient[this.ginfo._gradient.length - 1] + "])" + "|grad|^2 = " + MathUtils.l2norm2((double[])this.ginfo._gradient);
        }
    }

    public static class ProgressMonitor {
        public boolean progress(GradientInfo ginfo) {
            return true;
        }
    }

    public static abstract class GradientSolver {
        public abstract GradientInfo[] getGradient(double[][] var1);

        public final GradientInfo getGradient(double[] betas) {
            return this.getGradient(new double[][]{betas})[0];
        }
    }

    public static class GradientInfo {
        public final double _objVal;
        public final double[] _gradient;

        public GradientInfo(double objVal, double[] grad) {
            this._objVal = objVal;
            this._gradient = grad;
        }

        public boolean isValid() {
            if (Double.isNaN(this._objVal)) {
                return false;
            }
            return !ArrayUtils.hasNaNsOrInfs((double[])this._gradient);
        }

        public String toString() {
            return " objVal = " + this._objVal + ", " + Arrays.toString(this._gradient);
        }
    }
}

