/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.connector.MockConnectorFactory;
import io.trino.connector.MockConnectorTableHandle;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.ScalarStatsCalculator;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.plugin.tpch.TpchColumnHandle;
import io.trino.security.AccessControl;
import io.trino.spi.connector.Assignment;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.ConnectorPartitioningHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.connector.ConnectorTablePartitioning;
import io.trino.spi.connector.ConnectorTableProperties;
import io.trino.spi.connector.ProjectionApplicationResult;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.FieldDereference;
import io.trino.spi.expression.Variable;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FieldReference;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.ConnectorExpressionTranslator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.PushProjectionIntoTableScan;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.iterative.rule.test.RuleTester;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.testing.TestingSession;
import io.trino.transaction.TransactionId;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public class TestPushProjectionIntoTableScan {
    private static final String TEST_SCHEMA = "test_schema";
    private static final String TEST_TABLE = "test_table";
    private static final Type ROW_TYPE = RowType.from(Arrays.asList(RowType.field((String)"a", (Type)BigintType.BIGINT), RowType.field((String)"b", (Type)BigintType.BIGINT)));
    private static final ConnectorPartitioningHandle PARTITIONING_HANDLE = new ConnectorPartitioningHandle(){};
    private static final Session MOCK_SESSION = TestingSession.testSessionBuilder().setCatalog("test_catalog").setSchema("test_schema").build();

    @Test
    public void testDoesNotFire() {
        String columnName = "input_column";
        Type columnType = ROW_TYPE;
        ColumnHandle inputColumnHandle = TestPushProjectionIntoTableScan.column(columnName, columnType);
        MockConnectorFactory factory = this.createMockFactory((Map<String, ColumnHandle>)ImmutableMap.of((Object)columnName, (Object)inputColumnHandle), Optional.empty());
        try (RuleTester ruleTester = RuleTester.builder().withDefaultCatalogConnectorFactory(factory).build();){
            PushProjectionIntoTableScan optimizer = TestPushProjectionIntoTableScan.createRule(ruleTester);
            ruleTester.assertThat((Rule<?>)optimizer).withSession(MOCK_SESSION).on(p -> {
                Symbol symbol = p.symbol(columnName, columnType);
                return p.project(Assignments.of((Symbol)p.symbol("symbol_dereference", (Type)BigintType.BIGINT), (Expression)new FieldReference((Expression)symbol.toSymbolReference(), 0)), (PlanNode)p.tableScan(ruleTester.getCurrentCatalogTableHandle(TEST_SCHEMA, TEST_TABLE), (List<Symbol>)ImmutableList.of((Object)symbol), (Map<Symbol, ColumnHandle>)ImmutableMap.of((Object)symbol, (Object)inputColumnHandle)));
            }).doesNotFire();
        }
    }

    @Test
    public void testPushProjection() {
        String columnName = "col0";
        Type columnType = ROW_TYPE;
        Symbol baseColumn = new Symbol(columnType, columnName);
        TpchColumnHandle columnHandle = new TpchColumnHandle(columnName, columnType);
        MockConnectorFactory factory = this.createMockFactory((Map<String, ColumnHandle>)ImmutableMap.of((Object)columnName, (Object)columnHandle), Optional.of(this::mockApplyProjection));
        try (RuleTester ruleTester = RuleTester.builder().withDefaultCatalogConnectorFactory(factory).build();){
            Symbol identity = new Symbol(ROW_TYPE, "symbol_identity");
            Symbol dereference = new Symbol((Type)BigintType.BIGINT, "symbol_dereference");
            Symbol constant = new Symbol((Type)BigintType.BIGINT, "symbol_constant");
            ImmutableMap types = ImmutableMap.of((Object)baseColumn, (Object)ROW_TYPE, (Object)identity, (Object)ROW_TYPE, (Object)dereference, (Object)BigintType.BIGINT, (Object)constant, (Object)BigintType.BIGINT);
            Assignments inputProjections = Assignments.builder().put(identity, (Expression)baseColumn.toSymbolReference()).put(dereference, (Expression)new FieldReference((Expression)baseColumn.toSymbolReference(), 0)).put(constant, (Expression)new Constant((Type)BigintType.BIGINT, (Object)5L)).build();
            TransactionId transactionId = ruleTester.getPlanTester().getTransactionManager().beginTransaction(false);
            Session session = MOCK_SESSION.beginTransactionId(transactionId, ruleTester.getPlanTester().getTransactionManager(), (AccessControl)ruleTester.getPlanTester().getAccessControl());
            Map connectorNames = (Map)inputProjections.entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, e -> ((ConnectorExpression)ConnectorExpressionTranslator.translate((Session)session, (Expression)((Expression)e.getValue())).get()).toString()));
            ImmutableMap newNames = ImmutableMap.of((Object)identity, (Object)("projected_variable_" + (String)connectorNames.get(identity)), (Object)dereference, (Object)("projected_dereference_" + (String)connectorNames.get(dereference)));
            ImmutableMap constants = ImmutableMap.of((Object)constant, (Object)Objects.requireNonNull(inputProjections.get(constant)));
            Map expectedColumns = (Map)newNames.entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getValue, arg_0 -> TestPushProjectionIntoTableScan.lambda$testPushProjection$1((Map)types, arg_0)));
            ruleTester.assertThat((Rule<?>)TestPushProjectionIntoTableScan.createRule(ruleTester)).withSession(MOCK_SESSION).on(arg_0 -> TestPushProjectionIntoTableScan.lambda$testPushProjection$2((Map)types, inputProjections, ruleTester, (ColumnHandle)columnHandle, baseColumn, arg_0)).matches(PlanMatchPattern.project((Map)Stream.concat(newNames.entrySet().stream(), constants.entrySet().stream()).collect(ImmutableMap.toImmutableMap(e -> ((Symbol)e.getKey()).name(), e -> {
                Object patt0$temp = e.getValue();
                if (patt0$temp instanceof String) {
                    String value = (String)patt0$temp;
                    return PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, value));
                }
                Object patt1$temp = e.getValue();
                if (patt1$temp instanceof Expression) {
                    Expression value = (Expression)patt1$temp;
                    return PlanMatchPattern.expression(value);
                }
                throw new IllegalArgumentException("Unexpected value type: " + e.getValue().getClass().getName());
            })), PlanMatchPattern.tableScan(new MockConnectorTableHandle(new SchemaTableName(TEST_SCHEMA, "projected_test_table"), (TupleDomain<ColumnHandle>)TupleDomain.all(), Optional.of(ImmutableList.copyOf(expectedColumns.values())))::equals, (TupleDomain<Predicate<ColumnHandle>>)TupleDomain.all(), (Map)expectedColumns.entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, e -> arg_0 -> ((ColumnHandle)((ColumnHandle)e.getValue())).equals(arg_0))))));
        }
    }

    @Test
    public void testPartitioningChanged() {
        String columnName = "col0";
        TpchColumnHandle columnHandle = new TpchColumnHandle(columnName, (Type)VarcharType.VARCHAR);
        MockConnectorFactory factory = this.createMockFactory((Map<String, ColumnHandle>)ImmutableMap.of((Object)columnName, (Object)columnHandle), Optional.of(this::mockApplyProjection));
        try (RuleTester ruleTester = RuleTester.builder().withDefaultCatalogConnectorFactory(factory).build();){
            Assertions.assertThatThrownBy(() -> TestPushProjectionIntoTableScan.lambda$testPartitioningChanged$0(ruleTester, (ColumnHandle)columnHandle)).hasMessage("Partitioning must not change after projection is pushed down");
        }
    }

    private MockConnectorFactory createMockFactory(Map<String, ColumnHandle> assignments, Optional<MockConnectorFactory.ApplyProjection> applyProjection) {
        List metadata = (List)assignments.entrySet().stream().map(entry -> new ColumnMetadata((String)entry.getKey(), ((TpchColumnHandle)entry.getValue()).type())).collect(ImmutableList.toImmutableList());
        MockConnectorFactory.Builder builder = MockConnectorFactory.builder().withListSchemaNames(connectorSession -> ImmutableList.of((Object)TEST_SCHEMA)).withListTables((connectorSession, schema) -> TEST_SCHEMA.equals(schema) ? ImmutableList.of((Object)TEST_TABLE) : ImmutableList.of()).withGetColumns(schemaTableName -> metadata).withGetTableProperties((session, tableHandle) -> {
            MockConnectorTableHandle mockTableHandle = (MockConnectorTableHandle)tableHandle;
            if (mockTableHandle.getTableName().getTableName().equals(TEST_TABLE)) {
                return new ConnectorTableProperties(TupleDomain.all(), Optional.of(new ConnectorTablePartitioning(PARTITIONING_HANDLE, (List)ImmutableList.of((Object)TestPushProjectionIntoTableScan.column("col", (Type)VarcharType.VARCHAR)))), Optional.empty(), (List)ImmutableList.of());
            }
            return new ConnectorTableProperties();
        });
        if (applyProjection.isPresent()) {
            builder = builder.withApplyProjection(applyProjection.get());
        }
        return builder.build();
    }

    private Optional<ProjectionApplicationResult<ConnectorTableHandle>> mockApplyProjection(ConnectorSession session, ConnectorTableHandle tableHandle, List<ConnectorExpression> projections, Map<String, ColumnHandle> assignments) {
        SchemaTableName inputSchemaTableName = ((MockConnectorTableHandle)tableHandle).getTableName();
        SchemaTableName projectedTableName = new SchemaTableName(inputSchemaTableName.getSchemaName(), "projected_" + inputSchemaTableName.getTableName());
        ImmutableList.Builder outputExpressions = ImmutableList.builder();
        ImmutableList.Builder outputAssignments = ImmutableList.builder();
        ImmutableList.Builder projectedColumnsBuilder = ImmutableList.builder();
        for (ConnectorExpression projection : projections) {
            String variablePrefix;
            if (projection instanceof Variable) {
                variablePrefix = "projected_variable_";
            } else if (projection instanceof FieldDereference) {
                variablePrefix = "projected_dereference_";
            } else if (projection instanceof Call) {
                variablePrefix = "projected_call_";
            } else {
                if (projection instanceof io.trino.spi.expression.Constant) {
                    throw new UnsupportedOperationException("constant expression should not be pushed to the connector");
                }
                throw new UnsupportedOperationException();
            }
            String newVariableName = variablePrefix + projection.toString();
            Variable newVariable = new Variable(newVariableName, projection.getType());
            TpchColumnHandle newColumnHandle = new TpchColumnHandle(newVariableName, projection.getType());
            outputExpressions.add((Object)newVariable);
            outputAssignments.add((Object)new Assignment(newVariableName, (ColumnHandle)newColumnHandle, projection.getType()));
            projectedColumnsBuilder.add((Object)newColumnHandle);
        }
        return Optional.of(new ProjectionApplicationResult((Object)new MockConnectorTableHandle(projectedTableName, (TupleDomain<ColumnHandle>)TupleDomain.all(), Optional.of(projectedColumnsBuilder.build())), (List)outputExpressions.build(), (List)outputAssignments.build(), false));
    }

    private static PushProjectionIntoTableScan createRule(RuleTester tester) {
        PlannerContext plannerContext = tester.getPlannerContext();
        return new PushProjectionIntoTableScan(plannerContext, new ScalarStatsCalculator(plannerContext));
    }

    private static ColumnHandle column(String name, Type type) {
        return new TpchColumnHandle(name, type);
    }

    private static /* synthetic */ void lambda$testPartitioningChanged$0(RuleTester ruleTester, ColumnHandle columnHandle) throws Throwable {
        ruleTester.assertThat((Rule<?>)TestPushProjectionIntoTableScan.createRule(ruleTester)).withSession(MOCK_SESSION).on(p -> p.project(Assignments.of(), (PlanNode)p.tableScan(ruleTester.getCurrentCatalogTableHandle(TEST_SCHEMA, TEST_TABLE), (List<Symbol>)ImmutableList.of((Object)p.symbol("col", (Type)VarcharType.VARCHAR)), (Map<Symbol, ColumnHandle>)ImmutableMap.of((Object)p.symbol("col", (Type)VarcharType.VARCHAR), (Object)columnHandle), Optional.of(true)))).matches(PlanMatchPattern.anyTree(new PlanMatchPattern[0]));
    }

    private static /* synthetic */ PlanNode lambda$testPushProjection$2(Map types, Assignments inputProjections, RuleTester ruleTester, ColumnHandle columnHandle, Symbol baseColumn, PlanBuilder p) {
        types.forEach((symbol, type) -> p.symbol(symbol.name(), (Type)type));
        return p.project(inputProjections, (PlanNode)p.tableScan(tableScan -> tableScan.setTableHandle(ruleTester.getCurrentCatalogTableHandle(TEST_SCHEMA, TEST_TABLE)).setSymbols((List<Symbol>)ImmutableList.copyOf(types.keySet())).setAssignments(types.keySet().stream().collect(Collectors.toMap(Function.identity(), v -> columnHandle))).setStatistics(Optional.of(PlanNodeStatsEstimate.builder().setOutputRowCount(42.0).addSymbolStatistics(baseColumn, SymbolStatsEstimate.builder().setNullsFraction(0.0).setDistinctValuesCount(33.0).build()).build()))));
    }

    private static /* synthetic */ ColumnHandle lambda$testPushProjection$1(Map types, Map.Entry e) {
        return TestPushProjectionIntoTableScan.column((String)e.getValue(), (Type)types.get(e.getKey()));
    }
}

