/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer;

import com.facebook.presto.hive.shaded.org.apache.commons.logging.Log;
import com.facebook.presto.hive.shaded.org.apache.commons.logging.LogFactory;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.Stack;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.JoinOperator;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.optimizer.SortBucketJoinProcCtx;
import org.apache.hadoop.hive.ql.optimizer.SortedMergeBucketMapjoinProc;
import org.apache.hadoop.hive.ql.optimizer.SortedMergeJoinProc;
import org.apache.hadoop.hive.ql.optimizer.Transform;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.SemanticException;

public class SortedMergeBucketMapJoinOptimizer
implements Transform {
    private static final Log LOG = LogFactory.getLog(SortedMergeBucketMapJoinOptimizer.class.getName());

    private void getListOfRejectedJoins(ParseContext pctx, SortBucketJoinProcCtx smbJoinContext) throws SemanticException {
        LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
        opRules.put(new RuleRegExp("R1", JoinOperator.getOperatorName() + "%"), this.getCheckCandidateJoin());
        DefaultRuleDispatcher disp = new DefaultRuleDispatcher(this.getDefaultProc(), opRules, smbJoinContext);
        DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getTopOps().values());
        ogw.startWalking(topNodes, null);
    }

    @Override
    public ParseContext transform(ParseContext pctx) throws SemanticException {
        HiveConf conf = pctx.getConf();
        SortBucketJoinProcCtx smbJoinContext = new SortBucketJoinProcCtx(conf);
        this.getListOfRejectedJoins(pctx, smbJoinContext);
        LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
        opRules.put(new RuleRegExp("R1", MapJoinOperator.getOperatorName() + "%"), this.getSortedMergeBucketMapjoinProc(pctx));
        if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_AUTO_SORTMERGE_JOIN)) {
            opRules.put(new RuleRegExp("R2", "JOIN%"), this.getSortedMergeJoinProc(pctx));
        }
        DefaultRuleDispatcher disp = new DefaultRuleDispatcher(this.getDefaultProc(), opRules, smbJoinContext);
        DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getTopOps().values());
        ogw.startWalking(topNodes, null);
        return pctx;
    }

    private NodeProcessor getSortedMergeBucketMapjoinProc(ParseContext pctx) {
        return new SortedMergeBucketMapjoinProc(pctx);
    }

    private NodeProcessor getSortedMergeJoinProc(ParseContext pctx) {
        return new SortedMergeJoinProc(pctx);
    }

    private NodeProcessor getDefaultProc() {
        return new NodeProcessor(){

            @Override
            public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
                return null;
            }
        };
    }

    private NodeProcessor getCheckCandidateJoin() {
        return new NodeProcessor(){

            @Override
            public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
                SortBucketJoinProcCtx smbJoinContext = (SortBucketJoinProcCtx)procCtx;
                JoinOperator joinOperator = (JoinOperator)nd;
                int size = stack.size();
                if (!(stack.get(size - 1) instanceof JoinOperator) || !(stack.get(size - 2) instanceof ReduceSinkOperator)) {
                    smbJoinContext.getRejectedJoinOps().add(joinOperator);
                    return null;
                }
                for (int pos = size - 3; pos >= 0; --pos) {
                    Operator op = (Operator)stack.get(pos);
                    if (op.supportAutomaticSortMergeJoin()) continue;
                    smbJoinContext.getRejectedJoinOps().add(joinOperator);
                    return null;
                }
                return null;
            }
        };
    }
}

