/*
 * Decompiled with CFR 0.152.
 */
package hex.glm;

import hex.DataInfo;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.optimization.ADMM;
import hex.optimization.OptimizationUtils;
import java.util.Arrays;
import water.H2O;
import water.Job;
import water.MemoryManager;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MathUtils;

public final class ComputationState {
    final boolean _intercept;
    final int _nclasses;
    private final GLMModel.GLMParameters _parms;
    private GLM.BetaConstraint _bc;
    final double _alpha;
    double[] _ymu;
    double[] _u;
    double[] _z;
    boolean _allIn;
    int _iter;
    private double _lambda = 0.0;
    private double _lambdaMax = Double.NaN;
    private GLM.GLMGradientInfo _ginfo;
    private double _likelihood;
    private double _gradientErr;
    private DataInfo _activeData;
    private GLM.BetaConstraint _activeBC = null;
    private double[] _beta;
    final DataInfo _dinfo;
    private GLM.GLMGradientSolver _gslvr;
    private final Job _job;
    private int _activeClass = -1;
    public boolean _lsNeeded = false;
    private DataInfo[] _activeDataMultinomial;
    private double _betaDiff;
    private double _relImprovement;
    String convergenceMsg = "";

    public ComputationState(Job job, GLMModel.GLMParameters parms, DataInfo dinfo, GLM.BetaConstraint bc, int nclasses) {
        this._job = job;
        this._parms = parms;
        this._activeBC = this._bc = bc;
        this._activeData = this._dinfo = dinfo;
        this._intercept = this._parms._intercept;
        this._nclasses = parms._family == GLMModel.GLMParameters.Family.multinomial ? nclasses : 1;
        this._alpha = this._parms._alpha[0];
    }

    public GLM.GLMGradientSolver gslvr() {
        return this._gslvr;
    }

    public double lambda() {
        return this._lambda;
    }

    public void setLambdaMax(double lmax) {
        this._lambdaMax = lmax;
    }

    public void setLambda(double lambda) {
        this.adjustToNewLambda(this._lambda, 0.0);
        this.applyStrongRules(lambda, this._lambda);
        this.adjustToNewLambda(0.0, lambda);
        this._lambda = lambda;
        this._gslvr = new GLM.GLMGradientSolver(this._job, this._parms, this._activeData, this.l2pen(), this._activeBC);
    }

    public double[] beta() {
        if (this._activeClass != -1) {
            return this.betaMultinomial(this._activeClass, this._beta);
        }
        return this._beta;
    }

    public GLM.GLMGradientInfo ginfo() {
        return this._ginfo == null ? (this._ginfo = this.gslvr().getGradient(this.beta())) : this._ginfo;
    }

    public GLM.BetaConstraint activeBC() {
        return this._activeBC;
    }

    public double likelihood() {
        return this._likelihood;
    }

    public DataInfo activeData() {
        if (this._activeClass != -1) {
            return this.activeDataMultinomial(this._activeClass);
        }
        return this._activeData;
    }

    public DataInfo activeDataMultinomial() {
        return this._activeData;
    }

    public void dropActiveData() {
        this._activeData = null;
    }

    public String toString() {
        return "iter=" + this._iter + " lmb=" + GLM.lambdaFormatter.format(this._lambda) + " obj=" + MathUtils.roundToNDigits((double)this.objective(), (int)4) + " imp=" + GLM.lambdaFormatter.format(this._relImprovement) + " bdf=" + GLM.lambdaFormatter.format(this._betaDiff);
    }

