/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.connector.MockConnectorColumnHandle;
import io.trino.connector.MockConnectorFactory;
import io.trino.connector.MockConnectorTableHandle;
import io.trino.spi.connector.BucketFunction;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.ConnectorBucketNodeMap;
import io.trino.spi.connector.ConnectorFactory;
import io.trino.spi.connector.ConnectorNodePartitioningProvider;
import io.trino.spi.connector.ConnectorPartitioningHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorSplit;
import io.trino.spi.connector.ConnectorTablePartitioning;
import io.trino.spi.connector.ConnectorTableProperties;
import io.trino.spi.connector.ConnectorTransactionHandle;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.planner.OptimizerConfig;
import io.trino.sql.planner.assertions.BasePlanTest;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.testing.LocalQueryRunner;
import io.trino.testing.TestingSession;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.ToIntFunction;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

public class TestColocatedJoin
extends BasePlanTest {
    private static final String TABLE_NAME = "orders";
    private static final String CATALOG_NAME = "mock";
    private static final String SCHEMA_NAME = "default";
    private static final ConnectorPartitioningHandle PARTITIONING_HANDLE = new ConnectorPartitioningHandle(){};
    private static final int BUCKET_COUNT = 10;
    private static final String COLUMN_A = "column_a";
    private static final String COLUMN_B = "column_b";

    @Override
    protected LocalQueryRunner createLocalQueryRunner() {
        MockConnectorFactory connectorFactory = MockConnectorFactory.builder().withGetTableHandle((session, tableName) -> {
            if (tableName.getTableName().equals(TABLE_NAME)) {
                return new MockConnectorTableHandle((SchemaTableName)tableName);
            }
            return null;
        }).withPartitionProvider(new TestPartitioningProvider()).withGetColumns(schemaTableName -> ImmutableList.of((Object)new ColumnMetadata(COLUMN_A, (Type)BigintType.BIGINT), (Object)new ColumnMetadata(COLUMN_B, (Type)VarcharType.VARCHAR))).withName(CATALOG_NAME).withGetTableProperties((session, tableHandle) -> new ConnectorTableProperties(TupleDomain.all(), Optional.of(new ConnectorTablePartitioning(PARTITIONING_HANDLE, (List)ImmutableList.of((Object)new MockConnectorColumnHandle(COLUMN_A, (Type)BigintType.BIGINT)))), Optional.empty(), Optional.empty(), (List)ImmutableList.of())).build();
        Session session2 = TestingSession.testSessionBuilder().setCatalog(CATALOG_NAME).setSchema(SCHEMA_NAME).build();
        LocalQueryRunner queryRunner = LocalQueryRunner.create((Session)session2);
        queryRunner.createCatalog(CATALOG_NAME, (ConnectorFactory)connectorFactory, (Map)ImmutableMap.of());
        return queryRunner;
    }

    @DataProvider(name="colocated_join_enabled")
    public Object[][] colocatedJoinEnabled() {
        return new Object[][]{{true}, {false}};
    }

    @Test(dataProvider="colocated_join_enabled")
    public void testColocatedJoinWhenNumberOfBucketsInTableScanIsNotSufficient(boolean colocatedJoinEnabled) {
        this.assertDistributedPlan("    SELECT\n        orders.column_a,\n        orders.column_b\n    FROM (\n        SELECT\n            column_a,\n            ARBITRARY(column_b) AS column_b,\n            COUNT(*)\n        FROM orders\n        GROUP BY\n            column_a\n        ) t,\n        orders\n        WHERE\n            orders.column_a = t.column_a\n        AND orders.column_b = t.column_b\n", this.prepareSession(20.0, colocatedJoinEnabled), PlanMatchPattern.anyTree(PlanMatchPattern.project(PlanMatchPattern.anyTree(PlanMatchPattern.tableScan(TABLE_NAME))), PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, PlanMatchPattern.project(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan(TABLE_NAME)))))));
    }

    @Test
    public void testColocatedJoinWhenNumberOfBucketsInTableScanIsSufficient() {
        this.assertDistributedPlan("    SELECT\n        orders.column_a,\n        orders.column_b\n    FROM (\n        SELECT\n            column_a,\n            ARBITRARY(column_b) AS column_b,\n            COUNT(*)\n        FROM orders\n        GROUP BY\n            column_a\n        ) t,\n        orders\n        WHERE\n            orders.column_a = t.column_a\n            AND orders.column_b = t.column_b\n", this.prepareSession(0.01, true), PlanMatchPattern.anyTree(PlanMatchPattern.project(PlanMatchPattern.anyTree(PlanMatchPattern.tableScan(TABLE_NAME))), PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, PlanMatchPattern.project(PlanMatchPattern.tableScan(TABLE_NAME)))));
    }

    private Session prepareSession(double tableScanNodePartitioningMinBucketToTaskRatio, boolean colocatedJoinEnabled) {
        return Session.builder((Session)this.getQueryRunner().getDefaultSession()).setSystemProperty("join_reordering_strategy", OptimizerConfig.JoinReorderingStrategy.NONE.name()).setSystemProperty("join_distribution_type", OptimizerConfig.JoinDistributionType.BROADCAST.name()).setSystemProperty("task_concurrency", "16").setSystemProperty("table_scan_node_partitioning_min_bucket_to_task_ratio", Double.toString(tableScanNodePartitioningMinBucketToTaskRatio)).setSystemProperty("colocated_join", Boolean.toString(colocatedJoinEnabled)).build();
    }

    public static class TestPartitioningProvider
    implements ConnectorNodePartitioningProvider {
        public Optional<ConnectorBucketNodeMap> getBucketNodeMapping(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) {
            if (partitioningHandle.equals(PARTITIONING_HANDLE)) {
                return Optional.of(ConnectorBucketNodeMap.createBucketNodeMap((int)10));
            }
            throw new IllegalArgumentException();
        }

        public ToIntFunction<ConnectorSplit> getSplitBucketFunction(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) {
            throw new UnsupportedOperationException();
        }

        public BucketFunction getBucketFunction(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle, List<Type> partitionChannelTypes, int bucketCount) {
            throw new UnsupportedOperationException();
        }
    }
}

