/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.ast.prims.mungers;

import java.util.Arrays;
import java.util.Comparator;
import water.Futures;
import water.H2O;
import water.Iced;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.Merge;
import water.rapids.Val;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.ast.params.AstNum;
import water.rapids.ast.params.AstNumList;
import water.rapids.vals.ValFrame;
import water.rapids.vals.ValFun;
import water.util.ArrayUtils;
import water.util.IcedHashMap;
import water.util.Log;

public class AstGroup
extends AstPrimitive {
    public int _numberOfMedianActionsNeeded = -1;

    @Override
    public int nargs() {
        return -1;
    }

    @Override
    public String[] args() {
        return new String[]{"..."};
    }

    @Override
    public String str() {
        return "GB";
    }

    @Override
    public ValFrame apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        Frame fr = stk.track(asts[1].exec(env)).getFrame();
        int ncols = fr.numCols();
        AstNumList groupby = AstGroup.check(ncols, asts[2]);
        int[] gbCols = groupby.expand4();
        int validAggregatesCount = this.countNumberOfAggregates(fr, ncols, asts);
        AGG[] aggs = this.constructAggregates(fr, validAggregatesCount, env, asts);
        return this.performGroupingWithAggregations(fr, gbCols, aggs, this._numberOfMedianActionsNeeded);
    }

    public ValFrame performGroupingWithAggregations(Frame fr, int[] gbCols, AGG[] aggs, int medianCount) {
        IcedHashMap<G, String> gss = AstGroup.doGroups(fr, gbCols, aggs, medianCount);
        G[] grps = gss.keySet().toArray(new G[gss.size()]);
        this.applyOrdering(gbCols, grps);
        this.calculateMediansForGRPS(fr, gbCols, aggs, gss, grps);
        MRTask mrFill = this.prepareMRFillTask(grps, aggs, medianCount);
        String[] fcNames = this.prepareFCNames(fr, aggs);
        Frame f = AstGroup.buildOutput(gbCols, aggs.length, fr, fcNames, grps.length, mrFill);
        return new ValFrame(f);
    }

    private MRTask prepareMRFillTask(final G[] grps, final AGG[] aggs, final int medianCount) {
        return new MRTask(){

            @Override
            public void map(Chunk[] c, NewChunk[] ncs) {
                int start = (int)c[0].start();
                for (int i = 0; i < c[0]._len; ++i) {
                    int j;
                    G g = grps[i + start];
                    for (j = 0; j < g._gs.length; ++j) {
                        ncs[j].addNum(g._gs[j]);
                    }
                    for (int a = 0; a < aggs.length; ++a) {
                        if (medianCount >= 0 && g._isMedian[a]) {
                            ncs[j++].addNum(g._medians[a]);
                            continue;
                        }
                        ncs[j++].addNum(aggs[a]._fcn.postPass(g._dss[a], g._ns[a]));
                    }
                }
            }
        };
    }

    private String[] prepareFCNames(Frame fr, AGG[] aggs) {
        String[] fcnames = new String[aggs.length];
        for (int i = 0; i < aggs.length; ++i) {
            fcnames[i] = aggs[i]._fcn.toString() != "nrow" ? aggs[i]._fcn.toString() + "_" + fr.name(aggs[i]._col) : aggs[i]._fcn.toString();
        }
        return fcnames;
    }

    private int countNumberOfAggregates(Frame fr, int numberOfColumns, AstRoot[] asts) {
        int validGroupByCols = 0;
        for (int idx = 3; idx < asts.length; idx += 3) {
            AstNumList col = AstGroup.check(numberOfColumns, asts[idx + 1]);
            if (col.cnt() != 1L) {
                throw new IllegalArgumentException("Group-By functions take only a single column");
            }
            int agg_col = (int)col.min();
            if (fr.vec(agg_col).isString()) {
                Log.warn("Column " + fr._names[agg_col] + " is a string column.  Groupby operations will be skipped for this column.");
                continue;
            }
            ++validGroupByCols;
        }
        return validGroupByCols;
    }

    private AGG[] constructAggregates(Frame fr, int numberOfAggregates, Env env, AstRoot[] asts) {
        AGG[] aggs = new AGG[numberOfAggregates];
        int ncols = fr.numCols();
        int countCols = 0;
        for (int idx = 3; idx < asts.length; idx += 3) {
            Val v = asts[idx].exec(env);
            String fn = v instanceof ValFun ? v.getFun().str() : v.getStr();
            FCN fcn = FCN.valueOf(fn);
            AstNumList col = AstGroup.check(ncols, asts[idx + 1]);
            if (col.cnt() != 1L) {
                throw new IllegalArgumentException("Group-By functions take only a single column");
            }
            int agg_col = (int)col.min();
            if (fcn == FCN.mode && !fr.vec(agg_col).isCategorical()) {
                throw new IllegalArgumentException("Mode only allowed on categorical columns");
            }
            NAHandling na = NAHandling.valueOf(asts[idx + 2].exec(env).getStr().toUpperCase());
            if (!fr.vec(agg_col).isString()) {
                aggs[countCols++] = new AGG(fcn, agg_col, na, (int)fr.vec(agg_col).max() + 1);
            }
            if (fcn != FCN.median) continue;
            this._numberOfMedianActionsNeeded = 0;
        }
        return aggs;
    }

    private void applyOrdering(final int[] gbCols, G[] grps) {
        if (gbCols.length > 0) {
            Arrays.sort(grps, new Comparator<G>(){

                @Override
                public int compare(G g1, G g2) {
                    for (int i = 0; i < gbCols.length; ++i) {
                        if (Double.isNaN(g1._gs[i]) && !Double.isNaN(g2._gs[i])) {
                            return -1;
                        }
                        if (!Double.isNaN(g1._gs[i]) && Double.isNaN(g2._gs[i])) {
                            return 1;
                        }
                        if (g1._gs[i] == g2._gs[i]) continue;
                        return g1._gs[i] < g2._gs[i] ? -1 : 1;
                    }
                    return 0;
                }

                @Override
                public boolean equals(Object o) {
                    throw H2O.unimpl();
                }
            });
        }
    }

    private void calculateMediansForGRPS(Frame fr, int[] gbCols, AGG[] aggs, IcedHashMap<G, String> gss, G[] grps) {
        if (this._numberOfMedianActionsNeeded >= 0) {
            for (G g : grps) {
                for (int index = 0; index < g._isMedian.length; ++index) {
                    if (!g._isMedian[index]) continue;
                    ++this._numberOfMedianActionsNeeded;
                }
            }
            BuildGroup buildMedians = new BuildGroup(gbCols, aggs, gss, grps, this._numberOfMedianActionsNeeded);
            Vec[] groupChunks = ((BuildGroup)buildMedians.doAll(this._numberOfMedianActionsNeeded, (byte)3, fr)).close();
            buildMedians.calcMedian(groupChunks);
        }
    }

    public static AstNumList check(long dstX, AstRoot ast) {
        AstNumList dim;
        if (ast instanceof AstNumList) {
            dim = (AstNumList)ast;
        } else if (ast instanceof AstNum) {
            dim = new AstNumList(((AstNum)ast).getNum());
        } else {
            throw new IllegalArgumentException("Requires a number-list, but found a " + ast.getClass());
        }
        if (dim.isEmpty()) {
            return dim;
        }
        for (int col : dim.expand4()) {
            if (0 <= col && (long)col < dstX) continue;
            throw new IllegalArgumentException("Selection must be an integer from 0 to " + dstX);
        }
        return dim;
    }

    public static IcedHashMap<G, String> doGroups(Frame fr, int[] gbCols, AGG[] aggs) {
        return AstGroup.doGroups(fr, gbCols, aggs, -1);
    }

    public static IcedHashMap<G, String> doGroups(Frame fr, int[] gbCols, AGG[] aggs, int medianCount) {
        long start = System.currentTimeMillis();
        GBTask p1 = (GBTask)new GBTask(gbCols, aggs, medianCount).doAll(fr);
        Log.info("Group By Task done in " + (double)(System.currentTimeMillis() - start) / 1000.0 + " (s)");
        return p1._gss;
    }

    public static AGG[] aggNRows() {
        return new AGG[]{new AGG(FCN.nrow, 0, NAHandling.IGNORE, 0)};
    }

    public static Frame buildOutput(int[] gbCols, int noutCols, Frame fr, String[] fcnames, int ngrps, MRTask mrfill) {
        int i;
        int nCols = gbCols.length + noutCols;
        String[] names = new String[nCols];
        String[][] domains = new String[nCols][];
        byte[] types = new byte[nCols];
        for (i = 0; i < gbCols.length; ++i) {
            names[i] = fr.name(gbCols[i]);
            domains[i] = fr.domains()[gbCols[i]];
            types[i] = fr.vec(names[i]).get_type();
        }
        for (i = 0; i < fcnames.length; ++i) {
            names[i + gbCols.length] = fcnames[i];
            types[i + gbCols.length] = 3;
        }
        Vec v = Vec.makeZero(ngrps);
        Frame f = ((MRTask)mrfill.doAll(types, new Frame(v))).outputFrame(names, domains);
        v.remove();
        return f;
    }

    private static class BuildGroup
    extends MRTask<BuildGroup> {
        final int[] _gbCols;
        private final AGG[] _aggs;
        private final int _medianCols;
        IcedHashMap<G, String> _gss;
        private G[] _grps;

        BuildGroup(int[] gbCols, AGG[] aggs, IcedHashMap<G, String> gss, G[] grps, int medianCols) {
            this._gbCols = gbCols;
            this._aggs = aggs;
            this._gss = gss;
            this._grps = grps;
            this._medianCols = medianCols;
        }

        @Override
        public void map(Chunk[] cs, NewChunk[] ncs) {
            G gWork = new G(this._gbCols.length, this._aggs, this._medianCols);
            for (int row = 0; row < cs[0]._len; ++row) {
                gWork.fill(row, cs, this._gbCols);
                G gOld = this._gss.getk(gWork);
                for (int i = 0; i < gOld._isMedian.length; ++i) {
                    double d1;
                    if (!gOld._isMedian[i] || Double.isNaN(d1 = cs[gOld._medianCols[i]].atd(row)) && gOld._na[i] == NAHandling.RM) continue;
                    ncs[gOld._newChunkCols[i]].addNum(d1);
                }
            }
        }

        Vec[] close() {
            Futures fs = new Futures();
            int cCount = 0;
            Vec[] tempVgrps = new Vec[this._medianCols];
            for (G oneG : this._grps) {
                for (int index = 0; index < oneG._isMedian.length; ++index) {
                    if (!oneG._isMedian[index]) continue;
                    tempVgrps[cCount++] = this._appendables[oneG._newChunkCols[index]].close(this._appendables[oneG._newChunkCols[index]].compute_rowLayout(), fs);
                }
            }
            fs.blockForPending();
            return tempVgrps;
        }

        public void calcMedian(Vec[] tempVgrps) {
            int cCount = 0;
            for (G oneG : this._grps) {
                for (int index = 0; index < oneG._isMedian.length; ++index) {
                    double medianVal;
                    Vec[] vgrps;
                    long totalRows;
                    if (!oneG._isMedian[index]) continue;
                    if ((totalRows = (vgrps = new Vec[]{tempVgrps[cCount++]})[0].length()) == 0L) {
                        medianVal = Double.NaN;
                    } else {
                        Frame myFrame = new Frame(Key.make(), vgrps, true);
                        long midRow = totalRows / 2L;
                        Frame tempFrame = Merge.sort(myFrame, new int[]{0});
                        medianVal = totalRows % 2L == 0L ? 0.5 * (tempFrame.vec(0).at(midRow - 1L) + tempFrame.vec(0).at(midRow)) : tempFrame.vec(0).at(midRow);
                        tempFrame.delete();
                        myFrame.delete();
                    }
                    oneG._medians[index] = medianVal;
                }
            }
        }
    }

    public static class G
    extends Iced {
        public final double[] _gs;
        int _hash;
        public final double[][] _dss;
        public final long[] _ns;
        int[] _medianCols;
        double[] _medians;
        boolean[] _isMedian;
        int[] _newChunkCols;
        public NAHandling[] _na;

        public G(int ncols, AGG[] aggs) {
            this(ncols, aggs, -1);
        }

        public G(int ncols, AGG[] aggs, int medianCounts) {
            this._gs = new double[ncols];
            int len = aggs == null ? 0 : aggs.length;
            this._dss = new double[len][];
            this._ns = new long[len];
            if (medianCounts >= 0) {
                this._medianCols = new int[len];
                this._medians = new double[len];
                this._isMedian = new boolean[len];
                this._newChunkCols = new int[len];
                this._na = new NAHandling[len];
            }
            for (int i = 0; i < len; ++i) {
                this._dss[i] = aggs[i].initVal();
                if (medianCounts < 0 || !aggs[i]._fcn.toString().equals("median")) continue;
                this._medianCols[i] = aggs[i]._col;
                this._isMedian[i] = true;
                this._na[i] = aggs[i]._na;
            }
        }

        public G fill(int row, Chunk[] chks) {
            for (int c = 0; c < chks.length; ++c) {
                this._gs[c] = chks[c].atd(row);
            }
            this._hash = this.hash();
            return this;
        }

        public G fill(int row, Chunk[] chks, int[] cols) {
            for (int c = 0; c < cols.length; ++c) {
                this._gs[c] = chks[cols[c]].atd(row);
            }
            this._hash = this.hash();
            return this;
        }

        protected int hash() {
            long h = 0L;
            for (double d : this._gs) {
                h += Double.doubleToRawLongBits(d);
            }
            h ^= h >>> 20 ^ h >>> 12;
            h ^= h >>> 7 ^ h >>> 4;
            return (int)((h ^ h >> 32) & Integer.MAX_VALUE);
        }

        public boolean equals(Object o) {
            return o instanceof G && Arrays.equals(this._gs, ((G)o)._gs);
        }

        public int hashCode() {
            return this._hash;
        }

        public String toString() {
            return Arrays.toString(this._gs);
        }
    }

    public static class GBTask
    extends MRTask<GBTask> {
        final IcedHashMap<G, String> _gss;
        private final int[] _gbCols;
        private final AGG[] _aggs;
        private final int _medianCounts;

        GBTask(int[] gbCols, AGG[] aggs, int medianCounts) {
            this._gbCols = gbCols;
            this._aggs = aggs;
            this._gss = new IcedHashMap();
            this._medianCounts = medianCounts;
        }

        @Override
        public void map(Chunk[] cs) {
            IcedHashMap<G, String> gs = new IcedHashMap<G, String>();
            G gWork = new G(this._gbCols.length, this._aggs, this._medianCounts);
            for (int row = 0; row < cs[0]._len; ++row) {
                G gOld;
                gWork.fill(row, cs, this._gbCols);
                if (gs.putIfAbsent(gWork, "") == null) {
                    gOld = gWork;
                    gWork = new G(this._gbCols.length, this._aggs, this._medianCounts);
                } else {
                    gOld = gs.getk(gWork);
                }
                for (int i = 0; i < this._aggs.length; ++i) {
                    this._aggs[i].op(gOld._dss, gOld._ns, i, cs[this._aggs[i]._col].atd(row));
                }
            }
            this.reduce(gs);
        }

        @Override
        public void reduce(GBTask t) {
            if (this._gss != t._gss) {
                this.reduce(t._gss);
            }
        }

        @Override
        private void reduce(IcedHashMap<G, String> r) {
            for (G rg : r.keySet()) {
                if (this._gss.putIfAbsent(rg, "") == null) continue;
                G lg = this._gss.getk(rg);
                for (int i = 0; i < this._aggs.length; ++i) {
                    this._aggs[i].atomic_op(lg._dss, lg._ns, i, rg._dss[i], rg._ns[i]);
                }
            }
        }
    }

    public static class AGG
    extends Iced {
        final FCN _fcn;
        public final int _col;
        final NAHandling _na;
        final int _maxx;

        public AGG(FCN fcn, int col, NAHandling na, int maxx) {
            this._fcn = fcn;
            this._col = col;
            this._na = na;
            this._maxx = maxx;
        }

        public void op(double[][] d0ss, long[] n0s, int i, double d1) {
            if (!Double.isNaN(d1) || this._na == NAHandling.ALL) {
                this._fcn.op(d0ss[i], d1);
            }
            if (!Double.isNaN(d1) || this._na == NAHandling.IGNORE) {
                int n = i;
                n0s[n] = n0s[n] + 1L;
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void atomic_op(double[][] d0ss, long[] n0s, int i, double[] d1s, long n1) {
            double[] dArray = d0ss[i];
            synchronized (dArray) {
                this._fcn.atomic_op(d0ss[i], d1s);
                int n = i;
                n0s[n] = n0s[n] + n1;
            }
        }

        public double[] initVal() {
            return this._fcn.initVal(this._maxx);
        }
    }

    public static enum FCN {
        nrow{

            @Override
            public void op(double[] d0s, double d1) {
                d0s[0] = d0s[0] + 1.0;
            }

            @Override
            public void atomic_op(double[] d0s, double[] d1s) {
                d0s[0] = d0s[0] + d1s[0];
            }

            @Override
            public double postPass(double[] ds, long n) {
                return ds[0];
            }
        }
        ,
        mean{

            @Override
            public void op(double[] d0s, double d1) {
                d0s[0] = d0s[0] + d1;
            }

            @Override
            public void atomic_op(double[] d0s, double[] d1s) {
                d0s[0] = d0s[0] + d1s[0];
            }

            @Override
            public double postPass(double[] ds, long n) {
                return ds[0] / (double)n;
            }
        }
        ,
        sum{

            @Override
            public void op(double[] d0s, double d1) {
                d0s[0] = d0s[0] + d1;
            }

            @Override
            public void atomic_op(double[] d0s, double[] d1s) {
                d0s[0] = d0s[0] + d1s[0];
            }

            @Override
            public double postPass(double[] ds, long n) {
                return ds[0];
            }
        }
        ,
        sumSquares{

            @Override
            public void op(double[] d0s, double d1) {
                d0s[0] = d0s[0] + d1 * d1;
            }

            @Override
            public void atomic_op(double[] d0s, double[] d1s) {
                d0s[0] = d0s[0] + d1s[0];
            }

            @Override
            public double postPass(double[] ds, long n) {
                return ds[0];
            }
        }
        ,
        var{

            @Override
            public void op(double[] d0s, double d1) {
                d0s[0] = d0s[0] + d1 * d1;
                d0s[1] = d0s[1] + d1;
            }

            @Override
            public void atomic_op(double[] d0s, double[] d1s) {
                ArrayUtils.add(d0s, d1s);
            }

            @Override
            public double postPass(double[] ds, long n) {
                double numerator = ds[0] - ds[1] * ds[1] / (double)n;
                if (Math.abs(numerator) < 1.0E-5) {
                    numerator = 0.0;
                }
                return numerator / (double)(n - 1L);
            }

            @Override
            public double[] initVal(int ignored) {
                return new double[2];
            }
        }
        ,
        sdev{

            @Override
            public void op(double[] d0s, double d1) {
                d0s[0] = d0s[0] + d1 * d1;
                d0s[1] = d0s[1] + d1;
            }

            @Override
            public void atomic_op(double[] d0s, double[] d1s) {
                ArrayUtils.add(d0s, d1s);
            }

            @Override
            public double postPass(double[] ds, long n) {
                double numerator = ds[0] - ds[1] * ds[1] / (double)n;
                if (Math.abs(numerator) < 1.0E-5) {
                    numerator = 0.0;
                }
                return Math.sqrt(numerator / (double)(n - 1L));
            }

            @Override
            public double[] initVal(int ignored) {
                return new double[2];
            }
        }
        ,
        min{

            @Override
            public void op(double[] d0s, double d1) {
                d0s[0] = Math.min(d0s[0], d1);
            }

            @Override
            public void atomic_op(double[] d0s, double[] d1s) {
                this.op(d0s, d1s[0]);
            }

            @Override
            public double postPass(double[] ds, long n) {
                return ds[0];
            }

            @Override
            public double[] initVal(int maxx) {
                return new double[]{Double.MAX_VALUE};
            }
        }
        ,
        max{

            @Override
            public void op(double[] d0s, double d1) {
                d0s[0] = Math.max(d0s[0], d1);
            }

            @Override
            public void atomic_op(double[] d0s, double[] d1s) {
                this.op(d0s, d1s[0]);
            }

            @Override
            public double postPass(double[] ds, long n) {
                return ds[0];
            }

            @Override
            public double[] initVal(int maxx) {
                return new double[]{-1.7976931348623157E308};
            }
        }
        ,
        median{

            @Override
            public void op(double[] d0s, double d1) {
            }

            @Override
            public void atomic_op(double[] d0s, double[] d1s) {
            }

            @Override
            public double postPass(double[] ds, long n) {
                return 0.0;
            }

            @Override
            public double[] initVal(int maxx) {
                return new double[maxx];
            }
        }
        ,
        mode{

            @Override
            public void op(double[] d0s, double d1) {
                int n = (int)d1;
                d0s[n] = d0s[n] + 1.0;
            }

            @Override
            public void atomic_op(double[] d0s, double[] d1s) {
                ArrayUtils.add(d0s, d1s);
            }

            @Override
            public double postPass(double[] ds, long n) {
                return ArrayUtils.maxIndex(ds);
            }

            @Override
            public double[] initVal(int maxx) {
                return new double[maxx];
            }
        };


        public abstract void op(double[] var1, double var2);

        public abstract void atomic_op(double[] var1, double[] var2);

        public abstract double postPass(double[] var1, long var2);

        public double[] initVal(int maxx) {
            return new double[]{0.0};
        }
    }

    public static enum NAHandling {
        ALL,
        RM,
        IGNORE;

    }
}

