/*
 * Decompiled with CFR 0.152.
 */
package water.init;

import java.util.Random;
import jsr166y.CountedCompleter;
import water.DTask;
import water.Futures;
import water.H2O;
import water.Iced;
import water.MRTask;
import water.RPC;
import water.util.Log;
import water.util.TwoDimTable;

public class NetworkBench
extends Iced {
    public static int[] MSG_SZS = new int[]{1, 1, 1, 1, 1};
    public static int[] MSG_CNT = new int[]{500000, 500000, 500000, 500000, 500000};
    public NetworkBenchResults[] _results;

    public NetworkBench doTest() {
        long t1 = System.currentTimeMillis();
        H2O.submitTask(new H2O.H2OCountedCompleter(){

            @Override
            protected void compute2() {
                NetworkBench.this._results = new NetworkBenchResults[MSG_SZS.length];
                for (int i = 0; i < MSG_SZS.length; ++i) {
                    long t2 = System.currentTimeMillis();
                    long[] mrts = new long[H2O.CLOUD.size()];
                    Log.info("Network Bench, running All2All, message size = " + MSG_SZS[i] + ", message count = " + MSG_CNT[i]);
                    long[][] all2all = ((TestAll2All)new TestAll2All((int)NetworkBench.MSG_SZS[i], (int)NetworkBench.MSG_CNT[i]).doAllNodes())._time;
                    Log.info("All2All test done in " + (double)(System.currentTimeMillis() - t2) * 0.001 + "s");
                    NetworkBench.this._results[i] = new NetworkBenchResults(MSG_SZS[i], MSG_CNT[i], all2all, mrts);
                }
                this.tryComplete();
            }
        }).join();
        for (NetworkBenchResults r : this._results) {
            System.out.println("===================================== MSG SZ = " + r._msgSz + ", CNT = " + r._msgCnt + " =========================================");
            System.out.println(r.to2dTable());
            System.out.println();
        }
        Log.info("Newtork test done in " + (double)(System.currentTimeMillis() - t1) * 0.001 + "s");
        return this;
    }

    private static class TestMRTasks
    extends DTask<TestMRTasks> {
        final int _msgSz;
        final int _msgCnt;
        long _time;

        public TestMRTasks(int msgSz, int msgCnt) {
            this._msgSz = msgSz;
            this._msgCnt = msgCnt;
        }

        @Override
        protected void compute2() {
            Futures fs = new Futures();
            this._time = System.currentTimeMillis();
            this.addToPendingCount(this._msgCnt - 1);
            final byte[] data = new byte[this._msgSz];
            new Random().nextBytes(data);
            for (int i = 0; i < this._msgCnt; ++i) {
                new MRTask(this){
                    byte[] dd;
                    {
                        super(x0);
                        this.dd = data;
                    }

                    @Override
                    public void setupLocal() {
                        this.dd = null;
                    }
                }.asyncExecOnAllNodes();
            }
        }

        @Override
        public byte priority() {
            return 1;
        }

        @Override
        public void onCompletion(CountedCompleter cc) {
            this._time = System.currentTimeMillis() - this._time;
        }
    }

    private static class TestAll2All
    extends MRTask<TestAll2All> {
        final int _msgSz;
        final int _msgCnt;
        long[][] _time;

        public TestAll2All(int msgSz, int msgCnt) {
            this._msgSz = msgSz;
            this._msgCnt = msgCnt;
        }

        @Override
        public void setupLocal() {
            this._time = new long[H2O.CLOUD.size()][];
            final int myId = H2O.SELF.index();
            this._time[myId] = new long[H2O.CLOUD.size()];
            this.addToPendingCount(H2O.CLOUD.size() - 1);
            for (int i = 0; i < H2O.CLOUD.size(); ++i) {
                if (i == myId) continue;
                final int fi = i;
                H2O.submitTask(new H2O.H2OCountedCompleter(this){
                    long t1;

                    @Override
                    protected void compute2() {
                        this.t1 = System.currentTimeMillis();
                        this.addToPendingCount(TestAll2All.this._msgCnt - 1);
                        for (int j = 0; j < TestAll2All.this._msgCnt; ++j) {
                            new RPC<SendRandomBytesTsk>(H2O.CLOUD._memary[fi], new SendRandomBytesTsk(TestAll2All.this._msgSz)).addCompleter(this).call();
                        }
                    }

                    @Override
                    public void onCompletion(CountedCompleter cc) {
                        long t2 = System.currentTimeMillis();
                        TestAll2All.this._time[myId][fi] = t2 - this.t1;
                    }
                });
            }
        }

        @Override
        public void reduce(TestAll2All tst) {
            for (int i = 0; i < this._time.length; ++i) {
                if (this._time[i] == null) {
                    this._time[i] = tst._time[i];
                    continue;
                }
                assert (tst._time[i] == null);
            }
        }

        private static class SendRandomBytesTsk
        extends DTask {
            final byte[] dd;

            public SendRandomBytesTsk(int sz) {
                this.dd = new byte[sz];
                new Random().nextBytes(this.dd);
            }

            @Override
            protected void compute2() {
                this.tryComplete();
            }
        }
    }

    public static class NetworkBenchResults {
        final int _msgSz;
        final int _msgCnt;
        final long[] _mrtTimes;
        final long[][] _all2AllTimes;

        public NetworkBenchResults(int msgSz, int msgCnt, long[][] all2all, long[] mrts) {
            this._msgSz = msgSz;
            this._msgCnt = msgCnt;
            this._mrtTimes = mrts;
            this._all2AllTimes = all2all;
        }

        public TwoDimTable to2dTable() {
            String title = "Network Bench, sz = " + this._msgSz + "B, cnt = " + this._msgCnt + ", total sz = " + 0.01 * (double)((int)((double)(100 * this._msgSz * this._msgCnt) / 1048576.0)) + "MB";
            String[] rowHeaders = new String[H2O.CLOUD.size() + 1];
            rowHeaders[H2O.CLOUD.size()] = "MrTasks";
            String[] colHeaders = new String[H2O.CLOUD.size()];
            String[] colTypes = new String[H2O.CLOUD.size()];
            String[] colFormats = new String[H2O.CLOUD.size()];
            for (int i = 0; i < H2O.CLOUD.size(); ++i) {
                rowHeaders[i] = colHeaders[i] = H2O.CLOUD._memary[i].toString();
                colTypes[i] = "double";
                colFormats[i] = "%2f";
            }
            TwoDimTable td = new TwoDimTable(title, "Network benchmark results, round-trip bandwidth in MB/s", rowHeaders, colHeaders, colTypes, colFormats, "");
            for (int i = 0; i < this._all2AllTimes.length; ++i) {
                for (int j = 0; j < this._all2AllTimes.length; ++j) {
                    td.set(i, j, 0.01 * (double)((int)((double)(this._msgSz * this._msgCnt) / ((double)this._all2AllTimes[i][j] * 1.0E-5))));
                }
                td.set(H2O.CLOUD.size(), i, 0.01 * (double)((int)((double)(this._msgSz * this._msgCnt) / ((double)this._mrtTimes[i] * 1.0E-5))));
            }
            return td;
        }
    }
}

