/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules.views;

import io.prestosql.hive.$internal.com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;

public class HiveAggregateIncrementalRewritingRule
extends RelOptRule {
    public static final HiveAggregateIncrementalRewritingRule INSTANCE = new HiveAggregateIncrementalRewritingRule();

    private HiveAggregateIncrementalRewritingRule() {
        super(HiveAggregateIncrementalRewritingRule.operand(Aggregate.class, (RelOptRuleOperand)HiveAggregateIncrementalRewritingRule.operand(Union.class, (RelOptRuleOperandChildren)HiveAggregateIncrementalRewritingRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), HiveRelFactories.HIVE_BUILDER, "HiveAggregateIncrementalRewritingRule");
    }

    public void onMatch(RelOptRuleCall call) {
        Aggregate agg = (Aggregate)call.rel(0);
        Union union = (Union)call.rel(1);
        RexBuilder rexBuilder = agg.getCluster().getRexBuilder();
        RelNode joinLeftInput = union.getInput(1);
        RelNode joinRightInput = union.getInput(0);
        ArrayList<Object> projExprs = new ArrayList<Object>();
        ArrayList<RexNode> joinConjs = new ArrayList<RexNode>();
        ArrayList<RexNode> filterConjs = new ArrayList<RexNode>();
        int groupCount = agg.getGroupCount();
        int totalCount = agg.getGroupCount() + agg.getAggCallList().size();
        int leftPos = 0;
        int rightPos = totalCount;
        while (leftPos < groupCount) {
            RexInputRef leftRef = rexBuilder.makeInputRef(((RelDataTypeField)joinLeftInput.getRowType().getFieldList().get(leftPos)).getType(), leftPos);
            RexInputRef rightRef = rexBuilder.makeInputRef(((RelDataTypeField)joinRightInput.getRowType().getFieldList().get(leftPos)).getType(), rightPos);
            projExprs.add(rightRef);
            joinConjs.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, ImmutableList.of(leftRef, rightRef)));
            filterConjs.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NULL, ImmutableList.of(leftRef)));
            ++leftPos;
            ++rightPos;
        }
        RexNode caseFilterCond = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, filterConjs);
        int i = 0;
        int leftPos2 = groupCount;
        int rightPos2 = totalCount + groupCount;
        while (leftPos2 < totalCount) {
            RexNode elseReturn;
            RexInputRef leftRef = rexBuilder.makeInputRef(((RelDataTypeField)joinLeftInput.getRowType().getFieldList().get(leftPos2)).getType(), leftPos2);
            RexInputRef rightRef = rexBuilder.makeInputRef(((RelDataTypeField)joinRightInput.getRowType().getFieldList().get(leftPos2)).getType(), rightPos2);
            SqlAggFunction aggCall = ((AggregateCall)agg.getAggCallList().get(i)).getAggregation();
            switch (aggCall.getKind()) {
                case SUM: {
                    elseReturn = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.PLUS, ImmutableList.of(rightRef, leftRef));
                    break;
                }
                case MIN: {
                    RexNode condInnerCase = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.LESS_THAN, ImmutableList.of(rightRef, leftRef));
                    elseReturn = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, ImmutableList.of(condInnerCase, rightRef, leftRef));
                    break;
                }
                case MAX: {
                    RexNode condInnerCase = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.GREATER_THAN, ImmutableList.of(rightRef, leftRef));
                    elseReturn = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, ImmutableList.of(condInnerCase, rightRef, leftRef));
                    break;
                }
                default: {
                    throw new AssertionError((Object)("Found an aggregation that could not be recognized: " + aggCall));
                }
            }
            projExprs.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, ImmutableList.of(caseFilterCond, rightRef, elseReturn)));
            ++i;
            ++leftPos2;
            ++rightPos2;
        }
        RexNode joinCond = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, joinConjs);
        RexNode filterCond = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, filterConjs);
        RelNode newNode = call.builder().push(union.getInput(1)).push(union.getInput(0)).join(JoinRelType.RIGHT, joinCond).filter(new RexNode[]{rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.OR, ImmutableList.of(joinCond, filterCond))}).project(projExprs).build();
        call.transformTo(newNode);
    }
}

