/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost.matrix;

import ai.h2o.xgboost4j.java.DMatrix;
import ai.h2o.xgboost4j.java.XGBoostError;
import hex.DataInfo;
import hex.tree.xgboost.matrix.MatrixFactoryUtils;
import hex.tree.xgboost.matrix.MatrixLoader;
import hex.tree.xgboost.matrix.SparseMatrix;
import hex.tree.xgboost.matrix.SparseMatrixDimensions;
import java.util.Arrays;
import java.util.Objects;
import water.H2O;
import water.LocalMR;
import water.MemoryManager;
import water.MrFun;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;

public class SparseMatrixFactory {
    public static MatrixLoader.DMatrixProvider csr(Frame frame, int[] chunksIds, Vec weightsVec, Vec offsetsVec, Vec responseVec, DataInfo di, float[] resp, float[] weights, float[] offsets) {
        SparseMatrixDimensions sparseMatrixDimensions = SparseMatrixFactory.calculateCSRMatrixDimensions(frame, chunksIds, weightsVec, di);
        SparseMatrix sparseMatrix = SparseMatrixFactory.allocateCSRMatrix(sparseMatrixDimensions);
        int actualRows = SparseMatrixFactory.initializeFromChunkIds(frame, chunksIds, weightsVec, offsetsVec, di, sparseMatrix, sparseMatrixDimensions, responseVec, resp, weights, offsets);
        return SparseMatrixFactory.toDMatrix(sparseMatrix, sparseMatrixDimensions, actualRows, di.fullN(), resp, weights, offsets);
    }

    public static DMatrix csr(Chunk[] chunks, int weight, int respIdx, int offsetIdx, DataInfo di, float[] resp, float[] weights, float[] offsets) throws XGBoostError {
        SparseMatrixDimensions sparseMatrixDimensions = SparseMatrixFactory.calculateCSRMatrixDimensions(chunks, di, weight);
        SparseMatrix sparseMatrix = SparseMatrixFactory.allocateCSRMatrix(sparseMatrixDimensions);
        int actualRows = SparseMatrixFactory.initializeFromChunks(chunks, weight, di, sparseMatrix._rowHeaders, sparseMatrix._sparseData, sparseMatrix._colIndices, respIdx, resp, weights, offsetIdx, offsets);
        return SparseMatrixFactory.toDMatrix(sparseMatrix, sparseMatrixDimensions, actualRows, di.fullN(), resp, weights, offsets).get();
    }

    public static SparseDMatrixProvider toDMatrix(SparseMatrix sm, SparseMatrixDimensions smd, int actualRows, int shape, float[] resp, float[] weights, float[] offsets) {
        return new SparseDMatrixProvider(sm._rowHeaders, sm._colIndices, sm._sparseData, DMatrix.SparseType.CSR, shape, smd._nonZeroElementsCount, actualRows, resp, weights, offsets);
    }

    public static int initializeFromChunkIds(Frame frame, int[] chunks, Vec weightsVec, Vec offsetsVec, DataInfo di, SparseMatrix matrix, SparseMatrixDimensions dimensions, Vec respVec, float[] resp, float[] weights, float[] offsets) {
        InitializeCSRMatrixFromChunkIdsMrFun fun = new InitializeCSRMatrixFromChunkIdsMrFun(frame, chunks, weightsVec, offsetsVec, di, matrix, dimensions, respVec, resp, weights, offsets);
        ((LocalMR)H2O.submitTask((H2O.H2OCountedCompleter)new LocalMR((MrFun)fun, chunks.length))).join();
        return ArrayUtils.sum((int[])fun._actualRows);
    }

