/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.airlift.log.Logger;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties;
import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

public class RuntimeReorderJoinSides
implements Rule<JoinNode> {
    private static final Logger log = Logger.get(RuntimeReorderJoinSides.class);
    private static final Pattern<JoinNode> PATTERN = Patterns.join();
    private final Metadata metadata;
    private final SqlParser parser;

    public RuntimeReorderJoinSides(Metadata metadata, SqlParser parser) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        this.parser = Objects.requireNonNull(parser, "parser is null");
    }

    @Override
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        if (PlanNodeSearcher.searchFrom(joinNode, context.getLookup()).where(node -> node.getSources().isEmpty() && !(node instanceof TableScanNode)).matches()) {
            return Rule.Result.empty();
        }
        double leftOutputSizeInBytes = Double.NaN;
        double rightOutputSizeInBytes = Double.NaN;
        StatsProvider statsProvider = context.getStatsProvider();
        if (PlanNodeSearcher.searchFrom(joinNode, context.getLookup()).where(node -> !(node instanceof TableScanNode) && !(node instanceof ExchangeNode)).findAll().size() == 1) {
            leftOutputSizeInBytes = statsProvider.getStats(joinNode.getLeft()).getOutputSizeInBytes();
            rightOutputSizeInBytes = statsProvider.getStats(joinNode.getRight()).getOutputSizeInBytes();
        }
        if (Double.isNaN(leftOutputSizeInBytes) || Double.isNaN(rightOutputSizeInBytes)) {
            leftOutputSizeInBytes = statsProvider.getStats(joinNode.getLeft()).getOutputSizeInBytes(joinNode.getLeft().getOutputVariables());
            rightOutputSizeInBytes = statsProvider.getStats(joinNode.getRight()).getOutputSizeInBytes(joinNode.getRight().getOutputVariables());
        }
        if (Double.isNaN(leftOutputSizeInBytes) || Double.isNaN(rightOutputSizeInBytes)) {
            return Rule.Result.empty();
        }
        if (rightOutputSizeInBytes <= leftOutputSizeInBytes) {
            return Rule.Result.empty();
        }
        if (!this.isSwappedJoinValid(joinNode)) {
            return Rule.Result.empty();
        }
        JoinNode swapped = joinNode.flipChildren();
        PlanNode newLeft = swapped.getLeft();
        Optional<Object> leftHashVariable = swapped.getLeftHashVariable();
        PlanNode resolvedSwappedLeft = context.getLookup().resolve(newLeft);
        if (resolvedSwappedLeft instanceof ExchangeNode && resolvedSwappedLeft.getSources().size() == 1 && this.checkProbeSidePropertySatisfied((PlanNode)resolvedSwappedLeft.getSources().get(0), context)) {
            newLeft = (PlanNode)resolvedSwappedLeft.getSources().get(0);
            if (swapped.getLeftHashVariable().isPresent()) {
                int hashVariableIndex = resolvedSwappedLeft.getOutputVariables().indexOf(swapped.getLeftHashVariable().get());
                leftHashVariable = Optional.of(((PlanNode)resolvedSwappedLeft.getSources().get(0)).getOutputVariables().get(hashVariableIndex));
                if (swapped.getOutputVariables().contains(swapped.getLeftHashVariable().get())) {
                    return Rule.Result.empty();
                }
            }
        }
        List buildJoinVariables = (List)swapped.getCriteria().stream().map(JoinNode.EquiJoinClause::getRight).collect(ImmutableList.toImmutableList());
        PlanNode newRight = swapped.getRight();
        if (!this.checkBuildSidePropertySatisfied(swapped.getRight(), buildJoinVariables, context)) {
            newRight = SystemSessionProperties.getTaskConcurrency(context.getSession()) > 1 ? ExchangeNode.systemPartitionedExchange(context.getIdAllocator().getNextId(), ExchangeNode.Scope.LOCAL, swapped.getRight(), buildJoinVariables, swapped.getRightHashVariable()) : ExchangeNode.gatheringExchange(context.getIdAllocator().getNextId(), ExchangeNode.Scope.LOCAL, swapped.getRight());
        }
        JoinNode newJoinNode = new JoinNode(swapped.getId(), swapped.getType(), newLeft, newRight, swapped.getCriteria(), swapped.getOutputVariables(), swapped.getFilter(), leftHashVariable, swapped.getRightHashVariable(), swapped.getDistributionType(), swapped.getDynamicFilters());
        log.debug(String.format("Probe size: %.2f is smaller than Build size: %.2f => invoke runtime join swapping on JoinNode ID: %s.", leftOutputSizeInBytes, rightOutputSizeInBytes, newJoinNode.getId()));
        return Rule.Result.ofPlanNode(newJoinNode);
    }

    private boolean isSwappedJoinValid(JoinNode join) {
        return !(join.getDistributionType().get() == JoinNode.DistributionType.REPLICATED && join.getType() == JoinNode.Type.LEFT || join.getDistributionType().get() == JoinNode.DistributionType.PARTITIONED && join.getCriteria().isEmpty() && join.getType() == JoinNode.Type.RIGHT);
    }

    private boolean checkProbeSidePropertySatisfied(PlanNode node, Rule.Context context) {
        StreamPreferredProperties requiredProbeProperty = SystemSessionProperties.isSpillEnabled(context.getSession()) && SystemSessionProperties.isJoinSpillingEnabled(context.getSession()) ? StreamPreferredProperties.fixedParallelism() : StreamPreferredProperties.defaultParallelism(context.getSession());
        StreamPropertyDerivations.StreamProperties nodeProperty = this.derivePropertiesRecursively(node, this.metadata, this.parser, context);
        return requiredProbeProperty.isSatisfiedBy(nodeProperty);
    }

    private boolean checkBuildSidePropertySatisfied(PlanNode node, List<VariableReferenceExpression> partitioningColumns, Rule.Context context) {
        StreamPreferredProperties requiredBuildProperty = SystemSessionProperties.getTaskConcurrency(context.getSession()) > 1 ? StreamPreferredProperties.exactlyPartitionedOn(partitioningColumns) : StreamPreferredProperties.singleStream();
        StreamPropertyDerivations.StreamProperties nodeProperty = this.derivePropertiesRecursively(node, this.metadata, this.parser, context);
        return requiredBuildProperty.isSatisfiedBy(nodeProperty);
    }

    private StreamPropertyDerivations.StreamProperties derivePropertiesRecursively(PlanNode node, Metadata metadata, SqlParser parser, Rule.Context context) {
        PlanNode actual = context.getLookup().resolve(node);
        List inputProperties = (List)actual.getSources().stream().map(source -> this.derivePropertiesRecursively((PlanNode)source, metadata, parser, context)).collect(ImmutableList.toImmutableList());
        return StreamPropertyDerivations.deriveProperties(actual, inputProperties, metadata, context.getSession(), context.getVariableAllocator().getTypes(), parser);
    }
}

