/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import io.airlift.log.Logger;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.CachingStatsProvider;
import io.trino.cost.StatsCalculator;
import io.trino.cost.StatsProvider;
import io.trino.cost.TableStatsProvider;
import io.trino.execution.warnings.WarningCollector;
import io.trino.operator.RetryPolicy;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.MergeWriterNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import io.trino.sql.planner.plan.TableExecuteNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.TableWriterNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.ValuesNode;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.ToDoubleFunction;

public class DeterminePartitionCount
implements PlanOptimizer {
    private static final Logger log = Logger.get(DeterminePartitionCount.class);
    private static final List<Class<? extends PlanNode>> INSERT_NODES = ImmutableList.of(TableExecuteNode.class, TableWriterNode.class, MergeWriterNode.class);
    private final StatsCalculator statsCalculator;

    public DeterminePartitionCount(StatsCalculator statsCalculator) {
        this.statsCalculator = Objects.requireNonNull(statsCalculator, "statsCalculator is null");
    }

    @Override
    public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) {
        Objects.requireNonNull(plan, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(types, "types is null");
        Objects.requireNonNull(tableStatsProvider, "tableStatsProvider is null");
        if (PlanNodeSearcher.searchFrom(plan).whereIsInstanceOfAny(INSERT_NODES).matches() || SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK) {
            return plan;
        }
        try {
            return this.determinePartitionCount(plan, session, types, tableStatsProvider).map(partitionCount -> SimplePlanRewriter.rewriteWith(new Rewriter((int)partitionCount), plan)).orElse(plan);
        }
        catch (RuntimeException e) {
            log.warn((Throwable)e, "Error occurred when determining hash partition count for query %s", new Object[]{session.getQueryId()});
            return plan;
        }
    }

    private Optional<Integer> determinePartitionCount(PlanNode plan, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) {
        long minInputSizePerTask = SystemSessionProperties.getMinInputSizePerTask(session).toBytes();
        long minInputRowsPerTask = SystemSessionProperties.getMinInputRowsPerTask(session);
        if (minInputSizePerTask == 0L || minInputRowsPerTask == 0L) {
            return Optional.empty();
        }
        if (DeterminePartitionCount.isInputMultiplyingPlanNodePresent(plan)) {
            return Optional.empty();
        }
        CachingStatsProvider statsProvider = new CachingStatsProvider(this.statsCalculator, session, types, tableStatsProvider);
        long queryMaxMemoryPerNode = SystemSessionProperties.getQueryMaxMemoryPerNode(session).toBytes();
        Optional<Integer> partitionCountBasedOnOutputSize = DeterminePartitionCount.getPartitionCountBasedOnOutputSize(plan, statsProvider, types, minInputSizePerTask, queryMaxMemoryPerNode);
        Optional<Integer> partitionCountBasedOnRows = DeterminePartitionCount.getPartitionCountBasedOnRows(plan, statsProvider, minInputRowsPerTask);
        if (partitionCountBasedOnOutputSize.isEmpty() || partitionCountBasedOnRows.isEmpty()) {
            return Optional.empty();
        }
        int partitionCount = Math.max(Math.max(partitionCountBasedOnOutputSize.get(), partitionCountBasedOnRows.get()), SystemSessionProperties.getMinHashPartitionCount(session));
        if (partitionCount >= SystemSessionProperties.getMaxHashPartitionCount(session)) {
            return Optional.empty();
        }
        log.debug("Estimated remote exchange partition count for query %s is %s", new Object[]{session.getQueryId(), partitionCount});
        return Optional.of(partitionCount);
    }

    private static Optional<Integer> getPartitionCountBasedOnOutputSize(PlanNode plan, StatsProvider statsProvider, TypeProvider types, long minInputSizePerTask, long queryMaxMemoryPerNode) {
        double sourceTablesOutputSize = DeterminePartitionCount.getSourceNodesOutputStats(plan, node -> statsProvider.getStats((PlanNode)node).getOutputSizeInBytes(node.getOutputSymbols(), types));
        double expandingNodesMaxOutputSize = DeterminePartitionCount.getExpandingNodesMaxOutputStats(plan, node -> statsProvider.getStats((PlanNode)node).getOutputSizeInBytes(node.getOutputSymbols(), types));
        if (Double.isNaN(sourceTablesOutputSize) || Double.isNaN(expandingNodesMaxOutputSize)) {
            return Optional.empty();
        }
        int partitionCountBasedOnOutputSize = DeterminePartitionCount.getPartitionCount(Math.max(sourceTablesOutputSize, expandingNodesMaxOutputSize), minInputSizePerTask);
        int partitionCountBasedOnMemory = (int)(Math.max(sourceTablesOutputSize, expandingNodesMaxOutputSize) * 2.0 / (double)queryMaxMemoryPerNode);
        return Optional.of(Math.max(partitionCountBasedOnOutputSize, partitionCountBasedOnMemory));
    }

    private static Optional<Integer> getPartitionCountBasedOnRows(PlanNode plan, StatsProvider statsProvider, long minInputRowsPerTask) {
        double sourceTablesRowCount = DeterminePartitionCount.getSourceNodesOutputStats(plan, node -> statsProvider.getStats((PlanNode)node).getOutputRowCount());
        double expandingNodesMaxRowCount = DeterminePartitionCount.getExpandingNodesMaxOutputStats(plan, node -> statsProvider.getStats((PlanNode)node).getOutputRowCount());
        if (Double.isNaN(sourceTablesRowCount) || Double.isNaN(expandingNodesMaxRowCount)) {
            return Optional.empty();
        }
        return Optional.of(DeterminePartitionCount.getPartitionCount(Math.max(sourceTablesRowCount, expandingNodesMaxRowCount), minInputRowsPerTask));
    }

    private static int getPartitionCount(double outputStats, long minInputStatsPerTask) {
        return Math.max((int)(outputStats / (double)minInputStatsPerTask), 1);
    }

    private static boolean isInputMultiplyingPlanNodePresent(PlanNode root) {
        return PlanNodeSearcher.searchFrom(root).where(DeterminePartitionCount::isInputMultiplyingPlanNode).matches();
    }

    private static boolean isInputMultiplyingPlanNode(PlanNode node) {
        if (node instanceof UnnestNode) {
            return true;
        }
        if (node instanceof JoinNode) {
            JoinNode joinNode = (JoinNode)node;
            if (joinNode.isCrossJoin()) {
                return !QueryCardinalityUtil.isAtMostScalar(joinNode.getRight()) && !QueryCardinalityUtil.isAtMostScalar(joinNode.getLeft());
            }
            return joinNode.getCriteria().size() > 1;
        }
        return false;
    }

    private static double getExpandingNodesMaxOutputStats(PlanNode root, ToDoubleFunction<PlanNode> statsMapper) {
        List expandingNodes = PlanNodeSearcher.searchFrom(root).where(DeterminePartitionCount::isExpandingPlanNode).findAll();
        return expandingNodes.stream().mapToDouble(statsMapper).max().orElse(0.0);
    }

    private static boolean isExpandingPlanNode(PlanNode node) {
        return node instanceof JoinNode || node instanceof UnionNode || node instanceof ExchangeNode && node.getSources().size() > 1;
    }

    private static double getSourceNodesOutputStats(PlanNode root, ToDoubleFunction<PlanNode> statsMapper) {
        List sourceNodes = PlanNodeSearcher.searchFrom(root).whereIsInstanceOfAny(TableScanNode.class, ValuesNode.class).findAll();
        return sourceNodes.stream().mapToDouble(statsMapper).sum();
    }

    private static class Rewriter
    extends SimplePlanRewriter<Void> {
        private final int partitionCount;

        private Rewriter(int partitionCount) {
            this.partitionCount = partitionCount;
        }

        @Override
        public PlanNode visitExchange(ExchangeNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            PartitioningHandle handle = node.getPartitioningScheme().getPartitioning().getHandle();
            if (node.getScope() != ExchangeNode.Scope.REMOTE || !(handle.getConnectorHandle() instanceof SystemPartitioningHandle)) {
                return node;
            }
            List sources = (List)node.getSources().stream().map(context::rewrite).collect(ImmutableList.toImmutableList());
            return new ExchangeNode(node.getId(), node.getType(), node.getScope(), node.getPartitioningScheme().withPartitionCount(this.partitionCount), sources, node.getInputs(), node.getOrderingScheme());
        }
    }
}