    private static int initializeFromChunks(Chunk[] chunks, int weight, DataInfo di, long[][] rowHeaders, float[][] data, int[][] colIndex, int respIdx, float[] resp, float[] weights, int offsetIdx, float[] offsets) {
        int actualRows = 0;
        int nonZeroCount = 0;
        int rwRow = 0;
        NestedArrayPointer rowHeaderPointer = new NestedArrayPointer();
        NestedArrayPointer dataPointer = new NestedArrayPointer();
        for (int i = 0; i < chunks[0].len(); ++i) {
            int j;
            if (weight != -1 && chunks[weight].atd(i) == 0.0) continue;
            ++actualRows;
            rowHeaderPointer.setAndIncrement(rowHeaders, nonZeroCount);
            for (j = 0; j < di._cats; ++j) {
                dataPointer.set(data, 1.0f);
                if (chunks[j].isNA(i)) {
                    dataPointer.set(colIndex, di.getCategoricalId(j, Double.NaN));
                } else {
                    dataPointer.set(colIndex, di.getCategoricalId(j, (double)chunks[j].at8(i)));
                }
                dataPointer.increment();
                ++nonZeroCount;
            }
            for (j = 0; j < di._nums; ++j) {
                float val = (float)chunks[di._cats + j].atd(i);
                if (val == 0.0f) continue;
                dataPointer.set(data, val);
                dataPointer.set(colIndex, di._catOffsets[di._catOffsets.length - 1] + j);
                dataPointer.increment();
                ++nonZeroCount;
            }
            rwRow = MatrixFactoryUtils.setResponseAndWeightAndOffset(chunks, respIdx, weight, offsetIdx, resp, weights, offsets, rwRow, i);
        }
        rowHeaderPointer.set(rowHeaders, (long)nonZeroCount);
        return actualRows;
    }

    public static SparseMatrix allocateCSRMatrix(SparseMatrixDimensions sparseMatrixDimensions) {
        int dataRowsNumber = (int)(sparseMatrixDimensions._nonZeroElementsCount / (long)SparseMatrix.MAX_DIM);
        int dataLastRowSize = (int)(sparseMatrixDimensions._nonZeroElementsCount % (long)SparseMatrix.MAX_DIM);
        int rowIndicesRowsNumber = sparseMatrixDimensions._rowHeadersCount / SparseMatrix.MAX_DIM;
        int rowIndicesLastRowSize = sparseMatrixDimensions._rowHeadersCount % SparseMatrix.MAX_DIM;
        int colIndicesRowsNumber = dataRowsNumber;
        int colIndicesLastRowSize = dataLastRowSize;
        float[][] sparseData = new float[dataLastRowSize == 0 ? dataRowsNumber : dataRowsNumber + 1][];
        int iterationLimit = dataLastRowSize == 0 ? sparseData.length : sparseData.length - 1;
        for (int sparseDataRow = 0; sparseDataRow < iterationLimit; ++sparseDataRow) {
            sparseData[sparseDataRow] = MemoryManager.malloc4f((int)SparseMatrix.MAX_DIM);
        }
        if (dataLastRowSize > 0) {
            sparseData[sparseData.length - 1] = MemoryManager.malloc4f((int)dataLastRowSize);
        }
        long[][] rowIndices = new long[rowIndicesLastRowSize == 0 ? rowIndicesRowsNumber : rowIndicesRowsNumber + 1][];
        iterationLimit = rowIndicesLastRowSize == 0 ? rowIndices.length : rowIndices.length - 1;
        for (int rowIndicesRow = 0; rowIndicesRow < iterationLimit; ++rowIndicesRow) {
            rowIndices[rowIndicesRow] = MemoryManager.malloc8((int)SparseMatrix.MAX_DIM);
        }
        if (rowIndicesLastRowSize > 0) {
            rowIndices[rowIndices.length - 1] = MemoryManager.malloc8((int)rowIndicesLastRowSize);
        }
        int[][] colIndices = new int[colIndicesLastRowSize == 0 ? colIndicesRowsNumber : colIndicesRowsNumber + 1][];
        iterationLimit = colIndicesLastRowSize == 0 ? colIndices.length : colIndices.length - 1;
        for (int colIndicesRow = 0; colIndicesRow < iterationLimit; ++colIndicesRow) {
            colIndices[colIndicesRow] = MemoryManager.malloc4((int)SparseMatrix.MAX_DIM);
        }
        if (colIndicesLastRowSize > 0) {
            colIndices[colIndices.length - 1] = MemoryManager.malloc4((int)colIndicesLastRowSize);
        }
        return new SparseMatrix(sparseData, rowIndices, colIndices);
    }

