/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner;

import com.facebook.presto.Session;
import com.facebook.presto.execution.TaskManagerConfig;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.sql.planner.Plan;
import com.facebook.presto.sql.planner.assertions.BasePlanTest;
import com.facebook.presto.sql.planner.assertions.ExpressionMatcher;
import com.facebook.presto.sql.planner.assertions.PlanAssert;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.testing.LocalQueryRunner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.units.DataSize;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.testng.annotations.Test;

public class TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet
extends BasePlanTest {
    public TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet() {
        super(TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet::setup);
    }

    private static LocalQueryRunner setup() {
        TaskManagerConfig taskManagerConfig = new TaskManagerConfig().setMaxPartialAggregationMemoryUsage(DataSize.succinctDataSize((double)1.0, (DataSize.Unit)DataSize.Unit.KILOBYTE));
        return TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet.createQueryRunner((Map<String, String>)ImmutableMap.of((Object)"add_exchange_below_partial_aggregation_over_group_id", (Object)"true"), taskManagerConfig);
    }

    @Test
    public void testRollup() {
        this.assertDistributedPlan("SELECT orderkey, suppkey, partkey, sum(quantity) from lineitem GROUP BY ROLLUP(orderkey, suppkey, partkey)", PlanMatchPattern.anyTree(PlanMatchPattern.node(GroupIdNode.class, PlanMatchPattern.anyTree(PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, (List<PlanMatchPattern.Ordering>)ImmutableList.of(), (Set<String>)ImmutableSet.of((Object)"orderkey"), PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE_STREAMING, ExchangeNode.Type.REPARTITION, (List<PlanMatchPattern.Ordering>)ImmutableList.of(), (Set<String>)ImmutableSet.of((Object)"orderkey"), PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("lineitem", (Map<String, String>)ImmutableMap.of((Object)"orderkey", (Object)"orderkey")))))))));
    }

    @Test
    public void testNegativeCases() {
        Session enableMergeAggregationWithAndWithoutFilter = Session.builder((Session)this.getQueryRunner().getDefaultSession()).setSystemProperty("merge_aggregations_with_and_without_filter", "true").build();
        String sql = "select partkey, sum(quantity), sum(quantity) filter (where discount > 0.1) from lineitem group by grouping sets((), (partkey))";
        this.assertDistributedPlan(sql, enableMergeAggregationWithAndWithoutFilter, PlanMatchPattern.anyTree(PlanMatchPattern.node(GroupIdNode.class, PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"partkey", (Object)PlanMatchPattern.expression("partkey"), (Object)"quantity", (Object)PlanMatchPattern.expression("quantity"), (Object)"expr", (Object)PlanMatchPattern.expression("discount > DOUBLE'0.1'")), PlanMatchPattern.tableScan("lineitem", (Map<String, String>)ImmutableMap.of((Object)"partkey", (Object)"partkey", (Object)"quantity", (Object)"quantity", (Object)"discount", (Object)"discount"))))));
        TaskManagerConfig taskManagerConfig = new TaskManagerConfig().setMaxPartialAggregationMemoryUsage(DataSize.succinctDataSize((double)1.0, (DataSize.Unit)DataSize.Unit.MEGABYTE));
        try (LocalQueryRunner queryRunner = TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet.createQueryRunner((Map<String, String>)ImmutableMap.of((Object)"add_exchange_below_partial_aggregation_over_group_id", (Object)"true"), taskManagerConfig);){
            queryRunner.inTransaction(queryRunner.getDefaultSession(), transactionSession -> {
                Plan plan = queryRunner.createPlan(transactionSession, "SELECT orderkey, suppkey, partkey, sum(quantity) from lineitem GROUP BY ROLLUP(orderkey, suppkey, partkey)", WarningCollector.NOOP);
                PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getStatsCalculator(), plan, PlanMatchPattern.anyTree(PlanMatchPattern.node(GroupIdNode.class, PlanMatchPattern.tableScan("lineitem", (Map<String, String>)ImmutableMap.of((Object)"orderkey", (Object)"orderkey")))));
                return null;
            });
        }
    }
}

