/*
 * Decompiled with CFR 0.152.
 */
package hex.glrm;

import hex.DataInfo;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.genmodel.algos.glrm.GlrmInitialization;
import hex.genmodel.algos.glrm.GlrmLoss;
import hex.genmodel.algos.glrm.GlrmRegularizer;
import hex.glrm.GLRM;
import hex.glrm.GlrmMojoWriter;
import hex.glrm.ModelMetricsGLRM;
import hex.svd.SVDModel;
import java.util.ArrayList;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Job;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

public class GLRMModel
extends Model<GLRMModel, GLRMParameters, GLRMOutput>
implements Model.GLRMArchetypes {
    public GLRMModel(Key<GLRMModel> selfKey, GLRMParameters parms, GLRMOutput output) {
        super(selfKey, (Model.Parameters)parms, (Model.Output)output);
    }

    protected Futures remove_impl(Futures fs) {
        if (((GLRMOutput)this._output)._init_key != null) {
            ((GLRMOutput)this._output)._init_key.remove(fs);
        }
        if (((GLRMOutput)this._output)._representation_key != null) {
            ((GLRMOutput)this._output)._representation_key.remove(fs);
        }
        return super.remove_impl(fs);
    }

    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        ab.putKey(((GLRMOutput)this._output)._init_key);
        ab.putKey(((GLRMOutput)this._output)._representation_key);
        return super.writeAll_impl(ab);
    }

    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        ab.getKey(((GLRMOutput)this._output)._init_key, fs);
        ab.getKey(((GLRMOutput)this._output)._representation_key, fs);
        return super.readAll_impl(ab, fs);
    }

    public GlrmMojoWriter getMojo() {
        return new GlrmMojoWriter(this);
    }

    private Frame reconstruct(Frame orig, Frame adaptedFr, Key<Frame> destination_key, boolean save_imputed, boolean reverse_transform) {
        int ncols = ((GLRMOutput)this._output)._names.length;
        assert (ncols == adaptedFr.numCols());
        String prefix = "reconstr_";
        Frame fullFrm = new Frame(adaptedFr);
        Frame loadingFrm = (Frame)DKV.get(((GLRMOutput)this._output)._representation_key).get();
        fullFrm.add(loadingFrm);
        String[][] adaptedDomme = adaptedFr.domains();
        Vec anyVec = fullFrm.anyVec();
        assert (anyVec != null);
        for (int i = 0; i < ncols; ++i) {
            Vec v = anyVec.makeZero();
            v.setDomain(adaptedDomme[i]);
            fullFrm.add(prefix + ((GLRMOutput)this._output)._names[i], v);
        }
        GLRMScore gs = (GLRMScore)new GLRMScore(ncols, ((GLRMParameters)this._parms)._k, save_imputed, reverse_transform).doAll(fullFrm);
        int x = ncols + ((GLRMParameters)this._parms)._k;
        int y = fullFrm.numCols();
        Frame f = fullFrm.extractFrame(x, y);
        f = new Frame(destination_key == null ? Key.make() : destination_key, f.names(), f.vecs());
        DKV.put((Keyed)f);
        gs._mb.makeModelMetrics(this, orig, null, null);
        return f;
    }

    protected Frame predictScoreImpl(Frame orig, Frame adaptedFr, String destination_key, Job j, boolean computeMetrics) {
        return this.reconstruct(orig, adaptedFr, (Key<Frame>)Key.make((String)destination_key), true, ((GLRMParameters)this._parms)._impute_original);
    }

    public Frame scoreReconstruction(Frame frame, Key<Frame> destination_key, boolean reverse_transform) {
        Frame adaptedFr = new Frame(frame);
        this.adaptTestForTrain(adaptedFr, true, false);
        return this.reconstruct(frame, adaptedFr, destination_key, true, reverse_transform);
    }

    public Frame scoreArchetypes(Frame frame, Key<Frame> destination_key, boolean reverse_transform) {
        int k;
        int d;
        int ncols = ((GLRMOutput)this._output)._names.length;
        Frame adaptedFr = new Frame(frame);
        this.adaptTestForTrain(adaptedFr, true, false);
        assert (ncols == adaptedFr.numCols());
        String[][] adaptedDomme = adaptedFr.domains();
        double[][] proj = new double[((GLRMParameters)this._parms)._k][((GLRMOutput)this._output)._nnums + ((GLRMOutput)this._output)._ncats];
        for (d = 0; d < ((GLRMOutput)this._output)._ncats; ++d) {
            double[][] block = ((GLRMOutput)this._output)._archetypes_raw.getCatBlock(d);
            for (k = 0; k < ((GLRMParameters)this._parms)._k; ++k) {
                proj[k][((GLRMOutput)this._output)._permutation[d]] = ((GLRMOutput)this._output)._lossFunc[d].mimpute(block[k]);
            }
        }
        for (d = ((GLRMOutput)this._output)._ncats; d < ((GLRMOutput)this._output)._ncats + ((GLRMOutput)this._output)._nnums; ++d) {
            int ds = d - ((GLRMOutput)this._output)._ncats;
            for (k = 0; k < ((GLRMParameters)this._parms)._k; ++k) {
                double num = ((GLRMOutput)this._output)._archetypes_raw.getNum(ds, k);
                proj[k][((GLRMOutput)this._output)._permutation[d]] = ((GLRMOutput)this._output)._lossFunc[d].impute(num);
                if (!reverse_transform) continue;
                proj[k][((GLRMOutput)this._output)._permutation[d]] = proj[k][((GLRMOutput)this._output)._permutation[d]] / ((GLRMOutput)this._output)._normMul[ds] + ((GLRMOutput)this._output)._normSub[ds];
            }
        }
        Frame f = ArrayUtils.frame(destination_key == null ? Key.make() : destination_key, (String[])adaptedFr.names(), (double[][])proj);
        for (int i = 0; i < ncols; ++i) {
            f.vec(i).setDomain(adaptedDomme[i]);
        }
        return f;
    }

    protected double[] score0(double[] data, double[] preds) {
        throw H2O.unimpl();
    }

    public ModelMetricsGLRM scoreMetricsOnly(Frame frame) {
        if (frame == null) {
            return null;
        }
        int ncols = ((GLRMOutput)this._output)._names.length;
        Frame adaptedFr = new Frame(frame);
        this.adaptTestForTrain(adaptedFr, true, false);
        assert (ncols == adaptedFr.numCols());
        Frame fullFrm = new Frame(adaptedFr);
        Frame loadingFrm = (Frame)DKV.get(((GLRMOutput)this._output)._representation_key).get();
        fullFrm.add(loadingFrm);
        GLRMScore gs = (GLRMScore)new GLRMScore(ncols, ((GLRMParameters)this._parms)._k, false).doAll(fullFrm);
        ModelMetrics mm = gs._mb.makeModelMetrics(this, adaptedFr, null, null);
        return (ModelMetricsGLRM)mm;
    }

    public ModelMetricsGLRM.GlrmModelMetricsBuilder makeMetricBuilder(String[] domain) {
        return new ModelMetricsGLRM.GlrmModelMetricsBuilder(((GLRMParameters)this._parms)._k, ((GLRMOutput)this._output)._permutation, ((GLRMParameters)this._parms)._impute_original);
    }

    private class GLRMScore
    extends MRTask<GLRMScore> {
        final int _ncolA;
        final int _ncolX;
        final boolean _save_imputed;
        final boolean _reverse_transform;
        ModelMetricsGLRM.GlrmModelMetricsBuilder _mb;

        GLRMScore(int ncolA, int ncolX, boolean save_imputed) {
            this(ncolA, ncolX, save_imputed, ((GLRMParameters)gLRMModel._parms)._impute_original);
        }

        GLRMScore(int ncolA, int ncolX, boolean save_imputed, boolean reverse_transform) {
            this._ncolA = ncolA;
            this._ncolX = ncolX;
            this._save_imputed = save_imputed;
            this._reverse_transform = reverse_transform;
        }

        public void map(Chunk[] chks) {
            float[] atmp = new float[this._ncolA];
            double[] xtmp = new double[this._ncolX];
            double[] preds = new double[this._ncolA];
            this._mb = GLRMModel.this.makeMetricBuilder(null);
            if (this._save_imputed) {
                for (int row = 0; row < chks[0]._len; ++row) {
                    double[] p = this.impute_data(chks, row, xtmp, preds);
                    this.compute_metrics(chks, row, atmp, p);
                    for (int c = 0; c < preds.length; ++c) {
                        chks[this._ncolA + this._ncolX + c].set(row, p[c]);
                    }
                }
            } else {
                for (int row = 0; row < chks[0]._len; ++row) {
                    double[] p = this.impute_data(chks, row, xtmp, preds);
                    this.compute_metrics(chks, row, atmp, p);
                }
            }
        }

        public void reduce(GLRMScore other) {
            if (this._mb != null) {
                this._mb.reduce(other._mb);
            }
        }

        protected void postGlobal() {
            if (this._mb != null) {
                this._mb.postGlobal();
            }
        }

        private float[] compute_metrics(Chunk[] chks, int row_in_chunk, float[] tmp, double[] preds) {
            for (int i = 0; i < tmp.length; ++i) {
                tmp[i] = (float)chks[i].atd(row_in_chunk);
            }
            this._mb.perRow(preds, tmp, GLRMModel.this);
            return tmp;
        }

        private double[] impute_data(Chunk[] chks, int row_in_chunk, double[] tmp, double[] preds) {
            for (int i = 0; i < tmp.length; ++i) {
                tmp[i] = chks[this._ncolA + i].atd(row_in_chunk);
            }
            this.impute_data(tmp, preds);
            return preds;
        }

        private double[] impute_data(double[] tmp, double[] preds) {
            int d;
            assert (preds.length == ((GLRMOutput)GLRMModel.this._output)._nnums + ((GLRMOutput)GLRMModel.this._output)._ncats);
            for (d = 0; d < ((GLRMOutput)GLRMModel.this._output)._ncats; ++d) {
                double[] xyblock = ((GLRMOutput)GLRMModel.this._output)._archetypes_raw.lmulCatBlock(tmp, d);
                preds[((GLRMOutput)GLRMModel.this._output)._permutation[d]] = ((GLRMOutput)GLRMModel.this._output)._lossFunc[d].mimpute(xyblock);
            }
            for (d = ((GLRMOutput)GLRMModel.this._output)._ncats; d < preds.length; ++d) {
                int ds = d - ((GLRMOutput)GLRMModel.this._output)._ncats;
                double xy = ((GLRMOutput)GLRMModel.this._output)._archetypes_raw.lmulNumCol(tmp, ds);
                preds[((GLRMOutput)GLRMModel.this._output)._permutation[d]] = ((GLRMOutput)GLRMModel.this._output)._lossFunc[d].impute(xy);
                if (!this._reverse_transform) continue;
                preds[((GLRMOutput)GLRMModel.this._output)._permutation[d]] = preds[((GLRMOutput)GLRMModel.this._output)._permutation[d]] / ((GLRMOutput)GLRMModel.this._output)._normMul[ds] + ((GLRMOutput)GLRMModel.this._output)._normSub[ds];
            }
            return preds;
        }
    }

    public static class GLRMOutput
    extends Model.Output {
        public int _iterations;
        public int _updates;
        public double _objective;
        public double _step_size;
        public double _avg_change_obj;
        public ArrayList<Double> _history_objective = new ArrayList();
        public TwoDimTable _archetypes;
        public GLRM.Archetypes _archetypes_raw;
        public ArrayList<Double> _history_step_size = new ArrayList();
        public double[][] _eigenvectors_raw;
        public TwoDimTable _eigenvectors;
        public double[] _singular_vals;
        public String _representation_name;
        public Key<Frame> _representation_key;
        public Key<? extends Model> _init_key;
        public int _ncats;
        public int _nnums;
        public long _nobs;
        public int[] _catOffsets;
        public double[] _normSub;
        public double[] _normMul;
        public int[] _permutation;
        public String[] _names_expanded;
        public GlrmLoss[] _lossFunc;
        public ArrayList<Long> _training_time_ms = new ArrayList();

        public GLRMOutput(GLRM b) {
            super((ModelBuilder)b);
        }

        public int nfeatures() {
            return this._names.length;
        }

        public ModelCategory getModelCategory() {
            return ModelCategory.DimReduction;
        }
    }

    public static class GLRMParameters
    extends Model.Parameters {
        public DataInfo.TransformType _transform = DataInfo.TransformType.NONE;
        public int _k = 1;
        public GlrmInitialization _init = GlrmInitialization.PlusPlus;
        public SVDModel.SVDParameters.Method _svd_method = SVDModel.SVDParameters.Method.Randomized;
        public Key<Frame> _user_y;
        public Key<Frame> _user_x;
        public boolean _expand_user_y = true;
        public GlrmLoss _loss = GlrmLoss.Quadratic;
        public GlrmLoss _multi_loss = GlrmLoss.Categorical;
        public int _period = 1;
        public GlrmLoss[] _loss_by_col;
        public int[] _loss_by_col_idx;
        public GlrmRegularizer _regularization_x = GlrmRegularizer.None;
        public GlrmRegularizer _regularization_y = GlrmRegularizer.None;
        public double _gamma_x = 0.0;
        public double _gamma_y = 0.0;
        public int _max_iterations = 1000;
        public int _max_updates = 2 * this._max_iterations;
        public double _init_step_size = 1.0;
        public double _min_step_size = 1.0E-4;
        public String _representation_name;
        public boolean _recover_svd = false;
        public boolean _impute_original = false;
        public boolean _verbose = true;

        public String algoName() {
            return "GLRM";
        }

        public String fullName() {
            return "Generalized Low Rank Modeling";
        }

        public String javaName() {
            return GLRMModel.class.getName();
        }

        public long progressUnits() {
            return 2 + this._max_iterations;
        }
    }
}