    protected static SparseMatrixDimensions calculateCSRMatrixDimensions(Chunk[] chunks, DataInfo di, int weightColIndex) {
        int[] nonZeroElementsCounts = new int[1];
        int[] rowIndicesCounts = new int[1];
        for (int i = 0; i < chunks[0].len(); ++i) {
            if (weightColIndex != -1 && chunks[weightColIndex].atd(i) == 0.0) continue;
            rowIndicesCounts[0] = rowIndicesCounts[0] + 1;
            nonZeroElementsCounts[0] = nonZeroElementsCounts[0] + di._cats;
            for (int j = 0; j < di._nums; ++j) {
                double val = chunks[di._cats + j].atd(i);
                if (val == 0.0) continue;
                nonZeroElementsCounts[0] = nonZeroElementsCounts[0] + 1;
            }
        }
        return new SparseMatrixDimensions(nonZeroElementsCounts, rowIndicesCounts);
    }

    public static SparseMatrixDimensions calculateCSRMatrixDimensions(Frame f, int[] chunkIds, Vec w, DataInfo di) {
        CalculateCSRMatrixDimensionsMrFun fun = new CalculateCSRMatrixDimensionsMrFun(f, di, w, chunkIds);
        ((LocalMR)H2O.submitTask((H2O.H2OCountedCompleter)new LocalMR((MrFun)fun, chunkIds.length))).join();
        return new SparseMatrixDimensions(fun._nonZeroElementsCounts, fun._rowIndicesCounts);
    }

    private static class CalculateCSRMatrixDimensionsMrFun
    extends MrFun<CalculateCSRMatrixDimensionsMrFun> {
        private Frame _f;
        private DataInfo _di;
        private Vec _w;
        private int[] _chunkIds;
        private int[] _rowIndicesCounts;
        private int[] _nonZeroElementsCounts;

        CalculateCSRMatrixDimensionsMrFun(Frame f, DataInfo di, Vec w, int[] chunkIds) {
            this._f = f;
            this._di = di;
            this._w = w;
            this._chunkIds = chunkIds;
            this._rowIndicesCounts = new int[chunkIds.length];
            this._nonZeroElementsCounts = new int[chunkIds.length];
        }

        protected void map(int i) {
            int cidx = this._chunkIds[i];
            int rowIndicesCount = 0;
            int nonZeroElementsCount = 0;
            if (this._di._nums == 0) {
                if (this._w == null) {
                    rowIndicesCount = this._f.anyVec().chunkForChunkIdx((int)cidx)._len;
                    nonZeroElementsCount = rowIndicesCount * this._di._cats;
                } else {
                    Chunk ws = this._w.chunkForChunkIdx(cidx);
                    int nzWeights = 0;
                    for (int r = 0; r < ws._len; ++r) {
                        if (ws.atd(r) == 0.0) continue;
                        ++nzWeights;
                    }
                    rowIndicesCount += nzWeights;
                    nonZeroElementsCount += nzWeights * this._di._cats;
                }
            } else {
                Chunk[] cs = new Chunk[this._di._nums];
                for (int c = 0; c < cs.length; ++c) {
                    cs[c] = this._f.vec(this._di._cats + c).chunkForChunkIdx(cidx);
                }
                Chunk ws = this._w != null ? this._w.chunkForChunkIdx(cidx) : null;
                for (int r = 0; r < cs[0]._len; ++r) {
                    if (ws != null && ws.atd(r) == 0.0) continue;
                    ++rowIndicesCount;
                    nonZeroElementsCount += this._di._cats;
                    for (int j = 0; j < this._di._nums; ++j) {
                        if (cs[j].atd(r) == 0.0) continue;
                        ++nonZeroElementsCount;
                    }
                }
            }
            this._rowIndicesCounts[i] = rowIndicesCount;
            this._nonZeroElementsCounts[i] = nonZeroElementsCount;
        }
    }

