/*
 * Decompiled with CFR 0.152.
 */
package hex.coxph;

import Jama.Matrix;
import hex.DataInfo;
import hex.FrameTask;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.SupervisedModel;
import hex.SupervisedModelBuilder;
import hex.coxph.CoxPHModel;
import hex.schemas.ModelBuilderSchema;
import java.util.Arrays;
import jsr166y.ForkJoinTask;
import jsr166y.RecursiveAction;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.MemoryManager;
import water.Scope;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

public class CoxPH
extends SupervisedModelBuilder<CoxPHModel, CoxPHModel.CoxPHParameters, CoxPHModel.CoxPHOutput> {
    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Unknown};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    public CoxPH(CoxPHModel.CoxPHParameters parms) {
        super("CoxPHLearning", (SupervisedModel.SupervisedParameters)parms);
        this.init(false);
    }

    public ModelBuilderSchema schema() {
        H2O.unimpl();
        return null;
    }

    public Job<CoxPHModel> trainModel() {
        CoxPHDriver cd = new CoxPHDriver();
        cd.setModelBuilderTrain(this._train);
        CoxPH cph = (CoxPH)this.start(cd, ((CoxPHModel.CoxPHParameters)this._parms).iter_max);
        return cph;
    }

    public void init(boolean expensive) {
        super.init(expensive);
        if (((CoxPHModel.CoxPHParameters)this._parms).start_column != null && !((CoxPHModel.CoxPHParameters)this._parms).start_column.isInt()) {
            this.error("start_column", "start time must be null or of type integer");
        }
        if (!((CoxPHModel.CoxPHParameters)this._parms).stop_column.isInt()) {
            this.error("stop_column", "stop time must be of type integer");
        }
        if (!((CoxPHModel.CoxPHParameters)this._parms).event_column.isInt() && !((CoxPHModel.CoxPHParameters)this._parms).event_column.isEnum()) {
            this.error("event_column", "event must be of type integer or factor");
        }
        if (Double.isNaN(((CoxPHModel.CoxPHParameters)this._parms).lre_min) || ((CoxPHModel.CoxPHParameters)this._parms).lre_min <= 0.0) {
            this.error("lre_min", "lre_min must be a positive number");
        }
        if (((CoxPHModel.CoxPHParameters)this._parms).iter_max < 1) {
            this.error("iter_max", "iter_max must be a positive integer");
        }
        int MAX_TIME_BINS = 10000;
        long min_time = ((CoxPHModel.CoxPHParameters)this._parms).start_column == null ? (long)((CoxPHModel.CoxPHParameters)this._parms).stop_column.min() : (long)((CoxPHModel.CoxPHParameters)this._parms).start_column.min() + 1L;
        int n_time = (int)(((CoxPHModel.CoxPHParameters)this._parms).stop_column.max() - (double)min_time + 1.0);
        if (n_time < 1) {
            this.error("start_column", "start times must be strictly less than stop times");
        }
        if (n_time > 10000) {
            this.error("stop_column", "number of distinct stop times is " + n_time + "; maximum number allowed is " + 10000);
        }
    }

    private static double[][] malloc2DArray(int d1, int d2) {
        double[][] array = new double[d1][];
        for (int j = 0; j < d1; ++j) {
            array[j] = MemoryManager.malloc8d((int)d2);
        }
        return array;
    }

    private static double[][][] malloc3DArray(int d1, int d2, int d3) {
        double[][][] array = new double[d1][d2][];
        for (int j = 0; j < d1; ++j) {
            for (int k = 0; k < d2; ++k) {
                array[j][k] = MemoryManager.malloc8d((int)d3);
            }
        }
        return array;
    }

    protected static class CoxPHTask
    extends FrameTask<CoxPHTask> {
        private final double[] _beta;
        private final int _n_time;
        private final long _min_time;
        private final int _n_offsets;
        private final boolean _has_start_column;
        private final boolean _has_weights_column;
        protected long n;
        protected long n_missing;
        protected double sumWeights;
        protected double[] sumWeightedCatX;
        protected double[] sumWeightedNumX;
        protected double[] sizeRiskSet;
        protected double[] sizeCensored;
        protected double[] sizeEvents;
        protected long[] countEvents;
        protected double[][] sumXEvents;
        protected double[] sumRiskEvents;
        protected double[][] sumXRiskEvents;
        protected double[][][] sumXXRiskEvents;
        protected double[] sumLogRiskEvents;
        protected double[] rcumsumRisk;
        protected double[][] rcumsumXRisk;
        protected double[][][] rcumsumXXRisk;

        CoxPHTask(Key jobKey, DataInfo dinfo, double[] beta, long min_time, int n_time, int n_offsets, boolean has_start_column, boolean has_weights_column) {
            super(jobKey, dinfo);
            this._beta = beta;
            this._n_time = n_time;
            this._min_time = min_time;
            this._n_offsets = n_offsets;
            this._has_start_column = has_start_column;
            this._has_weights_column = has_weights_column;
        }

        @Override
        protected void chunkInit() {
            int n_coef = this._beta.length;
            this.sumWeightedCatX = MemoryManager.malloc8d((int)(n_coef - (this._dinfo._nums - this._n_offsets)));
            this.sumWeightedNumX = MemoryManager.malloc8d((int)this._dinfo._nums);
            this.sizeRiskSet = MemoryManager.malloc8d((int)this._n_time);
            this.sizeCensored = MemoryManager.malloc8d((int)this._n_time);
            this.sizeEvents = MemoryManager.malloc8d((int)this._n_time);
            this.countEvents = MemoryManager.malloc8((int)this._n_time);
            this.sumRiskEvents = MemoryManager.malloc8d((int)this._n_time);
            this.sumLogRiskEvents = MemoryManager.malloc8d((int)this._n_time);
            this.rcumsumRisk = MemoryManager.malloc8d((int)this._n_time);
            this.sumXEvents = CoxPH.malloc2DArray(this._n_time, n_coef);
            this.sumXRiskEvents = CoxPH.malloc2DArray(this._n_time, n_coef);
            this.rcumsumXRisk = CoxPH.malloc2DArray(this._n_time, n_coef);
            this.sumXXRiskEvents = CoxPH.malloc3DArray(this._n_time, n_coef, n_coef);
            this.rcumsumXXRisk = CoxPH.malloc3DArray(this._n_time, n_coef, n_coef);
        }

        @Override
        protected void processRow(long gid, DataInfo.Row row) {
            int j;
            int j2;
            int t2;
            double weight;
            ++this.n;
            double[] response = row.response;
            int ncats = row.nBins;
            int[] cats = row.numIds;
            double[] nums = row.numVals;
            double d = weight = this._has_weights_column ? response[0] : 1.0;
            if (weight <= 0.0) {
                throw new IllegalArgumentException("weights must be positive values");
            }
            long event = (long)response[response.length - 1];
            int t1 = this._has_start_column ? (int)((long)response[response.length - 3] + 1L - this._min_time) : -1;
            if (t1 > (t2 = (int)((long)response[response.length - 2] - this._min_time))) {
                throw new IllegalArgumentException("start times must be strictly less than stop times");
            }
            int numStart = this._dinfo.numStart();
            this.sumWeights += weight;
            for (j2 = 0; j2 < ncats; ++j2) {
                int n = cats[j2];
                this.sumWeightedCatX[n] = this.sumWeightedCatX[n] + weight;
            }
            for (j2 = 0; j2 < nums.length; ++j2) {
                int n = j2;
                this.sumWeightedNumX[n] = this.sumWeightedNumX[n] + weight * nums[j2];
            }
            double logRisk = 0.0;
            for (j = 0; j < ncats; ++j) {
                logRisk += this._beta[cats[j]];
            }
            for (j = 0; j < nums.length - this._n_offsets; ++j) {
                logRisk += nums[j] * this._beta[numStart + j];
            }
            for (j = nums.length - this._n_offsets; j < nums.length; ++j) {
                logRisk += nums[j];
            }
            double risk = weight * Math.exp(logRisk);
            logRisk *= weight;
            if (event > 0L) {
                int n = t2;
                this.countEvents[n] = this.countEvents[n] + 1L;
                int n2 = t2;
                this.sizeEvents[n2] = this.sizeEvents[n2] + weight;
                int n3 = t2;
                this.sumLogRiskEvents[n3] = this.sumLogRiskEvents[n3] + logRisk;
                int n4 = t2;
                this.sumRiskEvents[n4] = this.sumRiskEvents[n4] + risk;
            } else {
                int n = t2;
                this.sizeCensored[n] = this.sizeCensored[n] + weight;
            }
            if (this._has_start_column) {
                int t = t1;
                while (t <= t2) {
                    int n = t++;
                    this.sizeRiskSet[n] = this.sizeRiskSet[n] + weight;
                }
                t = t1;
                while (t <= t2) {
                    int n = t++;
                    this.rcumsumRisk[n] = this.rcumsumRisk[n] + risk;
                }
            } else {
                int n = t2;
                this.sizeRiskSet[n] = this.sizeRiskSet[n] + weight;
                int n5 = t2;
                this.rcumsumRisk[n5] = this.rcumsumRisk[n5] + risk;
            }
            int ntotal = ncats + (nums.length - this._n_offsets);
            int numStartIter = numStart - ncats;
            for (int jit = 0; jit < ntotal; ++jit) {
                boolean jIsCat = jit < ncats;
                int j3 = jIsCat ? cats[jit] : numStartIter + jit;
                double x1 = jIsCat ? 1.0 : nums[jit - ncats];
                double xRisk = x1 * risk;
                if (event > 0L) {
                    double[] dArray = this.sumXEvents[t2];
                    int n = j3;
                    dArray[n] = dArray[n] + weight * x1;
                    double[] dArray2 = this.sumXRiskEvents[t2];
                    int n6 = j3;
                    dArray2[n6] = dArray2[n6] + xRisk;
                }
                if (this._has_start_column) {
                    for (int t = t1; t <= t2; ++t) {
                        double[] dArray = this.rcumsumXRisk[t];
                        int n = j3;
                        dArray[n] = dArray[n] + xRisk;
                    }
                } else {
                    double[] dArray = this.rcumsumXRisk[t2];
                    int n = j3;
                    dArray[n] = dArray[n] + xRisk;
                }
                for (int kit = 0; kit < ntotal; ++kit) {
                    boolean kIsCat = kit < ncats;
                    int k = kIsCat ? cats[kit] : numStartIter + kit;
                    double x2 = kIsCat ? 1.0 : nums[kit - ncats];
                    double xxRisk = x2 * xRisk;
                    if (event > 0L) {
                        double[] dArray = this.sumXXRiskEvents[t2][j3];
                        int n = k;
                        dArray[n] = dArray[n] + xxRisk;
                    }
                    if (this._has_start_column) {
                        for (int t = t1; t <= t2; ++t) {
                            double[] dArray = this.rcumsumXXRisk[t][j3];
                            int n = k;
                            dArray[n] = dArray[n] + xxRisk;
                        }
                        continue;
                    }
                    double[] dArray = this.rcumsumXXRisk[t2][j3];
                    int n = k;
                    dArray[n] = dArray[n] + xxRisk;
                }
            }
        }

        public void reduce(CoxPHTask that) {
            this.n += that.n;
            this.sumWeights += that.sumWeights;
            ArrayUtils.add((double[])this.sumWeightedCatX, (double[])that.sumWeightedCatX);
            ArrayUtils.add((double[])this.sumWeightedNumX, (double[])that.sumWeightedNumX);
            ArrayUtils.add((double[])this.sizeRiskSet, (double[])that.sizeRiskSet);
            ArrayUtils.add((double[])this.sizeCensored, (double[])that.sizeCensored);
            ArrayUtils.add((double[])this.sizeEvents, (double[])that.sizeEvents);
            ArrayUtils.add((long[])this.countEvents, (long[])that.countEvents);
            ArrayUtils.add((double[][])this.sumXEvents, (double[][])that.sumXEvents);
            ArrayUtils.add((double[])this.sumRiskEvents, (double[])that.sumRiskEvents);
            ArrayUtils.add((double[][])this.sumXRiskEvents, (double[][])that.sumXRiskEvents);
            ArrayUtils.add((double[][][])this.sumXXRiskEvents, (double[][][])that.sumXXRiskEvents);
            ArrayUtils.add((double[])this.sumLogRiskEvents, (double[])that.sumLogRiskEvents);
            ArrayUtils.add((double[])this.rcumsumRisk, (double[])that.rcumsumRisk);
            ArrayUtils.add((double[][])this.rcumsumXRisk, (double[][])that.rcumsumXRisk);
            ArrayUtils.add((double[][][])this.rcumsumXXRisk, (double[][][])that.rcumsumXXRisk);
        }

        protected void postGlobal() {
            if (!this._has_start_column) {
                int j;
                int t;
                for (t = this.rcumsumRisk.length - 2; t >= 0; --t) {
                    int n = t;
                    this.rcumsumRisk[n] = this.rcumsumRisk[n] + this.rcumsumRisk[t + 1];
                }
                for (t = this.rcumsumXRisk.length - 2; t >= 0; --t) {
                    for (j = 0; j < this.rcumsumXRisk[t].length; ++j) {
                        double[] dArray = this.rcumsumXRisk[t];
                        int n = j;
                        dArray[n] = dArray[n] + this.rcumsumXRisk[t + 1][j];
                    }
                }
                for (t = this.rcumsumXXRisk.length - 2; t >= 0; --t) {
                    for (j = 0; j < this.rcumsumXXRisk[t].length; ++j) {
                        for (int k = 0; k < this.rcumsumXXRisk[t][j].length; ++k) {
                            double[] dArray = this.rcumsumXXRisk[t][j];
                            int n = k;
                            dArray[n] = dArray[n] + this.rcumsumXXRisk[t + 1][j][k];
                        }
                    }
                }
            }
        }
    }

    public class CoxPHDriver
    extends H2O.H2OCountedCompleter<CoxPHDriver> {
        private Frame _modelBuilderTrain = null;

        public void setModelBuilderTrain(Frame v) {
            this._modelBuilderTrain = v;
        }

        private void applyScoringFrameSideEffects() {
            int offset_ncol;
            int n = offset_ncol = ((CoxPHModel.CoxPHParameters)CoxPH.this._parms).offset_columns == null ? 0 : ((CoxPHModel.CoxPHParameters)CoxPH.this._parms).offset_columns.length;
            if (offset_ncol == 0) {
                return;
            }
            int numCols = this._modelBuilderTrain.numCols();
            String responseVecName = this._modelBuilderTrain.names()[numCols - 1];
            Vec responseVec = this._modelBuilderTrain.remove(numCols - 1);
            for (int i = 0; i < offset_ncol; ++i) {
                Vec offsetVec = ((CoxPHModel.CoxPHParameters)CoxPH.this._parms).offset_columns[i];
                int idxInRawFrame = CoxPH.this._train.find(offsetVec);
                if (idxInRawFrame < 0) {
                    throw new RuntimeException("CoxPHDriver failed to find offsetVec");
                }
                String offsetVecName = ((CoxPHModel.CoxPHParameters)CoxPH.this._parms).train().names()[idxInRawFrame];
                this._modelBuilderTrain.add(offsetVecName, offsetVec);
            }
            this._modelBuilderTrain.add(responseVecName, responseVec);
        }

        private void applyTrainingFrameSideEffects() {
            int idxInRawFrame;
            boolean use_start_column;
            int numCols = this._modelBuilderTrain.numCols();
            String responseVecName = this._modelBuilderTrain.names()[numCols - 1];
            Vec responseVec = this._modelBuilderTrain.remove(numCols - 1);
            boolean use_weights_column = ((CoxPHModel.CoxPHParameters)CoxPH.this._parms).weights_column != null;
            boolean bl = use_start_column = ((CoxPHModel.CoxPHParameters)CoxPH.this._parms).start_column != null;
            if (use_weights_column) {
                Vec weightsVec = ((CoxPHModel.CoxPHParameters)CoxPH.this._parms).weights_column;
                idxInRawFrame = CoxPH.this._train.find(weightsVec);
                if (idxInRawFrame < 0) {
                    throw new RuntimeException("CoxPHDriver failed to find weightVec");
                }
                String weightsVecName = ((CoxPHModel.CoxPHParameters)CoxPH.this._parms).train().names()[idxInRawFrame];
                this._modelBuilderTrain.add(weightsVecName, weightsVec);
            }
            if (use_start_column) {
                Vec startVec = ((CoxPHModel.CoxPHParameters)CoxPH.this._parms).start_column;
                idxInRawFrame = CoxPH.this._train.find(startVec);
                if (idxInRawFrame < 0) {
                    throw new RuntimeException("CoxPHDriver failed to find startVec");
                }
                String startVecName = ((CoxPHModel.CoxPHParameters)CoxPH.this._parms).train().names()[idxInRawFrame];
                this._modelBuilderTrain.add(startVecName, startVec);
            }
            Vec stopVec = ((CoxPHModel.CoxPHParameters)CoxPH.this._parms).stop_column;
            idxInRawFrame = CoxPH.this._train.find(stopVec);
            if (idxInRawFrame < 0) {
                throw new RuntimeException("CoxPHDriver failed to find stopVec");
            }
            String stopVecName = ((CoxPHModel.CoxPHParameters)CoxPH.this._parms).train().names()[idxInRawFrame];
            this._modelBuilderTrain.add(stopVecName, stopVec);
            this._modelBuilderTrain.add(responseVecName, responseVec);
        }

        protected void initStats(CoxPHModel model, DataInfo dinfo) {
            CoxPHModel.CoxPHParameters p = (CoxPHModel.CoxPHParameters)model._parms;
            CoxPHModel.CoxPHOutput o = (CoxPHModel.CoxPHOutput)model._output;
            o.n = p.stop_column.length();
            o.data_info = dinfo;
            int n_offsets = p.offset_columns == null ? 0 : p.offset_columns.length;
            int n_coef = o.data_info.fullN() - n_offsets;
            String[] coefNames = o.data_info.coefNames();
            o.coef_names = new String[n_coef];
            System.arraycopy(coefNames, 0, o.coef_names, 0, n_coef);
            o.coef = MemoryManager.malloc8d((int)n_coef);
            o.exp_coef = MemoryManager.malloc8d((int)n_coef);
            o.exp_neg_coef = MemoryManager.malloc8d((int)n_coef);
            o.se_coef = MemoryManager.malloc8d((int)n_coef);
            o.z_coef = MemoryManager.malloc8d((int)n_coef);
            o.gradient = MemoryManager.malloc8d((int)n_coef);
            o.hessian = CoxPH.malloc2DArray(n_coef, n_coef);
            o.var_coef = CoxPH.malloc2DArray(n_coef, n_coef);
            o.x_mean_cat = MemoryManager.malloc8d((int)(n_coef - (o.data_info._nums - n_offsets)));
            o.x_mean_num = MemoryManager.malloc8d((int)(o.data_info._nums - n_offsets));
            o.mean_offset = MemoryManager.malloc8d((int)n_offsets);
            o.offset_names = new String[n_offsets];
            System.arraycopy(coefNames, n_coef, o.offset_names, 0, n_offsets);
            Vec start_column = p.start_column;
            Vec stop_column = p.stop_column;
            o.min_time = p.start_column == null ? (long)stop_column.min() : (long)start_column.min() + 1L;
            o.max_time = (long)stop_column.max();
            int n_time = ((Vec.CollectDomain)new Vec.CollectDomain().doAll(new Vec[]{stop_column})).domain().length;
            o.time = MemoryManager.malloc8((int)n_time);
            o.n_risk = MemoryManager.malloc8d((int)n_time);
            o.n_event = MemoryManager.malloc8d((int)n_time);
            o.n_censor = MemoryManager.malloc8d((int)n_time);
            o.cumhaz_0 = MemoryManager.malloc8d((int)n_time);
            o.var_cumhaz_1 = MemoryManager.malloc8d((int)n_time);
            o.var_cumhaz_2 = CoxPH.malloc2DArray(n_time, n_coef);
        }

        protected void calcCounts(CoxPHModel model, CoxPHTask coxMR) {
            int t;
            int j;
            CoxPHModel.CoxPHParameters p = (CoxPHModel.CoxPHParameters)model._parms;
            CoxPHModel.CoxPHOutput o = (CoxPHModel.CoxPHOutput)model._output;
            o.n_missing = o.n - coxMR.n;
            o.n = coxMR.n;
            for (j = 0; j < o.x_mean_cat.length; ++j) {
                o.x_mean_cat[j] = coxMR.sumWeightedCatX[j] / coxMR.sumWeights;
            }
            for (j = 0; j < o.x_mean_num.length; ++j) {
                o.x_mean_num[j] = coxMR.dinfo()._normSub[j] + coxMR.sumWeightedNumX[j] / coxMR.sumWeights;
            }
            System.arraycopy(coxMR.dinfo()._normSub, o.x_mean_num.length, o.mean_offset, 0, o.mean_offset.length);
            int nz = 0;
            for (t = 0; t < coxMR.countEvents.length; ++t) {
                o.total_event += coxMR.countEvents[t];
                if (!(coxMR.sizeEvents[t] > 0.0) && !(coxMR.sizeCensored[t] > 0.0)) continue;
                o.time[nz] = o.min_time + (long)t;
                o.n_risk[nz] = coxMR.sizeRiskSet[t];
                o.n_event[nz] = coxMR.sizeEvents[t];
                o.n_censor[nz] = coxMR.sizeCensored[t];
                ++nz;
            }
            if (p.start_column == null) {
                for (t = o.n_risk.length - 2; t >= 0; --t) {
                    int n = t;
                    o.n_risk[n] = o.n_risk[n] + o.n_risk[t + 1];
                }
            }
        }

        protected double calcLoglik(CoxPHModel model, final CoxPHTask coxMR) {
            int j;
            CoxPHModel.CoxPHParameters p = (CoxPHModel.CoxPHParameters)model._parms;
            CoxPHModel.CoxPHOutput o = (CoxPHModel.CoxPHOutput)model._output;
            final int n_coef = o.coef.length;
            int n_time = coxMR.sizeEvents.length;
            double newLoglik = 0.0;
            for (j = 0; j < n_coef; ++j) {
                o.gradient[j] = 0.0;
            }
            for (j = 0; j < n_coef; ++j) {
                for (int k = 0; k < n_coef; ++k) {
                    o.hessian[j][k] = 0.0;
                }
            }
            switch (p.ties) {
                case efron: {
                    int j2;
                    int t;
                    final double[] newLoglik_t = MemoryManager.malloc8d((int)n_time);
                    final double[][] gradient_t = CoxPH.malloc2DArray(n_time, n_coef);
                    final double[][][] hessian_t = CoxPH.malloc3DArray(n_time, n_coef, n_coef);
                    ForkJoinTask[] fjts = new ForkJoinTask[n_time];
                    for (t = n_time - 1; t >= 0; --t) {
                        final int _t = t;
                        fjts[t] = new RecursiveAction(){

                            protected void compute() {
                                double sizeEvents_t = coxMR.sizeEvents[_t];
                                if (sizeEvents_t > 0.0) {
                                    long countEvents_t = coxMR.countEvents[_t];
                                    double sumLogRiskEvents_t = coxMR.sumLogRiskEvents[_t];
                                    double sumRiskEvents_t = coxMR.sumRiskEvents[_t];
                                    double rcumsumRisk_t = coxMR.rcumsumRisk[_t];
                                    double avgSize = sizeEvents_t / (double)countEvents_t;
                                    newLoglik_t[_t] = sumLogRiskEvents_t;
                                    System.arraycopy(coxMR.sumXEvents[_t], 0, gradient_t[_t], 0, n_coef);
                                    for (long e = 0L; e < countEvents_t; ++e) {
                                        double frac = (double)e / (double)countEvents_t;
                                        double term = rcumsumRisk_t - frac * sumRiskEvents_t;
                                        int n = _t;
                                        newLoglik_t[n] = newLoglik_t[n] - avgSize * Math.log(term);
                                        for (int j = 0; j < n_coef; ++j) {
                                            double djTerm = coxMR.rcumsumXRisk[_t][j] - frac * coxMR.sumXRiskEvents[_t][j];
                                            double djLogTerm = djTerm / term;
                                            double[] dArray = gradient_t[_t];
                                            int n2 = j;
                                            dArray[n2] = dArray[n2] - avgSize * djLogTerm;
                                            int k = 0;
                                            while (k < n_coef) {
                                                double dkTerm = coxMR.rcumsumXRisk[_t][k] - frac * coxMR.sumXRiskEvents[_t][k];
                                                double djkTerm = coxMR.rcumsumXXRisk[_t][j][k] - frac * coxMR.sumXXRiskEvents[_t][j][k];
                                                double[] dArray2 = hessian_t[_t][j];
                                                int n3 = k++;
                                                dArray2[n3] = dArray2[n3] - avgSize * (djkTerm / term - djLogTerm * (dkTerm / term));
                                            }
                                        }
                                    }
                                }
                            }
                        };
                    }
                    ForkJoinTask.invokeAll((ForkJoinTask[])fjts);
                    for (t = 0; t < n_time; ++t) {
                        newLoglik += newLoglik_t[t];
                    }
                    for (t = 0; t < n_time; ++t) {
                        for (j2 = 0; j2 < n_coef; ++j2) {
                            int n = j2;
                            o.gradient[n] = o.gradient[n] + gradient_t[t][j2];
                        }
                    }
                    for (t = 0; t < n_time; ++t) {
                        for (j2 = 0; j2 < n_coef; ++j2) {
                            for (int k = 0; k < n_coef; ++k) {
                                double[] dArray = o.hessian[j2];
                                int n = k;
                                dArray[n] = dArray[n] + hessian_t[t][j2][k];
                            }
                        }
                    }
                    break;
                }
                case breslow: {
                    for (int t = n_time - 1; t >= 0; --t) {
                        double sizeEvents_t = coxMR.sizeEvents[t];
                        if (!(sizeEvents_t > 0.0)) continue;
                        double sumLogRiskEvents_t = coxMR.sumLogRiskEvents[t];
                        double rcumsumRisk_t = coxMR.rcumsumRisk[t];
                        newLoglik += sumLogRiskEvents_t;
                        newLoglik -= sizeEvents_t * Math.log(rcumsumRisk_t);
                        for (int j3 = 0; j3 < n_coef; ++j3) {
                            double dlogTerm = coxMR.rcumsumXRisk[t][j3] / rcumsumRisk_t;
                            int n = j3;
                            o.gradient[n] = o.gradient[n] + coxMR.sumXEvents[t][j3];
                            int n2 = j3;
                            o.gradient[n2] = o.gradient[n2] - sizeEvents_t * dlogTerm;
                            for (int k = 0; k < n_coef; ++k) {
                                double[] dArray = o.hessian[j3];
                                int n3 = k;
                                dArray[n3] = dArray[n3] - sizeEvents_t * (coxMR.rcumsumXXRisk[t][j3][k] / rcumsumRisk_t - dlogTerm * (coxMR.rcumsumXRisk[t][k] / rcumsumRisk_t));
                            }
                        }
                    }
                    break;
                }
                default: {
                    throw new IllegalArgumentException("ties method must be either efron or breslow");
                }
            }
            return newLoglik;
        }

        protected void calcModelStats(CoxPHModel model, double[] newCoef, double newLoglik) {
            int k;
            int j;
            CoxPHModel.CoxPHParameters p = (CoxPHModel.CoxPHParameters)model._parms;
            CoxPHModel.CoxPHOutput o = (CoxPHModel.CoxPHOutput)model._output;
            int n_coef = o.coef.length;
            Matrix inv_hessian = new Matrix(o.hessian).inverse();
            for (j = 0; j < n_coef; ++j) {
                for (int k2 = 0; k2 <= j; ++k2) {
                    double elem;
                    o.var_coef[j][k2] = elem = -inv_hessian.get(j, k2);
                    o.var_coef[k2][j] = elem;
                }
            }
            for (j = 0; j < n_coef; ++j) {
                o.coef[j] = newCoef[j];
                o.exp_coef[j] = Math.exp(o.coef[j]);
                o.exp_neg_coef[j] = Math.exp(-o.coef[j]);
                o.se_coef[j] = Math.sqrt(o.var_coef[j][j]);
                o.z_coef[j] = o.coef[j] / o.se_coef[j];
            }
            if (o.iter == 0) {
                o.null_loglik = newLoglik;
                o.maxrsq = 1.0 - Math.exp(2.0 * o.null_loglik / (double)o.n);
                o.score_test = 0.0;
                for (j = 0; j < n_coef; ++j) {
                    double sum = 0.0;
                    for (k = 0; k < n_coef; ++k) {
                        sum += o.var_coef[j][k] * o.gradient[k];
                    }
                    o.score_test += o.gradient[j] * sum;
                }
            }
            o.loglik = newLoglik;
            o.loglik_test = -2.0 * (o.null_loglik - o.loglik);
            o.rsq = 1.0 - Math.exp(-o.loglik_test / (double)o.n);
            o.wald_test = 0.0;
            for (j = 0; j < n_coef; ++j) {
                double sum = 0.0;
                for (k = 0; k < n_coef; ++k) {
                    sum -= o.hessian[j][k] * (o.coef[k] - p.init);
                }
                o.wald_test += (o.coef[j] - p.init) * sum;
            }
        }

        protected void calcCumhaz_0(CoxPHModel model, CoxPHTask coxMR) {
            int t;
            CoxPHModel.CoxPHParameters p = (CoxPHModel.CoxPHParameters)model._parms;
            CoxPHModel.CoxPHOutput o = (CoxPHModel.CoxPHOutput)model._output;
            int n_coef = o.coef.length;
            int nz = 0;
            switch (p.ties) {
                case efron: {
                    double sizeCensored_t;
                    double sizeEvents_t;
                    for (t = 0; t < coxMR.sizeEvents.length; ++t) {
                        sizeEvents_t = coxMR.sizeEvents[t];
                        sizeCensored_t = coxMR.sizeCensored[t];
                        if (!(sizeEvents_t > 0.0) && !(sizeCensored_t > 0.0)) continue;
                        long countEvents_t = coxMR.countEvents[t];
                        double sumRiskEvents_t = coxMR.sumRiskEvents[t];
                        double rcumsumRisk_t = coxMR.rcumsumRisk[t];
                        double avgSize = sizeEvents_t / (double)countEvents_t;
                        o.cumhaz_0[nz] = 0.0;
                        o.var_cumhaz_1[nz] = 0.0;
                        for (int j = 0; j < n_coef; ++j) {
                            o.var_cumhaz_2[nz][j] = 0.0;
                        }
                        for (long e = 0L; e < countEvents_t; ++e) {
                            double frac = (double)e / (double)countEvents_t;
                            double haz = 1.0 / (rcumsumRisk_t - frac * sumRiskEvents_t);
                            double haz_sq = haz * haz;
                            int n = nz;
                            o.cumhaz_0[n] = o.cumhaz_0[n] + avgSize * haz;
                            int n2 = nz;
                            o.var_cumhaz_1[n2] = o.var_cumhaz_1[n2] + avgSize * haz_sq;
                            for (int j = 0; j < n_coef; ++j) {
                                double[] dArray = o.var_cumhaz_2[nz];
                                int n3 = j;
                                dArray[n3] = dArray[n3] + avgSize * ((coxMR.rcumsumXRisk[t][j] - frac * coxMR.sumXRiskEvents[t][j]) * haz_sq);
                            }
                        }
                        ++nz;
                    }
                    break;
                }
                case breslow: {
                    double sizeCensored_t;
                    double sizeEvents_t;
                    for (t = 0; t < coxMR.sizeEvents.length; ++t) {
                        double cumhaz_0_nz;
                        sizeEvents_t = coxMR.sizeEvents[t];
                        sizeCensored_t = coxMR.sizeCensored[t];
                        if (!(sizeEvents_t > 0.0) && !(sizeCensored_t > 0.0)) continue;
                        double rcumsumRisk_t = coxMR.rcumsumRisk[t];
                        o.cumhaz_0[nz] = cumhaz_0_nz = sizeEvents_t / rcumsumRisk_t;
                        o.var_cumhaz_1[nz] = sizeEvents_t / (rcumsumRisk_t * rcumsumRisk_t);
                        for (int j = 0; j < n_coef; ++j) {
                            o.var_cumhaz_2[nz][j] = coxMR.rcumsumXRisk[t][j] / rcumsumRisk_t * cumhaz_0_nz;
                        }
                        ++nz;
                    }
                    break;
                }
                default: {
                    throw new IllegalArgumentException("ties method must be either efron or breslow");
                }
            }
            for (t = 1; t < o.cumhaz_0.length; ++t) {
                o.cumhaz_0[t] = o.cumhaz_0[t - 1] + o.cumhaz_0[t];
                o.var_cumhaz_1[t] = o.var_cumhaz_1[t - 1] + o.var_cumhaz_1[t];
                for (int j = 0; j < n_coef; ++j) {
                    o.var_cumhaz_2[t][j] = o.var_cumhaz_2[t - 1][j] + o.var_cumhaz_2[t][j];
                }
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        protected void compute2() {
            block16: {
                CoxPHModel model = null;
                try {
                    Scope.enter();
                    ((CoxPHModel.CoxPHParameters)CoxPH.this._parms).read_lock_frames((Job)CoxPH.this);
                    CoxPH.this.init(true);
                    this.applyScoringFrameSideEffects();
                    model = new CoxPHModel(CoxPH.this.dest(), (CoxPHModel.CoxPHParameters)CoxPH.this._parms, new CoxPHModel.CoxPHOutput(CoxPH.this));
                    model.delete_and_lock(CoxPH.this._key);
                    this.applyTrainingFrameSideEffects();
                    int nResponses = 1;
                    boolean useAllFactorLevels = false;
                    DataInfo dinfo = new DataInfo(Key.make(), this._modelBuilderTrain, null, nResponses, useAllFactorLevels, DataInfo.TransformType.DEMEAN, DataInfo.TransformType.NONE, true, false);
                    this.initStats(model, dinfo);
                    int n_offsets = ((CoxPHModel.CoxPHParameters)model._parms).offset_columns == null ? 0 : ((CoxPHModel.CoxPHParameters)model._parms).offset_columns.length;
                    int n_coef = dinfo.fullN() - n_offsets;
                    double[] step = MemoryManager.malloc8d((int)n_coef);
                    double[] oldCoef = MemoryManager.malloc8d((int)n_coef);
                    double[] newCoef = MemoryManager.malloc8d((int)n_coef);
                    Arrays.fill(step, Double.NaN);
                    Arrays.fill(oldCoef, Double.NaN);
                    for (int j = 0; j < n_coef; ++j) {
                        newCoef[j] = ((CoxPHModel.CoxPHParameters)model._parms).init;
                    }
                    double oldLoglik = -1.7976931348623157E308;
                    int n_time = (int)(((CoxPHModel.CoxPHOutput)model._output).max_time - ((CoxPHModel.CoxPHOutput)model._output).min_time + 1L);
                    boolean has_start_column = ((CoxPHModel.CoxPHParameters)model._parms).start_column != null;
                    boolean has_weights_column = ((CoxPHModel.CoxPHParameters)model._parms).weights_column != null;
                    for (int i = 0; i <= ((CoxPHModel.CoxPHParameters)model._parms).iter_max; ++i) {
                        int j;
                        ((CoxPHModel.CoxPHOutput)model._output).iter = i;
                        CoxPHTask coxMR = (CoxPHTask)new CoxPHTask(this.self(), dinfo, newCoef, ((CoxPHModel.CoxPHOutput)model._output).min_time, n_time, n_offsets, has_start_column, has_weights_column).doAll(dinfo._adaptedFrame);
                        double newLoglik = this.calcLoglik(model, coxMR);
                        if (newLoglik > oldLoglik) {
                            if (i == 0) {
                                this.calcCounts(model, coxMR);
                            }
                            this.calcModelStats(model, newCoef, newLoglik);
                            this.calcCumhaz_0(model, coxMR);
                            ((CoxPHModel.CoxPHOutput)model._output).lre = newLoglik == 0.0 ? -Math.log10(Math.abs(oldLoglik - newLoglik)) : -Math.log10(Math.abs((oldLoglik - newLoglik) / newLoglik));
                            if (((CoxPHModel.CoxPHOutput)model._output).lre >= ((CoxPHModel.CoxPHParameters)model._parms).lre_min) break;
                            Arrays.fill(step, 0.0);
                            for (j = 0; j < n_coef; ++j) {
                                for (int k = 0; k < n_coef; ++k) {
                                    int n = j;
                                    step[n] = step[n] - ((CoxPHModel.CoxPHOutput)model._output).var_coef[j][k] * ((CoxPHModel.CoxPHOutput)model._output).gradient[k];
                                }
                            }
                            for (j = 0; j < n_coef && !Double.isNaN(step[j]) && !Double.isInfinite(step[j]); ++j) {
                            }
                            oldLoglik = newLoglik;
                            System.arraycopy(newCoef, 0, oldCoef, 0, oldCoef.length);
                        } else {
                            j = 0;
                            while (j < n_coef) {
                                int n = j++;
                                step[n] = step[n] / 2.0;
                            }
                        }
                        for (j = 0; j < n_coef; ++j) {
                            newCoef[j] = oldCoef[j] - step[j];
                        }
                    }
                    model.update(CoxPH.this._key);
                }
                catch (Throwable t) {
                    Job thisJob = (Job)DKV.getGet((Key)CoxPH.this._key);
                    if (thisJob._state == Job.JobState.CANCELLED) {
                        Log.info((Object[])new Object[]{"Job cancelled by user."});
                        break block16;
                    }
                    t.printStackTrace();
                    CoxPH.this.failed(t);
                    throw t;
                }
                finally {
                    ((CoxPHModel.CoxPHParameters)CoxPH.this._parms).read_unlock_frames((Job)CoxPH.this);
                    Scope.exit((Key[])new Key[0]);
                    CoxPH.this.done();
                }
            }
            this.tryComplete();
        }

        Key self() {
            return CoxPH.this._key;
        }
    }
}

