/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.collect.Iterables;
import io.trino.sql.planner.LogicalPlanner;
import io.trino.sql.planner.Plan;
import io.trino.sql.planner.assertions.BasePlanTest;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.TopNNode;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.assertj.core.api.AbstractBooleanAssert;
import org.assertj.core.api.AbstractIntegerAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

abstract class BaseTestUnion
extends BasePlanTest {
    protected BaseTestUnion(Map<String, String> sessionProperties) {
        super(sessionProperties);
    }

    @Test
    public void testSimpleUnion() {
        Plan plan = this.plan("SELECT suppkey FROM supplier UNION ALL SELECT nationkey FROM nation", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false);
        List remotes = PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(BaseTestUnion::isRemoteExchange).findAll();
        ((AbstractIntegerAssert)Assertions.assertThat((int)remotes.size()).describedAs("There should be exactly one RemoteExchange", new Object[0])).isEqualTo(1);
        Assertions.assertThat((Comparable)((ExchangeNode)Iterables.getOnlyElement((Iterable)remotes)).getType()).isEqualTo((Object)ExchangeNode.Type.GATHER);
        this.assertPlanIsFullyDistributed(plan);
    }

    @Test
    public void testUnionUnderTopN() {
        Plan plan = this.plan("SELECT * FROM (   SELECT regionkey FROM nation    UNION ALL    SELECT nationkey FROM nation) t(a) ORDER BY a LIMIT 1", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false);
        List remotes = PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(BaseTestUnion::isRemoteExchange).findAll();
        ((AbstractIntegerAssert)Assertions.assertThat((int)remotes.size()).describedAs("There should be exactly one RemoteExchange", new Object[0])).isEqualTo(1);
        Assertions.assertThat((Comparable)((ExchangeNode)Iterables.getOnlyElement((Iterable)remotes)).getType()).isEqualTo((Object)ExchangeNode.Type.GATHER);
        int numberOfpartialTopN = PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(planNode -> planNode instanceof TopNNode && ((TopNNode)planNode).getStep() == TopNNode.Step.PARTIAL).count();
        ((AbstractIntegerAssert)Assertions.assertThat((int)numberOfpartialTopN).describedAs("There should be exactly two partial TopN nodes", new Object[0])).isEqualTo(2);
        this.assertPlanIsFullyDistributed(plan);
    }

    @Test
    public void testUnionOverSingleNodeAggregationAndUnion() {
        Plan plan = this.plan("SELECT count(*) FROM (SELECT 1 FROM nation GROUP BY regionkey UNION ALL (   SELECT 1 FROM nation    UNION ALL    SELECT 1 FROM nation))", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false);
        List remotes = PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(BaseTestUnion::isRemoteExchange).findAll();
        ((AbstractIntegerAssert)Assertions.assertThat((int)remotes.size()).describedAs("There should be exactly two RemoteExchanges", new Object[0])).isEqualTo(2);
        Assertions.assertThat((Comparable)((ExchangeNode)remotes.get(0)).getType()).isEqualTo((Object)ExchangeNode.Type.GATHER);
        Assertions.assertThat((Comparable)((ExchangeNode)remotes.get(1)).getType()).isEqualTo((Object)ExchangeNode.Type.REPARTITION);
    }

    @Test
    public void testPartialAggregationsWithUnion() {
        Plan plan = this.plan("SELECT orderstatus, sum(orderkey) FROM (SELECT orderkey, orderstatus FROM orders UNION ALL SELECT orderkey, orderstatus FROM orders) x GROUP BY (orderstatus)", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false);
        BaseTestUnion.assertAtMostOneAggregationBetweenRemoteExchanges(plan);
        this.assertPlanIsFullyDistributed(plan);
    }

    @Test
    public void testPartialRollupAggregationsWithUnion() {
        Plan plan = this.plan("SELECT orderstatus, sum(orderkey) FROM (SELECT orderkey, orderstatus FROM orders UNION ALL SELECT orderkey, orderstatus FROM orders) x GROUP BY ROLLUP (orderstatus)", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false);
        BaseTestUnion.assertAtMostOneAggregationBetweenRemoteExchanges(plan);
        this.assertPlanIsFullyDistributed(plan);
    }

    @Test
    public void testAggregationWithUnionAndValues() {
        Plan plan = this.plan("SELECT regionkey, count(*) FROM (SELECT regionkey FROM nation UNION ALL SELECT * FROM (VALUES 2, 100) t(regionkey)) GROUP BY regionkey", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false);
        BaseTestUnion.assertAtMostOneAggregationBetweenRemoteExchanges(plan);
    }

    @Test
    public void testUnionOnProbeSide() {
        Plan plan = this.plan("SELECT * FROM (SELECT * FROM nation UNION ALL SELECT * from nation) n, region r WHERE n.regionkey=r.regionkey", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false);
        this.assertPlanIsFullyDistributed(plan);
    }

    private void assertPlanIsFullyDistributed(Plan plan) {
        int numberOfGathers = PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(BaseTestUnion::isRemoteGatheringExchange).findAll().size();
        if (numberOfGathers == 0) {
            return;
        }
        ((AbstractBooleanAssert)Assertions.assertThat((boolean)PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).recurseOnlyWhen(BaseTestUnion::isNotRemoteGatheringExchange).findAll().stream().noneMatch(this::shouldBeDistributed)).describedAs("There is a node that should be distributed between output and first REMOTE GATHER ExchangeNode", new Object[0])).isTrue();
        ((AbstractIntegerAssert)Assertions.assertThat((int)numberOfGathers).describedAs("Only a single REMOTE GATHER was expected", new Object[0])).isEqualTo(1);
    }

    private boolean shouldBeDistributed(PlanNode planNode) {
        if (planNode instanceof JoinNode) {
            return true;
        }
        if (planNode instanceof AggregationNode) {
            return true;
        }
        if (planNode instanceof TopNNode) {
            return ((TopNNode)planNode).getStep() == TopNNode.Step.PARTIAL;
        }
        return false;
    }

    private static void assertAtMostOneAggregationBetweenRemoteExchanges(Plan plan) {
        List fragments = PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(BaseTestUnion::isRemoteExchange).findAll().stream().flatMap(exchangeNode -> exchangeNode.getSources().stream()).collect(Collectors.toList());
        for (PlanNode fragment : fragments) {
            List aggregations = PlanNodeSearcher.searchFrom((PlanNode)fragment).where(AggregationNode.class::isInstance).recurseOnlyWhen(BaseTestUnion::isNotRemoteExchange).findAll();
            ((AbstractBooleanAssert)Assertions.assertThat((aggregations.size() > 1 ? 1 : 0) != 0).describedAs("More than a single AggregationNode between remote exchanges", new Object[0])).isFalse();
        }
    }

    private static boolean isNotRemoteGatheringExchange(PlanNode planNode) {
        return !BaseTestUnion.isRemoteGatheringExchange(planNode);
    }

    private static boolean isRemoteGatheringExchange(PlanNode planNode) {
        return BaseTestUnion.isRemoteExchange(planNode) && ((ExchangeNode)planNode).getType() == ExchangeNode.Type.GATHER;
    }

    private static boolean isNotRemoteExchange(PlanNode planNode) {
        return !BaseTestUnion.isRemoteExchange(planNode);
    }

    private static boolean isRemoteExchange(PlanNode planNode) {
        return planNode instanceof ExchangeNode && ((ExchangeNode)planNode).getScope() == ExchangeNode.Scope.REMOTE;
    }
}

