/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.facebook.presto.hive.$internal.com.google.common.collect.ImmutableList;
import com.facebook.presto.hive.$internal.com.google.common.collect.Lists;
import com.facebook.presto.hive.$internal.org.slf4j.Logger;
import com.facebook.presto.hive.$internal.org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;

public abstract class HiveSemiJoinRule
extends RelOptRule {
    protected static final Logger LOG = LoggerFactory.getLogger(HiveSemiJoinRule.class);
    public static final HiveProjectToSemiJoinRule INSTANCE_PROJECT = new HiveProjectToSemiJoinRule(HiveRelFactories.HIVE_BUILDER);
    public static final HiveAggregateToSemiJoinRule INSTANCE_AGGREGATE = new HiveAggregateToSemiJoinRule(HiveRelFactories.HIVE_BUILDER);

    private HiveSemiJoinRule(RelOptRuleOperand operand, RelBuilderFactory relBuilder) {
        super(operand, relBuilder, null);
    }

    protected void perform(RelOptRuleCall call, ImmutableBitSet topRefs, RelNode topOperator, Join join, RelNode left, Aggregate aggregate) {
        LOG.debug("Matched HiveSemiJoinRule");
        RelOptCluster cluster = join.getCluster();
        RexBuilder rexBuilder = cluster.getRexBuilder();
        ImmutableBitSet rightBits = ImmutableBitSet.range((int)left.getRowType().getFieldCount(), (int)join.getRowType().getFieldCount());
        if (topRefs.intersects(rightBits)) {
            return;
        }
        JoinInfo joinInfo = join.analyzeCondition();
        if (!joinInfo.rightSet().equals((Object)ImmutableBitSet.range((int)aggregate.getGroupCount()))) {
            return;
        }
        if (join.getJoinType() == JoinRelType.LEFT) {
            call.transformTo(topOperator.copy(topOperator.getTraitSet(), ImmutableList.of(left)));
            return;
        }
        if (join.getJoinType() != JoinRelType.INNER) {
            return;
        }
        if (!joinInfo.isEqui()) {
            return;
        }
        LOG.debug("All conditions matched for HiveSemiJoinRule. Going to apply transformation.");
        ArrayList newRightKeyBuilder = Lists.newArrayList();
        List aggregateKeys = aggregate.getGroupSet().asList();
        Iterator iterator = joinInfo.rightKeys.iterator();
        while (iterator.hasNext()) {
            int key = (Integer)iterator.next();
            newRightKeyBuilder.add(aggregateKeys.get(key));
        }
        ImmutableIntList newRightKeys = ImmutableIntList.copyOf(newRightKeyBuilder);
        RelNode newRight = aggregate.getInput();
        RexNode newCondition = RelOptUtil.createEquiJoinCondition((RelNode)left, (List)joinInfo.leftKeys, (RelNode)newRight, (List)newRightKeys, (RexBuilder)rexBuilder);
        RelNode semi = null;
        if (aggregate.getInput() instanceof HepRelVertex && ((HepRelVertex)aggregate.getInput()).getCurrentRel() instanceof Join) {
            Join rightJoin = (Join)((HepRelVertex)aggregate.getInput()).getCurrentRel();
            ArrayList<RexInputRef> projects = new ArrayList<RexInputRef>();
            for (int i = 0; i < rightJoin.getRowType().getFieldCount(); ++i) {
                projects.add(rexBuilder.makeInputRef((RelNode)rightJoin, i));
            }
            RelNode topProject = call.builder().push((RelNode)rightJoin).project(projects, (Iterable)rightJoin.getRowType().getFieldNames(), true).build();
            semi = call.builder().push(left).push(topProject).semiJoin(new RexNode[]{newCondition}).build();
        } else {
            semi = call.builder().push(left).push(aggregate.getInput()).semiJoin(new RexNode[]{newCondition}).build();
        }
        call.transformTo(topOperator.copy(topOperator.getTraitSet(), ImmutableList.of(semi)));
    }

    public static class HiveAggregateToSemiJoinRule
    extends HiveSemiJoinRule {
        public HiveAggregateToSemiJoinRule(RelBuilderFactory relBuilder) {
            super(HiveAggregateToSemiJoinRule.operand(Aggregate.class, (RelOptRuleOperandChildren)HiveAggregateToSemiJoinRule.some((RelOptRuleOperand)HiveAggregateToSemiJoinRule.operand(Join.class, (RelOptRuleOperandChildren)HiveAggregateToSemiJoinRule.some((RelOptRuleOperand)HiveAggregateToSemiJoinRule.operand(RelNode.class, (RelOptRuleOperandChildren)HiveAggregateToSemiJoinRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[]{HiveAggregateToSemiJoinRule.operand(Aggregate.class, (RelOptRuleOperandChildren)HiveAggregateToSemiJoinRule.any())})), (RelOptRuleOperand[])new RelOptRuleOperand[0])), relBuilder);
        }

        public void onMatch(RelOptRuleCall call) {
            Aggregate topAggregate = (Aggregate)call.rel(0);
            Join join = (Join)call.rel(1);
            RelNode left = call.rel(2);
            Aggregate aggregate = (Aggregate)call.rel(3);
            ImmutableBitSet.Builder topRefs = ImmutableBitSet.builder();
            topRefs.addAll(topAggregate.getGroupSet());
            for (AggregateCall aggCall : topAggregate.getAggCallList()) {
                topRefs.addAll((Iterable)aggCall.getArgList());
                if (aggCall.filterArg == -1) continue;
                topRefs.set(aggCall.filterArg);
            }
            this.perform(call, topRefs.build(), (RelNode)topAggregate, join, left, aggregate);
        }
    }

    public static class HiveProjectToSemiJoinRule
    extends HiveSemiJoinRule {
        public HiveProjectToSemiJoinRule(RelBuilderFactory relBuilder) {
            super(HiveProjectToSemiJoinRule.operand(Project.class, (RelOptRuleOperandChildren)HiveProjectToSemiJoinRule.some((RelOptRuleOperand)HiveProjectToSemiJoinRule.operand(Join.class, (RelOptRuleOperandChildren)HiveProjectToSemiJoinRule.some((RelOptRuleOperand)HiveProjectToSemiJoinRule.operand(RelNode.class, (RelOptRuleOperandChildren)HiveProjectToSemiJoinRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[]{HiveProjectToSemiJoinRule.operand(Aggregate.class, (RelOptRuleOperandChildren)HiveProjectToSemiJoinRule.any())})), (RelOptRuleOperand[])new RelOptRuleOperand[0])), relBuilder);
        }

        public void onMatch(RelOptRuleCall call) {
            Project project = (Project)call.rel(0);
            Join join = (Join)call.rel(1);
            RelNode left = call.rel(2);
            Aggregate aggregate = (Aggregate)call.rel(3);
            ImmutableBitSet topRefs = RelOptUtil.InputFinder.bits((List)project.getChildExps(), null);
            this.perform(call, topRefs, (RelNode)project, join, left, aggregate);
        }
    }
}

