/*
 * Decompiled with CFR 0.152.
 */
package hex.coxph;

import hex.DataInfo;
import hex.coxph.CPHBaseTask;
import hex.coxph.CoxPH;
import hex.coxph.CoxPHUtils;
import hex.coxph.EfronDJKSetupFun;
import water.util.ArrayUtils;

class EfronDJKTermTask
extends CPHBaseTask<EfronDJKTermTask> {
    private final CoxPH.CoxPHTask _coxMR;
    private final EfronDJKSetupFun _setup;
    double[][] _djkTerm;

    EfronDJKTermTask(DataInfo dinfo, CoxPH.CoxPHTask coxMR, EfronDJKSetupFun setup) {
        super(dinfo);
        this._coxMR = coxMR;
        this._setup = setup;
    }

    @Override
    protected void chunkInit() {
        int n_coef = this._coxMR._beta.length;
        this._djkTerm = CoxPHUtils.malloc2DArray(n_coef, n_coef);
    }

    @Override
    protected void processRow(DataInfo.Row row) {
        int j;
        double strata;
        double weight;
        double[] response = row.response;
        int ncats = row.nBins;
        int[] cats = row.binIds;
        double[] nums = row.numVals;
        double d = weight = this._coxMR._has_weights_column ? response[0] : 1.0;
        if (weight <= 0.0) {
            throw new IllegalArgumentException("weights must be positive values");
        }
        int respIdx = response.length - 1;
        long event = (long)(response[respIdx--] - (double)this._coxMR._min_event);
        int t2 = (int)response[respIdx--];
        int t1 = this._coxMR._has_start_column ? (int)response[respIdx--] : -1;
        double d2 = strata = this._coxMR._has_strata_column ? response[respIdx--] : 0.0;
        assert (respIdx == -1) : "expected to use all response data";
        if (Double.isNaN(strata)) {
            return;
        }
        int numStart = this._dinfo.numStart();
        double logRisk = 0.0;
        for (j = 0; j < ncats; ++j) {
            logRisk += this._coxMR._beta[cats[j]];
        }
        for (j = 0; j < nums.length - this._coxMR._n_offsets; ++j) {
            logRisk += nums[j] * this._coxMR._beta[numStart + j];
        }
        for (j = nums.length - this._coxMR._n_offsets; j < nums.length; ++j) {
            logRisk += nums[j];
        }
        double risk = weight * Math.exp(logRisk);
        int ntotal = ncats + (nums.length - this._coxMR._n_offsets);
        int numStartIter = numStart - ncats;
        double cumsumRiskTerm = this._coxMR._has_start_column && t1 % this._coxMR._time.length > 0 ? this._setup._cumsumRiskTerm[t2] - this._setup._cumsumRiskTerm[t1 - 1] : this._setup._cumsumRiskTerm[t2];
        double riskTermT2 = event > 0L ? this._setup._riskTermT2[t2] : 0.0;
        double mult = (riskTermT2 - cumsumRiskTerm) * risk;
        for (int jit = 0; jit < ntotal; ++jit) {
            boolean jIsCat = jit < ncats;
            int j2 = jIsCat ? cats[jit] : numStartIter + jit;
            double x1 = jIsCat ? 1.0 : nums[jit - ncats];
            double x1mult = x1 * mult;
            for (int kit = jit; kit < ntotal; ++kit) {
                boolean kIsCat = kit < ncats;
                int k = kIsCat ? cats[kit] : numStartIter + kit;
                double x2 = kIsCat ? 1.0 : nums[kit - ncats];
                double[] dArray = this._djkTerm[j2];
                int n = k;
                dArray[n] = dArray[n] + x1mult * x2;
            }
        }
    }

    public void reduce(EfronDJKTermTask that) {
        ArrayUtils.add((double[][])this._djkTerm, (double[][])that._djkTerm);
    }

    protected void postGlobal() {
        for (int j = 1; j < this._djkTerm.length; ++j) {
            for (int k = 0; k < j; ++k) {
                this._djkTerm[j][k] = this._djkTerm[k][j];
            }
        }
    }
}

