/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.ast.prims.advmath;

import hex.quantile.QuantileModel;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import water.Freezable;
import water.H2O;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.Val;
import water.rapids.ast.AstFrame;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.ast.params.AstNum;
import water.rapids.ast.params.AstNumList;
import water.rapids.ast.params.AstStr;
import water.rapids.ast.params.AstStrList;
import water.rapids.ast.prims.advmath.AstMode;
import water.rapids.ast.prims.mungers.AstGroup;
import water.rapids.ast.prims.reducers.AstMean;
import water.rapids.ast.prims.reducers.AstMedian;
import water.rapids.vals.ValFrame;
import water.rapids.vals.ValNums;
import water.util.ArrayUtils;
import water.util.IcedDouble;
import water.util.IcedHashMap;

public class AstImpute
extends AstPrimitive {
    @Override
    public String[] args() {
        return new String[]{"ary", "col", "method", "combineMethod", "groupByCols", "groupByFrame", "values"};
    }

    @Override
    public String str() {
        return "h2o.impute";
    }

    @Override
    public int nargs() {
        return 8;
    }

    @Override
    public Val apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        IcedHashMap group_impute_map;
        int i;
        boolean doGrpBy;
        AstNumList by2;
        Frame fr = stk.track(asts[1].exec(env)).getFrame();
        int col = (int)asts[2].exec(env).getNum();
        if (col >= fr.numCols()) {
            throw new IllegalArgumentException("Column not -1 or in range 0 to " + fr.numCols());
        }
        boolean doAllVecs = col == -1;
        Vec vec = doAllVecs ? null : fr.vec(col);
        AstPrimitive method = null;
        boolean ffill0 = false;
        boolean bfill0 = false;
        String string = asts[3].exec(env).getStr().toUpperCase();
        int n = -1;
        switch (string.hashCode()) {
            case 2362309: {
                if (!string.equals("MEAN")) break;
                n = 0;
                break;
            }
            case -2024701686: {
                if (!string.equals("MEDIAN")) break;
                n = 1;
                break;
            }
            case 2372003: {
                if (!string.equals("MODE")) break;
                n = 2;
                break;
            }
            case 66804425: {
                if (!string.equals("FFILL")) break;
                n = 3;
                break;
            }
            case 63110341: {
                if (!string.equals("BFILL")) break;
                n = 4;
            }
        }
        switch (n) {
            case 0: {
                method = new AstMean();
                break;
            }
            case 1: {
                method = new AstMedian();
                break;
            }
            case 2: {
                method = new AstMode();
                break;
            }
            case 3: {
                ffill0 = true;
                break;
            }
            case 4: {
                bfill0 = true;
                break;
            }
            default: {
                throw new IllegalArgumentException("Method must be one of mean, median or mode");
            }
        }
        QuantileModel.CombineMethod combine = QuantileModel.CombineMethod.valueOf(asts[4].exec(env).getStr().toUpperCase());
        AstRoot ast = asts[5];
        if (ast instanceof AstNumList) {
            by2 = (AstNumList)ast;
        } else if (ast instanceof AstNum) {
            by2 = new AstNumList(((AstNum)ast).getNum());
        } else if (ast instanceof AstStrList) {
            String[] names = ((AstStrList)ast)._strs;
            double[] list = new double[names.length];
            int i2 = 0;
            for (String name : ((AstStrList)ast)._strs) {
                list[i2++] = fr.find(name);
            }
            Arrays.sort(list);
            by2 = new AstNumList(list);
        } else {
            throw new IllegalArgumentException("Requires a number-list, but found a " + ast.getClass());
        }
        Frame groupByFrame = asts[6].str().equals("_") ? null : stk.track(asts[6].exec(env)).getFrame();
        AstRoot vals = asts[7];
        AstNumList values = vals instanceof AstNumList ? (AstNumList)vals : (vals instanceof AstNum ? new AstNumList(((AstNum)vals).getNum()) : null);
        boolean bl = doGrpBy = !by2.isEmpty() || groupByFrame != null;
        if (!doGrpBy) {
            double[] res;
            if (ffill0 || bfill0) {
                boolean ffill = ffill0;
                boolean bfill = bfill0;
                throw H2O.unimpl("No ffill or bfill imputation supported");
            }
            double[] dArray = res = values == null ? new double[fr.numCols()] : values.expand();
            if (values == null) {
                if (doAllVecs) {
                    for (int i3 = 0; i3 < res.length; ++i3) {
                        if (!fr.vec(i3).isNumeric() && !fr.vec(i3).isCategorical()) continue;
                        res[i3] = fr.vec(i3).isNumeric() ? fr.vec(i3).mean() : (double)ArrayUtils.maxIndex(fr.vec(i3).bins());
                    }
                } else {
                    Arrays.fill(res, Double.NaN);
                    if (method instanceof AstMean) {
                        res[col] = vec.mean();
                    }
                    if (method instanceof AstMedian) {
                        res[col] = AstMedian.median(new Frame(vec), combine);
                    }
                    if (method instanceof AstMode) {
                        res[col] = AstMode.mode(vec);
                    }
                }
            }
            new MRTask(){

                @Override
                public void map(Chunk[] cs) {
                    int len = cs[0]._len;
                    for (int c = 0; c < cs.length; ++c) {
                        if (Double.isNaN(res[c])) continue;
                        for (int row = 0; row < len; ++row) {
                            if (!cs[c].isNA(row)) continue;
                            cs[c].set(row, res[c]);
                        }
                    }
                }
            }.doAll(fr);
            return new ValNums(res);
        }
        if (col >= fr.numCols()) {
            throw new IllegalArgumentException("Column not -1 or in range 0 to " + fr.numCols());
        }
        Frame imputes = groupByFrame;
        if (imputes == null) {
            AstGroup ast_grp = new AstGroup();
            if (doAllVecs) {
                AstRoot[] aggs = new AstRoot[(int)(3L + 3L * ((long)fr.numCols() - by2.cnt()))];
                aggs[0] = ast_grp;
                aggs[1] = new AstFrame(fr);
                aggs[2] = by2;
                int c = 3;
                for (i = 0; i < fr.numCols(); ++i) {
                    if (by2.has(i) || !fr.vec(i).isCategorical() && !fr.vec(i).isNumeric()) continue;
                    aggs[c] = fr.vec(i).isNumeric() ? new AstMean() : new AstMode();
                    aggs[c + 1] = new AstNumList(i, i + 1);
                    aggs[c + 2] = new AstStr("rm");
                    c += 3;
                }
                imputes = ((AstRoot)ast_grp).apply(env, stk, aggs).getFrame();
            } else {
                imputes = ((AstRoot)ast_grp).apply(env, stk, new AstRoot[]{ast_grp, new AstFrame(fr), by2, method, new AstNumList(col, col + 1), new AstStr("rm")}).getFrame();
            }
        }
        if (by2.isEmpty() && imputes.numCols() > 2) {
            throw new IllegalArgumentException("Ambiguous group-by frame. Supply the `by` columns to proceed.");
        }
        int[] bycols0 = ArrayUtils.seq(0, Math.max((int)by2.cnt(), 1));
        final IcedHashMap final_group_impute_map = group_impute_map = ((Gather)new Gather(by2.expand4(), bycols0, fr.numCols(), col).doAll(imputes))._group_impute_map;
        if (by2.isEmpty()) {
            int[] byCols = new int[imputes.numCols() - 1];
            for (i = 0; i < byCols.length; ++i) {
                byCols[i] = fr.find(imputes.name(i));
            }
            by2 = new AstNumList(byCols);
        }
        final int[] bycols = by2.expand4();
        new MRTask(){

            @Override
            public void map(Chunk[] cs) {
                HashSet<Integer> _bycolz = new HashSet<Integer>();
                for (int b : bycols) {
                    _bycolz.add(b);
                }
                AstGroup.G g = new AstGroup.G(bycols.length, null);
                for (int row = 0; row < cs[0]._len; ++row) {
                    for (int c = 0; c < cs.length; ++c) {
                        if (_bycolz.contains(c) || !cs[c].isNA(row)) continue;
                        cs[c].set(row, ((IcedDouble)((Freezable[])final_group_impute_map.get((Object)g.fill((int)row, (Chunk[])cs, (int[])bycols)))[c])._val);
                    }
                }
            }
        }.doAll(fr);
        return new ValFrame(imputes);
    }

    private static class Gather
    extends MRTask<Gather> {
        private final int _imputedCol;
        private final int _ncol;
        private final int[] _byCols0;
        private final int[] _byCols;
        private IcedHashMap<AstGroup.G, Freezable[]> _group_impute_map;
        private transient Set<Integer> _localbyColzSet;

        Gather(int[] byCols0, int[] byCols, int ncol, int imputeCol) {
            this._byCols = byCols;
            this._byCols0 = byCols0;
            this._ncol = ncol;
            this._imputedCol = imputeCol;
        }

        @Override
        public void setupLocal() {
            this._localbyColzSet = new HashSet<Integer>();
            for (int by : this._byCols0) {
                this._localbyColzSet.add(by);
            }
        }

        @Override
        public void map(Chunk[] cs) {
            this._group_impute_map = new IcedHashMap();
            for (int row = 0; row < cs[0]._len; ++row) {
                IcedDouble[] imputes = new IcedDouble[this._ncol];
                int c = 0;
                int z = this._byCols.length;
                while (c < imputes.length) {
                    imputes[c] = this._imputedCol != -1 ? (c == this._imputedCol ? new IcedDouble(cs[cs.length - 1].atd(row)) : new IcedDouble(Double.NaN)) : (this._localbyColzSet.contains(c) ? new IcedDouble(Double.NaN) : new IcedDouble(cs[z].atd(row)));
                    ++c;
                    ++z;
                }
                this._group_impute_map.put(new AstGroup.G(this._byCols.length, null).fill(row, cs, this._byCols), imputes);
            }
        }

        @Override
        public void reduce(Gather mrt) {
            this._group_impute_map.putAll(mrt._group_impute_map);
        }
    }
}

