/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.query;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.Type;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.LogicalExpression;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.planner.LogicalPlanner;
import io.trino.sql.planner.assertions.BasePlanTest;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.query.QueryAssertions;
import java.util.List;
import java.util.Map;
import org.assertj.core.api.AbstractBooleanAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Test;

public class TestFilteredAggregations
extends BasePlanTest {
    private final QueryAssertions assertions = new QueryAssertions();

    @AfterAll
    public void teardown() {
        this.assertions.close();
    }

    @Test
    public void testAddPredicateForFilterClauses() {
        ((QueryAssertions.QueryAssert)Assertions.assertThat(this.assertions.query("SELECT sum(x) FILTER(WHERE x > 0) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x)"))).matches("VALUES (BIGINT '10')");
        ((QueryAssertions.QueryAssert)Assertions.assertThat(this.assertions.query("SELECT sum(x) FILTER(WHERE x > 0), sum(x) FILTER(WHERE x < 3) FROM (VALUES 1, 1, 0, 5, 3, 8) t(x)"))).matches("VALUES (BIGINT '18', BIGINT '2')");
        ((QueryAssertions.QueryAssert)Assertions.assertThat(this.assertions.query("SELECT sum(x) FILTER(WHERE x > 1), sum(x) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x)"))).matches("VALUES (BIGINT '8', BIGINT '10')");
    }

    @Test
    public void testGroupAll() {
        ((QueryAssertions.QueryAssert)Assertions.assertThat(this.assertions.query("SELECT count(DISTINCT x) FILTER (WHERE x > 1) FROM (VALUES 1, 1, 1, 2, 3, 3) t(x)"))).matches("VALUES BIGINT '2'");
        ((QueryAssertions.QueryAssert)Assertions.assertThat(this.assertions.query("SELECT count(DISTINCT x) FILTER (WHERE x > 1), sum(DISTINCT x) FROM (VALUES 1, 1, 1, 2, 3, 3) t(x)"))).matches("VALUES (BIGINT '2', BIGINT '6')");
        ((QueryAssertions.QueryAssert)Assertions.assertThat(this.assertions.query("SELECT count(DISTINCT x) FILTER (WHERE x > 1), sum(DISTINCT y) FILTER (WHERE x < 3)FROM (VALUES (1, 10),(1, 20),(1, 20),(2, 20),(3, 30)) t(x, y)"))).matches("VALUES (BIGINT '2', BIGINT '30')");
        ((QueryAssertions.QueryAssert)Assertions.assertThat(this.assertions.query("SELECT count(x) FILTER (WHERE x > 1), sum(DISTINCT x) FROM (VALUES 1, 2, 3, 3) t(x)"))).matches("VALUES (BIGINT '3', BIGINT '6')");
    }

    @Test
    public void testGroupingSets() {
        ((QueryAssertions.QueryAssert)Assertions.assertThat(this.assertions.query("SELECT k, count(DISTINCT x) FILTER (WHERE y = 100), count(DISTINCT x) FILTER (WHERE y = 200) FROM (VALUES    (1, 1, 100),   (1, 1, 200),   (1, 2, 100),   (1, 3, 300),   (2, 1, 100),   (2, 10, 100),   (2, 20, 100),   (2, 20, 200),   (2, 30, 300),   (2, 40, 100)) t(k, x, y) GROUP BY GROUPING SETS ((), (k))"))).matches("VALUES (1, BIGINT '2', BIGINT '1'), (2, BIGINT '4', BIGINT '1'), (CAST(NULL AS INTEGER), BIGINT '5', BIGINT '2')");
    }

    @Test
    public void rewriteAddFilterWithMultipleFilters() {
        PlanMatchPattern source = PlanMatchPattern.tableScan("orders", (Map<String, String>)ImmutableMap.of((Object)"totalprice", (Object)"totalprice", (Object)"custkey", (Object)"custkey"));
        this.assertPlan("SELECT sum(totalprice) FILTER(WHERE totalprice > 0), sum(custkey) FILTER(WHERE custkey > 0) FROM orders", PlanMatchPattern.anyTree(PlanMatchPattern.filter((Expression)new LogicalExpression(LogicalExpression.Operator.OR, (List)ImmutableList.of((Object)new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, (Expression)new SymbolReference((Type)DoubleType.DOUBLE, "totalprice"), (Expression)new Constant((Type)DoubleType.DOUBLE, (Object)0.0)), (Object)new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, (Expression)new SymbolReference((Type)BigintType.BIGINT, "custkey"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)))), source)));
    }

    @Test
    public void testDoNotPushdownPredicateIfNonFilteredAggregateIsPresent() {
        this.assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE totalprice > 0), sum(custkey) FROM orders");
    }

    @Test
    public void testPushDownConstantFilterPredicate() {
        this.assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE FALSE) FROM orders");
        this.assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE TRUE) FROM orders");
    }

    @Test
    public void testNoFilterAddedForConstantValueFilters() {
        this.assertPlanContainsNoFilter("SELECT sum(x) FILTER(WHERE x > 0) FROM (VALUES 1, 1, 0, 2, 3, 3) t(x) GROUP BY x");
        this.assertPlanContainsNoFilter("SELECT sum(totalprice) FILTER(WHERE totalprice > 0) FROM orders GROUP BY totalprice");
    }

    private void assertPlanContainsNoFilter(String sql) {
        ((AbstractBooleanAssert)Assertions.assertThat((boolean)PlanNodeSearcher.searchFrom((PlanNode)this.plan(sql, LogicalPlanner.Stage.OPTIMIZED).getRoot()).whereIsInstanceOfAny(new Class[]{FilterNode.class}).matches()).describedAs("Unexpected node for query: " + sql, new Object[0])).isFalse();
    }
}

