/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost.task;

import hex.DataInfo;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostModelInfo;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.XGBoostUtils;
import hex.tree.xgboost.exec.XGBoostHttpClient;
import hex.tree.xgboost.matrix.MatrixFactoryUtils;
import hex.tree.xgboost.matrix.SparseMatrixDimensions;
import hex.tree.xgboost.matrix.SparseMatrixFactory;
import hex.tree.xgboost.remote.RemoteXGBoostUploadServlet;
import hex.tree.xgboost.task.AbstractXGBoostTask;
import java.util.Optional;
import org.apache.log4j.Logger;
import water.BootstrapFreezable;
import water.H2O;
import water.Iced;
import water.LocalMR;
import water.MemoryManager;
import water.MrFun;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.VecUtils;

public class XGBoostUploadMatrixTask
extends AbstractXGBoostTask<XGBoostUploadMatrixTask> {
    private static final Logger LOG = Logger.getLogger(XGBoostUploadMatrixTask.class);
    private final String[] remoteNodes;
    private final boolean https;
    private final String contextPath;
    private final String userName;
    private final String password;
    private final Frame train;
    private final XGBoostModelInfo modelInfo;
    private final XGBoostModel.XGBoostParameters parms;
    private final boolean sparse;

    public XGBoostUploadMatrixTask(XGBoostModel model, Frame train, boolean[] frameNodes, String[] remoteNodes, boolean https, String contextPath, String userName, String password) {
        super(model._key, frameNodes);
        this.remoteNodes = remoteNodes;
        this.https = https;
        this.contextPath = contextPath;
        this.userName = userName;
        this.password = password;
        this.modelInfo = model.model_info();
        this.parms = (XGBoostModel.XGBoostParameters)model._parms;
        this.sparse = ((XGBoostOutput)model._output)._sparse;
        this.train = train;
    }

    private XGBoostHttpClient makeClient() {
        String remoteUri = this.remoteNodes[H2O.SELF.index()] + this.contextPath;
        return new XGBoostHttpClient(remoteUri, this.https, this.userName, this.password);
    }

    @Override
    protected void execute() {
        XGBoostHttpClient client = this.makeClient();
        LOG.info((Object)("Starting matrix upload for " + this._modelKey));
        long start = System.currentTimeMillis();
        assert (this.modelInfo.dataInfo() != null);
        int[] chunks = VecUtils.getLocalChunkIds((Vec)this.train.anyVec());
        Vec responseVec = this.train.vec(this.parms._response_column);
        Vec weightVec = this.train.vec(this.parms._weights_column);
        Vec offsetsVec = this.train.vec(this.parms._offset_column);
        int[] nRowsByChunk = new int[chunks.length];
        long nRowsL = XGBoostUtils.sumChunksLength(chunks, responseVec, Optional.ofNullable(weightVec), nRowsByChunk);
        if (nRowsL > Integer.MAX_VALUE) {
            throw new IllegalArgumentException("XGBoost currently doesn't support datasets with more than 2147483647 per node. To train a XGBoost model on this dataset add more nodes to your H2O cluster and use distributed training.");
        }
        int nRows = (int)nRowsL;
        MatrixData matrixData = new MatrixData(nRows, weightVec, offsetsVec);
        if (this.sparse) {
            LOG.debug((Object)"Treating matrix as sparse.");
            matrixData.shape = this.modelInfo.dataInfo().fullN();
            matrixData.actualRows = this.csr(client, chunks, weightVec, offsetsVec, responseVec, this.modelInfo.dataInfo(), matrixData.resp, matrixData.weights, matrixData.offsets);
        } else {
            LOG.debug((Object)"Treating matrix as dense.");
            matrixData.actualRows = this.dense(client, chunks, nRows, nRowsByChunk, weightVec, offsetsVec, responseVec, this.modelInfo.dataInfo(), matrixData.resp, matrixData.weights, matrixData.offsets);
        }
        client.uploadObject(this._modelKey, RemoteXGBoostUploadServlet.RequestType.matrixData, matrixData);
        LOG.debug((Object)("Matrix upload finished in " + (double)(System.currentTimeMillis() - start) / 1000.0));
    }

    private int dense(XGBoostHttpClient client, int[] chunksIds, int nRows, int[] nRowsByChunk, Vec weightVec, Vec offsetsVec, Vec responseVec, DataInfo dataInfo, float[] resp, float[] weights, float[] offsets) {
        int[] rowOffsets = new int[nRowsByChunk.length + 1];
        for (int i = 0; i < chunksIds.length; ++i) {
            rowOffsets[i + 1] = nRowsByChunk[i] + rowOffsets[i];
        }
        client.uploadObject(this._modelKey, RemoteXGBoostUploadServlet.RequestType.denseMatrixDimensions, new DenseMatrixDimensions(nRows, dataInfo.fullN(), rowOffsets));
        UploadDenseChunkFun writeFun = new UploadDenseChunkFun(this.train, chunksIds, rowOffsets, weightVec, offsetsVec, responseVec, dataInfo, resp, weights, offsets);
        ((LocalMR)H2O.submitTask((H2O.H2OCountedCompleter)new LocalMR((MrFun)writeFun, chunksIds.length))).join();
        return writeFun.getTotalRows();
    }

    private int csr(XGBoostHttpClient client, int[] chunksIds, Vec weightVec, Vec offsetsVec, Vec responseVec, DataInfo dataInfo, float[] resp, float[] weights, float[] offsets) {
        SparseMatrixDimensions dimensions = SparseMatrixFactory.calculateCSRMatrixDimensions(this.train, chunksIds, weightVec, dataInfo);
        client.uploadObject(this._modelKey, RemoteXGBoostUploadServlet.RequestType.sparseMatrixDimensions, dimensions);
        UploadSparseMatrixFun fun = new UploadSparseMatrixFun(this.train, chunksIds, weightVec, offsetsVec, dataInfo, dimensions, responseVec, resp, weights, offsets);
        ((LocalMR)H2O.submitTask((H2O.H2OCountedCompleter)new LocalMR((MrFun)fun, chunksIds.length))).join();
        return ArrayUtils.sum((int[])fun._actualRows);
    }

    private class UploadSparseMatrixFun
    extends MrFun<UploadSparseMatrixFun> {
        Frame _frame;
        int[] _chunks;
        Vec _weightVec;
        Vec _offsetsVec;
        DataInfo _di;
        SparseMatrixDimensions _dims;
        Vec _respVec;
        float[] _resp;
        float[] _weights;
        float[] _offsets;
        int[] _actualRows;

        UploadSparseMatrixFun(Frame frame, int[] chunks, Vec weightVec, Vec offsetVec, DataInfo di, SparseMatrixDimensions dimensions, Vec respVec, float[] resp, float[] weights, float[] offsets) {
            this._actualRows = new int[chunks.length];
            this._frame = frame;
            this._chunks = chunks;
            this._weightVec = weightVec;
            this._offsetsVec = offsetVec;
            this._di = di;
            this._dims = dimensions;
            this._respVec = respVec;
            this._resp = resp;
            this._weights = weights;
            this._offsets = offsets;
        }

        protected void map(int chunkIdx) {
            long dataSize;
            int rowHeaderSize;
            int chunk = this._chunks[chunkIdx];
            long nonZeroCount = this._dims._precedingNonZeroElementsCounts[chunkIdx];
            int rwRow = this._dims._precedingRowCounts[chunkIdx];
            if (chunkIdx == this._dims._precedingNonZeroElementsCounts.length - 1) {
                rowHeaderSize = this._dims._rowHeadersCount - rwRow;
                dataSize = this._dims._nonZeroElementsCount - nonZeroCount;
            } else {
                rowHeaderSize = this._dims._precedingRowCounts[chunkIdx + 1] - rwRow + 1;
                dataSize = this._dims._precedingNonZeroElementsCounts[chunkIdx + 1] - nonZeroCount;
            }
            assert (dataSize < Integer.MAX_VALUE);
            Chunk weightChunk = this._weightVec != null ? this._weightVec.chunkForChunkIdx(chunk) : null;
            Chunk offsetChunk = this._offsetsVec != null ? this._offsetsVec.chunkForChunkIdx(chunk) : null;
            Chunk respChunk = this._respVec.chunkForChunkIdx(chunk);
            Chunk[] featChunks = new Chunk[this._frame.vecs().length];
            for (int i = 0; i < featChunks.length; ++i) {
                featChunks[i] = this._frame.vecs()[i].chunkForChunkIdx(chunk);
            }
            SparseMatrixChunk chunkData = new SparseMatrixChunk(chunkIdx, rowHeaderSize, (int)dataSize);
            int dataIndex = 0;
            int rowHeaderIndex = 0;
            for (int i = 0; i < respChunk._len; ++i) {
                int j;
                if (weightChunk != null && weightChunk.atd(i) == 0.0) continue;
                chunkData.rowHeader[rowHeaderIndex++] = nonZeroCount;
                int n = chunkIdx;
                this._actualRows[n] = this._actualRows[n] + 1;
                for (j = 0; j < this._di._cats; ++j) {
                    chunkData.data[dataIndex] = 1.0f;
                    chunkData.colIndices[dataIndex] = featChunks[j].isNA(i) ? this._di.getCategoricalId(j, Double.NaN) : this._di.getCategoricalId(j, (double)featChunks[j].at8(i));
                    ++dataIndex;
                    ++nonZeroCount;
                }
                for (j = 0; j < this._di._nums; ++j) {
                    float val = (float)featChunks[this._di._cats + j].atd(i);
                    if (val == 0.0f) continue;
                    chunkData.data[dataIndex] = val;
                    chunkData.colIndices[dataIndex] = this._di._catOffsets[this._di._catOffsets.length - 1] + j;
                    ++dataIndex;
                    ++nonZeroCount;
                }
                rwRow = MatrixFactoryUtils.setResponseWeightAndOffset(weightChunk, offsetChunk, respChunk, this._resp, this._weights, this._offsets, rwRow, i);
            }
            chunkData.rowHeader[rowHeaderIndex] = nonZeroCount;
            XGBoostUploadMatrixTask.this.makeClient().uploadObject(XGBoostUploadMatrixTask.this._modelKey, RemoteXGBoostUploadServlet.RequestType.sparseMatrixChunk, chunkData);
        }
    }

    public static class SparseMatrixChunk
    extends Iced<SparseMatrixChunk>
    implements BootstrapFreezable<SparseMatrixChunk> {
        public final int id;
        public final long[] rowHeader;
        public final float[] data;
        public final int[] colIndices;

        SparseMatrixChunk(int id, int rowHeaderSize, int dataSize) {
            this.id = id;
            this.rowHeader = new long[rowHeaderSize];
            this.data = new float[dataSize];
            this.colIndices = new int[dataSize];
        }
    }

    private class UploadDenseChunkFun
    extends MrFun<UploadDenseChunkFun> {
        private final Frame _f;
        private final int[] _chunks;
        private final int[] _rowOffsets;
        private final Vec _weightsVec;
        private final Vec _offsetsVec;
        private final Vec _respVec;
        private final DataInfo _di;
        private final float[] _resp;
        private final float[] _weights;
        private final float[] _offsets;
        private final int[] _nRowsByChunk;

        private UploadDenseChunkFun(Frame f, int[] chunks, int[] rowOffsets, Vec weightsVec, Vec offsetsVec, Vec respVec, DataInfo di, float[] resp, float[] weights, float[] offsets) {
            this._f = f;
            this._chunks = chunks;
            this._rowOffsets = rowOffsets;
            this._weightsVec = weightsVec;
            this._offsetsVec = offsetsVec;
            this._respVec = respVec;
            this._di = di;
            this._resp = resp;
            this._weights = weights;
            this._offsets = offsets;
            this._nRowsByChunk = new int[chunks.length];
        }

        protected void map(int id) {
            int chunkIdx = this._chunks[id];
            Chunk[] chks = new Chunk[this._f.numCols()];
            for (int c = 0; c < chks.length; ++c) {
                chks[c] = this._f.vec(c).chunkForChunkIdx(chunkIdx);
            }
            Chunk weightsChk = this._weightsVec != null ? this._weightsVec.chunkForChunkIdx(chunkIdx) : null;
            Chunk offsetsChk = this._offsetsVec != null ? this._offsetsVec.chunkForChunkIdx(chunkIdx) : null;
            Chunk respChk = this._respVec.chunkForChunkIdx(chunkIdx);
            int idx = 0;
            DenseMatrixChunk chunkData = new DenseMatrixChunk(id, (this._rowOffsets[id + 1] - this._rowOffsets[id]) * this._di.fullN());
            int actualRows = 0;
            for (int i = 0; i < chks[0]._len; ++i) {
                if (weightsChk != null && weightsChk.atd(i) == 0.0) continue;
                idx = this.writeDenseRow(this._di, chks, i, chunkData.data, idx);
                this._resp[this._rowOffsets[id] + actualRows] = (float)respChk.atd(i);
                if (weightsChk != null) {
                    this._weights[this._rowOffsets[id] + actualRows] = (float)weightsChk.atd(i);
                }
                if (offsetsChk != null) {
                    this._offsets[this._rowOffsets[id] + actualRows] = (float)offsetsChk.atd(i);
                }
                ++actualRows;
            }
            assert (idx == chunkData.data.length) : "idx should be " + chunkData.data.length + " but it is " + idx;
            this._nRowsByChunk[id] = actualRows;
            XGBoostUploadMatrixTask.this.makeClient().uploadObject(XGBoostUploadMatrixTask.this._modelKey, RemoteXGBoostUploadServlet.RequestType.denseMatrixChunk, chunkData);
        }

        private int writeDenseRow(DataInfo di, Chunk[] chunks, int rowInChunk, float[] data, int idx) {
            int j;
            for (j = 0; j < di._cats; ++j) {
                int len = di._catOffsets[j + 1] - di._catOffsets[j];
                double val = chunks[j].isNA(rowInChunk) ? Double.NaN : (double)chunks[j].at8(rowInChunk);
                int pos = di.getCategoricalId(j, val) - di._catOffsets[j];
                data[idx + pos] = 1.0f;
                idx += len;
            }
            for (j = 0; j < di._nums; ++j) {
                float val = chunks[di._cats + j].isNA(rowInChunk) ? Float.NaN : (float)chunks[di._cats + j].atd(rowInChunk);
                data[idx++] = val;
            }
            return idx;
        }

        private int getTotalRows() {
            int totalRows = 0;
            for (int r : this._nRowsByChunk) {
                totalRows += r;
            }
            return totalRows;
        }
    }

    public static class DenseMatrixChunk
    extends Iced<DenseMatrixChunk>
    implements BootstrapFreezable<DenseMatrixChunk> {
        public final int id;
        public final float[] data;

        DenseMatrixChunk(int id, int dataSize) {
            this.id = id;
            this.data = new float[dataSize];
        }
    }

    public static class DenseMatrixDimensions
    extends Iced<DenseMatrixDimensions>
    implements BootstrapFreezable<DenseMatrixDimensions> {
        public final int rows;
        public final int cols;
        public final int[] rowOffsets;

        public DenseMatrixDimensions(int rows, int cols, int[] rowOffsets) {
            this.rows = rows;
            this.cols = cols;
            this.rowOffsets = rowOffsets;
        }
    }

    public static class MatrixData
    extends Iced<MatrixData>
    implements BootstrapFreezable<MatrixData> {
        public final float[] resp;
        public final float[] weights;
        public final float[] offsets;
        public int actualRows;
        public int shape;

        MatrixData(int nRows, Vec weightVec, Vec offsetsVec) {
            this.resp = MemoryManager.malloc4f((int)nRows);
            this.weights = (float[])(weightVec != null ? MemoryManager.malloc4f((int)nRows) : null);
            this.offsets = (float[])(offsetsVec != null ? MemoryManager.malloc4f((int)nRows) : null);
        }
    }
}

