/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import io.airlift.units.DataSize;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.PlanNodeStatsEstimateMath;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.operator.RetryPolicy;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.StreamPreferredProperties;
import io.trino.sql.planner.optimizations.StreamPropertyDerivations;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.ChildReplacer;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

public class AdaptiveReorderPartitionedJoin
implements Rule<JoinNode> {
    private static final Capture<ExchangeNode> LOCAL_EXCHANGE_NODE = Capture.newCapture();
    private static final Pattern<JoinNode> PATTERN = Patterns.join().matching(AdaptiveReorderPartitionedJoin::isPartitionedJoinWithNoHashSymbols).or(new Function[]{prev -> prev.with(Patterns.Join.right().matching(Patterns.exchange().matching(exchangeNode -> exchangeNode.getScope().equals((Object)ExchangeNode.Scope.LOCAL) && !exchangeNode.getType().equals((Object)ExchangeNode.Type.GATHER)).capturedAs(LOCAL_EXCHANGE_NODE))), prev -> prev.with(Patterns.Join.right().matching(Patterns.aggregation().matching(node -> node.getStep() == AggregationNode.Step.PARTIAL).with(Patterns.source().matching(Patterns.exchange().matching(exchangeNode -> exchangeNode.getScope().equals((Object)ExchangeNode.Scope.LOCAL) && !exchangeNode.getType().equals((Object)ExchangeNode.Type.GATHER)).capturedAs(LOCAL_EXCHANGE_NODE)))))});
    private final Metadata metadata;

    public AdaptiveReorderPartitionedJoin(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    private static boolean isPartitionedJoinWithNoHashSymbols(JoinNode joinNode) {
        return joinNode.getDistributionType().equals(Optional.of(JoinNode.DistributionType.PARTITIONED)) && joinNode.getRightHashSymbol().isEmpty() && joinNode.getLeftHashSymbol().isEmpty();
    }

    @Override
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK && SystemSessionProperties.isFaultTolerantExecutionAdaptiveJoinReorderingEnabled(session);
    }

    @Override
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        List buildSymbols;
        ExchangeNode localExchangeNode = (ExchangeNode)captures.get(LOCAL_EXCHANGE_NODE);
        if (!AdaptiveReorderPartitionedJoin.isBuildSideLocalExchangeNode(localExchangeNode, (Set<Symbol>)ImmutableSet.copyOf((Collection)(buildSymbols = Lists.transform(joinNode.getCriteria(), JoinNode.EquiJoinClause::getRight))))) {
            return Rule.Result.empty();
        }
        boolean flipJoin = AdaptiveReorderPartitionedJoin.flipJoinBasedOnStats(joinNode, context);
        if (flipJoin) {
            return Rule.Result.ofPlanNode(AdaptiveReorderPartitionedJoin.flipJoinAndFixLocalExchanges(joinNode, localExchangeNode.getId(), this.metadata, context));
        }
        return Rule.Result.empty();
    }

    private static boolean isBuildSideLocalExchangeNode(ExchangeNode exchangeNode, Set<Symbol> rightSymbols) {
        return exchangeNode.getScope() == ExchangeNode.Scope.LOCAL && exchangeNode.getPartitioningScheme().getPartitioning().getColumns().equals(rightSymbols) && exchangeNode.getPartitioningScheme().getHashColumn().isEmpty();
    }

    private static JoinNode flipJoinAndFixLocalExchanges(JoinNode joinNode, PlanNodeId buildSideLocalExchangeId, Metadata metadata, Rule.Context context) {
        JoinNode flippedJoinNode = joinNode.flipChildren();
        BuildToProbeLocalExchangeRewriter buildToProbeLocalExchangeRewriter = new BuildToProbeLocalExchangeRewriter(buildSideLocalExchangeId, context);
        PlanNode probeSide = SimplePlanRewriter.rewriteWith(buildToProbeLocalExchangeRewriter, context.getLookup().resolve(flippedJoinNode.getLeft()));
        PlanNode buildSide = flippedJoinNode.getRight();
        StreamPropertyDerivations.StreamProperties rightProperties = AdaptiveReorderPartitionedJoin.deriveStreamPropertiesRecursively(buildSide, metadata, context.getLookup(), context.getSession());
        List buildSymbols = Lists.transform(flippedJoinNode.getCriteria(), JoinNode.EquiJoinClause::getRight);
        StreamPreferredProperties expectedRightProperties = StreamPreferredProperties.partitionedOn(buildSymbols);
        if (!expectedRightProperties.isSatisfiedBy(rightProperties)) {
            ProbeToBuildLocalExchangeRewriter probeToBuildLocalExchangeRewriter = new ProbeToBuildLocalExchangeRewriter(buildSymbols, context);
            buildSide = SimplePlanRewriter.rewriteWith(probeToBuildLocalExchangeRewriter, context.getLookup().resolve(buildSide));
        }
        return new JoinNode(flippedJoinNode.getId(), flippedJoinNode.getType(), probeSide, buildSide, flippedJoinNode.getCriteria(), flippedJoinNode.getLeftOutputSymbols(), flippedJoinNode.getRightOutputSymbols(), flippedJoinNode.isMaySkipOutputDuplicates(), flippedJoinNode.getFilter(), flippedJoinNode.getLeftHashSymbol(), flippedJoinNode.getRightHashSymbol(), flippedJoinNode.getDistributionType(), flippedJoinNode.isSpillable(), flippedJoinNode.getDynamicFilters(), flippedJoinNode.getReorderJoinStatsAndCost());
    }

    private static boolean flipJoinBasedOnStats(JoinNode joinNode, Rule.Context context) {
        double leftOutputSizeInBytes = PlanNodeStatsEstimateMath.getFirstKnownOutputSizeInBytes(joinNode.getLeft(), context.getLookup(), context.getStatsProvider());
        double rightOutputSizeInBytes = PlanNodeStatsEstimateMath.getFirstKnownOutputSizeInBytes(joinNode.getRight(), context.getLookup(), context.getStatsProvider());
        DataSize minSizeThreshold = SystemSessionProperties.getFaultTolerantExecutionAdaptiveJoinReorderingMinSizeThreshold(context.getSession());
        double sizeDifferenceRatio = SystemSessionProperties.getFaultTolerantExecutionAdaptiveJoinReorderingSizeDifferenceRatio(context.getSession());
        return rightOutputSizeInBytes > (double)minSizeThreshold.toBytes() && rightOutputSizeInBytes > sizeDifferenceRatio * leftOutputSizeInBytes;
    }

    private static StreamPropertyDerivations.StreamProperties deriveStreamPropertiesRecursively(PlanNode node, Metadata metadata, Lookup lookup, Session session) {
        PlanNode resolvedNode = lookup.resolve(node);
        List inputProperties = (List)resolvedNode.getSources().stream().map(source -> AdaptiveReorderPartitionedJoin.deriveStreamPropertiesRecursively(source, metadata, lookup, session)).collect(ImmutableList.toImmutableList());
        return StreamPropertyDerivations.deriveStreamPropertiesWithoutActualProperties(resolvedNode, inputProperties, metadata, session);
    }

    private static PlanNode rewriteSources(SimplePlanRewriter<Void> rewriter, PlanNode node, Rule.Context context) {
        ImmutableList.Builder children = ImmutableList.builderWithExpectedSize((int)node.getSources().size());
        node.getSources().forEach(source -> children.add((Object)SimplePlanRewriter.rewriteWith(rewriter, context.getLookup().resolve((PlanNode)source))));
        return ChildReplacer.replaceChildren(node, (List<PlanNode>)children.build());
    }

    private static class BuildToProbeLocalExchangeRewriter
    extends SimplePlanRewriter<Void> {
        private final PlanNodeId localExchangeNodeId;
        private final Rule.Context context;

        private BuildToProbeLocalExchangeRewriter(PlanNodeId localExchangeNodeId, Rule.Context context) {
            this.localExchangeNodeId = Objects.requireNonNull(localExchangeNodeId, "localExchangeNodeId is null");
            this.context = Objects.requireNonNull(context, "context is null");
        }

        @Override
        public PlanNode visitPlan(PlanNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            throw new UnsupportedOperationException("Unexpected plan node: " + node.getClass().getSimpleName());
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, SimplePlanRewriter.RewriteContext<Void> ctx) {
            Verify.verify((node.getStep() == AggregationNode.Step.PARTIAL ? 1 : 0) != 0, (String)"Unexpected aggregation step: %s", (Object)((Object)node.getStep()));
            return AdaptiveReorderPartitionedJoin.rewriteSources(this, node, this.context);
        }

        @Override
        public PlanNode visitExchange(ExchangeNode node, SimplePlanRewriter.RewriteContext<Void> ctx) {
            Verify.verify((node.getScope().equals((Object)ExchangeNode.Scope.LOCAL) && node.getId().equals(this.localExchangeNodeId) ? 1 : 0) != 0, (String)"Unexpected exchange node: %s", (Object)node.getId());
            if (node.getSources().size() == 1) {
                return node.getSources().getFirst();
            }
            return ExchangeNode.roundRobinExchange(this.context.getIdAllocator().getNextId(), ExchangeNode.Scope.LOCAL, node.getSources(), node.getOutputSymbols());
        }
    }

    private static class ProbeToBuildLocalExchangeRewriter
    extends SimplePlanRewriter<Void> {
        private final Rule.Context context;
        private final List<Symbol> buildSymbols;

        private ProbeToBuildLocalExchangeRewriter(List<Symbol> buildSymbols, Rule.Context context) {
            this.buildSymbols = Objects.requireNonNull(buildSymbols, "buildSymbols is null");
            this.context = Objects.requireNonNull(context, "context is null");
        }

        @Override
        public PlanNode visitPlan(PlanNode node, SimplePlanRewriter.RewriteContext<Void> ctx) {
            return ExchangeNode.partitionedExchange(this.context.getIdAllocator().getNextId(), ExchangeNode.Scope.LOCAL, node, this.buildSymbols, Optional.empty());
        }

        @Override
        public PlanNode visitExchange(ExchangeNode node, SimplePlanRewriter.RewriteContext<Void> ctx) {
            if (node.getScope().equals((Object)ExchangeNode.Scope.LOCAL) && node.getSources().size() > 1 && node.getPartitioningScheme().getPartitioning().getHandle().equals(SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION)) {
                return ExchangeNode.partitionedExchange(this.context.getIdAllocator().getNextId(), ExchangeNode.Scope.LOCAL, node.getSources(), this.buildSymbols, node.getOutputSymbols());
            }
            return this.visitPlan((PlanNode)node, ctx);
        }
    }
}

