/*
 * Decompiled with CFR 0.152.
 */
package io.trino.tests;

import io.trino.Session;
import io.trino.plugin.memory.MemoryPlugin;
import io.trino.spi.Plugin;
import io.trino.sql.planner.Plan;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.testing.AbstractTestAggregations;
import io.trino.testing.DistributedQueryRunner;
import io.trino.testing.QueryRunner;
import io.trino.testing.TestingSession;
import io.trino.tests.tpch.TpchQueryRunner;
import java.util.function.Predicate;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public class TestAggregations
extends AbstractTestAggregations {
    private final Session memorySession = TestingSession.testSessionBuilder().setCatalog("memory").setSchema("default").build();

    protected QueryRunner createQueryRunner() throws Exception {
        DistributedQueryRunner queryRunner = TpchQueryRunner.builder().build();
        queryRunner.installPlugin((Plugin)new MemoryPlugin());
        queryRunner.createCatalog("memory", "memory");
        queryRunner.execute(this.memorySession, "CREATE TABLE test_table (key VARCHAR, sequence BIGINT, value DECIMAL(2, 0))");
        queryRunner.execute(this.memorySession, "INSERT INTO test_table VALUES ('a', 0, 0),('a', 0, 1),('a', 1, 2),('a', 1, 3),('b', 0, 10),('b', 0, 11),('b', 1, 13),('b', 1, 14)");
        return queryRunner;
    }

    @Test
    public void testPreAggregate() {
        this.assertQuery(this.memorySession, "SELECT key, sum(CASE WHEN sequence = 0 THEN value END), sum(CASE WHEN sequence = 2 THEN value END), min(CASE WHEN sequence = 1 THEN value ELSE null END), max(CASE WHEN sequence = 0 THEN value END), sum(CASE WHEN sequence = 1 THEN cast(value * 2 as real) ELSE cast(0 as real) END) FROM test_table GROUP BY key", "VALUES ('a', 1, null, 2, 1, 10), ('b', 21, null, 13, 11, 54)", plan -> this.assertAggregationNodeCount((Plan)plan, 4));
        this.assertQuery(this.memorySession, "SELECT sum(CASE WHEN sequence = 0 THEN value END), sum(CASE WHEN sequence = 2 THEN value END), min(CASE WHEN sequence = 1 THEN value ELSE null END), max(CASE WHEN sequence = 0 THEN value END), sum(CASE WHEN sequence = 1 THEN value * 2 ELSE 0 END) FROM test_table", "VALUES (22, null, 2, 11, 64)", plan -> this.assertAggregationNodeCount((Plan)plan, 4));
        this.assertQuery(this.memorySession, "SELECT key, sum(CASE WHEN sequence = 0 THEN value END), sum(CASE WHEN sequence = 2 THEN value END), min(CASE WHEN sequence = 1 THEN value ELSE null END), max(CASE WHEN sequence = 0 THEN value END), sum(CASE WHEN sequence = 1 THEN value * 2 ELSE 1 END) FROM test_table GROUP BY key", "VALUES ('a', 1, null, 2, 1, 12), ('b', 21, null, 13, 11, 56)", plan -> this.assertAggregationNodeCount((Plan)plan, 2));
        this.assertQuery(this.memorySession, "SELECT key, sum(CASE WHEN sequence = 0 THEN value END), sum(CASE WHEN sequence = 2 THEN value END), min(CASE WHEN sequence = 1 THEN value ELSE null END), max(CASE WHEN sequence = 0 THEN value END), max(CASE WHEN sequence = 1 THEN value * 2 ELSE 100 END) FROM test_table GROUP BY key", "VALUES ('a', 1, null, 2, 1, 100), ('b', 21, null, 13, 11, 100)", plan -> this.assertAggregationNodeCount((Plan)plan, 2));
        this.assertQuery(this.memorySession, "SELECT key, sum(CASE WHEN sequence = 42 THEN value ELSE 0 END), sum(CASE WHEN sequence = 42 THEN value END), sum(CASE WHEN sequence = 24 THEN cast(value * 2 as real) ELSE cast(0 as real) END), sum(CASE WHEN sequence = 24 THEN cast(value * 2 as real) END) FROM test_table GROUP BY key", "VALUES ('a', 0, null, 0, null), ('b', 0, null, 0, null)", plan -> this.assertAggregationNodeCount((Plan)plan, 4));
        this.assertQuery(this.memorySession, "SELECT sum(CASE WHEN sequence = 42 THEN value ELSE 0 END), sum(CASE WHEN sequence = 42 THEN value END), sum(CASE WHEN sequence = 24 THEN cast(value * 2 as real) ELSE cast(0 as real) END), sum(CASE WHEN sequence = 24 THEN cast(value * 2 as real) END) FROM test_table", "VALUES (0, null, 0, null)", plan -> this.assertAggregationNodeCount((Plan)plan, 4));
    }

    @Test
    public void testPreAggregateWithFilter() {
        this.assertQuery(this.memorySession, "SELECT sum(CASE WHEN sequence = 0 THEN value END), sum(CASE WHEN sequence = 2 THEN value END), min(CASE WHEN sequence = 1 THEN value ELSE null END), max(CASE WHEN sequence = 0 THEN value END), sum(CASE WHEN sequence = 1 THEN value * 2 ELSE 0 END) FROM test_table WHERE sequence = 42", "VALUES (null, null, null, null, null)", plan -> this.assertAggregationNodeCount((Plan)plan, 4));
    }

    private void assertAggregationNodeCount(Plan plan, int count) {
        Assertions.assertThat((int)TestAggregations.countOfMatchingNodes(plan, AggregationNode.class::isInstance)).isEqualTo(count);
    }

    private static int countOfMatchingNodes(Plan plan, Predicate<PlanNode> predicate) {
        return PlanNodeSearcher.searchFrom((PlanNode)plan.getRoot()).where(predicate).count();
    }
}

