/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.query;

import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.sql.planner.Plan;
import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.assertions.RvalueMatcher;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.query.QueryAssertions;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.testing.QueryRunner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import org.intellij.lang.annotations.Language;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

public class TestSubqueries {
    private static final String UNSUPPORTED_CORRELATED_SUBQUERY_ERROR_MSG = "line .*: Given correlated subquery is not supported";
    private QueryAssertions assertions;

    @BeforeClass
    public void init() {
        this.assertions = new QueryAssertions();
    }

    @AfterClass(alwaysRun=true)
    public void teardown() {
        this.assertions.close();
        this.assertions = null;
    }

    @Test(expectedExceptions={PrestoException.class}, expectedExceptionsMessageRegExp="line .*: Given correlated subquery is not supported")
    public void testCorrelatedSubqueriesWithDistinct() {
        QueryRunner runner = this.assertions.getQueryRunner();
        runner.execute(runner.getDefaultSession(), "select a from (values (1, 10), (2, 20)) t(a,b) where a in (select distinct c from (values 1) t2(c) where b in (10, 11))");
    }

    @Test
    public void testCorrelatedExistsSubqueriesWithOrPredicateAndNull() {
        this.assertExistsRewrittenToAggregationAboveJoin("SELECT EXISTS(SELECT 1 FROM (VALUES null, 10) t(x) WHERE y > x OR y + 10 > x) FROM (values (11)) t2(y)", "VALUES true", false);
        this.assertExistsRewrittenToAggregationAboveJoin("SELECT EXISTS(SELECT 1 FROM (VALUES null) t(x) WHERE y > x OR y + 10 > x) FROM (values (11)) t2(y)", "VALUES false", false);
    }

    @Test
    public void testUnsupportedSubqueriesWithCoercions() {
        this.assertions.assertFails("select (select count(*) from (values 1) t(a) where t.a=t2.b limit 1) from (values 1.0) t2(b)", UNSUPPORTED_CORRELATED_SUBQUERY_ERROR_MSG);
        this.assertions.assertFails("select EXISTS(select 1 from (values (null, null)) t(a, b) where t.a=t2.b GROUP BY t.b) from (values 1, 2) t2(b)", UNSUPPORTED_CORRELATED_SUBQUERY_ERROR_MSG);
    }

    @Test
    public void testCorrelatedSubqueriesWithLimit() {
        this.assertions.assertQuery("select (select t.a from (values 1, 2) t(a) where t.a=t2.b limit 1) from (values 1) t2(b)", "VALUES 1");
        this.assertions.assertFails("select (select t.a from (values 1, 2) t(a) where t.a=t2.b limit 2) from (values 1) t2(b)", UNSUPPORTED_CORRELATED_SUBQUERY_ERROR_MSG);
        this.assertions.assertQuery("select (select sum(t.a) from (values 1, 2) t(a) where t.a=t2.b group by t.a limit 2) from (values 1) t2(b)", "VALUES BIGINT '1'");
        this.assertions.assertQuery("select (select count(*) from (select t.a from (values 1, 1, null, 3) t(a) limit 1) t where t.a=t2.b) from (values 1, 2) t2(b)", "VALUES BIGINT '1', BIGINT '0'");
        this.assertExistsRewrittenToAggregationBelowJoin("select EXISTS(select 1 from (values 1, 1, 3) t(a) where t.a=t2.b limit 1) from (values 1, 2) t2(b)", "VALUES true, false", false);
        this.assertions.assertFails("select (select count(*) from (values 1, 1, 3) t(a) where t.a=t2.b limit 1) from (values 1) t2(b)", UNSUPPORTED_CORRELATED_SUBQUERY_ERROR_MSG);
        this.assertExistsRewrittenToAggregationBelowJoin("SELECT EXISTS(SELECT 1 FROM (values ('x', 1)) u(x, cid) WHERE x = 'x' AND t.cid = cid LIMIT 1) FROM (values 1) t(cid)", "VALUES true", false);
    }

