/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.glrm;

import hex.ModelCategory;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.glrm.GlrmInitialization;
import hex.genmodel.algos.glrm.GlrmLoss;
import hex.genmodel.algos.glrm.GlrmRegularizer;
import java.util.EnumSet;
import java.util.Random;

public class GlrmMojoModel
extends MojoModel {
    public int _ncolA;
    public int _ncolX;
    public int _ncolY;
    public int _nrowY;
    public double[][] _archetypes;
    public int[] _numLevels;
    public int[] _permutation;
    public GlrmLoss[] _losses;
    public GlrmRegularizer _regx;
    public double _gammax;
    public GlrmInitialization _init;
    public int _ncats;
    public int _nnums;
    public double[] _normSub;
    public double[] _normMul;
    private static double alpha = 1.0;
    private static final double DOWN_FACTOR = 0.5;
    private static final double UP_FACTOR = Math.pow(2.0, 0.25);
    private static EnumSet<ModelCategory> CATEGORIES;

    @Override
    public EnumSet<ModelCategory> getModelCategories() {
        return CATEGORIES;
    }

    protected GlrmMojoModel(String[] columns, String[][] domains) {
        super(columns, domains);
    }

    @Override
    public int getPredsSize(ModelCategory mc) {
        return this._ncolX;
    }

    @Override
    public double[] score0(double[] row, double[] preds) {
        assert (row.length == this._ncolA);
        assert (preds.length == this._ncolX);
        assert (this._nrowY == this._ncolX);
        assert (this._archetypes.length == this._nrowY);
        assert (this._archetypes[0].length == this._ncolY);
        double[] a = new double[this._ncolA];
        for (int i = 0; i < this._ncolA; ++i) {
            a[i] = row[this._permutation[i]];
        }
        double[] x = new double[this._ncolX];
        Random random = new Random();
        for (int i = 0; i < this._ncolX; ++i) {
            x[i] = random.nextGaussian();
        }
        x = this._regx.project(x, random);
        double obj = this.objective(x, a);
        boolean done = false;
        int iters = 0;
        block2: while (!done && iters++ < 100) {
            double[] grad = this.gradientL(x, a);
            double[] u = new double[this._ncolX];
            while (true) {
                for (int k = 0; k < this._ncolX; ++k) {
                    u[k] = x[k] - alpha * grad[k];
                }
                double[] xnew = this._regx.rproxgrad(u, alpha * this._gammax, random);
                double newobj = this.objective(xnew, a);
                if (newobj == 0.0) continue block2;
                double obj_improvement = 1.0 - newobj / obj;
                if (obj_improvement >= 0.0) {
                    if (obj_improvement < 1.0E-6) {
                        done = true;
                    }
                    obj = newobj;
                    x = xnew;
                    alpha *= UP_FACTOR;
                    continue block2;
                }
                alpha *= 0.5;
            }
        }
        System.arraycopy(x, 0, preds, 0, this._ncolX);
        return preds;
    }

    private double[] gradientL(double[] x, double[] a) {
        int j;
        double[] grad = new double[this._ncolX];
        int cat_offset = 0;
        for (j = 0; j < this._ncats; ++j) {
            int k;
            if (Double.isNaN(a[j])) continue;
            int n_levels = this._numLevels[j];
            double[] xy = new double[n_levels];
            for (int level = 0; level < n_levels; ++level) {
                for (k = 0; k < this._ncolX; ++k) {
                    int n = level;
                    xy[n] = xy[n] + x[k] * this._archetypes[k][level + cat_offset];
                }
            }
            double[] gradL = this._losses[j].mlgrad(xy, (int)a[j]);
            for (k = 0; k < this._ncolX; ++k) {
                for (int c = 0; c < n_levels; ++c) {
                    int n = k;
                    grad[n] = grad[n] + gradL[c] * this._archetypes[k][c + cat_offset];
                }
            }
            cat_offset += n_levels;
        }
        for (j = this._ncats; j < this._ncolA; ++j) {
            int js = j - this._ncats;
            if (Double.isNaN(a[j])) continue;
            double xy = 0.0;
            for (int k = 0; k < this._ncolX; ++k) {
                xy += x[k] * this._archetypes[k][js + cat_offset];
            }
            double gradL = this._losses[j].lgrad(xy, (a[j] - this._normSub[js]) * this._normMul[js]);
            for (int k = 0; k < this._ncolX; ++k) {
                int n = k;
                grad[n] = grad[n] + gradL * this._archetypes[k][js + cat_offset];
            }
        }
        return grad;
    }

    private double objective(double[] x, double[] a) {
        int k;
        int j;
        double res = 0.0;
        int cat_offset = 0;
        for (j = 0; j < this._ncats; ++j) {
            if (Double.isNaN(a[j])) continue;
            int n_levels = this._numLevels[j];
            double[] xy = new double[n_levels];
            for (int level = 0; level < n_levels; ++level) {
                for (k = 0; k < this._ncolX; ++k) {
                    int n = level;
                    xy[n] = xy[n] + x[k] * this._archetypes[k][level + cat_offset];
                }
            }
            res += this._losses[j].mloss(xy, (int)a[j]);
            cat_offset += n_levels;
        }
        for (j = this._ncats; j < this._ncolA; ++j) {
            int js = j - this._ncats;
            if (Double.isNaN(a[j])) continue;
            double xy = 0.0;
            for (k = 0; k < this._ncolX; ++k) {
                xy += x[k] * this._archetypes[k][js + cat_offset];
            }
            res += this._losses[j].loss(xy, (a[j] - this._normSub[js]) * this._normMul[js]);
        }
        return res += this._gammax * this._regx.regularize(x);
    }

    static {
        assert (UP_FACTOR > 1.0);
        CATEGORIES = EnumSet.of(ModelCategory.AutoEncoder, ModelCategory.DimReduction);
    }
}

