/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.physical;

import io.prestosql.hive.$internal.org.slf4j.Logger;
import io.prestosql.hive.$internal.org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.Stack;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.tez.TezTask;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Dispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.optimizer.physical.LlapDecider;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.TezWork;

public class LlapPreVectorizationPass
implements PhysicalPlanResolver {
    protected static final transient Logger LOG = LoggerFactory.getLogger(LlapPreVectorizationPass.class);

    @Override
    public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException {
        HiveConf conf = pctx.getConf();
        LlapDecider.LlapMode mode = LlapDecider.LlapMode.valueOf(HiveConf.getVar(conf, HiveConf.ConfVars.LLAP_EXECUTION_MODE));
        if (mode == LlapDecider.LlapMode.none) {
            LOG.info("LLAP disabled.");
            return pctx;
        }
        LlapPreVectorizationPassDispatcher disp = new LlapPreVectorizationPassDispatcher(pctx);
        DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getRootTasks());
        ogw.startWalking(topNodes, null);
        return pctx;
    }

    class LlapPreVectorizationPassDispatcher
    implements Dispatcher {
        HiveConf conf;

        LlapPreVectorizationPassDispatcher(PhysicalContext pctx) {
            this.conf = pctx.getConf();
        }

        @Override
        public Object dispatch(Node nd, Stack<Node> stack, Object ... nodeOutputs) throws SemanticException {
            Task currTask = (Task)nd;
            if (currTask instanceof TezTask) {
                TezWork work = (TezWork)((TezTask)currTask).getWork();
                for (BaseWork w : work.getAllWork()) {
                    this.handleWork(work, w);
                }
            }
            return null;
        }

        private void handleWork(TezWork tezWork, BaseWork work) throws SemanticException {
            LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
            if (this.conf.getVar(HiveConf.ConfVars.LLAP_EXECUTION_MODE).equals("only") && !this.conf.getBoolVar(HiveConf.ConfVars.LLAP_ENABLE_GRACE_JOIN_IN_LLAP)) {
                opRules.put(new RuleRegExp("Disable grace hash join if LLAP mode and not dynamic partition hash join", MapJoinOperator.getOperatorName() + "%"), new NodeProcessor(){

                    @Override
                    public Object process(Node n, Stack<Node> s, NodeProcessorCtx c, Object ... os) {
                        MapJoinOperator mapJoinOp = (MapJoinOperator)n;
                        if (((MapJoinDesc)mapJoinOp.getConf()).isHybridHashJoin() && !((MapJoinDesc)mapJoinOp.getConf()).isDynamicPartitionHashJoin()) {
                            ((MapJoinDesc)mapJoinOp.getConf()).setHybridHashJoin(false);
                        }
                        return new Boolean(true);
                    }
                });
            }
            if (!opRules.isEmpty()) {
                DefaultRuleDispatcher disp = new DefaultRuleDispatcher(null, opRules, null);
                DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
                ArrayList<Node> topNodes = new ArrayList<Node>();
                topNodes.addAll(work.getAllRootOperators());
                ogw.startWalking(topNodes, null);
            }
        }
    }
}

