/*
 * Decompiled with CFR 0.152.
 */
package hex;

import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import water.DKV;
import water.H2O;
import water.Iced;
import water.Key;
import water.MRTask;
import water.MemoryManager;
import water.TAtomic;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

public class DMatrix {
    static int cnt = 0;

    public static Frame transpose(Frame src) {
        if (src.numRows() != (long)((int)src.numRows())) {
            throw H2O.unimpl();
        }
        int nchunks = Math.max(1, src.numCols() / 10000);
        long[] espc = new long[nchunks + 1];
        int rpc = src.numCols() / nchunks;
        int rem = src.numCols() % nchunks;
        Arrays.fill(espc, (long)rpc);
        int i = 0;
        while (i < rem) {
            int n = i++;
            espc[n] = espc[n] + 1L;
        }
        long sum = 0L;
        for (int i2 = 0; i2 < espc.length; ++i2) {
            long s = espc[i2];
            espc[i2] = sum;
            sum += s;
        }
        Key<Vec> key = Vec.newKey();
        int rowLayout = Vec.ESPC.rowLayout(key, espc);
        return DMatrix.transpose(src, new Frame(new Vec(key, rowLayout).makeZeros((int)src.numRows())));
    }

    public static Frame transpose(Frame src, Frame tgt) {
        if (src.numRows() != (long)tgt.numCols() || (long)src.numCols() != tgt.numRows()) {
            throw new IllegalArgumentException("dimension do not match!");
        }
        for (Vec v : src.vecs()) {
            if (v.isCategorical()) {
                throw new IllegalArgumentException("transpose can only be applied to all-numeric frames (representing a matrix)");
            }
            if (v.length() <= 1000000L) continue;
            throw new IllegalArgumentException("too many rows, transpose only works for frames with < 1M rows.");
        }
        new TransposeTsk(tgt).doAll(src);
        return tgt;
    }

    public static Frame mmul(Frame x, Frame y) {
        MatrixMulTsk t = new MatrixMulTsk(null, null, x, y);
        if (Thread.currentThread() instanceof H2O.FJWThr) {
            t.fork().join();
        } else {
            H2O.submitTask(t).join();
        }
        return t._z;
    }

    private static class UpdateProgress
    extends TAtomic<MatrixMulStats> {
        final int _chunkSz;
        final int _chunkType;

        public UpdateProgress(int sz, int type) {
            this._chunkSz = sz;
            this._chunkType = type;
        }

        @Override
        public MatrixMulStats atomic(MatrixMulStats old) {
            old.chunkCnts = (long[])old.chunkCnts.clone();
            int j = -1;
            for (int i = 0; i < old.chunkTypes.length; ++i) {
                if (this._chunkType != old.chunkTypes[i]) continue;
                j = i;
                break;
            }
            if (j == -1) {
                old.chunkTypes = Arrays.copyOf(old.chunkTypes, old.chunkTypes.length + 1);
                old.chunkCnts = Arrays.copyOf(old.chunkCnts, old.chunkCnts.length + 1);
                old.chunkTypes[old.chunkTypes.length - 1] = this._chunkType;
                j = old.chunkTypes.length - 1;
            }
            ++old.chunksDone;
            int n = j;
            old.chunkCnts[n] = old.chunkCnts[n] + 1L;
            old.lastUpdateAt = System.currentTimeMillis();
            old.size += (long)this._chunkSz;
            return old;
        }
    }

    private static class VecTsk
    extends MRTask<VecTsk> {
        double[] _y;
        Key _progressKey;

        public VecTsk(H2O.H2OCountedCompleter cmp, Key progressKey, double[] y) {
            super(cmp);
            this._progressKey = progressKey;
            this._y = y;
        }

        @Override
        public void setupLocal() {
            this._fr.lastVec().preWriting();
        }

        @Override
        public void map(Chunk[] chks) {
            Chunk zChunk = chks[chks.length - 1];
            double[] res = MemoryManager.malloc8d(chks[0]._len);
            for (int i = 0; i < this._y.length; ++i) {
                double yVal = this._y[i];
                Chunk xChunk = chks[i];
                int k = xChunk.nextNZ(-1);
                while (k < res.length) {
                    try {
                        int n = k;
                        res[n] = res[n] + yVal * xChunk.atd(k);
                    }
                    catch (Throwable t) {
                        t.printStackTrace();
                        throw new RuntimeException(t);
                    }
                    k = xChunk.nextNZ(k);
                }
            }
            Chunk modChunk = new NewChunk(res).setSparseRatio(2).compress();
            if (this._progressKey != null) {
                new UpdateProgress(modChunk.getBytes().length, modChunk.frozenType()).fork(this._progressKey);
            }
            DKV.put(zChunk.vec().chunkKey(zChunk.cidx()), modChunk, this._fs);
        }

        @Override
        public void closeLocal() {
            this._y = null;
            this._progressKey = null;
        }
    }

