/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.airlift.log.Logger;
import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.cost.CostComparator;
import com.facebook.presto.cost.CostProvider;
import com.facebook.presto.cost.PlanCostEstimate;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.expressions.RowExpressionNodeInliner;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.function.FunctionMetadataManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinDistributionType;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.planner.EqualityInference;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.iterative.ConfidenceBasedBroadcastUtil;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.DetermineJoinDistributionType;
import com.facebook.presto.sql.planner.optimizations.JoinNodeUtils;
import com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.MultiJoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.base.Verify;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

public class ReorderJoins
implements Rule<JoinNode> {
    private static final Logger log = Logger.get(ReorderJoins.class);
    private final Pattern<JoinNode> joinNodePattern;
    private final CostComparator costComparator;
    private final Metadata metadata;
    private final FunctionResolution functionResolution;
    private final DeterminismEvaluator determinismEvaluator;
    private String statsSource;

    public ReorderJoins(CostComparator costComparator, Metadata metadata) {
        this.costComparator = Objects.requireNonNull(costComparator, "costComparator is null");
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        this.functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver());
        this.determinismEvaluator = new RowExpressionDeterminismEvaluator(metadata.getFunctionAndTypeManager());
        this.joinNodePattern = Patterns.join().matching(joinNode -> !joinNode.getDistributionType().isPresent() && joinNode.getType() == JoinType.INNER && this.determinismEvaluator.isDeterministic((RowExpression)joinNode.getFilter().orElse(LogicalRowExpressions.TRUE_CONSTANT)));
    }

    @Override
    public Pattern<JoinNode> getPattern() {
        return this.joinNodePattern;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.getJoinReorderingStrategy(session) == FeaturesConfig.JoinReorderingStrategy.AUTOMATIC;
    }

    @Override
    public boolean isCostBased(Session session) {
        return this.isEnabled(session);
    }

    @Override
    public String getStatsSource() {
        return this.statsSource;
    }

    @Override
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        MultiJoinNode multiJoinNode = ReorderJoins.toMultiJoinNode(joinNode, context.getLookup(), SystemSessionProperties.getMaxReorderedJoins(context.getSession()), SystemSessionProperties.shouldHandleComplexEquiJoins(context.getSession()), this.functionResolution, this.determinismEvaluator);
        JoinEnumerator joinEnumerator = new JoinEnumerator(this.costComparator, multiJoinNode.getFilter(), context, this.determinismEvaluator, this.functionResolution, this.metadata);
        JoinEnumerationResult result = joinEnumerator.chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputVariables());
        if (!result.getPlanNode().isPresent()) {
            return Rule.Result.empty();
        }
        this.statsSource = context.getStatsProvider().getStats((PlanNode)joinNode).getSourceInfo().getSourceInfoName();
        PlanNode transformedPlan = result.getPlanNode().get();
        if (!multiJoinNode.getAssignments().isEmpty()) {
            transformedPlan = new ProjectNode(transformedPlan.getSourceLocation(), context.getIdAllocator().getNextId(), transformedPlan, multiJoinNode.getAssignments(), ProjectNode.Locality.LOCAL);
        }
        return Rule.Result.ofPlanNode(transformedPlan);
    }

    public static MultiJoinNode toMultiJoinNode(JoinNode joinNode, Lookup lookup, int joinLimit, boolean handleComplexEquiJoins, FunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator) {
        return new JoinNodeFlattener(joinNode, lookup, joinLimit + 1, handleComplexEquiJoins, functionResolution, determinismEvaluator).toMultiJoinNode();
    }

    @VisibleForTesting
    static class JoinEnumerator {
        private final Session session;
        private final CostProvider costProvider;
        private final Ordering<JoinEnumerationResult> resultComparator;
        private final PlanNodeIdAllocator idAllocator;
        private final Metadata metadata;
        private final RowExpression allFilter;
        private final EqualityInference allFilterInference;
        private final LogicalRowExpressions logicalRowExpressions;
        private final Lookup lookup;
        private final Rule.Context context;
        private final Map<Set<PlanNode>, JoinEnumerationResult> memo = new HashMap<Set<PlanNode>, JoinEnumerationResult>();
        private final FunctionResolution functionResolution;

        @VisibleForTesting
        JoinEnumerator(CostComparator costComparator, RowExpression filter, Rule.Context context, DeterminismEvaluator determinismEvaluator, FunctionResolution functionResolution, Metadata metadata) {
            this.context = Objects.requireNonNull(context);
            this.session = Objects.requireNonNull(context.getSession(), "session is null");
            this.costProvider = Objects.requireNonNull(context.getCostProvider(), "costProvider is null");
            this.resultComparator = costComparator.forSession(this.session).onResultOf(result -> ((JoinEnumerationResult)result).cost);
            this.idAllocator = Objects.requireNonNull(context.getIdAllocator(), "idAllocator is null");
            this.allFilter = Objects.requireNonNull(filter, "filter is null");
            this.lookup = Objects.requireNonNull(context.getLookup(), "lookup is null");
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
            this.allFilterInference = EqualityInference.createEqualityInference(metadata, filter);
            this.logicalRowExpressions = new LogicalRowExpressions(determinismEvaluator, (StandardFunctionResolution)functionResolution, (FunctionMetadataManager)metadata.getFunctionAndTypeManager());
            this.functionResolution = functionResolution;
        }

        private JoinEnumerationResult chooseJoinOrder(LinkedHashSet<PlanNode> sources, List<VariableReferenceExpression> outputVariables) {
            this.context.checkTimeoutNotExhausted();
            ImmutableSet multiJoinKey = ImmutableSet.copyOf(sources);
            JoinEnumerationResult bestResult = this.memo.get(multiJoinKey);
            if (bestResult == null) {
                Preconditions.checkState((sources.size() > 1 ? 1 : 0) != 0, (Object)"sources size is less than or equal to one");
                ImmutableList.Builder resultBuilder = ImmutableList.builder();
                Set<Set<Integer>> partitions = JoinEnumerator.generatePartitions(sources.size());
                for (Set<Integer> partition : partitions) {
                    JoinEnumerationResult result = this.createJoinAccordingToPartitioning(sources, outputVariables, partition);
                    if (result.equals(JoinEnumerationResult.UNKNOWN_COST_RESULT)) {
                        this.memo.put((Set<PlanNode>)multiJoinKey, result);
                        return result;
                    }
                    if (result.equals(JoinEnumerationResult.INFINITE_COST_RESULT)) continue;
                    resultBuilder.add((Object)result);
                }
                ImmutableList results = resultBuilder.build();
                if (results.isEmpty()) {
                    this.memo.put((Set<PlanNode>)multiJoinKey, JoinEnumerationResult.INFINITE_COST_RESULT);
                    return JoinEnumerationResult.INFINITE_COST_RESULT;
                }
                bestResult = (JoinEnumerationResult)this.resultComparator.min((Iterable)results);
                this.memo.put((Set<PlanNode>)multiJoinKey, bestResult);
            }
            bestResult.planNode.ifPresent(planNode -> log.debug("Least cost join was: %s", new Object[]{planNode}));
            return bestResult;
        }

        @VisibleForTesting
        static Set<Set<Integer>> generatePartitions(int totalNodes) {
            Preconditions.checkArgument((totalNodes > 1 ? 1 : 0) != 0, (Object)"totalNodes must be greater than 1");
            Set numbers = (Set)IntStream.range(0, totalNodes).boxed().collect(ImmutableSet.toImmutableSet());
            return (Set)Sets.powerSet((Set)numbers).stream().filter(subSet -> subSet.contains(0)).filter(subSet -> subSet.size() < numbers.size()).collect(ImmutableSet.toImmutableSet());
        }

        @VisibleForTesting
        JoinEnumerationResult createJoinAccordingToPartitioning(LinkedHashSet<PlanNode> sources, List<VariableReferenceExpression> outputVariables, Set<Integer> partitioning) {
            ImmutableList sourceList = ImmutableList.copyOf(sources);
            LinkedHashSet leftSources = partitioning.stream().map(((List)sourceList)::get).collect(Collectors.toCollection(LinkedHashSet::new));
            LinkedHashSet rightSources = sources.stream().filter(source -> !leftSources.contains(source)).collect(Collectors.toCollection(LinkedHashSet::new));
            return this.createJoin(leftSources, rightSources, outputVariables);
        }

        private JoinEnumerationResult createJoin(LinkedHashSet<PlanNode> leftSources, LinkedHashSet<PlanNode> rightSources, List<VariableReferenceExpression> outputVariables) {
            HashSet leftVariables = leftSources.stream().flatMap(node -> node.getOutputVariables().stream()).collect(Collectors.toCollection(HashSet::new));
            HashSet rightVariables = rightSources.stream().flatMap(node -> node.getOutputVariables().stream()).collect(Collectors.toCollection(HashSet::new));
            List<RowExpression> joinPredicates = this.getJoinPredicates(leftVariables, rightVariables);
            VariableAllocator variableAllocator = this.context.getVariableAllocator();
            JoinCondition joinConditions = this.extractJoinConditions(joinPredicates, leftVariables, rightVariables, variableAllocator);
            List<EquiJoinClause> joinClauses = joinConditions.getJoinClauses();
            List<RowExpression> joinFilters = joinConditions.getJoinFilters();
            leftVariables.addAll(joinConditions.getNewLeftAssignments().keySet());
            rightVariables.addAll(joinConditions.getNewRightAssignments().keySet());
            if (joinClauses.isEmpty()) {
                return JoinEnumerationResult.INFINITE_COST_RESULT;
            }
            ImmutableSet requiredJoinVariables = ImmutableSet.builder().addAll(outputVariables).addAll(VariablesExtractor.extractUnique(joinPredicates)).build();
            JoinEnumerationResult leftResult = this.getJoinSource(leftSources, (List)requiredJoinVariables.stream().filter(leftVariables::contains).collect(ImmutableList.toImmutableList()));
            if (leftResult.equals(JoinEnumerationResult.UNKNOWN_COST_RESULT)) {
                return JoinEnumerationResult.UNKNOWN_COST_RESULT;
            }
            if (leftResult.equals(JoinEnumerationResult.INFINITE_COST_RESULT)) {
                return JoinEnumerationResult.INFINITE_COST_RESULT;
            }
            PlanNode left = (PlanNode)leftResult.planNode.orElseThrow(() -> new VerifyException("Plan node is not present"));
            if (!joinConditions.getNewLeftAssignments().isEmpty()) {
                ImmutableMap.Builder assignments = ImmutableMap.builder();
                left.getOutputVariables().forEach(outputVariable -> assignments.put(outputVariable, outputVariable));
                assignments.putAll(joinConditions.getNewLeftAssignments());
                left = PlannerUtils.addProjections(left, this.idAllocator, (Map<VariableReferenceExpression, RowExpression>)assignments.build());
            }
            JoinEnumerationResult rightResult = this.getJoinSource(rightSources, (List)requiredJoinVariables.stream().filter(rightVariables::contains).collect(ImmutableList.toImmutableList()));
            if (rightResult.equals(JoinEnumerationResult.UNKNOWN_COST_RESULT)) {
                return JoinEnumerationResult.UNKNOWN_COST_RESULT;
            }
            if (rightResult.equals(JoinEnumerationResult.INFINITE_COST_RESULT)) {
                return JoinEnumerationResult.INFINITE_COST_RESULT;
            }
            PlanNode right = (PlanNode)rightResult.planNode.orElseThrow(() -> new VerifyException("Plan node is not present"));
            if (!joinConditions.getNewRightAssignments().isEmpty()) {
                ImmutableMap.Builder assignments = ImmutableMap.builder();
                right.getOutputVariables().forEach(outputVariable -> assignments.put(outputVariable, outputVariable));
                assignments.putAll(joinConditions.getNewRightAssignments());
                right = PlannerUtils.addProjections(right, this.idAllocator, (Map<VariableReferenceExpression, RowExpression>)assignments.build());
            }
            List sortedOutputVariables = (List)Stream.concat(left.getOutputVariables().stream(), right.getOutputVariables().stream()).filter(outputVariables::contains).collect(ImmutableList.toImmutableList());
            return this.setJoinNodeProperties(new JoinNode(left.getSourceLocation(), this.idAllocator.getNextId(), JoinType.INNER, left, right, joinClauses, sortedOutputVariables, joinFilters.isEmpty() ? Optional.empty() : Optional.of(LogicalRowExpressions.and(joinFilters)), Optional.empty(), Optional.empty(), Optional.empty(), (Map)ImmutableMap.of()));
        }

        private List<RowExpression> getJoinPredicates(Set<VariableReferenceExpression> leftVariables, Set<VariableReferenceExpression> rightVariables) {
            ImmutableList.Builder joinPredicatesBuilder = ImmutableList.builder();
            EqualityInference.Builder builder = new EqualityInference.Builder(this.metadata);
            StreamSupport.stream(builder.nonInferableConjuncts(this.allFilter).spliterator(), false).map(conjunct -> this.allFilterInference.rewriteExpression((RowExpression)conjunct, (Predicate<VariableReferenceExpression>)((Predicate)variable -> leftVariables.contains(variable) || rightVariables.contains(variable)))).filter(Objects::nonNull).filter(conjunct -> this.allFilterInference.rewriteExpression((RowExpression)conjunct, (Predicate<VariableReferenceExpression>)((Predicate)leftVariables::contains)) == null).filter(conjunct -> this.allFilterInference.rewriteExpression((RowExpression)conjunct, (Predicate<VariableReferenceExpression>)((Predicate)rightVariables::contains)) == null).forEach(arg_0 -> ((ImmutableList.Builder)joinPredicatesBuilder).add(arg_0));
            List<RowExpression> joinEqualities = this.allFilterInference.generateEqualitiesPartitionedBy((Predicate<VariableReferenceExpression>)((Predicate)variable -> leftVariables.contains(variable) || rightVariables.contains(variable))).getScopeEqualities();
            EqualityInference joinInference = EqualityInference.createEqualityInference(this.metadata, joinEqualities.toArray(new RowExpression[0]));
            joinPredicatesBuilder.addAll(joinInference.generateEqualitiesPartitionedBy((Predicate<VariableReferenceExpression>)Predicates.in(leftVariables)).getScopeStraddlingEqualities());
            return joinPredicatesBuilder.build();
        }

        private JoinEnumerationResult getJoinSource(LinkedHashSet<PlanNode> nodes, List<VariableReferenceExpression> outputVariables) {
            if (nodes.size() == 1) {
                PlanNode planNode = (PlanNode)Iterables.getOnlyElement(nodes);
                ImmutableList.Builder predicates = ImmutableList.builder();
                predicates.addAll(this.allFilterInference.generateEqualitiesPartitionedBy((Predicate<VariableReferenceExpression>)((Predicate)outputVariables::contains)).getScopeEqualities());
                EqualityInference.Builder builder = new EqualityInference.Builder(this.metadata);
                StreamSupport.stream(builder.nonInferableConjuncts(this.allFilter).spliterator(), false).map(conjunct -> this.allFilterInference.rewriteExpression((RowExpression)conjunct, (Predicate<VariableReferenceExpression>)((Predicate)outputVariables::contains))).filter(Objects::nonNull).forEach(arg_0 -> ((ImmutableList.Builder)predicates).add(arg_0));
                RowExpression filter = this.logicalRowExpressions.combineConjuncts((Collection)predicates.build());
                if (!LogicalRowExpressions.TRUE_CONSTANT.equals((Object)filter)) {
                    planNode = new FilterNode(planNode.getSourceLocation(), this.idAllocator.getNextId(), planNode, filter);
                }
                return this.createJoinEnumerationResult(planNode);
            }
            return this.chooseJoinOrder(nodes, outputVariables);
        }

        @VisibleForTesting
        JoinCondition extractJoinConditions(List<RowExpression> joinPredicates, Set<VariableReferenceExpression> leftVariables, Set<VariableReferenceExpression> rightVariables, VariableAllocator variableAllocator) {
            ImmutableMap.Builder newLeftAssignments = ImmutableMap.builder();
            ImmutableMap.Builder newRightAssignments = ImmutableMap.builder();
            ImmutableList.Builder joinClauses = ImmutableList.builder();
            ImmutableList.Builder joinFilters = ImmutableList.builder();
            for (RowExpression predicate : joinPredicates) {
                if (predicate instanceof CallExpression && this.functionResolution.isEqualFunction(((CallExpression)predicate).getFunctionHandle()) && ((CallExpression)predicate).getArguments().size() == 2) {
                    RowExpression argument0 = (RowExpression)((CallExpression)predicate).getArguments().get(0);
                    RowExpression argument1 = (RowExpression)((CallExpression)predicate).getArguments().get(1);
                    Set<VariableReferenceExpression> argument0Vars = VariablesExtractor.extractUnique(argument0);
                    Set<VariableReferenceExpression> argument1Vars = VariablesExtractor.extractUnique(argument1);
                    if (!(leftVariables.containsAll(argument0Vars) && rightVariables.containsAll(argument1Vars) || rightVariables.containsAll(argument0Vars) && leftVariables.containsAll(argument1Vars))) {
                        joinFilters.add((Object)predicate);
                        continue;
                    }
                    if (leftVariables.containsAll(argument1Vars)) {
                        RowExpression temp = argument1;
                        argument1 = argument0;
                        argument0 = temp;
                    }
                    if (!(argument0 instanceof VariableReferenceExpression)) {
                        VariableReferenceExpression newLeft = variableAllocator.newVariable(argument0);
                        newLeftAssignments.put((Object)newLeft, (Object)argument0);
                        argument0 = newLeft;
                    }
                    if (!(argument1 instanceof VariableReferenceExpression)) {
                        VariableReferenceExpression newRight = variableAllocator.newVariable(argument1);
                        newRightAssignments.put((Object)newRight, (Object)argument1);
                        argument1 = newRight;
                    }
                    joinClauses.add((Object)new EquiJoinClause((VariableReferenceExpression)argument0, (VariableReferenceExpression)argument1));
                    continue;
                }
                joinFilters.add((Object)predicate);
            }
            return new JoinCondition((List<EquiJoinClause>)joinClauses.build(), (List<RowExpression>)joinFilters.build(), (Map<VariableReferenceExpression, RowExpression>)newLeftAssignments.build(), (Map<VariableReferenceExpression, RowExpression>)newRightAssignments.build());
        }

        private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode) {
            Optional<JoinNode> result;
            if (QueryCardinalityUtil.isAtMostScalar(joinNode.getRight(), this.lookup)) {
                return this.createJoinEnumerationResult((PlanNode)joinNode.withDistributionType(JoinDistributionType.REPLICATED));
            }
            if (QueryCardinalityUtil.isAtMostScalar(joinNode.getLeft(), this.lookup)) {
                return this.createJoinEnumerationResult((PlanNode)joinNode.flipChildren().withDistributionType(JoinDistributionType.REPLICATED));
            }
            if (DetermineJoinDistributionType.isBelowMaxBroadcastSize(joinNode, this.context) && DetermineJoinDistributionType.isBelowMaxBroadcastSize(joinNode.flipChildren(), this.context) && !DetermineJoinDistributionType.mustPartition(joinNode) && SystemSessionProperties.confidenceBasedBroadcastEnabled(this.context.getSession()) && (result = ConfidenceBasedBroadcastUtil.confidenceBasedBroadcast(joinNode, this.context)).isPresent()) {
                return this.createJoinEnumerationResult((PlanNode)result.get());
            }
            List<JoinEnumerationResult> possibleJoinNodes = this.getPossibleJoinNodes(joinNode, SystemSessionProperties.getJoinDistributionType(this.session));
            Verify.verify((!possibleJoinNodes.isEmpty() ? 1 : 0) != 0, (String)"possibleJoinNodes is empty", (Object[])new Object[0]);
            if (possibleJoinNodes.stream().anyMatch(JoinEnumerationResult.UNKNOWN_COST_RESULT::equals)) {
                return JoinEnumerationResult.UNKNOWN_COST_RESULT;
            }
            return (JoinEnumerationResult)this.resultComparator.min(possibleJoinNodes);
        }

        private List<JoinEnumerationResult> getPossibleJoinNodes(JoinNode joinNode, FeaturesConfig.JoinDistributionType distributionType) {
            Preconditions.checkArgument((joinNode.getType() == JoinType.INNER ? 1 : 0) != 0, (String)"unexpected join node type: %s", (Object)joinNode.getType());
            if (joinNode.isCrossJoin()) {
                return this.getPossibleJoinNodes(joinNode, JoinDistributionType.REPLICATED);
            }
            switch (distributionType) {
                case PARTITIONED: {
                    return this.getPossibleJoinNodes(joinNode, JoinDistributionType.PARTITIONED);
                }
                case BROADCAST: {
                    return this.getPossibleJoinNodes(joinNode, JoinDistributionType.REPLICATED);
                }
                case AUTOMATIC: {
                    ImmutableList.Builder result = ImmutableList.builder();
                    result.addAll(this.getPossibleJoinNodes(joinNode, JoinDistributionType.PARTITIONED));
                    result.addAll(this.getPossibleJoinNodes(joinNode, JoinDistributionType.REPLICATED, (Predicate<JoinNode>)((Predicate)node -> DetermineJoinDistributionType.isBelowMaxBroadcastSize(node, this.context))));
                    return result.build();
                }
            }
            throw new IllegalArgumentException("unexpected join distribution type: " + (Object)((Object)distributionType));
        }

        private List<JoinEnumerationResult> getPossibleJoinNodes(JoinNode joinNode, JoinDistributionType distributionType) {
            return this.getPossibleJoinNodes(joinNode, distributionType, (Predicate<JoinNode>)((Predicate)node -> true));
        }

        private List<JoinEnumerationResult> getPossibleJoinNodes(JoinNode joinNode, JoinDistributionType distributionType, Predicate<JoinNode> isAllowed) {
            ImmutableList nodes = ImmutableList.of((Object)joinNode.withDistributionType(distributionType), (Object)joinNode.flipChildren().withDistributionType(distributionType));
            return (List)nodes.stream().filter(isAllowed).map(this::createJoinEnumerationResult).collect(ImmutableList.toImmutableList());
        }

        private JoinEnumerationResult createJoinEnumerationResult(PlanNode planNode) {
            return JoinEnumerationResult.createJoinEnumerationResult(Optional.of(planNode), this.costProvider.getCost(planNode));
        }

        @VisibleForTesting
        static class JoinCondition {
            List<EquiJoinClause> joinClauses;
            List<RowExpression> joinFilters;
            Map<VariableReferenceExpression, RowExpression> newLeftAssignments;
            Map<VariableReferenceExpression, RowExpression> newRightAssignments;

            public JoinCondition(List<EquiJoinClause> joinClauses, List<RowExpression> joinFilters, Map<VariableReferenceExpression, RowExpression> left, Map<VariableReferenceExpression, RowExpression> right) {
                this.joinClauses = joinClauses;
                this.joinFilters = joinFilters;
                this.newLeftAssignments = left;
                this.newRightAssignments = right;
            }

            public List<EquiJoinClause> getJoinClauses() {
                return this.joinClauses;
            }

            public List<RowExpression> getJoinFilters() {
                return this.joinFilters;
            }

            public Map<VariableReferenceExpression, RowExpression> getNewLeftAssignments() {
                return this.newLeftAssignments;
            }

            public Map<VariableReferenceExpression, RowExpression> getNewRightAssignments() {
                return this.newRightAssignments;
            }
        }
    }

    @VisibleForTesting
    static class JoinEnumerationResult {
        public static final JoinEnumerationResult UNKNOWN_COST_RESULT = new JoinEnumerationResult(Optional.empty(), PlanCostEstimate.unknown());
        public static final JoinEnumerationResult INFINITE_COST_RESULT = new JoinEnumerationResult(Optional.empty(), PlanCostEstimate.infinite());
        private final Optional<PlanNode> planNode;
        private final PlanCostEstimate cost;

        private JoinEnumerationResult(Optional<PlanNode> planNode, PlanCostEstimate cost) {
            this.planNode = Objects.requireNonNull(planNode, "planNode is null");
            this.cost = Objects.requireNonNull(cost, "cost is null");
            Preconditions.checkArgument(((cost.hasUnknownComponents() || cost.equals(PlanCostEstimate.infinite())) && !planNode.isPresent() || (!cost.hasUnknownComponents() || !cost.equals(PlanCostEstimate.infinite())) && planNode.isPresent() ? 1 : 0) != 0, (Object)"planNode should be present if and only if cost is known");
        }

        public Optional<PlanNode> getPlanNode() {
            return this.planNode;
        }

        public PlanCostEstimate getCost() {
            return this.cost;
        }

        static JoinEnumerationResult createJoinEnumerationResult(Optional<PlanNode> planNode, PlanCostEstimate cost) {
            if (cost.hasUnknownComponents()) {
                return UNKNOWN_COST_RESULT;
            }
            if (cost.equals(PlanCostEstimate.infinite())) {
                return INFINITE_COST_RESULT;
            }
            return new JoinEnumerationResult(planNode, cost);
        }
    }

    @VisibleForTesting
    private static class JoinNodeFlattener {
        private final LinkedHashSet<PlanNode> sources = new LinkedHashSet();
        private final Assignments intermediateAssignments;
        private final boolean handleComplexEquiJoins;
        private List<RowExpression> filters = new ArrayList<RowExpression>();
        private final List<VariableReferenceExpression> outputVariables;
        private final FunctionResolution functionResolution;
        private final DeterminismEvaluator determinismEvaluator;
        private final Lookup lookup;

        JoinNodeFlattener(JoinNode node, Lookup lookup, int sourceLimit, boolean handleComplexEquiJoins, FunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator) {
            Objects.requireNonNull(node, "node is null");
            Preconditions.checkState((node.getType() == JoinType.INNER ? 1 : 0) != 0, (Object)"join type must be INNER");
            this.outputVariables = node.getOutputVariables();
            this.lookup = Objects.requireNonNull(lookup, "lookup is null");
            this.functionResolution = Objects.requireNonNull(functionResolution, "functionResolution is null");
            this.determinismEvaluator = Objects.requireNonNull(determinismEvaluator, "determinismEvaluator is null");
            this.handleComplexEquiJoins = handleComplexEquiJoins;
            HashMap<VariableReferenceExpression, RowExpression> intermediateAssignments = new HashMap<VariableReferenceExpression, RowExpression>();
            this.flattenNode((PlanNode)node, sourceLimit, intermediateAssignments);
            ImmutableSet inputVariables = (ImmutableSet)this.sources.stream().flatMap(s -> s.getOutputVariables().stream()).collect(ImmutableSet.toImmutableSet());
            this.intermediateAssignments = this.resolveAssignments(intermediateAssignments, (Set<VariableReferenceExpression>)inputVariables);
            this.rewriteFilterWithInlinedAssignments(this.intermediateAssignments);
        }

        private Assignments resolveAssignments(Map<VariableReferenceExpression, RowExpression> assignments, Set<VariableReferenceExpression> availableVariables) {
            HashSet resolvedVariables = new HashSet();
            ImmutableList.copyOf(assignments.keySet()).forEach(variable -> this.resolveVariable((VariableReferenceExpression)variable, resolvedVariables, assignments, availableVariables));
            return Assignments.builder().putAll(assignments).build();
        }

        private void resolveVariable(VariableReferenceExpression variable, HashSet<VariableReferenceExpression> resolvedVariables, Map<VariableReferenceExpression, RowExpression> assignments, Set<VariableReferenceExpression> availableVariables) {
            RowExpression expression = assignments.get(variable);
            Sets.SetView variablesToResolve = Sets.difference((Set)Sets.difference(VariablesExtractor.extractUnique(expression), availableVariables), resolvedVariables);
            variablesToResolve.forEach(variableToResolve -> this.resolveVariable((VariableReferenceExpression)variableToResolve, resolvedVariables, assignments, availableVariables));
            assignments.put(variable, RowExpressionNodeInliner.replaceExpression((RowExpression)expression, assignments));
            resolvedVariables.add(variable);
        }

        private void rewriteFilterWithInlinedAssignments(Assignments assignments) {
            ImmutableList.Builder modifiedFilters = ImmutableList.builder();
            this.filters.forEach(filter -> modifiedFilters.add((Object)RowExpressionNodeInliner.replaceExpression((RowExpression)filter, (Map)assignments.getMap())));
            this.filters = modifiedFilters.build();
        }

        /*
         * Enabled aggressive block sorting
         */
        private void flattenNode(PlanNode node, int limit, Map<VariableReferenceExpression, RowExpression> assignmentsBuilder) {
            PlanNode resolved = this.lookup.resolve(node);
            if (resolved instanceof ProjectNode) {
                ProjectNode projectNode = (ProjectNode)resolved;
                if (this.handleComplexEquiJoins && this.lookup.resolve(projectNode.getSource()) instanceof JoinNode) {
                    if (projectNode.getAssignments().getExpressions().stream().allMatch(arg_0 -> ((DeterminismEvaluator)this.determinismEvaluator).isDeterministic(arg_0))) {
                        assignmentsBuilder.putAll(AssignmentUtils.getNonIdentityAssignments(projectNode.getAssignments()));
                        this.flattenNode(projectNode.getSource(), limit, assignmentsBuilder);
                        return;
                    }
                }
                this.sources.add(node);
                return;
            }
            if (!(resolved instanceof JoinNode) || this.sources.size() > limit - 2) {
                this.sources.add(node);
                return;
            }
            JoinNode joinNode = (JoinNode)resolved;
            if (joinNode.getType() == JoinType.INNER && this.determinismEvaluator.isDeterministic((RowExpression)joinNode.getFilter().orElse(LogicalRowExpressions.TRUE_CONSTANT)) && !joinNode.getDistributionType().isPresent()) {
                this.flattenNode(joinNode.getLeft(), limit - 1, assignmentsBuilder);
                this.flattenNode(joinNode.getRight(), limit, assignmentsBuilder);
                joinNode.getCriteria().stream().map(criteria -> JoinNodeUtils.toRowExpression(criteria, this.functionResolution)).forEach(this.filters::add);
                joinNode.getFilter().ifPresent(this.filters::add);
                return;
            }
            this.sources.add(node);
        }

        MultiJoinNode toMultiJoinNode() {
            ImmutableSet inputVariables = (ImmutableSet)this.sources.stream().flatMap(source -> source.getOutputVariables().stream()).collect(ImmutableSet.toImmutableSet());
            ImmutableSet.Builder updatedOutputVariables = ImmutableSet.builder();
            Assignments.Builder overallAssignments = Assignments.builder();
            boolean nonIdentityAssignmentsFound = false;
            for (VariableReferenceExpression outputVariable : this.outputVariables) {
                if (inputVariables.contains((Object)outputVariable)) {
                    overallAssignments.put(outputVariable, (RowExpression)outputVariable);
                    updatedOutputVariables.add((Object)outputVariable);
                    continue;
                }
                Preconditions.checkState((boolean)this.intermediateAssignments.getMap().containsKey(outputVariable), (String)"Output variable [%s] not found in input variables or in intermediate assignments", (Object)outputVariable);
                nonIdentityAssignmentsFound = true;
                overallAssignments.put(outputVariable, this.intermediateAssignments.get(outputVariable));
                updatedOutputVariables.addAll(VariablesExtractor.extractUnique(this.intermediateAssignments.get(outputVariable)));
            }
            return new MultiJoinNode(this.sources, LogicalRowExpressions.and(this.filters), (List<VariableReferenceExpression>)updatedOutputVariables.build().asList(), nonIdentityAssignmentsFound ? overallAssignments.build() : Assignments.of(), false, Optional.empty());
        }
    }
}

