/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.ast.prims.mungers;

import java.util.Arrays;
import water.DKV;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.util.ArrayUtils;
import water.util.VecUtils;

public class AstRelevelByFreq
extends AstPrimitive<AstRelevelByFreq> {
    @Override
    public String[] args() {
        return new String[]{"frame", "weights", "topn"};
    }

    @Override
    public int nargs() {
        return 4;
    }

    @Override
    public String str() {
        return "relevel.by.freq";
    }

    @Override
    public ValFrame apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        Frame f2 = stk.track(asts[1].exec(env)).getFrame();
        String weightsColumn = asts[2].exec(env).getStr();
        Vec weights = f2.vec(weightsColumn);
        if (weightsColumn != null && weights == null) {
            throw new IllegalArgumentException("Frame doesn't contain weights column '" + weightsColumn + "'.");
        }
        double topN = asts[3].exec(env).getNum();
        if (topN != -1.0 && topN <= 0.0 || (double)((int)topN) != topN) {
            throw new IllegalArgumentException("TopN argument needs to be a positive integer number, got: " + topN);
        }
        Frame result = new Frame(f2);
        for (int i2 = 0; i2 < result.numCols(); ++i2) {
            Vec v2 = result.vec(i2);
            if (!v2.isCategorical()) continue;
            v2 = v2.makeCopy();
            result.replace(i2, v2);
            AstRelevelByFreq.relevelByFreq(v2, weights, (int)topN);
        }
        return new ValFrame(result);
    }

    static void relevelByFreq(Vec v2, Vec weights, int topN) {
        double[] levelWeights = VecUtils.collectDomainWeights(v2, weights);
        int[] newDomainOrder = ArrayUtils.seq(0, levelWeights.length);
        ArrayUtils.sort(newDomainOrder, levelWeights);
        if (topN != -1 && topN < newDomainOrder.length - 1) {
            newDomainOrder = AstRelevelByFreq.takeTopN(newDomainOrder, topN, v2.domain().length);
        }
        String[] domain = v2.domain();
        String[] newDomain = (String[])v2.domain().clone();
        for (int i2 = 0; i2 < newDomainOrder.length; ++i2) {
            newDomain[i2] = domain[newDomainOrder[newDomainOrder.length - i2 - 1]];
        }
        new RemapDomain(newDomainOrder).doAll(v2);
        v2.setDomain(newDomain);
        DKV.put(v2);
    }

    static int[] takeTopN(int[] domainOrder, int topN, int domainSize) {
        int[] newDomainOrder = new int[domainSize];
        int[] topNidxs = new int[topN];
        for (int i2 = 0; i2 < topN; ++i2) {
            int topIdx;
            topNidxs[i2] = topIdx = domainOrder[domainOrder.length - i2 - 1];
            newDomainOrder[domainSize - i2 - 1] = topIdx;
        }
        Arrays.sort(topNidxs);
        int pos = domainSize - topN - 1;
        for (int i3 = 0; i3 < domainSize; ++i3) {
            if (Arrays.binarySearch(topNidxs, i3) >= 0) continue;
            newDomainOrder[pos--] = i3;
        }
        assert (pos == -1);
        return newDomainOrder;
    }

    static class RemapDomain
    extends MRTask<RemapDomain> {
        private final int[] _mapping;

        public RemapDomain(int[] mapping) {
            this._mapping = mapping;
        }

        @Override
        public void map(Chunk c2) {
            for (int row = 0; row < c2._len; ++row) {
                if (c2.isNA(row)) continue;
                int level = (int)c2.atd(row);
                int newLevel = this._mapping.length - this._mapping[level] - 1;
                c2.set(row, newLevel);
            }
        }
    }
}

