/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.prims.mungers;

import ai.h2o.targetencoding.TargetEncoder;
import water.fvec.Frame;
import water.rapids.Env;
import water.rapids.ast.AstBuiltin;
import water.rapids.ast.AstRoot;
import water.rapids.ast.params.AstStr;
import water.rapids.ast.params.AstStrList;
import water.rapids.vals.ValMapFrame;
import water.util.IcedHashMapGeneric;

public class AstTargetEncoderFit
extends AstBuiltin<AstTargetEncoderFit> {
    public String[] args() {
        return new String[]{"trainFrame teColumns targetColumnName foldColumnName"};
    }

    public String str() {
        return "target.encoder.fit";
    }

    public int nargs() {
        return 5;
    }

    public ValMapFrame apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        Frame trainFrame = this.getTrainingFrame(env, stk, asts);
        String[] teColumnsToEncode = this.getTEColumns(env, stk, asts);
        String targetColumnName = this.getTargetColumnName(env, stk, asts);
        String foldColumnName = this.getFoldColumnName(env, stk, asts);
        boolean withImputationForOriginalColumns = true;
        TargetEncoder tec = new TargetEncoder(teColumnsToEncode);
        IcedHashMapGeneric<String, Frame> encodingMap = tec.prepareEncodingMap(trainFrame, targetColumnName, foldColumnName, withImputationForOriginalColumns);
        return new ValMapFrame(encodingMap);
    }

    private Frame getTrainingFrame(Env env, Env.StackHelp stk, AstRoot[] asts) {
        return stk.track(asts[1].exec(env)).getFrame();
    }

    private String[] getTEColumns(Env env, Env.StackHelp stk, AstRoot[] asts) {
        if (asts[2] instanceof AstStrList) {
            AstStrList teColumns = (AstStrList)asts[2];
            return teColumns._strs;
        }
        if (asts[2] instanceof AstStr) {
            String teColumn = stk.track(asts[2].exec(env)).getStr();
            return new String[]{teColumn};
        }
        throw new IllegalStateException("Couldn't parse `teColumns` parameter");
    }

    private String getTargetColumnName(Env env, Env.StackHelp stk, AstRoot[] asts) {
        return stk.track(asts[3].exec(env)).getStr();
    }

    private String getFoldColumnName(Env env, Env.StackHelp stk, AstRoot[] asts) {
        try {
            String str = stk.track(asts[4].exec(env)).getStr();
            if (str.equals("")) {
                return null;
            }
            return str;
        }
        catch (IllegalArgumentException ex) {
            return null;
        }
    }

    private boolean getWithImputation(Env env, Env.StackHelp stk, AstRoot[] asts) {
        return stk.track(asts[5].exec(env)).getNum() == 1.0;
    }
}

