/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.trino.SystemSessionProperties;
import io.trino.metadata.GlobalFunctionCatalog;
import io.trino.metadata.Metadata;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.ir.CoalesceExpression;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrExpressions;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.GroupIdNode;
import io.trino.sql.planner.plan.MarkDistinctNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

public class OptimizeMixedDistinctAggregations
implements PlanOptimizer {
    private static final CatalogSchemaFunctionName COUNT_NAME = GlobalFunctionCatalog.builtinFunctionName("count");
    private static final CatalogSchemaFunctionName COUNT_IF_NAME = GlobalFunctionCatalog.builtinFunctionName("count_if");
    private static final CatalogSchemaFunctionName APPROX_DISTINCT_NAME = GlobalFunctionCatalog.builtinFunctionName("approx_distinct");
    private final Metadata metadata;

    public OptimizeMixedDistinctAggregations(Metadata metadata) {
        this.metadata = metadata;
    }

    @Override
    public PlanNode optimize(PlanNode plan, PlanOptimizer.Context context) {
        if (SystemSessionProperties.isOptimizeDistinctAggregationEnabled(context.session())) {
            return SimplePlanRewriter.rewriteWith(new Optimizer(context.idAllocator(), context.symbolAllocator(), this.metadata), plan, Optional.empty());
        }
        return plan;
    }

    private static class Optimizer
    extends SimplePlanRewriter<Optional<AggregateInfo>> {
        private final PlanNodeIdAllocator idAllocator;
        private final SymbolAllocator symbolAllocator;
        private final Metadata metadata;

        private Optimizer(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Metadata metadata) {
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, SimplePlanRewriter.RewriteContext<Optional<AggregateInfo>> context) {
            List masks = (List)node.getAggregations().values().stream().map(AggregationNode.Aggregation::getMask).filter(Optional::isPresent).map(Optional::get).collect(ImmutableList.toImmutableList());
            ImmutableSet uniqueMasks = ImmutableSet.copyOf((Collection)masks);
            if (uniqueMasks.size() != 1 || masks.size() == node.getAggregations().size()) {
                return context.defaultRewrite(node, Optional.empty());
            }
            if (node.getAggregations().values().stream().map(AggregationNode.Aggregation::getFilter).anyMatch(Optional::isPresent)) {
                return context.defaultRewrite(node, Optional.empty());
            }
            if (node.hasOrderings()) {
                return context.defaultRewrite(node, Optional.empty());
            }
            AggregateInfo aggregateInfo = new AggregateInfo(node.getGroupingKeys(), (Symbol)Iterables.getOnlyElement((Iterable)uniqueMasks), node.getAggregations());
            if (!this.checkAllEquatableTypes(aggregateInfo)) {
                return context.defaultRewrite(node, Optional.empty());
            }
            PlanNode source = context.rewrite(node.getSource(), Optional.of(aggregateInfo));
            if (!aggregateInfo.isFoundMarkDistinct()) {
                return context.defaultRewrite(node, Optional.empty());
            }
            ImmutableMap.Builder aggregations = ImmutableMap.builder();
            ImmutableMap.Builder coalesceSymbolsBuilder = ImmutableMap.builder();
            for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
                AggregationNode.Aggregation aggregation = entry.getValue();
                if (aggregation.getMask().isPresent()) {
                    aggregations.put((Object)entry.getKey(), (Object)new AggregationNode.Aggregation(aggregation.getResolvedFunction(), (List<Expression>)ImmutableList.of((Object)aggregateInfo.getNewDistinctAggregateSymbol().toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty()));
                    continue;
                }
                Symbol argument = aggregateInfo.getNewNonDistinctAggregateSymbols().get(entry.getKey());
                AggregationNode.Aggregation newAggregation = new AggregationNode.Aggregation(this.metadata.resolveBuiltinFunction("arbitrary", TypeSignatureProvider.fromTypes(argument.getType())), (List<Expression>)ImmutableList.of((Object)argument.toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty());
                CatalogSchemaFunctionName signatureName = aggregation.getResolvedFunction().getSignature().getName();
                if (signatureName.equals((Object)COUNT_NAME) || signatureName.equals((Object)COUNT_IF_NAME) || signatureName.equals((Object)APPROX_DISTINCT_NAME)) {
                    Symbol newSymbol = this.symbolAllocator.newSymbol("expr", entry.getKey().getType());
                    aggregations.put((Object)newSymbol, (Object)newAggregation);
                    coalesceSymbolsBuilder.put((Object)newSymbol, (Object)entry.getKey());
                    continue;
                }
                aggregations.put((Object)entry.getKey(), (Object)newAggregation);
            }
            ImmutableMap coalesceSymbols = coalesceSymbolsBuilder.buildOrThrow();
            AggregationNode aggregationNode = new AggregationNode(this.idAllocator.getNextId(), source, (Map<Symbol, AggregationNode.Aggregation>)aggregations.buildOrThrow(), node.getGroupingSets(), (List<Symbol>)ImmutableList.of(), node.getStep(), Optional.empty(), node.getGroupIdSymbol());
            if (coalesceSymbols.isEmpty()) {
                return aggregationNode;
            }
            Assignments.Builder outputSymbols = Assignments.builder();
            for (Symbol symbol : aggregationNode.getOutputSymbols()) {
                if (coalesceSymbols.containsKey(symbol)) {
                    CoalesceExpression expression = new CoalesceExpression(symbol.toSymbolReference(), new Constant((Type)BigintType.BIGINT, 0L), new Expression[0]);
                    outputSymbols.put((Symbol)coalesceSymbols.get(symbol), expression);
                    continue;
                }
                outputSymbols.putIdentity(symbol);
            }
            return new ProjectNode(this.idAllocator.getNextId(), aggregationNode, outputSymbols.build());
        }

        @Override
        public PlanNode visitMarkDistinct(MarkDistinctNode node, SimplePlanRewriter.RewriteContext<Optional<AggregateInfo>> context) {
            Symbol distinctSymbol;
            Optional<AggregateInfo> aggregateInfo = context.get();
            if (aggregateInfo.isEmpty() || !aggregateInfo.get().getMask().equals(node.getMarkerSymbol())) {
                return context.defaultRewrite(node, Optional.empty());
            }
            aggregateInfo.get().foundMarkDistinct();
            PlanNode source = context.rewrite(node.getSource(), Optional.empty());
            HashSet<Symbol> allSymbols = new HashSet<Symbol>();
            List<Symbol> groupBySymbols = aggregateInfo.get().getGroupBySymbols();
            List<Symbol> nonDistinctAggregateSymbols = aggregateInfo.get().getOriginalNonDistinctAggregateArgs();
            Symbol duplicatedDistinctSymbol = distinctSymbol = (Symbol)Iterables.getOnlyElement(aggregateInfo.get().getOriginalDistinctAggregateArgs());
            if (nonDistinctAggregateSymbols.contains(distinctSymbol)) {
                Symbol newSymbol = this.symbolAllocator.newSymbol(distinctSymbol.getName(), distinctSymbol.getType());
                nonDistinctAggregateSymbols.set(nonDistinctAggregateSymbols.indexOf(distinctSymbol), newSymbol);
                duplicatedDistinctSymbol = newSymbol;
            }
            allSymbols.addAll(groupBySymbols);
            allSymbols.addAll(nonDistinctAggregateSymbols);
            allSymbols.add(distinctSymbol);
            Symbol groupSymbol = this.symbolAllocator.newSymbol("group", (Type)BigintType.BIGINT);
            GroupIdNode groupIdNode = this.createGroupIdNode(groupBySymbols, nonDistinctAggregateSymbols, distinctSymbol, duplicatedDistinctSymbol, groupSymbol, allSymbols, source);
            HashSet<Symbol> groupByKeys = new HashSet<Symbol>(groupBySymbols);
            groupByKeys.add(distinctSymbol);
            groupByKeys.add(groupSymbol);
            ImmutableMap.Builder aggregationOutputSymbolsMapBuilder = ImmutableMap.builder();
            AggregationNode aggregationNode = this.createNonDistinctAggregation(aggregateInfo.get(), distinctSymbol, duplicatedDistinctSymbol, groupByKeys, groupIdNode, node, (ImmutableMap.Builder<Symbol, Symbol>)aggregationOutputSymbolsMapBuilder);
            ImmutableMap aggregationOutputSymbolsMap = aggregationOutputSymbolsMapBuilder.buildOrThrow();
            ProjectNode projectNode = this.createProjectNode(aggregationNode, aggregateInfo.get(), distinctSymbol, groupSymbol, groupBySymbols, (Map<Symbol, Symbol>)aggregationOutputSymbolsMap);
            return projectNode;
        }

        private boolean checkAllEquatableTypes(AggregateInfo aggregateInfo) {
            for (Symbol symbol : aggregateInfo.getOriginalNonDistinctAggregateArgs()) {
                Type type = symbol.getType();
                if (type.isComparable()) continue;
                return false;
            }
            return aggregateInfo.getMask().getType().isComparable();
        }

        private ProjectNode createProjectNode(AggregationNode source, AggregateInfo aggregateInfo, Symbol distinctSymbol, Symbol groupSymbol, List<Symbol> groupBySymbols, Map<Symbol, Symbol> aggregationOutputSymbolsMap) {
            Assignments.Builder outputSymbols = Assignments.builder();
            ImmutableMap.Builder outputNonDistinctAggregateSymbols = ImmutableMap.builder();
            for (Symbol symbol : source.getOutputSymbols()) {
                if (distinctSymbol.equals(symbol)) {
                    newSymbol = this.symbolAllocator.newSymbol("expr", symbol.getType());
                    aggregateInfo.setNewDistinctAggregateSymbol(newSymbol);
                    expression = Optimizer.createIfExpression(groupSymbol.toSymbolReference(), new Constant((Type)BigintType.BIGINT, 1L), ComparisonExpression.Operator.EQUAL, symbol.toSymbolReference(), symbol.getType());
                    outputSymbols.put(newSymbol, expression);
                } else if (aggregationOutputSymbolsMap.containsKey(symbol)) {
                    newSymbol = this.symbolAllocator.newSymbol("expr", symbol.getType());
                    outputNonDistinctAggregateSymbols.put((Object)aggregationOutputSymbolsMap.get(symbol), (Object)newSymbol);
                    expression = Optimizer.createIfExpression(groupSymbol.toSymbolReference(), new Constant((Type)BigintType.BIGINT, 0L), ComparisonExpression.Operator.EQUAL, symbol.toSymbolReference(), symbol.getType());
                    outputSymbols.put(newSymbol, expression);
                }
                if (!groupBySymbols.contains(symbol)) continue;
                SymbolReference expression = symbol.toSymbolReference();
                outputSymbols.put(symbol, expression);
            }
            outputSymbols.put(aggregateInfo.getMask(), new Constant((Type)BooleanType.BOOLEAN, null));
            aggregateInfo.setNewNonDistinctAggregateSymbols((Map<Symbol, Symbol>)outputNonDistinctAggregateSymbols.buildOrThrow());
            return new ProjectNode(this.idAllocator.getNextId(), source, outputSymbols.build());
        }

        private GroupIdNode createGroupIdNode(List<Symbol> groupBySymbols, List<Symbol> nonDistinctAggregateSymbols, Symbol distinctSymbol, Symbol duplicatedDistinctSymbol, Symbol groupSymbol, Set<Symbol> allSymbols, PlanNode source) {
            ArrayList<List<Symbol>> groups = new ArrayList<List<Symbol>>();
            HashSet<Symbol> group0 = new HashSet<Symbol>();
            group0.addAll(groupBySymbols);
            group0.addAll(nonDistinctAggregateSymbols);
            groups.add((List<Symbol>)ImmutableList.copyOf(group0));
            HashSet<Symbol> group1 = new HashSet<Symbol>(groupBySymbols);
            group1.add(distinctSymbol);
            groups.add((List<Symbol>)ImmutableList.copyOf(group1));
            return new GroupIdNode(this.idAllocator.getNextId(), source, groups, allSymbols.stream().collect(Collectors.toMap(symbol -> symbol, symbol -> symbol.equals(duplicatedDistinctSymbol) ? distinctSymbol : symbol)), (List<Symbol>)ImmutableList.of(), groupSymbol);
        }

        private AggregationNode createNonDistinctAggregation(AggregateInfo aggregateInfo, Symbol distinctSymbol, Symbol duplicatedDistinctSymbol, Set<Symbol> groupByKeys, GroupIdNode groupIdNode, MarkDistinctNode originalNode, ImmutableMap.Builder<Symbol, Symbol> aggregationOutputSymbolsMapBuilder) {
            ImmutableMap.Builder aggregations = ImmutableMap.builder();
            for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregateInfo.getAggregations().entrySet()) {
                AggregationNode.Aggregation aggregation = entry.getValue();
                if (!aggregation.getMask().isEmpty()) continue;
                Symbol newSymbol = this.symbolAllocator.newSymbol(entry.getKey().toSymbolReference(), entry.getKey().getType());
                aggregationOutputSymbolsMapBuilder.put((Object)newSymbol, (Object)entry.getKey());
                if (!duplicatedDistinctSymbol.equals(distinctSymbol) && aggregation.getArguments().contains(distinctSymbol.toSymbolReference())) {
                    ImmutableList.Builder arguments = ImmutableList.builder();
                    for (Expression argument : aggregation.getArguments()) {
                        if (distinctSymbol.toSymbolReference().equals(argument)) {
                            arguments.add((Object)duplicatedDistinctSymbol.toSymbolReference());
                            continue;
                        }
                        arguments.add((Object)argument);
                    }
                    aggregation = new AggregationNode.Aggregation(aggregation.getResolvedFunction(), (List<Expression>)arguments.build(), false, Optional.empty(), Optional.empty(), Optional.empty());
                }
                aggregations.put((Object)newSymbol, (Object)aggregation);
            }
            return new AggregationNode(this.idAllocator.getNextId(), groupIdNode, (Map<Symbol, AggregationNode.Aggregation>)aggregations.buildOrThrow(), AggregationNode.singleGroupingSet((List<Symbol>)ImmutableList.copyOf(groupByKeys)), (List<Symbol>)ImmutableList.of(), AggregationNode.Step.SINGLE, originalNode.getHashSymbol(), Optional.empty());
        }

        private static Expression createIfExpression(Expression left, Expression right, ComparisonExpression.Operator operator, Expression result, Type trueValueType) {
            return IrExpressions.ifExpression(new ComparisonExpression(operator, left, right), result, new Constant(trueValueType, null));
        }
    }

    private static class AggregateInfo {
        private final List<Symbol> groupBySymbols;
        private final Symbol mask;
        private final Map<Symbol, AggregationNode.Aggregation> aggregations;
        private Map<Symbol, Symbol> newNonDistinctAggregateSymbols;
        private Symbol newDistinctAggregateSymbol;
        private boolean foundMarkDistinct;

        public AggregateInfo(List<Symbol> groupBySymbols, Symbol mask, Map<Symbol, AggregationNode.Aggregation> aggregations) {
            this.groupBySymbols = ImmutableList.copyOf(groupBySymbols);
            this.mask = mask;
            this.aggregations = ImmutableMap.copyOf(aggregations);
        }

        public List<Symbol> getOriginalNonDistinctAggregateArgs() {
            return this.aggregations.values().stream().filter(aggregation -> aggregation.getMask().isEmpty()).flatMap(aggregation -> aggregation.getArguments().stream()).distinct().map(Symbol::from).collect(Collectors.toList());
        }

        public List<Symbol> getOriginalDistinctAggregateArgs() {
            return this.aggregations.values().stream().filter(aggregation -> aggregation.getMask().isPresent()).flatMap(aggregation -> aggregation.getArguments().stream()).distinct().map(Symbol::from).collect(Collectors.toList());
        }

        public Symbol getNewDistinctAggregateSymbol() {
            return this.newDistinctAggregateSymbol;
        }

        public void setNewDistinctAggregateSymbol(Symbol newDistinctAggregateSymbol) {
            this.newDistinctAggregateSymbol = newDistinctAggregateSymbol;
        }

        public Map<Symbol, Symbol> getNewNonDistinctAggregateSymbols() {
            return this.newNonDistinctAggregateSymbols;
        }

        public void setNewNonDistinctAggregateSymbols(Map<Symbol, Symbol> newNonDistinctAggregateSymbols) {
            this.newNonDistinctAggregateSymbols = newNonDistinctAggregateSymbols;
        }

        public Symbol getMask() {
            return this.mask;
        }

        public List<Symbol> getGroupBySymbols() {
            return this.groupBySymbols;
        }

        public Map<Symbol, AggregationNode.Aggregation> getAggregations() {
            return this.aggregations;
        }

        public void foundMarkDistinct() {
            this.foundMarkDistinct = true;
        }

        public boolean isFoundMarkDistinct() {
            return this.foundMarkDistinct;
        }
    }
}

