/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import com.google.common.base.Verify;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.PlanNodeStatsEstimateMath;
import io.trino.cost.RuntimeInfoProvider;
import io.trino.cost.SimpleStatsRule;
import io.trino.cost.StatsCalculator;
import io.trino.cost.StatsNormalizer;
import io.trino.cost.StatsProvider;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.execution.scheduler.OutputDataSizeEstimate;
import io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator;
import io.trino.matching.Pattern;
import io.trino.spi.type.FixedWidthType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.util.MoreMath;
import java.util.List;
import java.util.Optional;

public class RemoteSourceStatsRule
extends SimpleStatsRule<RemoteSourceNode> {
    private static final Pattern<RemoteSourceNode> PATTERN = Patterns.remoteSourceNode();

    public RemoteSourceStatsRule(StatsNormalizer normalizer) {
        super(normalizer);
    }

    @Override
    public Pattern<RemoteSourceNode> getPattern() {
        return PATTERN;
    }

    @Override
    protected Optional<PlanNodeStatsEstimate> doCalculate(RemoteSourceNode node, StatsCalculator.Context context) {
        Optional<PlanNodeStatsEstimate> estimate = Optional.empty();
        RuntimeInfoProvider runtimeInfoProvider = context.runtimeInfoProvider();
        for (int i = 0; i < node.getSourceFragmentIds().size(); ++i) {
            PlanFragmentId planFragmentId = node.getSourceFragmentIds().get(i);
            OutputStatsEstimator.OutputStatsEstimateResult stageRuntimeStats = runtimeInfoProvider.getRuntimeOutputStats(planFragmentId);
            PlanNodeStatsEstimate stageEstimatedStats = this.getEstimatedStats(runtimeInfoProvider, context.statsProvider(), planFragmentId);
            PlanNodeStatsEstimate adjustedStageStats = this.adjustStats(node.getOutputSymbols(), context.types(), stageRuntimeStats, stageEstimatedStats);
            estimate = estimate.map(planNodeStatsEstimate -> PlanNodeStatsEstimateMath.addStatsAndMaxDistinctValues(planNodeStatsEstimate, adjustedStageStats)).or(() -> Optional.of(adjustedStageStats));
        }
        Verify.verify((boolean)estimate.isPresent());
        return estimate;
    }

    private PlanNodeStatsEstimate getEstimatedStats(RuntimeInfoProvider runtimeInfoProvider, StatsProvider statsProvider, PlanFragmentId fragmentId) {
        PlanFragment fragment = runtimeInfoProvider.getPlanFragment(fragmentId);
        PlanNode fragmentRoot = fragment.getRoot();
        PlanNodeStatsEstimate estimate = fragment.getStatsAndCosts().getStats().get(fragmentRoot.getId());
        if (estimate != null && !estimate.isOutputRowCountUnknown()) {
            return estimate;
        }
        return statsProvider.getStats(fragmentRoot);
    }

    private PlanNodeStatsEstimate adjustStats(List<Symbol> outputs, TypeProvider typeProvider, OutputStatsEstimator.OutputStatsEstimateResult runtimeStats, PlanNodeStatsEstimate estimateStats) {
        if (runtimeStats.isUnknown()) {
            return estimateStats;
        }
        OutputDataSizeEstimate outputDataSizeEstimate = runtimeStats.outputDataSizeEstimate();
        PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder().setOutputRowCount(runtimeStats.outputRowCountEstimate());
        double fixedWidthTypeSize = 0.0;
        double variableTypeValuesCount = 0.0;
        for (Symbol outputSymbol : outputs) {
            Type type = typeProvider.get(outputSymbol);
            SymbolStatsEstimate symbolStatistics = estimateStats.getSymbolStatistics(outputSymbol);
            double nullsFraction = MoreMath.firstNonNaN(symbolStatistics.getNullsFraction(), 0.0);
            double numberOfNonNullRows = (double)runtimeStats.outputRowCountEstimate() * (1.0 - nullsFraction);
            if (type instanceof FixedWidthType) {
                fixedWidthTypeSize += numberOfNonNullRows * (double)((FixedWidthType)type).getFixedSize();
                continue;
            }
            variableTypeValuesCount += numberOfNonNullRows;
        }
        double runtimeOutputDataSize = outputDataSizeEstimate.getTotalSizeInBytes();
        double variableTypeValueAverageSize = Double.NaN;
        if (variableTypeValuesCount > 0.0 && runtimeOutputDataSize > fixedWidthTypeSize) {
            variableTypeValueAverageSize = (runtimeOutputDataSize - fixedWidthTypeSize) / variableTypeValuesCount;
        }
        for (Symbol outputSymbol : outputs) {
            SymbolStatsEstimate symbolStatistics = estimateStats.getSymbolStatistics(outputSymbol);
            Type type = typeProvider.get(outputSymbol);
            if (!Double.isNaN(variableTypeValueAverageSize) && !(type instanceof FixedWidthType)) {
                symbolStatistics = SymbolStatsEstimate.buildFrom(symbolStatistics).setAverageRowSize(variableTypeValueAverageSize).build();
            }
            result.addSymbolStatistics(outputSymbol, symbolStatistics);
        }
        return result.build();
    }
}

