/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.TrinoException;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.Int128;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.RealType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.analyzer.TypeSignatureTranslator;
import io.trino.sql.planner.ExpressionInterpreter;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.SearchedCaseExpression;
import io.trino.sql.tree.SymbolReference;
import io.trino.sql.tree.WhenClause;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;

public class PreAggregateCaseAggregations
implements Rule<AggregationNode> {
    private static final int MIN_AGGREGATION_COUNT = 4;
    private static final Set<String> ALLOWED_FUNCTIONS = ImmutableSet.of((Object)"max", (Object)"min", (Object)"sum");
    private static final Capture<ProjectNode> PROJECT_CAPTURE = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(aggregation -> aggregation.getStep() == AggregationNode.Step.SINGLE && aggregation.getGroupingSetCount() == 1).with(Patterns.source().matching(Patterns.project().capturedAs(PROJECT_CAPTURE).with(Patterns.source().matching(Predicate.not(AggregationNode.class::isInstance)))));
    private final PlannerContext plannerContext;
    private final TypeAnalyzer typeAnalyzer;

    public PreAggregateCaseAggregations(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
        this.typeAnalyzer = Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isPreAggregateCaseAggregationsEnabled(session);
    }

    @Override
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        ProjectNode projectNode = (ProjectNode)captures.get(PROJECT_CAPTURE);
        Optional<List<CaseAggregation>> aggregationsOptional = this.extractCaseAggregations(aggregationNode, projectNode, context);
        if (aggregationsOptional.isEmpty()) {
            return Rule.Result.empty();
        }
        List<CaseAggregation> aggregations = aggregationsOptional.get();
        if (aggregations.size() < 4) {
            return Rule.Result.empty();
        }
        Set extraGroupingKeys = (Set)aggregations.stream().flatMap(expression -> SymbolsExtractor.extractUnique(expression.getOperand()).stream()).collect(ImmutableSet.toImmutableSet());
        if (extraGroupingKeys.size() != 1) {
            return Rule.Result.empty();
        }
        Map<PreAggregationKey, PreAggregation> preAggregations = this.getPreAggregations(aggregations, context);
        Assignments.Builder preGroupingExpressionsBuilder = Assignments.builder();
        preGroupingExpressionsBuilder.putIdentities(extraGroupingKeys);
        aggregationNode.getGroupingKeys().forEach(symbol -> preGroupingExpressionsBuilder.put((Symbol)symbol, projectNode.getAssignments().get((Symbol)symbol)));
        Assignments preGroupingExpressions = preGroupingExpressionsBuilder.build();
        ProjectNode preProjection = this.createPreProjection(projectNode.getSource(), preGroupingExpressions, preAggregations, context);
        AggregationNode preAggregation = this.createPreAggregation(preProjection, preGroupingExpressions.getOutputs(), preAggregations, context);
        Map<CaseAggregation, Symbol> newProjectionSymbols = this.getNewProjectionSymbols(aggregations, context);
        ProjectNode newProjection = this.createNewProjection(preAggregation, aggregationNode, projectNode, newProjectionSymbols, preAggregations);
        return Rule.Result.ofPlanNode(this.createNewAggregation(newProjection, aggregationNode, newProjectionSymbols));
    }

    private AggregationNode createNewAggregation(PlanNode source, AggregationNode aggregationNode, Map<CaseAggregation, Symbol> newProjectionSymbols) {
        return new AggregationNode(aggregationNode.getId(), source, (Map)newProjectionSymbols.entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> ((CaseAggregation)entry.getKey()).getAggregationSymbol(), entry -> new AggregationNode.Aggregation(((CaseAggregation)entry.getKey()).getCumulativeFunction(), (List<Expression>)ImmutableList.of((Object)((Symbol)entry.getValue()).toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty()))), aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedSymbols(), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol());
    }

    private ProjectNode createNewProjection(PlanNode source, AggregationNode aggregationNode, ProjectNode projectNode, Map<CaseAggregation, Symbol> newProjectionSymbols, Map<PreAggregationKey, PreAggregation> preAggregations) {
        Assignments.Builder assignments = Assignments.builder();
        assignments.putIdentities(aggregationNode.getGroupingKeys());
        newProjectionSymbols.forEach((aggregation, symbol) -> assignments.put((Symbol)symbol, (Expression)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause(aggregation.getOperand(), (Expression)((PreAggregation)preAggregations.get(new PreAggregationKey((CaseAggregation)aggregation))).getAggregationSymbol().toSymbolReference())), aggregation.getCumulativeAggregationDefaultValue())));
        return new ProjectNode(projectNode.getId(), source, assignments.build());
    }

    private Map<CaseAggregation, Symbol> getNewProjectionSymbols(List<CaseAggregation> aggregations, Rule.Context context) {
        return (Map)aggregations.stream().collect(ImmutableMap.toImmutableMap(Function.identity(), aggregation -> context.getSymbolAllocator().newSymbol(aggregation.getAggregationSymbol())));
    }

    private AggregationNode createPreAggregation(PlanNode source, List<Symbol> groupingKeys, Map<PreAggregationKey, PreAggregation> preAggregations, Rule.Context context) {
        return new AggregationNode(context.getIdAllocator().getNextId(), source, (Map)preAggregations.entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> ((PreAggregation)entry.getValue()).getAggregationSymbol(), entry -> new AggregationNode.Aggregation(((PreAggregationKey)entry.getKey()).getFunction(), (List<Expression>)ImmutableList.of((Object)((PreAggregation)entry.getValue()).getProjectionSymbol().toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty()))), AggregationNode.singleGroupingSet(groupingKeys), (List<Symbol>)ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
    }

    private ProjectNode createPreProjection(PlanNode source, Assignments groupingExpressions, Map<PreAggregationKey, PreAggregation> preAggregations, Rule.Context context) {
        Assignments.Builder assignments = Assignments.builder();
        assignments.putAll(groupingExpressions);
        preAggregations.values().forEach(aggregation -> assignments.put(aggregation.getProjectionSymbol(), aggregation.getProjection()));
        return new ProjectNode(context.getIdAllocator().getNextId(), source, assignments.build());
    }

    private Map<PreAggregationKey, PreAggregation> getPreAggregations(List<CaseAggregation> aggregations, Rule.Context context) {
        HashSet<PreAggregationKey> keys = new HashSet<PreAggregationKey>();
        ImmutableMap.Builder preAggregations = ImmutableMap.builder();
        for (CaseAggregation aggregation : aggregations) {
            Type aggregationInputType;
            PreAggregationKey preAggregationKey = new PreAggregationKey(aggregation);
            if (keys.contains(preAggregationKey)) continue;
            Expression preProjection = aggregation.getResult();
            Type preProjectionType = this.getType(context, preProjection);
            if (!preProjectionType.equals(aggregationInputType = (Type)Iterables.getOnlyElement((Iterable)aggregation.getFunction().getSignature().getArgumentTypes()))) {
                preProjection = new Cast(preProjection, TypeSignatureTranslator.toSqlType(aggregationInputType));
                preProjectionType = aggregationInputType;
            }
            Symbol preProjectionSymbol = context.getSymbolAllocator().newSymbol(preProjection, preProjectionType);
            Symbol preAggregationSymbol = context.getSymbolAllocator().newSymbol(aggregation.getAggregationSymbol());
            preAggregations.put((Object)preAggregationKey, (Object)new PreAggregation(preAggregationSymbol, preProjection, preProjectionSymbol));
            keys.add(preAggregationKey);
        }
        return ImmutableMap.copyOf((Map)preAggregations.buildOrThrow());
    }

    private Optional<List<CaseAggregation>> extractCaseAggregations(AggregationNode aggregationNode, ProjectNode projectNode, Rule.Context context) {
        ImmutableList.Builder caseAggregations = ImmutableList.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> aggregation : aggregationNode.getAggregations().entrySet()) {
            Optional<CaseAggregation> caseAggregation = this.extractCaseAggregation(aggregation.getKey(), aggregation.getValue(), projectNode, context);
            if (caseAggregation.isEmpty()) {
                return Optional.empty();
            }
            caseAggregations.add((Object)caseAggregation.get());
        }
        return Optional.of(caseAggregations.build());
    }

    private Optional<CaseAggregation> extractCaseAggregation(Symbol aggregationSymbol, AggregationNode.Aggregation aggregation, ProjectNode projectNode, Rule.Context context) {
        ResolvedFunction cumulativeFunction;
        if (aggregation.getArguments().size() != 1 || !(aggregation.getArguments().get(0) instanceof SymbolReference) || aggregation.isDistinct() || aggregation.getFilter().isPresent() || aggregation.getMask().isPresent() || aggregation.getOrderingScheme().isPresent()) {
            return Optional.empty();
        }
        String name = aggregation.getResolvedFunction().getSignature().getName();
        if (!ALLOWED_FUNCTIONS.contains(name)) {
            return Optional.empty();
        }
        Symbol projectionSymbol = Symbol.from(aggregation.getArguments().get(0));
        Expression projection = projectNode.getAssignments().get(projectionSymbol);
        Expression unwrappedProjection = projection instanceof Cast ? ((Cast)projection).getExpression() : projection;
        if (!(unwrappedProjection instanceof SearchedCaseExpression)) {
            return Optional.empty();
        }
        SearchedCaseExpression caseExpression = (SearchedCaseExpression)unwrappedProjection;
        if (caseExpression.getWhenClauses().size() != 1) {
            return Optional.empty();
        }
        Type aggregationType = aggregation.getResolvedFunction().getSignature().getReturnType();
        try {
            cumulativeFunction = this.plannerContext.getMetadata().resolveFunction(context.getSession(), QualifiedName.of((String)name), TypeSignatureProvider.fromTypes(aggregationType));
        }
        catch (TrinoException e) {
            return Optional.empty();
        }
        if (!cumulativeFunction.getSignature().getReturnType().equals(aggregationType)) {
            return Optional.empty();
        }
        Optional<Expression> cumulativeAggregationDefaultValue = Optional.empty();
        if (caseExpression.getDefaultValue().isPresent()) {
            Type defaultType = this.getType(context, (Expression)caseExpression.getDefaultValue().get());
            Object defaultValue = this.optimizeExpression((Expression)caseExpression.getDefaultValue().get(), context);
            if (defaultValue != null) {
                if (!name.equals("sum")) {
                    return Optional.empty();
                }
                if (defaultType instanceof BigintType || defaultType == IntegerType.INTEGER || defaultType == SmallintType.SMALLINT || defaultType == TinyintType.TINYINT || defaultType == DoubleType.DOUBLE || defaultType == RealType.REAL || defaultType instanceof DecimalType) {
                    if (!(defaultValue.equals(0L) || defaultValue.equals(0.0) || defaultValue.equals(Int128.ZERO))) {
                        return Optional.empty();
                    }
                } else {
                    return Optional.empty();
                }
            }
            cumulativeAggregationDefaultValue = Optional.of(new Cast((Expression)caseExpression.getDefaultValue().get(), TypeSignatureTranslator.toSqlType(aggregationType)));
        }
        return Optional.of(new CaseAggregation(aggregationSymbol, aggregation.getResolvedFunction(), cumulativeFunction, name, ((WhenClause)caseExpression.getWhenClauses().get(0)).getOperand(), ((WhenClause)caseExpression.getWhenClauses().get(0)).getResult(), cumulativeAggregationDefaultValue));
    }

    private Type getType(Rule.Context context, Expression expression) {
        return this.typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), expression);
    }

    private Object optimizeExpression(Expression expression, Rule.Context context) {
        Map<NodeRef<Expression>, Type> expressionTypes = this.typeAnalyzer.getTypes(context.getSession(), context.getSymbolAllocator().getTypes(), expression);
        ExpressionInterpreter expressionInterpreter = new ExpressionInterpreter(expression, this.plannerContext, context.getSession(), expressionTypes);
        return expressionInterpreter.optimize(Symbol::toSymbolReference);
    }

    private static class CaseAggregation {
        private final Symbol aggregationSymbol;
        private final ResolvedFunction function;
        private final ResolvedFunction cumulativeFunction;
        private final String name;
        private final Expression operand;
        private final Expression result;
        private final Optional<Expression> cumulativeAggregationDefaultValue;

        public CaseAggregation(Symbol aggregationSymbol, ResolvedFunction function, ResolvedFunction cumulativeFunction, String name, Expression operand, Expression result, Optional<Expression> cumulativeAggregationDefaultValue) {
            this.aggregationSymbol = Objects.requireNonNull(aggregationSymbol, "aggregationSymbol is null");
            this.function = Objects.requireNonNull(function, "function is null");
            this.cumulativeFunction = Objects.requireNonNull(cumulativeFunction, "cumulativeFunction is null");
            this.name = Objects.requireNonNull(name, "name is null");
            this.operand = Objects.requireNonNull(operand, "operand is null");
            this.result = Objects.requireNonNull(result, "result is null");
            this.cumulativeAggregationDefaultValue = Objects.requireNonNull(cumulativeAggregationDefaultValue, "cumulativeAggregationDefaultValue is null");
        }

        public Symbol getAggregationSymbol() {
            return this.aggregationSymbol;
        }

        public ResolvedFunction getFunction() {
            return this.function;
        }

        public ResolvedFunction getCumulativeFunction() {
            return this.cumulativeFunction;
        }

        public String getName() {
            return this.name;
        }

        public Expression getOperand() {
            return this.operand;
        }

        public Expression getResult() {
            return this.result;
        }

        public Optional<Expression> getCumulativeAggregationDefaultValue() {
            return this.cumulativeAggregationDefaultValue;
        }
    }

    private static class PreAggregationKey {
        private final ResolvedFunction function;
        private final Expression projection;

        private PreAggregationKey(CaseAggregation aggregation) {
            this.function = aggregation.getFunction();
            this.projection = aggregation.getResult();
        }

        public ResolvedFunction getFunction() {
            return this.function;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            PreAggregationKey that = (PreAggregationKey)o;
            return Objects.equals(this.function, that.function) && Objects.equals(this.projection, that.projection);
        }

        public int hashCode() {
            return Objects.hash(this.function, this.projection);
        }
    }

    private static class PreAggregation {
        private final Symbol aggregationSymbol;
        private final Expression projection;
        private final Symbol projectionSymbol;

        public PreAggregation(Symbol aggregationSymbol, Expression projection, Symbol projectionSymbol) {
            this.aggregationSymbol = Objects.requireNonNull(aggregationSymbol, "aggregationSymbol is null");
            this.projection = Objects.requireNonNull(projection, "projection is null");
            this.projectionSymbol = Objects.requireNonNull(projectionSymbol, "projectionSymbol is null");
        }

        public Symbol getAggregationSymbol() {
            return this.aggregationSymbol;
        }

        public Expression getProjection() {
            return this.projection;
        }

        public Symbol getProjectionSymbol() {
            return this.projectionSymbol;
        }
    }
}