    private static class InitializeCSRMatrixFromChunkIdsMrFun
    extends MrFun<InitializeCSRMatrixFromChunkIdsMrFun> {
        Frame _frame;
        int[] _chunks;
        Vec _weightVec;
        Vec _offsetsVec;
        DataInfo _di;
        SparseMatrix _matrix;
        SparseMatrixDimensions _dims;
        Vec _respVec;
        float[] _resp;
        float[] _weights;
        float[] _offsets;
        int[] _actualRows;

        InitializeCSRMatrixFromChunkIdsMrFun(Frame frame, int[] chunks, Vec weightVec, Vec offsetVec, DataInfo di, SparseMatrix matrix, SparseMatrixDimensions dimensions, Vec respVec, float[] resp, float[] weights, float[] offsets) {
            this._actualRows = new int[chunks.length];
            this._frame = frame;
            this._chunks = chunks;
            this._weightVec = weightVec;
            this._offsetsVec = offsetVec;
            this._di = di;
            this._matrix = matrix;
            this._dims = dimensions;
            this._respVec = respVec;
            this._resp = resp;
            this._weights = weights;
            this._offsets = offsets;
        }

        protected void map(int chunkIdx) {
            int i;
            int chunk = this._chunks[chunkIdx];
            long nonZeroCount = this._dims._precedingNonZeroElementsCounts[chunkIdx];
            int rwRow = this._dims._precedingRowCounts[chunkIdx];
            NestedArrayPointer rowHeaderPointer = new NestedArrayPointer(rwRow);
            NestedArrayPointer dataPointer = new NestedArrayPointer(nonZeroCount);
            Chunk weightChunk = this._weightVec != null ? this._weightVec.chunkForChunkIdx(chunk) : null;
            Chunk offsetChunk = this._offsetsVec != null ? this._offsetsVec.chunkForChunkIdx(chunk) : null;
            Chunk respChunk = this._respVec.chunkForChunkIdx(chunk);
            Chunk[] featChunks = new Chunk[this._frame.vecs().length];
            for (i = 0; i < featChunks.length; ++i) {
                featChunks[i] = this._frame.vecs()[i].chunkForChunkIdx(chunk);
            }
            for (i = 0; i < respChunk._len; ++i) {
                int j;
                if (weightChunk != null && weightChunk.atd(i) == 0.0) continue;
                rowHeaderPointer.setAndIncrement(this._matrix._rowHeaders, nonZeroCount);
                int n = chunkIdx;
                this._actualRows[n] = this._actualRows[n] + 1;
                for (j = 0; j < this._di._cats; ++j) {
                    dataPointer.set(this._matrix._sparseData, 1.0f);
                    if (featChunks[j].isNA(i)) {
                        dataPointer.set(this._matrix._colIndices, this._di.getCategoricalId(j, Double.NaN));
                    } else {
                        dataPointer.set(this._matrix._colIndices, this._di.getCategoricalId(j, (double)featChunks[j].at8(i)));
                    }
                    dataPointer.increment();
                    ++nonZeroCount;
                }
                for (j = 0; j < this._di._nums; ++j) {
                    float val = (float)featChunks[this._di._cats + j].atd(i);
                    if (val == 0.0f) continue;
                    dataPointer.set(this._matrix._sparseData, val);
                    dataPointer.set(this._matrix._colIndices, this._di._catOffsets[this._di._catOffsets.length - 1] + j);
                    dataPointer.increment();
                    ++nonZeroCount;
                }
                rwRow = MatrixFactoryUtils.setResponseWeightAndOffset(weightChunk, offsetChunk, respChunk, this._resp, this._weights, this._offsets, rwRow, i);
            }
            rowHeaderPointer.set(this._matrix._rowHeaders, nonZeroCount);
        }
    }

