/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.scheduler.faulttolerant;

import com.google.common.primitives.ImmutableLongArray;
import io.airlift.units.DataSize;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.execution.StageId;
import io.trino.execution.scheduler.faulttolerant.EventDrivenFaultTolerantQueryScheduler;
import io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator;
import io.trino.execution.scheduler.faulttolerant.OutputStatsEstimatorFactory;
import io.trino.spi.QueryId;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.RemoteSourceNode;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;

public class BySmallStageOutputStatsEstimator
implements OutputStatsEstimator {
    private final QueryId queryId;
    private final boolean smallStageEstimationEnabled;
    private final DataSize smallStageEstimationThreshold;
    private final double smallStageSourceSizeMultiplier;
    private final DataSize smallSizePartitionSizeEstimate;
    private final boolean smallStageRequireNoMorePartitions;

    private BySmallStageOutputStatsEstimator(QueryId queryId, boolean smallStageEstimationEnabled, DataSize smallStageEstimationThreshold, double smallStageSourceSizeMultiplier, DataSize smallSizePartitionSizeEstimate, boolean smallStageRequireNoMorePartitions) {
        this.queryId = Objects.requireNonNull(queryId, "queryId is null");
        this.smallStageEstimationEnabled = smallStageEstimationEnabled;
        this.smallStageEstimationThreshold = Objects.requireNonNull(smallStageEstimationThreshold, "smallStageEstimationThreshold is null");
        this.smallStageSourceSizeMultiplier = smallStageSourceSizeMultiplier;
        this.smallSizePartitionSizeEstimate = Objects.requireNonNull(smallSizePartitionSizeEstimate, "smallSizePartitionSizeEstimate is null");
        this.smallStageRequireNoMorePartitions = smallStageRequireNoMorePartitions;
    }

    @Override
    public Optional<OutputStatsEstimator.OutputStatsEstimateResult> getEstimatedOutputStats(EventDrivenFaultTolerantQueryScheduler.StageExecution stageExecution, Function<StageId, EventDrivenFaultTolerantQueryScheduler.StageExecution> stageExecutionLookup, boolean parentEager) {
        if (!this.smallStageEstimationEnabled) {
            return Optional.empty();
        }
        if (this.smallStageRequireNoMorePartitions && !stageExecution.isNoMorePartitions()) {
            return Optional.empty();
        }
        long[] currentOutputDataSize = stageExecution.currentOutputDataSize();
        long totaleOutputDataSize = 0L;
        for (long partitionOutputDataSize : currentOutputDataSize) {
            totaleOutputDataSize += partitionOutputDataSize;
        }
        if (totaleOutputDataSize > this.smallStageEstimationThreshold.toBytes()) {
            return Optional.empty();
        }
        PlanFragment planFragment = stageExecution.getStageInfo().getPlan();
        boolean hasPartitionedSources = planFragment.getPartitionedSources().size() > 0;
        List<RemoteSourceNode> remoteSourceNodes = planFragment.getRemoteSourceNodes();
        long partitionedInputSizeEstimate = 0L;
        if (hasPartitionedSources) {
            if (!stageExecution.isNoMorePartitions()) {
                return Optional.empty();
            }
            partitionedInputSizeEstimate += (long)stageExecution.getPartitionsCount() * this.smallSizePartitionSizeEstimate.toBytes();
        }
        long remoteInputSizeEstimate = 0L;
        for (RemoteSourceNode remoteSourceNode : remoteSourceNodes) {
            for (PlanFragmentId sourceFragmentId : remoteSourceNode.getSourceFragmentIds()) {
                StageId sourceStageId = StageId.create(this.queryId, sourceFragmentId);
                EventDrivenFaultTolerantQueryScheduler.StageExecution sourceStage = stageExecutionLookup.apply(sourceStageId);
                Objects.requireNonNull(sourceStage, "sourceStage is null");
                Optional<OutputStatsEstimator.OutputStatsEstimateResult> sourceStageOutputDataSize = sourceStage.getOutputStats(stageExecutionLookup, false);
                if (sourceStageOutputDataSize.isEmpty()) {
                    return Optional.empty();
                }
                remoteInputSizeEstimate += sourceStageOutputDataSize.orElseThrow().outputDataSizeEstimate().getTotalSizeInBytes();
            }
        }
        long inputSizeEstimate = (long)((double)(partitionedInputSizeEstimate + remoteInputSizeEstimate) * this.smallStageSourceSizeMultiplier);
        if (inputSizeEstimate > this.smallStageEstimationThreshold.toBytes()) {
            return Optional.empty();
        }
        int outputPartitionsCount = stageExecution.getSinkPartitioningScheme().getPartitionCount();
        ImmutableLongArray.Builder estimateBuilder = ImmutableLongArray.builder((int)outputPartitionsCount);
        for (int i = 0; i < outputPartitionsCount; ++i) {
            estimateBuilder.add(inputSizeEstimate / (long)outputPartitionsCount);
        }
        return Optional.of(new OutputStatsEstimator.OutputStatsEstimateResult(estimateBuilder.build(), 0L, "BY_SMALL_INPUT"));
    }

    public static class Factory
    implements OutputStatsEstimatorFactory {
        @Override
        public OutputStatsEstimator create(Session session) {
            return new BySmallStageOutputStatsEstimator(session.getQueryId(), SystemSessionProperties.isFaultTolerantExecutionSmallStageEstimationEnabled(session), SystemSessionProperties.getFaultTolerantExecutionSmallStageEstimationThreshold(session), SystemSessionProperties.getFaultTolerantExecutionSmallStageSourceSizeMultiplier(session), SystemSessionProperties.getFaultTolerantExecutionArbitraryDistributionComputeTaskTargetSizeMin(session), SystemSessionProperties.isFaultTolerantExecutionSmallStageRequireNoMorePartitions(session));
        }
    }
}

