/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.primitives.ImmutableLongArray;
import io.trino.Session;
import io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.RuleStatsRecorder;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.BasePlanTest;
import io.trino.sql.planner.assertions.ExpectedValueProvider;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.assertions.SubPlanMatcher;
import io.trino.sql.planner.iterative.IterativeOptimizer;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.AdaptivePlanOptimizer;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.junit.jupiter.api.Test;

public class TestAdaptivePlanner
extends BasePlanTest {
    @Test
    public void testJoinOrderSwitchRule() {
        Session session = Session.builder((Session)this.getPlanTester().getDefaultSession()).setSystemProperty("join_distribution_type", "PARTITIONED").build();
        SubPlanMatcher matcher = SubPlanMatcher.builder().fragmentMatcher(fm -> fm.fragmentId(3).planPattern(PlanMatchPattern.any(PlanMatchPattern.adaptivePlan(PlanMatchPattern.join(JoinType.INNER, builder -> builder.equiCriteria((List<ExpectedValueProvider<JoinNode.EquiJoinClause>>)ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol((Type)BigintType.BIGINT, "nationkey"), new Symbol((Type)BigintType.BIGINT, "nationkey_1")))).left(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("1")))).right(PlanMatchPattern.any(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("2")))))), PlanMatchPattern.join(JoinType.INNER, builder -> builder.equiCriteria((List<ExpectedValueProvider<JoinNode.EquiJoinClause>>)ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol((Type)BigintType.BIGINT, "nationkey_1"), new Symbol((Type)BigintType.BIGINT, "nationkey")))).right(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("1")))).left(PlanMatchPattern.any(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("2")))))))))).children(spb -> spb.fragmentMatcher(fm -> fm.fragmentId(2).planPattern(PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0]))), spb -> spb.fragmentMatcher(fm -> fm.fragmentId(1).planPattern(PlanMatchPattern.any(PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0]))))).build();
        this.assertAdaptivePlan("SELECT n.name FROM supplier AS s JOIN nation AS n on s.nationkey = n.nationkey", session, (List<AdaptivePlanOptimizer>)ImmutableList.of((Object)new IterativeOptimizer(this.getPlanTester().getPlannerContext(), new RuleStatsRecorder(), this.getPlanTester().getStatsCalculator(), this.getPlanTester().getCostCalculator(), (Set)ImmutableSet.builder().add((Object)new TestJoinOrderSwitchRule()).build())), (Map<PlanFragmentId, OutputStatsEstimator.OutputStatsEstimateResult>)ImmutableMap.of((Object)new PlanFragmentId("1"), (Object)this.createRuntimeStats(ImmutableLongArray.of((long)10000L, (long)10000L, (long)10000L), 10000L), (Object)new PlanFragmentId("2"), (Object)this.createRuntimeStats(ImmutableLongArray.of((long)200L, (long)2000L, (long)1000L), 500L)), matcher, true);
    }

    @Test
    public void testNoChangeInFragmentIdsForUnchangedSubPlans() {
        Session session = Session.builder((Session)this.getPlanTester().getDefaultSession()).setSystemProperty("join_distribution_type", "PARTITIONED").build();
        SubPlanMatcher matcher = SubPlanMatcher.builder().fragmentMatcher(fm -> fm.fragmentId(5).planPattern(PlanMatchPattern.output(PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.exchange(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("6")))))))).children(spb -> spb.fragmentMatcher(fm -> fm.fragmentId(6).planPattern(PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.adaptivePlan(PlanMatchPattern.join(JoinType.INNER, builder -> builder.equiCriteria((List<ExpectedValueProvider<JoinNode.EquiJoinClause>>)ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol((Type)BigintType.BIGINT, "nationkey"), new Symbol((Type)BigintType.BIGINT, "count")))).left(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("2")))).right(PlanMatchPattern.any(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("3")))))), PlanMatchPattern.join(JoinType.INNER, builder -> builder.equiCriteria((List<ExpectedValueProvider<JoinNode.EquiJoinClause>>)ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol((Type)BigintType.BIGINT, "count"), new Symbol((Type)BigintType.BIGINT, "nationkey")))).right(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("2")))).left(PlanMatchPattern.any(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("3")))))))))).children(spb2 -> spb2.fragmentMatcher(fm -> fm.fragmentId(3).planPattern(PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.exchange(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("4"))))))).children(spb3 -> spb3.fragmentMatcher(fm -> fm.fragmentId(4).planPattern(PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0]))))), spb2 -> spb2.fragmentMatcher(fm -> fm.fragmentId(2).planPattern(PlanMatchPattern.any(PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0])))))).build();
        this.assertAdaptivePlan("    WITH t AS (SELECT regionkey, count(*) as some_count FROM nation group by regionkey)\n    SELECT max(s.nationkey), sum(t.regionkey)\n    FROM supplier AS s\n    JOIN t\n    ON s.nationkey = t.some_count\n", session, (List<AdaptivePlanOptimizer>)ImmutableList.of((Object)new IterativeOptimizer(this.getPlanTester().getPlannerContext(), new RuleStatsRecorder(), this.getPlanTester().getStatsCalculator(), this.getPlanTester().getCostCalculator(), (Set)ImmutableSet.builder().add((Object)new TestJoinOrderSwitchRule()).build())), (Map<PlanFragmentId, OutputStatsEstimator.OutputStatsEstimateResult>)ImmutableMap.of((Object)new PlanFragmentId("3"), (Object)this.createRuntimeStats(ImmutableLongArray.of((long)10000L, (long)10000L, (long)10000L), 10000L), (Object)new PlanFragmentId("2"), (Object)this.createRuntimeStats(ImmutableLongArray.of((long)200L, (long)2000L, (long)1000L), 500L)), matcher, true);
    }

    @Test
    public void testAdaptivePlanNodeAsRootOfFragment() {
        Session session = Session.builder((Session)this.getPlanTester().getDefaultSession()).setSystemProperty("join_distribution_type", "PARTITIONED").setSystemProperty("push_partial_aggregation_through_join", "true").setSystemProperty("distinct_aggregations_strategy", "SINGLE_STEP").build();
        SubPlanMatcher matcher = SubPlanMatcher.builder().fragmentMatcher(fm -> fm.fragmentId(5).planPattern(PlanMatchPattern.output(PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.exchange(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("6")))))))).children(spb -> spb.fragmentMatcher(fm -> fm.fragmentId(6).planPattern(PlanMatchPattern.adaptivePlan(PlanMatchPattern.join(JoinType.INNER, builder -> builder.equiCriteria((List<ExpectedValueProvider<JoinNode.EquiJoinClause>>)ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol((Type)BigintType.BIGINT, "nationkey"), new Symbol((Type)BigintType.BIGINT, "count")))).left(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("2")))).right(PlanMatchPattern.any(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("3")))))), PlanMatchPattern.join(JoinType.INNER, builder -> builder.equiCriteria((List<ExpectedValueProvider<JoinNode.EquiJoinClause>>)ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol((Type)BigintType.BIGINT, "count"), new Symbol((Type)BigintType.BIGINT, "nationkey")))).right(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("2")))).left(PlanMatchPattern.any(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("3"))))))))).children(spb2 -> spb2.fragmentMatcher(fm -> fm.fragmentId(3).planPattern(PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.exchange(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("4"))))))).children(spb3 -> spb3.fragmentMatcher(fm -> fm.fragmentId(4).planPattern(PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0]))))), spb2 -> spb2.fragmentMatcher(fm -> fm.fragmentId(2).planPattern(PlanMatchPattern.any(PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0])))))).build();
        this.assertAdaptivePlan("    WITH t AS (SELECT regionkey, count(*) as some_count FROM nation group by regionkey)\n    SELECT max(distinct s.nationkey), sum(distinct t.regionkey)\n    FROM supplier AS s\n    JOIN t\n    ON s.nationkey = t.some_count\n", session, (List<AdaptivePlanOptimizer>)ImmutableList.of((Object)new IterativeOptimizer(this.getPlanTester().getPlannerContext(), new RuleStatsRecorder(), this.getPlanTester().getStatsCalculator(), this.getPlanTester().getCostCalculator(), (Set)ImmutableSet.builder().add((Object)new TestJoinOrderSwitchRule()).build())), (Map<PlanFragmentId, OutputStatsEstimator.OutputStatsEstimateResult>)ImmutableMap.of(), matcher, true);
    }

    @Test
    public void testNoChangeToRootSubPlanIfStatsAreAccurate() {
        Session session = Session.builder((Session)this.getPlanTester().getDefaultSession()).setSystemProperty("join_distribution_type", "PARTITIONED").build();
        SubPlanMatcher matcher = SubPlanMatcher.builder().fragmentMatcher(fm -> fm.fragmentId(0).planPattern(PlanMatchPattern.any(PlanMatchPattern.join(JoinType.INNER, builder -> builder.equiCriteria((List<ExpectedValueProvider<JoinNode.EquiJoinClause>>)ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol((Type)BigintType.BIGINT, "nationkey"), new Symbol((Type)BigintType.BIGINT, "nationkey_1")))).left(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("1")))).right(PlanMatchPattern.any(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("2"))))))))).children(spb -> spb.fragmentMatcher(fm -> fm.fragmentId(1).planPattern(PlanMatchPattern.any(PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0])))), spb -> spb.fragmentMatcher(fm -> fm.fragmentId(2).planPattern(PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0])))).build();
        this.assertAdaptivePlan("SELECT n.name FROM supplier AS s JOIN nation AS n on s.nationkey = n.nationkey", session, (List<AdaptivePlanOptimizer>)ImmutableList.of((Object)new IterativeOptimizer(this.getPlanTester().getPlannerContext(), new RuleStatsRecorder(), this.getPlanTester().getStatsCalculator(), this.getPlanTester().getCostCalculator(), (Set)ImmutableSet.builder().add((Object)new TestJoinOrderSwitchRule()).build())), (Map<PlanFragmentId, OutputStatsEstimator.OutputStatsEstimateResult>)ImmutableMap.of((Object)new PlanFragmentId("1"), (Object)this.createRuntimeStats(ImmutableLongArray.of((long)10000L, (long)10000L, (long)10000L), 10000L), (Object)new PlanFragmentId("2"), (Object)this.createRuntimeStats(ImmutableLongArray.of((long)200L, (long)2000L, (long)1000L), 500L), (Object)new PlanFragmentId("0"), (Object)this.createRuntimeStats(ImmutableLongArray.of((long)10000L, (long)10000L, (long)10000L), 10000L)), matcher, true);
    }

    @Test
    public void testNoChangeToNestedSubPlanIfStatsAreAccurate() {
        Session session = Session.builder((Session)this.getPlanTester().getDefaultSession()).setSystemProperty("join_distribution_type", "PARTITIONED").build();
        SubPlanMatcher matcher = SubPlanMatcher.builder().fragmentMatcher(fm -> fm.fragmentId(0).planPattern(PlanMatchPattern.output(PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.exchange(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("1")))))))).children(spb -> spb.fragmentMatcher(fm -> fm.fragmentId(1).planPattern(PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.join(JoinType.INNER, builder -> builder.equiCriteria((List<ExpectedValueProvider<JoinNode.EquiJoinClause>>)ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol((Type)BigintType.BIGINT, "nationkey"), new Symbol((Type)BigintType.BIGINT, "count")))).left(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("2")))).right(PlanMatchPattern.any(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("3"))))))))).children(spb2 -> spb2.fragmentMatcher(fm -> fm.fragmentId(2).planPattern(PlanMatchPattern.any(PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0])))), spb2 -> spb2.fragmentMatcher(fm -> fm.fragmentId(3).planPattern(PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.exchange(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("4"))))))).children(spb3 -> spb3.fragmentMatcher(fm -> fm.fragmentId(4).planPattern(PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0]))))))).build();
        this.assertAdaptivePlan("    WITH t AS (SELECT regionkey, count(*) as some_count FROM nation group by regionkey)\n    SELECT max(s.nationkey), sum(t.regionkey)\n    FROM supplier AS s\n    JOIN t\n    ON s.nationkey = t.some_count\n", session, (List<AdaptivePlanOptimizer>)ImmutableList.of((Object)new IterativeOptimizer(this.getPlanTester().getPlannerContext(), new RuleStatsRecorder(), this.getPlanTester().getStatsCalculator(), this.getPlanTester().getCostCalculator(), (Set)ImmutableSet.builder().add((Object)new TestJoinOrderSwitchRule()).build())), (Map<PlanFragmentId, OutputStatsEstimator.OutputStatsEstimateResult>)ImmutableMap.of((Object)new PlanFragmentId("1"), (Object)this.createRuntimeStats(ImmutableLongArray.of((long)10000L, (long)10000L, (long)10000L), 10000L), (Object)new PlanFragmentId("3"), (Object)this.createRuntimeStats(ImmutableLongArray.of((long)10000L, (long)10000L, (long)10000L), 10000L), (Object)new PlanFragmentId("4"), (Object)this.createRuntimeStats(ImmutableLongArray.of((long)10000L, (long)10000L, (long)10000L), 10000L), (Object)new PlanFragmentId("2"), (Object)this.createRuntimeStats(ImmutableLongArray.of((long)200L, (long)2000L, (long)1000L), 500L)), matcher, true);
    }

    @Test
    public void testWhenSimilarColumnIsProjectedTwice() {
        Session session = Session.builder((Session)this.getPlanTester().getDefaultSession()).setSystemProperty("join_distribution_type", "PARTITIONED").setSystemProperty("prefer_partial_aggregation", "false").build();
        this.assertAdaptivePlan("    SELECT\n        sum(sales),\n        sum(another_sales),\n        sum(acctbal)\n    FROM (\n    SELECT\n        CAST(0 AS DECIMAL(7,2)) \"sales\",\n        CAST(0 AS DECIMAL(7,2)) \"another_sales\",\n        cast(\"acctbal\" as DECIMAL(7,2)) \"acctbal\"\n    FROM customer\n    UNION ALL\n    SELECT\n        cast(\"acctbal\" as DECIMAL(7,2)) \"sales\",\n        CAST(0 AS DECIMAL(7,2)) \"another_sales\",\n        CAST(0 AS DECIMAL(7,2)) \"acctbal\"\n    FROM customer\n    ) test_table\n", session, (List<AdaptivePlanOptimizer>)ImmutableList.of((Object)new IterativeOptimizer(this.getPlanTester().getPlannerContext(), new RuleStatsRecorder(), this.getPlanTester().getStatsCalculator(), this.getPlanTester().getCostCalculator(), (Set)ImmutableSet.builder().add((Object)new TestJoinOrderSwitchRule()).build())), (Map<PlanFragmentId, OutputStatsEstimator.OutputStatsEstimateResult>)ImmutableMap.of(), SubPlanMatcher.builder().fragmentMatcher(fm -> fm.fragmentId(0).planPattern(PlanMatchPattern.output(PlanMatchPattern.node(AggregationNode.class, PlanMatchPattern.exchange(PlanMatchPattern.remoteSource((List<PlanFragmentId>)ImmutableList.of((Object)new PlanFragmentId("1"), (Object)new PlanFragmentId("2")))))))).children(spb -> spb.fragmentMatcher(fm -> fm.fragmentId(1).planPattern(PlanMatchPattern.node(ProjectNode.class, PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0])))), spb -> spb.fragmentMatcher(fm -> fm.fragmentId(2).planPattern(PlanMatchPattern.node(ProjectNode.class, PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0]))))).build(), true);
    }

    private OutputStatsEstimator.OutputStatsEstimateResult createRuntimeStats(ImmutableLongArray partitionDataSizes, long outputRowCountEstimate) {
        return new OutputStatsEstimator.OutputStatsEstimateResult(partitionDataSizes, outputRowCountEstimate, "FINISHED", true);
    }

    private static class TestJoinOrderSwitchRule
    implements Rule<JoinNode> {
        private static final Pattern<JoinNode> PATTERN = Patterns.join();
        private final Set<PlanNodeId> alreadyVisited = new HashSet<PlanNodeId>();

        private TestJoinOrderSwitchRule() {
        }

        public Pattern<JoinNode> getPattern() {
            return PATTERN;
        }

        public Rule.Result apply(JoinNode node, Captures captures, Rule.Context context) {
            if (this.alreadyVisited.contains(node.getId())) {
                return Rule.Result.empty();
            }
            this.alreadyVisited.add(node.getId());
            return Rule.Result.ofPlanNode((PlanNode)node.flipChildren());
        }
    }
}

