/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.Plugin;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.Type;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.AggregationFunction;
import io.trino.sql.planner.assertions.ExpectedValueProvider;
import io.trino.sql.planner.assertions.ExpressionMatcher;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.PushPartialAggregationThroughJoin;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.junit.jupiter.api.Test;

public class TestPushPartialAggregationThroughJoin
extends BaseRuleTest {
    private static final PlanNodeId JOIN_ID = new PlanNodeId("join_id");
    private static final PlanNodeId CHILD_ID = new PlanNodeId("child_id");
    private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution();
    private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, (List<? extends Type>)ImmutableList.of((Object)BigintType.BIGINT, (Object)BigintType.BIGINT));

    public TestPushPartialAggregationThroughJoin() {
        super(new Plugin[0]);
    }

    @Test
    public void testPushesPartialAggregationThroughJoinToLeftChildWithoutProjection() {
        this.tester().assertThat(new PushPartialAggregationThroughJoin().pushPartialAggregationThroughJoinWithoutProjection()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinType.INNER, (PlanNode)p.values(p.symbol("LEFT_EQUI"), p.symbol("LEFT_NON_EQUI"), p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_AGGR")), (PlanNode)p.values(p.symbol("RIGHT_EQUI"), p.symbol("RIGHT_NON_EQUI")), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(p.symbol("LEFT_EQUI"), p.symbol("RIGHT_EQUI"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("LEFT_EQUI"), (Object)p.symbol("LEFT_NON_EQUI"), (Object)p.symbol("LEFT_GROUP_BY"), (Object)p.symbol("LEFT_AGGR")), (List<Symbol>)ImmutableList.of(), Optional.of(new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "LEFT_NON_EQUI"), (Expression)new Reference((Type)BigintType.BIGINT, "RIGHT_NON_EQUI"))))).addAggregation(p.symbol("AVG", (Type)DoubleType.DOUBLE), PlanBuilder.aggregation("AVG", (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "LEFT_AGGR"))), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE)).singleGroupingSet(p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_EQUI"), p.symbol("LEFT_NON_EQUI")).step(AggregationNode.Step.PARTIAL))).matches(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"LEFT_GROUP_BY", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "LEFT_GROUP_BY")), (Object)"LEFT_EQUI", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "LEFT_EQUI")), (Object)"LEFT_NON_EQUI", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "LEFT_NON_EQUI")), (Object)"AVG", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)DoubleType.DOUBLE, "AVG"))), PlanMatchPattern.join(JoinType.INNER, builder -> builder.equiCriteria("LEFT_EQUI", "RIGHT_EQUI").filter((Expression)new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "LEFT_NON_EQUI"), (Expression)new Reference((Type)BigintType.BIGINT, "RIGHT_NON_EQUI"))).left(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("LEFT_GROUP_BY", "LEFT_EQUI", "LEFT_NON_EQUI"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.aggregationFunction("avg", (List<String>)ImmutableList.of((Object)"LEFT_AGGR"))), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.values("LEFT_EQUI", "LEFT_NON_EQUI", "LEFT_GROUP_BY", "LEFT_AGGR"))).right(PlanMatchPattern.values("RIGHT_EQUI", "RIGHT_NON_EQUI")))));
        this.tester().assertThat(new PushPartialAggregationThroughJoin().pushPartialAggregationThroughJoinWithoutProjection()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinType.INNER, (PlanNode)p.values(p.symbol("LEFT_EQUI"), p.symbol("LEFT_NON_EQUI")), (PlanNode)p.values(p.symbol("RIGHT_EQUI"), p.symbol("RIGHT_NON_EQUI"), p.symbol("RIGHT_GROUP_BY"), p.symbol("RIGHT_AGGR")), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(p.symbol("LEFT_EQUI"), p.symbol("RIGHT_EQUI"))), (List<Symbol>)ImmutableList.of(), (List<Symbol>)ImmutableList.of((Object)p.symbol("RIGHT_EQUI"), (Object)p.symbol("RIGHT_NON_EQUI"), (Object)p.symbol("RIGHT_GROUP_BY"), (Object)p.symbol("RIGHT_AGGR")), Optional.of(new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "LEFT_NON_EQUI"), (Expression)new Reference((Type)BigintType.BIGINT, "RIGHT_NON_EQUI"))))).addAggregation(p.symbol("AVG", (Type)DoubleType.DOUBLE), PlanBuilder.aggregation("avg", (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "RIGHT_AGGR"))), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE)).singleGroupingSet(p.symbol("RIGHT_GROUP_BY"), p.symbol("RIGHT_EQUI"), p.symbol("RIGHT_NON_EQUI")).step(AggregationNode.Step.PARTIAL))).matches(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"RIGHT_GROUP_BY", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "RIGHT_GROUP_BY")), (Object)"RIGHT_EQUI", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "RIGHT_EQUI")), (Object)"RIGHT_NON_EQUI", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "RIGHT_NON_EQUI")), (Object)"AVG", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)DoubleType.DOUBLE, "AVG"))), PlanMatchPattern.join(JoinType.INNER, builder -> builder.equiCriteria("LEFT_EQUI", "RIGHT_EQUI").filter((Expression)new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "LEFT_NON_EQUI"), (Expression)new Reference((Type)BigintType.BIGINT, "RIGHT_NON_EQUI"))).left(PlanMatchPattern.values("LEFT_EQUI", "LEFT_NON_EQUI")).right(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("RIGHT_GROUP_BY", "RIGHT_EQUI", "RIGHT_NON_EQUI"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.aggregationFunction("avg", (List<String>)ImmutableList.of((Object)"RIGHT_AGGR"))), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.values("RIGHT_EQUI", "RIGHT_NON_EQUI", "RIGHT_GROUP_BY", "RIGHT_AGGR"))))));
    }

    @Test
    public void testDoesNotPushPartialAggregationForExpandingJoin() {
        this.tester().assertThat(new PushPartialAggregationThroughJoin().pushPartialAggregationThroughJoinWithoutProjection()).overrideStats(CHILD_ID.toString(), new PlanNodeStatsEstimate(10.0, (Map)ImmutableMap.of())).overrideStats(JOIN_ID.toString(), new PlanNodeStatsEstimate(20.0, (Map)ImmutableMap.of())).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JOIN_ID, JoinType.INNER, (PlanNode)p.values(CHILD_ID, p.symbol("LEFT_EQUI"), p.symbol("LEFT_NON_EQUI"), p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_AGGR")), (PlanNode)p.values(p.symbol("RIGHT_EQUI"), p.symbol("RIGHT_NON_EQUI")), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(p.symbol("LEFT_EQUI"), p.symbol("RIGHT_EQUI"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("LEFT_EQUI"), (Object)p.symbol("LEFT_NON_EQUI"), (Object)p.symbol("LEFT_GROUP_BY"), (Object)p.symbol("LEFT_AGGR")), (List<Symbol>)ImmutableList.of(), Optional.of(new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "LEFT_NON_EQUI"), (Expression)new Reference((Type)BigintType.BIGINT, "RIGHT_NON_EQUI"))), Optional.empty(), Optional.empty(), Optional.empty(), (Map<DynamicFilterId, Symbol>)ImmutableMap.of())).addAggregation(p.symbol("AVG", (Type)DoubleType.DOUBLE), PlanBuilder.aggregation("avg", (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "LEFT_AGGR"))), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE)).singleGroupingSet(p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_EQUI"), p.symbol("LEFT_NON_EQUI")).step(AggregationNode.Step.PARTIAL))).doesNotFire();
    }

    @Test
    public void testDoesNotPushPartialAggregationIfPushedGroupingSetIsLarger() {
        this.tester().assertThat(new PushPartialAggregationThroughJoin().pushPartialAggregationThroughJoinWithoutProjection()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinType.INNER, (PlanNode)p.values(p.symbol("LEFT_EQUI"), p.symbol("LEFT_NON_EQUI"), p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_AGGR")), (PlanNode)p.values(p.symbol("RIGHT_EQUI"), p.symbol("RIGHT_NON_EQUI")), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(p.symbol("LEFT_EQUI"), p.symbol("RIGHT_EQUI"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("LEFT_EQUI"), (Object)p.symbol("LEFT_NON_EQUI"), (Object)p.symbol("LEFT_GROUP_BY"), (Object)p.symbol("LEFT_AGGR")), (List<Symbol>)ImmutableList.of(), Optional.of(new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "LEFT_NON_EQUI"), (Expression)new Reference((Type)BigintType.BIGINT, "RIGHT_NON_EQUI"))))).addAggregation(p.symbol("AVG", (Type)DoubleType.DOUBLE), PlanBuilder.aggregation("avg", (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "LEFT_AGGR"))), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE)).singleGroupingSet(p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_EQUI")).step(AggregationNode.Step.PARTIAL))).doesNotFire();
        this.tester().assertThat(new PushPartialAggregationThroughJoin().pushPartialAggregationThroughJoinWithProjection()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.project(Assignments.builder().put(p.symbol("LEFT_AGGR_PRJ"), (Expression)new Call(ADD_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "LEFT_AGGR"), (Object)new Reference((Type)BigintType.BIGINT, "LEFT_AGGR")))).putIdentity(p.symbol("LEFT_GROUP_BY")).putIdentity(p.symbol("LEFT_EQUI")).putIdentity(p.symbol("LEFT_NON_EQUI")).build(), (PlanNode)p.join(JoinType.INNER, (PlanNode)p.values(p.symbol("LEFT_EQUI"), p.symbol("LEFT_NON_EQUI"), p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_AGGR")), (PlanNode)p.values(p.symbol("RIGHT_EQUI"), p.symbol("RIGHT_NON_EQUI")), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(p.symbol("LEFT_EQUI"), p.symbol("RIGHT_EQUI"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("LEFT_EQUI"), (Object)p.symbol("LEFT_NON_EQUI"), (Object)p.symbol("LEFT_GROUP_BY"), (Object)p.symbol("LEFT_AGGR")), (List<Symbol>)ImmutableList.of(), Optional.of(new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "LEFT_NON_EQUI"), (Expression)new Reference((Type)BigintType.BIGINT, "RIGHT_NON_EQUI")))))).addAggregation(p.symbol("AVG", (Type)DoubleType.DOUBLE), PlanBuilder.aggregation("avg", (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "LEFT_AGGR_PRJ"))), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE)).singleGroupingSet(p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_EQUI")).step(AggregationNode.Step.PARTIAL))).doesNotFire();
    }

    @Test
    public void testDoesNotPushPartialAggregationIfPushedGroupingSetIsSame() {
        this.tester().assertThat(new PushPartialAggregationThroughJoin().pushPartialAggregationThroughJoinWithoutProjection()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinType.INNER, (PlanNode)p.values(p.symbol("FACT_DATE_ID"), p.symbol("AMOUNT")), (PlanNode)p.values(p.symbol("DATE_DIM_DATE_ID"), p.symbol("DATE_DIM_YEAR")), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(p.symbol("FACT_DATE_ID"), p.symbol("DATE_DIM_DATE_ID"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("FACT_DATE_ID"), (Object)p.symbol("AMOUNT")), (List<Symbol>)ImmutableList.of((Object)p.symbol("DATE_DIM_YEAR")), Optional.empty())).addAggregation(p.symbol("AVG", (Type)DoubleType.DOUBLE), PlanBuilder.aggregation("avg", (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "AMOUNT"))), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE)).singleGroupingSet(p.symbol("DATE_DIM_YEAR")).step(AggregationNode.Step.PARTIAL))).matches(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"DATE_DIM_YEAR", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "DATE_DIM_YEAR")), (Object)"AVG", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)DoubleType.DOUBLE, "AVG"))), PlanMatchPattern.join(JoinType.INNER, builder -> builder.equiCriteria("FACT_DATE_ID", "DATE_DIM_DATE_ID").left(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("FACT_DATE_ID"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.aggregationFunction("avg", (List<String>)ImmutableList.of((Object)"AMOUNT"))), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.values("FACT_DATE_ID", "AMOUNT"))).right(PlanMatchPattern.values("DATE_DIM_DATE_ID", "DATE_DIM_YEAR")))));
    }

    @Test
    public void testDoesNotPushPartialAggregationIfGroupingSymbolHasBigNDV() {
        this.tester().assertThat(new PushPartialAggregationThroughJoin().pushPartialAggregationThroughJoinWithoutProjection()).overrideStats(CHILD_ID.toString(), new PlanNodeStatsEstimate(10.0, (Map)ImmutableMap.of((Object)new Symbol((Type)BigintType.BIGINT, "FACT_DATE_ID"), (Object)new SymbolStatsEstimate(Double.NaN, Double.NaN, 0.0, Double.NaN, 10.0)))).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinType.INNER, (PlanNode)p.values(CHILD_ID, p.symbol("FACT_DATE_ID"), p.symbol("AMOUNT")), (PlanNode)p.values(p.symbol("DATE_DIM_DATE_ID"), p.symbol("DATE_DIM_YEAR")), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(p.symbol("FACT_DATE_ID"), p.symbol("DATE_DIM_DATE_ID"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("FACT_DATE_ID"), (Object)p.symbol("AMOUNT")), (List<Symbol>)ImmutableList.of((Object)p.symbol("DATE_DIM_YEAR")), Optional.empty())).addAggregation(p.symbol("AVG", (Type)DoubleType.DOUBLE), PlanBuilder.aggregation("avg", (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "AMOUNT"))), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE)).singleGroupingSet(p.symbol("DATE_DIM_YEAR")).step(AggregationNode.Step.PARTIAL))).doesNotFire();
    }

    @Test
    public void testKeepsIntermediateAggregation() {
        this.tester().assertThat(new PushPartialAggregationThroughJoin().pushPartialAggregationThroughJoinWithoutProjection()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinType.INNER, (PlanNode)p.values(p.symbol("FACT_DATE_ID"), p.symbol("AMOUNT")), (PlanNode)p.values(p.symbol("DATE_DIM_DATE_ID"), p.symbol("DATE_DIM_YEAR")), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(p.symbol("FACT_DATE_ID"), p.symbol("DATE_DIM_DATE_ID"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("FACT_DATE_ID"), (Object)p.symbol("AMOUNT")), (List<Symbol>)ImmutableList.of((Object)p.symbol("DATE_DIM_YEAR")), Optional.empty())).addAggregation(p.symbol("AVG", (Type)DoubleType.DOUBLE), PlanBuilder.aggregation("avg", (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "AMOUNT"))), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE)).singleGroupingSet(p.symbol("DATE_DIM_YEAR")).step(AggregationNode.Step.PARTIAL).exchangeInputAggregation(true))).matches(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("DATE_DIM_YEAR"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.aggregationFunction("avg", (List<String>)ImmutableList.of((Object)"AVG"))), Optional.empty(), AggregationNode.Step.INTERMEDIATE, PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"DATE_DIM_YEAR", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "DATE_DIM_YEAR")), (Object)"AVG", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)DoubleType.DOUBLE, "AVG"))), PlanMatchPattern.join(JoinType.INNER, builder -> builder.equiCriteria("FACT_DATE_ID", "DATE_DIM_DATE_ID").left(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("FACT_DATE_ID"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.aggregationFunction("avg", (List<String>)ImmutableList.of((Object)"AMOUNT"))), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.values("FACT_DATE_ID", "AMOUNT"))).right(PlanMatchPattern.values("DATE_DIM_DATE_ID", "DATE_DIM_YEAR"))))));
        this.tester().assertThat(new PushPartialAggregationThroughJoin().pushPartialAggregationThroughJoinWithoutProjection()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinType.INNER, (PlanNode)p.values(p.symbol("FACT_DATE_ID"), p.symbol("AMOUNT")), (PlanNode)p.values(p.symbol("DATE_DIM_DATE_ID"), p.symbol("DATE_DIM_YEAR")), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(p.symbol("FACT_DATE_ID"), p.symbol("DATE_DIM_DATE_ID"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("FACT_DATE_ID"), (Object)p.symbol("AMOUNT")), (List<Symbol>)ImmutableList.of((Object)p.symbol("DATE_DIM_YEAR")), Optional.empty())).addAggregation(p.symbol("AVG", (Type)DoubleType.DOUBLE), PlanBuilder.aggregation("avg", (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "AMOUNT"))), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE)).singleGroupingSet(p.symbol("FACT_DATE_ID")).step(AggregationNode.Step.PARTIAL).exchangeInputAggregation(true))).matches(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"FACT_DATE_ID", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "FACT_DATE_ID")), (Object)"AVG", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)DoubleType.DOUBLE, "AVG"))), PlanMatchPattern.join(JoinType.INNER, builder -> builder.equiCriteria("FACT_DATE_ID", "DATE_DIM_DATE_ID").left(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("FACT_DATE_ID"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.aggregationFunction("avg", (List<String>)ImmutableList.of((Object)"AMOUNT"))), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.values("FACT_DATE_ID", "AMOUNT"))).right(PlanMatchPattern.values("DATE_DIM_DATE_ID", "DATE_DIM_YEAR")))));
    }

    @Test
    public void testPushesPartialAggregationThroughJoinWithProjection() {
        this.tester().assertThat(new PushPartialAggregationThroughJoin().pushPartialAggregationThroughJoinWithProjection()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.project(Assignments.builder().put(p.symbol("LEFT_AGGR_PRJ"), (Expression)new Call(ADD_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "LEFT_AGGR"), (Object)new Reference((Type)BigintType.BIGINT, "LEFT_AGGR")))).putIdentity(p.symbol("LEFT_GROUP_BY")).putIdentity(p.symbol("LEFT_EQUI")).putIdentity(p.symbol("LEFT_NON_EQUI")).build(), (PlanNode)p.join(JoinType.INNER, (PlanNode)p.values(p.symbol("LEFT_EQUI"), p.symbol("LEFT_NON_EQUI"), p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_AGGR")), (PlanNode)p.values(p.symbol("RIGHT_EQUI"), p.symbol("RIGHT_NON_EQUI")), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(p.symbol("LEFT_EQUI"), p.symbol("RIGHT_EQUI"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("LEFT_EQUI"), (Object)p.symbol("LEFT_NON_EQUI"), (Object)p.symbol("LEFT_GROUP_BY"), (Object)p.symbol("LEFT_AGGR")), (List<Symbol>)ImmutableList.of(), Optional.of(new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "LEFT_NON_EQUI"), (Expression)new Reference((Type)BigintType.BIGINT, "RIGHT_NON_EQUI")))))).addAggregation(p.symbol("AVG", (Type)DoubleType.DOUBLE), PlanBuilder.aggregation("avg", (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "LEFT_AGGR_PRJ"))), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE)).singleGroupingSet(p.symbol("LEFT_GROUP_BY"), p.symbol("LEFT_EQUI"), p.symbol("LEFT_NON_EQUI")).step(AggregationNode.Step.PARTIAL))).matches(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"LEFT_GROUP_BY", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "LEFT_GROUP_BY")), (Object)"LEFT_EQUI", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "LEFT_EQUI")), (Object)"LEFT_NON_EQUI", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "LEFT_NON_EQUI")), (Object)"AVG", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)DoubleType.DOUBLE, "AVG"))), PlanMatchPattern.join(JoinType.INNER, builder -> builder.equiCriteria("LEFT_EQUI", "RIGHT_EQUI").filter((Expression)new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "LEFT_NON_EQUI"), (Expression)new Reference((Type)BigintType.BIGINT, "RIGHT_NON_EQUI"))).left(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("LEFT_GROUP_BY", "LEFT_EQUI", "LEFT_NON_EQUI"), (Map<Optional<String>, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.aggregationFunction("avg", (List<String>)ImmutableList.of((Object)"LEFT_AGGR_PRJ"))), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"LEFT_AGGR_PRJ", (Object)PlanMatchPattern.expression((Expression)new Call(ADD_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "LEFT_AGGR"), (Object)new Reference((Type)BigintType.BIGINT, "LEFT_AGGR"))))), PlanMatchPattern.values("LEFT_EQUI", "LEFT_NON_EQUI", "LEFT_GROUP_BY", "LEFT_AGGR")))).right(PlanMatchPattern.project(PlanMatchPattern.values("RIGHT_EQUI", "RIGHT_NON_EQUI"))))));
    }
}

