/*
 * Decompiled with CFR 0.152.
 */
package hex.svd;

import hex.DataInfo;
import hex.FrameTask;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.gram.Gram;
import hex.schemas.ModelBuilderSchema;
import hex.schemas.SVDV99;
import hex.svd.SVDModel;
import java.util.Arrays;
import water.DKV;
import water.H2O;
import water.Iced;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

public class SVD
extends ModelBuilder<SVDModel, SVDModel.SVDParameters, SVDModel.SVDOutput> {
    private final double TOLERANCE = 1.0E-6;
    private final int MAX_COLS_EXPANDED = 5000;
    private transient int _ncolExp;

    public ModelBuilderSchema schema() {
        return new SVDV99();
    }

    public Job<SVDModel> trainModelImpl(long work, boolean restartTimer) {
        return this.start(new SVDDriver(), work, restartTimer);
    }

    public long progressUnits() {
        return ((SVDModel.SVDParameters)this._parms)._nv + 1;
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.DimReduction};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    public SVD(SVDModel.SVDParameters parms) {
        super("SVD", (Model.Parameters)parms);
        this.init(false);
    }

    public void init(boolean expensive) {
        super.init(expensive);
        if (((SVDModel.SVDParameters)this._parms)._u_name == null || ((SVDModel.SVDParameters)this._parms)._u_name.length() == 0) {
            ((SVDModel.SVDParameters)this._parms)._u_name = "SVDUMatrix_" + Key.rand();
        }
        if (((SVDModel.SVDParameters)this._parms)._max_iterations < 1) {
            this.error("_max_iterations", "max_iterations must be at least 1");
        }
        if (this._train == null) {
            return;
        }
        this._ncolExp = this._train.numColsExp(((SVDModel.SVDParameters)this._parms)._use_all_factor_levels, false);
        if (this._ncolExp > 5000) {
            this.warn("_train", "_train has " + this._ncolExp + " columns when categoricals are expanded. Algorithm may be slow.");
        }
        if (((SVDModel.SVDParameters)this._parms)._nv < 1 || ((SVDModel.SVDParameters)this._parms)._nv > this._ncolExp) {
            this.error("_nv", "Number of right singular values must be between 1 and " + this._ncolExp);
        }
    }

    public double[] powerLoop(Gram gram, long seed) {
        return this.powerLoop(gram, ArrayUtils.gaussianVector((int)gram.fullN(), (long)seed));
    }

    public double[] powerLoop(Gram gram, double[] vinit) {
        assert (vinit.length == gram.fullN());
        double err = 2.0E-6;
        double[] v = (double[])vinit.clone();
        double[] vnew = new double[v.length];
        for (int iters = 0; iters < ((SVDModel.SVDParameters)this._parms)._max_iterations && err > 1.0E-6; ++iters) {
            gram.mul(v, vnew);
            double norm = ArrayUtils.l2norm((double[])vnew);
            err = 0.0;
            for (int i = 0; i < v.length; ++i) {
                int n = i;
                vnew[n] = vnew[n] / norm;
                double diff = v[i] - vnew[i];
                err += diff * diff;
                v[i] = vnew[i];
            }
            err = Math.sqrt(err);
        }
        return v;
    }

    public static double[][] updateIVVSum(double[][] ivv_sum, double[] vec) {
        for (int i = 0; i < vec.length; ++i) {
            for (int j = 0; j < i; ++j) {
                double diff;
                double d = diff = ivv_sum[i][j] - vec[i] * vec[j];
                ivv_sum[j][i] = d;
                ivv_sum[i][j] = d;
            }
            double[] dArray = ivv_sum[i];
            int n = i;
            dArray[n] = dArray[n] - vec[i] * vec[i];
        }
        return ivv_sum;
    }

    private static class GramUpdate
    extends FrameTask<GramUpdate> {
        final double[][] _ivv;
        public Gram _gram;
        public long _nobs;

        public GramUpdate(Key jobKey, DataInfo dinfo, double[][] ivv) {
            super(jobKey, dinfo);
            assert (null != ivv && ivv.length == ivv[0].length);
            this._ivv = ivv;
        }

        @Override
        protected boolean chunkInit() {
            this._gram = new Gram(this._dinfo.fullN(), 0, this._ivv.length, 0, false);
            return true;
        }

        @Override
        protected void processRow(long gid, DataInfo.Row r) {
            double w = 1.0;
            double[] nums = new double[this._ivv.length];
            for (int row = 0; row < this._ivv.length; ++row) {
                nums[row] = r.innerProduct(this._ivv[row]);
            }
            this._gram.addRow(this._dinfo.newDenseRow(nums), w);
            ++this._nobs;
        }

        @Override
        protected void chunkDone(long n) {
            double r = 1.0 / (double)this._nobs;
            this._gram.mul(r);
        }

        public void reduce(GramUpdate gt) {
            double r1 = (double)this._nobs / (double)(this._nobs + gt._nobs);
            this._gram.mul(r1);
            double r2 = (double)gt._nobs / (double)(this._nobs + gt._nobs);
            gt._gram.mul(r2);
            this._gram.add(gt._gram);
            this._nobs += gt._nobs;
        }
    }

    private static class CalcSigmaU
    extends FrameTask<CalcSigmaU> {
        final double[] _svec;
        public double _sval;
        public long _nobs;

        public CalcSigmaU(Key jobKey, DataInfo dinfo, double[] svec) {
            super(jobKey, dinfo);
            this._svec = svec;
            this._sval = 0.0;
        }

        @Override
        protected void processRow(long gid, DataInfo.Row r, NewChunk[] outputs) {
            double num = r.innerProduct(this._svec);
            outputs[0].addNum(num);
            this._sval += num * num;
            ++this._nobs;
        }

        public void reduce(CalcSigmaU other) {
            this._nobs += other._nobs;
            this._sval += other._sval;
        }

        protected void postGlobal() {
            this._sval = Math.sqrt(this._sval);
        }
    }

    class SVDDriver
    extends H2O.H2OCountedCompleter<SVDDriver> {
        SVDDriver() {
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        protected void compute2() {
            SVDModel model = null;
            DataInfo dinfo = null;
            Frame u = null;
            Vec[] uvecs = null;
            try {
                SVD.this.init(true);
                ((SVDModel.SVDParameters)SVD.this._parms).read_lock_frames((Job)SVD.this);
                if (SVD.this.error_count() > 0) {
                    throw new IllegalArgumentException("Found validation errors: " + SVD.this.validationErrors());
                }
                model = new SVDModel(SVD.this.dest(), (SVDModel.SVDParameters)SVD.this._parms, new SVDModel.SVDOutput(SVD.this));
                model.delete_and_lock(this.self());
                dinfo = new DataInfo(Key.make(), SVD.this._train, SVD.this._valid, 0, ((SVDModel.SVDParameters)SVD.this._parms)._use_all_factor_levels, ((SVDModel.SVDParameters)SVD.this._parms)._transform, DataInfo.TransformType.NONE, true, false, false, false, false, false);
                DKV.put((Key)dinfo._key, (Iced)dinfo);
                double[] dArray = ((SVDModel.SVDOutput)model._output)._normSub = dinfo._normSub == null ? new double[dinfo._nums] : dinfo._normSub;
                if (dinfo._normMul == null) {
                    ((SVDModel.SVDOutput)model._output)._normMul = new double[dinfo._nums];
                    Arrays.fill(((SVDModel.SVDOutput)model._output)._normMul, 1.0);
                } else {
                    ((SVDModel.SVDOutput)model._output)._normMul = dinfo._normMul;
                }
                ((SVDModel.SVDOutput)model._output)._permutation = dinfo._permutation;
                ((SVDModel.SVDOutput)model._output)._nnums = dinfo._nums;
                ((SVDModel.SVDOutput)model._output)._ncats = dinfo._cats;
                ((SVDModel.SVDOutput)model._output)._catOffsets = dinfo._catOffsets;
                ((SVDModel.SVDOutput)model._output)._names_expanded = dinfo.coefNames();
                Gram.GramTask gtsk = (Gram.GramTask)new Gram.GramTask(this.self(), dinfo).doAll(dinfo._adaptedFrame);
                Gram gram = gtsk._gram;
                assert (gram.fullN() == SVD.this._ncolExp);
                ((SVDModel.SVDOutput)model._output)._nobs = gtsk._nobs;
                ((SVDModel.SVDOutput)model._output)._v = new double[((SVDModel.SVDParameters)SVD.this._parms)._nv][SVD.this._ncolExp];
                ((SVDModel.SVDOutput)model._output)._total_variance = gram.diagSum() * (double)gtsk._nobs / (double)(gtsk._nobs - 1L);
                model.update(this.self());
                SVD.this.update(1L);
                ((SVDModel.SVDOutput)model._output)._v[0] = SVD.this.powerLoop(gram, ((SVDModel.SVDParameters)SVD.this._parms)._seed);
                double[][] ivv_sum = new double[SVD.this._ncolExp][SVD.this._ncolExp];
                for (int i = 0; i < SVD.this._ncolExp; ++i) {
                    ivv_sum[i][i] = 1.0;
                }
                if (!((SVDModel.SVDParameters)SVD.this._parms)._only_v) {
                    ((SVDModel.SVDOutput)model._output)._d = new double[((SVDModel.SVDParameters)SVD.this._parms)._nv];
                    ((SVDModel.SVDOutput)model._output)._u_key = Key.make((String)((SVDModel.SVDParameters)SVD.this._parms)._u_name);
                    uvecs = new Vec[((SVDModel.SVDParameters)SVD.this._parms)._nv];
                    double[] ivv_vk = ArrayUtils.multArrVec((double[][])ivv_sum, (double[])((SVDModel.SVDOutput)model._output)._v[0]);
                    CalcSigmaU ctsk = (CalcSigmaU)new CalcSigmaU(this.self(), dinfo, ivv_vk).doAll(1, dinfo._adaptedFrame);
                    ((SVDModel.SVDOutput)model._output)._d[0] = ctsk._sval;
                    assert (ctsk._nobs == ((SVDModel.SVDOutput)model._output)._nobs) : "Processed " + ctsk._nobs + " rows but expected " + ((SVDModel.SVDOutput)model._output)._nobs;
                    Frame tmp = ctsk.outputFrame();
                    uvecs[0] = tmp.vec(0);
                    tmp.unlock(this.self());
                }
                model.update(this.self());
                SVD.this.update(1L);
                SVD.updateIVVSum(ivv_sum, ((SVDModel.SVDOutput)model._output)._v[0]);
                GramUpdate guptsk = (GramUpdate)new GramUpdate(this.self(), dinfo, ivv_sum).doAll(dinfo._adaptedFrame);
                Gram gram_update = guptsk._gram;
                for (int k = 1; k < ((SVDModel.SVDParameters)SVD.this._parms)._nv && SVD.this.isRunning(); ++k) {
                    ((SVDModel.SVDOutput)model._output)._v[k] = SVD.this.powerLoop(gram_update, ((SVDModel.SVDParameters)SVD.this._parms)._seed);
                    if (!((SVDModel.SVDParameters)SVD.this._parms)._only_v) {
                        double[] ivv_vk = ArrayUtils.multArrVec((double[][])ivv_sum, (double[])((SVDModel.SVDOutput)model._output)._v[k]);
                        CalcSigmaU ctsk = (CalcSigmaU)new CalcSigmaU(this.self(), dinfo, ivv_vk).doAll(1, dinfo._adaptedFrame);
                        ((SVDModel.SVDOutput)model._output)._d[k] = ctsk._sval;
                        assert (ctsk._nobs == ((SVDModel.SVDOutput)model._output)._nobs) : "Processed " + ctsk._nobs + " rows but expected " + ((SVDModel.SVDOutput)model._output)._nobs;
                        Frame tmp = ctsk.outputFrame();
                        uvecs[k] = tmp.vec(0);
                        tmp.unlock(this.self());
                    }
                    SVD.updateIVVSum(ivv_sum, ((SVDModel.SVDOutput)model._output)._v[k]);
                    guptsk = (GramUpdate)new GramUpdate(this.self(), dinfo, ivv_sum).doAll(dinfo._adaptedFrame);
                    gram_update = guptsk._gram;
                    model.update(this.self());
                    SVD.this.update(1L);
                }
                ((SVDModel.SVDOutput)model._output)._v = ArrayUtils.transpose((double[][])((SVDModel.SVDOutput)model._output)._v);
                if (!((SVDModel.SVDParameters)SVD.this._parms)._only_v && ((SVDModel.SVDParameters)SVD.this._parms)._keep_u) {
                    u = new Frame(((SVDModel.SVDOutput)model._output)._u_key, null, uvecs);
                    DKV.put((Key)u._key, (Iced)u);
                    final double[] sigma = ((SVDModel.SVDOutput)model._output)._d;
                    new MRTask(){

                        public void map(Chunk[] cs) {
                            for (int col = 0; col < cs.length; ++col) {
                                for (int row = 0; row < cs[0].len(); ++row) {
                                    double x = cs[col].atd(row);
                                    cs[col].set(row, x / sigma[col]);
                                }
                            }
                        }
                    }.doAll(u);
                }
                model.update(this.self());
                SVD.this.done();
            }
            catch (Throwable t) {
                try {
                    Job thisJob = (Job)DKV.getGet((Key)SVD.this._key);
                    if (thisJob._state != Job.JobState.CANCELLED) {
                        t.printStackTrace();
                        SVD.this.failed(t);
                        throw t;
                    }
                    Log.info((Object[])new Object[]{"Job cancelled by user."});
                }
                catch (Throwable throwable) {
                    SVD.this.updateModelOutput();
                    if (model != null) {
                        model.unlock(SVD.this._key);
                    }
                    if (dinfo != null) {
                        dinfo.remove();
                    }
                    if (u != null & !((SVDModel.SVDParameters)SVD.this._parms)._keep_u) {
                        u.delete();
                    }
                    ((SVDModel.SVDParameters)SVD.this._parms).read_unlock_frames((Job)SVD.this);
                    throw throwable;
                }
                SVD.this.updateModelOutput();
                if (model != null) {
                    model.unlock(SVD.this._key);
                }
                if (dinfo != null) {
                    dinfo.remove();
                }
                if (u != null & !((SVDModel.SVDParameters)SVD.this._parms)._keep_u) {
                    u.delete();
                }
                ((SVDModel.SVDParameters)SVD.this._parms).read_unlock_frames((Job)SVD.this);
            }
            SVD.this.updateModelOutput();
            if (model != null) {
                model.unlock(SVD.this._key);
            }
            if (dinfo != null) {
                dinfo.remove();
            }
            if (u != null & !((SVDModel.SVDParameters)SVD.this._parms)._keep_u) {
                u.delete();
            }
            ((SVDModel.SVDParameters)SVD.this._parms).read_unlock_frames((Job)SVD.this);
            this.tryComplete();
        }

        Key self() {
            return SVD.this._key;
        }
    }
}

