/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import io.airlift.log.Logger;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.CachingStatsProvider;
import io.trino.cost.StatsCalculator;
import io.trino.cost.StatsProvider;
import io.trino.cost.TableStatsProvider;
import io.trino.cost.TaskCountEstimator;
import io.trino.operator.RetryPolicy;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.MergeWriterNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import io.trino.sql.planner.plan.TableExecuteNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.TableWriterNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.ValuesNode;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.ToDoubleFunction;

public class DeterminePartitionCount
implements PlanOptimizer {
    private static final Logger log = Logger.get(DeterminePartitionCount.class);
    private static final List<Class<? extends PlanNode>> INSERT_NODES = ImmutableList.of(TableExecuteNode.class, TableWriterNode.class, MergeWriterNode.class);
    private final StatsCalculator statsCalculator;
    private final TaskCountEstimator taskCountEstimator;

    public DeterminePartitionCount(StatsCalculator statsCalculator, TaskCountEstimator taskCountEstimator) {
        this.statsCalculator = Objects.requireNonNull(statsCalculator, "statsCalculator is null");
        this.taskCountEstimator = Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
    }

    @Override
    public PlanNode optimize(PlanNode plan, PlanOptimizer.Context context) {
        Objects.requireNonNull(plan, "plan is null");
        boolean taskRetries = SystemSessionProperties.getRetryPolicy(context.session()).equals((Object)RetryPolicy.TASK);
        if (!DeterminePartitionCount.isEligibleRemoteExchangePresent(plan, taskRetries)) {
            return plan;
        }
        boolean isWriteQuery = PlanNodeSearcher.searchFrom(plan).whereIsInstanceOfAny(INSERT_NODES).matches();
        if (isWriteQuery && !SystemSessionProperties.isDeterminePartitionCountForWriteEnabled(context.session())) {
            return plan;
        }
        return this.determinePartitionCount(plan, context.session(), context.types(), context.tableStatsProvider(), isWriteQuery).map(partitionCount -> SimplePlanRewriter.rewriteWith(new Rewriter((int)partitionCount, taskRetries), plan)).orElse(plan);
    }

    private Optional<Integer> determinePartitionCount(PlanNode plan, Session session, TypeProvider types, TableStatsProvider tableStatsProvider, boolean isWriteQuery) {
        int maxPartitionCount;
        int minPartitionCount;
        long minInputSizePerTask = SystemSessionProperties.getMinInputSizePerTask(session).toBytes();
        long minInputRowsPerTask = SystemSessionProperties.getMinInputRowsPerTask(session);
        if (minInputSizePerTask == 0L || minInputRowsPerTask == 0L) {
            return Optional.empty();
        }
        if (DeterminePartitionCount.isInputMultiplyingPlanNodePresent(plan)) {
            return Optional.empty();
        }
        if (SystemSessionProperties.getRetryPolicy(session).equals((Object)RetryPolicy.TASK)) {
            minPartitionCount = isWriteQuery ? SystemSessionProperties.getFaultTolerantExecutionMinPartitionCountForWrite(session) : SystemSessionProperties.getFaultTolerantExecutionMinPartitionCount(session);
            maxPartitionCount = SystemSessionProperties.getFaultTolerantExecutionMaxPartitionCount(session);
        } else {
            minPartitionCount = isWriteQuery ? SystemSessionProperties.getMinHashPartitionCountForWrite(session) : SystemSessionProperties.getMinHashPartitionCount(session);
            maxPartitionCount = SystemSessionProperties.getMaxHashPartitionCount(session);
        }
        Verify.verify((minPartitionCount <= maxPartitionCount ? 1 : 0) != 0, (String)"minPartitionCount %s larger than maxPartitionCount %s", (int)minPartitionCount, (int)maxPartitionCount);
        CachingStatsProvider statsProvider = new CachingStatsProvider(this.statsCalculator, session, types, tableStatsProvider);
        long queryMaxMemoryPerNode = SystemSessionProperties.getQueryMaxMemoryPerNode(session).toBytes();
        Optional<Integer> partitionCountBasedOnOutputSize = DeterminePartitionCount.getPartitionCountBasedOnOutputSize(plan, statsProvider, types, minInputSizePerTask, queryMaxMemoryPerNode);
        Optional<Integer> partitionCountBasedOnRows = DeterminePartitionCount.getPartitionCountBasedOnRows(plan, statsProvider, minInputRowsPerTask);
        if (partitionCountBasedOnOutputSize.isEmpty() || partitionCountBasedOnRows.isEmpty()) {
            return Optional.empty();
        }
        int partitionCount = Math.max(Math.max(partitionCountBasedOnOutputSize.get(), partitionCountBasedOnRows.get()), minPartitionCount);
        if (partitionCount >= maxPartitionCount) {
            return Optional.empty();
        }
        if (partitionCount * 2 >= this.taskCountEstimator.estimateHashedTaskCount(session) && !SystemSessionProperties.getRetryPolicy(session).equals((Object)RetryPolicy.TASK)) {
            return Optional.empty();
        }
        log.debug("Estimated remote exchange partition count for query %s is %s", new Object[]{session.getQueryId(), partitionCount});
        return Optional.of(partitionCount);
    }

    private static Optional<Integer> getPartitionCountBasedOnOutputSize(PlanNode plan, StatsProvider statsProvider, TypeProvider types, long minInputSizePerTask, long queryMaxMemoryPerNode) {
        double sourceTablesOutputSize = DeterminePartitionCount.getSourceNodesOutputStats(plan, node -> statsProvider.getStats((PlanNode)node).getOutputSizeInBytes(node.getOutputSymbols(), types));
        double expandingNodesMaxOutputSize = DeterminePartitionCount.getExpandingNodesMaxOutputStats(plan, node -> statsProvider.getStats((PlanNode)node).getOutputSizeInBytes(node.getOutputSymbols(), types));
        if (Double.isNaN(sourceTablesOutputSize) || Double.isNaN(expandingNodesMaxOutputSize)) {
            return Optional.empty();
        }
        int partitionCountBasedOnOutputSize = DeterminePartitionCount.getPartitionCount(Math.max(sourceTablesOutputSize, expandingNodesMaxOutputSize), minInputSizePerTask);
        int partitionCountBasedOnMemory = (int)(Math.max(sourceTablesOutputSize, expandingNodesMaxOutputSize) * 2.0 / (double)queryMaxMemoryPerNode);
        return Optional.of(Math.max(partitionCountBasedOnOutputSize, partitionCountBasedOnMemory));
    }

    private static Optional<Integer> getPartitionCountBasedOnRows(PlanNode plan, StatsProvider statsProvider, long minInputRowsPerTask) {
        double sourceTablesRowCount = DeterminePartitionCount.getSourceNodesOutputStats(plan, node -> statsProvider.getStats((PlanNode)node).getOutputRowCount());
        double expandingNodesMaxRowCount = DeterminePartitionCount.getExpandingNodesMaxOutputStats(plan, node -> statsProvider.getStats((PlanNode)node).getOutputRowCount());
        if (Double.isNaN(sourceTablesRowCount) || Double.isNaN(expandingNodesMaxRowCount)) {
            return Optional.empty();
        }
        return Optional.of(DeterminePartitionCount.getPartitionCount(Math.max(sourceTablesRowCount, expandingNodesMaxRowCount), minInputRowsPerTask));
    }

    private static int getPartitionCount(double outputStats, long minInputStatsPerTask) {
        return Math.max((int)(outputStats / (double)minInputStatsPerTask), 1);
    }

    private static boolean isInputMultiplyingPlanNodePresent(PlanNode root) {
        return PlanNodeSearcher.searchFrom(root).where(DeterminePartitionCount::isInputMultiplyingPlanNode).matches();
    }

    private static boolean isInputMultiplyingPlanNode(PlanNode node) {
        if (node instanceof UnnestNode) {
            return true;
        }
        if (node instanceof JoinNode) {
            JoinNode joinNode = (JoinNode)node;
            if (joinNode.isCrossJoin()) {
                return !QueryCardinalityUtil.isAtMostScalar(joinNode.getRight()) && !QueryCardinalityUtil.isAtMostScalar(joinNode.getLeft());
            }
            return joinNode.getCriteria().size() > 1;
        }
        return false;
    }

    private static double getExpandingNodesMaxOutputStats(PlanNode root, ToDoubleFunction<PlanNode> statsMapper) {
        List expandingNodes = PlanNodeSearcher.searchFrom(root).where(DeterminePartitionCount::isExpandingPlanNode).findAll();
        return expandingNodes.stream().mapToDouble(statsMapper).max().orElse(0.0);
    }

    private static boolean isExpandingPlanNode(PlanNode node) {
        return node instanceof JoinNode || node instanceof UnionNode || node instanceof ExchangeNode && node.getSources().size() > 1;
    }

    private static double getSourceNodesOutputStats(PlanNode root, ToDoubleFunction<PlanNode> statsMapper) {
        List sourceNodes = PlanNodeSearcher.searchFrom(root).whereIsInstanceOfAny(TableScanNode.class, ValuesNode.class).findAll();
        return sourceNodes.stream().mapToDouble(statsMapper).sum();
    }

    private static boolean isEligibleRemoteExchangePresent(PlanNode root, boolean taskRetries) {
        return PlanNodeSearcher.searchFrom(root).where(node -> {
            ExchangeNode exchangeNode;
            return node instanceof ExchangeNode && DeterminePartitionCount.isEligibleRemoteExchange(exchangeNode = (ExchangeNode)node, taskRetries);
        }).matches();
    }

    private static boolean isEligibleRemoteExchange(ExchangeNode exchangeNode, boolean taskRetries) {
        if (exchangeNode.getScope() != ExchangeNode.Scope.REMOTE || exchangeNode.getType() != ExchangeNode.Type.REPARTITION) {
            return false;
        }
        PartitioningHandle partitioningHandle = exchangeNode.getPartitioningScheme().getPartitioning().getHandle();
        return !partitioningHandle.isScaleWriters() && !partitioningHandle.isSingleNode() && partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle && (!taskRetries || partitioningHandle == SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION);
    }

    private static class Rewriter
    extends SimplePlanRewriter<Void> {
        private final int partitionCount;
        private final boolean taskRetries;

        private Rewriter(int partitionCount, boolean taskRetries) {
            this.partitionCount = partitionCount;
            this.taskRetries = taskRetries;
        }

        @Override
        public PlanNode visitExchange(ExchangeNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            List sources = (List)node.getSources().stream().map(context::rewrite).collect(ImmutableList.toImmutableList());
            PartitioningScheme partitioningScheme = node.getPartitioningScheme();
            if (DeterminePartitionCount.isEligibleRemoteExchange(node, this.taskRetries)) {
                partitioningScheme = partitioningScheme.withPartitionCount(Optional.of(this.partitionCount));
            }
            return new ExchangeNode(node.getId(), node.getType(), node.getScope(), partitioningScheme, sources, node.getInputs(), node.getOrderingScheme());
        }
    }
}

