/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.metadata.FunctionResolver;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.statistics.ColumnStatisticMetadata;
import io.trino.spi.statistics.ColumnStatisticType;
import io.trino.spi.statistics.TableStatisticType;
import io.trino.spi.statistics.TableStatisticsMetadata;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.StatisticAggregations;
import io.trino.sql.planner.plan.StatisticAggregationsDescriptor;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.QualifiedName;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public class StatisticsAggregationPlanner {
    private final SymbolAllocator symbolAllocator;
    private final Metadata metadata;
    private final Session session;
    private final FunctionResolver functionResolver;

    public StatisticsAggregationPlanner(SymbolAllocator symbolAllocator, PlannerContext plannerContext, Session session) {
        this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        this.metadata = plannerContext.getMetadata();
        this.session = Objects.requireNonNull(session, "session is null");
        this.functionResolver = plannerContext.getFunctionResolver();
    }

    public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMetadata statisticsMetadata, Map<String, Symbol> columnToSymbolMap) {
        StatisticAggregationsDescriptor.Builder descriptor = StatisticAggregationsDescriptor.builder();
        List groupingColumns = statisticsMetadata.getGroupingColumns();
        List groupingSymbols = (List)groupingColumns.stream().map(columnToSymbolMap::get).collect(ImmutableList.toImmutableList());
        for (int i = 0; i < groupingSymbols.size(); ++i) {
            descriptor.addGrouping((String)groupingColumns.get(i), (Symbol)groupingSymbols.get(i));
        }
        ImmutableMap.Builder aggregations = ImmutableMap.builder();
        for (TableStatisticType type : statisticsMetadata.getTableStatistics()) {
            if (type != TableStatisticType.ROW_COUNT) {
                throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.NOT_SUPPORTED, "Table-wide statistic type not supported: " + String.valueOf(type));
            }
            AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation(this.metadata.resolveBuiltinFunction("count", (List<TypeSignatureProvider>)ImmutableList.of()), (List<Expression>)ImmutableList.of(), false, Optional.empty(), Optional.empty(), Optional.empty());
            Symbol symbol = this.symbolAllocator.newSymbol("rowCount", (Type)BigintType.BIGINT);
            aggregations.put((Object)symbol, (Object)aggregation);
            descriptor.addTableStatistic(TableStatisticType.ROW_COUNT, symbol);
        }
        for (ColumnStatisticMetadata columnStatisticMetadata : statisticsMetadata.getColumnStatistics()) {
            String symbolHint;
            ColumnStatisticsAggregation aggregation;
            String columnName = columnStatisticMetadata.getColumnName();
            Symbol inputSymbol = columnToSymbolMap.get(columnName);
            Verify.verifyNotNull((Object)inputSymbol, (String)"no symbol for [%s] column, these columns exist: %s", (Object[])new Object[]{columnName, columnToSymbolMap.keySet()});
            Type inputType = this.symbolAllocator.getTypes().get(inputSymbol);
            Verify.verifyNotNull((Object)inputType, (String)"inputType is null for symbol: %s", (Object[])new Object[]{inputSymbol});
            if (columnStatisticMetadata.getStatisticTypeIfPresent().isPresent()) {
                ColumnStatisticType statisticType = columnStatisticMetadata.getStatisticType();
                aggregation = this.createColumnAggregation(statisticType, inputSymbol, inputType);
                symbolHint = String.valueOf(statisticType) + ":" + columnName;
            } else {
                FunctionName aggregationName = columnStatisticMetadata.getAggregation();
                aggregation = this.createColumnAggregation(aggregationName, inputSymbol, inputType);
                symbolHint = aggregationName.getName() + ":" + columnName;
            }
            Symbol symbol = this.symbolAllocator.newSymbol(symbolHint, aggregation.getOutputType());
            aggregations.put((Object)symbol, (Object)aggregation.getAggregation());
            descriptor.addColumnStatistic(columnStatisticMetadata, symbol);
        }
        StatisticAggregations aggregation = new StatisticAggregations((Map<Symbol, AggregationNode.Aggregation>)aggregations.buildOrThrow(), groupingSymbols);
        return new TableStatisticAggregation(aggregation, descriptor.build());
    }

    private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticType statisticType, Symbol input, Type inputType) {
        return switch (statisticType) {
            default -> throw new MatchException(null, null);
            case ColumnStatisticType.MIN_VALUE -> this.createAggregation("min", input, inputType);
            case ColumnStatisticType.MAX_VALUE -> this.createAggregation("max", input, inputType);
            case ColumnStatisticType.NUMBER_OF_DISTINCT_VALUES -> this.createAggregation("approx_distinct", input, inputType);
            case ColumnStatisticType.NUMBER_OF_DISTINCT_VALUES_SUMMARY -> this.createAggregation("$approx_set", input, inputType);
            case ColumnStatisticType.NUMBER_OF_NON_NULL_VALUES -> this.createAggregation("count", input, inputType);
            case ColumnStatisticType.NUMBER_OF_TRUE_VALUES -> this.createAggregation("count_if", input, (Type)BooleanType.BOOLEAN);
            case ColumnStatisticType.TOTAL_SIZE_IN_BYTES -> this.createAggregation("$internal$sum_data_size_for_stats", input, inputType);
            case ColumnStatisticType.MAX_VALUE_SIZE_IN_BYTES -> this.createAggregation("$internal$max_data_size_for_stats", input, inputType);
        };
    }

    private ColumnStatisticsAggregation createColumnAggregation(FunctionName aggregation, Symbol input, Type inputType) {
        QualifiedName name = aggregation.getCatalogSchema().map(catalogSchemaName -> QualifiedName.of((String)catalogSchemaName.getCatalogName(), (String[])new String[]{catalogSchemaName.getSchemaName(), aggregation.getName()})).orElseGet(() -> QualifiedName.of((String)aggregation.getName()));
        return StatisticsAggregationPlanner.createAggregation(this.functionResolver.resolveFunction(this.session, name, TypeSignatureProvider.fromTypes(inputType), new AllowAllAccessControl()), input, inputType);
    }

    private ColumnStatisticsAggregation createAggregation(String functionName, Symbol input, Type inputType) {
        return StatisticsAggregationPlanner.createAggregation(this.metadata.resolveBuiltinFunction(functionName, TypeSignatureProvider.fromTypes(inputType)), input, inputType);
    }

    private static ColumnStatisticsAggregation createAggregation(ResolvedFunction resolvedFunction, Symbol input, Type inputType) {
        Type resolvedType = (Type)Iterables.getOnlyElement((Iterable)resolvedFunction.getSignature().getArgumentTypes());
        Verify.verify((boolean)resolvedType.equals((Object)inputType), (String)"resolved function input type does not match the input type: %s != %s", (Object)resolvedType, (Object)inputType);
        return new ColumnStatisticsAggregation(new AggregationNode.Aggregation(resolvedFunction, (List<Expression>)ImmutableList.of((Object)input.toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty()), resolvedFunction.getSignature().getReturnType());
    }

    public static class ColumnStatisticsAggregation {
        private final AggregationNode.Aggregation aggregation;
        private final Type outputType;

        private ColumnStatisticsAggregation(AggregationNode.Aggregation aggregation, Type outputType) {
            this.aggregation = Objects.requireNonNull(aggregation, "aggregation is null");
            this.outputType = Objects.requireNonNull(outputType, "outputType is null");
        }

        public AggregationNode.Aggregation getAggregation() {
            return this.aggregation;
        }

        public Type getOutputType() {
            return this.outputType;
        }
    }

    public static class TableStatisticAggregation {
        private final StatisticAggregations aggregations;
        private final StatisticAggregationsDescriptor<Symbol> descriptor;

        private TableStatisticAggregation(StatisticAggregations aggregations, StatisticAggregationsDescriptor<Symbol> descriptor) {
            this.aggregations = Objects.requireNonNull(aggregations, "aggregations is null");
            this.descriptor = Objects.requireNonNull(descriptor, "descriptor is null");
        }

        public StatisticAggregations getAggregations() {
            return this.aggregations;
        }

        public StatisticAggregationsDescriptor<Symbol> getDescriptor() {
            return this.descriptor;
        }
    }
}

