/*
 * Decompiled with CFR 0.152.
 */
package io.druid.sql.calcite.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.druid.sql.calcite.planner.Calcites;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;

public class CaseFilteredAggregatorRule
extends RelOptRule {
    private static final CaseFilteredAggregatorRule INSTANCE = new CaseFilteredAggregatorRule();

    private CaseFilteredAggregatorRule() {
        super(CaseFilteredAggregatorRule.operand(Aggregate.class, (RelOptRuleOperand)CaseFilteredAggregatorRule.operand(Project.class, (RelOptRuleOperandChildren)CaseFilteredAggregatorRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]));
    }

    public static CaseFilteredAggregatorRule instance() {
        return INSTANCE;
    }

    public boolean matches(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        Project project = (Project)call.rel(1);
        if (aggregate.indicator || aggregate.getGroupSets().size() != 1) {
            return false;
        }
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            if (!CaseFilteredAggregatorRule.isOneArgAggregateCall(aggregateCall) || !CaseFilteredAggregatorRule.isThreeArgCase((RexNode)project.getChildExps().get((Integer)Iterables.getOnlyElement((Iterable)aggregateCall.getArgList())))) continue;
            return true;
        }
        return false;
    }

    public void onMatch(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        Project project = (Project)call.rel(1);
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        ArrayList<AggregateCall> newCalls = new ArrayList<AggregateCall>(aggregate.getAggCallList().size());
        ArrayList<RexNode> newProjects = new ArrayList<RexNode>(project.getChildExps());
        ArrayList<Object> newCasts = new ArrayList<Object>(aggregate.getGroupCount() + aggregate.getAggCallList().size());
        RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory();
        Iterator iterator = aggregate.getGroupSet().iterator();
        while (iterator.hasNext()) {
            int fieldNumber = (Integer)iterator.next();
            newCasts.add(rexBuilder.makeInputRef(((RexNode)project.getChildExps().get(fieldNumber)).getType(), fieldNumber));
        }
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            RexNode rexNode;
            AggregateCall newCall = null;
            if (CaseFilteredAggregatorRule.isOneArgAggregateCall(aggregateCall) && CaseFilteredAggregatorRule.isThreeArgCase(rexNode = (RexNode)project.getChildExps().get((Integer)Iterables.getOnlyElement((Iterable)aggregateCall.getArgList())))) {
                RexCall caseCall = (RexCall)rexNode;
                boolean flip = RexLiteral.isNullLiteral((RexNode)((RexNode)caseCall.getOperands().get(1))) && !RexLiteral.isNullLiteral((RexNode)((RexNode)caseCall.getOperands().get(2)));
                RexNode arg1 = (RexNode)caseCall.getOperands().get(flip ? 2 : 1);
                RexNode arg2 = (RexNode)caseCall.getOperands().get(flip ? 1 : 2);
                RelDataType booleanType = Calcites.createSqlType(typeFactory, SqlTypeName.BOOLEAN);
                RexNode filterFromCase = rexBuilder.makeCall(booleanType, (SqlOperator)(flip ? SqlStdOperatorTable.IS_FALSE : SqlStdOperatorTable.IS_TRUE), (List)ImmutableList.of(caseCall.getOperands().get(0)));
                RexNode filter = aggregateCall.filterArg >= 0 ? rexBuilder.makeCall(booleanType, (SqlOperator)SqlStdOperatorTable.AND, (List)ImmutableList.of(project.getProjects().get(aggregateCall.filterArg), (Object)filterFromCase)) : filterFromCase;
                if (aggregateCall.isDistinct()) {
                    if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT && RexLiteral.isNullLiteral((RexNode)arg2)) {
                        newProjects.add(arg1);
                        newProjects.add(filter);
                        newCall = AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.COUNT, (boolean)true, (List)ImmutableList.of((Object)(newProjects.size() - 2)), (int)(newProjects.size() - 1), (RelDataType)aggregateCall.getType(), (String)aggregateCall.getName());
                    }
                } else if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT && arg1.isA(SqlKind.LITERAL) && !RexLiteral.isNullLiteral((RexNode)arg1) && RexLiteral.isNullLiteral((RexNode)arg2)) {
                    newProjects.add(filter);
                    newCall = AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.COUNT, (boolean)false, (List)ImmutableList.of(), (int)(newProjects.size() - 1), (RelDataType)aggregateCall.getType(), (String)aggregateCall.getName());
                } else if (aggregateCall.getAggregation().getKind() == SqlKind.SUM && Calcites.isIntLiteral(arg1) && RexLiteral.intValue((RexNode)arg1) == 1 && Calcites.isIntLiteral(arg2) && RexLiteral.intValue((RexNode)arg2) == 0) {
                    newProjects.add(filter);
                    newCall = AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.COUNT, (boolean)false, (List)ImmutableList.of(), (int)(newProjects.size() - 1), (RelDataType)Calcites.createSqlType(typeFactory, SqlTypeName.BIGINT), (String)aggregateCall.getName());
                } else if (RexLiteral.isNullLiteral((RexNode)arg2) || aggregateCall.getAggregation().getKind() == SqlKind.SUM && Calcites.isIntLiteral(arg2) && RexLiteral.intValue((RexNode)arg2) == 0) {
                    newProjects.add(arg1);
                    newProjects.add(filter);
                    newCall = AggregateCall.create((SqlAggFunction)aggregateCall.getAggregation(), (boolean)false, (List)ImmutableList.of((Object)(newProjects.size() - 2)), (int)(newProjects.size() - 1), (RelDataType)aggregateCall.getType(), (String)aggregateCall.getName());
                }
            }
            newCalls.add(newCall == null ? aggregateCall : newCall);
            int i = newCasts.size();
            RelDataType oldType = ((RelDataTypeField)aggregate.getRowType().getFieldList().get(i)).getType();
            if (newCall == null) {
                newCasts.add(rexBuilder.makeInputRef(oldType, i));
                continue;
            }
            newCasts.add(rexBuilder.makeCast(oldType, (RexNode)rexBuilder.makeInputRef(newCall.getType(), i)));
        }
        if (!newCalls.equals(aggregate.getAggCallList())) {
            RelBuilder relBuilder = call.builder().push(project.getInput()).project(newProjects);
            RelBuilder.GroupKey groupKey = relBuilder.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets());
            RelNode newAggregate = relBuilder.aggregate(groupKey, newCalls).project(newCasts).build();
            call.transformTo(newAggregate);
            call.getPlanner().setImportance((RelNode)aggregate, 0.0);
        }
    }

    private static boolean isOneArgAggregateCall(AggregateCall aggregateCall) {
        return aggregateCall.getArgList().size() == 1;
    }

    private static boolean isThreeArgCase(RexNode rexNode) {
        return rexNode.getKind() == SqlKind.CASE && ((RexCall)rexNode).getOperands().size() == 3;
    }
}

