/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import io.trino.sql.ir.Expression;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.rule.InlineProjections;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

public final class PushProjectionThroughJoin {
    public static Optional<PlanNode> pushProjectionThroughJoin(ProjectNode projectNode, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator) {
        if (!projectNode.getAssignments().getExpressions().stream().allMatch(expression -> DeterminismEvaluator.isDeterministic(expression))) {
            return Optional.empty();
        }
        PlanNode child = lookup.resolve(projectNode.getSource());
        if (!(child instanceof JoinNode)) {
            return Optional.empty();
        }
        JoinNode joinNode = (JoinNode)child;
        PlanNode leftChild = joinNode.getLeft();
        PlanNode rightChild = joinNode.getRight();
        if (joinNode.getType() != JoinType.INNER) {
            return Optional.empty();
        }
        Assignments.Builder leftAssignmentsBuilder = Assignments.builder();
        Assignments.Builder rightAssignmentsBuilder = Assignments.builder();
        for (Map.Entry<Symbol, Expression> entry : projectNode.getAssignments().entrySet()) {
            Expression expression2 = entry.getValue();
            Set<Symbol> symbols = SymbolsExtractor.extractUnique(expression2);
            if (leftChild.getOutputSymbols().containsAll(symbols)) {
                leftAssignmentsBuilder.put(entry.getKey(), expression2);
                continue;
            }
            if (rightChild.getOutputSymbols().containsAll(symbols)) {
                rightAssignmentsBuilder.put(entry.getKey(), expression2);
                continue;
            }
            return Optional.empty();
        }
        Set<Symbol> joinRequiredSymbols = PushProjectionThroughJoin.getJoinRequiredSymbols(joinNode);
        for (Symbol requiredSymbol : joinRequiredSymbols) {
            if (leftChild.getOutputSymbols().contains(requiredSymbol)) {
                leftAssignmentsBuilder.putIdentity(requiredSymbol);
                continue;
            }
            Preconditions.checkState((boolean)rightChild.getOutputSymbols().contains(requiredSymbol));
            rightAssignmentsBuilder.putIdentity(requiredSymbol);
        }
        Assignments assignments = leftAssignmentsBuilder.build();
        Assignments rightAssignments = rightAssignmentsBuilder.build();
        List leftOutputSymbols = (List)assignments.getOutputs().stream().filter(arg_0 -> ((ImmutableSet)ImmutableSet.copyOf(projectNode.getOutputSymbols())).contains(arg_0)).collect(ImmutableList.toImmutableList());
        List rightOutputSymbols = (List)rightAssignments.getOutputs().stream().filter(arg_0 -> ((ImmutableSet)ImmutableSet.copyOf(projectNode.getOutputSymbols())).contains(arg_0)).collect(ImmutableList.toImmutableList());
        return Optional.of(new JoinNode(joinNode.getId(), joinNode.getType(), PushProjectionThroughJoin.inlineProjections(new ProjectNode(planNodeIdAllocator.getNextId(), leftChild, assignments), lookup), PushProjectionThroughJoin.inlineProjections(new ProjectNode(planNodeIdAllocator.getNextId(), rightChild, rightAssignments), lookup), joinNode.getCriteria(), leftOutputSymbols, rightOutputSymbols, joinNode.isMaySkipOutputDuplicates(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost()));
    }

    private static PlanNode inlineProjections(ProjectNode parentProjection, Lookup lookup) {
        PlanNode child = lookup.resolve(parentProjection.getSource());
        if (!(child instanceof ProjectNode)) {
            return parentProjection;
        }
        ProjectNode childProjection = (ProjectNode)child;
        return InlineProjections.inlineProjections(parentProjection, childProjection).map(node -> PushProjectionThroughJoin.inlineProjections(node, lookup)).orElse(parentProjection);
    }

    private static Set<Symbol> getJoinRequiredSymbols(JoinNode node) {
        return (Set)Streams.concat((Stream[])new Stream[]{node.getCriteria().stream().map(JoinNode.EquiJoinClause::getLeft), node.getCriteria().stream().map(JoinNode.EquiJoinClause::getRight), node.getFilter().map(SymbolsExtractor::extractUnique).orElse((Set)ImmutableSet.of()).stream(), node.getLeftHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream(), node.getRightHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream()}).collect(ImmutableSet.toImmutableSet());
    }

    private PushProjectionThroughJoin() {
    }
}

