/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.operator.RetryPolicy;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.DynamicFilterSourceNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;

public final class AddDynamicFilterSource {
    private AddDynamicFilterSource() {
    }

    public static Set<Rule<?>> rules() {
        return ImmutableSet.of((Object)new RewriteJoinDynamicFilter(), (Object)new RewriteSemiJoinDynamicFilter(), (Object)new PushOrRemoveDynamicFilterSource());
    }

    private static boolean canAddDynamicFilterSource(PlanNode node, Collection<Symbol> dynamicFilterSymbols) {
        boolean isIdentityProjection = node instanceof ProjectNode && dynamicFilterSymbols.stream().allMatch(symbol -> ((ProjectNode)node).getAssignments().isIdentity((Symbol)symbol));
        return isIdentityProjection || node instanceof ExchangeNode && node.getSources().size() == 1;
    }

    private static boolean isRemoteExchange(PlanNode node) {
        if (!(node instanceof ExchangeNode)) {
            return false;
        }
        return ((ExchangeNode)node).getScope() == ExchangeNode.Scope.REMOTE;
    }

    private static class RewriteJoinDynamicFilter
    implements Rule<JoinNode> {
        private static final Capture<PlanNode> BUILD_SIDE_NODE = Capture.newCapture();
        private static final Pattern<JoinNode> PATTERN = Patterns.join().matching(node -> !node.getDynamicFilters().isEmpty()).with(Patterns.Join.right().capturedAs(BUILD_SIDE_NODE));

        private RewriteJoinDynamicFilter() {
        }

        @Override
        public Pattern<JoinNode> getPattern() {
            return PATTERN;
        }

        @Override
        public boolean isEnabled(Session session) {
            return SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK;
        }

        @Override
        public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
            PlanNode buildSource = (PlanNode)captures.get(BUILD_SIDE_NODE);
            if (!AddDynamicFilterSource.canAddDynamicFilterSource(buildSource, joinNode.getDynamicFilters().values())) {
                return Rule.Result.ofPlanNode(joinNode.withoutDynamicFilters());
            }
            return Rule.Result.ofPlanNode(joinNode.withoutDynamicFilters().replaceChildren((List<PlanNode>)ImmutableList.of((Object)joinNode.getLeft(), (Object)buildSource.replaceChildren((List<PlanNode>)ImmutableList.of((Object)new DynamicFilterSourceNode(context.getIdAllocator().getNextId(), (PlanNode)Iterables.getOnlyElement(buildSource.getSources()), joinNode.getDynamicFilters()))))));
        }
    }

    private static class RewriteSemiJoinDynamicFilter
    implements Rule<SemiJoinNode> {
        private static final Capture<PlanNode> FILTERING_SOURCE_NODE = Capture.newCapture();
        private static final Pattern<SemiJoinNode> PATTERN = Patterns.semiJoin().matching(node -> node.getDynamicFilterId().isPresent()).with(Patterns.SemiJoin.getFilteringSource().capturedAs(FILTERING_SOURCE_NODE));

        private RewriteSemiJoinDynamicFilter() {
        }

        @Override
        public Pattern<SemiJoinNode> getPattern() {
            return PATTERN;
        }

        @Override
        public boolean isEnabled(Session session) {
            return SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK;
        }

        @Override
        public Rule.Result apply(SemiJoinNode semiJoinNode, Captures captures, Rule.Context context) {
            PlanNode filteringSource = (PlanNode)captures.get(FILTERING_SOURCE_NODE);
            if (!AddDynamicFilterSource.canAddDynamicFilterSource(filteringSource, (Collection<Symbol>)ImmutableList.of((Object)semiJoinNode.getFilteringSourceJoinSymbol()))) {
                return Rule.Result.ofPlanNode(semiJoinNode.withoutDynamicFilter());
            }
            return Rule.Result.ofPlanNode(semiJoinNode.withoutDynamicFilter().replaceChildren((List<PlanNode>)ImmutableList.of((Object)semiJoinNode.getSource(), (Object)filteringSource.replaceChildren((List<PlanNode>)ImmutableList.of((Object)new DynamicFilterSourceNode(context.getIdAllocator().getNextId(), (PlanNode)Iterables.getOnlyElement(filteringSource.getSources()), (Map<DynamicFilterId, Symbol>)ImmutableMap.of((Object)semiJoinNode.getDynamicFilterId().orElseThrow(), (Object)semiJoinNode.getFilteringSourceJoinSymbol())))))));
        }
    }

    private static class PushOrRemoveDynamicFilterSource
    implements Rule<PlanNode> {
        private static final Capture<DynamicFilterSourceNode> DYNAMIC_FILTER_SOURCE = Capture.newCapture();
        private static final Pattern<PlanNode> PATTERN = Pattern.typeOf(PlanNode.class).with(Patterns.source().matching(Patterns.dynamicFilterSource().capturedAs(DYNAMIC_FILTER_SOURCE)));

        private PushOrRemoveDynamicFilterSource() {
        }

        @Override
        public Pattern<PlanNode> getPattern() {
            return PATTERN;
        }

        @Override
        public boolean isEnabled(Session session) {
            return SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK;
        }

        @Override
        public Rule.Result apply(PlanNode node, Captures captures, Rule.Context context) {
            if (AddDynamicFilterSource.isRemoteExchange(node)) {
                return Rule.Result.empty();
            }
            DynamicFilterSourceNode dynamicFilterSourceNode = (DynamicFilterSourceNode)captures.get(DYNAMIC_FILTER_SOURCE);
            PlanNode dynamicFilterChildNode = context.getLookup().resolve(dynamicFilterSourceNode.getSource());
            if (!AddDynamicFilterSource.canAddDynamicFilterSource(dynamicFilterChildNode, dynamicFilterSourceNode.getDynamicFilters().values())) {
                return Rule.Result.ofPlanNode(node.replaceChildren((List<PlanNode>)ImmutableList.of((Object)dynamicFilterChildNode)));
            }
            PlanNode dynamicFilterSourceRewritten = dynamicFilterSourceNode.replaceChildren(dynamicFilterChildNode.getSources());
            return Rule.Result.ofPlanNode(node.replaceChildren((List<PlanNode>)ImmutableList.of((Object)dynamicFilterChildNode.replaceChildren((List<PlanNode>)ImmutableList.of((Object)dynamicFilterSourceRewritten)))));
        }
    }
}

