/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.TopNNode;
import com.facebook.presto.sql.Optimizer;
import com.facebook.presto.sql.planner.Plan;
import com.facebook.presto.sql.planner.assertions.BasePlanTest;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.google.common.collect.Iterables;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.testng.Assert;
import org.testng.annotations.Test;

public class TestUnion
extends BasePlanTest {
    public TestUnion() {
    }

    public TestUnion(Map<String, String> sessionProperties) {
        super(sessionProperties);
    }

    @Test
    public void testSimpleUnion() {
        Plan plan = this.plan("SELECT suppkey FROM supplier UNION ALL SELECT nationkey FROM nation", Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, false);
        List remotes = PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(TestUnion::isRemoteExchange).findAll();
        Assert.assertEquals((int)remotes.size(), (int)1, (String)"There should be exactly one RemoteExchange");
        Assert.assertEquals((Object)((ExchangeNode)Iterables.getOnlyElement((Iterable)remotes)).getType(), (Object)ExchangeNode.Type.GATHER);
        this.assertPlanIsFullyDistributed(plan);
    }

    @Test
    public void testUnionUnderTopN() {
        Plan plan = this.plan("SELECT * FROM (   SELECT regionkey FROM nation    UNION ALL    SELECT nationkey FROM nation) t(a) ORDER BY a LIMIT 1", Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, false);
        List remotes = PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(TestUnion::isRemoteExchange).findAll();
        Assert.assertEquals((int)remotes.size(), (int)1, (String)"There should be exactly one RemoteExchange");
        Assert.assertEquals((Object)((ExchangeNode)Iterables.getOnlyElement((Iterable)remotes)).getType(), (Object)ExchangeNode.Type.GATHER);
        int numberOfPartialTopN = PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(planNode -> planNode instanceof TopNNode && ((TopNNode)planNode).getStep().equals((Object)TopNNode.Step.PARTIAL)).count();
        Assert.assertEquals((int)numberOfPartialTopN, (int)2, (String)"There should be exactly two partial TopN nodes");
        this.assertPlanIsFullyDistributed(plan);
    }

    @Test
    public void testUnionOverSingleNodeAggregationAndUnion() {
        Plan plan = this.plan("SELECT count(*) FROM (SELECT 1 FROM nation GROUP BY regionkey UNION ALL (   SELECT 1 FROM nation    UNION ALL    SELECT 1 FROM nation))", Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, false);
        List remotes = PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(TestUnion::isRemoteExchange).findAll();
        Assert.assertEquals((int)remotes.size(), (int)2, (String)"There should be exactly two RemoteExchanges");
        Assert.assertEquals((Object)((ExchangeNode)remotes.get(0)).getType(), (Object)ExchangeNode.Type.GATHER);
        Assert.assertEquals((Object)((ExchangeNode)remotes.get(1)).getType(), (Object)ExchangeNode.Type.REPARTITION);
    }

    @Test
    public void testPartialAggregationsWithUnion() {
        Plan plan = this.plan("SELECT orderstatus, sum(orderkey) FROM (SELECT orderkey, orderstatus FROM orders UNION ALL SELECT orderkey, orderstatus FROM orders) x GROUP BY (orderstatus)", Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, false);
        TestUnion.assertAtMostOneAggregationBetweenRemoteExchanges(plan);
        this.assertPlanIsFullyDistributed(plan);
    }

    @Test
    public void testPartialRollupAggregationsWithUnion() {
        Plan plan = this.plan("SELECT orderstatus, sum(orderkey) FROM (SELECT orderkey, orderstatus FROM orders UNION ALL SELECT orderkey, orderstatus FROM orders) x GROUP BY ROLLUP (orderstatus)", Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, false);
        TestUnion.assertAtMostOneAggregationBetweenRemoteExchanges(plan);
        this.assertPlanIsFullyDistributed(plan);
    }

    @Test
    public void testAggregationWithUnionAndValues() {
        Plan plan = this.plan("SELECT regionkey, count(*) FROM (SELECT regionkey FROM nation UNION ALL SELECT * FROM (VALUES 2, 100) t(regionkey)) GROUP BY regionkey", Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, false);
        TestUnion.assertAtMostOneAggregationBetweenRemoteExchanges(plan);
    }

    @Test
    public void testUnionOnProbeSide() {
        Plan plan = this.plan("SELECT * FROM (SELECT * FROM nation UNION ALL SELECT * from nation) n, region r WHERE n.regionkey=r.regionkey", Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, false);
        this.assertPlanIsFullyDistributed(plan);
    }

    private void assertPlanIsFullyDistributed(Plan plan) {
        int numberOfGathers = PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(TestUnion::isRemoteGatheringExchange).findAll().size();
        if (numberOfGathers == 0) {
            return;
        }
        Assert.assertTrue((boolean)PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).recurseOnlyWhen(TestUnion::isNotRemoteGatheringExchange).findAll().stream().noneMatch(this::shouldBeDistributed), (String)"There is a node that should be distributed between output and first REMOTE GATHER ExchangeNode");
        Assert.assertEquals((int)numberOfGathers, (int)1, (String)"Only a single REMOTE GATHER was expected");
    }

    private boolean shouldBeDistributed(PlanNode planNode) {
        if (planNode instanceof JoinNode) {
            return true;
        }
        if (planNode instanceof AggregationNode) {
            return true;
        }
        if (planNode instanceof TopNNode) {
            return ((TopNNode)planNode).getStep() == TopNNode.Step.PARTIAL;
        }
        return false;
    }

    private static void assertAtMostOneAggregationBetweenRemoteExchanges(Plan plan) {
        List fragments = PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(TestUnion::isRemoteExchange).findAll().stream().flatMap(exchangeNode -> exchangeNode.getSources().stream()).collect(Collectors.toList());
        for (PlanNode fragment : fragments) {
            List aggregations = PlanNodeSearcher.searchFrom((PlanNode)fragment).where(AggregationNode.class::isInstance).recurseOnlyWhen(TestUnion::isNotRemoteExchange).findAll();
            Assert.assertFalse((aggregations.size() > 1 ? 1 : 0) != 0, (String)"More than a single AggregationNode between remote exchanges");
        }
    }

    private static boolean isNotRemoteGatheringExchange(PlanNode planNode) {
        return !TestUnion.isRemoteGatheringExchange(planNode);
    }

    private static boolean isRemoteGatheringExchange(PlanNode planNode) {
        return TestUnion.isRemoteExchange(planNode) && ((ExchangeNode)planNode).getType().equals((Object)ExchangeNode.Type.GATHER);
    }

    private static boolean isNotRemoteExchange(PlanNode planNode) {
        return !TestUnion.isRemoteExchange(planNode);
    }

    private static boolean isRemoteExchange(PlanNode planNode) {
        return planNode instanceof ExchangeNode && ((ExchangeNode)planNode).getScope().isRemote();
    }
}

