/*
 * Decompiled with CFR 0.152.
 */
package hex.coxph;

import Jama.Matrix;
import hex.DataInfo;
import hex.FrameTask;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.coxph.CoxPHModel;
import java.util.Arrays;
import jsr166y.ForkJoinTask;
import jsr166y.RecursiveAction;
import water.DKV;
import water.Job;
import water.Key;
import water.Keyed;
import water.MemoryManager;
import water.Scope;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.VecUtils;

public class CoxPH
extends ModelBuilder<CoxPHModel, CoxPHModel.CoxPHParameters, CoxPHModel.CoxPHOutput> {
    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.CoxPH};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    public boolean isSupervised() {
        return true;
    }

    public CoxPH(boolean startup_once) {
        super((Model.Parameters)new CoxPHModel.CoxPHParameters(), startup_once);
    }

    public CoxPH(CoxPHModel.CoxPHParameters parms) {
        super((Model.Parameters)parms);
        this.init(false);
    }

    protected CoxPHDriver trainModelImpl() {
        return new CoxPHDriver();
    }

    public void init(boolean expensive) {
        super.init(expensive);
        if (((CoxPHModel.CoxPHParameters)this._parms)._train != null) {
            if (((CoxPHModel.CoxPHParameters)this._parms)._start_column != null && !((CoxPHModel.CoxPHParameters)this._parms).startVec().isInt()) {
                this.error("start_column", "start time must be null or of type integer");
            }
            if (!((CoxPHModel.CoxPHParameters)this._parms).stopVec().isInt()) {
                this.error("stop_column", "stop time must be of type integer");
            }
            if (!this._response.isInt() && !this._response.isCategorical()) {
                this.error("response_column", "response/event column must be of type integer or factor");
            }
            int MAX_TIME_BINS = 10000;
            long min_time = ((CoxPHModel.CoxPHParameters)this._parms).startVec() == null ? (long)((CoxPHModel.CoxPHParameters)this._parms).stopVec().min() : (long)((CoxPHModel.CoxPHParameters)this._parms).startVec().min() + 1L;
            int n_time = (int)(((CoxPHModel.CoxPHParameters)this._parms).stopVec().max() - (double)min_time + 1.0);
            if (n_time < 1) {
                this.error("start_column", "start times must be strictly less than stop times");
            }
            if (n_time > 10000) {
                this.error("stop_column", "number of distinct stop times is " + n_time + "; maximum number allowed is " + 10000);
            }
        }
        if (Double.isNaN(((CoxPHModel.CoxPHParameters)this._parms)._lre_min) || ((CoxPHModel.CoxPHParameters)this._parms)._lre_min <= 0.0) {
            this.error("lre_min", "lre_min must be a positive number");
        }
        if (((CoxPHModel.CoxPHParameters)this._parms)._iter_max < 1) {
            this.error("iter_max", "iter_max must be a positive integer");
        }
    }

    private static double[][] malloc2DArray(int d1, int d2) {
        double[][] array = new double[d1][];
        for (int j = 0; j < d1; ++j) {
            array[j] = MemoryManager.malloc8d((int)d2);
        }
        return array;
    }

    private static double[][][] malloc3DArray(int d1, int d2, int d3) {
        double[][][] array = new double[d1][d2][];
        for (int j = 0; j < d1; ++j) {
            for (int k = 0; k < d2; ++k) {
                array[j][k] = MemoryManager.malloc8d((int)d3);
            }
        }
        return array;
    }

    protected static class CoxPHTask
    extends FrameTask<CoxPHTask> {
        private final double[] _beta;
        private final int _n_time;
        private final long _min_time;
        private final int _n_offsets;
        private final boolean _has_start_column;
        private final boolean _has_weights_column;
        protected long n;
        protected double sumWeights;
        protected double[] sumWeightedCatX;
        protected double[] sumWeightedNumX;
        protected double[] sizeRiskSet;
        protected double[] sizeCensored;
        protected double[] sizeEvents;
        protected long[] countEvents;
        protected double[][] sumXEvents;
        protected double[] sumRiskEvents;
        protected double[][] sumXRiskEvents;
        protected double[][][] sumXXRiskEvents;
        protected double[] sumLogRiskEvents;
        protected double[] rcumsumRisk;
        protected double[][] rcumsumXRisk;
        protected double[][][] rcumsumXXRisk;

        CoxPHTask(Key<Job> jobKey, DataInfo dinfo, double[] beta, long min_time, int n_time, int n_offsets, boolean has_start_column, boolean has_weights_column) {
            super(jobKey, dinfo);
            this._beta = beta;
            this._n_time = n_time;
            this._min_time = min_time;
            this._n_offsets = n_offsets;
            this._has_start_column = has_start_column;
            this._has_weights_column = has_weights_column;
        }

        @Override
        protected boolean chunkInit() {
            int n_coef = this._beta.length;
            this.sumWeightedCatX = MemoryManager.malloc8d((int)(n_coef - (this._dinfo._nums - this._n_offsets)));
            this.sumWeightedNumX = MemoryManager.malloc8d((int)this._dinfo._nums);
            this.sizeRiskSet = MemoryManager.malloc8d((int)this._n_time);
            this.sizeCensored = MemoryManager.malloc8d((int)this._n_time);
            this.sizeEvents = MemoryManager.malloc8d((int)this._n_time);
            this.countEvents = MemoryManager.malloc8((int)this._n_time);
            this.sumRiskEvents = MemoryManager.malloc8d((int)this._n_time);
            this.sumLogRiskEvents = MemoryManager.malloc8d((int)this._n_time);
            this.rcumsumRisk = MemoryManager.malloc8d((int)this._n_time);
            this.sumXEvents = CoxPH.malloc2DArray(this._n_time, n_coef);
            this.sumXRiskEvents = CoxPH.malloc2DArray(this._n_time, n_coef);
            this.rcumsumXRisk = CoxPH.malloc2DArray(this._n_time, n_coef);
            this.sumXXRiskEvents = CoxPH.malloc3DArray(this._n_time, n_coef, n_coef);
            this.rcumsumXXRisk = CoxPH.malloc3DArray(this._n_time, n_coef, n_coef);
            return true;
        }

        @Override
        protected void processRow(long gid, DataInfo.Row row) {
            int j;
            int j2;
            int t2;
            double weight;
            ++this.n;
            double[] response = row.response;
            int ncats = row.nBins;
            int[] cats = row.numIds;
            double[] nums = row.numVals;
            double d = weight = this._has_weights_column ? response[0] : 1.0;
            if (weight <= 0.0) {
                throw new IllegalArgumentException("weights must be positive values");
            }
            long event = (long)response[response.length - 1];
            int t1 = this._has_start_column ? (int)((long)response[response.length - 3] + 1L - this._min_time) : -1;
            if (t1 > (t2 = (int)((long)response[response.length - 2] - this._min_time))) {
                throw new IllegalArgumentException("start times must be strictly less than stop times");
            }
            int numStart = this._dinfo.numStart();
            this.sumWeights += weight;
            for (j2 = 0; j2 < ncats; ++j2) {
                int n = cats[j2];
                this.sumWeightedCatX[n] = this.sumWeightedCatX[n] + weight;
            }
            for (j2 = 0; j2 < nums.length; ++j2) {
                int n = j2;
                this.sumWeightedNumX[n] = this.sumWeightedNumX[n] + weight * nums[j2];
            }
            double logRisk = 0.0;
            for (j = 0; j < ncats; ++j) {
                logRisk += this._beta[cats[j]];
            }
            for (j = 0; j < nums.length - this._n_offsets; ++j) {
                logRisk += nums[j] * this._beta[numStart + j];
            }
            for (j = nums.length - this._n_offsets; j < nums.length; ++j) {
                logRisk += nums[j];
            }
            double risk = weight * Math.exp(logRisk);
            logRisk *= weight;
            if (event > 0L) {
                int n = t2;
                this.countEvents[n] = this.countEvents[n] + 1L;
                int n2 = t2;
                this.sizeEvents[n2] = this.sizeEvents[n2] + weight;
                int n3 = t2;
                this.sumLogRiskEvents[n3] = this.sumLogRiskEvents[n3] + logRisk;
                int n4 = t2;
                this.sumRiskEvents[n4] = this.sumRiskEvents[n4] + risk;
            } else {
                int n = t2;
                this.sizeCensored[n] = this.sizeCensored[n] + weight;
            }
            if (this._has_start_column) {
                int t = t1;
                while (t <= t2) {
                    int n = t++;
                    this.sizeRiskSet[n] = this.sizeRiskSet[n] + weight;
                }
                t = t1;
                while (t <= t2) {
                    int n = t++;
                    this.rcumsumRisk[n] = this.rcumsumRisk[n] + risk;
                }
            } else {
                int n = t2;
                this.sizeRiskSet[n] = this.sizeRiskSet[n] + weight;
                int n5 = t2;
                this.rcumsumRisk[n5] = this.rcumsumRisk[n5] + risk;
            }
            int ntotal = ncats + (nums.length - this._n_offsets);
            int numStartIter = numStart - ncats;
            for (int jit = 0; jit < ntotal; ++jit) {
                boolean jIsCat = jit < ncats;
                int j3 = jIsCat ? cats[jit] : numStartIter + jit;
                double x1 = jIsCat ? 1.0 : nums[jit - ncats];
                double xRisk = x1 * risk;
                if (event > 0L) {
                    double[] dArray = this.sumXEvents[t2];
                    int n = j3;
                    dArray[n] = dArray[n] + weight * x1;
                    double[] dArray2 = this.sumXRiskEvents[t2];
                    int n6 = j3;
                    dArray2[n6] = dArray2[n6] + xRisk;
                }
                if (this._has_start_column) {
                    for (int t = t1; t <= t2; ++t) {
                        double[] dArray = this.rcumsumXRisk[t];
                        int n = j3;
                        dArray[n] = dArray[n] + xRisk;
                    }
                } else {
                    double[] dArray = this.rcumsumXRisk[t2];
                    int n = j3;
                    dArray[n] = dArray[n] + xRisk;
                }
                for (int kit = 0; kit < ntotal; ++kit) {
                    boolean kIsCat = kit < ncats;
                    int k = kIsCat ? cats[kit] : numStartIter + kit;
                    double x2 = kIsCat ? 1.0 : nums[kit - ncats];
                    double xxRisk = x2 * xRisk;
                    if (event > 0L) {
                        double[] dArray = this.sumXXRiskEvents[t2][j3];
                        int n = k;
                        dArray[n] = dArray[n] + xxRisk;
                    }
                    if (this._has_start_column) {
                        for (int t = t1; t <= t2; ++t) {
                            double[] dArray = this.rcumsumXXRisk[t][j3];
                            int n = k;
                            dArray[n] = dArray[n] + xxRisk;
                        }
                        continue;
                    }
                    double[] dArray = this.rcumsumXXRisk[t2][j3];
                    int n = k;
                    dArray[n] = dArray[n] + xxRisk;
                }
            }
        }

        public void reduce(CoxPHTask that) {
            this.n += that.n;
            this.sumWeights += that.sumWeights;
            ArrayUtils.add((double[])this.sumWeightedCatX, (double[])that.sumWeightedCatX);
            ArrayUtils.add((double[])this.sumWeightedNumX, (double[])that.sumWeightedNumX);
            ArrayUtils.add((double[])this.sizeRiskSet, (double[])that.sizeRiskSet);
            ArrayUtils.add((double[])this.sizeCensored, (double[])that.sizeCensored);
            ArrayUtils.add((double[])this.sizeEvents, (double[])that.sizeEvents);
            ArrayUtils.add((long[])this.countEvents, (long[])that.countEvents);
            ArrayUtils.add((double[][])this.sumXEvents, (double[][])that.sumXEvents);
            ArrayUtils.add((double[])this.sumRiskEvents, (double[])that.sumRiskEvents);
            ArrayUtils.add((double[][])this.sumXRiskEvents, (double[][])that.sumXRiskEvents);
            ArrayUtils.add((double[][][])this.sumXXRiskEvents, (double[][][])that.sumXXRiskEvents);
            ArrayUtils.add((double[])this.sumLogRiskEvents, (double[])that.sumLogRiskEvents);
            ArrayUtils.add((double[])this.rcumsumRisk, (double[])that.rcumsumRisk);
            ArrayUtils.add((double[][])this.rcumsumXRisk, (double[][])that.rcumsumXRisk);
            ArrayUtils.add((double[][][])this.rcumsumXXRisk, (double[][][])that.rcumsumXXRisk);
        }

        protected void postGlobal() {
            if (!this._has_start_column) {
                int j;
                int t;
                for (t = this.rcumsumRisk.length - 2; t >= 0; --t) {
                    int n = t;
                    this.rcumsumRisk[n] = this.rcumsumRisk[n] + this.rcumsumRisk[t + 1];
                }
                for (t = this.rcumsumXRisk.length - 2; t >= 0; --t) {
                    for (j = 0; j < this.rcumsumXRisk[t].length; ++j) {
                        double[] dArray = this.rcumsumXRisk[t];
                        int n = j;
                        dArray[n] = dArray[n] + this.rcumsumXRisk[t + 1][j];
                    }
                }
                for (t = this.rcumsumXXRisk.length - 2; t >= 0; --t) {
                    for (j = 0; j < this.rcumsumXXRisk[t].length; ++j) {
                        for (int k = 0; k < this.rcumsumXXRisk[t][j].length; ++k) {
                            double[] dArray = this.rcumsumXXRisk[t][j];
                            int n = k;
                            dArray[n] = dArray[n] + this.rcumsumXXRisk[t + 1][j][k];
                        }
                    }
                }
            }
        }
    }

    public class CoxPHDriver
    extends ModelBuilder.Driver {
        public CoxPHDriver() {
            super((ModelBuilder)CoxPH.this);
        }

        private Frame reorderTrainFrameColumns() {
            Frame f = new Frame(new Vec[0]);
            Vec weightVec = null;
            Vec startVec = null;
            Vec stopVec = null;
            Vec eventVec = null;
            Vec[] vecs = CoxPH.this.train().vecs();
            String[] names = CoxPH.this.train().names();
            for (int i = 0; i < names.length; ++i) {
                if (names[i].equals(((CoxPHModel.CoxPHParameters)CoxPH.this._parms)._weights_column)) {
                    weightVec = vecs[i];
                    continue;
                }
                if (names[i].equals(((CoxPHModel.CoxPHParameters)CoxPH.this._parms)._start_column)) {
                    startVec = vecs[i];
                    continue;
                }
                if (names[i].equals(((CoxPHModel.CoxPHParameters)CoxPH.this._parms)._stop_column)) {
                    stopVec = vecs[i];
                    continue;
                }
                if (names[i].equals(((CoxPHModel.CoxPHParameters)CoxPH.this._parms)._response_column)) {
                    eventVec = vecs[i];
                    continue;
                }
                f.add(names[i], vecs[i]);
            }
            if (weightVec != null) {
                f.add(((CoxPHModel.CoxPHParameters)CoxPH.this._parms)._weights_column, weightVec);
            }
            if (startVec != null) {
                f.add(((CoxPHModel.CoxPHParameters)CoxPH.this._parms)._start_column, startVec);
            }
            if (stopVec != null) {
                f.add(((CoxPHModel.CoxPHParameters)CoxPH.this._parms)._stop_column, stopVec);
            }
            if (eventVec != null) {
                f.add(((CoxPHModel.CoxPHParameters)CoxPH.this._parms)._response_column, eventVec);
            }
            return f;
        }

        protected void initStats(CoxPHModel model, DataInfo dinfo) {
            CoxPHModel.CoxPHParameters p = (CoxPHModel.CoxPHParameters)model._parms;
            CoxPHModel.CoxPHOutput o = (CoxPHModel.CoxPHOutput)model._output;
            o._n = p.stopVec().length();
            o.data_info = dinfo;
            int n_offsets = CoxPH.this._offset == null ? 0 : 1;
            int n_coef = o.data_info.fullN() - n_offsets;
            String[] coefNames = o.data_info.coefNames();
            o._coef_names = new String[n_coef];
            System.arraycopy(coefNames, 0, o._coef_names, 0, n_coef);
            o._coef = MemoryManager.malloc8d((int)n_coef);
            o._exp_coef = MemoryManager.malloc8d((int)n_coef);
            o._exp_neg_coef = MemoryManager.malloc8d((int)n_coef);
            o._se_coef = MemoryManager.malloc8d((int)n_coef);
            o._z_coef = MemoryManager.malloc8d((int)n_coef);
            o.gradient = MemoryManager.malloc8d((int)n_coef);
            o.hessian = CoxPH.malloc2DArray(n_coef, n_coef);
            o._var_coef = CoxPH.malloc2DArray(n_coef, n_coef);
            o._x_mean_cat = MemoryManager.malloc8d((int)(n_coef - (o.data_info._nums - n_offsets)));
            o._x_mean_num = MemoryManager.malloc8d((int)(o.data_info._nums - n_offsets));
            o._mean_offset = MemoryManager.malloc8d((int)n_offsets);
            o._offset_names = new String[n_offsets];
            System.arraycopy(coefNames, n_coef, o._offset_names, 0, n_offsets);
            o._min_time = p.startVec() == null ? (long)p.stopVec().min() : (long)p.startVec().min() + 1L;
            o._max_time = (long)p.stopVec().max();
            int n_time = ((VecUtils.CollectIntegerDomain)new VecUtils.CollectIntegerDomain().doAll(new Vec[]{p.stopVec()})).domain().length;
            o._time = MemoryManager.malloc8((int)n_time);
            o._n_risk = MemoryManager.malloc8d((int)n_time);
            o._n_event = MemoryManager.malloc8d((int)n_time);
            o._n_censor = MemoryManager.malloc8d((int)n_time);
            o._cumhaz_0 = MemoryManager.malloc8d((int)n_time);
            o._var_cumhaz_1 = MemoryManager.malloc8d((int)n_time);
            o._var_cumhaz_2 = CoxPH.malloc2DArray(n_time, n_coef);
        }

        protected void calcCounts(CoxPHModel model, CoxPHTask coxMR) {
            int t;
            int j;
            CoxPHModel.CoxPHParameters p = (CoxPHModel.CoxPHParameters)model._parms;
            CoxPHModel.CoxPHOutput o = (CoxPHModel.CoxPHOutput)model._output;
            o._n_missing = o._n - coxMR.n;
            o._n = coxMR.n;
            for (j = 0; j < o._x_mean_cat.length; ++j) {
                o._x_mean_cat[j] = coxMR.sumWeightedCatX[j] / coxMR.sumWeights;
            }
            for (j = 0; j < o._x_mean_num.length; ++j) {
                o._x_mean_num[j] = o.data_info._normSub[j] + coxMR.sumWeightedNumX[j] / coxMR.sumWeights;
            }
            System.arraycopy(o.data_info._normSub, o._x_mean_num.length, o._mean_offset, 0, o._mean_offset.length);
            int nz = 0;
            for (t = 0; t < coxMR.countEvents.length; ++t) {
                o._total_event += coxMR.countEvents[t];
                if (!(coxMR.sizeEvents[t] > 0.0) && !(coxMR.sizeCensored[t] > 0.0)) continue;
                o._time[nz] = o._min_time + (long)t;
                o._n_risk[nz] = coxMR.sizeRiskSet[t];
                o._n_event[nz] = coxMR.sizeEvents[t];
                o._n_censor[nz] = coxMR.sizeCensored[t];
                ++nz;
            }
            if (p._start_column == null) {
                for (t = o._n_risk.length - 2; t >= 0; --t) {
                    int n = t;
                    o._n_risk[n] = o._n_risk[n] + o._n_risk[t + 1];
                }
            }
        }

        protected double calcLoglik(CoxPHModel model, final CoxPHTask coxMR) {
            int j;
            CoxPHModel.CoxPHParameters p = (CoxPHModel.CoxPHParameters)model._parms;
            CoxPHModel.CoxPHOutput o = (CoxPHModel.CoxPHOutput)model._output;
            final int n_coef = o._coef.length;
            int n_time = coxMR.sizeEvents.length;
            double newLoglik = 0.0;
            for (j = 0; j < n_coef; ++j) {
                o.gradient[j] = 0.0;
            }
            for (j = 0; j < n_coef; ++j) {
                for (int k = 0; k < n_coef; ++k) {
                    o.hessian[j][k] = 0.0;
                }
            }
            switch (p._ties) {
                case efron: {
                    int j2;
                    int t;
                    final double[] newLoglik_t = MemoryManager.malloc8d((int)n_time);
                    final double[][] gradient_t = CoxPH.malloc2DArray(n_time, n_coef);
                    final double[][][] hessian_t = CoxPH.malloc3DArray(n_time, n_coef, n_coef);
                    ForkJoinTask[] fjts = new ForkJoinTask[n_time];
                    for (t = n_time - 1; t >= 0; --t) {
                        final int _t = t;
                        fjts[t] = new RecursiveAction(){

                            protected void compute() {
                                double sizeEvents_t = coxMR.sizeEvents[_t];
                                if (sizeEvents_t > 0.0) {
                                    long countEvents_t = coxMR.countEvents[_t];
                                    double sumLogRiskEvents_t = coxMR.sumLogRiskEvents[_t];
                                    double sumRiskEvents_t = coxMR.sumRiskEvents[_t];
                                    double rcumsumRisk_t = coxMR.rcumsumRisk[_t];
                                    double avgSize = sizeEvents_t / (double)countEvents_t;
                                    newLoglik_t[_t] = sumLogRiskEvents_t;
                                    System.arraycopy(coxMR.sumXEvents[_t], 0, gradient_t[_t], 0, n_coef);
                                    for (long e = 0L; e < countEvents_t; ++e) {
                                        double frac = (double)e / (double)countEvents_t;
                                        double term = rcumsumRisk_t - frac * sumRiskEvents_t;
                                        int n = _t;
                                        newLoglik_t[n] = newLoglik_t[n] - avgSize * Math.log(term);
                                        for (int j = 0; j < n_coef; ++j) {
                                            double djTerm = coxMR.rcumsumXRisk[_t][j] - frac * coxMR.sumXRiskEvents[_t][j];
                                            double djLogTerm = djTerm / term;
                                            double[] dArray = gradient_t[_t];
                                            int n2 = j;
                                            dArray[n2] = dArray[n2] - avgSize * djLogTerm;
                                            int k = 0;
                                            while (k < n_coef) {
                                                double dkTerm = coxMR.rcumsumXRisk[_t][k] - frac * coxMR.sumXRiskEvents[_t][k];
                                                double djkTerm = coxMR.rcumsumXXRisk[_t][j][k] - frac * coxMR.sumXXRiskEvents[_t][j][k];
                                                double[] dArray2 = hessian_t[_t][j];
                                                int n3 = k++;
                                                dArray2[n3] = dArray2[n3] - avgSize * (djkTerm / term - djLogTerm * (dkTerm / term));
                                            }
                                        }
                                    }
                                }
                            }
                        };
                    }
                    ForkJoinTask.invokeAll((ForkJoinTask[])fjts);
                    for (t = 0; t < n_time; ++t) {
                        newLoglik += newLoglik_t[t];
                    }
                    for (t = 0; t < n_time; ++t) {
                        for (j2 = 0; j2 < n_coef; ++j2) {
                            int n = j2;
                            o.gradient[n] = o.gradient[n] + gradient_t[t][j2];
                        }
                    }
                    for (t = 0; t < n_time; ++t) {
                        for (j2 = 0; j2 < n_coef; ++j2) {
                            for (int k = 0; k < n_coef; ++k) {
                                double[] dArray = o.hessian[j2];
                                int n = k;
                                dArray[n] = dArray[n] + hessian_t[t][j2][k];
                            }
                        }
                    }
                    break;
                }
                case breslow: {
                    for (int t = n_time - 1; t >= 0; --t) {
                        double sizeEvents_t = coxMR.sizeEvents[t];
                        if (!(sizeEvents_t > 0.0)) continue;
                        double sumLogRiskEvents_t = coxMR.sumLogRiskEvents[t];
                        double rcumsumRisk_t = coxMR.rcumsumRisk[t];
                        newLoglik += sumLogRiskEvents_t;
                        newLoglik -= sizeEvents_t * Math.log(rcumsumRisk_t);
                        for (int j3 = 0; j3 < n_coef; ++j3) {
                            double dlogTerm = coxMR.rcumsumXRisk[t][j3] / rcumsumRisk_t;
                            int n = j3;
                            o.gradient[n] = o.gradient[n] + coxMR.sumXEvents[t][j3];
                            int n2 = j3;
                            o.gradient[n2] = o.gradient[n2] - sizeEvents_t * dlogTerm;
                            for (int k = 0; k < n_coef; ++k) {
                                double[] dArray = o.hessian[j3];
                                int n3 = k;
                                dArray[n3] = dArray[n3] - sizeEvents_t * (coxMR.rcumsumXXRisk[t][j3][k] / rcumsumRisk_t - dlogTerm * (coxMR.rcumsumXRisk[t][k] / rcumsumRisk_t));
                            }
                        }
                    }
                    break;
                }
                default: {
                    throw new IllegalArgumentException("_ties method must be either efron or breslow");
                }
            }
            return newLoglik;
        }

        protected void calcModelStats(CoxPHModel model, double[] newCoef, double newLoglik) {
            int k;
            int j;
            CoxPHModel.CoxPHParameters p = (CoxPHModel.CoxPHParameters)model._parms;
            CoxPHModel.CoxPHOutput o = (CoxPHModel.CoxPHOutput)model._output;
            int n_coef = o._coef.length;
            Matrix inv_hessian = new Matrix(o.hessian).inverse();
            for (j = 0; j < n_coef; ++j) {
                for (int k2 = 0; k2 <= j; ++k2) {
                    double elem;
                    o._var_coef[j][k2] = elem = -inv_hessian.get(j, k2);
                    o._var_coef[k2][j] = elem;
                }
            }
            for (j = 0; j < n_coef; ++j) {
                o._coef[j] = newCoef[j];
                o._exp_coef[j] = Math.exp(o._coef[j]);
                o._exp_neg_coef[j] = Math.exp(-o._coef[j]);
                o._se_coef[j] = Math.sqrt(o._var_coef[j][j]);
                o._z_coef[j] = o._coef[j] / o._se_coef[j];
            }
            if (o._iter == 0) {
                o._null_loglik = newLoglik;
                o._maxrsq = 1.0 - Math.exp(2.0 * o._null_loglik / (double)o._n);
                o._score_test = 0.0;
                for (j = 0; j < n_coef; ++j) {
                    double sum = 0.0;
                    for (k = 0; k < n_coef; ++k) {
                        sum += o._var_coef[j][k] * o.gradient[k];
                    }
                    o._score_test += o.gradient[j] * sum;
                }
            }
            o._loglik = newLoglik;
            o._loglik_test = -2.0 * (o._null_loglik - o._loglik);
            o._rsq = 1.0 - Math.exp(-o._loglik_test / (double)o._n);
            o._wald_test = 0.0;
            for (j = 0; j < n_coef; ++j) {
                double sum = 0.0;
                for (k = 0; k < n_coef; ++k) {
                    sum -= o.hessian[j][k] * (o._coef[k] - p._init);
                }
                o._wald_test += (o._coef[j] - p._init) * sum;
            }
        }

        protected void calcCumhaz_0(CoxPHModel model, CoxPHTask coxMR) {
            int t;
            CoxPHModel.CoxPHParameters p = (CoxPHModel.CoxPHParameters)model._parms;
            CoxPHModel.CoxPHOutput o = (CoxPHModel.CoxPHOutput)model._output;
            int n_coef = o._coef.length;
            int nz = 0;
            switch (p._ties) {
                case efron: {
                    double sizeCensored_t;
                    double sizeEvents_t;
                    for (t = 0; t < coxMR.sizeEvents.length; ++t) {
                        sizeEvents_t = coxMR.sizeEvents[t];
                        sizeCensored_t = coxMR.sizeCensored[t];
                        if (!(sizeEvents_t > 0.0) && !(sizeCensored_t > 0.0)) continue;
                        long countEvents_t = coxMR.countEvents[t];
                        double sumRiskEvents_t = coxMR.sumRiskEvents[t];
                        double rcumsumRisk_t = coxMR.rcumsumRisk[t];
                        double avgSize = sizeEvents_t / (double)countEvents_t;
                        o._cumhaz_0[nz] = 0.0;
                        o._var_cumhaz_1[nz] = 0.0;
                        for (int j = 0; j < n_coef; ++j) {
                            o._var_cumhaz_2[nz][j] = 0.0;
                        }
                        for (long e = 0L; e < countEvents_t; ++e) {
                            double frac = (double)e / (double)countEvents_t;
                            double haz = 1.0 / (rcumsumRisk_t - frac * sumRiskEvents_t);
                            double haz_sq = haz * haz;
                            int n = nz;
                            o._cumhaz_0[n] = o._cumhaz_0[n] + avgSize * haz;
                            int n2 = nz;
                            o._var_cumhaz_1[n2] = o._var_cumhaz_1[n2] + avgSize * haz_sq;
                            for (int j = 0; j < n_coef; ++j) {
                                double[] dArray = o._var_cumhaz_2[nz];
                                int n3 = j;
                                dArray[n3] = dArray[n3] + avgSize * ((coxMR.rcumsumXRisk[t][j] - frac * coxMR.sumXRiskEvents[t][j]) * haz_sq);
                            }
                        }
                        ++nz;
                    }
                    break;
                }
                case breslow: {
                    double sizeCensored_t;
                    double sizeEvents_t;
                    for (t = 0; t < coxMR.sizeEvents.length; ++t) {
                        double cumhaz_0_nz;
                        sizeEvents_t = coxMR.sizeEvents[t];
                        sizeCensored_t = coxMR.sizeCensored[t];
                        if (!(sizeEvents_t > 0.0) && !(sizeCensored_t > 0.0)) continue;
                        double rcumsumRisk_t = coxMR.rcumsumRisk[t];
                        o._cumhaz_0[nz] = cumhaz_0_nz = sizeEvents_t / rcumsumRisk_t;
                        o._var_cumhaz_1[nz] = sizeEvents_t / (rcumsumRisk_t * rcumsumRisk_t);
                        for (int j = 0; j < n_coef; ++j) {
                            o._var_cumhaz_2[nz][j] = coxMR.rcumsumXRisk[t][j] / rcumsumRisk_t * cumhaz_0_nz;
                        }
                        ++nz;
                    }
                    break;
                }
                default: {
                    throw new IllegalArgumentException("_ties method must be either efron or breslow");
                }
            }
            for (t = 1; t < o._cumhaz_0.length; ++t) {
                o._cumhaz_0[t] = o._cumhaz_0[t - 1] + o._cumhaz_0[t];
                o._var_cumhaz_1[t] = o._var_cumhaz_1[t - 1] + o._var_cumhaz_1[t];
                for (int j = 0; j < n_coef; ++j) {
                    o._var_cumhaz_2[t][j] = o._var_cumhaz_2[t - 1][j] + o._var_cumhaz_2[t][j];
                }
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void computeImpl() {
            CoxPHModel model = null;
            try {
                CoxPH.this.init(true);
                Frame f = this.reorderTrainFrameColumns();
                model = new CoxPHModel(CoxPH.this._job._result, (CoxPHModel.CoxPHParameters)CoxPH.this._parms, new CoxPHModel.CoxPHOutput(CoxPH.this));
                model.delete_and_lock(CoxPH.this._job);
                int nResponses = ((CoxPHModel.CoxPHParameters)CoxPH.this._parms).startVec() == null ? 2 : 3;
                DataInfo dinfo = new DataInfo(f, null, nResponses, false, DataInfo.TransformType.DEMEAN, DataInfo.TransformType.NONE, true, false, false, false, false, false);
                Scope.track_generic((Keyed)dinfo);
                DKV.put((Keyed)dinfo);
                this.initStats(model, dinfo);
                int n_offsets = CoxPH.this._offset == null ? 0 : 1;
                int n_coef = dinfo.fullN() - n_offsets;
                double[] step = MemoryManager.malloc8d((int)n_coef);
                double[] oldCoef = MemoryManager.malloc8d((int)n_coef);
                double[] newCoef = MemoryManager.malloc8d((int)n_coef);
                Arrays.fill(step, Double.NaN);
                Arrays.fill(oldCoef, Double.NaN);
                for (int j = 0; j < n_coef; ++j) {
                    newCoef[j] = ((CoxPHModel.CoxPHParameters)model._parms)._init;
                }
                double oldLoglik = -1.7976931348623157E308;
                int n_time = (int)(((CoxPHModel.CoxPHOutput)model._output)._max_time - ((CoxPHModel.CoxPHOutput)model._output)._min_time + 1L);
                boolean has_start_column = ((CoxPHModel.CoxPHParameters)model._parms).startVec() != null;
                boolean has_weights_column = CoxPH.this._weights != null;
                for (int i = 0; i <= ((CoxPHModel.CoxPHParameters)model._parms)._iter_max; ++i) {
                    int j;
                    ((CoxPHModel.CoxPHOutput)model._output)._iter = i;
                    CoxPHTask coxMR = (CoxPHTask)new CoxPHTask((Key<Job>)CoxPH.this._job._key, dinfo, newCoef, ((CoxPHModel.CoxPHOutput)model._output)._min_time, n_time, n_offsets, has_start_column, has_weights_column).doAll(dinfo._adaptedFrame);
                    double newLoglik = this.calcLoglik(model, coxMR);
                    if (newLoglik > oldLoglik) {
                        if (i == 0) {
                            this.calcCounts(model, coxMR);
                        }
                        this.calcModelStats(model, newCoef, newLoglik);
                        this.calcCumhaz_0(model, coxMR);
                        ((CoxPHModel.CoxPHOutput)model._output)._lre = newLoglik == 0.0 ? -Math.log10(Math.abs(oldLoglik - newLoglik)) : -Math.log10(Math.abs((oldLoglik - newLoglik) / newLoglik));
                        if (((CoxPHModel.CoxPHOutput)model._output)._lre >= ((CoxPHModel.CoxPHParameters)model._parms)._lre_min) break;
                        Arrays.fill(step, 0.0);
                        for (j = 0; j < n_coef; ++j) {
                            for (int k = 0; k < n_coef; ++k) {
                                int n = j;
                                step[n] = step[n] - ((CoxPHModel.CoxPHOutput)model._output)._var_coef[j][k] * ((CoxPHModel.CoxPHOutput)model._output).gradient[k];
                            }
                        }
                        for (j = 0; j < n_coef && !Double.isNaN(step[j]) && !Double.isInfinite(step[j]); ++j) {
                        }
                        oldLoglik = newLoglik;
                        System.arraycopy(newCoef, 0, oldCoef, 0, oldCoef.length);
                    } else {
                        j = 0;
                        while (j < n_coef) {
                            int n = j++;
                            step[n] = step[n] / 2.0;
                        }
                    }
                    for (j = 0; j < n_coef; ++j) {
                        newCoef[j] = oldCoef[j] - step[j];
                    }
                }
                model.update(CoxPH.this._job);
            }
            finally {
                if (model != null) {
                    model.unlock(CoxPH.this._job);
                }
            }
        }
    }
}

