/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.glrm;

import hex.ModelCategory;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.glrm.GlrmInitialization;
import hex.genmodel.algos.glrm.GlrmLoss;
import hex.genmodel.algos.glrm.GlrmRegularizer;
import java.util.EnumSet;
import java.util.Random;

public class GlrmMojoModel
extends MojoModel {
    public int _ncolA;
    public int _ncolX;
    public int _ncolY;
    public int _nrowY;
    public double[][] _archetypes;
    public double[][] _archetypes_raw;
    public int[] _numLevels;
    public int[] _catOffsets;
    public int[] _permutation;
    public GlrmLoss[] _losses;
    public GlrmRegularizer _regx;
    public double _gammax;
    public GlrmInitialization _init;
    public int _ncats;
    public int _nnums;
    public double[] _normSub;
    public double[] _normMul;
    public long _seed;
    public boolean _transposed;
    public boolean _reverse_transform;
    public double _accuracyEps = 1.0E-10;
    public int _iterNumber = 200;
    private static final double DOWN_FACTOR = 0.5;
    private static final double UP_FACTOR = Math.pow(2.0, 0.25);
    public long _rcnt = 0L;
    public int _numAlphaFactors = 20;
    public double[] _allAlphas;
    private static EnumSet<ModelCategory> CATEGORIES;

    @Override
    public EnumSet<ModelCategory> getModelCategories() {
        return CATEGORIES;
    }

    public GlrmMojoModel(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    @Override
    public int getPredsSize(ModelCategory mc) {
        return this._ncolX;
    }

    public static double[] initializeAlphas(int numAlpha) {
        double[] alphas = new double[numAlpha];
        double alpha = 1.0;
        for (int index = 0; index < numAlpha; ++index) {
            alphas[index] = alpha *= 0.5;
        }
        return alphas;
    }

    public double[] score0(double[] row, double[] preds, long seedValue) {
        double obj;
        assert (row.length == this._ncolA);
        assert (preds.length == this._ncolX);
        assert (this._nrowY == this._ncolX);
        assert (this._archetypes.length == this._nrowY);
        assert (this._archetypes[0].length == this._ncolY);
        double[] a = this.getRowData(row);
        double[] x = new double[this._ncolX];
        double[] u = new double[this._ncolX];
        Random random = new Random(seedValue);
        for (int i = 0; i < this._ncolX; ++i) {
            x[i] = random.nextGaussian();
        }
        x = this._regx.project(x, random);
        double oldObj = obj = this.objective(x, a);
        boolean done = false;
        int iters = 0;
        while (!done && iters++ < this._iterNumber) {
            double[] grad = this.gradientL(x, a);
            obj = this.applyBestAlpha(u, x, grad, a, oldObj, random);
            double obj_improvement = 1.0 - obj / oldObj;
            if (obj_improvement < 0.0 || obj_improvement < this._accuracyEps) {
                done = true;
            }
            oldObj = obj;
        }
        System.arraycopy(x, 0, preds, 0, this._ncolX);
        return preds;
    }

    public double[] getRowData(double[] row) {
        int i;
        double[] a = new double[this._ncolA];
        for (i = 0; i < this._ncats; ++i) {
            double temp = row[this._permutation[i]];
            a[i] = temp >= (double)this._numLevels[i] ? Double.NaN : temp;
        }
        for (i = this._ncats; i < this._ncolA; ++i) {
            a[i] = row[this._permutation[i]];
        }
        return a;
    }

    public double applyBestAlpha(double[] u, double[] x, double[] grad, double[] a, double oldObj, Random random) {
        double[] bestX = new double[x.length];
        double lowestObj = Double.MAX_VALUE;
        if (oldObj == 0.0) {
            return 0.0;
        }
        double alphaScale = oldObj > 10.0 ? 1.0 / oldObj : 1.0;
        for (int index = 0; index < this._numAlphaFactors; ++index) {
            double alpha = this._allAlphas[index] * alphaScale;
            for (int k = 0; k < this._ncolX; ++k) {
                u[k] = x[k] - alpha * grad[k];
            }
            double[] xnew = this._regx.rproxgrad(u, alpha * this._gammax, random);
            double newobj = this.objective(xnew, a);
            if (lowestObj > newobj) {
                System.arraycopy(xnew, 0, bestX, 0, xnew.length);
                lowestObj = newobj;
            }
            if (newobj == 0.0) break;
        }
        if (lowestObj < oldObj) {
            System.arraycopy(bestX, 0, x, 0, x.length);
        }
        return lowestObj;
    }

    @Override
    public double[] score0(double[] row, double[] preds) {
        return this.score0(row, preds, this._seed + this._rcnt++);
    }

    public static double[] impute_data(double[] xfactor, double[] preds, int nnums, int ncats, int[] permutation, boolean reverse_transform, double[] normMul, double[] normSub, GlrmLoss[] losses, boolean transposed, double[][] archetypes_raw, int[] catOffsets, int[] numLevels) {
        int d;
        assert (preds.length == nnums + ncats);
        for (d = 0; d < ncats; ++d) {
            double[] xyblock = GlrmMojoModel.lmulCatBlock(xfactor, d, numLevels, transposed, archetypes_raw, catOffsets);
            preds[permutation[d]] = losses[d].mimpute(xyblock);
        }
        for (d = ncats; d < preds.length; ++d) {
            int ds = d - ncats;
            double xy = GlrmMojoModel.lmulNumCol(xfactor, ds, transposed, archetypes_raw, catOffsets);
            preds[permutation[d]] = losses[d].impute(xy);
            if (!reverse_transform) continue;
            preds[permutation[d]] = preds[permutation[d]] / normMul[ds] + normSub[ds];
        }
        return preds;
    }

    public static int getNumCidx(int j, int[] catOffsets) {
        return catOffsets[catOffsets.length - 1] + j;
    }

    public static double lmulNumCol(double[] x, int j, boolean transposed, double[][] archetypes_raw, int[] catOffsets) {
        assert (x != null && x.length == GlrmMojoModel.rank(transposed, archetypes_raw)) : "x must be of length " + GlrmMojoModel.rank(transposed, archetypes_raw);
        int cidx = GlrmMojoModel.getNumCidx(j, catOffsets);
        double prod = 0.0;
        if (transposed) {
            for (int k = 0; k < GlrmMojoModel.rank(transposed, archetypes_raw); ++k) {
                prod += x[k] * archetypes_raw[cidx][k];
            }
        } else {
            for (int k = 0; k < GlrmMojoModel.rank(transposed, archetypes_raw); ++k) {
                prod += x[k] * archetypes_raw[k][cidx];
            }
        }
        return prod;
    }

    public static int getCatCidx(int j, int level, int[] numLevels, int[] catOffsets) {
        int catColJLevel = numLevels[j];
        assert (catColJLevel != 0) : "Number of levels in categorical column cannot be zero";
        assert (!Double.isNaN(level) && level >= 0 && level < catColJLevel) : "Got level = " + level + " when expected integer in [0," + catColJLevel + ")";
        return catOffsets[j] + level;
    }

    public static double[] lmulCatBlock(double[] x, int j, int[] numLevels, boolean transposed, double[][] archetypes_raw, int[] catOffsets) {
        int catColJLevel = numLevels[j];
        assert (catColJLevel != 0) : "Number of levels in categorical column cannot be zero";
        assert (x != null && x.length == GlrmMojoModel.rank(transposed, archetypes_raw)) : "x must be of length " + GlrmMojoModel.rank(transposed, archetypes_raw);
        double[] prod = new double[catColJLevel];
        if (transposed) {
            for (int level = 0; level < catColJLevel; ++level) {
                int cidx = GlrmMojoModel.getCatCidx(j, level, numLevels, catOffsets);
                for (int k = 0; k < GlrmMojoModel.rank(transposed, archetypes_raw); ++k) {
                    int n = level;
                    prod[n] = prod[n] + x[k] * archetypes_raw[cidx][k];
                }
            }
        } else {
            for (int level = 0; level < catColJLevel; ++level) {
                int cidx = GlrmMojoModel.getCatCidx(j, level, numLevels, catOffsets);
                for (int k = 0; k < GlrmMojoModel.rank(transposed, archetypes_raw); ++k) {
                    int n = level;
                    prod[n] = prod[n] + x[k] * archetypes_raw[k][cidx];
                }
            }
        }
        return prod;
    }

    public static int rank(boolean transposed, double[][] archetypes_raw) {
        return transposed ? archetypes_raw[0].length : archetypes_raw.length;
    }

    private double[] gradientL(double[] x, double[] a) {
        int j;
        double[] grad = new double[this._ncolX];
        int cat_offset = 0;
        for (j = 0; j < this._ncats; ++j) {
            int k;
            if (Double.isNaN(a[j])) continue;
            int n_levels = this._numLevels[j];
            double[] xy = new double[n_levels];
            for (int level = 0; level < n_levels; ++level) {
                for (k = 0; k < this._ncolX; ++k) {
                    int n = level;
                    xy[n] = xy[n] + x[k] * this._archetypes[k][level + cat_offset];
                }
            }
            double[] gradL = this._losses[j].mlgrad(xy, (int)a[j]);
            for (k = 0; k < this._ncolX; ++k) {
                for (int c = 0; c < n_levels; ++c) {
                    int n = k;
                    grad[n] = grad[n] + gradL[c] * this._archetypes[k][c + cat_offset];
                }
            }
            cat_offset += n_levels;
        }
        for (j = this._ncats; j < this._ncolA; ++j) {
            int js = j - this._ncats;
            if (Double.isNaN(a[j])) continue;
            double xy = 0.0;
            for (int k = 0; k < this._ncolX; ++k) {
                xy += x[k] * this._archetypes[k][js + cat_offset];
            }
            double gradL = this._losses[j].lgrad(xy, (a[j] - this._normSub[js]) * this._normMul[js]);
            for (int k = 0; k < this._ncolX; ++k) {
                int n = k;
                grad[n] = grad[n] + gradL * this._archetypes[k][js + cat_offset];
            }
        }
        return grad;
    }

    private double objective(double[] x, double[] a) {
        int k;
        int j;
        double res = 0.0;
        int cat_offset = 0;
        for (j = 0; j < this._ncats; ++j) {
            if (Double.isNaN(a[j])) continue;
            int n_levels = this._numLevels[j];
            double[] xy = new double[n_levels];
            for (int level = 0; level < n_levels; ++level) {
                for (k = 0; k < this._ncolX; ++k) {
                    int n = level;
                    xy[n] = xy[n] + x[k] * this._archetypes[k][level + cat_offset];
                }
            }
            res += this._losses[j].mloss(xy, (int)a[j]);
            cat_offset += n_levels;
        }
        for (j = this._ncats; j < this._ncolA; ++j) {
            int js = j - this._ncats;
            if (Double.isNaN(a[j])) continue;
            double xy = 0.0;
            for (k = 0; k < this._ncolX; ++k) {
                xy += x[k] * this._archetypes[k][js + cat_offset];
            }
            res += this._losses[j].loss(xy, (a[j] - this._normSub[js]) * this._normMul[js]);
        }
        return res += this._gammax * this._regx.regularize(x);
    }

    static {
        assert (UP_FACTOR > 1.0);
        CATEGORIES = EnumSet.of(ModelCategory.AutoEncoder, ModelCategory.DimReduction);
    }
}

