/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.Slices;
import io.trino.Session;
import io.trino.connector.MockConnectorFactory;
import io.trino.connector.MockConnectorTableHandle;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.ConnectorFactory;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.ir.Between;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Case;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.planner.Plan;
import io.trino.sql.planner.assertions.AggregationFunction;
import io.trino.sql.planner.assertions.BasePlanTest;
import io.trino.sql.planner.assertions.ExpectedValueProvider;
import io.trino.sql.planner.assertions.ExpressionMatcher;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.testing.PlanTester;
import io.trino.testing.TestingSession;
import java.math.BigDecimal;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import org.assertj.core.api.Assertions;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.Test;

public class TestPreAggregateCaseAggregations
extends BasePlanTest {
    private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution();
    private static final ResolvedFunction CONCAT = FUNCTIONS.resolveFunction("concat", TypeSignatureProvider.fromTypes((Type[])new Type[]{VarcharType.VARCHAR, VarcharType.VARCHAR}));
    private static final ResolvedFunction MULTIPLY_BIGINT = FUNCTIONS.resolveOperator(OperatorType.MULTIPLY, (List<? extends Type>)ImmutableList.of((Object)BigintType.BIGINT, (Object)BigintType.BIGINT));
    private static final ResolvedFunction MODULUS_BIGINT = FUNCTIONS.resolveOperator(OperatorType.MODULUS, (List<? extends Type>)ImmutableList.of((Object)BigintType.BIGINT, (Object)BigintType.BIGINT));
    private static final ResolvedFunction MULTIPLY_DECIMAL_10_0 = FUNCTIONS.resolveOperator(OperatorType.MULTIPLY, (List<? extends Type>)ImmutableList.of((Object)DecimalType.createDecimalType((int)10), (Object)DecimalType.createDecimalType((int)10)));
    private static final SchemaTableName TABLE = new SchemaTableName("default", "t");

    @Override
    protected PlanTester createPlanTester() {
        Session.SessionBuilder sessionBuilder = TestingSession.testSessionBuilder().setCatalog("local").setSchema("default").setSystemProperty("optimize_hash_generation", "false").setSystemProperty("prefer_partial_aggregation", "false").setSystemProperty("task_concurrency", "1");
        PlanTester planTester = PlanTester.create((Session)sessionBuilder.build());
        MockConnectorFactory.Builder builder = MockConnectorFactory.builder().withGetTableHandle((session, schemaTableName) -> new MockConnectorTableHandle((SchemaTableName)schemaTableName)).withGetColumns(name -> {
            if (!name.equals((Object)TABLE)) {
                throw new IllegalArgumentException();
            }
            return ImmutableList.of((Object)new ColumnMetadata("col_varchar", (Type)VarcharType.VARCHAR), (Object)new ColumnMetadata("col_bigint", (Type)BigintType.BIGINT), (Object)new ColumnMetadata("col_tinyint", (Type)TinyintType.TINYINT), (Object)new ColumnMetadata("col_decimal", (Type)DecimalType.createDecimalType((int)2, (int)1)), (Object)new ColumnMetadata("col_long_decimal", (Type)DecimalType.createDecimalType((int)19, (int)18)), (Object)new ColumnMetadata("col_double", (Type)DoubleType.DOUBLE));
        });
        planTester.createCatalog("local", (ConnectorFactory)builder.build(), (Map)ImmutableMap.of());
        return planTester;
    }

    @Test
    public void testPreAggregatesCaseAggregations() {
        this.assertPlan("SELECT (col_varchar || 'a'), sum(CASE WHEN col_bigint = 1 THEN col_bigint * 2 ELSE 0 END), CAST(sum(CASE WHEN col_bigint = 1 THEN CAST(col_bigint * 2 AS INTEGER) ELSE CAST(0 AS INTEGER) END) AS VARCHAR(10)), sum(CASE WHEN col_bigint = 2 THEN col_bigint * 2 ELSE null END), min(CASE WHEN col_bigint % 2 > 1.23 THEN col_bigint * 2 END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CAST(CASE WHEN col_bigint = 4 THEN col_decimal * 2 END AS BIGINT)) FROM t GROUP BY (col_varchar || 'a')", PlanMatchPattern.anyTree(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"SUM_2_CAST", (Object)PlanMatchPattern.expression((Expression)new Cast((Expression)new Reference((Type)BigintType.BIGINT, "SUM_2"), (Type)VarcharType.createVarcharType((int)10)))), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("KEY"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.builder().put(Optional.of("SUM_1"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_1_INPUT"))).put(Optional.of("SUM_2"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_2_INPUT"))).put(Optional.of("SUM_3"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_3_INPUT"))).put(Optional.of("MIN_1"), PlanMatchPattern.aggregationFunction("min", (List<String>)ImmutableList.of((Object)"MIN_1_INPUT"))).put(Optional.of("SUM_4"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_4_INPUT"))).put(Optional.of("SUM_5"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_5_INPUT"))).buildOrThrow(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.builder().put((Object)"SUM_1_INPUT", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)1L)), (Expression)new Reference((Type)BigintType.BIGINT, "SUM_BIGINT"))), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)))).put((Object)"SUM_2_INPUT", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)1L)), (Expression)new Reference((Type)BigintType.BIGINT, "SUM_INT_CAST"))), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)))).put((Object)"SUM_3_INPUT", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)2L)), (Expression)new Reference((Type)BigintType.BIGINT, "SUM_BIGINT"))), (Expression)new Constant((Type)BigintType.BIGINT, null)))).put((Object)"MIN_1_INPUT", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Call(MODULUS_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Object)new Constant((Type)BigintType.BIGINT, (Object)2L))), (Expression)new Constant((Type)BigintType.BIGINT, (Object)1L)), (Expression)new Reference((Type)BigintType.BIGINT, "MIN_BIGINT"))), (Expression)new Constant((Type)BigintType.BIGINT, null)))).put((Object)"SUM_4_INPUT", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)3L)), (Expression)new Reference((Type)DecimalType.createDecimalType((int)38, (int)1), "SUM_DECIMAL"))), (Expression)new Constant((Type)DecimalType.createDecimalType((int)38, (int)1), null)))).put((Object)"SUM_5_INPUT", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)4L)), (Expression)new Reference((Type)BigintType.BIGINT, "SUM_DECIMAL_CAST"))), (Expression)new Constant((Type)BigintType.BIGINT, null)))).buildOrThrow(), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("KEY", "COL_BIGINT"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(Optional.of("SUM_BIGINT"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"VALUE_BIGINT")), Optional.of("SUM_INT_CAST"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"VALUE_INT_CAST")), Optional.of("MIN_BIGINT"), PlanMatchPattern.aggregationFunction("min", (List<String>)ImmutableList.of((Object)"VALUE_2_BIGINT")), Optional.of("SUM_DECIMAL"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"COL_DECIMAL")), Optional.of("SUM_DECIMAL_CAST"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"VALUE_DECIMAL_CAST"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.exchange(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"KEY", (Object)PlanMatchPattern.expression((Expression)new Call(CONCAT, (List)ImmutableList.of((Object)new Reference((Type)VarcharType.VARCHAR, "COL_VARCHAR"), (Object)new Constant((Type)VarcharType.VARCHAR, (Object)Slices.utf8Slice((String)"a"))))), (Object)"VALUE_BIGINT", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Between((Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)1L), (Expression)new Constant((Type)BigintType.BIGINT, (Object)2L)), (Expression)new Call(MULTIPLY_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Object)new Constant((Type)BigintType.BIGINT, (Object)2L))))), (Expression)new Constant((Type)BigintType.BIGINT, null))), (Object)"VALUE_INT_CAST", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)1L)), (Expression)new Cast((Expression)new Cast((Expression)new Call(MULTIPLY_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Object)new Constant((Type)BigintType.BIGINT, (Object)2L))), (Type)IntegerType.INTEGER), (Type)BigintType.BIGINT))), (Expression)new Constant((Type)BigintType.BIGINT, null))), (Object)"VALUE_2_BIGINT", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Call(MODULUS_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Object)new Constant((Type)BigintType.BIGINT, (Object)2L))), (Expression)new Constant((Type)BigintType.BIGINT, (Object)1L)), (Expression)new Call(MULTIPLY_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Object)new Constant((Type)BigintType.BIGINT, (Object)2L))))), (Expression)new Constant((Type)BigintType.BIGINT, null))), (Object)"VALUE_DECIMAL_CAST", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)4L)), (Expression)new Cast((Expression)new Call(MULTIPLY_DECIMAL_10_0, (List)ImmutableList.of((Object)new Reference((Type)DecimalType.createDecimalType((int)10, (int)0), "COL_DECIMAL"), (Object)new Constant((Type)DecimalType.createDecimalType((int)10, (int)0), (Object)Decimals.valueOfShort((BigDecimal)new BigDecimal("2"))))), (Type)BigintType.BIGINT))), (Expression)new Constant((Type)BigintType.BIGINT, null)))), PlanMatchPattern.tableScan("t", (Map<String, String>)ImmutableMap.of((Object)"COL_VARCHAR", (Object)"col_varchar", (Object)"COL_BIGINT", (Object)"col_bigint", (Object)"COL_DECIMAL", (Object)"col_decimal"))))))))));
    }

    @Test
    public void testGlobalPreAggregatesCaseAggregations() {
        this.assertPlan("SELECT sum(CASE WHEN col_bigint = 1 THEN col_bigint * 2 ELSE 0 END), CAST(sum(CASE WHEN col_bigint = 1 THEN CAST(col_bigint * 2 AS INTEGER) ELSE CAST(0 AS INTEGER) END) AS VARCHAR(10)), sum(CASE WHEN col_bigint = 2 THEN col_bigint * 2 ELSE null END), min(CASE WHEN col_bigint % 2 > 1.23 THEN col_bigint * 2 END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CAST(CASE WHEN col_bigint = 4 THEN col_decimal * 2 END AS BIGINT)) FROM t", PlanMatchPattern.anyTree(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"SUM_2_CAST", (Object)PlanMatchPattern.expression((Expression)new Cast((Expression)new Reference((Type)BigintType.BIGINT, "SUM_2"), (Type)VarcharType.createVarcharType((int)10)))), PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.builder().put(Optional.of("SUM_1"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_1_INPUT"))).put(Optional.of("SUM_2"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_2_INPUT"))).put(Optional.of("SUM_3"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_3_INPUT"))).put(Optional.of("MIN_1"), PlanMatchPattern.aggregationFunction("min", (List<String>)ImmutableList.of((Object)"MIN_1_INPUT"))).put(Optional.of("SUM_4"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_4_INPUT"))).put(Optional.of("SUM_5"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_5_INPUT"))).buildOrThrow(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.builder().put((Object)"SUM_1_INPUT", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)1L)), (Expression)new Reference((Type)BigintType.BIGINT, "SUM_BIGINT"))), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)))).put((Object)"SUM_2_INPUT", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)1L)), (Expression)new Reference((Type)BigintType.BIGINT, "SUM_INT_CAST"))), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)))).put((Object)"SUM_3_INPUT", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)2L)), (Expression)new Reference((Type)BigintType.BIGINT, "SUM_BIGINT"))), (Expression)new Constant((Type)BigintType.BIGINT, null)))).put((Object)"MIN_1_INPUT", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Call(MODULUS_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Object)new Constant((Type)BigintType.BIGINT, (Object)2L))), (Expression)new Constant((Type)BigintType.BIGINT, (Object)1L)), (Expression)new Reference((Type)BigintType.BIGINT, "MIN_BIGINT"))), (Expression)new Constant((Type)BigintType.BIGINT, null)))).put((Object)"SUM_4_INPUT", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)3L)), (Expression)new Reference((Type)DecimalType.createDecimalType((int)38, (int)1), "SUM_DECIMAL"))), (Expression)new Constant((Type)DecimalType.createDecimalType((int)38, (int)1), null)))).put((Object)"SUM_5_INPUT", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)4L)), (Expression)new Reference((Type)BigintType.BIGINT, "SUM_DECIMAL_CAST"))), (Expression)new Constant((Type)BigintType.BIGINT, null)))).buildOrThrow(), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL_BIGINT"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(Optional.of("SUM_BIGINT"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"VALUE_BIGINT")), Optional.of("SUM_INT_CAST"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"VALUE_INT_CAST")), Optional.of("MIN_BIGINT"), PlanMatchPattern.aggregationFunction("min", (List<String>)ImmutableList.of((Object)"VALUE_2_INT_CAST")), Optional.of("SUM_DECIMAL"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"COL_DECIMAL")), Optional.of("SUM_DECIMAL_CAST"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"VALUE_DECIMAL_CAST"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.exchange(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"VALUE_BIGINT", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Between((Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)1L), (Expression)new Constant((Type)BigintType.BIGINT, (Object)2L)), (Expression)new Call(MULTIPLY_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Object)new Constant((Type)BigintType.BIGINT, (Object)2L))))), (Expression)new Constant((Type)BigintType.BIGINT, null))), (Object)"VALUE_INT_CAST", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)1L)), (Expression)new Cast((Expression)new Cast((Expression)new Call(MULTIPLY_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Object)new Constant((Type)BigintType.BIGINT, (Object)2L))), (Type)IntegerType.INTEGER), (Type)BigintType.BIGINT))), (Expression)new Constant((Type)BigintType.BIGINT, null))), (Object)"VALUE_2_INT_CAST", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Call(MODULUS_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Object)new Constant((Type)BigintType.BIGINT, (Object)2L))), (Expression)new Constant((Type)BigintType.BIGINT, (Object)1L)), (Expression)new Call(MULTIPLY_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Object)new Constant((Type)BigintType.BIGINT, (Object)2L))))), (Expression)new Constant((Type)BigintType.BIGINT, null))), (Object)"VALUE_DECIMAL_CAST", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)4L)), (Expression)new Cast((Expression)new Call(MULTIPLY_DECIMAL_10_0, (List)ImmutableList.of((Object)new Reference((Type)DecimalType.createDecimalType((int)10, (int)0), "COL_DECIMAL"), (Object)new Constant((Type)DecimalType.createDecimalType((int)10, (int)0), (Object)Decimals.valueOfShort((BigDecimal)new BigDecimal("2"))))), (Type)BigintType.BIGINT))), (Expression)new Constant((Type)BigintType.BIGINT, null)))), PlanMatchPattern.tableScan("t", (Map<String, String>)ImmutableMap.of((Object)"COL_BIGINT", (Object)"col_bigint", (Object)"COL_DECIMAL", (Object)"col_decimal"))))))))));
    }

    @Test
    public void testPreAggregatesWithDefaultValues() {
        this.assertPlan("SELECT sum(CASE WHEN col_bigint = 1 THEN col_bigint ELSE BIGINT '0' END), sum(CASE WHEN col_bigint = 1 THEN col_bigint END), sum(CASE WHEN col_bigint = 2 THEN CAST(col_bigint AS INTEGER) ELSE CAST(0 AS INTEGER) END), sum(CASE WHEN col_bigint = 2 THEN CAST(col_bigint AS INTEGER) END) FROM t", PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.builder().put(Optional.of("SUM_1"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_BIGINT_FINAL"))).put(Optional.of("SUM_1_DEFAULT"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_BIGINT_FINAL_DEFAULT"))).put(Optional.of("SUM_2"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_INT_CAST_FINAL"))).put(Optional.of("SUM_2_DEFAULT"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_INT_CAST_FINAL_DEFAULT"))).buildOrThrow(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.builder().put((Object)"SUM_BIGINT_FINAL", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)1L)), (Expression)new Reference((Type)BigintType.BIGINT, "SUM_BIGINT"))), (Expression)new Constant((Type)BigintType.BIGINT, null)))).put((Object)"SUM_BIGINT_FINAL_DEFAULT", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)1L)), (Expression)new Reference((Type)BigintType.BIGINT, "SUM_BIGINT"))), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)))).put((Object)"SUM_INT_CAST_FINAL", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)2L)), (Expression)new Reference((Type)BigintType.BIGINT, "SUM_INT_CAST"))), (Expression)new Constant((Type)BigintType.BIGINT, null)))).put((Object)"SUM_INT_CAST_FINAL_DEFAULT", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)2L)), (Expression)new Reference((Type)BigintType.BIGINT, "SUM_INT_CAST"))), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)))).buildOrThrow(), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL_BIGINT"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(Optional.of("SUM_BIGINT"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"COL_BIGINT")), Optional.of("SUM_INT_CAST"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"VALUE_INT_CAST"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.exchange(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"VALUE_INT_CAST", (Object)PlanMatchPattern.expression((Expression)new Case((List)ImmutableList.of((Object)new WhenClause((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)2L)), (Expression)new Cast((Expression)new Cast((Expression)new Reference((Type)BigintType.BIGINT, "COL_BIGINT"), (Type)IntegerType.INTEGER), (Type)BigintType.BIGINT))), (Expression)new Constant((Type)BigintType.BIGINT, null)))), PlanMatchPattern.tableScan("t", (Map<String, String>)ImmutableMap.of((Object)"COL_BIGINT", (Object)"col_bigint")))))))));
    }

    @Test
    public void testPreAggregatesSumAggregationsWithZeroDefault() {
        this.assertFires("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_bigint ELSE BIGINT '0' END), sum(CASE WHEN col_bigint = 2 THEN col_bigint ELSE BIGINT '0' END), sum(CASE WHEN col_bigint = 2 THEN col_tinyint ELSE TINYINT '0' END), sum(CASE WHEN col_bigint = 3 THEN col_double ELSE DOUBLE '0' END), sum(CASE WHEN col_bigint = 4 THEN col_decimal ELSE DECIMAL '0.0' END), sum(CASE WHEN col_bigint = 5 THEN col_long_decimal ELSE DECIMAL '0.000000000000000000' END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testPreAggregatesWithoutNewExtraGroupingKeys() {
        this.assertFires("SELECT col_bigint, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_bigint = 4 THEN col_decimal END) FROM t GROUP BY col_bigint");
    }

    @Test
    public void testDoesNotFireWithGroupingSets() {
        this.assertThatDoesNotFire("SELECT col_varchar, col_bigint, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_bigint = 4 THEN col_decimal END) FROM t GROUP BY GROUPING SETS ((col_varchar), (col_bigint))");
    }

    @Test
    public void testDoesNotFireWithoutEnoughAggregations() {
        this.assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireWithMultipleExtraGroupingKeys() {
        this.assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_decimal = DECIMAL '4.1' THEN col_decimal END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireForSearchedCaseExpressionWithMultipleWithClauses() {
        this.assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_bigint = 4 THEN col_decimal END), sum(CASE WHEN col_bigint = 5 THEN col_decimal WHEN col_bigint = 6 THEN col_decimal * 2 END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireForNonCumulativeAggregation() {
        this.assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), count(CASE WHEN col_bigint = 4 THEN col_decimal END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireForSumAggregationWithNonZeroDefaultValue() {
        this.assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_bigint = 4 THEN col_decimal ELSE 1 END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireForMinAggregationWithNonNullDefaultValue() {
        this.assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), min(CASE WHEN col_bigint = 4 THEN col_decimal ELSE 0 END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireForNonCaseAggregation() {
        this.assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_bigint = 4 THEN col_decimal END), sum(col_decimal) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireIfAggregationsAreNotReduced() {
        this.assertThatDoesNotFire("SELECT\n    SUM(IF(col_varchar != 'V', col_bigint + col_decimal)),\n    SUM(IF(col_varchar != 'V', col_decimal + col_tinyint)),\n    SUM(IF(col_varchar != 'V', col_tinyint + col_double)),\n    SUM(IF(col_varchar != 'V', col_double + col_bigint))\nFROM t\n");
    }

    private void assertFires(@Language(value="SQL") String query) {
        Assertions.assertThat((int)TestPreAggregateCaseAggregations.countOfMatchingNodes(this.plan(query), AggregationNode.class::isInstance)).isEqualTo(2);
    }

    private void assertThatDoesNotFire(@Language(value="SQL") String query) {
        Assertions.assertThat((int)TestPreAggregateCaseAggregations.countOfMatchingNodes(this.plan(query), AggregationNode.class::isInstance)).isEqualTo(1);
    }

    private static int countOfMatchingNodes(Plan plan, Predicate<PlanNode> predicate) {
        return PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(predicate).count();
    }
}