    private void adjustToNewLambda(double lambdaNew, double lambdaOld) {
        double ldiff = lambdaNew - lambdaOld;
        if (ldiff == 0.0 || this.l2pen() == 0.0) {
            return;
        }
        double l2pen = 0.5 * ArrayUtils.l2norm2((double[])this._beta, (boolean)true);
        if (l2pen > 0.0) {
            if (this._parms._family == GLMModel.GLMParameters.Family.multinomial) {
                int off = 0;
                for (int c = 0; c < this._nclasses; ++c) {
                    DataInfo activeData = this.activeDataMultinomial(c);
                    for (int i = 0; i < activeData.fullN(); ++i) {
                        int n = off + i;
                        this._ginfo._gradient[n] = this._ginfo._gradient[n] + ldiff * this._beta[off + i];
                    }
                    off += activeData.fullN() + 1;
                }
            } else {
                for (int i = 0; i < this._activeData.fullN(); ++i) {
                    int n = i;
                    this._ginfo._gradient[n] = this._ginfo._gradient[n] + ldiff * this._beta[i];
                }
            }
        }
        this._ginfo = new GLM.GLMGradientInfo(this._ginfo._likelihood, this._ginfo._objVal + ldiff * l2pen, this._ginfo._gradient);
    }

    public double l1pen() {
        return this._alpha * this._lambda;
    }

    public double l2pen() {
        return (1.0 - this._alpha) * this._lambda;
    }

    protected int applyStrongRules(double lambdaNew, double lambdaOld) {
        lambdaNew = Math.min(this._lambdaMax, lambdaNew);
        lambdaOld = Math.min(this._lambdaMax, lambdaOld);
        if (this._parms._family == GLMModel.GLMParameters.Family.multinomial) {
            return this.applyStrongRulesMultinomial(lambdaNew, lambdaOld);
        }
        int P = this._dinfo.fullN();
        int newlySelected = 0;
        this._activeBC = this._bc;
        this._activeData = this._activeData != null ? this._activeData : this._dinfo;
        boolean bl = this._allIn = this._allIn || this._parms._alpha[0] * lambdaNew == 0.0 || this._activeBC.hasBounds();
        if (!this._allIn) {
            double rhs = Math.max(0.0, this._alpha * (2.0 * lambdaNew - lambdaOld));
            int[] newCols = MemoryManager.malloc4((int)P);
            int j = 0;
            int[] oldActiveCols = this._activeData._activeCols == null ? new int[]{} : this._activeData.activeCols();
            for (int i = 0; i < P; ++i) {
                if (j < oldActiveCols.length && i == oldActiveCols[j]) {
                    ++j;
                    newCols[newlySelected++] = i;
                    continue;
                }
                if (!(this._ginfo._gradient[i] > rhs) && !(-this._ginfo._gradient[i] > rhs)) continue;
                newCols[newlySelected++] = i;
            }
            int active = newlySelected;
            boolean bl2 = this._allIn = active == P;
            if (!this._allIn) {
                int[] cols = newCols;
                cols[newlySelected++] = P;
                cols = Arrays.copyOf(cols, newlySelected);
                this._beta = ArrayUtils.select((double[])this._beta, (int[])cols);
                if (this._u != null) {
                    this._u = ArrayUtils.select((double[])this._u, (int[])cols);
                }
                this._activeData = this._dinfo.filterExpandedColumns(cols);
                assert (this._activeData.activeCols().length == this._beta.length);
                assert (this._u == null || this._activeData.activeCols().length == this._u.length);
                this._ginfo = new GLM.GLMGradientInfo(this._ginfo._likelihood, this._ginfo._objVal, ArrayUtils.select((double[])this._ginfo._gradient, (int[])cols));
                this._activeBC = this._bc.filterExpandedColumns(this._activeData.activeCols());
                this._gslvr = new GLM.GLMGradientSolver(this._job, this._parms, this._activeData, (1.0 - this._alpha) * this._lambda, this._bc);
                assert (this._beta.length == newlySelected);
                return newlySelected;
            }
        }
        this._activeData = this._dinfo;
        return this._dinfo.fullN();
    }

    public DataInfo activeDataMultinomial(int c) {
        return this._activeDataMultinomial != null ? this._activeDataMultinomial[c] : this._dinfo;
    }

    private static double[] extractSubRange(int N, int c, int[] ids, double[] src) {
        if (ids == null) {
            return Arrays.copyOfRange(src, c * N, c * N + N);
        }
        double[] res = MemoryManager.malloc8d((int)ids.length);
        int j = 0;
        int off = c * N;
        for (int i : ids) {
            res[j++] = src[off + i];
        }
        return res;
    }

