/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.planner.ExpressionNodeInliner;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.DereferencePushdown;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.SubscriptExpression;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

public class PushDownDereferenceThroughJoin
implements Rule<ProjectNode> {
    private static final Capture<JoinNode> CHILD = Capture.newCapture();
    private final TypeAnalyzer typeAnalyzer;

    public PushDownDereferenceThroughJoin(TypeAnalyzer typeAnalyzer) {
        this.typeAnalyzer = Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
    }

    @Override
    public Pattern<ProjectNode> getPattern() {
        return Patterns.project().with(Patterns.source().matching(Patterns.join().capturedAs(CHILD)));
    }

    @Override
    public Rule.Result apply(ProjectNode projectNode, Captures captures, Rule.Context context) {
        JoinNode joinNode = (JoinNode)captures.get(CHILD);
        ImmutableList.Builder expressionsBuilder = ImmutableList.builder();
        expressionsBuilder.addAll(projectNode.getAssignments().getExpressions());
        joinNode.getFilter().ifPresent(arg_0 -> ((ImmutableList.Builder)expressionsBuilder).add(arg_0));
        Set dereferences = DereferencePushdown.extractRowSubscripts((Collection<Expression>)expressionsBuilder.build(), false, context.getSession(), this.typeAnalyzer, context.getSymbolAllocator().getTypes());
        ImmutableSet.Builder criteriaSymbolsBuilder = ImmutableSet.builder();
        joinNode.getCriteria().forEach(criteria -> {
            criteriaSymbolsBuilder.add((Object)criteria.getLeft());
            criteriaSymbolsBuilder.add((Object)criteria.getRight());
        });
        ImmutableSet excludeSymbols = criteriaSymbolsBuilder.build();
        dereferences = (Set)dereferences.stream().filter(arg_0 -> PushDownDereferenceThroughJoin.lambda$apply$1((Set)excludeSymbols, arg_0)).collect(ImmutableSet.toImmutableSet());
        if (dereferences.isEmpty()) {
            return Rule.Result.empty();
        }
        Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), this.typeAnalyzer);
        Map mappings = (Map)HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, entry -> ((Symbol)entry.getValue()).toSymbolReference()));
        Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> ExpressionNodeInliner.replaceExpression(expression, mappings));
        Assignments.Builder leftAssignmentsBuilder = Assignments.builder();
        Assignments.Builder rightAssignmentsBuilder = Assignments.builder();
        dereferenceAssignments.entrySet().stream().forEach(entry -> {
            Symbol baseSymbol = (Symbol)Iterables.getOnlyElement(SymbolsExtractor.extractAll((Expression)entry.getValue()));
            if (joinNode.getLeft().getOutputSymbols().contains(baseSymbol)) {
                leftAssignmentsBuilder.put((Symbol)entry.getKey(), (Expression)entry.getValue());
            } else if (joinNode.getRight().getOutputSymbols().contains(baseSymbol)) {
                rightAssignmentsBuilder.put((Symbol)entry.getKey(), (Expression)entry.getValue());
            } else {
                throw new IllegalArgumentException(String.format("Unexpected symbol %s in projectNode", baseSymbol));
            }
        });
        Assignments leftAssignments = leftAssignmentsBuilder.build();
        Assignments rightAssignments = rightAssignmentsBuilder.build();
        PlanNode leftNode = PushDownDereferenceThroughJoin.createProjectNodeIfRequired(joinNode.getLeft(), leftAssignments, context.getIdAllocator());
        PlanNode rightNode = PushDownDereferenceThroughJoin.createProjectNodeIfRequired(joinNode.getRight(), rightAssignments, context.getIdAllocator());
        List referredSymbolsInAssignments = newAssignments.getExpressions().stream().flatMap(expression -> SymbolsExtractor.extractAll(expression).stream()).collect(Collectors.toList());
        List<Symbol> newLeftOutputSymbols = referredSymbolsInAssignments.stream().filter(symbol -> leftNode.getOutputSymbols().contains(symbol)).collect(Collectors.toList());
        List<Symbol> newRightOutputSymbols = referredSymbolsInAssignments.stream().filter(symbol -> rightNode.getOutputSymbols().contains(symbol)).collect(Collectors.toList());
        JoinNode newJoinNode = new JoinNode(context.getIdAllocator().getNextId(), joinNode.getType(), leftNode, rightNode, joinNode.getCriteria(), newLeftOutputSymbols, newRightOutputSymbols, joinNode.isMaySkipOutputDuplicates(), joinNode.getFilter().map(expression -> ExpressionNodeInliner.replaceExpression(expression, mappings)), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost());
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), newJoinNode, newAssignments));
    }

    private static PlanNode createProjectNodeIfRequired(PlanNode planNode, Assignments dereferences, PlanNodeIdAllocator idAllocator) {
        if (dereferences.isEmpty()) {
            return planNode;
        }
        return new ProjectNode(idAllocator.getNextId(), planNode, Assignments.builder().putIdentities(planNode.getOutputSymbols()).putAll(dereferences).build());
    }

    private static /* synthetic */ boolean lambda$apply$1(Set excludeSymbols, SubscriptExpression expression) {
        return !excludeSymbols.contains(DereferencePushdown.getBase(expression));
    }
}

