/*
 * Decompiled with CFR 0.152.
 */
package hex.svd;

import hex.DataInfo;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.gram.Gram;
import hex.schemas.ModelBuilderSchema;
import hex.schemas.SVDV3;
import hex.svd.SVDModel;
import java.util.Arrays;
import water.DKV;
import water.H2O;
import water.Iced;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

public class SVD
extends ModelBuilder<SVDModel, SVDModel.SVDParameters, SVDModel.SVDOutput> {
    private final double TOLERANCE = 1.0E-6;
    private final int MAX_COLS_EXPANDED = 5000;
    private transient int _ncolExp;

    public ModelBuilderSchema schema() {
        return new SVDV3();
    }

    public Job<SVDModel> trainModel() {
        return this.start(new SVDDriver(), 0L);
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.DimReduction};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    public SVD(SVDModel.SVDParameters parms) {
        super("SVD", (Model.Parameters)parms);
        this.init(false);
    }

    public void init(boolean expensive) {
        super.init(expensive);
        if (((SVDModel.SVDParameters)this._parms)._u_key == null) {
            ((SVDModel.SVDParameters)this._parms)._u_key = Key.make((String)("SVDUMatrix_" + Key.rand()));
        }
        if (((SVDModel.SVDParameters)this._parms)._max_iterations < 1) {
            this.error("_max_iterations", "max_iterations must be at least 1");
        }
        if (this._train == null) {
            return;
        }
        this._ncolExp = this._train.numColsExp(((SVDModel.SVDParameters)this._parms)._useAllFactorLevels, false);
        if (this._ncolExp > 5000) {
            this.warn("_train", "_train has " + this._ncolExp + " columns when categoricals are expanded. Algorithm may be slow.");
        }
        if (((SVDModel.SVDParameters)this._parms)._nv < 1 || ((SVDModel.SVDParameters)this._parms)._nv > this._ncolExp) {
            this.error("_nv", "Number of right singular values must be between 1 and " + this._ncolExp);
        }
    }

    public double[] powerLoop(double[][] gram) {
        return this.powerLoop(gram, ArrayUtils.gaussianVector((int)gram[0].length));
    }

    public double[] powerLoop(double[][] gram, long seed) {
        return this.powerLoop(gram, ArrayUtils.gaussianVector((int)gram[0].length, (long)seed));
    }

    public double[] powerLoop(double[][] gram, double[] vinit) {
        assert (gram.length == gram[0].length);
        assert (vinit.length == gram.length);
        double err = 2.0E-6;
        double[] v = (double[])vinit.clone();
        double[] vnew = new double[v.length];
        for (int iters = 0; iters < ((SVDModel.SVDParameters)this._parms)._max_iterations && err > 1.0E-6; ++iters) {
            for (int i = 0; i < v.length; ++i) {
                vnew[i] = ArrayUtils.innerProduct((double[])gram[i], (double[])v);
            }
            double norm = ArrayUtils.l2norm((double[])vnew);
            err = 0.0;
            for (int i = 0; i < v.length; ++i) {
                int n = i;
                vnew[n] = vnew[n] / norm;
                double diff = v[i] - vnew[i];
                err += diff * diff;
                v[i] = vnew[i];
            }
            err = Math.sqrt(err);
        }
        return v;
    }

    public double[][] sub_symm(double[][] lmat, double[][] rmat) {
        for (int i = 0; i < rmat.length; ++i) {
            for (int j = 0; j < i; ++j) {
                double diff;
                double d = diff = lmat[i][j] - rmat[i][j];
                lmat[j][i] = d;
                lmat[i][j] = d;
            }
            double[] dArray = lmat[i];
            int n = i;
            dArray[n] = dArray[n] - rmat[i][i];
        }
        return lmat;
    }

    protected static Chunk chk_u(Chunk[] chks, int c, int ncols) {
        return chks[ncols + c];
    }

    private static double l2norm2(Chunk[] cs, double[] vec, int k, DataInfo dinfo, double[] normSub, double[] normMul) {
        double sumsqr = 0.0;
        int ncols = dinfo._adaptedFrame.numCols();
        for (int row = 0; row < cs[0]._len; ++row) {
            double sum = 0.0;
            for (int j = 0; j < dinfo._cats; ++j) {
                int i = (int)cs[j].atd(row);
                sum += vec[dinfo._catOffsets[j] + i];
            }
            int cidx = dinfo._cats;
            int vidx = dinfo.numStart();
            for (int j = 0; j < dinfo._nums; ++j) {
                double a = cs[cidx].atd(row);
                sum += (a - normSub[j]) * normMul[j] * vec[vidx];
                ++cidx;
                ++vidx;
            }
            assert (cidx == ncols && vidx == vec.length);
            sumsqr += sum * sum;
            SVD.chk_u(cs, k, ncols).set(row, sum);
        }
        return sumsqr;
    }

    private static void div(Chunk chk, double norm) {
        for (int row = 0; row < chk._len; ++row) {
            double tmp = chk.atd(row);
            chk.set(row, tmp / norm);
        }
    }

    private static class CalcSigmaUNorm
    extends MRTask<CalcSigmaUNorm> {
        DataInfo _dinfo;
        SVDModel.SVDParameters _parms;
        final int _k;
        final double[] _svec;
        final double _sval_old;
        final double[] _normSub;
        final double[] _normMul;
        final int _ncols;
        double _sval;

        CalcSigmaUNorm(DataInfo dinfo, SVDModel.SVDParameters parms, double[] svec, int k, double sval_old, double[] normSub, double[] normMul) {
            assert (k >= 1) : "Index of singular vector k must be at least 1";
            this._dinfo = dinfo;
            this._parms = parms;
            this._k = k;
            this._svec = svec;
            this._normSub = normSub;
            this._normMul = normMul;
            this._ncols = this._dinfo._adaptedFrame.numCols();
            this._sval_old = sval_old;
            this._sval = 0.0;
        }

        public void map(Chunk[] cs) {
            assert (cs.length - this._ncols == this._parms._nv);
            this._sval += SVD.l2norm2(cs, this._svec, this._k, this._dinfo, this._normSub, this._normMul);
            SVD.div(SVD.chk_u(cs, this._k - 1, this._ncols), this._sval_old);
        }

        protected void postGlobal() {
            this._sval = Math.sqrt(this._sval);
        }
    }

    private static class CalcSigmaU
    extends MRTask<CalcSigmaU> {
        DataInfo _dinfo;
        SVDModel.SVDParameters _parms;
        final double[] _normSub;
        final double[] _normMul;
        final int _ncols;
        final double[] _svec;
        double _sval;

        CalcSigmaU(DataInfo dinfo, SVDModel.SVDParameters parms, double[] svec, double[] normSub, double[] normMul) {
            this._dinfo = dinfo;
            this._parms = parms;
            this._svec = svec;
            this._normSub = normSub;
            this._normMul = normMul;
            this._ncols = this._dinfo._adaptedFrame.numCols();
            this._sval = 0.0;
        }

        public void map(Chunk[] cs) {
            assert (cs.length - this._ncols == this._parms._nv);
            this._sval += SVD.l2norm2(cs, this._svec, 0, this._dinfo, this._normSub, this._normMul);
        }

        protected void postGlobal() {
            this._sval = Math.sqrt(this._sval);
        }
    }

    class SVDDriver
    extends H2O.H2OCountedCompleter<SVDDriver> {
        SVDDriver() {
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        protected void compute2() {
            Frame u;
            DataInfo dinfo;
            DataInfo uinfo;
            block28: {
                SVDModel model = null;
                uinfo = null;
                dinfo = null;
                Frame fr = null;
                u = null;
                try {
                    double[] ivv_vk;
                    ((SVDModel.SVDParameters)SVD.this._parms).read_lock_frames((Job)SVD.this);
                    SVD.this.init(true);
                    if (SVD.this.error_count() > 0) {
                        throw new IllegalArgumentException("Found validation errors: " + SVD.this.validationErrors());
                    }
                    model = new SVDModel(SVD.this.dest(), (SVDModel.SVDParameters)SVD.this._parms, new SVDModel.SVDOutput(SVD.this));
                    model.delete_and_lock(this.self());
                    dinfo = new DataInfo(Key.make(), SVD.this._train, null, 0, ((SVDModel.SVDParameters)SVD.this._parms)._useAllFactorLevels, ((SVDModel.SVDParameters)SVD.this._parms)._transform, DataInfo.TransformType.NONE, true, false);
                    DKV.put((Key)dinfo._key, (Iced)dinfo);
                    double[] dArray = ((SVDModel.SVDOutput)model._output)._normSub = dinfo._normSub == null ? new double[dinfo._nums] : dinfo._normSub;
                    if (dinfo._normMul == null) {
                        ((SVDModel.SVDOutput)model._output)._normMul = new double[dinfo._nums];
                        Arrays.fill(((SVDModel.SVDOutput)model._output)._normMul, 1.0);
                    } else {
                        ((SVDModel.SVDOutput)model._output)._normMul = dinfo._normMul;
                    }
                    ((SVDModel.SVDOutput)model._output)._permutation = dinfo._permutation;
                    ((SVDModel.SVDOutput)model._output)._nnums = dinfo._nums;
                    ((SVDModel.SVDOutput)model._output)._ncats = dinfo._cats;
                    ((SVDModel.SVDOutput)model._output)._catOffsets = dinfo._catOffsets;
                    Gram.GramTask tsk = (Gram.GramTask)new Gram.GramTask(this.self(), dinfo).doAll(dinfo._adaptedFrame);
                    double[][] gram = tsk._gram.getXX();
                    double[] sigma = new double[((SVDModel.SVDParameters)SVD.this._parms)._nv];
                    double[][] rsvec = new double[((SVDModel.SVDParameters)SVD.this._parms)._nv][gram.length];
                    assert (gram.length == SVD.this._ncolExp);
                    rsvec[0] = SVD.this.powerLoop(gram, ((SVDModel.SVDParameters)SVD.this._parms)._seed);
                    double[][] ivv_sum = new double[gram.length][gram.length];
                    for (int i = 0; i < gram.length; ++i) {
                        ivv_sum[i][i] = 1.0;
                    }
                    if (!((SVDModel.SVDParameters)SVD.this._parms)._only_v) {
                        Vec[] vecs = new Vec[SVD.this._train.numCols() + ((SVDModel.SVDParameters)SVD.this._parms)._nv];
                        Vec[] uvecs = new Vec[((SVDModel.SVDParameters)SVD.this._parms)._nv];
                        for (int i = 0; i < SVD.this._train.numCols(); ++i) {
                            vecs[i] = SVD.this._train.vec(i);
                        }
                        int c = 0;
                        for (int i = SVD.this._train.numCols(); i < vecs.length; ++i) {
                            vecs[i] = SVD.this._train.anyVec().makeZero();
                            uvecs[c++] = vecs[i];
                        }
                        assert (c == uvecs.length);
                        fr = new Frame(null, vecs);
                        u = new Frame(((SVDModel.SVDParameters)SVD.this._parms)._u_key, null, uvecs);
                        uinfo = new DataInfo(Key.make(), fr, null, 0, false, ((SVDModel.SVDParameters)SVD.this._parms)._transform, DataInfo.TransformType.NONE, true, false);
                        DKV.put((Key)uinfo._key, (Iced)uinfo);
                        DKV.put((Key)u._key, (Iced)u);
                        ivv_vk = ArrayUtils.multArrVec((double[][])ivv_sum, (double[])rsvec[0]);
                        sigma[0] = ((CalcSigmaU)new CalcSigmaU((DataInfo)dinfo, (SVDModel.SVDParameters)((SVDModel.SVDParameters)SVD.this._parms), (double[])ivv_vk, (double[])((SVDModel.SVDOutput)model._output)._normSub, (double[])((SVDModel.SVDOutput)model._output)._normMul).doAll((Frame)uinfo._adaptedFrame))._sval;
                    }
                    double[][] vv = ArrayUtils.outerProduct((double[])rsvec[0], (double[])rsvec[0]);
                    ivv_sum = SVD.this.sub_symm(ivv_sum, vv);
                    double[][] gram_update = ArrayUtils.multArrArr((double[][])ArrayUtils.multArrArr((double[][])ivv_sum, (double[][])gram), (double[][])ivv_sum);
                    for (int k = 1; k < ((SVDModel.SVDParameters)SVD.this._parms)._nv; ++k) {
                        rsvec[k] = SVD.this.powerLoop(gram_update, ((SVDModel.SVDParameters)SVD.this._parms)._seed);
                        if (!((SVDModel.SVDParameters)SVD.this._parms)._only_v) {
                            ivv_vk = ArrayUtils.multArrVec((double[][])ivv_sum, (double[])rsvec[k]);
                            sigma[k] = ((CalcSigmaUNorm)new CalcSigmaUNorm((DataInfo)dinfo, (SVDModel.SVDParameters)((SVDModel.SVDParameters)SVD.this._parms), (double[])ivv_vk, (int)k, (double)sigma[k - 1], (double[])((SVDModel.SVDOutput)model._output)._normSub, (double[])((SVDModel.SVDOutput)model._output)._normMul).doAll((Frame)uinfo._adaptedFrame))._sval;
                        }
                        vv = ArrayUtils.outerProduct((double[])rsvec[k], (double[])rsvec[k]);
                        ivv_sum = SVD.this.sub_symm(ivv_sum, vv);
                        double[][] lmat = ArrayUtils.multArrArr((double[][])ivv_sum, (double[][])gram);
                        gram_update = ArrayUtils.multArrArr((double[][])lmat, (double[][])ivv_sum);
                        model.update(this.self());
                        SVD.this.update(1L);
                    }
                    ((SVDModel.SVDOutput)model._output)._v = ArrayUtils.transpose((double[][])rsvec);
                    if (!((SVDModel.SVDParameters)SVD.this._parms)._only_v) {
                        ((SVDModel.SVDOutput)model._output)._d = sigma;
                        if (((SVDModel.SVDParameters)SVD.this._parms)._keep_u) {
                            final int idx = ((SVDModel.SVDParameters)SVD.this._parms)._nv - 1;
                            final int ncols = SVD.this._train.numCols();
                            final double sigma_last = sigma[((SVDModel.SVDParameters)SVD.this._parms)._nv - 1];
                            new MRTask(){

                                public void map(Chunk[] cs) {
                                    SVD.div(SVD.chk_u(cs, idx, ncols), sigma_last);
                                }
                            }.doAll(uinfo._adaptedFrame);
                            ((SVDModel.SVDOutput)model._output)._u_key = ((SVDModel.SVDParameters)SVD.this._parms)._u_key;
                        }
                    }
                    model.update(this.self());
                    SVD.this.done();
                    if (model == null) break block28;
                }
                catch (Throwable t) {
                    block29: {
                        try {
                            Job thisJob = (Job)DKV.getGet((Key)SVD.this._key);
                            if (thisJob._state != Job.JobState.CANCELLED) {
                                t.printStackTrace();
                                SVD.this.failed(t);
                                throw t;
                            }
                            Log.info((Object[])new Object[]{"Job cancelled by user."});
                            if (model == null) break block29;
                        }
                        catch (Throwable throwable) {
                            if (model != null) {
                                model.unlock(SVD.this._key);
                            }
                            if (dinfo != null) {
                                dinfo.remove();
                            }
                            if (uinfo != null) {
                                uinfo.remove();
                            }
                            if (u != null & !((SVDModel.SVDParameters)SVD.this._parms)._keep_u) {
                                u.delete();
                            }
                            ((SVDModel.SVDParameters)SVD.this._parms).read_unlock_frames((Job)SVD.this);
                            throw throwable;
                        }
                        model.unlock(SVD.this._key);
                    }
                    if (dinfo != null) {
                        dinfo.remove();
                    }
                    if (uinfo != null) {
                        uinfo.remove();
                    }
                    if (u != null & !((SVDModel.SVDParameters)SVD.this._parms)._keep_u) {
                        u.delete();
                    }
                    ((SVDModel.SVDParameters)SVD.this._parms).read_unlock_frames((Job)SVD.this);
                }
                model.unlock(SVD.this._key);
            }
            if (dinfo != null) {
                dinfo.remove();
            }
            if (uinfo != null) {
                uinfo.remove();
            }
            if (u != null & !((SVDModel.SVDParameters)SVD.this._parms)._keep_u) {
                u.delete();
            }
            ((SVDModel.SVDParameters)SVD.this._parms).read_unlock_frames((Job)SVD.this);
            this.tryComplete();
        }

        Key self() {
            return SVD.this._key;
        }
    }
}

