/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.metadata.OperatorNotFoundException;
import io.trino.operator.join.JoinUtils;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;
import io.trino.sql.DynamicFilters;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.BooleanLiteral;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.ExpressionRewriter;
import io.trino.sql.ir.ExpressionTreeRewriter;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.LogicalExpression;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.plan.ChildReplacer;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.SpatialJoinNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.type.TypeCoercion;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

public class RemoveUnsupportedDynamicFilters
implements PlanOptimizer {
    private final PlannerContext plannerContext;

    public RemoveUnsupportedDynamicFilters(PlannerContext plannerContext) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    @Override
    public PlanNode optimize(PlanNode plan, PlanOptimizer.Context context) {
        PlanWithConsumedDynamicFilters result = plan.accept(new Rewriter(), ImmutableSet.of());
        return result.getNode();
    }

    private class Rewriter
    extends PlanVisitor<PlanWithConsumedDynamicFilters, Set<DynamicFilterId>> {
        private final TypeCoercion typeCoercion = new TypeCoercion(arg_0 -> ((TypeManager)removeUnsupportedDynamicFilters.plannerContext.getTypeManager()).getType(arg_0));

        @Override
        protected PlanWithConsumedDynamicFilters visitPlan(PlanNode node, Set<DynamicFilterId> allowedDynamicFilterIds) {
            List children = (List)node.getSources().stream().map(source -> source.accept(this, allowedDynamicFilterIds)).collect(ImmutableList.toImmutableList());
            PlanNode result = ChildReplacer.replaceChildren(node, children.stream().map(PlanWithConsumedDynamicFilters::getNode).collect(Collectors.toList()));
            Set consumedDynamicFilterIds = (Set)children.stream().map(PlanWithConsumedDynamicFilters::getConsumedDynamicFilterIds).flatMap(Collection::stream).collect(ImmutableSet.toImmutableSet());
            return new PlanWithConsumedDynamicFilters(result, consumedDynamicFilterIds);
        }

        @Override
        public PlanWithConsumedDynamicFilters visitJoin(JoinNode node, Set<DynamicFilterId> allowedDynamicFilterIds) {
            Map<DynamicFilterId, Symbol> currentJoinDynamicFilters = JoinUtils.getJoinDynamicFilters(node);
            ImmutableSet allowedDynamicFilterIdsProbeSide = ImmutableSet.builder().addAll(currentJoinDynamicFilters.keySet()).addAll(allowedDynamicFilterIds).build();
            PlanWithConsumedDynamicFilters leftResult = node.getLeft().accept(this, allowedDynamicFilterIdsProbeSide);
            Set<DynamicFilterId> consumedProbeSide = leftResult.getConsumedDynamicFilterIds();
            Map dynamicFilters = (Map)currentJoinDynamicFilters.entrySet().stream().filter(entry -> consumedProbeSide.contains(entry.getKey())).collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
            PlanWithConsumedDynamicFilters rightResult = node.getRight().accept(this, allowedDynamicFilterIds);
            HashSet<DynamicFilterId> consumed = new HashSet<DynamicFilterId>(rightResult.getConsumedDynamicFilterIds());
            consumed.addAll(consumedProbeSide);
            consumed.removeAll(dynamicFilters.keySet());
            Optional<Expression> filter = node.getFilter().map(this::removeAllDynamicFilters).filter(expression -> !expression.equals(BooleanLiteral.TRUE_LITERAL));
            PlanNode left = leftResult.getNode();
            PlanNode right = rightResult.getNode();
            if (!(left.equals(node.getLeft()) && right.equals(node.getRight()) && dynamicFilters.equals(currentJoinDynamicFilters) && filter.equals(node.getFilter()))) {
                return new PlanWithConsumedDynamicFilters(new JoinNode(node.getId(), node.getType(), left, right, node.getCriteria(), node.getLeftOutputSymbols(), node.getRightOutputSymbols(), node.isMaySkipOutputDuplicates(), filter, node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType(), node.isSpillable(), (Map<DynamicFilterId, Symbol>)(node.getDynamicFilters().isEmpty() ? ImmutableMap.of() : dynamicFilters), node.getReorderJoinStatsAndCost()), (Set<DynamicFilterId>)ImmutableSet.copyOf(consumed));
            }
            return new PlanWithConsumedDynamicFilters(node, (Set<DynamicFilterId>)ImmutableSet.copyOf(consumed));
        }

        @Override
        public PlanWithConsumedDynamicFilters visitSpatialJoin(SpatialJoinNode node, Set<DynamicFilterId> allowedDynamicFilterIds) {
            PlanWithConsumedDynamicFilters leftResult = node.getLeft().accept(this, allowedDynamicFilterIds);
            PlanWithConsumedDynamicFilters rightResult = node.getRight().accept(this, allowedDynamicFilterIds);
            ImmutableSet consumed = ImmutableSet.builder().addAll(leftResult.consumedDynamicFilterIds).addAll(rightResult.consumedDynamicFilterIds).build();
            Expression filter = this.removeAllDynamicFilters(node.getFilter());
            if (!node.getFilter().equals(filter) || leftResult.getNode() != node.getLeft() || rightResult.getNode() != node.getRight()) {
                return new PlanWithConsumedDynamicFilters(new SpatialJoinNode(node.getId(), node.getType(), leftResult.getNode(), rightResult.getNode(), node.getOutputSymbols(), filter, node.getLeftPartitionSymbol(), node.getRightPartitionSymbol(), node.getKdbTree()), (Set<DynamicFilterId>)consumed);
            }
            return new PlanWithConsumedDynamicFilters(node, (Set<DynamicFilterId>)consumed);
        }

        @Override
        public PlanWithConsumedDynamicFilters visitSemiJoin(SemiJoinNode node, Set<DynamicFilterId> allowedDynamicFilterIds) {
            Optional<Object> newFilterId;
            Optional<DynamicFilterId> dynamicFilterIdOptional = JoinUtils.getSemiJoinDynamicFilterId(node);
            if (dynamicFilterIdOptional.isEmpty()) {
                return this.visitPlan((PlanNode)node, allowedDynamicFilterIds);
            }
            DynamicFilterId dynamicFilterId = dynamicFilterIdOptional.get();
            ImmutableSet allowedDynamicFilterIdsSourceSide = ImmutableSet.builder().add((Object)dynamicFilterId).addAll(allowedDynamicFilterIds).build();
            PlanWithConsumedDynamicFilters sourceResult = node.getSource().accept(this, allowedDynamicFilterIdsSourceSide);
            PlanWithConsumedDynamicFilters filteringSourceResult = node.getFilteringSource().accept(this, allowedDynamicFilterIds);
            HashSet<DynamicFilterId> consumed = new HashSet<DynamicFilterId>(filteringSourceResult.getConsumedDynamicFilterIds());
            consumed.addAll(sourceResult.getConsumedDynamicFilterIds());
            if (consumed.contains(dynamicFilterId)) {
                consumed.remove(dynamicFilterId);
                newFilterId = Optional.of(dynamicFilterId);
            } else {
                newFilterId = Optional.empty();
            }
            PlanNode newSource = sourceResult.getNode();
            PlanNode newFilteringSource = filteringSourceResult.getNode();
            if (!(newSource.equals(node.getSource()) && newFilteringSource.equals(node.getFilteringSource()) && newFilterId.equals(dynamicFilterIdOptional))) {
                return new PlanWithConsumedDynamicFilters(new SemiJoinNode(node.getId(), newSource, newFilteringSource, node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol(), node.getSemiJoinOutput(), node.getSourceHashSymbol(), node.getFilteringSourceHashSymbol(), node.getDistributionType(), node.getDynamicFilterId().isEmpty() ? Optional.empty() : newFilterId), (Set<DynamicFilterId>)ImmutableSet.copyOf(consumed));
            }
            return new PlanWithConsumedDynamicFilters(node, (Set<DynamicFilterId>)ImmutableSet.copyOf(consumed));
        }

        @Override
        public PlanWithConsumedDynamicFilters visitFilter(FilterNode node, Set<DynamicFilterId> allowedDynamicFilterIds) {
            PlanWithConsumedDynamicFilters result = node.getSource().accept(this, allowedDynamicFilterIds);
            Expression original = node.getPredicate();
            ImmutableSet.Builder consumedDynamicFilterIds = ImmutableSet.builder().addAll(result.getConsumedDynamicFilterIds());
            PlanNode source = result.getNode();
            Expression modified = source instanceof TableScanNode ? this.removeDynamicFilters(original, allowedDynamicFilterIds, (ImmutableSet.Builder<DynamicFilterId>)consumedDynamicFilterIds) : this.removeAllDynamicFilters(original);
            if (BooleanLiteral.TRUE_LITERAL.equals(modified)) {
                return new PlanWithConsumedDynamicFilters(source, (Set<DynamicFilterId>)consumedDynamicFilterIds.build());
            }
            if (!original.equals(modified) || source != node.getSource()) {
                return new PlanWithConsumedDynamicFilters(new FilterNode(node.getId(), source, modified), (Set<DynamicFilterId>)consumedDynamicFilterIds.build());
            }
            return new PlanWithConsumedDynamicFilters(node, (Set<DynamicFilterId>)consumedDynamicFilterIds.build());
        }

        private Expression removeDynamicFilters(Expression expression, Set<DynamicFilterId> allowedDynamicFilterIds, ImmutableSet.Builder<DynamicFilterId> consumedDynamicFilterIds) {
            return IrUtils.combineConjuncts((Collection)IrUtils.extractConjuncts(expression).stream().map(this::removeNestedDynamicFilters).filter(conjunct -> DynamicFilters.getDescriptor(conjunct).map(descriptor -> {
                if (allowedDynamicFilterIds.contains(descriptor.getId()) && this.isSupportedDynamicFilterExpression(descriptor.getInput())) {
                    consumedDynamicFilterIds.add((Object)descriptor.getId());
                    return true;
                }
                return false;
            }).orElse(true)).collect(ImmutableList.toImmutableList()));
        }

        private boolean isSupportedDynamicFilterExpression(Expression expression) {
            Type castTargetType;
            if (expression instanceof SymbolReference) {
                return true;
            }
            if (!(expression instanceof Cast)) {
                return false;
            }
            Cast castExpression = (Cast)expression;
            if (!(castExpression.getExpression() instanceof SymbolReference)) {
                return false;
            }
            Type castSourceType = castExpression.expression().type();
            if (!this.typeCoercion.canCoerce(castSourceType, castTargetType = castExpression.type())) {
                return false;
            }
            return this.doesSaturatedFloorCastOperatorExist(castTargetType, castSourceType);
        }

        private boolean doesSaturatedFloorCastOperatorExist(Type fromType, Type toType) {
            try {
                RemoveUnsupportedDynamicFilters.this.plannerContext.getMetadata().getCoercion(OperatorType.SATURATED_FLOOR_CAST, fromType, toType);
            }
            catch (OperatorNotFoundException e) {
                return false;
            }
            return true;
        }

        private Expression removeAllDynamicFilters(Expression expression) {
            Expression rewrittenExpression = this.removeNestedDynamicFilters(expression);
            DynamicFilters.ExtractResult extractResult = DynamicFilters.extractDynamicFilters(rewrittenExpression);
            if (extractResult.getDynamicConjuncts().isEmpty()) {
                return rewrittenExpression;
            }
            return IrUtils.combineConjuncts(extractResult.getStaticConjuncts());
        }

        private Expression removeNestedDynamicFilters(Expression expression) {
            return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>(this){

                @Override
                public Expression rewriteLogicalExpression(LogicalExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
                    LogicalExpression rewrittenNode = treeRewriter.defaultRewrite(node, context);
                    boolean modified = node != rewrittenNode;
                    ImmutableList.Builder expressionBuilder = ImmutableList.builder();
                    for (Expression term : rewrittenNode.getTerms()) {
                        if (DynamicFilters.isDynamicFilter(term)) {
                            expressionBuilder.add((Object)BooleanLiteral.TRUE_LITERAL);
                            modified = true;
                            continue;
                        }
                        expressionBuilder.add((Object)term);
                    }
                    if (!modified) {
                        return node;
                    }
                    return IrUtils.combinePredicates(node.getOperator(), (Collection<Expression>)expressionBuilder.build());
                }
            }, expression);
        }
    }

    private static class PlanWithConsumedDynamicFilters {
        private final PlanNode node;
        private final Set<DynamicFilterId> consumedDynamicFilterIds;

        PlanWithConsumedDynamicFilters(PlanNode node, Set<DynamicFilterId> consumedDynamicFilterIds) {
            this.node = node;
            this.consumedDynamicFilterIds = ImmutableSet.copyOf(consumedDynamicFilterIds);
        }

        PlanNode getNode() {
            return this.node;
        }

        Set<DynamicFilterId> getConsumedDynamicFilterIds() {
            return this.consumedDynamicFilterIds;
        }
    }
}

