/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.connector.MockConnectorFactory;
import io.trino.connector.MockConnectorTableHandle;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.ConnectorFactory;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.ir.ArithmeticBinaryExpression;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.DecimalLiteral;
import io.trino.sql.ir.DoubleLiteral;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FunctionCall;
import io.trino.sql.ir.GenericLiteral;
import io.trino.sql.ir.InPredicate;
import io.trino.sql.ir.SearchedCaseExpression;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.planner.Plan;
import io.trino.sql.planner.assertions.AggregationFunction;
import io.trino.sql.planner.assertions.BasePlanTest;
import io.trino.sql.planner.assertions.ExpectedValueProvider;
import io.trino.sql.planner.assertions.ExpressionMatcher;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.tree.QualifiedName;
import io.trino.testing.PlanTester;
import io.trino.testing.TestingSession;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import org.assertj.core.api.Assertions;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.Test;

public class TestPreAggregateCaseAggregations
extends BasePlanTest {
    private static final SchemaTableName TABLE = new SchemaTableName("default", "t");

    @Override
    protected PlanTester createPlanTester() {
        Session.SessionBuilder sessionBuilder = TestingSession.testSessionBuilder().setCatalog("local").setSchema("default").setSystemProperty("optimize_hash_generation", "false").setSystemProperty("prefer_partial_aggregation", "false").setSystemProperty("task_concurrency", "1");
        PlanTester planTester = PlanTester.create((Session)sessionBuilder.build());
        MockConnectorFactory.Builder builder = MockConnectorFactory.builder().withGetTableHandle((session, schemaTableName) -> new MockConnectorTableHandle((SchemaTableName)schemaTableName)).withGetColumns(name -> {
            if (!name.equals((Object)TABLE)) {
                throw new IllegalArgumentException();
            }
            return ImmutableList.of((Object)new ColumnMetadata("col_varchar", (Type)VarcharType.VARCHAR), (Object)new ColumnMetadata("col_bigint", (Type)BigintType.BIGINT), (Object)new ColumnMetadata("col_tinyint", (Type)TinyintType.TINYINT), (Object)new ColumnMetadata("col_decimal", (Type)DecimalType.createDecimalType((int)2, (int)1)), (Object)new ColumnMetadata("col_long_decimal", (Type)DecimalType.createDecimalType((int)19, (int)18)), (Object)new ColumnMetadata("col_double", (Type)DoubleType.DOUBLE));
        });
        planTester.createCatalog("local", (ConnectorFactory)builder.build(), (Map)ImmutableMap.of());
        return planTester;
    }

    @Test
    public void testPreAggregatesCaseAggregations() {
        this.assertPlan("SELECT (col_varchar || 'a'), sum(CASE WHEN col_bigint = 1 THEN col_bigint * 2 ELSE 0 END), CAST(sum(CASE WHEN col_bigint = 1 THEN CAST(col_bigint * 2 AS INTEGER) ELSE CAST(0 AS INTEGER) END) AS VARCHAR(10)), sum(CASE WHEN col_bigint = 2 THEN col_bigint * 2 ELSE null END), min(CASE WHEN col_bigint % 2 > 1.23 THEN col_bigint * 2 END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CAST(CASE WHEN col_bigint = 4 THEN col_decimal * 2 END AS BIGINT)) FROM t GROUP BY (col_varchar || 'a')", PlanMatchPattern.anyTree(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"SUM_2_CAST", (Object)PlanMatchPattern.expression((Expression)new Cast((Expression)new SymbolReference("SUM_2"), (Type)VarcharType.createVarcharType((int)10)))), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("KEY"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.builder().put(Optional.of("SUM_1"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_1_INPUT"))).put(Optional.of("SUM_2"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_2_INPUT"))).put(Optional.of("SUM_3"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_3_INPUT"))).put(Optional.of("MIN_1"), PlanMatchPattern.aggregationFunction("min", (List<String>)ImmutableList.of((Object)"MIN_1_INPUT"))).put(Optional.of("SUM_4"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_4_INPUT"))).put(Optional.of("SUM_5"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_5_INPUT"))).buildOrThrow(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.builder().put((Object)"SUM_1_INPUT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "1")), (Expression)new SymbolReference("SUM_BIGINT"))), Optional.of(new GenericLiteral((Type)BigintType.BIGINT, "0"))))).put((Object)"SUM_2_INPUT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "1")), (Expression)new SymbolReference("SUM_INT_CAST"))), Optional.of(new GenericLiteral((Type)BigintType.BIGINT, "0"))))).put((Object)"SUM_3_INPUT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "2")), (Expression)new SymbolReference("SUM_BIGINT"))), Optional.empty()))).put((Object)"MIN_1_INPUT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, (Expression)new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MODULUS, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "2")), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "1")), (Expression)new SymbolReference("MIN_BIGINT"))), Optional.empty()))).put((Object)"SUM_4_INPUT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "3")), (Expression)new SymbolReference("SUM_DECIMAL"))), Optional.empty()))).put((Object)"SUM_5_INPUT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "4")), (Expression)new SymbolReference("SUM_DECIMAL_CAST"))), Optional.empty()))).buildOrThrow(), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("KEY", "COL_BIGINT"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(Optional.of("SUM_BIGINT"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"VALUE_BIGINT")), Optional.of("SUM_INT_CAST"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"VALUE_INT_CAST")), Optional.of("MIN_BIGINT"), PlanMatchPattern.aggregationFunction("min", (List<String>)ImmutableList.of((Object)"VALUE_2_BIGINT")), Optional.of("SUM_DECIMAL"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"COL_DECIMAL")), Optional.of("SUM_DECIMAL_CAST"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"VALUE_DECIMAL_CAST"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.exchange(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"KEY", (Object)PlanMatchPattern.expression((Expression)new FunctionCall(QualifiedName.of((String)"concat"), (List)ImmutableList.of((Object)new SymbolReference("COL_VARCHAR"), (Object)new GenericLiteral((Type)VarcharType.VARCHAR, "a")))), (Object)"VALUE_BIGINT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new InPredicate((Expression)new SymbolReference("COL_BIGINT"), (List)ImmutableList.of((Object)new GenericLiteral((Type)BigintType.BIGINT, "1"), (Object)new GenericLiteral((Type)BigintType.BIGINT, "2"))), (Expression)new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "2")))), Optional.empty())), (Object)"VALUE_INT_CAST", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "1")), (Expression)new Cast((Expression)new Cast((Expression)new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "2")), (Type)IntegerType.INTEGER), (Type)BigintType.BIGINT))), Optional.empty())), (Object)"VALUE_2_BIGINT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, (Expression)new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MODULUS, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "2")), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "1")), (Expression)new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "2")))), Optional.empty())), (Object)"VALUE_DECIMAL_CAST", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "4")), (Expression)new Cast((Expression)new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, (Expression)new SymbolReference("COL_DECIMAL"), (Expression)new Cast((Expression)new DecimalLiteral("2"), (Type)DecimalType.createDecimalType((int)10, (int)0))), (Type)BigintType.BIGINT))), Optional.empty()))), PlanMatchPattern.tableScan("t", (Map<String, String>)ImmutableMap.of((Object)"COL_VARCHAR", (Object)"col_varchar", (Object)"COL_BIGINT", (Object)"col_bigint", (Object)"COL_DECIMAL", (Object)"col_decimal"))))))))));
    }

    @Test
    public void testGlobalPreAggregatesCaseAggregations() {
        this.assertPlan("SELECT sum(CASE WHEN col_bigint = 1 THEN col_bigint * 2 ELSE 0 END), CAST(sum(CASE WHEN col_bigint = 1 THEN CAST(col_bigint * 2 AS INTEGER) ELSE CAST(0 AS INTEGER) END) AS VARCHAR(10)), sum(CASE WHEN col_bigint = 2 THEN col_bigint * 2 ELSE null END), min(CASE WHEN col_bigint % 2 > 1.23 THEN col_bigint * 2 END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CAST(CASE WHEN col_bigint = 4 THEN col_decimal * 2 END AS BIGINT)) FROM t", PlanMatchPattern.anyTree(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"SUM_2_CAST", (Object)PlanMatchPattern.expression((Expression)new Cast((Expression)new SymbolReference("SUM_2"), (Type)VarcharType.createVarcharType((int)10)))), PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.builder().put(Optional.of("SUM_1"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_1_INPUT"))).put(Optional.of("SUM_2"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_2_INPUT"))).put(Optional.of("SUM_3"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_3_INPUT"))).put(Optional.of("MIN_1"), PlanMatchPattern.aggregationFunction("min", (List<String>)ImmutableList.of((Object)"MIN_1_INPUT"))).put(Optional.of("SUM_4"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_4_INPUT"))).put(Optional.of("SUM_5"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_5_INPUT"))).buildOrThrow(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.builder().put((Object)"SUM_1_INPUT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "1")), (Expression)new SymbolReference("SUM_BIGINT"))), Optional.of(new GenericLiteral((Type)BigintType.BIGINT, "0"))))).put((Object)"SUM_2_INPUT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "1")), (Expression)new SymbolReference("SUM_INT_CAST"))), Optional.of(new GenericLiteral((Type)BigintType.BIGINT, "0"))))).put((Object)"SUM_3_INPUT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "2")), (Expression)new SymbolReference("SUM_BIGINT"))), Optional.empty()))).put((Object)"MIN_1_INPUT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, (Expression)new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MODULUS, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "2")), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "1")), (Expression)new SymbolReference("MIN_BIGINT"))), Optional.empty()))).put((Object)"SUM_4_INPUT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "3")), (Expression)new SymbolReference("SUM_DECIMAL"))), Optional.empty()))).put((Object)"SUM_5_INPUT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "4")), (Expression)new SymbolReference("SUM_DECIMAL_CAST"))), Optional.empty()))).buildOrThrow(), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL_BIGINT"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(Optional.of("SUM_BIGINT"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"VALUE_BIGINT")), Optional.of("SUM_INT_CAST"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"VALUE_INT_CAST")), Optional.of("MIN_BIGINT"), PlanMatchPattern.aggregationFunction("min", (List<String>)ImmutableList.of((Object)"VALUE_2_INT_CAST")), Optional.of("SUM_DECIMAL"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"COL_DECIMAL")), Optional.of("SUM_DECIMAL_CAST"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"VALUE_DECIMAL_CAST"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.exchange(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"VALUE_BIGINT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new InPredicate((Expression)new SymbolReference("COL_BIGINT"), (List)ImmutableList.of((Object)new GenericLiteral((Type)BigintType.BIGINT, "1"), (Object)new GenericLiteral((Type)BigintType.BIGINT, "2"))), (Expression)new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "2")))), Optional.empty())), (Object)"VALUE_INT_CAST", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "1")), (Expression)new Cast((Expression)new Cast((Expression)new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "2")), (Type)IntegerType.INTEGER), (Type)BigintType.BIGINT))), Optional.empty())), (Object)"VALUE_2_INT_CAST", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, (Expression)new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MODULUS, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "2")), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "1")), (Expression)new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "2")))), Optional.empty())), (Object)"VALUE_DECIMAL_CAST", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "4")), (Expression)new Cast((Expression)new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, (Expression)new SymbolReference("COL_DECIMAL"), (Expression)new Cast((Expression)new DecimalLiteral("2"), (Type)DecimalType.createDecimalType((int)10, (int)0))), (Type)BigintType.BIGINT))), Optional.empty()))), PlanMatchPattern.tableScan("t", (Map<String, String>)ImmutableMap.of((Object)"COL_BIGINT", (Object)"col_bigint", (Object)"COL_DECIMAL", (Object)"col_decimal"))))))))));
    }

    @Test
    public void testPreAggregatesWithDefaultValues() {
        this.assertPlan("SELECT sum(CASE WHEN col_bigint = 1 THEN col_bigint ELSE BIGINT '0' END), sum(CASE WHEN col_bigint = 1 THEN col_bigint END), sum(CASE WHEN col_bigint = 2 THEN CAST(col_bigint AS INTEGER) ELSE CAST(0 AS INTEGER) END), sum(CASE WHEN col_bigint = 2 THEN CAST(col_bigint AS INTEGER) END), sum(CASE WHEN col_bigint = 3 THEN col_tinyint ELSE TINYINT '0' END), sum(CASE WHEN col_bigint = 3 THEN col_tinyint END), sum(CASE WHEN col_bigint = 4 THEN col_decimal ELSE CAST(0 AS DECIMAL(2, 1)) END), sum(CASE WHEN col_bigint = 4 THEN col_decimal END), sum(CASE WHEN col_bigint = 5 THEN col_long_decimal ELSE CAST(0 AS DECIMAL(19, 18)) END), sum(CASE WHEN col_bigint = 5 THEN col_long_decimal END), sum(CASE WHEN col_bigint = 6 THEN col_double ELSE DOUBLE '0' END), sum(CASE WHEN col_bigint = 6 THEN col_double END) FROM t", PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.builder().put(Optional.of("SUM_1"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_BIGINT_FINAL"))).put(Optional.of("SUM_1_DEFAULT"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_BIGINT_FINAL_DEFAULT"))).put(Optional.of("SUM_2"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_INT_CAST_FINAL"))).put(Optional.of("SUM_2_DEFAULT"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_INT_CAST_FINAL_DEFAULT"))).put(Optional.of("SUM_3"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_TINYINT_FINAL"))).put(Optional.of("SUM_3_DEFAULT"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_TINYINT_FINAL_DEFAULT"))).put(Optional.of("SUM_4"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_DECIMAL_FINAL"))).put(Optional.of("SUM_4_DEFAULT"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_DECIMAL_FINAL_DEFAULT"))).put(Optional.of("SUM_5"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_LONG_DECIMAL_FINAL"))).put(Optional.of("SUM_5_DEFAULT"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_LONG_DECIMAL_FINAL_DEFAULT"))).put(Optional.of("SUM_6"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_DOUBLE_FINAL"))).put(Optional.of("SUM_6_DEFAULT"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"SUM_DOUBLE_FINAL_DEFAULT"))).buildOrThrow(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.builder().put((Object)"SUM_BIGINT_FINAL", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "1")), (Expression)new SymbolReference("SUM_BIGINT"))), Optional.empty()))).put((Object)"SUM_BIGINT_FINAL_DEFAULT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "1")), (Expression)new SymbolReference("SUM_BIGINT"))), Optional.of(new GenericLiteral((Type)BigintType.BIGINT, "0"))))).put((Object)"SUM_INT_CAST_FINAL", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "2")), (Expression)new SymbolReference("SUM_INT_CAST"))), Optional.empty()))).put((Object)"SUM_INT_CAST_FINAL_DEFAULT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "2")), (Expression)new SymbolReference("SUM_INT_CAST"))), Optional.of(new GenericLiteral((Type)BigintType.BIGINT, "0"))))).put((Object)"SUM_TINYINT_FINAL", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "3")), (Expression)new SymbolReference("SUM_TINYINT"))), Optional.empty()))).put((Object)"SUM_TINYINT_FINAL_DEFAULT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "3")), (Expression)new SymbolReference("SUM_TINYINT"))), Optional.of(new GenericLiteral((Type)BigintType.BIGINT, "0"))))).put((Object)"SUM_DECIMAL_FINAL", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "4")), (Expression)new SymbolReference("SUM_DECIMAL"))), Optional.empty()))).put((Object)"SUM_DECIMAL_FINAL_DEFAULT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "4")), (Expression)new SymbolReference("SUM_DECIMAL"))), Optional.of(new Cast((Expression)new DecimalLiteral("0.0"), (Type)DecimalType.createDecimalType((int)38, (int)1)))))).put((Object)"SUM_LONG_DECIMAL_FINAL", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "5")), (Expression)new SymbolReference("SUM_LONG_DECIMAL"))), Optional.empty()))).put((Object)"SUM_LONG_DECIMAL_FINAL_DEFAULT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "5")), (Expression)new SymbolReference("SUM_LONG_DECIMAL"))), Optional.of(new Cast((Expression)new DecimalLiteral("0.000000000000000000"), (Type)DecimalType.createDecimalType((int)38, (int)18)))))).put((Object)"SUM_DOUBLE_FINAL", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "6")), (Expression)new SymbolReference("SUM_DOUBLE"))), Optional.empty()))).put((Object)"SUM_DOUBLE_FINAL_DEFAULT", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "6")), (Expression)new SymbolReference("SUM_DOUBLE"))), Optional.of(new DoubleLiteral(0.0))))).buildOrThrow(), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL_BIGINT"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(Optional.of("SUM_BIGINT"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"COL_BIGINT")), Optional.of("SUM_INT_CAST"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"VALUE_INT_CAST")), Optional.of("SUM_TINYINT"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"VALUE_TINYINT_CAST")), Optional.of("SUM_DECIMAL"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"COL_DECIMAL")), Optional.of("SUM_LONG_DECIMAL"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"COL_LONG_DECIMAL")), Optional.of("SUM_DOUBLE"), PlanMatchPattern.aggregationFunction("sum", (List<String>)ImmutableList.of((Object)"COL_DOUBLE"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.exchange(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"VALUE_INT_CAST", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "2")), (Expression)new Cast((Expression)new Cast((Expression)new SymbolReference("COL_BIGINT"), (Type)IntegerType.INTEGER), (Type)BigintType.BIGINT))), Optional.empty())), (Object)"VALUE_TINYINT_CAST", (Object)PlanMatchPattern.expression((Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference("COL_BIGINT"), (Expression)new GenericLiteral((Type)BigintType.BIGINT, "3")), (Expression)new Cast((Expression)new SymbolReference("COL_TINYINT"), (Type)BigintType.BIGINT))), Optional.empty()))), PlanMatchPattern.tableScan("t", (Map<String, String>)ImmutableMap.of((Object)"COL_BIGINT", (Object)"col_bigint", (Object)"COL_TINYINT", (Object)"col_tinyint", (Object)"COL_DECIMAL", (Object)"col_decimal", (Object)"COL_LONG_DECIMAL", (Object)"col_long_decimal", (Object)"COL_DOUBLE", (Object)"col_double")))))))));
    }

    @Test
    public void testPreAggregatesSumAggregationsWithZeroDefault() {
        this.assertFires("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_bigint ELSE BIGINT '0' END), sum(CASE WHEN col_bigint = 2 THEN col_tinyint ELSE TINYINT '0' END), sum(CASE WHEN col_bigint = 3 THEN col_double ELSE DOUBLE '0' END), sum(CASE WHEN col_bigint = 4 THEN col_decimal ELSE DECIMAL '0.0' END), sum(CASE WHEN col_bigint = 5 THEN col_long_decimal ELSE DECIMAL '0.000000000000000000' END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testPreAggregatesWithoutNewExtraGroupingKeys() {
        this.assertFires("SELECT col_bigint, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_bigint = 4 THEN col_decimal END) FROM t GROUP BY col_bigint");
    }

    @Test
    public void testDoesNotFireWithGroupingSets() {
        this.assertThatDoesNotFire("SELECT col_varchar, col_bigint, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_bigint = 4 THEN col_decimal END) FROM t GROUP BY GROUPING SETS ((col_varchar), (col_bigint))");
    }

    @Test
    public void testDoesNotFireWithoutEnoughAggregations() {
        this.assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireWithMultipleExtraGroupingKeys() {
        this.assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_decimal = DECIMAL '4.1' THEN col_decimal END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireForSearchedCaseExpressionWithMultipleWithClauses() {
        this.assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_bigint = 4 THEN col_decimal END), sum(CASE WHEN col_bigint = 5 THEN col_decimal WHEN col_bigint = 6 THEN col_decimal * 2 END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireForNonCumulativeAggregation() {
        this.assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), count(CASE WHEN col_bigint = 4 THEN col_decimal END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireForSumAggregationWithNonZeroDefaultValue() {
        this.assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_bigint = 4 THEN col_decimal ELSE 1 END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireForMinAggregationWithNonNullDefaultValue() {
        this.assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), min(CASE WHEN col_bigint = 4 THEN col_decimal ELSE 0 END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireForNonCaseAggregation() {
        this.assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_bigint = 4 THEN col_decimal END), sum(col_decimal) FROM t GROUP BY col_varchar");
    }

    private void assertFires(@Language(value="SQL") String query) {
        Assertions.assertThat((int)TestPreAggregateCaseAggregations.countOfMatchingNodes(this.plan(query), AggregationNode.class::isInstance)).isEqualTo(2);
    }

    private void assertThatDoesNotFire(@Language(value="SQL") String query) {
        Assertions.assertThat((int)TestPreAggregateCaseAggregations.countOfMatchingNodes(this.plan(query), AggregationNode.class::isInstance)).isEqualTo(1);
    }

    private static int countOfMatchingNodes(Plan plan, Predicate<PlanNode> predicate) {
        return PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(predicate).count();
    }
}

