/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.TableHandle;
import io.trino.spi.connector.BasicRelationStatistics;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.JoinApplicationResult;
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.planner.ConnectorExpressionTranslator;
import io.trino.sql.planner.IrTypeAnalyzer;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.Rules;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Expression;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public class PushJoinIntoTableScan
implements Rule<JoinNode> {
    private static final Capture<TableScanNode> LEFT_TABLE_SCAN = Capture.newCapture();
    private static final Capture<TableScanNode> RIGHT_TABLE_SCAN = Capture.newCapture();
    private static final Pattern<JoinNode> PATTERN = Patterns.join().with(Patterns.Join.left().matching(Patterns.tableScan().capturedAs(LEFT_TABLE_SCAN))).with(Patterns.Join.right().matching(Patterns.tableScan().capturedAs(RIGHT_TABLE_SCAN)));
    private final PlannerContext plannerContext;
    private final IrTypeAnalyzer typeAnalyzer;

    public PushJoinIntoTableScan(PlannerContext plannerContext, IrTypeAnalyzer typeAnalyzer) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
        this.typeAnalyzer = Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
    }

    @Override
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isAllowPushdownIntoConnectors(session);
    }

    @Override
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        if (joinNode.isCrossJoin()) {
            return Rule.Result.empty();
        }
        TableScanNode left = (TableScanNode)captures.get(LEFT_TABLE_SCAN);
        TableScanNode right = (TableScanNode)captures.get(RIGHT_TABLE_SCAN);
        Verify.verify((!left.isUpdateTarget() && !right.isUpdateTarget() ? 1 : 0) != 0, (String)"Unexpected Join over for-update table scan", (Object[])new Object[0]);
        Expression effectiveFilter = this.getEffectiveFilter(joinNode);
        ConnectorExpressionTranslator.ConnectorExpressionTranslation translation = ConnectorExpressionTranslator.translateConjuncts(context.getSession(), effectiveFilter, context.getSymbolAllocator().getTypes(), this.plannerContext, this.typeAnalyzer);
        if (!translation.remainingExpression().equals((Object)BooleanLiteral.TRUE_LITERAL)) {
            return Rule.Result.empty();
        }
        if (left.getEnforcedConstraint().isNone() || right.getEnforcedConstraint().isNone()) {
            return Rule.Result.empty();
        }
        Map leftAssignments = (Map)left.getAssignments().entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> ((Symbol)entry.getKey()).getName(), Map.Entry::getValue));
        Map rightAssignments = (Map)right.getAssignments().entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> ((Symbol)entry.getKey()).getName(), Map.Entry::getValue));
        JoinStatistics joinStatistics = this.getJoinStatistics(joinNode, left, right, context);
        Optional<JoinApplicationResult<TableHandle>> joinApplicationResult = this.plannerContext.getMetadata().applyJoin(context.getSession(), this.getJoinType(joinNode), left.getTable(), right.getTable(), translation.connectorExpression(), leftAssignments, rightAssignments, joinStatistics);
        if (joinApplicationResult.isEmpty()) {
            return Rule.Result.empty();
        }
        TableHandle handle = (TableHandle)joinApplicationResult.get().getTableHandle();
        Map leftColumnHandlesMapping = joinApplicationResult.get().getLeftColumnHandles();
        Map rightColumnHandlesMapping = joinApplicationResult.get().getRightColumnHandles();
        ImmutableMap.Builder assignmentsBuilder = ImmutableMap.builder();
        assignmentsBuilder.putAll((Map)left.getAssignments().entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, entry -> (ColumnHandle)leftColumnHandlesMapping.get(entry.getValue()))));
        assignmentsBuilder.putAll((Map)right.getAssignments().entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, entry -> (ColumnHandle)rightColumnHandlesMapping.get(entry.getValue()))));
        ImmutableMap assignments = assignmentsBuilder.buildOrThrow();
        io.trino.sql.planner.plan.JoinType joinType = joinNode.getType();
        TupleDomain<ColumnHandle> leftConstraint = this.deriveConstraint(left.getEnforcedConstraint(), leftColumnHandlesMapping, joinType == io.trino.sql.planner.plan.JoinType.RIGHT || joinType == io.trino.sql.planner.plan.JoinType.FULL);
        TupleDomain<ColumnHandle> rightConstraint = this.deriveConstraint(right.getEnforcedConstraint(), rightColumnHandlesMapping, joinType == io.trino.sql.planner.plan.JoinType.LEFT || joinType == io.trino.sql.planner.plan.JoinType.FULL);
        TupleDomain newEnforcedConstraint = TupleDomain.withColumnDomains((Map)ImmutableMap.builder().putAll((Map)leftConstraint.getDomains().orElseThrow()).putAll((Map)rightConstraint.getDomains().orElseThrow()).buildOrThrow());
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), new TableScanNode(joinNode.getId(), handle, (List<Symbol>)ImmutableList.copyOf(assignments.keySet()), (Map<Symbol, ColumnHandle>)assignments, (TupleDomain<ColumnHandle>)newEnforcedConstraint, Rules.deriveTableStatisticsForPushdown(context.getStatsProvider(), context.getSession(), joinApplicationResult.get().isPrecalculateStatistics(), joinNode), false, Optional.empty()), Assignments.identity(joinNode.getOutputSymbols())));
    }

    private JoinStatistics getJoinStatistics(final JoinNode join, final TableScanNode left, final TableScanNode right, final Rule.Context context) {
        return new JoinStatistics(){

            public Optional<BasicRelationStatistics> getLeftStatistics() {
                return this.getBasicRelationStats(left, left.getOutputSymbols(), context);
            }

            public Optional<BasicRelationStatistics> getRightStatistics() {
                return this.getBasicRelationStats(right, right.getOutputSymbols(), context);
            }

            public Optional<BasicRelationStatistics> getJoinStatistics() {
                return this.getBasicRelationStats(join, join.getOutputSymbols(), context);
            }

            private Optional<BasicRelationStatistics> getBasicRelationStats(PlanNode node, List<Symbol> outputSymbols, Rule.Context context2) {
                PlanNodeStatsEstimate stats = context2.getStatsProvider().getStats(node);
                TypeProvider types = context2.getSymbolAllocator().getTypes();
                double outputRowCount = stats.getOutputRowCount();
                double outputSize = stats.getOutputSizeInBytes(outputSymbols, types);
                if (Double.isNaN(outputRowCount) || Double.isNaN(outputSize)) {
                    return Optional.empty();
                }
                return Optional.of(new BasicRelationStatistics((long)outputRowCount, (long)outputSize));
            }
        };
    }

    private TupleDomain<ColumnHandle> deriveConstraint(TupleDomain<ColumnHandle> constraint, Map<ColumnHandle, ColumnHandle> columnMapping, boolean nullable) {
        if (nullable) {
            constraint = constraint.transformDomains((columnHandle, domain) -> domain.union(Domain.onlyNull((Type)domain.getType())));
        }
        return constraint.transformKeys(columnMapping::get);
    }

    public Expression getEffectiveFilter(JoinNode node) {
        Expression effectiveFilter = IrUtils.and((Collection)node.getCriteria().stream().map(JoinNode.EquiJoinClause::toExpression).collect(ImmutableList.toImmutableList()));
        if (node.getFilter().isPresent()) {
            effectiveFilter = IrUtils.and(effectiveFilter, node.getFilter().get());
        }
        return effectiveFilter;
    }

    private JoinType getJoinType(JoinNode joinNode) {
        return switch (joinNode.getType()) {
            default -> throw new MatchException(null, null);
            case io.trino.sql.planner.plan.JoinType.INNER -> JoinType.INNER;
            case io.trino.sql.planner.plan.JoinType.LEFT -> JoinType.LEFT_OUTER;
            case io.trino.sql.planner.plan.JoinType.RIGHT -> JoinType.RIGHT_OUTER;
            case io.trino.sql.planner.plan.JoinType.FULL -> JoinType.FULL_OUTER;
        };
    }
}