    private static void fillSubRange(int N, int c, int[] ids, double[] src, double[] dst) {
        if (ids == null) {
            System.arraycopy(src, 0, dst, c * N, N);
        } else {
            int j = 0;
            int off = c * N;
            for (int i : ids) {
                dst[off + i] = src[j++];
            }
        }
    }

    public double[] betaMultinomial() {
        return this._beta;
    }

    public double[] betaMultinomial(int c, double[] beta) {
        return ComputationState.extractSubRange(this._activeData.fullN() + 1, c, this._activeDataMultinomial[c].activeCols(), beta);
    }

    public GLMSubsetGinfo ginfoMultinomial(int c) {
        return new GLMSubsetGinfo(this._ginfo, this._activeData.fullN() + 1, c, this._activeDataMultinomial[c].activeCols());
    }

    public void setBC(GLM.BetaConstraint bc) {
        this._activeBC = this._bc = bc;
    }

    public void setActiveClass(int activeClass) {
        this._activeClass = activeClass;
    }

    public double deviance() {
        return 2.0 * this.likelihood();
    }

    public OptimizationUtils.GradientSolver gslvrMultinomial(final int c) {
        final double[] fullbeta = (double[])this._beta.clone();
        return new OptimizationUtils.GradientSolver(){

            @Override
            public OptimizationUtils.GradientInfo getGradient(double[] beta) {
                ComputationState.fillSubRange(ComputationState.this._activeData.fullN() + 1, c, ComputationState.this._activeDataMultinomial[c].activeCols(), beta, fullbeta);
                GLM.GLMGradientInfo fullGinfo = ComputationState.this._gslvr.getGradient(fullbeta);
                return new GLMSubsetGinfo(fullGinfo, ComputationState.this._activeData.fullN() + 1, c, ComputationState.this._activeDataMultinomial[c].activeCols());
            }

            @Override
            public OptimizationUtils.GradientInfo getObjective(double[] beta) {
                return this.getGradient(beta);
            }
        };
    }

    public void setBetaMultinomial(int c, double[] beta, double[] bc) {
        if (this._u != null) {
            Arrays.fill(this._u, 0.0);
        }
        ComputationState.fillSubRange(this._activeData.fullN() + 1, c, this._activeDataMultinomial[c].activeCols(), bc, beta);
    }

    protected int applyStrongRulesMultinomial(double lambdaNew, double lambdaOld) {
        int P = this._dinfo.fullN();
        int N = P + 1;
        int selected = 0;
        this._activeBC = this._bc;
        this._activeData = this._dinfo;
        if (!this._allIn) {
            if (this._activeDataMultinomial == null) {
                this._activeDataMultinomial = new DataInfo[this._nclasses];
            }
            double rhs = this._alpha * (2.0 * lambdaNew - lambdaOld);
            int[] oldActiveCols = this._activeData._activeCols == null ? new int[]{} : this._activeData.activeCols();
            int[] cols = MemoryManager.malloc4((int)(N * this._nclasses));
            int j = 0;
            for (int c = 0; c < this._nclasses; ++c) {
                int i;
                int start = selected;
                for (i = 0; i < P; ++i) {
                    if (j < oldActiveCols.length && i == oldActiveCols[j]) {
                        cols[selected++] = i;
                        ++j;
                        continue;
                    }
                    if (!(this._ginfo._gradient[c * N + i] > rhs) && !(this._ginfo._gradient[c * N + i] < -rhs)) continue;
                    cols[selected++] = i;
                }
                cols[selected++] = P;
                this._activeDataMultinomial[c] = this._dinfo.filterExpandedColumns(Arrays.copyOfRange(cols, start, selected));
                i = start;
                while (i < selected) {
                    int n = i++;
                    cols[n] = cols[n] + c * N;
                }
            }
            this._allIn = selected == cols.length;
        }
        return selected;
    }