    @Test
    public void testCorrelatedSubqueriesWithGroupBy() {
        this.assertions.assertFails("select (select count(*) from (values 1, 2, 3, null) t(a) where t.a<t2.b GROUP BY t.a) from (values 1, 2, 3) t2(b)", "Scalar sub-query has returned multiple rows");
        this.assertions.assertQuery("select (select count(*) from (values 1, 1, 2, 3, null) t(a) where t.a<t2.b GROUP BY t.a HAVING count(*) > 1) from (values 1, 2) t2(b)", "VALUES null, BIGINT '2'");
        this.assertExistsRewrittenToAggregationBelowJoin("select EXISTS(select 1 from (values 1, 1, 3) t(a) where t.a=t2.b GROUP BY t.a) from (values 1, 2) t2(b)", "VALUES true, false", false);
        this.assertExistsRewrittenToAggregationBelowJoin("select EXISTS(select 1 from (values (1, 2), (1, 2), (null, null), (3, 3)) t(a, b) where t.a=t2.b GROUP BY t.a, t.b) from (values 1, 2) t2(b)", "VALUES true, false", true);
        this.assertExistsRewrittenToAggregationAboveJoin("select EXISTS(select 1 from (values (1, 2), (1, 2), (null, null), (3, 3)) t(a, b) where t.a<t2.b GROUP BY t.a, t.b) from (values 1, 2) t2(b)", "VALUES false, true", true);
        this.assertions.assertFails("select EXISTS(select 1 from (values (1, 1), (1, 1), (null, null), (3, 3)) t(a, b) where t.a+t.b<t2.b GROUP BY t.a) from (values 1, 2) t2(b)", UNSUPPORTED_CORRELATED_SUBQUERY_ERROR_MSG);
        this.assertExistsRewrittenToAggregationAboveJoin("select EXISTS(select 1 from (values (1, 1), (1, 1), (null, null), (3, 3)) t(a, b) where t.a+t.b<t2.b GROUP BY t.a, t.b) from (values 1, 4) t2(b)", "VALUES false, true", true);
        this.assertExistsRewrittenToAggregationBelowJoin("select EXISTS(select 1 from (values (1, 2), (1, 2), (null, null), (3, 3)) t(a, b) where t.a=t2.b GROUP BY t.b) from (values 1, 2) t2(b)", "VALUES true, false", true);
        this.assertExistsRewrittenToAggregationBelowJoin("select EXISTS(select * from (values 1, 1, 2, 3) t(a) where t.a=t2.b GROUP BY t.a HAVING count(*) > 1) from (values 1, 2) t2(b)", "VALUES true, false", false);
        this.assertions.assertQuery("select EXISTS(select * from (select t.a from (values (1, 1), (1, 1), (1, 2), (1, 2), (3, 3)) t(a, b) where t.b=t2.b GROUP BY t.a HAVING count(*) > 1) t where t.a=t2.b) from (values 1, 2) t2(b)", "VALUES true, false");
        this.assertExistsRewrittenToAggregationBelowJoin("select EXISTS(select * from (values 1, 1, 2, 3) t(a) where t.a=t2.b GROUP BY (t.a) HAVING count(*) > 1) from (values 1, 2) t2(b)", "VALUES true, false", false);
    }

    @Test
    public void testCorrelatedLateralWithGroupBy() {
        this.assertions.assertQuery("select * from (values 1, 2) t2(b), LATERAL (select t.a from (values 1, 1, 3) t(a) where t.a=t2.b GROUP BY t.a)", "VALUES (1, 1)");
        this.assertions.assertQuery("select * from (values 1, 2) t2(b), LATERAL (select count(*) from (values 1, 1, 2, 3) t(a) where t.a=t2.b GROUP BY t.a HAVING count(*) > 1)", "VALUES (1, BIGINT '2')");
        this.assertions.assertFails("select * from (values 1, 2) t2(b), LATERAL (select t.a, t.b, count(*) from (values (1, 1), (1, 2), (2, 2), (3, 3)) t(a, b) where t.a=t2.b GROUP BY GROUPING SETS ((t.a, t.b), (t.a)))", UNSUPPORTED_CORRELATED_SUBQUERY_ERROR_MSG);
    }

