/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DateType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import java.util.Collection;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Predicate;

public class StatsNormalizer {
    public PlanNodeStatsEstimate normalize(PlanNodeStatsEstimate stats) {
        return this.normalize(stats, Optional.empty());
    }

    public PlanNodeStatsEstimate normalize(PlanNodeStatsEstimate stats, Collection<Symbol> outputSymbols) {
        return this.normalize(stats, Optional.of(outputSymbols));
    }

    private PlanNodeStatsEstimate normalize(PlanNodeStatsEstimate stats, Optional<Collection<Symbol>> outputSymbols) {
        PlanNodeStatsEstimate.Builder normalized = PlanNodeStatsEstimate.buildFrom(stats);
        Predicate<Symbol> symbolFilter = outputSymbols.map(ImmutableSet::copyOf).map(set -> arg_0 -> ((ImmutableSet)set).contains(arg_0)).orElse(symbol -> true);
        for (Symbol symbol2 : stats.getSymbolsWithKnownStatistics()) {
            SymbolStatsEstimate normalizedSymbolStats;
            if (!symbolFilter.test(symbol2)) {
                normalized.removeSymbolStatistics(symbol2);
                continue;
            }
            SymbolStatsEstimate symbolStats = stats.getSymbolStatistics(symbol2);
            SymbolStatsEstimate symbolStatsEstimate = normalizedSymbolStats = stats.isOutputRowCountUnknown() ? this.normalizeSymbolStatsWithoutRowCount(symbol2, symbolStats) : this.normalizeSymbolStats(symbol2, symbolStats, stats);
            if (normalizedSymbolStats.isUnknown()) {
                normalized.removeSymbolStatistics(symbol2);
                continue;
            }
            if (Objects.equals(normalizedSymbolStats, symbolStats)) continue;
            normalized.addSymbolStatistics(symbol2, normalizedSymbolStats);
        }
        return normalized.build();
    }

    private SymbolStatsEstimate normalizeSymbolStatsWithoutRowCount(Symbol symbol, SymbolStatsEstimate symbolStats) {
        double maxDistinctValuesByLowHigh;
        if (symbolStats.isUnknown()) {
            return SymbolStatsEstimate.unknown();
        }
        double distinctValuesCount = symbolStats.getDistinctValuesCount();
        if (!Double.isNaN(distinctValuesCount) && distinctValuesCount > (maxDistinctValuesByLowHigh = this.maxDistinctValuesByLowHigh(symbolStats, symbol.getType()))) {
            distinctValuesCount = maxDistinctValuesByLowHigh;
        }
        if (distinctValuesCount == 0.0) {
            return SymbolStatsEstimate.zero();
        }
        return SymbolStatsEstimate.buildFrom(symbolStats).setDistinctValuesCount(distinctValuesCount).build();
    }

    private SymbolStatsEstimate normalizeSymbolStats(Symbol symbol, SymbolStatsEstimate symbolStats, PlanNodeStatsEstimate stats) {
        if (stats.getOutputRowCount() == 0.0) {
            return SymbolStatsEstimate.zero();
        }
        if (symbolStats.isUnknown()) {
            return SymbolStatsEstimate.unknown();
        }
        double outputRowCount = stats.getOutputRowCount();
        Preconditions.checkArgument((outputRowCount > 0.0 ? 1 : 0) != 0, (String)"outputRowCount must be greater than zero: %s", (Object)outputRowCount);
        double distinctValuesCount = symbolStats.getDistinctValuesCount();
        double nullsFraction = symbolStats.getNullsFraction();
        if (!Double.isNaN(distinctValuesCount)) {
            double nonNullValues;
            double maxDistinctValuesByLowHigh = this.maxDistinctValuesByLowHigh(symbolStats, symbol.getType());
            if (distinctValuesCount > maxDistinctValuesByLowHigh) {
                distinctValuesCount = maxDistinctValuesByLowHigh;
            }
            if (distinctValuesCount > outputRowCount) {
                distinctValuesCount = outputRowCount;
            }
            if (distinctValuesCount > (nonNullValues = outputRowCount * (1.0 - nullsFraction))) {
                double difference = distinctValuesCount - nonNullValues;
                distinctValuesCount -= difference / 2.0;
                nullsFraction = 1.0 - (nonNullValues += difference / 2.0) / outputRowCount;
            }
        }
        if (distinctValuesCount == 0.0) {
            return SymbolStatsEstimate.zero();
        }
        return SymbolStatsEstimate.buildFrom(symbolStats).setDistinctValuesCount(distinctValuesCount).setNullsFraction(nullsFraction).build();
    }

    private double maxDistinctValuesByLowHigh(SymbolStatsEstimate symbolStats, Type type) {
        if (symbolStats.statisticRange().length() == 0.0) {
            return 1.0;
        }
        if (!StatsNormalizer.isDiscrete(type)) {
            return Double.NaN;
        }
        double length = symbolStats.getHighValue() - symbolStats.getLowValue();
        if (Double.isNaN(length)) {
            return Double.NaN;
        }
        if (type instanceof DecimalType) {
            length *= Math.pow(10.0, ((DecimalType)type).getScale());
        }
        return Math.floor(length + 1.0);
    }

    private static boolean isDiscrete(Type type) {
        return type.equals((Object)IntegerType.INTEGER) || type.equals((Object)BigintType.BIGINT) || type.equals((Object)SmallintType.SMALLINT) || type.equals((Object)TinyintType.TINYINT) || type.equals((Object)BooleanType.BOOLEAN) || type.equals((Object)DateType.DATE) || type instanceof DecimalType;
    }
}