    private static class GetNonZerosTsk
    extends MRTask<GetNonZerosTsk> {
        final int _maxsz;
        int[] _idxs;
        double[] _vals;

        public GetNonZerosTsk(H2O.H2OCountedCompleter cmp) {
            super(cmp);
            this._maxsz = 10000000;
        }

        public GetNonZerosTsk(H2O.H2OCountedCompleter cmp, int maxsz) {
            super(cmp);
            this._maxsz = maxsz;
        }

        @Override
        public void map(Chunk c) {
            int istart = (int)c.start();
            assert (c.start() + (long)c._len == (long)(istart + c._len));
            int n = c.sparseLenZero();
            this._idxs = MemoryManager.malloc4(n);
            this._vals = MemoryManager.malloc8d(n);
            int j = 0;
            int i = c.nextNZ(-1);
            while (i < c._len) {
                this._idxs[j] = i + istart;
                this._vals[j] = c.atd(i);
                i = c.nextNZ(i);
                ++j;
            }
            assert (j == n);
            if (this._idxs.length > this._maxsz) {
                throw new RuntimeException("too many nonzeros! found at least " + this._idxs.length + " nonzeros.");
            }
        }

        @Override
        public void reduce(GetNonZerosTsk gnz) {
            if (this._idxs.length + gnz._idxs.length > this._maxsz) {
                throw new RuntimeException("too many nonzeros! found at least " + (this._idxs.length + gnz._idxs.length) + " nonzeros.");
            }
            int[] idxs = MemoryManager.malloc4(this._idxs.length + gnz._idxs.length);
            double[] vals = MemoryManager.malloc8d(this._vals.length + gnz._vals.length);
            ArrayUtils.sortedMerge(this._idxs, this._vals, gnz._idxs, gnz._vals, idxs, vals);
            this._idxs = idxs;
            this._vals = vals;
        }
    }

    public static class MatrixMulTsk
    extends H2O.H2OCountedCompleter {
        final transient Frame _x;
        Frame _y;
        Frame _z;
        final Key _progressKey;
        AtomicInteger _cntr;

        public MatrixMulTsk(H2O.H2OCountedCompleter cmp, Key progressKey, Frame x, Frame y) {
            super(cmp);
            if ((long)x.numCols() != y.numRows()) {
                throw new IllegalArgumentException("dimensions do not match! x.numcols = " + x.numCols() + ", y.numRows = " + y.numRows());
            }
            this._x = x;
            this._y = y;
            this._progressKey = progressKey;
        }

        @Override
        public void compute2() {
            this._z = new Frame(this._x.anyVec().makeZeros(this._y.numCols()));
            int total_cores = H2O.CLOUD.size() * H2O.NUMCPUS;
            int chunksPerCol = this._y.anyVec().nChunks();
            int maxP = 256 * total_cores / chunksPerCol;
            Log.info("maxP = " + maxP);
            this._cntr = new AtomicInteger(maxP - 1);
            this.addToPendingCount(2 * this._y.numCols() - 1);
            for (int i = 0; i < Math.min(this._y.numCols(), maxP); ++i) {
                this.forkVecTask(i);
            }
        }

        private void forkVecTask(final int i) {
            new GetNonZerosTsk(new H2O.H2OCallback<GetNonZerosTsk>((H2O.H2OCountedCompleter)this){

                @Override
                public void callback(GetNonZerosTsk gnz) {
                    new VecTsk(new Callback(), _progressKey, gnz._vals).dfork(ArrayUtils.append(_x.vecs(gnz._idxs), _z.vec(i)));
                }
            }).dfork(this._y.vec(i));
        }

        private class Callback
        extends H2O.H2OCallback {
            public Callback() {
                super(MatrixMulTsk.this);
            }

            public void callback(H2O.H2OCountedCompleter h2OCountedCompleter) {
                int i = MatrixMulTsk.this._cntr.incrementAndGet();
                if (i < MatrixMulTsk.this._y.numCols()) {
                    MatrixMulTsk.this.forkVecTask(i);
                }
            }
        }
    }

    public static class MatrixMulStats
    extends Iced {
        public final Key jobKey;
        public final long chunksTotal;
        public final long _startTime;
        public long lastUpdateAt;
        public long chunksDone;
        public long size;
        public int[] chunkTypes = new int[0];
        public long[] chunkCnts = new long[0];

        public MatrixMulStats(long n, Key jobKey) {
            this.chunksTotal = n;
            this._startTime = System.currentTimeMillis();
            this.jobKey = jobKey;
        }

        public float progress() {
            return (float)((double)this.chunksDone / (double)this.chunksTotal);
        }
    }

    public static class TransposeTsk
    extends MRTask<TransposeTsk> {
        final Frame _tgt;

        public TransposeTsk(Frame tgt) {
            this._tgt = tgt;
        }

        @Override
        public void map(Chunk[] chks) {
            Frame tgt = this._tgt;
            long[] espc = tgt.anyVec().espc();
            int colStart = (int)chks[0].start();
            for (int i = 0; i < espc.length - 1; ++i) {
                int j;
                int fi = i;
                NewChunk[] tgtChunks = new NewChunk[chks[0]._len];
                for (j = 0; j < tgtChunks.length; ++j) {
                    tgtChunks[j] = new NewChunk(tgt.vec(j + colStart), fi);
                }
                for (int c = (int)espc[fi]; c < (int)espc[fi + 1]; ++c) {
                    int k;
                    Chunk nc = chks[c];
                    if (nc.isSparseZero()) {
                        k = nc.nextNZ(-1);
                        while (k < nc._len) {
                            tgtChunks[k].addZeros((int)((long)c - espc[fi]) - tgtChunks[k]._len);
                            nc.extractRows(tgtChunks[k], k);
                            k = nc.nextNZ(k);
                        }
                        continue;
                    }
                    for (k = 0; k < nc._len; ++k) {
                        tgtChunks[k].addZeros((int)((long)c - espc[fi]) - tgtChunks[k]._len);
                        nc.extractRows(tgtChunks[k], k);
                    }
                }
                j = 0;
                while (j < tgtChunks.length) {
                    int fj = j++;
                    tgtChunks[fj].addZeros((int)(espc[fi + 1] - espc[fi]) - tgtChunks[fj]._len);
                    tgtChunks[fj].close(this._fs);
                    tgtChunks[fj] = null;
                }
            }
        }
    }
}

