/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.TaskCountEstimator;
import com.facebook.presto.cost.VariableStatsEstimate;
import com.facebook.presto.execution.TaskManagerConfig;
import com.facebook.presto.spi.Plugin;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Partitioning;
import com.facebook.presto.spi.plan.PartitioningHandle;
import com.facebook.presto.spi.plan.PartitioningScheme;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SystemPartitioningHandle;
import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider;
import com.facebook.presto.sql.planner.assertions.ExpressionMatcher;
import com.facebook.presto.sql.planner.assertions.GroupIdMatcher;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.AddExchangesBelowPartialAggregationOverGroupIdRuleSet;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.planner.iterative.rule.test.RuleAssert;
import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.tree.FunctionCall;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

public class TestAddExchangesBelowPartialAggregationOverGroupIdRuleSet
extends BaseRuleTest {
    public TestAddExchangesBelowPartialAggregationOverGroupIdRuleSet() {
        super(new Plugin[0]);
    }

    private static AddExchangesBelowPartialAggregationOverGroupIdRuleSet.AddExchangesBelowExchangePartialAggregationGroupId belowExchangeRule(RuleTester ruleTester) {
        TaskCountEstimator taskCountEstimator = new TaskCountEstimator(() -> 4);
        TaskManagerConfig taskManagerConfig = new TaskManagerConfig();
        return new AddExchangesBelowPartialAggregationOverGroupIdRuleSet(taskCountEstimator, taskManagerConfig, ruleTester.getMetadata(), false).belowExchangeRule();
    }

    private static AddExchangesBelowPartialAggregationOverGroupIdRuleSet.AddExchangesBelowProjectionPartialAggregationGroupId belowProjectionRule(RuleTester ruleTester) {
        TaskCountEstimator taskCountEstimator = new TaskCountEstimator(() -> 4);
        TaskManagerConfig taskManagerConfig = new TaskManagerConfig();
        return new AddExchangesBelowPartialAggregationOverGroupIdRuleSet(taskCountEstimator, taskManagerConfig, ruleTester.getMetadata(), false).belowProjectionRule();
    }

    @DataProvider
    public static Object[][] testDataProvider() {
        return new Object[][]{{1000.0, 10000.0, 1000000.0, "groupingKey3"}, {1000.0, 2000000.0, 1000000.0, "groupingKey2"}, {1000.0, 1000.0, 1000.0, "groupingKey1"}};
    }

    @DataProvider
    public static Object[][] testDataProviderMissingStats() {
        return new Object[][]{{Double.NaN, 10000.0, 1000000.0}, {1000.0, Double.NaN, 1000000.0}, {1000.0, 10000.0, Double.NaN}};
    }

    @Test(dataProvider="testDataProvider")
    public void testAddExchangesWithoutProjection(double groupingKey1NDV, double groupingKey2NDV, double groupingKey3NDV, String expectedRepartitionSymbol) {
        this.buildRuleAssert(groupingKey1NDV, groupingKey2NDV, groupingKey3NDV, false).matches(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE_STREAMING, ExchangeNode.Type.GATHER, PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet((List<String>)ImmutableList.of((Object)"groupingKey1", (Object)"groupingKey2", (Object)"groupingKey3", (Object)"groupId")), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(), (List<String>)ImmutableList.of(), (Map<Symbol, Symbol>)ImmutableMap.of(), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.node(GroupIdNode.class, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, (List<PlanMatchPattern.Ordering>)ImmutableList.of(), (Set<String>)ImmutableSet.of((Object)expectedRepartitionSymbol), PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE_STREAMING, ExchangeNode.Type.REPARTITION, (List<PlanMatchPattern.Ordering>)ImmutableList.of(), (Set<String>)ImmutableSet.of((Object)expectedRepartitionSymbol), PlanMatchPattern.values("groupingKey1", "groupingKey2", "groupingKey3")))).with(new GroupIdMatcher((List<List<String>>)ImmutableList.of((Object)ImmutableList.of((Object)"groupingKey1", (Object)"groupingKey2"), (Object)ImmutableList.of((Object)"groupingKey1", (Object)"groupingKey3")), (Map<String, String>)ImmutableMap.of(), "groupId")))));
    }

    @Test(dataProvider="testDataProvider")
    public void testAddExchangesWithProjection(double groupingKey1NDV, double groupingKey2NDV, double groupingKey3NDV, String expectedRepartitionSymbol) {
        this.buildRuleAssert(groupingKey1NDV, groupingKey2NDV, groupingKey3NDV, true).matches(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE_STREAMING, ExchangeNode.Type.GATHER, PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"groupingKey1", (Object)PlanMatchPattern.expression("groupingKey1"), (Object)"groupingKey2", (Object)PlanMatchPattern.expression("groupingKey2"), (Object)"groupingKey3", (Object)PlanMatchPattern.expression("groupingKey3")), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet((List<String>)ImmutableList.of((Object)"groupingKey1", (Object)"groupingKey2", (Object)"groupingKey3", (Object)"groupId")), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(), (List<String>)ImmutableList.of(), (Map<Symbol, Symbol>)ImmutableMap.of(), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.node(GroupIdNode.class, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, (List<PlanMatchPattern.Ordering>)ImmutableList.of(), (Set<String>)ImmutableSet.of((Object)expectedRepartitionSymbol), PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE_STREAMING, ExchangeNode.Type.REPARTITION, (List<PlanMatchPattern.Ordering>)ImmutableList.of(), (Set<String>)ImmutableSet.of((Object)expectedRepartitionSymbol), PlanMatchPattern.values("groupingKey1", "groupingKey2", "groupingKey3")))).with(new GroupIdMatcher((List<List<String>>)ImmutableList.of((Object)ImmutableList.of((Object)"groupingKey1", (Object)"groupingKey2"), (Object)ImmutableList.of((Object)"groupingKey1", (Object)"groupingKey3")), (Map<String, String>)ImmutableMap.of(), "groupId"))))));
    }

    @Test(dataProvider="testDataProviderMissingStats")
    public void testDoesNotFireIfAnySourceSymbolIsMissingStats(double groupingKey1NDV, double groupingKey2NDV, double groupingKey3NDV) {
        this.buildRuleAssert(groupingKey1NDV, groupingKey2NDV, groupingKey3NDV, true).doesNotFire();
        this.buildRuleAssert(groupingKey1NDV, groupingKey2NDV, groupingKey3NDV, false).doesNotFire();
    }

    @Test
    public void testDoesNotFireIfSessionPropertyIsDisabled() {
        this.buildRuleAssert(1000.0, 1000.0, 1000.0, false).setSystemProperty("add_exchange_below_partial_aggregation_over_group_id", "false").doesNotFire();
    }

    private RuleAssert buildRuleAssert(double groupingKey1NDV, double groupingKey2NDV, double groupingKey3NDV, boolean withProjection) {
        RuleTester ruleTester = this.tester();
        String groupIdSourceId = "groupIdSourceId";
        return ruleTester.assertThat((Rule)(withProjection ? TestAddExchangesBelowPartialAggregationOverGroupIdRuleSet.belowProjectionRule(ruleTester) : TestAddExchangesBelowPartialAggregationOverGroupIdRuleSet.belowExchangeRule(ruleTester))).setSystemProperty("add_exchange_below_partial_aggregation_over_group_id", "true").overrideStats(groupIdSourceId, PlanNodeStatsEstimate.builder().setOutputRowCount(1.0E8).addVariableStatistics((Map)ImmutableMap.of((Object)new VariableReferenceExpression(Optional.empty(), "groupingKey1", (Type)BigintType.BIGINT), (Object)VariableStatsEstimate.builder().setDistinctValuesCount(groupingKey1NDV).build(), (Object)new VariableReferenceExpression(Optional.empty(), "groupingKey2", (Type)BigintType.BIGINT), (Object)VariableStatsEstimate.builder().setDistinctValuesCount(groupingKey2NDV).build(), (Object)new VariableReferenceExpression(Optional.empty(), "groupingKey3", (Type)BigintType.BIGINT), (Object)VariableStatsEstimate.builder().setDistinctValuesCount(groupingKey3NDV).build())).build()).on(p -> {
            VariableReferenceExpression groupingKey1 = p.variable("groupingKey1", (Type)BigintType.BIGINT);
            VariableReferenceExpression groupingKey2 = p.variable("groupingKey2", (Type)BigintType.BIGINT);
            VariableReferenceExpression groupingKey3 = p.variable("groupingKey3", (Type)BigintType.BIGINT);
            VariableReferenceExpression groupId = p.variable("groupId", (Type)BigintType.BIGINT);
            AggregationNode partialAgg = p.aggregation(builder -> builder.singleGroupingSet(groupingKey1, groupingKey2, groupingKey3, groupId).step(AggregationNode.Step.PARTIAL).source((PlanNode)p.groupId((List<List<VariableReferenceExpression>>)ImmutableList.of((Object)ImmutableList.of((Object)groupingKey1, (Object)groupingKey2), (Object)ImmutableList.of((Object)groupingKey1, (Object)groupingKey3)), (List<VariableReferenceExpression>)ImmutableList.of(), groupId, (PlanNode)p.values(new PlanNodeId(groupIdSourceId), groupingKey1, groupingKey2, groupingKey3))));
            return p.exchange(arg_0 -> TestAddExchangesBelowPartialAggregationOverGroupIdRuleSet.lambda$buildRuleAssert$3(groupingKey1, groupingKey2, groupingKey3, groupId, withProjection, p, (PlanNode)partialAgg, arg_0));
        });
    }

    private static /* synthetic */ void lambda$buildRuleAssert$3(VariableReferenceExpression groupingKey1, VariableReferenceExpression groupingKey2, VariableReferenceExpression groupingKey3, VariableReferenceExpression groupId, boolean withProjection, PlanBuilder p, PlanNode partialAgg, PlanBuilder.ExchangeBuilder exchangeBuilder) {
        exchangeBuilder.scope(ExchangeNode.Scope.REMOTE_STREAMING).partitioningScheme(new PartitioningScheme(Partitioning.create((PartitioningHandle)SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION, (Collection)ImmutableList.of()), (List)ImmutableList.copyOf((Collection)ImmutableList.of((Object)groupingKey1, (Object)groupingKey2, (Object)groupingKey3, (Object)groupId)))).addInputsSet(groupingKey1, groupingKey2, groupingKey3, groupId).addSource((PlanNode)(withProjection ? p.project(AssignmentUtils.identityAssignments((Collection)partialAgg.getOutputVariables()), partialAgg) : partialAgg));
    }
}

