/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.Session;
import io.prestosql.connector.CatalogName;
import io.prestosql.connector.MockConnectorColumnHandle;
import io.prestosql.connector.MockConnectorFactory;
import io.prestosql.connector.MockConnectorTableHandle;
import io.prestosql.connector.MockConnectorTransactionHandle;
import io.prestosql.metadata.TableHandle;
import io.prestosql.spi.Plugin;
import io.prestosql.spi.connector.AggregationApplicationResult;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ConnectorFactory;
import io.prestosql.spi.connector.ConnectorTableHandle;
import io.prestosql.spi.connector.ConnectorTransactionHandle;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.assertions.PlanMatchPattern;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.iterative.rule.PushDistinctLimitIntoTableScan;
import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.testing.LocalQueryRunner;
import io.prestosql.testing.TestingSession;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.AtomicIntegerAssert;
import org.assertj.core.api.AtomicReferenceAssert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

@Test(singleThreaded=true)
public class TestPushDistinctLimitIntoTableScan
extends BaseRuleTest {
    private static final CatalogName TEST_CATALOG = new CatalogName("test_push_dl_catalog");
    private PushDistinctLimitIntoTableScan rule;
    private TableHandle tableHandle;
    private MockConnectorFactory.ApplyAggregation testApplyAggregation;

    public TestPushDistinctLimitIntoTableScan() {
        super(new Plugin[0]);
    }

    @Override
    protected Optional<LocalQueryRunner> createLocalQueryRunner() {
        Session defaultSession = TestingSession.testSessionBuilder().setCatalog(TEST_CATALOG.getCatalogName()).setSchema("tiny").build();
        LocalQueryRunner queryRunner = LocalQueryRunner.create((Session)defaultSession);
        queryRunner.createCatalog(TEST_CATALOG.getCatalogName(), (ConnectorFactory)MockConnectorFactory.builder().withApplyAggregation((session, handle, aggregates, assignments, groupingSets) -> {
            if (this.testApplyAggregation != null) {
                return this.testApplyAggregation.apply(session, handle, aggregates, assignments, groupingSets);
            }
            return Optional.empty();
        }).build(), Map.of());
        return Optional.of(queryRunner);
    }

    @BeforeClass
    public void init() {
        this.rule = new PushDistinctLimitIntoTableScan(this.tester().getMetadata());
        this.tableHandle = new TableHandle(TEST_CATALOG, (ConnectorTableHandle)new MockConnectorTableHandle(new SchemaTableName("mock_schema", "mock_nation")), (ConnectorTransactionHandle)MockConnectorTransactionHandle.INSTANCE, Optional.empty());
    }

    @BeforeMethod
    public void reset() {
        this.testApplyAggregation = null;
    }

    @Test
    public void testDoesNotFireIfNoTableScan() {
        this.tester().assertThat((Rule<?>)this.rule).on(p -> p.values(p.symbol("a", (Type)BigintType.BIGINT))).doesNotFire();
    }

    @Test
    public void testNoEffect() {
        AtomicInteger applyCallCounter = new AtomicInteger();
        this.testApplyAggregation = (session, handle, aggregates, assignments, groupingSets) -> {
            applyCallCounter.incrementAndGet();
            return Optional.empty();
        };
        this.tester().assertThat((Rule<?>)this.rule).on(p -> {
            Symbol regionkey = p.symbol("regionkey");
            return p.distinctLimit(10L, List.of(regionkey), (PlanNode)p.tableScan(this.tableHandle, (List<Symbol>)ImmutableList.of((Object)regionkey), (Map<Symbol, ColumnHandle>)ImmutableMap.of((Object)regionkey, (Object)new MockConnectorColumnHandle("regionkey", (Type)BigintType.BIGINT))));
        }).doesNotFire();
        ((AtomicIntegerAssert)Assertions.assertThat((AtomicInteger)applyCallCounter).as("applyCallCounter", new Object[0])).hasValue(1);
    }

    @Test
    public void testPushDistinct() {
        AtomicInteger applyCallCounter = new AtomicInteger();
        AtomicReference applyAggregates = new AtomicReference();
        AtomicReference applyAssignments = new AtomicReference();
        AtomicReference applyGroupingSets = new AtomicReference();
        this.testApplyAggregation = (session, handle, aggregates, assignments, groupingSets) -> {
            applyCallCounter.incrementAndGet();
            applyAggregates.set(List.copyOf(aggregates));
            applyAssignments.set(Map.copyOf(assignments));
            applyGroupingSets.set(groupingSets.stream().map(List::copyOf).collect(Collectors.toUnmodifiableList()));
            return Optional.of(new AggregationApplicationResult((Object)new MockConnectorTableHandle(new SchemaTableName("mock_schema", "mock_nation_aggregated")), List.of(), List.of(), Map.of()));
        };
        MockConnectorColumnHandle regionkeyHandle = new MockConnectorColumnHandle("regionkey", (Type)BigintType.BIGINT);
        this.tester().assertThat((Rule<?>)this.rule).on(p -> {
            Symbol regionkeySymbol = p.symbol("regionkey_symbol");
            return p.distinctLimit(43L, List.of(regionkeySymbol), (PlanNode)p.tableScan(this.tableHandle, (List<Symbol>)ImmutableList.of((Object)regionkeySymbol), (Map<Symbol, ColumnHandle>)ImmutableMap.of((Object)regionkeySymbol, (Object)regionkeyHandle)));
        }).matches(PlanMatchPattern.limit(43L, PlanMatchPattern.project(PlanMatchPattern.tableScan("mock_nation_aggregated"))));
        ((AtomicIntegerAssert)Assertions.assertThat((AtomicInteger)applyCallCounter).as("applyCallCounter", new Object[0])).hasValue(1);
        ((AtomicReferenceAssert)Assertions.assertThat(applyAggregates).as("applyAggregates", new Object[0])).hasValue(List.of());
        ((AtomicReferenceAssert)Assertions.assertThat(applyAssignments).as("applyAssignments", new Object[0])).hasValue(Map.of("regionkey_symbol", regionkeyHandle));
        ((AtomicReferenceAssert)Assertions.assertThat(applyGroupingSets).as("applyGroupingSets", new Object[0])).hasValue(List.of(List.of(regionkeyHandle)));
    }
}

