/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.sanity;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.trino.Session;
import io.trino.cost.StatsAndCosts;
import io.trino.execution.warnings.WarningCollector;
import io.trino.sql.DynamicFilters;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.SubExpressionExtractor;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.OutputNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.planprinter.PlanPrinter;
import io.trino.sql.planner.sanity.PlanSanityChecker;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.SymbolReference;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;

public class DynamicFiltersChecker
implements PlanSanityChecker.Checker {
    @Override
    public void validate(PlanNode plan, Session session, PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) {
        try {
            this.validate(plan);
        }
        catch (RuntimeException e) {
            try {
                int nestLevel = 4;
                String explain = PlanPrinter.textLogicalPlan(plan, types, plannerContext.getMetadata(), plannerContext.getFunctionManager(), StatsAndCosts.empty(), session, nestLevel, false);
                e.addSuppressed(new Exception("Current plan:\n" + explain));
            }
            catch (RuntimeException runtimeException) {
                // empty catch block
            }
            throw e;
        }
    }

    private void validate(PlanNode plan) {
        plan.accept(new PlanVisitor<Set<DynamicFilterId>, Void>(){

            @Override
            protected Set<DynamicFilterId> visitPlan(PlanNode node, Void context) {
                HashSet<DynamicFilterId> consumed = new HashSet<DynamicFilterId>();
                for (PlanNode source : node.getSources()) {
                    consumed.addAll((Collection<DynamicFilterId>)source.accept(this, context));
                }
                return consumed;
            }

            @Override
            public Set<DynamicFilterId> visitOutput(OutputNode node, Void context) {
                Set<DynamicFilterId> unmatched = this.visitPlan((PlanNode)node, context);
                Verify.verify((boolean)unmatched.isEmpty(), (String)"All consumed dynamic filters could not be matched with a join/semi-join.", (Object[])new Object[0]);
                return unmatched;
            }

            @Override
            public Set<DynamicFilterId> visitJoin(JoinNode node, Void context) {
                Set<DynamicFilterId> currentJoinDynamicFilters = node.getDynamicFilters().keySet();
                Set<DynamicFilterId> consumedProbeSide = node.getLeft().accept(this, context);
                Sets.SetView unconsumedByProbeSide = Sets.difference(currentJoinDynamicFilters, consumedProbeSide);
                Verify.verify((boolean)unconsumedByProbeSide.isEmpty(), (String)"Dynamic filters %s present in join were not fully consumed by it's probe side.", (Object)unconsumedByProbeSide);
                Set<DynamicFilterId> consumedBuildSide = node.getRight().accept(this, context);
                Sets.SetView unconsumedByBuildSide = Sets.intersection(currentJoinDynamicFilters, consumedBuildSide);
                Verify.verify((boolean)unconsumedByBuildSide.isEmpty(), (String)"Dynamic filters %s present in join were consumed by it's build side.", (Object)unconsumedByBuildSide);
                List nonPushedDownFilters = node.getFilter().map(DynamicFilters::extractDynamicFilters).map(DynamicFilters.ExtractResult::getDynamicConjuncts).orElse((List)ImmutableList.of());
                Verify.verify((boolean)nonPushedDownFilters.isEmpty(), (String)"Dynamic filters %s present in join filter predicate were not pushed down.", (Object)nonPushedDownFilters);
                HashSet<DynamicFilterId> unmatched = new HashSet<DynamicFilterId>(consumedBuildSide);
                unmatched.addAll(consumedProbeSide);
                unmatched.removeAll(currentJoinDynamicFilters);
                return ImmutableSet.copyOf(unmatched);
            }

            @Override
            public Set<DynamicFilterId> visitSemiJoin(SemiJoinNode node, Void context) {
                Set<DynamicFilterId> consumedSourceSide = node.getSource().accept(this, context);
                Set<DynamicFilterId> consumedFilteringSourceSide = node.getFilteringSource().accept(this, context);
                HashSet<DynamicFilterId> unmatched = new HashSet<DynamicFilterId>(consumedSourceSide);
                unmatched.addAll(consumedFilteringSourceSide);
                if (node.getDynamicFilterId().isPresent()) {
                    DynamicFilterId dynamicFilterId = node.getDynamicFilterId().get();
                    Verify.verify((boolean)consumedSourceSide.contains(dynamicFilterId), (String)"The dynamic filter %s present in semi-join was not consumed by it's source side.", (Object)dynamicFilterId);
                    Verify.verify((!consumedFilteringSourceSide.contains(dynamicFilterId) ? 1 : 0) != 0, (String)"The dynamic filter %s present in semi-join was consumed by it's filtering source side.", (Object)dynamicFilterId);
                    unmatched.remove(dynamicFilterId);
                }
                return ImmutableSet.copyOf(unmatched);
            }

            @Override
            public Set<DynamicFilterId> visitFilter(FilterNode node, Void context) {
                List<DynamicFilters.Descriptor> dynamicFilters = DynamicFiltersChecker.extractDynamicPredicates(node.getPredicate());
                if (!dynamicFilters.isEmpty()) {
                    Verify.verify((boolean)(node.getSource() instanceof TableScanNode), (String)"Dynamic filters %s present in filter predicate whose source is not a table scan.", dynamicFilters);
                }
                ImmutableSet.Builder consumed = ImmutableSet.builder();
                dynamicFilters.forEach(descriptor -> {
                    DynamicFiltersChecker.validateDynamicFilterExpression(descriptor.getInput());
                    consumed.add((Object)descriptor.getId());
                });
                consumed.addAll((Iterable)node.getSource().accept(this, context));
                return consumed.build();
            }
        }, null);
    }

    private static void validateDynamicFilterExpression(Expression expression) {
        if (expression instanceof SymbolReference) {
            return;
        }
        Verify.verify((boolean)(expression instanceof Cast), (String)"Dynamic filter expression %s must be a SymbolReference or a CAST of SymbolReference.", (Object)expression);
        Cast castExpression = (Cast)expression;
        Verify.verify((boolean)(castExpression.getExpression() instanceof SymbolReference), (String)"The expression %s within in a CAST in dynamic filter must be a SymbolReference.", (Object)castExpression.getExpression());
    }

    private static List<DynamicFilters.Descriptor> extractDynamicPredicates(Expression expression) {
        return (List)SubExpressionExtractor.extract(expression).map(DynamicFilters::getDescriptor).filter(Optional::isPresent).map(Optional::get).collect(ImmutableList.toImmutableList());
    }
}

