/*
 * Decompiled with CFR 0.152.
 */
package hex.optimization;

import java.util.Arrays;
import java.util.Random;
import water.Iced;
import water.MemoryManager;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MathUtils;

public final class L_BFGS
extends Iced {
    int _maxIter = 500;
    double _gradEps = 1.0E-8;
    int _historySz = 20;
    History _hist;
    public static final double c1 = 0.25;

    public L_BFGS setMaxIter(int m) {
        this._maxIter = m;
        return this;
    }

    public L_BFGS setGradEps(double d) {
        this._gradEps = d;
        return this;
    }

    public L_BFGS setHistorySz(int sz) {
        this._historySz = sz;
        return this;
    }

    public int k() {
        return this._hist._k;
    }

    public int maxIter() {
        return this._maxIter;
    }

    public final Result solve(GradientSolver gslvr, double[] coefs) {
        return this.solve(gslvr, coefs, gslvr.getGradient(coefs), new ProgressMonitor());
    }

    public final Result solve(GradientSolver gslvr, double[] beta, GradientInfo ginfo, ProgressMonitor pm) {
        if (this._hist == null) {
            this._hist = new History(this._historySz, beta.length);
        }
        beta = (double[])beta.clone();
        int iter = 0;
        boolean doLineSearch = true;
        int ls_switch = 0;
        double gEps = Math.min(this._gradEps * (double)beta.length, Math.max(MathUtils.l2norm2((double[])ginfo._gradient) * 0.1, 1.0E-15));
        while (pm.progress(beta, ginfo) && MathUtils.l2norm2((double[])ginfo._gradient) > gEps && iter != this._maxIter) {
            double[] pk = this._hist.getSearchDirection(ginfo._gradient);
            if (ArrayUtils.hasNaNsOrInfs((double[])pk)) {
                Log.warn((Object[])new Object[]{"LBFGS: Got NaNs in search direction."});
                break;
            }
            double lsVal = Double.POSITIVE_INFINITY;
            if (doLineSearch) {
                LineSearchSol ls = gslvr.doLineSearch(ginfo, beta, pk, 24, 0.5);
                if (ls.step == 1.0) {
                    if (++ls_switch == 2) {
                        ls_switch = 0;
                        doLineSearch = false;
                    }
                } else {
                    ls_switch = 0;
                }
                if (!ls.madeProgress && this._hist._k >= 2) break;
                lsVal = ls.objVal;
                ArrayUtils.wadd((double[])beta, (double[])pk, (double)ls.step);
            } else {
                ArrayUtils.add((double[])beta, (double[])pk);
            }
            GradientInfo newGinfo = gslvr.getGradient(beta);
            if (doLineSearch && (!Double.isNaN(lsVal) || !Double.isNaN(newGinfo._objVal)) && Math.abs(lsVal - newGinfo._objVal) > 1.0E-10 * lsVal) {
                throw new IllegalArgumentException("L-BFGS: Got invalid gradient solver, objective values from line-search and gradient tasks differ, " + lsVal + " != " + newGinfo._objVal);
            }
            if (!doLineSearch) {
                if (!L_BFGS.admissibleStep(1.0, ginfo._objVal, newGinfo._objVal, pk, ginfo._gradient)) {
                    if (++ls_switch == 2) {
                        doLineSearch = true;
                        ls_switch = 0;
                    }
                    if (ginfo._objVal < newGinfo._objVal && newGinfo._objVal - ginfo._objVal > 0.001 * ginfo._objVal) {
                        doLineSearch = true;
                        ArrayUtils.subtract((double[])beta, (double[])pk, (double[])beta);
                        continue;
                    }
                } else {
                    ls_switch = 0;
                }
            }
            ++iter;
            this._hist.update(pk, newGinfo._gradient, ginfo._gradient);
            ginfo = newGinfo;
        }
        return new Result(iter, beta, ginfo);
    }

    public static double[] startCoefs(int n, long seed) {
        double[] res = MemoryManager.malloc8d((int)n);
        Random r = new Random(seed);
        for (int i = 0; i < res.length; ++i) {
            res[i] = r.nextGaussian();
        }
        return res;
    }

    private static final boolean admissibleStep(double step, double objOld, double objNew, double[] pk, double[] gradOld) {
        if (Double.isNaN(objNew)) {
            return false;
        }
        double f_hat = 0.0;
        for (int i = 0; i < pk.length; ++i) {
            f_hat += gradOld[i] * pk[i];
        }
        return objNew < (f_hat = 0.25 * step * f_hat + objOld);
    }

    public static class LineSearchSol {
        public final double objVal;
        public final double step;
        public final boolean madeProgress;

        public LineSearchSol(boolean progress, double obj, double step) {
            this.objVal = obj;
            this.step = step;
            this.madeProgress = progress;
        }
    }

    public static final class History
    extends Iced {
        private final double[][] _s;
        private final double[][] _y;
        private final double[] _rho;
        final int _m;
        final int _n;
        int _k;

        public History(int m, int n) {
            this._m = m;
            this._n = n;
            this._s = new double[m][];
            this._y = new double[m][];
            this._rho = MemoryManager.malloc8d((int)m);
            Arrays.fill(this._rho, Double.NaN);
            for (int i = 0; i < m; ++i) {
                this._s[i] = MemoryManager.malloc8d((int)n);
                Arrays.fill(this._s[i], Double.NaN);
                this._y[i] = MemoryManager.malloc8d((int)n);
                Arrays.fill(this._y[i], Double.NaN);
            }
        }

        double[] getY(int k) {
            return this._y[(this._k + k) % this._m];
        }

        double[] getS(int k) {
            return this._s[(this._k + k) % this._m];
        }

        double rho(int k) {
            return this._rho[(this._k + k) % this._m];
        }

        private final void update(double[] pk, double[] gNew, double[] gOld) {
            int id = this._k % this._m;
            double[] gradDiff = this._y[id];
            for (int i = 0; i < gNew.length; ++i) {
                gradDiff[i] = gNew[i] - gOld[i];
            }
            System.arraycopy(pk, 0, this._s[id], 0, pk.length);
            this._rho[id] = 1.0 / ArrayUtils.innerProduct((double[])this._s[id], (double[])this._y[id]);
            ++this._k;
        }

        protected final double[] getSearchDirection(double[] gradient) {
            int i;
            double[] alpha = MemoryManager.malloc8d((int)this._m);
            double[] q = (double[])gradient.clone();
            for (i = 1; i <= Math.min(this._k, this._m); ++i) {
                alpha[i - 1] = this.rho(-i) * ArrayUtils.innerProduct((double[])this.getS(-i), (double[])q);
                MathUtils.wadd((double[])q, (double[])this.getY(-i), (double)(-alpha[i - 1]));
            }
            if (this._k > 0) {
                double[] s = this.getS(-1);
                double[] y = this.getY(-1);
                double Hk0 = ArrayUtils.innerProduct((double[])s, (double[])y) / ArrayUtils.innerProduct((double[])y, (double[])y);
                ArrayUtils.mult((double[])q, (double)Hk0);
            }
            for (i = Math.min(this._k, this._m); i > 0; --i) {
                double beta = this.rho(-i) * ArrayUtils.innerProduct((double[])this.getY(-i), (double[])q);
                MathUtils.wadd((double[])q, (double[])this.getS(-i), (double)(alpha[i - 1] - beta));
            }
            return ArrayUtils.mult((double[])q, (double)-1.0);
        }
    }

    public static final class Result {
        public final int iter;
        public final double[] coefs;
        public final GradientInfo ginfo;

        public Result(int iter, double[] coefs, GradientInfo ginfo) {
            this.iter = iter;
            this.coefs = coefs;
            this.ginfo = ginfo;
        }

        public String toString() {
            return this.coefs.length < 50 ? "L-BFGS_res(iter = " + this.iter + ", obj = " + this.ginfo._objVal + ", " + " coefs = " + Arrays.toString(this.coefs) + ", grad = " + Arrays.toString(this.ginfo._gradient) + ")" : "L-BFGS_res(iter = " + this.iter + ", obj = " + this.ginfo._objVal + ", coefs = [" + this.coefs[0] + ", " + this.coefs[1] + ", ..., " + this.coefs[this.coefs.length - 2] + ", " + this.coefs[this.coefs.length - 1] + "]" + ", grad = [" + this.ginfo._gradient[0] + ", " + this.ginfo._gradient[1] + ", ..., " + this.ginfo._gradient[this.ginfo._gradient.length - 2] + ", " + this.ginfo._gradient[this.ginfo._gradient.length - 1] + "])" + "|grad|^2 = " + MathUtils.l2norm2((double[])this.ginfo._gradient);
        }
    }

    public static class ProgressMonitor {
        public boolean progress(double[] beta, GradientInfo ginfo) {
            return true;
        }
    }

    public static abstract class GradientSolver {
        public abstract GradientInfo getGradient(double[] var1);

        public abstract double[] getObjVals(double[] var1, double[] var2, int var3, double var4);

        public LineSearchSol doLineSearch(GradientInfo ginfo, double[] beta, double[] direction, int nSteps, double tdec) {
            double[] objVals = this.getObjVals(beta, direction, nSteps, tdec);
            double t = 1.0;
            for (int i = 0; i < objVals.length; ++i) {
                if (L_BFGS.admissibleStep(t, ginfo._objVal, objVals[i], direction, ginfo._gradient)) {
                    return new LineSearchSol(true, objVals[i], t);
                }
                t *= tdec;
            }
            return new LineSearchSol(objVals[objVals.length - 1] < ginfo._objVal, objVals[objVals.length - 1], t / tdec);
        }
    }

    public static class GradientInfo
    extends Iced {
        public double _objVal;
        public final double[] _gradient;

        public GradientInfo(double objVal, double[] grad) {
            this._objVal = objVal;
            this._gradient = grad;
        }

        public boolean isValid() {
            if (Double.isNaN(this._objVal)) {
                return false;
            }
            return !ArrayUtils.hasNaNsOrInfs((double[])this._gradient);
        }

        public String toString() {
            return " objVal = " + this._objVal + ", " + Arrays.toString(this._gradient);
        }

        public boolean hasNaNsOrInfs() {
            return Double.isNaN(this._objVal) || ArrayUtils.hasNaNsOrInfs((double[])this._gradient);
        }
    }
}

