/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.spi.Plugin;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.ExpectedValueProvider;
import io.trino.sql.planner.assertions.ExpressionMatcher;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.PushAggregationThroughOuterJoin;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.testng.annotations.Test;

public class TestPushAggregationThroughOuterJoin
extends BaseRuleTest {
    public TestPushAggregationThroughOuterJoin() {
        super(new Plugin[0]);
    }

    @Test
    public void testPushesAggregationThroughLeftJoin() {
        this.tester().assertThat((Rule<?>)new PushAggregationThroughOuterJoin()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinNode.Type.LEFT, (PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), (List<List<Expression>>)ImmutableList.of(PlanBuilder.expressions("10"))), (PlanNode)p.values(p.symbol("COL2")), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(p.symbol("COL1"), p.symbol("COL2"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(p.symbol("AVG", (Type)DoubleType.DOUBLE), PlanBuilder.expression("avg(COL2)"), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE)).singleGroupingSet(p.symbol("COL1")))).matches(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"COL1", (Object)PlanMatchPattern.expression("COL1"), (Object)"COALESCE", (Object)PlanMatchPattern.expression("coalesce(AVG, AVG_NULL)")), PlanMatchPattern.join(JoinNode.Type.INNER, builder -> builder.left(PlanMatchPattern.join(JoinNode.Type.LEFT, leftJoinBuilder -> leftJoinBuilder.equiCriteria("COL1", "COL2").left(PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"COL1", (Object)0))).right(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL2"), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.functionCall("avg", (List<String>)ImmutableList.of((Object)"COL2"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"COL2", (Object)0)))))).right(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("AVG_NULL"), PlanMatchPattern.functionCall("avg", (List<String>)ImmutableList.of((Object)"null_literal"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"null_literal", (Object)0)))))));
    }

    @Test
    public void testPushesAggregationThroughRightJoin() {
        this.tester().assertThat((Rule<?>)new PushAggregationThroughOuterJoin()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinNode.Type.RIGHT, (PlanNode)p.values(p.symbol("COL2")), (PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), (List<List<Expression>>)ImmutableList.of(PlanBuilder.expressions("10"))), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(p.symbol("COL2"), p.symbol("COL1"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL2")), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(p.symbol("AVG", (Type)DoubleType.DOUBLE), PlanBuilder.expression("avg(COL2)"), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE)).singleGroupingSet(p.symbol("COL1")))).matches(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"COALESCE", (Object)PlanMatchPattern.expression("coalesce(AVG, AVG_NULL)"), (Object)"COL1", (Object)PlanMatchPattern.expression("COL1")), PlanMatchPattern.join(JoinNode.Type.INNER, builder -> builder.left(PlanMatchPattern.join(JoinNode.Type.RIGHT, leftJoinBuilder -> leftJoinBuilder.equiCriteria("COL2", "COL1").left(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL2"), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.functionCall("avg", (List<String>)ImmutableList.of((Object)"COL2"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"COL2", (Object)0)))).right(PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"COL1", (Object)0))))).right(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("AVG_NULL"), PlanMatchPattern.functionCall("avg", (List<String>)ImmutableList.of((Object)"null_literal"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"null_literal", (Object)0)))))));
    }

    @Test
    public void testPushesAggregationWithMask() {
        this.tester().assertThat((Rule<?>)new PushAggregationThroughOuterJoin()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinNode.Type.LEFT, (PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), (List<List<Expression>>)ImmutableList.of(PlanBuilder.expressions("10"))), (PlanNode)p.values(p.symbol("COL2"), p.symbol("MASK")), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(p.symbol("COL1"), p.symbol("COL2"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL2"), (Object)p.symbol("MASK")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(p.symbol("AVG", (Type)DoubleType.DOUBLE), PlanBuilder.expression("avg(COL2)"), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE), p.symbol("MASK")).singleGroupingSet(p.symbol("COL1")))).matches(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"COL1", (Object)PlanMatchPattern.expression("COL1"), (Object)"COALESCE", (Object)PlanMatchPattern.expression("coalesce(AVG, AVG_NULL)")), PlanMatchPattern.join(JoinNode.Type.INNER, builder -> builder.left(PlanMatchPattern.join(JoinNode.Type.LEFT, leftJoinBuilder -> leftJoinBuilder.equiCriteria("COL1", "COL2").left(PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"COL1", (Object)0))).right(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL2"), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.functionCall("avg", (List<String>)ImmutableList.of((Object)"COL2"))), (List<String>)ImmutableList.of(), (List<String>)ImmutableList.of((Object)"MASK"), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"COL2", (Object)0, (Object)"MASK", (Object)1)))))).right(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("AVG_NULL"), PlanMatchPattern.functionCall("avg", (List<String>)ImmutableList.of((Object)"null_literal"))), (List<String>)ImmutableList.of(), (List<String>)ImmutableList.of((Object)"MASK_NULL"), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"null_literal", (Object)0, (Object)"MASK_NULL", (Object)1)))))));
    }

    @Test
    public void testPushCountAllAggregation() {
        this.tester().assertThat((Rule<?>)new PushAggregationThroughOuterJoin()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinNode.Type.LEFT, (PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), (List<List<Expression>>)ImmutableList.of(PlanBuilder.expressions("10"))), (PlanNode)p.values(p.symbol("COL2")), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(p.symbol("COL1"), p.symbol("COL2"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(p.symbol("COUNT"), PlanBuilder.expression("count(*)"), (List<Type>)ImmutableList.of()).singleGroupingSet(p.symbol("COL1")))).matches(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"COL1", (Object)PlanMatchPattern.expression("COL1"), (Object)"COALESCE", (Object)PlanMatchPattern.expression("coalesce(COUNT, COUNT_NULL)")), PlanMatchPattern.join(JoinNode.Type.INNER, builder -> builder.left(PlanMatchPattern.join(JoinNode.Type.LEFT, leftJoinBuilder -> leftJoinBuilder.equiCriteria("COL1", "COL2").left(PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"COL1", (Object)0))).right(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL2"), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("COUNT"), PlanMatchPattern.functionCall("count", (List<String>)ImmutableList.of())), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"COL2", (Object)0)))))).right(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), (Map<Optional<String>, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(Optional.of("COUNT_NULL"), PlanMatchPattern.functionCall("count", (List<String>)ImmutableList.of())), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>)ImmutableMap.of((Object)"null_literal", (Object)0)))))));
    }

    @Test
    public void testDoesNotFireWhenMultipleGroupingSets() {
        this.tester().assertThat((Rule<?>)new PushAggregationThroughOuterJoin()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinNode.Type.LEFT, (PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("COL1"), (Object)p.symbol("COL2")), (List<List<Expression>>)ImmutableList.of(PlanBuilder.expressions("1", "2"))), (PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("COL3")), (List<List<Expression>>)ImmutableList.of(PlanBuilder.expressions("1"))), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(p.symbol("COL1"), p.symbol("COL3"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL3")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(p.symbol("COUNT"), PlanBuilder.expression("count(*)"), (List<Type>)ImmutableList.of()).groupingSets(AggregationNode.groupingSets((List)ImmutableList.of((Object)p.symbol("COL1"), (Object)p.symbol("COL2")), (int)2, (Set)ImmutableSet.of())))).doesNotFire();
    }

    @Test
    public void testDoesNotFireWhenNotDistinct() {
        this.tester().assertThat((Rule<?>)new PushAggregationThroughOuterJoin()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinNode.Type.LEFT, (PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), (List<List<Expression>>)ImmutableList.of(PlanBuilder.expressions("10"), PlanBuilder.expressions("11"))), (PlanNode)p.values(new Symbol("COL2")), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(new Symbol("AVG"), PlanBuilder.expression("avg(COL2)"), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE)).singleGroupingSet(new Symbol("COL1")))).doesNotFire();
        this.tester().assertThat((Rule<?>)new PushAggregationThroughOuterJoin()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinNode.Type.LEFT, (PlanNode)p.project(Assignments.builder().putIdentity(p.symbol("COL1", (Type)BigintType.BIGINT)).build(), (PlanNode)p.aggregation(builder -> builder.singleGroupingSet(p.symbol("COL1"), p.symbol("unused")).source((PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("COL1"), (Object)p.symbol("unused")), (List<List<Expression>>)ImmutableList.of(PlanBuilder.expressions("10", "1"), PlanBuilder.expressions("10", "2")))))), (PlanNode)p.values(p.symbol("COL2")), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(p.symbol("COL1"), p.symbol("COL2"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(p.symbol("AVG", (Type)DoubleType.DOUBLE), PlanBuilder.expression("avg(COL2)"), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE)).singleGroupingSet(p.symbol("COL1")))).doesNotFire();
    }

    @Test
    public void testDoesNotFireWhenGroupingOnInner() {
        this.tester().assertThat((Rule<?>)new PushAggregationThroughOuterJoin()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinNode.Type.LEFT, (PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), (List<List<Expression>>)ImmutableList.of(PlanBuilder.expressions("10"))), (PlanNode)p.values(new Symbol("COL2"), new Symbol("COL3")), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(new Symbol("AVG"), PlanBuilder.expression("avg(COL2)"), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE)).singleGroupingSet(new Symbol("COL1"), new Symbol("COL3")))).doesNotFire();
    }

    @Test
    public void testDoesNotFireWhenAggregationDoesNotHaveSymbols() {
        this.tester().assertThat((Rule<?>)new PushAggregationThroughOuterJoin()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinNode.Type.LEFT, (PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), (List<List<Expression>>)ImmutableList.of(PlanBuilder.expressions("10"))), (PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("COL2")), (List<List<Expression>>)ImmutableList.of(PlanBuilder.expressions("20"))), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), (List<Symbol>)ImmutableList.of((Object)p.symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(new Symbol("SUM"), PlanBuilder.expression("sum(COL1)"), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE)).singleGroupingSet(new Symbol("COL1")))).doesNotFire();
    }

    @Test
    public void testDoesNotFireWhenAggregationOnMultipleSymbolsDoesNotHaveSomeSymbols() {
        this.tester().assertThat((Rule<?>)new PushAggregationThroughOuterJoin()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinNode.Type.LEFT, (PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), (List<List<Expression>>)ImmutableList.of(PlanBuilder.expressions("10"))), (PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("COL2"), (Object)p.symbol("COL3")), (List<List<Expression>>)ImmutableList.of(PlanBuilder.expressions("20", "30"))), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), (List<Symbol>)ImmutableList.of((Object)new Symbol("COL1")), (List<Symbol>)ImmutableList.of((Object)new Symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(new Symbol("MIN_BY"), PlanBuilder.expression("min_by(COL2, COL1)"), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE, (Object)DoubleType.DOUBLE)).singleGroupingSet(new Symbol("COL1")))).doesNotFire();
        this.tester().assertThat((Rule<?>)new PushAggregationThroughOuterJoin()).on(p -> p.aggregation(ab -> ab.source((PlanNode)p.join(JoinNode.Type.LEFT, (PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("COL1")), (List<List<Expression>>)ImmutableList.of(PlanBuilder.expressions("10"))), (PlanNode)p.values((List<Symbol>)ImmutableList.of((Object)p.symbol("COL2"), (Object)p.symbol("COL3")), (List<List<Expression>>)ImmutableList.of(PlanBuilder.expressions("20", "30"))), (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), (List<Symbol>)ImmutableList.of((Object)new Symbol("COL1")), (List<Symbol>)ImmutableList.of((Object)new Symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(new Symbol("SUM"), PlanBuilder.expression("sum(COL2)"), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE)).addAggregation(new Symbol("MIN_BY"), PlanBuilder.expression("min_by(COL2, COL3)"), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE, (Object)DoubleType.DOUBLE)).addAggregation(new Symbol("MAX_BY"), PlanBuilder.expression("max_by(COL2, COL1)"), (List<Type>)ImmutableList.of((Object)DoubleType.DOUBLE, (Object)DoubleType.DOUBLE)).singleGroupingSet(new Symbol("COL1")))).doesNotFire();
    }
}

