/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.planner.ExpressionSymbolInliner;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Expression;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;

public class TransformFilteringSemiJoinToInnerJoin
implements Rule<FilterNode> {
    private static final Capture<SemiJoinNode> SEMI_JOIN = Capture.newCapture();
    private static final Pattern<FilterNode> PATTERN = Patterns.filter().with(Patterns.source().matching(Patterns.semiJoin().capturedAs(SEMI_JOIN)));

    @Override
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isRewriteFilteringSemiJoinToInnerJoin(session);
    }

    @Override
    public Rule.Result apply(FilterNode filterNode, Captures captures, Rule.Context context) {
        SemiJoinNode semiJoin = (SemiJoinNode)captures.get(SEMI_JOIN);
        if (PlanNodeSearcher.searchFrom(semiJoin.getSource(), context.getLookup()).where(node -> node instanceof TableScanNode && ((TableScanNode)node).isUpdateTarget()).matches()) {
            return Rule.Result.empty();
        }
        Symbol semiJoinSymbol = semiJoin.getSemiJoinOutput();
        Predicate<Expression> isSemiJoinSymbol = expression -> expression.equals((Object)semiJoinSymbol.toSymbolReference());
        List<Expression> conjuncts = ExpressionUtils.extractConjuncts(filterNode.getPredicate());
        if (conjuncts.stream().noneMatch(isSemiJoinSymbol)) {
            return Rule.Result.empty();
        }
        Expression filteredPredicate = ExpressionUtils.and((Collection)conjuncts.stream().filter(Predicate.not(isSemiJoinSymbol)).collect(ImmutableList.toImmutableList()));
        Expression simplifiedPredicate = ExpressionSymbolInliner.inlineSymbols(symbol -> {
            if (symbol.equals(semiJoinSymbol)) {
                return BooleanLiteral.TRUE_LITERAL;
            }
            return symbol.toSymbolReference();
        }, filteredPredicate);
        Optional<Expression> joinFilter = simplifiedPredicate.equals((Object)BooleanLiteral.TRUE_LITERAL) ? Optional.empty() : Optional.of(simplifiedPredicate);
        AggregationNode filteringSourceDistinct = AggregationNode.singleAggregation(context.getIdAllocator().getNextId(), semiJoin.getFilteringSource(), (Map<Symbol, AggregationNode.Aggregation>)ImmutableMap.of(), AggregationNode.singleGroupingSet((List<Symbol>)ImmutableList.of((Object)semiJoin.getFilteringSourceJoinSymbol())));
        JoinNode innerJoin = new JoinNode(semiJoin.getId(), JoinType.INNER, semiJoin.getSource(), filteringSourceDistinct, (List<JoinNode.EquiJoinClause>)ImmutableList.of((Object)new JoinNode.EquiJoinClause(semiJoin.getSourceJoinSymbol(), semiJoin.getFilteringSourceJoinSymbol())), semiJoin.getSource().getOutputSymbols(), (List<Symbol>)ImmutableList.of(), false, joinFilter, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), (Map)semiJoin.getDynamicFilterId().map(id -> ImmutableMap.of((Object)id, (Object)semiJoin.getFilteringSourceJoinSymbol())).orElse(ImmutableMap.of()), Optional.empty());
        ProjectNode project = new ProjectNode(context.getIdAllocator().getNextId(), innerJoin, Assignments.builder().putIdentities(innerJoin.getOutputSymbols()).put(semiJoinSymbol, (Expression)BooleanLiteral.TRUE_LITERAL).build());
        return Rule.Result.ofPlanNode(project);
    }
}