    static class NestedArrayPointer {
        int _row;
        int _col;

        public NestedArrayPointer() {
        }

        public NestedArrayPointer(long pos) {
            this._row = (int)(pos / (long)SparseMatrix.MAX_DIM);
            this._col = (int)(pos % (long)SparseMatrix.MAX_DIM);
        }

        void increment() {
            ++this._col;
            if (this._col == SparseMatrix.MAX_DIM) {
                this._col = 0;
                ++this._row;
            }
        }

        void set(long[][] dest, long val) {
            dest[this._row][this._col] = val;
        }

        void set(float[][] dest, float val) {
            dest[this._row][this._col] = val;
        }

        void set(int[][] dest, int val) {
            dest[this._row][this._col] = val;
        }

        void setAndIncrement(long[][] dest, long val) {
            this.set(dest, val);
            this.increment();
        }

        public long get(long[][] dest) {
            return dest[this._row][this._col];
        }

        public int get(int[][] dest) {
            return dest[this._row][this._col];
        }

        public float get(float[][] dest) {
            return dest[this._row][this._col];
        }
    }

    public static class SparseDMatrixProvider
    extends MatrixLoader.DMatrixProvider {
        private long[][] rowHeaders;
        private int[][] colIndices;
        private float[][] sparseData;
        private DMatrix.SparseType csr;
        private int shape;
        private long nonZeroElementsCount;

        public SparseDMatrixProvider(long[][] rowHeaders, int[][] colIndices, float[][] sparseData, DMatrix.SparseType csr, int shape, long nonZeroElementsCount, int actualRows, float[] response, float[] weights, float[] offsets) {
            super(actualRows, response, weights, offsets);
            this.rowHeaders = rowHeaders;
            this.colIndices = colIndices;
            this.sparseData = sparseData;
            this.csr = csr;
            this.shape = shape;
            this.nonZeroElementsCount = nonZeroElementsCount;
        }

        @Override
        public DMatrix makeDMatrix() throws XGBoostError {
            return new DMatrix(this.rowHeaders, this.colIndices, this.sparseData, this.csr, this.shape, (int)this.actualRows + 1, this.nonZeroElementsCount);
        }

        @Override
        public void print(int nrow) {
            NestedArrayPointer r = new NestedArrayPointer();
            NestedArrayPointer d = new NestedArrayPointer();
            long elemIndex = 0L;
            r.increment();
            int i = 0;
            while ((long)i < (nrow > 0 ? (long)nrow : this.actualRows)) {
                System.out.print(i + ":\t");
                long rowEnd = r.get(this.rowHeaders);
                r.increment();
                while (elemIndex < rowEnd) {
                    System.out.print(d.get(this.colIndices) + ":" + d.get(this.sparseData) + "\t");
                    d.increment();
                    ++elemIndex;
                }
                System.out.print(this.response[i]);
                System.out.println();
                ++i;
            }
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            SparseDMatrixProvider that = (SparseDMatrixProvider)o;
            return this.shape == that.shape && this.nonZeroElementsCount == that.nonZeroElementsCount && Arrays.deepEquals((Object[])this.rowHeaders, (Object[])that.rowHeaders) && Arrays.deepEquals((Object[])this.colIndices, (Object[])that.colIndices) && Arrays.deepEquals((Object[])this.sparseData, (Object[])that.sparseData) && this.csr == that.csr;
        }

        @Override
        public int hashCode() {
            int result = Objects.hash(super.hashCode(), this.csr, this.shape, this.nonZeroElementsCount);
            result = 31 * result + Arrays.hashCode((Object[])this.rowHeaders);
            result = 31 * result + Arrays.hashCode((Object[])this.colIndices);
            result = 31 * result + Arrays.hashCode((Object[])this.sparseData);
            return result;
        }
    }
}

