/*
 * Decompiled with CFR 0.152.
 */
package hex.glrm;

import Jama.CholeskyDecomposition;
import Jama.Matrix;
import Jama.QRDecomposition;
import Jama.SingularValueDecomposition;
import hex.DataInfo;
import hex.FrameTask;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.glrm.GLRMModel;
import hex.gram.Gram;
import hex.kmeans.KMeans;
import hex.kmeans.KMeansModel;
import hex.pca.PCA;
import hex.pca.PCAModel;
import hex.schemas.GLRMV99;
import hex.schemas.ModelBuilderSchema;
import java.util.Arrays;
import java.util.Random;
import water.DKV;
import water.H2O;
import water.Iced;
import water.Job;
import water.Key;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.RandomUtils;

public class GLRM
extends ModelBuilder<GLRMModel, GLRMModel.GLRMParameters, GLRMModel.GLRMOutput> {
    private final double TOLERANCE = 1.0E-6;
    private final int MAX_COLS_EXPANDED = 5000;
    private transient int _ncolA;
    private transient int _ncolY;
    private transient int _ncolX;

    public ModelBuilderSchema schema() {
        return new GLRMV99();
    }

    public Job<GLRMModel> trainModelImpl(long work, boolean restartTimer) {
        return this.start(new GLRMDriver(), work, restartTimer);
    }

    public long progressUnits() {
        return ((GLRMModel.GLRMParameters)this._parms)._max_iterations + 1;
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Clustering};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    public GLRM(GLRMModel.GLRMParameters parms) {
        super("GLRM", (Model.Parameters)parms);
        this.init(false);
    }

    public void init(boolean expensive) {
        super.init(expensive);
        if (((GLRMModel.GLRMParameters)this._parms)._loading_key == null) {
            ((GLRMModel.GLRMParameters)this._parms)._loading_key = Key.make((String)("GLRMLoading_" + Key.rand()));
        }
        if (((GLRMModel.GLRMParameters)this._parms)._gamma_x < 0.0) {
            this.error("_gamma_x", "gamma must be a non-negative number");
        }
        if (((GLRMModel.GLRMParameters)this._parms)._gamma_y < 0.0) {
            this.error("_gamma_y", "gamma_y must be a non-negative number");
        }
        if (((GLRMModel.GLRMParameters)this._parms)._max_iterations < 1 || (double)((GLRMModel.GLRMParameters)this._parms)._max_iterations > 1000000.0) {
            this.error("_max_iterations", "max_iterations must be between 1 and 1e6 inclusive");
        }
        if (((GLRMModel.GLRMParameters)this._parms)._init_step_size <= 0.0) {
            this.error("_init_step_size", "init_step_size must be a positive number");
        }
        if (((GLRMModel.GLRMParameters)this._parms)._min_step_size < 0.0 || ((GLRMModel.GLRMParameters)this._parms)._min_step_size > ((GLRMModel.GLRMParameters)this._parms)._init_step_size) {
            this.error("_min_step_size", "min_step_size must be between 0 and " + ((GLRMModel.GLRMParameters)this._parms)._init_step_size);
        }
        if (this._train == null) {
            return;
        }
        if (this._train.numCols() < 2) {
            this.error("_train", "_train must have more than one column");
        }
        this._ncolY = this._train.numColsExp(true, false);
        int k_min = (int)Math.min((long)this._ncolY, this._train.numRows());
        if (this._ncolY > 5000) {
            this.warn("_train", "_train has " + this._ncolY + " columns when categoricals are expanded. Algorithm may be slow.");
        }
        if (((GLRMModel.GLRMParameters)this._parms)._k < 1 || ((GLRMModel.GLRMParameters)this._parms)._k > k_min) {
            this.error("_k", "_k must be between 1 and " + k_min);
        }
        if (null != ((GLRMModel.GLRMParameters)this._parms)._user_points) {
            if (((GLRMModel.GLRMParameters)this._parms)._init != Initialization.User) {
                this.error("init", "init must be 'User' if providing user-specified points");
            }
            if (((Frame)((GLRMModel.GLRMParameters)this._parms)._user_points.get()).numCols() != this._train.numCols()) {
                this.error("_user_points", "The user-specified points must have the same number of columns (" + this._train.numCols() + ") as the training observations");
            } else if (((Frame)((GLRMModel.GLRMParameters)this._parms)._user_points.get()).numRows() != (long)((GLRMModel.GLRMParameters)this._parms)._k) {
                this.error("_user_points", "The user-specified points must have k = " + ((GLRMModel.GLRMParameters)this._parms)._k + " rows");
            } else {
                int zero_vec = 0;
                Vec[] centersVecs = ((Frame)((GLRMModel.GLRMParameters)this._parms)._user_points.get()).vecs();
                for (int c = 0; c < this._train.numCols(); ++c) {
                    if (centersVecs[c].naCnt() > 0L) {
                        this.error("_user_points", "The user-specified points cannot contain any missing values");
                        break;
                    }
                    if (!centersVecs[c].isConst() || centersVecs[c].max() != 0.0) continue;
                    ++zero_vec;
                }
                if (zero_vec == this._train.numCols()) {
                    this.error("_user_points", "The user-specified points cannot all be zero");
                }
            }
        }
        this._ncolX = ((GLRMModel.GLRMParameters)this._parms)._k;
        this._ncolA = this._train.numCols();
    }

    public static double frobenius2(double[][] x) {
        if (x == null) {
            return 0.0;
        }
        double frob = 0.0;
        for (int i = 0; i < x.length; ++i) {
            for (int j = 0; j < x[0].length; ++j) {
                frob += x[i][j] * x[i][j];
            }
        }
        return frob;
    }

    public static double[][] transform(double[][] centers, double[] normSub, double[] normMul, int ncats, int nnums) {
        double[] mults;
        int K = centers.length;
        int N = centers[0].length;
        assert (ncats + nnums == N);
        double[][] value = new double[K][N];
        double[] means = normSub == null ? MemoryManager.malloc8d((int)nnums) : normSub;
        double[] dArray = mults = normMul == null ? MemoryManager.malloc8d((int)nnums) : normMul;
        if (normMul == null) {
            Arrays.fill(mults, 1.0);
        }
        for (int clu = 0; clu < K; ++clu) {
            System.arraycopy(centers[clu], 0, value[clu], 0, ncats);
            for (int col = 0; col < nnums; ++col) {
                value[clu][ncats + col] = (centers[clu][ncats + col] - means[col]) * mults[col];
            }
        }
        return value;
    }

    public static double[][] expandCats(double[][] sdata, DataInfo dinfo) {
        int i;
        int j;
        if (sdata == null || dinfo._cats == 0) {
            return sdata;
        }
        assert (sdata[0].length == dinfo._adaptedFrame.numCols());
        int catsexp = dinfo._catOffsets[dinfo._catOffsets.length - 1];
        double[][] cexp = new double[sdata.length][catsexp + dinfo._nums];
        for (j = 0; j < dinfo._cats; ++j) {
            for (i = 0; i < sdata.length; ++i) {
                int cidx;
                if (Double.isNaN(sdata[i][j])) {
                    if (dinfo._catMissing[j] == 0) continue;
                    cidx = dinfo._catOffsets[j + 1] - 1;
                } else {
                    cidx = dinfo.getCategoricalId(j, (int)sdata[i][j]);
                }
                if (cidx < 0) continue;
                cexp[i][cidx] = 1.0;
            }
        }
        for (j = 0; j < dinfo._nums; ++j) {
            for (i = 0; i < sdata.length; ++i) {
                cexp[i][catsexp + j] = sdata[i][dinfo._cats + j];
            }
        }
        return cexp;
    }

    protected static int idx_xold(int c, int ncolA) {
        return ncolA + c;
    }

    protected static int idx_xnew(int c, int ncolA, int ncolX) {
        return ncolA + ncolX + c;
    }

    protected static int idx_ycat(int c, int level, DataInfo dinfo) {
        assert (dinfo._adaptedFrame.domains() != null) : "Domain of categorical column cannot be null";
        assert (!Double.isNaN(level) && level >= 0 && level < dinfo._adaptedFrame.domains()[c].length);
        return dinfo._catOffsets[c] + level;
    }

    protected static int idx_ynum(int c, DataInfo dinfo) {
        return dinfo._catOffsets[dinfo._catOffsets.length - 1] + c;
    }

    protected static Chunk chk_xold(Chunk[] chks, int c, int ncolA) {
        return chks[ncolA + c];
    }

    protected static Chunk chk_xnew(Chunk[] chks, int c, int ncolA, int ncolX) {
        return chks[ncolA + ncolX + c];
    }

    protected static double[][] yt_block(double[][] yt, int cidx, DataInfo dinfo) {
        return GLRM.yt_block(yt, cidx, dinfo, false);
    }

    protected static double[][] yt_block(double[][] yt, int cidx, DataInfo dinfo, boolean transpose) {
        double[][] block;
        int catlvls;
        int n = catlvls = dinfo._adaptedFrame.domains() == null ? 1 : dinfo._adaptedFrame.domains()[cidx].length;
        if (transpose) {
            block = new double[yt[0].length][catlvls];
            for (int col = 0; col < block.length; ++col) {
                for (int level = 0; level < block[0].length; ++level) {
                    block[col][level] = yt[GLRM.idx_ycat(cidx, level, dinfo)][col];
                }
            }
        } else {
            block = new double[catlvls][yt[0].length];
            for (int col = 0; col < block[0].length; ++col) {
                for (int level = 0; level < block.length; ++level) {
                    block[level][col] = yt[GLRM.idx_ycat(cidx, level, dinfo)][col];
                }
            }
        }
        return block;
    }

    private static class BMulTask
    extends FrameTask<BMulTask> {
        double[][] _yt;

        BMulTask(Key jobKey, DataInfo dinfo, double[][] yt) {
            super(jobKey, dinfo);
            this._yt = yt;
        }

        @Override
        protected void processRow(long gid, DataInfo.Row row, NewChunk[] outputs) {
            assert (row.nBins + this._dinfo._nums == this._yt[0].length);
            for (int p = 0; p < this._yt.length; ++p) {
                double x = row.innerProduct(this._yt[p]);
                outputs[p].addNum(x);
            }
        }
    }

    private static class CholMulTask
    extends MRTask<CholMulTask> {
        DataInfo _dinfo;
        GLRMModel.GLRMParameters _parms;
        final double[][] _yt;
        final int _ncolA;
        final int _ncolX;
        final double[] _normSub;
        final double[] _normMul;
        CholeskyDecomposition _chol;

        CholMulTask(DataInfo dinfo, GLRMModel.GLRMParameters parms, CholeskyDecomposition chol, double[][] yt, double[] normSub, double[] normMul) {
            this(dinfo, parms, chol, yt, yt.length, yt[0].length, normSub, normMul);
        }

        CholMulTask(DataInfo dinfo, GLRMModel.GLRMParameters parms, CholeskyDecomposition chol, double[][] yt, int ncolA, int ncolX, double[] normSub, double[] normMul) {
            assert (yt != null && yt[0].length == ncolX);
            this._parms = parms;
            this._yt = yt;
            this._ncolA = ncolA;
            this._ncolX = ncolX;
            this._chol = chol;
            assert (dinfo._cats <= ncolA);
            this._dinfo = dinfo;
            this._normSub = normSub;
            this._normMul = normMul;
        }

        public void map(Chunk[] cs) {
            assert (this._ncolA + 2 * this._ncolX == cs.length);
            double[] xrow = new double[this._ncolX];
            for (int row = 0; row < cs[0]._len; ++row) {
                for (int k = 0; k < this._ncolX; ++k) {
                    int d;
                    double x = 0.0;
                    for (d = 0; d < this._dinfo._cats; ++d) {
                        double a = cs[d].atd(row);
                        if (Double.isNaN(a)) continue;
                        x += this._yt[GLRM.idx_ycat(d, (int)a, this._dinfo)][k];
                    }
                    for (d = this._dinfo._cats; d < this._ncolA; ++d) {
                        int ds = d - this._dinfo._cats;
                        double a = cs[d].atd(row);
                        if (Double.isNaN(a)) continue;
                        x += (a - this._normSub[ds]) * this._normMul[ds] * this._yt[GLRM.idx_ynum(ds, this._dinfo)][k];
                    }
                    xrow[k] = x;
                }
                Matrix tmp = this._chol.solve(new Matrix((double[][])new double[][]{xrow}).transpose());
                xrow = tmp.getColumnPackedCopy();
                int i = 0;
                for (int d = this._ncolA; d < this._ncolA + this._ncolX; ++d) {
                    cs[d].set(row, xrow[i]);
                    cs[d + this._ncolX].set(row, xrow[i++]);
                }
                assert (i == xrow.length);
            }
        }
    }

    private static class ObjCalc
    extends MRTask<ObjCalc> {
        DataInfo _dinfo;
        GLRMModel.GLRMParameters _parms;
        final double[][] _yt;
        final int _ncolA;
        final int _ncolX;
        final double[] _normSub;
        final double[] _normMul;
        final boolean _regX;
        double _loss;
        double _xold_reg;

        ObjCalc(DataInfo dinfo, GLRMModel.GLRMParameters parms, double[][] yt, int ncolA, int ncolX, double[] normSub, double[] normMul) {
            this(dinfo, parms, yt, ncolA, ncolX, normSub, normMul, false);
        }

        ObjCalc(DataInfo dinfo, GLRMModel.GLRMParameters parms, double[][] yt, int ncolA, int ncolX, double[] normSub, double[] normMul, boolean regX) {
            assert (yt != null && yt[0].length == ncolX);
            this._parms = parms;
            this._yt = yt;
            this._ncolA = ncolA;
            this._ncolX = ncolX;
            this._regX = regX;
            this._xold_reg = 0.0;
            this._loss = 0.0;
            assert (dinfo._cats <= ncolA);
            this._dinfo = dinfo;
            this._normSub = normSub;
            this._normMul = normMul;
        }

        public void map(Chunk[] cs) {
            assert (this._ncolA + 2 * this._ncolX == cs.length);
            for (int row = 0; row < cs[0]._len; ++row) {
                double a;
                int j;
                for (j = 0; j < this._dinfo._cats; ++j) {
                    a = cs[j].atd(row);
                    if (Double.isNaN(a)) continue;
                    double[] xy = new double[this._dinfo._adaptedFrame.domains()[j].length];
                    for (int level = 0; level < xy.length; ++level) {
                        for (int k = 0; k < this._ncolX; ++k) {
                            int n = level;
                            xy[n] = xy[n] + GLRM.chk_xnew(cs, k, this._ncolA, this._ncolX).atd(row) * this._yt[GLRM.idx_ycat(j, level, this._dinfo)][k];
                        }
                    }
                    this._loss += this._parms.mloss(xy, (int)a);
                }
                for (j = this._dinfo._cats; j < this._ncolA; ++j) {
                    a = cs[j].atd(row);
                    if (Double.isNaN(a)) continue;
                    double xy = 0.0;
                    int js = j - this._dinfo._cats;
                    for (int k = 0; k < this._ncolX; ++k) {
                        xy += GLRM.chk_xnew(cs, k, this._ncolA, this._ncolX).atd(row) * this._yt[GLRM.idx_ynum(js, this._dinfo)][k];
                    }
                    this._loss += this._parms.loss(xy, (a - this._normSub[js]) * this._normMul[js]);
                }
                if (!this._regX) continue;
                int idx = 0;
                double[] xrow = new double[this._ncolX];
                for (int j2 = this._ncolA; j2 < this._ncolA + this._ncolX; ++j2) {
                    xrow[idx] = cs[j2].atd(row);
                    ++idx;
                }
                assert (idx == this._ncolX);
                this._xold_reg += this._parms.regularize_x(xrow);
            }
        }
    }

    private static class UpdateY
    extends MRTask<UpdateY> {
        DataInfo _dinfo;
        GLRMModel.GLRMParameters _parms;
        final double _alpha;
        final double[][] _ytold;
        final int _ncolA;
        final int _ncolX;
        final double[] _normSub;
        final double[] _normMul;
        double[][] _ytnew;
        double _yreg;

        UpdateY(DataInfo dinfo, GLRMModel.GLRMParameters parms, double[][] yt, double alpha, int ncolA, int ncolX, double[] normSub, double[] normMul) {
            assert (yt != null && yt[0].length == ncolX);
            this._parms = parms;
            this._alpha = alpha;
            this._ncolA = ncolA;
            this._ncolX = ncolX;
            this._ytold = yt;
            this._yreg = 0.0;
            assert (dinfo._cats <= ncolA);
            this._dinfo = dinfo;
            this._normSub = normSub;
            this._normMul = normMul;
        }

        public void map(Chunk[] cs) {
            int j;
            assert (this._ncolA + 2 * this._ncolX == cs.length);
            this._ytnew = new double[this._ytold.length][this._ncolX];
            for (j = 0; j < this._dinfo._cats; ++j) {
                for (int row = 0; row < cs[0]._len; ++row) {
                    double a = cs[j].atd(row);
                    if (Double.isNaN(a)) continue;
                    double[] xy = new double[this._dinfo._adaptedFrame.domains()[j].length];
                    for (int level = 0; level < xy.length; ++level) {
                        for (int k = 0; k < this._ncolX; ++k) {
                            int n = level;
                            xy[n] = xy[n] + GLRM.chk_xnew(cs, k, this._ncolA, this._ncolX).atd(row) * this._ytold[GLRM.idx_ycat(j, level, this._dinfo)][k];
                        }
                    }
                    double[] weight = this._parms.mlgrad(xy, (int)a);
                    for (int level = 0; level < xy.length; ++level) {
                        for (int k = 0; k < this._ncolX; ++k) {
                            double[] dArray = this._ytnew[GLRM.idx_ycat(j, level, this._dinfo)];
                            int n = k;
                            dArray[n] = dArray[n] + weight[level] * GLRM.chk_xnew(cs, k, this._ncolA, this._ncolX).atd(row);
                        }
                    }
                }
            }
            for (j = this._dinfo._cats; j < this._ncolA; ++j) {
                int js = j - this._dinfo._cats;
                int yidx = GLRM.idx_ynum(js, this._dinfo);
                for (int row = 0; row < cs[0]._len; ++row) {
                    double a = cs[j].atd(row);
                    if (Double.isNaN(a)) continue;
                    double xy = 0.0;
                    for (int k = 0; k < this._ncolX; ++k) {
                        xy += GLRM.chk_xnew(cs, k, this._ncolA, this._ncolX).atd(row) * this._ytold[yidx][k];
                    }
                    double weight = this._parms.lgrad(xy, (a - this._normSub[js]) * this._normMul[js]);
                    for (int k = 0; k < this._ncolX; ++k) {
                        double[] dArray = this._ytnew[yidx];
                        int n = k;
                        dArray[n] = dArray[n] + weight * GLRM.chk_xnew(cs, k, this._ncolA, this._ncolX).atd(row);
                    }
                }
            }
        }

        public void reduce(UpdateY other) {
            ArrayUtils.add((double[][])this._ytnew, (double[][])other._ytnew);
        }

        protected void postGlobal() {
            assert (this._ytnew.length == this._ytold.length && this._ytnew[0].length == this._ytold[0].length);
            Random rand = RandomUtils.getRNG((long[])new long[]{this._parms._seed});
            for (int j = 0; j < this._ytnew.length; ++j) {
                double[] u = new double[this._ytnew[0].length];
                for (int k = 0; k < this._ytnew[0].length; ++k) {
                    u[k] = this._ytold[j][k] - this._alpha * this._ytnew[j][k];
                }
                this._ytnew[j] = this._parms.rproxgrad_y(u, this._alpha, rand);
                this._yreg += this._parms.regularize_y(this._ytnew[j]);
            }
        }
    }

    private static class UpdateX
    extends MRTask<UpdateX> {
        DataInfo _dinfo;
        GLRMModel.GLRMParameters _parms;
        final double _alpha;
        final boolean _update;
        final double[][] _yt;
        final int _ncolA;
        final int _ncolX;
        final double[] _normSub;
        final double[] _normMul;
        double _loss;
        double _xreg;

        UpdateX(DataInfo dinfo, GLRMModel.GLRMParameters parms, double[][] yt, double alpha, boolean update, int ncolA, int ncolX, double[] normSub, double[] normMul) {
            assert (yt != null && yt[0].length == ncolX);
            this._parms = parms;
            this._yt = yt;
            this._alpha = alpha;
            this._update = update;
            this._ncolA = ncolA;
            this._ncolX = ncolX;
            assert (dinfo._cats <= ncolA);
            this._dinfo = dinfo;
            this._normSub = normSub;
            this._normMul = normMul;
        }

        public void map(Chunk[] cs) {
            assert (this._ncolA + 2 * this._ncolX == cs.length);
            double[] a = new double[this._ncolA];
            Random rand = RandomUtils.getRNG((long[])new long[]{this._parms._seed + cs[0].start()});
            this._xreg = 0.0;
            this._loss = 0.0;
            for (int row = 0; row < cs[0]._len; ++row) {
                int j;
                int k;
                int j2;
                double[] grad = new double[this._ncolX];
                double[] xnew = new double[this._ncolX];
                if (this._update) {
                    for (int k2 = 0; k2 < this._ncolX; ++k2) {
                        GLRM.chk_xold(cs, k2, this._ncolA).set(row, GLRM.chk_xnew(cs, k2, this._ncolA, this._ncolX).atd(row));
                    }
                }
                for (j2 = 0; j2 < this._dinfo._cats; ++j2) {
                    a[j2] = cs[j2].atd(row);
                    if (Double.isNaN(a[j2])) continue;
                    double[] xy = new double[this._dinfo._adaptedFrame.domains()[j2].length];
                    for (int level = 0; level < xy.length; ++level) {
                        for (int k3 = 0; k3 < this._ncolX; ++k3) {
                            int n = level;
                            xy[n] = xy[n] + GLRM.chk_xold(cs, k3, this._ncolA).atd(row) * this._yt[GLRM.idx_ycat(j2, level, this._dinfo)][k3];
                        }
                    }
                    double[] weight = this._parms.mlgrad(xy, (int)a[j2]);
                    double[][] ytsub = GLRM.yt_block(this._yt, j2, this._dinfo);
                    for (int k4 = 0; k4 < this._ncolX; ++k4) {
                        for (int c = 0; c < weight.length; ++c) {
                            int n = k4;
                            grad[n] = grad[n] + weight[c] * ytsub[c][k4];
                        }
                    }
                }
                for (j2 = this._dinfo._cats; j2 < this._ncolA; ++j2) {
                    int js = j2 - this._dinfo._cats;
                    int yidx = GLRM.idx_ynum(js, this._dinfo);
                    a[j2] = cs[j2].atd(row);
                    if (Double.isNaN(a[j2])) continue;
                    double xy = 0.0;
                    for (int k5 = 0; k5 < this._ncolX; ++k5) {
                        xy += GLRM.chk_xold(cs, k5, this._ncolA).atd(row) * this._yt[yidx][k5];
                    }
                    double weight = this._parms.lgrad(xy, (a[j2] - this._normSub[js]) * this._normMul[js]);
                    for (int k6 = 0; k6 < this._ncolX; ++k6) {
                        int n = k6;
                        grad[n] = grad[n] + weight * this._yt[yidx][k6];
                    }
                }
                double[] u = new double[this._ncolX];
                for (k = 0; k < this._ncolX; ++k) {
                    double xold = GLRM.chk_xold(cs, k, this._ncolA).atd(row);
                    u[k] = xold - this._alpha * grad[k];
                }
                xnew = this._parms.rproxgrad_x(u, this._alpha, rand);
                this._xreg += this._parms.regularize_x(xnew);
                for (k = 0; k < this._ncolX; ++k) {
                    GLRM.chk_xnew(cs, k, this._ncolA, this._ncolX).set(row, xnew[k]);
                }
                for (j = 0; j < this._dinfo._cats; ++j) {
                    if (Double.isNaN(a[j])) continue;
                    double[] xy = ArrayUtils.multVecArr((double[])xnew, (double[][])GLRM.yt_block(this._yt, j, this._dinfo, true));
                    this._loss += this._parms.mloss(xy, (int)a[j]);
                }
                for (j = this._dinfo._cats; j < this._ncolA; ++j) {
                    int js = j - this._dinfo._cats;
                    if (Double.isNaN(a[j])) continue;
                    double xy = ArrayUtils.innerProduct((double[])xnew, (double[])this._yt[GLRM.idx_ynum(js, this._dinfo)]);
                    this._loss += this._parms.loss(xy, (a[j] - this._normSub[js]) * this._normMul[js]);
                }
            }
        }

        public void reduce(UpdateX other) {
            this._loss += other._loss;
            this._xreg += other._xreg;
        }
    }

    class GLRMDriver
    extends H2O.H2OCountedCompleter<GLRMDriver> {
        GLRMDriver() {
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private double[][] initialXY(DataInfo tinfo, DataInfo dinfo, long na_cnt) {
            double[][] centers_exp = null;
            if (null != ((GLRMModel.GLRMParameters)GLRM.this._parms)._user_points) {
                Vec[] centersVecs = ((Frame)((GLRMModel.GLRMParameters)GLRM.this._parms)._user_points.get()).vecs();
                double[][] centers = new double[((GLRMModel.GLRMParameters)GLRM.this._parms)._k][GLRM.this._ncolA];
                for (int c = 0; c < GLRM.this._ncolA; ++c) {
                    for (int r = 0; r < ((GLRMModel.GLRMParameters)GLRM.this._parms)._k; ++r) {
                        centers[r][c] = centersVecs[c].at((long)r);
                    }
                }
                centers = ArrayUtils.permuteCols((double[][])centers, (int[])tinfo._permutation);
                centers_exp = GLRM.expandCats(centers, tinfo);
            } else {
                Object parms;
                if (((GLRMModel.GLRMParameters)GLRM.this._parms)._init == Initialization.Random) {
                    return ArrayUtils.gaussianArray((int)((GLRMModel.GLRMParameters)GLRM.this._parms)._k, (int)GLRM.this._ncolY);
                }
                if (((GLRMModel.GLRMParameters)GLRM.this._parms)._init == Initialization.SVD) {
                    parms = new PCAModel.PCAParameters();
                    parms._train = ((GLRMModel.GLRMParameters)GLRM.this._parms)._train;
                    parms._ignored_columns = ((GLRMModel.GLRMParameters)GLRM.this._parms)._ignored_columns;
                    parms._ignore_const_cols = ((GLRMModel.GLRMParameters)GLRM.this._parms)._ignore_const_cols;
                    parms._score_each_iteration = ((GLRMModel.GLRMParameters)GLRM.this._parms)._score_each_iteration;
                    parms._use_all_factor_levels = true;
                    parms._k = ((GLRMModel.GLRMParameters)GLRM.this._parms)._k;
                    parms._max_iterations = ((GLRMModel.GLRMParameters)GLRM.this._parms)._max_iterations;
                    parms._transform = ((GLRMModel.GLRMParameters)GLRM.this._parms)._transform;
                    parms._seed = ((GLRMModel.GLRMParameters)GLRM.this._parms)._seed;
                    parms._pca_method = PCAModel.PCAParameters.Method.GramSVD;
                    PCAModel pca = null;
                    PCA job = null;
                    try {
                        job = new PCA((PCAModel.PCAParameters)((Object)parms));
                        pca = (PCAModel)job.trainModel().get();
                    }
                    finally {
                        if (job != null) {
                            job.remove();
                        }
                        if (pca != null) {
                            pca.remove();
                        }
                    }
                    assert (((PCAModel.PCAOutput)pca._output)._permutation.length == tinfo._permutation.length);
                    for (int i = 0; i < tinfo._permutation.length; ++i) {
                        assert (((PCAModel.PCAOutput)pca._output)._permutation[i] == tinfo._permutation[i]);
                    }
                    centers_exp = ArrayUtils.transpose((double[][])((PCAModel.PCAOutput)pca._output)._eigenvectors_raw);
                } else if (((GLRMModel.GLRMParameters)GLRM.this._parms)._init == Initialization.PlusPlus) {
                    parms = new KMeansModel.KMeansParameters();
                    ((KMeansModel.KMeansParameters)((Object)parms))._train = ((GLRMModel.GLRMParameters)GLRM.this._parms)._train;
                    ((KMeansModel.KMeansParameters)((Object)parms))._ignored_columns = ((GLRMModel.GLRMParameters)GLRM.this._parms)._ignored_columns;
                    ((KMeansModel.KMeansParameters)((Object)parms))._ignore_const_cols = ((GLRMModel.GLRMParameters)GLRM.this._parms)._ignore_const_cols;
                    ((KMeansModel.KMeansParameters)((Object)parms))._score_each_iteration = ((GLRMModel.GLRMParameters)GLRM.this._parms)._score_each_iteration;
                    ((KMeansModel.KMeansParameters)((Object)parms))._init = KMeans.Initialization.PlusPlus;
                    ((KMeansModel.KMeansParameters)((Object)parms))._k = ((GLRMModel.GLRMParameters)GLRM.this._parms)._k;
                    ((KMeansModel.KMeansParameters)((Object)parms))._max_iterations = ((GLRMModel.GLRMParameters)GLRM.this._parms)._max_iterations;
                    ((KMeansModel.KMeansParameters)((Object)parms))._standardize = true;
                    ((KMeansModel.KMeansParameters)((Object)parms))._seed = ((GLRMModel.GLRMParameters)GLRM.this._parms)._seed;
                    ((KMeansModel.KMeansParameters)((Object)parms))._pred_indicator = true;
                    KMeansModel km = null;
                    KMeans job = null;
                    try {
                        job = new KMeans((KMeansModel.KMeansParameters)((Object)parms));
                        km = (KMeansModel)job.trainModel().get();
                        double frob = GLRM.frobenius2(((KMeansModel.KMeansOutput)km._output)._centers_raw);
                        if (frob != 0.0 && !Double.isNaN(frob) && na_cnt == 0L && !((GLRMModel.GLRMParameters)GLRM.this._parms).hasClosedForm()) {
                            this.initialXKmeans(dinfo, km);
                        }
                    }
                    finally {
                        if (job != null) {
                            job.remove();
                        }
                        if (km != null) {
                            km.remove();
                        }
                    }
                    double[][] centers = ArrayUtils.permuteCols((double[][])((KMeansModel.KMeansOutput)km._output)._centers_raw, (int[])tinfo.mapNames(((KMeansModel.KMeansOutput)km._output)._names));
                    centers = GLRM.transform(centers, tinfo._normSub, tinfo._normMul, tinfo._cats, tinfo._nums);
                    centers_exp = GLRM.expandCats(centers, tinfo);
                } else {
                    GLRM.this.error("_init", "Initialization method " + (Object)((Object)((GLRMModel.GLRMParameters)GLRM.this._parms)._init) + " is undefined");
                }
            }
            assert (centers_exp != null && centers_exp[0].length == GLRM.this._ncolY);
            double frob = GLRM.frobenius2(centers_exp);
            if (frob == 0.0 || Double.isNaN(frob)) {
                GLRM.this.warn("_init", "Initialization failed. Setting initial Y to standard normal random matrix instead");
                centers_exp = ArrayUtils.gaussianArray((int)((GLRMModel.GLRMParameters)GLRM.this._parms)._k, (int)GLRM.this._ncolY);
            }
            return centers_exp;
        }

        private void initialXKmeans(DataInfo dinfo, KMeansModel km) {
            Log.info((Object[])new Object[]{"Initializing X to matrix of indicator columns corresponding to cluster assignments"});
            final KMeansModel model = km;
            new MRTask(){

                public void map(Chunk[] chks) {
                    double[] tmp = new double[GLRM.this._ncolA];
                    double[] preds = new double[GLRM.this._ncolX];
                    for (int row = 0; row < chks[0]._len; ++row) {
                        double[] p = model.score_indicator(chks, row, tmp, preds);
                        for (int c = 0; c < preds.length; ++c) {
                            chks[GLRM.this._ncolA + c].set(row, p[c]);
                            chks[GLRM.this._ncolA + GLRM.this._ncolX + c].set(row, p[c]);
                        }
                    }
                }
            }.doAll(dinfo._adaptedFrame);
        }

        private void initialXClosedForm(DataInfo dinfo, double[][] yt, double[] normSub, double[] normMul) {
            CholeskyDecomposition yychol;
            Log.info((Object[])new Object[]{"Initializing X = AY'(YY' + gamma I)^(-1) where A = training data"});
            double[][] ygram = ArrayUtils.formGram((double[][])yt);
            if (((GLRMModel.GLRMParameters)GLRM.this._parms)._gamma_y > 0.0) {
                int i = 0;
                while (i < ygram.length) {
                    double[] dArray = ygram[i];
                    int n = i++;
                    dArray[n] = dArray[n] + ((GLRMModel.GLRMParameters)GLRM.this._parms)._gamma_y;
                }
            }
            if (!(yychol = this.regularizedCholesky(ygram, 10, false)).isSPD()) {
                Log.warn((Object[])new Object[]{"Initialization failed: (YY' + gamma I) is non-SPD. Setting initial X to standard normal random matrix. Results will be numerically unstable"});
            } else {
                CholMulTask cmtsk = new CholMulTask(dinfo, (GLRMModel.GLRMParameters)GLRM.this._parms, yychol, yt, GLRM.this._ncolA, GLRM.this._ncolX, normSub, normMul);
                cmtsk.doAll(dinfo._adaptedFrame);
            }
        }

        private boolean isDone(GLRMModel model, int steps_in_row, double step) {
            if (!GLRM.this.isRunning()) {
                return true;
            }
            if (((GLRMModel.GLRMOutput)model._output)._iterations >= ((GLRMModel.GLRMParameters)GLRM.this._parms)._max_iterations) {
                return true;
            }
            if (step <= ((GLRMModel.GLRMParameters)GLRM.this._parms)._min_step_size) {
                return true;
            }
            return ((GLRMModel.GLRMOutput)model._output)._iterations > 10 && steps_in_row > 3 && Math.abs(((GLRMModel.GLRMOutput)model._output)._avg_change_obj) < 1.0E-6;
        }

        public Gram.Cholesky regularizedCholesky(Gram gram, int max_attempts) {
            double addedL2 = 0.0;
            Gram.Cholesky chol = gram.cholesky(null);
            for (int attempts = 0; !chol.isSPD() && attempts < max_attempts; ++attempts) {
                addedL2 = addedL2 == 0.0 ? 1.0E-5 : (addedL2 *= 10.0);
                gram.addDiag(addedL2);
                Log.info((Object[])new Object[]{"Added L2 regularization = " + addedL2 + " to diagonal of Gram matrix"});
                gram.cholesky(chol);
            }
            if (!chol.isSPD()) {
                throw new Gram.NonSPDMatrixException();
            }
            return chol;
        }

        public Gram.Cholesky regularizedCholesky(Gram gram) {
            return this.regularizedCholesky(gram, 10);
        }

        public CholeskyDecomposition regularizedCholesky(double[][] gram, int max_attempts, boolean throw_exception) {
            int attempts = 0;
            double addedL2 = 0.0;
            Matrix gmat = new Matrix(gram);
            CholeskyDecomposition chol = new CholeskyDecomposition(gmat);
            while (!chol.isSPD() && attempts < max_attempts) {
                addedL2 = addedL2 == 0.0 ? 1.0E-5 : (addedL2 *= 10.0);
                ++attempts;
                for (int i = 0; i < gram.length; ++i) {
                    gmat.set(i, i, addedL2);
                }
                Log.info((Object[])new Object[]{"Added L2 regularization = " + addedL2 + " to diagonal of Gram matrix"});
                chol = new CholeskyDecomposition(gmat);
            }
            if (!chol.isSPD() && throw_exception) {
                throw new Gram.NonSPDMatrixException();
            }
            return chol;
        }

        public CholeskyDecomposition regularizedCholesky(double[][] gram) {
            return this.regularizedCholesky(gram, 10, true);
        }

        public void recoverSVD(GLRMModel model, DataInfo xinfo) {
            Gram.GramTask xgram = (Gram.GramTask)new Gram.GramTask(this.self(), xinfo).doAll(xinfo._adaptedFrame);
            Gram.Cholesky xxchol = this.regularizedCholesky(xgram._gram);
            Matrix x_r = new Matrix(xxchol.getL()).transpose();
            x_r = x_r.times(Math.sqrt(GLRM.this._train.numRows()));
            QRDecomposition yt_qr = new QRDecomposition(new Matrix(((GLRMModel.GLRMOutput)model._output)._archetypes));
            Matrix yt_r = yt_qr.getR();
            Matrix rrmul = x_r.times(yt_r.transpose());
            SingularValueDecomposition rrsvd = new SingularValueDecomposition(rrmul);
            Matrix eigvec = yt_qr.getQ().times(rrsvd.getV());
            ((GLRMModel.GLRMOutput)model._output)._eigenvectors = eigvec.getArray();
            ((GLRMModel.GLRMOutput)model._output)._singular_vals = rrsvd.getSingularValues();
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        protected void compute2() {
            GLRMModel model = null;
            DataInfo dinfo = null;
            DataInfo xinfo = null;
            DataInfo tinfo = null;
            Frame fr = null;
            Frame x = null;
            boolean overwriteX = false;
            try {
                int i;
                int i2;
                GLRM.this.init(true);
                ((GLRMModel.GLRMParameters)GLRM.this._parms).read_lock_frames((Job)GLRM.this);
                if (GLRM.this.error_count() > 0) {
                    throw new IllegalArgumentException("Found validation errors: " + GLRM.this.validationErrors());
                }
                model = new GLRMModel(GLRM.this.dest(), (GLRMModel.GLRMParameters)GLRM.this._parms, new GLRMModel.GLRMOutput(GLRM.this));
                model.delete_and_lock(this.self());
                tinfo = new DataInfo(Key.make(), GLRM.this._train, GLRM.this._valid, 0, true, ((GLRMModel.GLRMParameters)GLRM.this._parms)._transform, DataInfo.TransformType.NONE, false, false, false, false, false);
                DKV.put((Key)tinfo._key, (Iced)tinfo);
                double[] dArray = ((GLRMModel.GLRMOutput)model._output)._normSub = tinfo._normSub == null ? new double[tinfo._nums] : tinfo._normSub;
                if (tinfo._normMul == null) {
                    ((GLRMModel.GLRMOutput)model._output)._normMul = new double[tinfo._nums];
                    Arrays.fill(((GLRMModel.GLRMOutput)model._output)._normMul, 1.0);
                } else {
                    ((GLRMModel.GLRMOutput)model._output)._normMul = tinfo._normMul;
                }
                ((GLRMModel.GLRMOutput)model._output)._permutation = tinfo._permutation;
                ((GLRMModel.GLRMOutput)model._output)._nnums = tinfo._nums;
                ((GLRMModel.GLRMOutput)model._output)._ncats = tinfo._cats;
                ((GLRMModel.GLRMOutput)model._output)._catOffsets = tinfo._catOffsets;
                ((GLRMModel.GLRMOutput)model._output)._names_expanded = tinfo.coefNames();
                long nobs = GLRM.this._train.numRows() * (long)GLRM.this._train.numCols();
                long na_cnt = 0L;
                for (int i3 = 0; i3 < GLRM.this._train.numCols(); ++i3) {
                    na_cnt += GLRM.this._train.vec(i3).naCnt();
                }
                ((GLRMModel.GLRMOutput)model._output)._nobs = nobs - na_cnt;
                Vec[] vecs = new Vec[GLRM.this._ncolA + 2 * GLRM.this._ncolX];
                for (i2 = 0; i2 < GLRM.this._ncolA; ++i2) {
                    vecs[i2] = GLRM.this._train.vec(i2);
                }
                for (i2 = GLRM.this._ncolA; i2 < vecs.length; ++i2) {
                    vecs[i2] = GLRM.this._train.anyVec().makeGaus(((GLRMModel.GLRMParameters)GLRM.this._parms)._seed);
                }
                fr = new Frame(null, vecs);
                dinfo = new DataInfo(Key.make(), fr, null, 0, true, ((GLRMModel.GLRMParameters)GLRM.this._parms)._transform, DataInfo.TransformType.NONE, false, false, false, false, false);
                DKV.put((Key)dinfo._key, (Iced)dinfo);
                double[][] yt = this.initialXY(tinfo, dinfo, na_cnt);
                yt = ArrayUtils.transpose((double[][])yt);
                if (na_cnt == 0L && ((GLRMModel.GLRMParameters)GLRM.this._parms).hasClosedForm()) {
                    this.initialXClosedForm(dinfo, yt, ((GLRMModel.GLRMOutput)model._output)._normSub, ((GLRMModel.GLRMOutput)model._output)._normMul);
                }
                ObjCalc objtsk = (ObjCalc)new ObjCalc(dinfo, (GLRMModel.GLRMParameters)GLRM.this._parms, yt, GLRM.this._ncolA, GLRM.this._ncolX, ((GLRMModel.GLRMOutput)model._output)._normSub, ((GLRMModel.GLRMOutput)model._output)._normMul, ((GLRMModel.GLRMParameters)GLRM.this._parms)._gamma_x != 0.0).doAll(dinfo._adaptedFrame);
                ((GLRMModel.GLRMOutput)model._output)._objective = objtsk._loss + ((GLRMModel.GLRMParameters)GLRM.this._parms)._gamma_x * objtsk._xold_reg + ((GLRMModel.GLRMParameters)GLRM.this._parms)._gamma_y * ((GLRMModel.GLRMParameters)GLRM.this._parms).regularize_y(yt);
                ((GLRMModel.GLRMOutput)model._output)._iterations = 0;
                ((GLRMModel.GLRMOutput)model._output)._avg_change_obj = 2.0E-6;
                model.update(GLRM.this._key);
                GLRM.this.update(1L);
                double step = ((GLRMModel.GLRMParameters)GLRM.this._parms)._init_step_size;
                int steps_in_row = 0;
                while (!this.isDone(model, steps_in_row, step)) {
                    UpdateX xtsk = new UpdateX(dinfo, (GLRMModel.GLRMParameters)GLRM.this._parms, yt, step / (double)GLRM.this._ncolA, overwriteX, GLRM.this._ncolA, GLRM.this._ncolX, ((GLRMModel.GLRMOutput)model._output)._normSub, ((GLRMModel.GLRMOutput)model._output)._normMul);
                    xtsk.doAll(dinfo._adaptedFrame);
                    UpdateY ytsk = new UpdateY(dinfo, (GLRMModel.GLRMParameters)GLRM.this._parms, yt, step / (double)GLRM.this._ncolA, GLRM.this._ncolA, GLRM.this._ncolX, ((GLRMModel.GLRMOutput)model._output)._normSub, ((GLRMModel.GLRMOutput)model._output)._normMul);
                    double[][] ytnew = ((UpdateY)ytsk.doAll((Frame)dinfo._adaptedFrame))._ytnew;
                    objtsk = (ObjCalc)new ObjCalc(dinfo, (GLRMModel.GLRMParameters)GLRM.this._parms, ytnew, GLRM.this._ncolA, GLRM.this._ncolX, ((GLRMModel.GLRMOutput)model._output)._normSub, ((GLRMModel.GLRMOutput)model._output)._normMul).doAll(dinfo._adaptedFrame);
                    double obj_new = objtsk._loss + ((GLRMModel.GLRMParameters)GLRM.this._parms)._gamma_x * xtsk._xreg + ((GLRMModel.GLRMParameters)GLRM.this._parms)._gamma_y * ytsk._yreg;
                    ((GLRMModel.GLRMOutput)model._output)._avg_change_obj = (((GLRMModel.GLRMOutput)model._output)._objective - obj_new) / (double)nobs;
                    ++((GLRMModel.GLRMOutput)model._output)._iterations;
                    if (((GLRMModel.GLRMOutput)model._output)._avg_change_obj > 0.0) {
                        yt = ytnew;
                        ((GLRMModel.GLRMOutput)model._output)._archetypes = ytnew;
                        ((GLRMModel.GLRMOutput)model._output)._objective = obj_new;
                        step *= 1.05;
                        steps_in_row = Math.max(1, steps_in_row + 1);
                        overwriteX = true;
                    } else {
                        steps_in_row = Math.min(0, steps_in_row - 1);
                        overwriteX = false;
                        Log.info((Object[])new Object[]{"Iteration " + ((GLRMModel.GLRMOutput)model._output)._iterations + ": Objective increased to " + obj_new + "; reducing step size to " + (step /= Math.max(1.5, (double)(-steps_in_row)))});
                    }
                    ((GLRMModel.GLRMOutput)model._output)._step_size = step;
                    model.update(this.self());
                    GLRM.this.update(1L);
                }
                Vec[] xvecs = new Vec[GLRM.this._ncolX];
                if (overwriteX) {
                    for (i = 0; i < GLRM.this._ncolX; ++i) {
                        xvecs[i] = fr.vec(GLRM.idx_xnew(i, GLRM.this._ncolA, GLRM.this._ncolX));
                    }
                } else {
                    for (i = 0; i < GLRM.this._ncolX; ++i) {
                        xvecs[i] = fr.vec(GLRM.idx_xold(i, GLRM.this._ncolA));
                    }
                }
                x = new Frame(((GLRMModel.GLRMParameters)GLRM.this._parms)._loading_key, null, xvecs);
                xinfo = new DataInfo(Key.make(), x, null, 0, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, false, false, false, false, false);
                DKV.put((Key)x._key, (Iced)x);
                DKV.put((Key)xinfo._key, (Iced)xinfo);
                ((GLRMModel.GLRMOutput)model._output)._loading_key = ((GLRMModel.GLRMParameters)GLRM.this._parms)._loading_key;
                ((GLRMModel.GLRMOutput)model._output)._archetypes = yt;
                ((GLRMModel.GLRMOutput)model._output)._step_size = step;
                if (((GLRMModel.GLRMParameters)GLRM.this._parms)._recover_svd) {
                    this.recoverSVD(model, xinfo);
                }
                model.update(this.self());
                GLRM.this.done();
            }
            catch (Throwable t) {
                Job thisJob = (Job)DKV.getGet((Key)GLRM.this._key);
                if (thisJob._state == Job.JobState.CANCELLED) {
                    Log.info((Object[])new Object[]{"Job cancelled by user."});
                }
                t.printStackTrace();
                GLRM.this.failed(t);
                throw t;
            }
            finally {
                GLRM.this.updateModelOutput();
                ((GLRMModel.GLRMParameters)GLRM.this._parms).read_unlock_frames((Job)GLRM.this);
                if (model != null) {
                    model.unlock(GLRM.this._key);
                }
                if (tinfo != null) {
                    tinfo.remove();
                }
                if (dinfo != null) {
                    dinfo.remove();
                }
                if (xinfo != null) {
                    xinfo.remove();
                }
                if (fr != null) {
                    if (overwriteX) {
                        for (int i = 0; i < GLRM.this._ncolX; ++i) {
                            fr.vec(GLRM.idx_xold(i, GLRM.this._ncolA)).remove();
                        }
                    } else {
                        for (int i = 0; i < GLRM.this._ncolX; ++i) {
                            fr.vec(GLRM.idx_xnew(i, GLRM.this._ncolA, GLRM.this._ncolX)).remove();
                        }
                    }
                }
            }
            this.tryComplete();
        }

        Key self() {
            return GLRM.this._key;
        }
    }

    public static enum Initialization {
        Random,
        SVD,
        PlusPlus,
        User;

    }
}

