/*
 * Decompiled with CFR 0.152.
 */
package hex.glm;

import hex.DataInfo;
import hex.Model;
import hex.ModelMetrics;
import hex.glm.GLMModel;
import java.util.Arrays;
import water.Job;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.util.ArrayUtils;
import water.util.FrameUtils;

public class GLMScore
extends MRTask<GLMScore> {
    final GLMModel _m;
    final Job _j;
    ModelMetrics.MetricBuilder _mb;
    final DataInfo _dinfo;
    final boolean _sparse;
    final String[] _domain;
    final boolean _computeMetrics;
    final boolean _generatePredictions;
    transient double[][] _vcov;
    transient double[] _tmp;
    transient double[] _eta;
    final int _nclasses;
    private final double[] _beta;
    private final double[][] _beta_multinomial;
    private final double _defaultThreshold;

    public GLMScore(Job j, GLMModel m, DataInfo dinfo, String[] domain, boolean computeMetrics, boolean generatePredictions) {
        this._j = j;
        this._m = m;
        this._computeMetrics = computeMetrics;
        this._sparse = FrameUtils.sparseRatio((Frame)dinfo._adaptedFrame) < 0.5;
        this._domain = domain;
        this._generatePredictions = generatePredictions;
        this._m._parms = m._parms;
        this._nclasses = ((GLMModel.GLMOutput)m._output).nclasses();
        if (((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.multinomial || ((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.ordinal) {
            this._beta = null;
            this._beta_multinomial = ((GLMModel.GLMOutput)m._output)._global_beta_multinomial;
        } else {
            double[] beta = m.beta();
            int[] ids = new int[beta.length - 1];
            int k = 0;
            for (int i = 0; i < beta.length - 1; ++i) {
                if (beta[i] == 0.0) continue;
                ids[k++] = i;
            }
            if (k < beta.length - 1) {
                ids = Arrays.copyOf(ids, k);
                dinfo = dinfo.filterExpandedColumns(ids);
                double[] beta2 = MemoryManager.malloc8d((int)(ids.length + 1));
                int l = 0;
                for (int x : ids) {
                    beta2[l++] = beta[x];
                }
                beta2[l] = beta[beta.length - 1];
                beta = beta2;
            }
            this._beta_multinomial = null;
            this._beta = beta;
        }
        this._dinfo = dinfo;
        this._dinfo._valid = true;
        this._defaultThreshold = m.defaultThreshold();
    }

    public double[] scoreRow(DataInfo.Row r, double o, double[] preds) {
        int lastClass = this._nclasses - 1;
        if (((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.ordinal) {
            double[][] bm = this._beta_multinomial;
            Arrays.fill(preds, 0.0);
            double previousCDF = 0.0;
            for (int cInd = 0; cInd < lastClass; ++cInd) {
                double eta = r.innerProduct(bm[cInd]) + o;
                double currCDF = 1.0 / (1.0 + Math.exp(-eta));
                preds[cInd + 1] = currCDF - previousCDF;
                previousCDF = currCDF;
            }
            preds[this._nclasses] = 1.0 - previousCDF;
            preds[0] = ArrayUtils.maxIndex((double[])preds) - 1;
        } else if (((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.multinomial) {
            int c;
            double[] eta = this._eta;
            double[][] bm = this._beta_multinomial;
            double sumExp = 0.0;
            double maxRow = 0.0;
            for (c = 0; c < bm.length; ++c) {
                eta[c] = r.innerProduct(bm[c]) + o;
                if (!(eta[c] > maxRow)) continue;
                maxRow = eta[c];
            }
            for (c = 0; c < bm.length; ++c) {
                eta[c] = Math.exp(eta[c] - maxRow);
                sumExp += eta[c];
            }
            sumExp = 1.0 / sumExp;
            for (c = 0; c < bm.length; ++c) {
                preds[c + 1] = eta[c] * sumExp;
            }
            preds[0] = ArrayUtils.maxIndex((double[])eta);
        } else {
            double mu = ((GLMModel.GLMParameters)this._m._parms).linkInv(r.innerProduct(this._beta) + o);
            if (((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.binomial || ((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.quasibinomial || ((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.fractionalbinomial) {
                preds[0] = mu >= this._defaultThreshold ? 1.0 : 0.0;
                preds[1] = 1.0 - mu;
                preds[2] = mu;
            } else {
                preds[0] = mu;
            }
        }
        return preds;
    }

    private void processRow(DataInfo.Row r, float[] res, double[] ps, NewChunk[] preds, int ncols) {
        if (this._dinfo._responses != 0) {
            res[0] = (float)r.response[0];
        }
        if (r.predictors_bad) {
            Arrays.fill(ps, Double.NaN);
        } else if (r.weight == 0.0) {
            Arrays.fill(ps, 0.0);
        } else {
            this.scoreRow(r, r.offset, ps);
            if (this._computeMetrics && !r.response_bad) {
                this._mb.perRow(ps, res, r.weight, r.offset, (Model)this._m);
            }
        }
        if (this._generatePredictions) {
            for (int c = 0; c < ncols; ++c) {
                preds[c].addNum(ps[c]);
            }
            if (this._vcov != null) {
                preds[ncols].addNum(Math.sqrt(r.innerProduct(r.mtrxMul(this._vcov, this._tmp))));
            }
        }
    }

    public void map(Chunk[] chks, NewChunk[] preds) {
        int ncols;
        double[] ps;
        if (this.isCancelled() || this._j != null && this._j.stop_requested()) {
            return;
        }
        if (((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.multinomial || ((GLMModel.GLMParameters)this._m._parms)._family == GLMModel.GLMParameters.Family.ordinal) {
            this._eta = MemoryManager.malloc8d((int)this._nclasses);
        }
        this._vcov = ((GLMModel.GLMOutput)this._m._output)._vcov;
        if (this._generatePredictions && this._vcov != null) {
            this._tmp = MemoryManager.malloc8d((int)this._vcov.length);
        }
        if (this._computeMetrics) {
            this._mb = this._m.makeMetricBuilder(this._domain);
            ps = this._mb._work;
        } else {
            ps = new double[((GLMModel.GLMOutput)this._m._output)._nclasses + 1];
        }
        float[] res = new float[1];
        int nc = ((GLMModel.GLMOutput)this._m._output).nclasses();
        int n = ncols = nc == 1 ? 1 : nc + 1;
        if (this._sparse) {
            for (DataInfo.Row r : this._dinfo.extractSparseRows(chks)) {
                this.processRow(r, res, ps, preds, ncols);
            }
        } else {
            DataInfo.Row r = this._dinfo.newDenseRow();
            for (int rid = 0; rid < chks[0]._len; ++rid) {
                this._dinfo.extractDenseRow(chks, rid, r);
                this.processRow(r, res, ps, preds, ncols);
            }
        }
        if (this._j != null) {
            this._j.update(1L);
        }
    }

    public void reduce(GLMScore bs) {
        if (this._mb != null) {
            this._mb.reduce(bs._mb);
        }
    }

    protected void postGlobal() {
        if (this._mb != null) {
            this._mb.postGlobal();
        }
    }
}

