/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.glrm;

import hex.ModelCategory;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.glrm.GlrmInitialization;
import hex.genmodel.algos.glrm.GlrmLoss;
import hex.genmodel.algos.glrm.GlrmRegularizer;
import java.util.EnumSet;
import java.util.Random;

public class GlrmMojoModel
extends MojoModel {
    public int _ncolA;
    public int _ncolX;
    public int _ncolY;
    public int _nrowY;
    public double[][] _archetypes;
    public double[][] _archetypes_raw;
    public int[] _numLevels;
    public int[] _catOffsets;
    public int[] _permutation;
    public GlrmLoss[] _losses;
    public GlrmRegularizer _regx;
    public double _gammax;
    public GlrmInitialization _init;
    public int _ncats;
    public int _nnums;
    public double[] _normSub;
    public double[] _normMul;
    public long _seed;
    public boolean _transposed;
    public boolean _reverse_transform;
    public double _accuracyEps = 1.0E-10;
    public int _iterNumber = 100;
    private static final double DOWN_FACTOR = 0.5;
    private static final double UP_FACTOR = Math.pow(2.0, 0.25);
    public long _rcnt = 0L;
    public int _numAlphaFactors = 10;
    public double[] _allAlphas;
    private static EnumSet<ModelCategory> CATEGORIES;

    @Override
    public EnumSet<ModelCategory> getModelCategories() {
        return CATEGORIES;
    }

    public GlrmMojoModel(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    @Override
    public int getPredsSize(ModelCategory mc) {
        return this._ncolX;
    }

    public static double[] initializeAlphas(int numAlpha) {
        double[] dArray = new double[numAlpha];
        double d2 = 1.0;
        for (int i2 = 0; i2 < numAlpha; ++i2) {
            dArray[i2] = d2 *= 0.5;
        }
        return dArray;
    }

    public double[] score0(double[] row, double[] preds, long seedValue) {
        double d2;
        assert (row.length == this._ncolA);
        assert (preds.length == this._ncolX);
        assert (this._nrowY == this._ncolX);
        assert (this._archetypes.length == this._nrowY);
        assert (this._archetypes[0].length == this._ncolY);
        double[] dArray = this.getRowData(row);
        double[] dArray2 = new double[this._ncolX];
        double[] dArray3 = new double[this._ncolX];
        Random random = new Random(seedValue);
        for (int i2 = 0; i2 < this._ncolX; ++i2) {
            dArray2[i2] = random.nextGaussian();
        }
        dArray2 = this._regx.project(dArray2, random);
        double d3 = d2 = this.objective(dArray2, dArray);
        boolean bl = false;
        int n2 = 0;
        while (!bl && n2++ < this._iterNumber) {
            double[] dArray4 = this.gradientL(dArray2, dArray);
            d2 = this.applyBestAlpha(dArray3, dArray2, dArray4, dArray, d3, random);
            double d4 = 1.0 - d2 / d3;
            if (d4 < 0.0 || d4 < this._accuracyEps) {
                bl = true;
            }
            d3 = d2;
        }
        System.arraycopy(dArray2, 0, preds, 0, this._ncolX);
        return preds;
    }

    public double[] getRowData(double[] row) {
        int n2;
        double[] dArray = new double[this._ncolA];
        for (n2 = 0; n2 < this._ncats; ++n2) {
            double d2 = row[this._permutation[n2]];
            dArray[n2] = d2 >= (double)this._numLevels[n2] ? Double.NaN : d2;
        }
        for (n2 = this._ncats; n2 < this._ncolA; ++n2) {
            dArray[n2] = row[this._permutation[n2]];
        }
        return dArray;
    }

    public double applyBestAlpha(double[] u2, double[] x2, double[] grad, double[] a2, double oldObj, Random random) {
        double[] dArray = new double[x2.length];
        double d2 = Double.MAX_VALUE;
        if (oldObj == 0.0) {
            return 0.0;
        }
        double d3 = oldObj > 10.0 ? 1.0 / oldObj : 1.0;
        for (int i2 = 0; i2 < this._numAlphaFactors; ++i2) {
            double d4 = this._allAlphas[i2] * d3;
            for (int i3 = 0; i3 < this._ncolX; ++i3) {
                u2[i3] = x2[i3] - d4 * grad[i3];
            }
            double[] dArray2 = this._regx.rproxgrad(u2, d4 * this._gammax, random);
            double d5 = this.objective(dArray2, a2);
            if (d2 > d5) {
                System.arraycopy(dArray2, 0, dArray, 0, dArray2.length);
                d2 = d5;
            }
            if (d5 == 0.0) break;
        }
        if (d2 < oldObj) {
            System.arraycopy(dArray, 0, x2, 0, x2.length);
        }
        return d2;
    }

    @Override
    public double[] score0(double[] row, double[] preds) {
        return this.score0(row, preds, this._seed + this._rcnt++);
    }

    public static double[] impute_data(double[] xfactor, double[] preds, int nnums, int ncats, int[] permutation, boolean reverse_transform, double[] normMul, double[] normSub, GlrmLoss[] losses, boolean transposed, double[][] archetypes_raw, int[] catOffsets, int[] numLevels) {
        int n2;
        assert (preds.length == nnums + ncats);
        for (n2 = 0; n2 < ncats; ++n2) {
            double[] dArray = GlrmMojoModel.lmulCatBlock(xfactor, n2, numLevels, transposed, archetypes_raw, catOffsets);
            preds[permutation[n2]] = losses[n2].mimpute(dArray);
        }
        for (n2 = ncats; n2 < preds.length; ++n2) {
            int n3 = n2 - ncats;
            double d2 = GlrmMojoModel.lmulNumCol(xfactor, n3, transposed, archetypes_raw, catOffsets);
            preds[permutation[n2]] = losses[n2].impute(d2);
            if (!reverse_transform) continue;
            preds[permutation[n2]] = preds[permutation[n2]] / normMul[n3] + normSub[n3];
        }
        return preds;
    }

    public static int getNumCidx(int j2, int[] catOffsets) {
        return catOffsets[catOffsets.length - 1] + j2;
    }

    public static double lmulNumCol(double[] x2, int j2, boolean transposed, double[][] archetypes_raw, int[] catOffsets) {
        assert (x2 != null && x2.length == GlrmMojoModel.rank(transposed, archetypes_raw)) : "x must be of length " + GlrmMojoModel.rank(transposed, archetypes_raw);
        int n2 = GlrmMojoModel.getNumCidx(j2, catOffsets);
        double d2 = 0.0;
        if (transposed) {
            for (int i2 = 0; i2 < GlrmMojoModel.rank(transposed, archetypes_raw); ++i2) {
                d2 += x2[i2] * archetypes_raw[n2][i2];
            }
        } else {
            for (int i3 = 0; i3 < GlrmMojoModel.rank(transposed, archetypes_raw); ++i3) {
                d2 += x2[i3] * archetypes_raw[i3][n2];
            }
        }
        return d2;
    }

    public static int getCatCidx(int j2, int level, int[] numLevels, int[] catOffsets) {
        int n2 = numLevels[j2];
        assert (n2 != 0) : "Number of levels in categorical column cannot be zero";
        assert (!Double.isNaN(level) && level >= 0 && level < n2) : "Got level = " + level + " when expected integer in [0," + n2 + ")";
        return catOffsets[j2] + level;
    }

    public static double[] lmulCatBlock(double[] x2, int j2, int[] numLevels, boolean transposed, double[][] archetypes_raw, int[] catOffsets) {
        int n2 = numLevels[j2];
        assert (n2 != 0) : "Number of levels in categorical column cannot be zero";
        assert (x2 != null && x2.length == GlrmMojoModel.rank(transposed, archetypes_raw)) : "x must be of length " + GlrmMojoModel.rank(transposed, archetypes_raw);
        double[] dArray = new double[n2];
        if (transposed) {
            for (int i2 = 0; i2 < n2; ++i2) {
                int n3 = GlrmMojoModel.getCatCidx(j2, i2, numLevels, catOffsets);
                for (int i3 = 0; i3 < GlrmMojoModel.rank(transposed, archetypes_raw); ++i3) {
                    int n4 = i2;
                    dArray[n4] = dArray[n4] + x2[i3] * archetypes_raw[n3][i3];
                }
            }
        } else {
            for (int i4 = 0; i4 < n2; ++i4) {
                int n5 = GlrmMojoModel.getCatCidx(j2, i4, numLevels, catOffsets);
                for (int i5 = 0; i5 < GlrmMojoModel.rank(transposed, archetypes_raw); ++i5) {
                    int n6 = i4;
                    dArray[n6] = dArray[n6] + x2[i5] * archetypes_raw[i5][n5];
                }
            }
        }
        return dArray;
    }

    public static int rank(boolean transposed, double[][] archetypes_raw) {
        if (transposed) {
            return archetypes_raw[0].length;
        }
        return archetypes_raw.length;
    }

    private double[] gradientL(double[] x2, double[] a2) {
        int n2;
        int n3;
        double[] dArray = new double[this._ncolX];
        int n4 = 0;
        for (n3 = 0; n3 < this._ncats; ++n3) {
            int n5;
            if (Double.isNaN(a2[n3])) continue;
            n2 = this._numLevels[n3];
            double[] dArray2 = new double[n2];
            for (int i2 = 0; i2 < n2; ++i2) {
                for (n5 = 0; n5 < this._ncolX; ++n5) {
                    int n6 = i2;
                    dArray2[n6] = dArray2[n6] + x2[n5] * this._archetypes[n5][i2 + n4];
                }
            }
            double[] dArray3 = this._losses[n3].mlgrad(dArray2, (int)a2[n3]);
            for (n5 = 0; n5 < this._ncolX; ++n5) {
                for (int i3 = 0; i3 < n2; ++i3) {
                    int n7 = n5;
                    dArray[n7] = dArray[n7] + dArray3[i3] * this._archetypes[n5][i3 + n4];
                }
            }
            n4 += n2;
        }
        for (n3 = this._ncats; n3 < this._ncolA; ++n3) {
            n2 = n3 - this._ncats;
            if (Double.isNaN(a2[n3])) continue;
            double d2 = 0.0;
            for (int i4 = 0; i4 < this._ncolX; ++i4) {
                d2 += x2[i4] * this._archetypes[i4][n2 + n4];
            }
            double d3 = this._losses[n3].lgrad(d2, (a2[n3] - this._normSub[n2]) * this._normMul[n2]);
            for (int i5 = 0; i5 < this._ncolX; ++i5) {
                int n8 = i5;
                dArray[n8] = dArray[n8] + d3 * this._archetypes[i5][n2 + n4];
            }
        }
        return dArray;
    }

    private double objective(double[] x2, double[] a2) {
        int n2;
        int n3;
        int n4;
        double d2 = 0.0;
        int n5 = 0;
        for (n4 = 0; n4 < this._ncats; ++n4) {
            if (Double.isNaN(a2[n4])) continue;
            n3 = this._numLevels[n4];
            double[] dArray = new double[n3];
            for (int i2 = 0; i2 < n3; ++i2) {
                for (n2 = 0; n2 < this._ncolX; ++n2) {
                    int n6 = i2;
                    dArray[n6] = dArray[n6] + x2[n2] * this._archetypes[n2][i2 + n5];
                }
            }
            d2 += this._losses[n4].mloss(dArray, (int)a2[n4]);
            n5 += n3;
        }
        for (n4 = this._ncats; n4 < this._ncolA; ++n4) {
            n3 = n4 - this._ncats;
            if (Double.isNaN(a2[n4])) continue;
            double d3 = 0.0;
            for (n2 = 0; n2 < this._ncolX; ++n2) {
                d3 += x2[n2] * this._archetypes[n2][n3 + n5];
            }
            d2 += this._losses[n4].loss(d3, (a2[n4] - this._normSub[n3]) * this._normMul[n3]);
        }
        return d2 += this._gammax * this._regx.regularize(x2);
    }

    @Override
    public String[] getOutputNames() {
        String[] stringArray = new String[this._ncolX];
        for (int i2 = 0; i2 < stringArray.length; ++i2) {
            stringArray[i2] = "Arch" + (i2 + 1);
        }
        return stringArray;
    }

    static {
        assert (UP_FACTOR > 1.0);
        CATEGORIES = EnumSet.of(ModelCategory.AutoEncoder, ModelCategory.DimReduction);
    }
}