    @Test
    public void testLateralWithUnnest() {
        this.assertions.assertFails("SELECT * FROM (VALUES ARRAY[1]) t(x), LATERAL (SELECT * FROM UNNEST(x))", UNSUPPORTED_CORRELATED_SUBQUERY_ERROR_MSG);
    }

    @Test
    public void testCorrelatedScalarSubquery() {
        this.assertions.assertQuery("SELECT * FROM (VALUES 1, 2) t2(b) WHERE (SELECT b) = 2", "VALUES 2");
    }

    @Test
    public void testCorrelatedSubqueryWithExplicitCoercion() {
        this.assertions.assertQuery("SELECT 1 FROM (VALUES 1, 2) t1(b) WHERE 1 = (SELECT cast(b as decimal(7,2)))", "VALUES 1");
    }

    private void assertExistsRewrittenToAggregationBelowJoin(@Language(value="SQL") String actual, @Language(value="SQL") String expected, boolean extraAggregation) {
        PlanMatchPattern source = PlanMatchPattern.node(ValuesNode.class, new PlanMatchPattern[0]);
        if (extraAggregation) {
            source = PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(), PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(), PlanMatchPattern.anyTree(PlanMatchPattern.node(ValuesNode.class, new PlanMatchPattern[0])))));
        }
        this.assertions.assertQueryAndPlan(actual, expected, PlanMatchPattern.anyTree(PlanMatchPattern.node(JoinNode.class, PlanMatchPattern.anyTree(PlanMatchPattern.node(ValuesNode.class, new PlanMatchPattern[0])), PlanMatchPattern.anyTree(PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(), AggregationNode.Step.FINAL, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of(), AggregationNode.Step.PARTIAL, PlanMatchPattern.anyTree(source))))))), plan -> Assert.assertEquals((int)TestSubqueries.countFinalAggregationNodes(plan), (int)(extraAggregation ? 2 : 1)));
    }

    private void assertExistsRewrittenToAggregationAboveJoin(@Language(value="SQL") String actual, @Language(value="SQL") String expected, boolean extraAggregation) {
        Consumer<Plan> singleStreamingAggregationValidator = plan -> Assert.assertEquals((int)TestSubqueries.countSingleStreamingAggregations(plan), (int)1);
        Consumer<Plan> finalAggregationValidator = plan -> Assert.assertEquals((int)TestSubqueries.countFinalAggregationNodes(plan), (int)(extraAggregation ? 1 : 0));
        this.assertions.assertQueryAndPlan(actual, expected, PlanMatchPattern.anyTree(PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<FunctionCall>>)ImmutableMap.of((Object)"COUNT", PlanMatchPattern.functionCall("count", (List<String>)ImmutableList.of((Object)"NON_NULL"))), AggregationNode.Step.SINGLE, PlanMatchPattern.node(JoinNode.class, PlanMatchPattern.anyTree(PlanMatchPattern.node(ValuesNode.class, new PlanMatchPattern[0])), PlanMatchPattern.anyTree(PlanMatchPattern.node(ProjectNode.class, PlanMatchPattern.anyTree(PlanMatchPattern.node(ValuesNode.class, new PlanMatchPattern[0]))).withAlias("NON_NULL", (RvalueMatcher)PlanMatchPattern.expression("true")))))), singleStreamingAggregationValidator.andThen(finalAggregationValidator));
    }

    private static int countFinalAggregationNodes(Plan plan) {
        return PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(node -> node instanceof AggregationNode && ((AggregationNode)node).getStep() == AggregationNode.Step.FINAL).count();
    }

    private static int countSingleStreamingAggregations(Plan plan) {
        return PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(node -> node instanceof AggregationNode && ((AggregationNode)node).getStep() == AggregationNode.Step.SINGLE && ((AggregationNode)node).isStreamable()).count();
    }
}

