/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost;

import hex.DataInfo;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix;
import water.H2O;
import water.LocalMR;
import water.MemoryManager;
import water.MrFun;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.VecUtils;

public class XGBoostUtils {
    protected static int SPARSE_MATRIX_DIM = 0x7FFFFFF5;

    public static String makeFeatureMap(Frame f, DataInfo di) {
        String[] coefnames = di.coefNames();
        StringBuilder sb = new StringBuilder();
        assert (coefnames.length == di.fullN());
        int catCols = di._catOffsets[di._catOffsets.length - 1];
        for (int i = 0; i < di.fullN(); ++i) {
            sb.append(i).append(" ").append(coefnames[i].replaceAll("\\s*", "")).append(" ");
            if (i < catCols || f.vec(i - catCols).isBinary()) {
                sb.append("i");
            } else if (f.vec(i - catCols).isInt()) {
                sb.append("int");
            } else {
                sb.append("q");
            }
            sb.append("\n");
        }
        return sb.toString();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static DMatrix convertFrameToDMatrix(DataInfo di, Frame f, String response, String weight, boolean sparse) throws XGBoostError {
        DMatrix trainMat;
        int[] nRowsByChunk;
        Vec weightVec;
        Vec responseVec;
        assert (di != null);
        int[] chunks = VecUtils.getLocalChunkIds((Vec)f.anyVec());
        long nRowsL = XGBoostUtils.sumChunksLength(chunks, responseVec = f.vec(response), weightVec = f.vec(weight), nRowsByChunk = new int[chunks.length]);
        if (nRowsL > Integer.MAX_VALUE) {
            throw new IllegalArgumentException("XGBoost currently doesn't support datasets with more than 2147483647 per node. To train a XGBoost model on this dataset add more nodes to your H2O cluster and use distributed training.");
        }
        int nRows = (int)nRowsL;
        float[] resp = MemoryManager.malloc4f((int)nRows);
        float[] weights = null;
        if (weightVec != null) {
            weights = MemoryManager.malloc4f((int)nRows);
        }
        if (sparse) {
            Log.debug((Object[])new Object[]{"Treating matrix as sparse."});
            trainMat = XGBoostUtils.csr(f, chunks, weightVec, responseVec, di, resp, weights);
        } else {
            Log.debug((Object[])new Object[]{"Treating matrix as dense."});
            BigDenseMatrix data = null;
            try {
                data = XGBoostUtils.allocateDenseMatrix(nRows, di);
                long actualRows = XGBoostUtils.denseChunk(data, chunks, nRowsByChunk, f, weightVec, responseVec, di, resp, weights);
                assert ((long)data.nrow == actualRows);
                trainMat = new DMatrix(data, Float.NaN);
            }
            finally {
                if (data != null) {
                    data.dispose();
                }
            }
        }
        assert (trainMat.rowNum() == (long)nRows);
        trainMat.setLabel(resp);
        if (weights != null) {
            trainMat.setWeight(weights);
        }
        return trainMat;
    }

    private static long sumChunksLength(int[] chunkIds, Vec vec, Vec weightsVector, int[] chunkLengths) {
        for (int i = 0; i < chunkIds.length; ++i) {
            int chunk = chunkIds[i];
            chunkLengths[i] = vec.chunkLen(chunk);
            if (weightsVector == null) continue;
            Chunk weightVecChunk = weightsVector.chunkForChunkIdx(chunk);
            if (weightVecChunk.atd(0) == 0.0) {
                int n = i;
                chunkLengths[n] = chunkLengths[n] - 1;
            }
            int nzIndex = 0;
            while ((nzIndex = weightVecChunk.nextNZ(nzIndex, true)) >= 0 && nzIndex < weightVecChunk._len) {
                if (weightVecChunk.atd(nzIndex) != 0.0) continue;
                int n = i;
                chunkLengths[n] = chunkLengths[n] - 1;
            }
        }
        long totalChunkLength = 0L;
        for (int cl : chunkLengths) {
            totalChunkLength += (long)cl;
        }
        return totalChunkLength;
    }

    private static int setResponseAndWeight(Chunk[] chunks, int respIdx, int weightIdx, float[] resp, float[] weights, int j, int i) {
        if (weightIdx != -1) {
            if (chunks[weightIdx].atd(i) == 0.0) {
                return j;
            }
            weights[j] = (float)chunks[weightIdx].atd(i);
        }
        resp[j++] = (float)chunks[respIdx].atd(i);
        return j;
    }

    private static int setResponseAndWeight(Vec.Reader w, float[] resp, float[] weights, Vec.Reader respVec, int j, long i) {
        if (w != null) {
            if (w.at(i) == 0.0) {
                return j;
            }
            weights[j] = (float)w.at(i);
        }
        resp[j++] = (float)respVec.at(i);
        return j;
    }

    public static DMatrix convertChunksToDMatrix(DataInfo di, Chunk[] chunks, int response, boolean sparse) throws XGBoostError {
        DMatrix trainMat;
        int nRows = chunks[0]._len;
        float[] resp = MemoryManager.malloc4f((int)nRows);
        try {
            if (sparse) {
                Log.debug((Object[])new Object[]{"Treating matrix as sparse."});
                trainMat = XGBoostUtils.csr(chunks, -1, response, di, resp, null);
            } else {
                trainMat = XGBoostUtils.dense(chunks, di, response, resp, null);
            }
        }
        catch (NegativeArraySizeException e) {
            throw new IllegalArgumentException(H2O.technote((int)11, (String)"Data is too large to fit into the 32-bit Java float[] array that needs to be passed to the XGBoost C++ backend. Use H2O GBM instead."));
        }
        int len = (int)trainMat.rowNum();
        resp = Arrays.copyOf(resp, len);
        trainMat.setLabel(resp);
        return trainMat;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static DMatrix dense(Chunk[] chunks, DataInfo di, int respIdx, float[] resp, float[] weights) throws XGBoostError {
        Log.debug((Object[])new Object[]{"Treating matrix as dense."});
        BigDenseMatrix data = null;
        try {
            data = XGBoostUtils.allocateDenseMatrix(chunks[0].len(), di);
            long actualRows = XGBoostUtils.denseChunk(data, chunks, respIdx, di, resp, weights);
            assert (actualRows == (long)data.nrow);
            DMatrix dMatrix = new DMatrix(data, Float.NaN);
            return dMatrix;
        }
        finally {
            if (data != null) {
                data.dispose();
            }
        }
    }

    private static long denseChunk(BigDenseMatrix data, int[] chunks, int[] nRowsByChunk, Frame f, Vec weightsVec, Vec respVec, DataInfo di, float[] resp, float[] weights) {
        int[] offsets = new int[nRowsByChunk.length + 1];
        for (int i = 0; i < chunks.length; ++i) {
            offsets[i + 1] = nRowsByChunk[i] + offsets[i];
        }
        WriteDenseChunkFun writeFun = new WriteDenseChunkFun(f, chunks, offsets, weightsVec, respVec, di, data, resp, weights);
        ((LocalMR)H2O.submitTask((H2O.H2OCountedCompleter)new LocalMR((MrFun)writeFun, chunks.length))).join();
        return writeFun.getTotalRows();
    }

    private static long denseChunk(BigDenseMatrix data, Chunk[] chunks, int respIdx, DataInfo di, float[] resp, float[] weights) {
        long idx = 0L;
        long actualRows = 0L;
        int rwRow = 0;
        for (int i = 0; i < chunks[0]._len; ++i) {
            idx = XGBoostUtils.writeDenseRow(di, chunks, i, data, idx);
            ++actualRows;
            rwRow = XGBoostUtils.setResponseAndWeight(chunks, respIdx, -1, resp, weights, rwRow, i);
        }
        assert ((long)data.nrow * (long)data.ncol == idx);
        return actualRows;
    }

    private static long writeDenseRow(DataInfo di, Chunk[] chunks, int rowInChunk, BigDenseMatrix data, long idx) {
        int j;
        for (j = 0; j < di._cats; ++j) {
            int len = di._catOffsets[j + 1] - di._catOffsets[j];
            double val = chunks[j].isNA(rowInChunk) ? Double.NaN : (double)chunks[j].at8(rowInChunk);
            int pos = di.getCategoricalId(j, val) - di._catOffsets[j];
            for (int cat = 0; cat < len; ++cat) {
                data.set(idx + (long)cat, 0.0f);
            }
            data.set(idx + (long)pos, 1.0f);
            idx += (long)len;
        }
        for (j = 0; j < di._nums; ++j) {
            float val = chunks[di._cats + j].isNA(rowInChunk) ? Float.NaN : (float)chunks[di._cats + j].atd(rowInChunk);
            data.set(idx++, val);
        }
        return idx;
    }

    private static DMatrix csr(Frame f, int[] chunksIds, Vec weightsVec, Vec responseVec, DataInfo di, float[] resp, float[] weights) throws XGBoostError {
        SparseMatrixDimensions sparseMatrixDimensions = XGBoostUtils.calculateCSRMatrixDimensions(f, chunksIds, weightsVec, di);
        SparseMatrix sparseMatrix = XGBoostUtils.allocateCSRMatrix(sparseMatrixDimensions);
        Vec.Reader[] vecs = new Vec.Reader[f.numCols()];
        for (int i = 0; i < vecs.length; ++i) {
            vecs[i] = new Vec.Reader(f.vec(i));
        }
        Vec.Reader weightsReader = weightsVec != null ? new Vec.Reader(weightsVec) : null;
        Vec.Reader responseReader = new Vec.Reader(responseVec);
        int actualRows = XGBoostUtils.initalizeFromChunkIds(f, chunksIds, vecs, weightsReader, di, sparseMatrix._rowHeaders, sparseMatrix._sparseData, sparseMatrix._colIndices, responseReader, resp, weights);
        return XGBoostUtils.toDMatrix(sparseMatrix, sparseMatrixDimensions, actualRows, di);
    }

    private static DMatrix csr(Chunk[] chunks, int weight, int respIdx, DataInfo di, float[] resp, float[] weights) throws XGBoostError {
        SparseMatrixDimensions sparseMatrixDimensions = XGBoostUtils.calculateCSRMatrixDimensions(chunks, di, weight);
        SparseMatrix sparseMatrix = XGBoostUtils.allocateCSRMatrix(sparseMatrixDimensions);
        int actualRows = XGBoostUtils.initializeFromChunks(chunks, weight, di, sparseMatrix._rowHeaders, sparseMatrix._sparseData, sparseMatrix._colIndices, respIdx, resp, weights);
        return XGBoostUtils.toDMatrix(sparseMatrix, sparseMatrixDimensions, actualRows, di);
    }

    private static DMatrix toDMatrix(SparseMatrix sm, SparseMatrixDimensions smd, int actualRows, DataInfo di) throws XGBoostError {
        DMatrix trainMat = new DMatrix(sm._rowHeaders, sm._colIndices, sm._sparseData, DMatrix.SparseType.CSR, di.fullN(), actualRows + 1, smd._nonZeroElementsCount);
        assert (trainMat.rowNum() == (long)actualRows);
        return trainMat;
    }

    static int initalizeFromChunkIds(Frame f, int[] chunks, Vec.Reader[] vecs, Vec.Reader w, DataInfo di, long[][] rowHeaders, float[][] data, int[][] colIndex, Vec.Reader respVec, float[] resp, float[] weights) {
        int actualRows = 0;
        int nonZeroCount = 0;
        int rowPointer = 0;
        int currentCol = 0;
        int rwRow = 0;
        int rowHeaderRowPointer = 0;
        int rowHeaderColPointer = 0;
        int lastNonZeroRow = 0;
        int[] nArray = chunks;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            Integer chunk = nArray[i];
            for (long i2 = f.anyVec().espc()[chunk]; i2 < f.anyVec().espc()[chunk + 1]; ++i2) {
                int j;
                if (w != null && w.at(i2) == 0.0) continue;
                ++actualRows;
                if (rowHeaderColPointer == SPARSE_MATRIX_DIM) {
                    rowHeaderColPointer = 0;
                    ++rowHeaderRowPointer;
                }
                boolean foundNonZero = false;
                for (j = 0; j < di._cats; ++j) {
                    if (currentCol == SPARSE_MATRIX_DIM) {
                        currentCol = 0;
                        ++rowPointer;
                    }
                    data[rowPointer][currentCol] = 1.0f;
                    if (!foundNonZero) {
                        foundNonZero = true;
                        for (int k = lastNonZeroRow; k < actualRows; ++k) {
                            rowHeaders[rowHeaderRowPointer][rowHeaderColPointer++] = nonZeroCount;
                        }
                        lastNonZeroRow = actualRows;
                    }
                    colIndex[rowPointer][currentCol++] = vecs[j].isNA(i2) ? di.getCategoricalId(j, Double.NaN) : di.getCategoricalId(j, (double)vecs[j].at8(i2));
                    ++nonZeroCount;
                }
                for (j = 0; j < di._nums; ++j) {
                    float val;
                    if (currentCol == SPARSE_MATRIX_DIM) {
                        currentCol = 0;
                        ++rowPointer;
                    }
                    if ((val = (float)vecs[di._cats + j].at(i2)) == 0.0f) continue;
                    data[rowPointer][currentCol] = val;
                    colIndex[rowPointer][currentCol++] = di._catOffsets[di._catOffsets.length - 1] + j;
                    if (!foundNonZero) {
                        foundNonZero = true;
                        for (int k = lastNonZeroRow; k < actualRows; ++k) {
                            rowHeaders[rowHeaderRowPointer][rowHeaderColPointer++] = nonZeroCount;
                        }
                        lastNonZeroRow = actualRows;
                    }
                    ++nonZeroCount;
                }
                rwRow = XGBoostUtils.setResponseAndWeight(w, resp, weights, respVec, rwRow, i2);
            }
        }
        for (int k = lastNonZeroRow; k <= actualRows; ++k) {
            if (rowHeaderColPointer == SPARSE_MATRIX_DIM) {
                rowHeaderColPointer = 0;
                ++rowHeaderRowPointer;
            }
            rowHeaders[rowHeaderRowPointer][rowHeaderColPointer++] = nonZeroCount;
        }
        return actualRows;
    }

    private static int initializeFromChunks(Chunk[] chunks, int weight, DataInfo di, long[][] rowHeaders, float[][] data, int[][] colIndex, int respIdx, float[] resp, float[] weights) {
        int actualRows = 0;
        int nonZeroCount = 0;
        int rowPointer = 0;
        int currentCol = 0;
        int rwRow = 0;
        int rowHeaderRowPointer = 0;
        int rowHeaderColPointer = 0;
        int lastNonZeroRow = 0;
        for (int i = 0; i < chunks[0].len(); ++i) {
            int j;
            if (weight != -1 && chunks[weight].atd(i) == 0.0) continue;
            ++actualRows;
            if (rowHeaderColPointer == SPARSE_MATRIX_DIM) {
                rowHeaderColPointer = 0;
                ++rowHeaderRowPointer;
            }
            boolean foundNonZero = false;
            for (j = 0; j < di._cats; ++j) {
                if (currentCol == SPARSE_MATRIX_DIM) {
                    currentCol = 0;
                    ++rowPointer;
                }
                data[rowPointer][currentCol] = 1.0f;
                if (!foundNonZero) {
                    foundNonZero = true;
                    for (int k = lastNonZeroRow; k < actualRows; ++k) {
                        rowHeaders[rowHeaderRowPointer][rowHeaderColPointer++] = nonZeroCount;
                    }
                    lastNonZeroRow = actualRows;
                }
                colIndex[rowPointer][currentCol++] = chunks[j].isNA(i) ? di.getCategoricalId(j, Double.NaN) : di.getCategoricalId(j, (double)chunks[j].at8(i));
                ++nonZeroCount;
            }
            for (j = 0; j < di._nums; ++j) {
                float val;
                if (currentCol == SPARSE_MATRIX_DIM) {
                    currentCol = 0;
                    ++rowPointer;
                }
                if ((val = (float)chunks[di._cats + j].atd(i)) == 0.0f) continue;
                data[rowPointer][currentCol] = val;
                colIndex[rowPointer][currentCol++] = di._catOffsets[di._catOffsets.length - 1] + j;
                if (!foundNonZero) {
                    foundNonZero = true;
                    for (int k = lastNonZeroRow; k < actualRows; ++k) {
                        rowHeaders[rowHeaderRowPointer][rowHeaderColPointer++] = nonZeroCount;
                    }
                    lastNonZeroRow = actualRows;
                }
                ++nonZeroCount;
            }
            rwRow = XGBoostUtils.setResponseAndWeight(chunks, respIdx, weight, resp, weights, rwRow, i);
        }
        for (int k = lastNonZeroRow; k <= actualRows; ++k) {
            if (rowHeaderColPointer == SPARSE_MATRIX_DIM) {
                rowHeaderColPointer = 0;
                ++rowHeaderRowPointer;
            }
            rowHeaders[rowHeaderRowPointer][rowHeaderColPointer++] = nonZeroCount;
        }
        return actualRows;
    }

    protected static SparseMatrix allocateCSRMatrix(SparseMatrixDimensions sparseMatrixDimensions) {
        int dataRowsNumber = (int)(sparseMatrixDimensions._nonZeroElementsCount / (long)SPARSE_MATRIX_DIM);
        int dataLastRowSize = (int)(sparseMatrixDimensions._nonZeroElementsCount % (long)SPARSE_MATRIX_DIM);
        int rowIndicesRowsNumber = (int)(sparseMatrixDimensions._rowHeadersCount / (long)SPARSE_MATRIX_DIM);
        int rowIndicesLastRowSize = (int)(sparseMatrixDimensions._rowHeadersCount % (long)SPARSE_MATRIX_DIM);
        int colIndicesRowsNumber = dataRowsNumber;
        int colIndicesLastRowSize = dataLastRowSize;
        float[][] sparseData = new float[dataLastRowSize == 0 ? dataRowsNumber : dataRowsNumber + 1][];
        int iterationLimit = dataLastRowSize == 0 ? sparseData.length : sparseData.length - 1;
        for (int sparseDataRow = 0; sparseDataRow < iterationLimit; ++sparseDataRow) {
            sparseData[sparseDataRow] = MemoryManager.malloc4f((int)SPARSE_MATRIX_DIM);
        }
        if (dataLastRowSize > 0) {
            sparseData[sparseData.length - 1] = MemoryManager.malloc4f((int)dataLastRowSize);
        }
        long[][] rowIndices = new long[rowIndicesLastRowSize == 0 ? rowIndicesRowsNumber : rowIndicesRowsNumber + 1][];
        iterationLimit = rowIndicesLastRowSize == 0 ? rowIndices.length : rowIndices.length - 1;
        for (int rowIndicesRow = 0; rowIndicesRow < iterationLimit; ++rowIndicesRow) {
            rowIndices[rowIndicesRow] = MemoryManager.malloc8((int)SPARSE_MATRIX_DIM);
        }
        if (rowIndicesLastRowSize > 0) {
            rowIndices[rowIndices.length - 1] = MemoryManager.malloc8((int)rowIndicesLastRowSize);
        }
        int[][] colIndices = new int[colIndicesLastRowSize == 0 ? colIndicesRowsNumber : colIndicesRowsNumber + 1][];
        iterationLimit = colIndicesLastRowSize == 0 ? colIndices.length : colIndices.length - 1;
        for (int colIndicesRow = 0; colIndicesRow < iterationLimit; ++colIndicesRow) {
            colIndices[colIndicesRow] = MemoryManager.malloc4((int)SPARSE_MATRIX_DIM);
        }
        if (colIndicesLastRowSize > 0) {
            colIndices[colIndices.length - 1] = MemoryManager.malloc4((int)colIndicesLastRowSize);
        }
        return new SparseMatrix(sparseData, rowIndices, colIndices);
    }

    protected static SparseMatrixDimensions calculateCSRMatrixDimensions(Chunk[] chunks, DataInfo di, int weightColIndex) {
        long nonZeroElementsCount = 0L;
        long rowIndicesCount = 0L;
        for (int i = 0; i < chunks[0].len(); ++i) {
            if (weightColIndex != -1 && chunks[weightColIndex].atd(i) == 0.0) continue;
            ++rowIndicesCount;
            nonZeroElementsCount += (long)di._cats;
            for (int j = 0; j < di._nums; ++j) {
                double val = chunks[di._cats + j].atd(i);
                if (val == 0.0) continue;
                ++nonZeroElementsCount;
            }
        }
        return new SparseMatrixDimensions(nonZeroElementsCount, ++rowIndicesCount);
    }

    static SparseMatrixDimensions calculateCSRMatrixDimensions(Frame f, int[] chunkIds, Vec w, DataInfo di) {
        CalculateCSRMatrixDimensionsMrFun fun = new CalculateCSRMatrixDimensionsMrFun(f, di, w, chunkIds);
        ((LocalMR)H2O.submitTask((H2O.H2OCountedCompleter)new LocalMR((MrFun)fun, chunkIds.length))).join();
        return new SparseMatrixDimensions(ArrayUtils.sum((long[])fun._nonZeroElementsCounts), ArrayUtils.sum((long[])fun._rowIndicesCounts) + 1L);
    }

    private static BigDenseMatrix allocateDenseMatrix(int rowCount, DataInfo dataInfo) {
        return new BigDenseMatrix(rowCount, dataInfo.fullN());
    }

    public static FeatureProperties assembleFeatureNames(DataInfo di) {
        String[] coefnames = di.coefNames();
        assert (coefnames.length == di.fullN());
        int numCatCols = di._catOffsets[di._catOffsets.length - 1];
        String[] featureNames = new String[di.fullN()];
        boolean[] oneHotEncoded = new boolean[di.fullN()];
        for (int i = 0; i < di.fullN(); ++i) {
            featureNames[i] = coefnames[i];
            if (i >= numCatCols) continue;
            oneHotEncoded[i] = true;
        }
        return new FeatureProperties(featureNames, oneHotEncoded);
    }

    static Map<String, FeatureScore> parseFeatureScores(String[] modelDump) {
        HashMap<String, FeatureScore> featureScore = new HashMap<String, FeatureScore>();
        for (String tree : modelDump) {
            for (String node : tree.split("\n")) {
                String[] keyValues;
                String[] content;
                String[] array = node.split("\\[", 2);
                if (array.length < 2 || (content = array[1].split("\\]", 2)).length < 2) continue;
                String fid = content[0].split("<")[0];
                FeatureScore fs = new FeatureScore();
                for (String keyValue : keyValues = content[1].split(",")) {
                    if (keyValue.startsWith("gain=")) {
                        fs._gain = Float.parseFloat(keyValue.substring("gain".length() + 1));
                        continue;
                    }
                    if (!keyValue.startsWith("cover=")) continue;
                    fs._cover = Float.parseFloat(keyValue.substring("cover".length() + 1));
                }
                fs._frequency = 1;
                if (featureScore.containsKey(fid)) {
                    ((FeatureScore)featureScore.get(fid)).add(fs);
                    continue;
                }
                featureScore.put(fid, fs);
            }
        }
        return featureScore;
    }

    static class FeatureScore {
        static final String GAIN_KEY = "gain";
        static final String COVER_KEY = "cover";
        int _frequency;
        float _gain;
        float _cover;

        FeatureScore() {
        }

        void add(FeatureScore fs) {
            this._frequency += fs._frequency;
            this._gain += fs._gain;
            this._cover += fs._cover;
        }
    }

    static class FeatureProperties {
        public String[] _names;
        public boolean[] _oneHotEncoded;

        public FeatureProperties(String[] names, boolean[] oneHotEncoded) {
            this._names = names;
            this._oneHotEncoded = oneHotEncoded;
        }
    }

    protected static final class SparseMatrix {
        protected final float[][] _sparseData;
        protected final long[][] _rowHeaders;
        protected final int[][] _colIndices;

        public SparseMatrix(float[][] sparseData, long[][] rowIndices, int[][] colIndices) {
            this._sparseData = sparseData;
            this._rowHeaders = rowIndices;
            this._colIndices = colIndices;
        }
    }

    protected static final class SparseMatrixDimensions {
        protected final long _nonZeroElementsCount;
        protected final long _rowHeadersCount;

        public SparseMatrixDimensions(long nonZeroElementsCount, long rowIndicesCount) {
            this._nonZeroElementsCount = nonZeroElementsCount;
            this._rowHeadersCount = rowIndicesCount;
        }
    }

    private static class CalculateCSRMatrixDimensionsMrFun
    extends MrFun<CalculateCSRMatrixDimensionsMrFun> {
        private Frame _f;
        private DataInfo _di;
        private Vec _w;
        private int[] _chunkIds;
        private long[] _rowIndicesCounts;
        private long[] _nonZeroElementsCounts;

        CalculateCSRMatrixDimensionsMrFun(Frame f, DataInfo di, Vec w, int[] chunkIds) {
            this._f = f;
            this._di = di;
            this._w = w;
            this._chunkIds = chunkIds;
            this._rowIndicesCounts = new long[chunkIds.length];
            this._nonZeroElementsCounts = new long[chunkIds.length];
        }

        protected void map(int i) {
            int cidx = this._chunkIds[i];
            long rowIndicesCount = 0L;
            long nonZeroElementsCount = 0L;
            if (this._di._nums == 0) {
                if (this._w == null) {
                    rowIndicesCount = this._f.anyVec().chunkForChunkIdx((int)cidx)._len;
                    nonZeroElementsCount = rowIndicesCount * (long)this._di._cats;
                } else {
                    Chunk ws = this._w.chunkForChunkIdx(cidx);
                    int nzWeights = 0;
                    for (int r = 0; r < ws._len; ++r) {
                        if (ws.atd(r) == 0.0) continue;
                        ++nzWeights;
                    }
                    rowIndicesCount += (long)nzWeights;
                    nonZeroElementsCount += (long)(nzWeights * this._di._cats);
                }
            } else {
                Chunk[] cs = new Chunk[this._di._nums];
                for (int c = 0; c < cs.length; ++c) {
                    cs[c] = this._f.vec(this._di._cats + c).chunkForChunkIdx(cidx);
                }
                Chunk ws = this._w != null ? this._w.chunkForChunkIdx(cidx) : null;
                for (int r = 0; r < cs[0]._len; ++r) {
                    if (ws != null && ws.atd(r) == 0.0) continue;
                    ++rowIndicesCount;
                    nonZeroElementsCount += (long)this._di._cats;
                    for (int j = 0; j < this._di._nums; ++j) {
                        if (cs[j].atd(r) == 0.0) continue;
                        ++nonZeroElementsCount;
                    }
                }
            }
            this._rowIndicesCounts[i] = rowIndicesCount;
            this._nonZeroElementsCounts[i] = nonZeroElementsCount;
        }
    }

    static class SparseItem {
        int pos;
        double val;

        SparseItem() {
        }
    }

    private static class WriteDenseChunkFun
    extends MrFun<WriteDenseChunkFun> {
        private final Frame _f;
        private final int[] _chunks;
        private final int[] _offsets;
        private final Vec _weightsVec;
        private final Vec _respVec;
        private final DataInfo _di;
        private final BigDenseMatrix _data;
        private final float[] _resp;
        private final float[] _weights;
        private int[] _nRowsByChunk;

        private WriteDenseChunkFun(Frame f, int[] chunks, int[] offsets, Vec weightsVec, Vec respVec, DataInfo di, BigDenseMatrix data, float[] resp, float[] weights) {
            this._f = f;
            this._chunks = chunks;
            this._offsets = offsets;
            this._weightsVec = weightsVec;
            this._respVec = respVec;
            this._di = di;
            this._data = data;
            this._resp = resp;
            this._weights = weights;
            this._nRowsByChunk = new int[chunks.length];
        }

        protected void map(int id) {
            int chunkIdx = this._chunks[id];
            Chunk[] chks = new Chunk[this._f.numCols()];
            for (int c = 0; c < chks.length; ++c) {
                chks[c] = this._f.vec(c).chunkForChunkIdx(chunkIdx);
            }
            Chunk weightsChk = this._weightsVec != null ? this._weightsVec.chunkForChunkIdx(chunkIdx) : null;
            Chunk respChk = this._respVec.chunkForChunkIdx(chunkIdx);
            long idx = this._offsets[id] * this._data.ncol;
            int actualRows = 0;
            for (int i = 0; i < chks[0]._len; ++i) {
                if (weightsChk != null && weightsChk.atd(i) == 0.0) continue;
                idx = XGBoostUtils.writeDenseRow(this._di, chks, i, this._data, idx);
                this._resp[this._offsets[id] + actualRows] = (float)respChk.atd(i);
                if (weightsChk != null) {
                    this._weights[this._offsets[id] + actualRows] = (float)weightsChk.atd(i);
                }
                ++actualRows;
            }
            assert (idx == (long)this._offsets[id + 1] * (long)this._data.ncol);
            this._nRowsByChunk[id] = actualRows;
        }

        private long getTotalRows() {
            long totalRows = 0L;
            for (int r : this._nRowsByChunk) {
                totalRows += (long)r;
            }
            return totalRows;
        }
    }
}