    protected boolean checkKKTsMultinomial() {
        if (this._activeData._activeCols == null) {
            return true;
        }
        throw H2O.unimpl();
    }

    protected boolean checkKKTs() {
        if (this._parms._family == GLMModel.GLMParameters.Family.multinomial) {
            return this.checkKKTsMultinomial();
        }
        double[] beta = this._beta;
        double[] u = this._u;
        if (this._activeData._activeCols != null) {
            beta = ArrayUtils.expandAndScatter((double[])beta, (int)(this._dinfo.fullN() + 1), (int[])this._activeData._activeCols);
            if (this._u != null) {
                u = ArrayUtils.expandAndScatter((double[])this._u, (int)(this._dinfo.fullN() + 1), (int[])this._activeData._activeCols);
            }
        }
        int[] activeCols = this._activeData.activeCols();
        if (beta != this._beta || this._ginfo == null) {
            this._gslvr = new GLM.GLMGradientSolver(this._job, this._parms, this._dinfo, (1.0 - this._alpha) * this._lambda, this._bc);
            this._ginfo = this._gslvr.getGradient(beta);
        }
        double[] grad = (double[])this._ginfo._gradient.clone();
        double err = 1.0E-4;
        if (u != null && u != this._u) {
            int k = 0;
            for (int i = 0; i < u.length; ++i) {
                if (this._activeData._activeCols[k] == i) {
                    ++k;
                    continue;
                }
                assert (u[i] == 0.0);
                u[i] = -grad[i];
            }
        }
        ADMM.subgrad(this._alpha * this._lambda, beta, grad);
        for (int c : activeCols) {
            if (grad[c] > err) {
                err = grad[c];
                continue;
            }
            if (!(grad[c] < -err)) continue;
            err = -grad[c];
        }
        this._gradientErr = err;
        this._beta = beta;
        this._u = u;
        this._activeBC = null;
        if (!this._allIn) {
            int[] failedCols = new int[64];
            int fcnt = 0;
            for (int i = 0; i < grad.length - 1; ++i) {
                if (Arrays.binarySearch(activeCols, i) >= 0 || !(grad[i] > err) && !(-grad[i] > err)) continue;
                if (fcnt == failedCols.length) {
                    failedCols = Arrays.copyOf(failedCols, failedCols.length << 1);
                }
                failedCols[fcnt++] = i;
            }
            if (fcnt > 0) {
                Log.info((Object[])new Object[]{fcnt + " variables failed KKT conditions, adding them to the model and recomputing."});
                int n = activeCols.length;
                int[] newCols = Arrays.copyOf(activeCols, activeCols.length + fcnt);
                for (int i = 0; i < fcnt; ++i) {
                    newCols[n + i] = failedCols[i];
                }
                Arrays.sort(newCols);
                this._beta = ArrayUtils.select((double[])beta, (int[])newCols);
                if (this._u != null) {
                    this._u = ArrayUtils.select((double[])this._u, (int[])newCols);
                }
                this._ginfo = new GLM.GLMGradientInfo(this._ginfo._likelihood, this._ginfo._objVal, ArrayUtils.select((double[])this._ginfo._gradient, (int[])newCols));
                this._activeData = this._dinfo.filterExpandedColumns(newCols);
                this._activeBC = this._bc.filterExpandedColumns(this._activeData.activeCols());
                this._gslvr = new GLM.GLMGradientSolver(this._job, this._parms, this._activeData, (1.0 - this._alpha) * this._lambda, this._activeBC);
                return false;
            }
        }
        return true;
    }

