/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.metadata.Metadata;
import io.trino.sql.planner.PlanOptimizers;
import io.trino.sql.planner.RuleStatsRecorder;
import io.trino.sql.planner.assertions.BasePlanTest;
import io.trino.sql.planner.assertions.ExpectedValueProvider;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.assertions.PlanTestSymbol;
import io.trino.sql.planner.iterative.IterativeOptimizer;
import io.trino.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct;
import io.trino.sql.planner.iterative.rule.RemoveRedundantIdentityProjections;
import io.trino.sql.planner.iterative.rule.SingleDistinctAggregationToGroupBy;
import io.trino.sql.planner.optimizations.OptimizeMixedDistinctAggregations;
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.optimizations.UnaliasSymbolReferences;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.tree.FunctionCall;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.testng.annotations.Test;

public class TestOptimizeMixedDistinctAggregations
extends BasePlanTest {
    public TestOptimizeMixedDistinctAggregations() {
        super((Map<String, String>)ImmutableMap.of((Object)"optimize_mixed_distinct_aggregations", (Object)"true"));
    }

    @Test
    public void testMixedDistinctAggregationOptimizer() {
        String sql = "SELECT custkey, max(totalprice) AS s, count(DISTINCT orderdate) AS d FROM orders GROUP BY custkey";
        String group = "GROUP";
        String groupBy = "CUSTKEY";
        String aggregate = "TOTALPRICE";
        String distinctAggregation = "ORDERDATE";
        ImmutableList groupByKeysSecond = ImmutableList.of((Object)groupBy);
        ImmutableMap aggregationsSecond = ImmutableMap.of(Optional.of("arbitrary"), PlanMatchPattern.functionCall("arbitrary", false, (List<PlanTestSymbol>)ImmutableList.of((Object)PlanMatchPattern.anySymbol())), Optional.of("count"), PlanMatchPattern.functionCall("count", false, (List<PlanTestSymbol>)ImmutableList.of((Object)PlanMatchPattern.anySymbol())));
        ImmutableList groupByKeysFirst = ImmutableList.of((Object)groupBy, (Object)distinctAggregation, (Object)group);
        ImmutableMap aggregationsFirst = ImmutableMap.of(Optional.of("MAX"), PlanMatchPattern.functionCall("max", (List<String>)ImmutableList.of((Object)"TOTALPRICE")));
        PlanMatchPattern tableScan = PlanMatchPattern.tableScan("orders", (Map<String, String>)ImmutableMap.of((Object)"TOTALPRICE", (Object)"totalprice", (Object)"CUSTKEY", (Object)"custkey", (Object)"ORDERDATE", (Object)"orderdate"));
        ImmutableList.Builder groups = ImmutableList.builder();
        groups.add((Object)ImmutableList.of((Object)groupBy, (Object)aggregate));
        groups.add((Object)ImmutableList.of((Object)groupBy, (Object)distinctAggregation));
        PlanMatchPattern expectedPlanPattern = PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet((List<String>)groupByKeysSecond), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)aggregationsSecond, Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.project(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet((List<String>)groupByKeysFirst), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)aggregationsFirst, Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.groupId((List<List<String>>)groups.build(), group, tableScan)))));
        this.assertUnitPlan(sql, expectedPlanPattern);
    }

    @Test
    public void testNestedType() {
        ImmutableMap aggregationsSecond = ImmutableMap.of((Object)"arbitrary", PlanMatchPattern.functionCall("arbitrary", false, (List<PlanTestSymbol>)ImmutableList.of((Object)PlanMatchPattern.anySymbol())), (Object)"count", PlanMatchPattern.functionCall("count", false, (List<PlanTestSymbol>)ImmutableList.of((Object)PlanMatchPattern.anySymbol())));
        ImmutableMap aggregationsFirst = ImmutableMap.of((Object)"max", PlanMatchPattern.functionCall("max", false, (List<PlanTestSymbol>)ImmutableList.of((Object)PlanMatchPattern.anySymbol())));
        this.assertUnitPlan("SELECT count(DISTINCT a), max(b) FROM (VALUES (ROW(1, 2), 3)) t(a, b)", PlanMatchPattern.anyTree(PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<FunctionCall>>)aggregationsSecond, PlanMatchPattern.project(PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<FunctionCall>>)aggregationsFirst, PlanMatchPattern.anyTree(PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of())))))));
    }

    private void assertUnitPlan(String sql, PlanMatchPattern pattern) {
        ImmutableList optimizers = ImmutableList.of((Object)new UnaliasSymbolReferences(this.getQueryRunner().getMetadata()), (Object)new IterativeOptimizer(this.getQueryRunner().getPlannerContext(), new RuleStatsRecorder(), this.getQueryRunner().getStatsCalculator(), this.getQueryRunner().getEstimatedExchangesCostCalculator(), (Set)ImmutableSet.of((Object)new RemoveRedundantIdentityProjections(), (Object)new SingleDistinctAggregationToGroupBy(), (Object)new MultipleDistinctAggregationToMarkDistinct())), (Object)new OptimizeMixedDistinctAggregations(this.getQueryRunner().getMetadata()), (Object)new IterativeOptimizer(this.getQueryRunner().getPlannerContext(), new RuleStatsRecorder(), this.getQueryRunner().getStatsCalculator(), this.getQueryRunner().getEstimatedExchangesCostCalculator(), (Set)ImmutableSet.builder().add((Object)new RemoveRedundantIdentityProjections()).addAll((Iterable)PlanOptimizers.columnPruningRules((Metadata)this.getQueryRunner().getMetadata())).build()));
        this.assertPlan(sql, pattern, (List<PlanOptimizer>)optimizers);
    }
}

