/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.Type;
import io.trino.sql.ir.BooleanLiteral;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IfExpression;
import io.trino.sql.ir.NullLiteral;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

public class TransformUncorrelatedSubqueryToJoin
implements Rule<CorrelatedJoinNode> {
    private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Pattern.empty(Patterns.CorrelatedJoin.correlation()));

    @Override
    public Pattern<CorrelatedJoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
        if (correlatedJoinNode.getType() == JoinType.INNER || correlatedJoinNode.getType() == JoinType.LEFT) {
            return Rule.Result.ofPlanNode(this.rewriteToJoin(correlatedJoinNode, correlatedJoinNode.getType(), correlatedJoinNode.getFilter(), context.getLookup()));
        }
        Preconditions.checkState((correlatedJoinNode.getType() == JoinType.RIGHT || correlatedJoinNode.getType() == JoinType.FULL ? 1 : 0) != 0, (Object)("unexpected CorrelatedJoin type: " + String.valueOf((Object)correlatedJoinNode.getType())));
        JoinType type = correlatedJoinNode.getType() == JoinType.RIGHT ? JoinType.INNER : JoinType.LEFT;
        JoinNode joinNode = this.rewriteToJoin(correlatedJoinNode, type, BooleanLiteral.TRUE_LITERAL, context.getLookup());
        if (correlatedJoinNode.getFilter().equals(BooleanLiteral.TRUE_LITERAL)) {
            return Rule.Result.ofPlanNode(joinNode);
        }
        if (correlatedJoinNode.getType() == JoinType.RIGHT) {
            Assignments.Builder assignments = Assignments.builder();
            assignments.putIdentities((Iterable<Symbol>)Sets.intersection((Set)ImmutableSet.copyOf(correlatedJoinNode.getSubquery().getOutputSymbols()), (Set)ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())));
            for (Symbol inputSymbol : Sets.intersection((Set)ImmutableSet.copyOf(correlatedJoinNode.getInput().getOutputSymbols()), (Set)ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols()))) {
                Type inputType = context.getSymbolAllocator().getTypes().get(inputSymbol);
                assignments.put(inputSymbol, new IfExpression(correlatedJoinNode.getFilter(), (Expression)inputSymbol.toSymbolReference(), new Cast(new NullLiteral(), inputType)));
            }
            ProjectNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), joinNode, assignments.build());
            return Rule.Result.ofPlanNode(projectNode);
        }
        return Rule.Result.empty();
    }

    private JoinNode rewriteToJoin(CorrelatedJoinNode parent, JoinType type, Expression filter, Lookup lookup) {
        if (type == JoinType.LEFT && QueryCardinalityUtil.extractCardinality(parent.getSubquery(), lookup).isAtLeastScalar() && filter.equals(BooleanLiteral.TRUE_LITERAL)) {
            type = JoinType.INNER;
        }
        return new JoinNode(parent.getId(), type, parent.getInput(), parent.getSubquery(), (List<JoinNode.EquiJoinClause>)ImmutableList.of(), parent.getInput().getOutputSymbols(), parent.getSubquery().getOutputSymbols(), false, filter.equals(BooleanLiteral.TRUE_LITERAL) ? Optional.empty() : Optional.of(filter), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), (Map<DynamicFilterId, Symbol>)ImmutableMap.of(), Optional.empty());
    }
}