    public int[] removeCols(int[] cols) {
        int[] activeCols = ArrayUtils.removeIds((int[])this._activeData.activeCols(), (int[])cols);
        if (this._beta != null) {
            this._beta = ArrayUtils.removeIds((double[])this._beta, (int[])cols);
        }
        if (this._u != null) {
            this._u = ArrayUtils.removeIds((double[])this._u, (int[])cols);
        }
        if (this._ginfo != null && this._ginfo._gradient != null) {
            this._ginfo._gradient = ArrayUtils.removeIds((double[])this._ginfo._gradient, (int[])cols);
        }
        this._activeData = this._dinfo.filterExpandedColumns(activeCols);
        this._activeBC = this._bc.filterExpandedColumns(activeCols);
        this._gslvr = new GLM.GLMGradientSolver(this._job, this._parms, this._activeData, (1.0 - this._alpha) * this._lambda, this._activeBC);
        return activeCols;
    }

    private double penalty(double[] beta) {
        if (this._lambda == 0.0) {
            return 0.0;
        }
        double l1norm = 0.0;
        double l2norm = 0.0;
        if (this._parms._family == GLMModel.GLMParameters.Family.multinomial) {
            for (int c = 0; c < this._nclasses; ++c) {
            }
        } else {
            for (int i = 0; i < beta.length - 1; ++i) {
                double d = beta[i];
                l1norm += d >= 0.0 ? d : -d;
                l2norm += d * d;
            }
        }
        return this.l1pen() * l1norm + 0.5 * this.l2pen() * l2norm;
    }

    public double objective() {
        return this._beta == null ? Double.MAX_VALUE : this.objective(this._beta, this._likelihood);
    }

    public double objective(double[] beta, double likelihood) {
        return likelihood * this._parms._obj_reg + this.penalty(beta) + (this._activeBC == null ? 0.0 : this._activeBC.proxPen(beta));
    }

    protected double updateState(double[] beta, double likelihood) {
        this._betaDiff = ArrayUtils.linfnorm((double[])(this._beta == null ? beta : ArrayUtils.subtract((double[])this._beta, (double[])beta)), (boolean)false);
        double objOld = this.objective();
        this._beta = beta;
        this._ginfo = null;
        this._likelihood = likelihood;
        this._relImprovement = (objOld - this.objective()) / objOld;
        return this._relImprovement;
    }

    public boolean converged() {
        boolean converged = false;
        if (this._betaDiff < this._parms._beta_epsilon) {
            this.convergenceMsg = "betaDiff < eps; betaDiff = " + this._betaDiff + ", eps = " + this._parms._beta_epsilon;
            converged = true;
        } else if (this._relImprovement < this._parms._objective_epsilon) {
            this.convergenceMsg = "relImprovement < eps; relImprovement = " + this._relImprovement + ", eps = " + this._parms._objective_epsilon;
            converged = true;
        } else {
            this.convergenceMsg = "not converged, betaDiff = " + this._betaDiff + ", relImprovement = " + this._relImprovement;
        }
        return converged;
    }

    protected double updateState(double[] beta, GLM.GLMGradientInfo ginfo) {
        this._betaDiff = ArrayUtils.linfnorm((double[])(this._beta == null ? beta : ArrayUtils.subtract((double[])this._beta, (double[])beta)), (boolean)false);
        double objOld = this.objective();
        if (this._beta == null) {
            this._beta = (double[])beta.clone();
        } else {
            System.arraycopy(beta, 0, this._beta, 0, beta.length);
        }
        this._ginfo = ginfo;
        this._likelihood = ginfo._likelihood;
        this._relImprovement = (objOld - this.objective()) / objOld;
        return this._relImprovement;
    }

    public double[] expandBeta(double[] beta) {
        if (this._activeData._activeCols == null) {
            return beta;
        }
        return ArrayUtils.expandAndScatter((double[])beta, (int)((this._dinfo.fullN() + 1) * this._nclasses), (int[])this._activeData._activeCols);
    }

    public static class GLMSubsetGinfo
    extends GLM.GLMGradientInfo {
        public final GLM.GLMGradientInfo _fullInfo;

        public GLMSubsetGinfo(GLM.GLMGradientInfo fullInfo, int N, int c, int[] ids) {
            super(fullInfo._likelihood, fullInfo._objVal, ComputationState.extractSubRange(N, c, ids, fullInfo._gradient));
            this._fullInfo = fullInfo;
        }
    }
}

