/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.airlift.log.Logger;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.JoinDistributionType;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.JoinSwappingUtils;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import java.util.Objects;
import java.util.Optional;

public class RuntimeReorderJoinSides
implements Rule<JoinNode> {
    private static final Logger log = Logger.get(RuntimeReorderJoinSides.class);
    private static final Pattern<JoinNode> PATTERN = Patterns.join();
    private final Metadata metadata;
    private final SqlParser parser;

    public RuntimeReorderJoinSides(Metadata metadata, SqlParser parser) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        this.parser = Objects.requireNonNull(parser, "parser is null");
    }

    @Override
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        if (PlanNodeSearcher.searchFrom(joinNode, context.getLookup()).where(node -> node.getSources().isEmpty() && !(node instanceof TableScanNode)).matches()) {
            return Rule.Result.empty();
        }
        double leftOutputSizeInBytes = Double.NaN;
        double rightOutputSizeInBytes = Double.NaN;
        StatsProvider statsProvider = context.getStatsProvider();
        if (PlanNodeSearcher.searchFrom(joinNode, context.getLookup()).where(node -> !(node instanceof TableScanNode) && !(node instanceof ExchangeNode)).findAll().size() == 1) {
            leftOutputSizeInBytes = statsProvider.getStats(joinNode.getLeft()).getOutputSizeInBytes();
            rightOutputSizeInBytes = statsProvider.getStats(joinNode.getRight()).getOutputSizeInBytes();
        }
        if (Double.isNaN(leftOutputSizeInBytes) || Double.isNaN(rightOutputSizeInBytes)) {
            leftOutputSizeInBytes = statsProvider.getStats(joinNode.getLeft()).getOutputSizeInBytes(joinNode.getLeft());
            rightOutputSizeInBytes = statsProvider.getStats(joinNode.getRight()).getOutputSizeInBytes(joinNode.getRight());
        }
        if (Double.isNaN(leftOutputSizeInBytes) || Double.isNaN(rightOutputSizeInBytes)) {
            return Rule.Result.empty();
        }
        if (rightOutputSizeInBytes <= leftOutputSizeInBytes) {
            return Rule.Result.empty();
        }
        if (!this.isSwappedJoinValid(joinNode)) {
            return Rule.Result.empty();
        }
        Optional<JoinNode> rewrittenNode = JoinSwappingUtils.createRuntimeSwappedJoinNode(joinNode, this.metadata, this.parser, context.getLookup(), context.getSession(), context.getVariableAllocator(), context.getIdAllocator());
        if (rewrittenNode.isPresent()) {
            log.debug(String.format("Probe size: %.2f is smaller than Build size: %.2f => invoke runtime join swapping on JoinNode ID: %s.", leftOutputSizeInBytes, rightOutputSizeInBytes, joinNode.getId()));
            return Rule.Result.ofPlanNode(rewrittenNode.get());
        }
        return Rule.Result.empty();
    }

    private boolean isSwappedJoinValid(JoinNode join) {
        return !(join.getDistributionType().get() == JoinDistributionType.REPLICATED && join.getType() == JoinType.LEFT || join.getDistributionType().get() == JoinDistributionType.PARTITIONED && join.getCriteria().isEmpty() && join.getType() == JoinType.RIGHT);
    }
}

