/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.DetermineJoinDistributionType;
import com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties;
import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.UnnestNode;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.airlift.units.DataSize;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;

public class JoinSwappingUtils {
    static final List<Class<? extends PlanNode>> EXPANDING_NODE_CLASSES = ImmutableList.of(JoinNode.class, UnnestNode.class);
    private static final double SIZE_DIFFERENCE_THRESHOLD = 8.0;

    private JoinSwappingUtils() {
    }

    public static Optional<JoinNode> createRuntimeSwappedJoinNode(JoinNode joinNode, Metadata metadata, Lookup lookup, Session session, PlanNodeIdAllocator idAllocator) {
        JoinNode swapped = joinNode.flipChildren();
        PlanNode newLeft = swapped.getLeft();
        Optional<Object> leftHashVariable = swapped.getLeftHashVariable();
        PlanNode resolvedSwappedLeft = lookup.resolve(newLeft);
        if (resolvedSwappedLeft instanceof ExchangeNode && resolvedSwappedLeft.getSources().size() == 1 && JoinSwappingUtils.checkProbeSidePropertySatisfied((PlanNode)resolvedSwappedLeft.getSources().get(0), metadata, lookup, session)) {
            newLeft = (PlanNode)resolvedSwappedLeft.getSources().get(0);
            if (swapped.getLeftHashVariable().isPresent()) {
                int hashVariableIndex = resolvedSwappedLeft.getOutputVariables().indexOf(swapped.getLeftHashVariable().get());
                leftHashVariable = Optional.of(((PlanNode)resolvedSwappedLeft.getSources().get(0)).getOutputVariables().get(hashVariableIndex));
                if (swapped.getOutputVariables().contains(swapped.getLeftHashVariable().get())) {
                    return Optional.empty();
                }
            }
        }
        List buildJoinVariables = (List)swapped.getCriteria().stream().map(EquiJoinClause::getRight).collect(ImmutableList.toImmutableList());
        PlanNode newRight = swapped.getRight();
        if (!JoinSwappingUtils.checkBuildSidePropertySatisfied(swapped.getRight(), buildJoinVariables, metadata, lookup, session)) {
            newRight = SystemSessionProperties.getTaskConcurrency(session) > 1 ? ExchangeNode.systemPartitionedExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, swapped.getRight(), buildJoinVariables, swapped.getRightHashVariable()) : ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, swapped.getRight());
        }
        JoinNode newJoinNode = new JoinNode(swapped.getSourceLocation(), swapped.getId(), swapped.getType(), newLeft, newRight, swapped.getCriteria(), swapped.getOutputVariables(), swapped.getFilter(), leftHashVariable, swapped.getRightHashVariable(), swapped.getDistributionType(), swapped.getDynamicFilters());
        return Optional.of(newJoinNode);
    }

    public static boolean checkProbeSidePropertySatisfied(PlanNode node, Metadata metadata, Lookup lookup, Session session) {
        StreamPreferredProperties requiredProbeProperty = SystemSessionProperties.isSpillEnabled(session) && SystemSessionProperties.isJoinSpillingEnabled(session) ? StreamPreferredProperties.fixedParallelism() : StreamPreferredProperties.defaultParallelism(session);
        StreamPropertyDerivations.StreamProperties nodeProperty = JoinSwappingUtils.derivePropertiesRecursively(node, metadata, lookup, session);
        return requiredProbeProperty.isSatisfiedBy(nodeProperty);
    }

    private static boolean checkBuildSidePropertySatisfied(PlanNode node, List<VariableReferenceExpression> partitioningColumns, Metadata metadata, Lookup lookup, Session session) {
        StreamPreferredProperties requiredBuildProperty = SystemSessionProperties.getTaskConcurrency(session) > 1 ? StreamPreferredProperties.exactlyPartitionedOn(partitioningColumns) : StreamPreferredProperties.singleStream();
        StreamPropertyDerivations.StreamProperties nodeProperty = JoinSwappingUtils.derivePropertiesRecursively(node, metadata, lookup, session);
        return requiredBuildProperty.isSatisfiedBy(nodeProperty);
    }

    private static StreamPropertyDerivations.StreamProperties derivePropertiesRecursively(PlanNode node, Metadata metadata, Lookup lookup, Session session) {
        PlanNode actual = lookup.resolve(node);
        List inputProperties = (List)actual.getSources().stream().map(source -> JoinSwappingUtils.derivePropertiesRecursively(source, metadata, lookup, session)).collect(ImmutableList.toImmutableList());
        return StreamPropertyDerivations.deriveProperties(actual, inputProperties, metadata, session);
    }

    public static boolean isBelowBroadcastLimit(PlanNode planNode, Rule.Context context) {
        DataSize joinMaxBroadcastTableSize = SystemSessionProperties.getJoinMaxBroadcastTableSize(context.getSession());
        return DetermineJoinDistributionType.getSourceTablesSizeInBytes(planNode, context) <= (double)joinMaxBroadcastTableSize.toBytes();
    }

    public static boolean isSmallerThanThreshold(PlanNode planNodeA, PlanNode planNodeB, Rule.Context context) {
        double bOutputSize;
        double aOutputSize = JoinSwappingUtils.getFirstKnownOutputSizeInBytes(planNodeA, context);
        return aOutputSize * 8.0 < (bOutputSize = JoinSwappingUtils.getFirstKnownOutputSizeInBytes(planNodeB, context));
    }

    private static double getFirstKnownOutputSizeInBytes(PlanNode node, Rule.Context context) {
        return JoinSwappingUtils.getFirstKnownOutputSizeInBytes(node, context.getLookup(), context.getStatsProvider());
    }

    @VisibleForTesting
    public static double getFirstKnownOutputSizeInBytes(PlanNode node, Lookup lookup, StatsProvider statsProvider) {
        return Stream.of(node).flatMap(planNode -> {
            if (planNode instanceof GroupReference) {
                return lookup.resolveGroup(node);
            }
            return Stream.of(planNode);
        }).mapToDouble(resolvedNode -> {
            double outputSizeInBytes = statsProvider.getStats((PlanNode)resolvedNode).getOutputSizeInBytes((PlanNode)resolvedNode);
            if (!Double.isNaN(outputSizeInBytes)) {
                return outputSizeInBytes;
            }
            if (EXPANDING_NODE_CLASSES.stream().anyMatch(clazz -> clazz.isInstance(resolvedNode))) {
                return Double.NaN;
            }
            List sourceNodes = resolvedNode.getSources();
            if (sourceNodes.isEmpty()) {
                return Double.NaN;
            }
            double sourcesOutputSizeInBytes = 0.0;
            for (PlanNode sourceNode : sourceNodes) {
                double firstKnownOutputSizeInBytes = JoinSwappingUtils.getFirstKnownOutputSizeInBytes(sourceNode, lookup, statsProvider);
                if (Double.isNaN(firstKnownOutputSizeInBytes)) {
                    return Double.NaN;
                }
                sourcesOutputSizeInBytes += firstKnownOutputSizeInBytes;
            }
            return sourcesOutputSizeInBytes;
        }).sum();
    }
}

