/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.assertions.BasePlanTest;
import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider;
import com.facebook.presto.sql.planner.assertions.ExpressionMatcher;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.tree.FunctionCall;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.testng.annotations.Test;

public class TestRewriteIfOverAggregation
extends BasePlanTest {
    private Session enableOptimization() {
        return Session.builder((Session)this.getQueryRunner().getDefaultSession()).setSystemProperty("optimize_conditional_aggregation_enabled", "true").build();
    }

    @Test
    public void testConditionOnGrouping() {
        this.assertPlan("SELECT orderstatus, shippriority, IF(GROUPING(orderstatus, shippriority) = 0, sum(totalprice)) FROM orders GROUP BY GROUPING SETS ((orderstatus), (orderstatus, shippriority))", this.enableOptimization(), PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(new PlanMatchPattern.GroupingSetDescriptor((List<String>)ImmutableList.of((Object)"orderstatus$gid", (Object)"shippriority$gid", (Object)"groupid"), 2, (Set<Integer>)ImmutableSet.of()), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("pricesum"), PlanMatchPattern.functionCall("sum", (List<String>)ImmutableList.of((Object)"totalprice"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("pricesum"), (Object)new Symbol("mask")), Optional.of(new Symbol("groupid")), AggregationNode.Step.PARTIAL, PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"mask", (Object)PlanMatchPattern.expression("array[1, 0][groupid+1]=0")), PlanMatchPattern.groupingSet((List<List<String>>)ImmutableList.of((Object)ImmutableList.of((Object)"orderstatus"), (Object)ImmutableList.of((Object)"orderstatus", (Object)"shippriority")), (Map<String, String>)ImmutableMap.of((Object)"totalprice", (Object)"totalprice"), "groupid", (Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"orderstatus$gid", (Object)PlanMatchPattern.expression("orderstatus"), (Object)"shippriority$gid", (Object)PlanMatchPattern.expression("shippriority")), PlanMatchPattern.tableScan("orders", (Map<String, String>)ImmutableMap.of((Object)"totalprice", (Object)"totalprice", (Object)"orderstatus", (Object)"orderstatus", (Object)"shippriority", (Object)"shippriority")))))));
    }

    @Test
    public void testConditionOnAggregation() {
        this.assertPlan("select orderpriority, if(count(1)>3000, avg(totalprice)) from orders group by orderpriority ", this.enableOptimization(), PlanMatchPattern.anyTree(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"ifexp", (Object)PlanMatchPattern.expression("if(count > 3000, avg, null)")), PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of((Object)"avg", PlanMatchPattern.functionCall("avg", (List<String>)ImmutableList.of((Object)"partial_avg")), (Object)"count", PlanMatchPattern.functionCall("count", (List<String>)ImmutableList.of((Object)"partial_count"))), PlanMatchPattern.exchange(PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of((Object)"partial_avg", PlanMatchPattern.functionCall("avg", (List<String>)ImmutableList.of((Object)"totalprice")), (Object)"partial_count", PlanMatchPattern.functionCall("count", (List<String>)ImmutableList.of())), PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"totalprice", (Object)PlanMatchPattern.expression("totalprice"), (Object)"orderpriority", (Object)PlanMatchPattern.expression("orderpriority")), PlanMatchPattern.tableScan("orders", (Map<String, String>)ImmutableMap.of((Object)"totalprice", (Object)"totalprice", (Object)"orderpriority", (Object)"orderpriority")))))))));
    }

    @Test
    public void testMultipleArgumentsAggregation() {
        this.assertPlan("SELECT orderstatus, shippriority, IF(GROUPING(orderstatus, shippriority) = 0, max_by(shippriority, totalprice)) FROM orders GROUP BY GROUPING SETS ((orderstatus), (orderstatus, shippriority))", this.enableOptimization(), PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(new PlanMatchPattern.GroupingSetDescriptor((List<String>)ImmutableList.of((Object)"orderstatus$gid", (Object)"shippriority$gid", (Object)"groupid"), 2, (Set<Integer>)ImmutableSet.of()), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("result"), PlanMatchPattern.functionCall("max_by", (List<String>)ImmutableList.of((Object)"shippriority", (Object)"totalprice"))), (Map<Symbol, Symbol>)ImmutableMap.of((Object)new Symbol("result"), (Object)new Symbol("mask")), Optional.of(new Symbol("groupid")), AggregationNode.Step.PARTIAL, PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"mask", (Object)PlanMatchPattern.expression("array[1, 0][groupid+1]=0")), PlanMatchPattern.groupingSet((List<List<String>>)ImmutableList.of((Object)ImmutableList.of((Object)"orderstatus"), (Object)ImmutableList.of((Object)"orderstatus", (Object)"shippriority")), (Map<String, String>)ImmutableMap.of((Object)"totalprice", (Object)"totalprice", (Object)"shippriority", (Object)"shippriority"), "groupid", (Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"orderstatus$gid", (Object)PlanMatchPattern.expression("orderstatus"), (Object)"shippriority$gid", (Object)PlanMatchPattern.expression("shippriority")), PlanMatchPattern.tableScan("orders", (Map<String, String>)ImmutableMap.of((Object)"totalprice", (Object)"totalprice", (Object)"orderstatus", (Object)"orderstatus", (Object)"shippriority", (Object)"shippriority")))))));
    }
}

