/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import java.lang.runtime.SwitchBootstraps;
import java.util.HashSet;
import java.util.Set;

public class RemoveRedundantDistinctAggregation
implements Rule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(AggregationNode::producesDistinctRows);

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        Lookup lookup = context.getLookup();
        if (RemoveRedundantDistinctAggregation.isDistinctOverGroupingKeys(lookup.resolve(aggregationNode.getSource()), lookup, new HashSet<Symbol>(aggregationNode.getGroupingKeys()))) {
            return Rule.Result.ofPlanNode(aggregationNode.getSource());
        }
        return Rule.Result.empty();
    }

    private static boolean isDistinctOverGroupingKeys(PlanNode node, Lookup lookup, Set<Symbol> parentSymbols) {
        PlanNode planNode = node;
        int n = 0;
        return switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{AggregationNode.class, ProjectNode.class, FilterNode.class}, (PlanNode)planNode, n)) {
            case 0 -> {
                AggregationNode aggregationNode = (AggregationNode)planNode;
                if (aggregationNode.getGroupingSets().getGroupingSetCount() == 1 && parentSymbols.containsAll(aggregationNode.getGroupingSets().getGroupingKeys())) {
                    yield true;
                }
                yield false;
            }
            case 1 -> {
                ProjectNode projectNode = (ProjectNode)planNode;
                yield RemoveRedundantDistinctAggregation.isDistinctOverGroupingKeys(lookup.resolve(projectNode.getSource()), lookup, RemoveRedundantDistinctAggregation.translateProjectReferences(projectNode, parentSymbols));
            }
            case 2 -> {
                FilterNode filterNode = (FilterNode)planNode;
                yield RemoveRedundantDistinctAggregation.isDistinctOverGroupingKeys(lookup.resolve(filterNode.getSource()), lookup, parentSymbols);
            }
            default -> false;
        };
    }

    private static Set<Symbol> translateProjectReferences(ProjectNode projectNode, Set<Symbol> groupingKeys) {
        HashSet<Symbol> translated = new HashSet<Symbol>();
        Assignments assignments = projectNode.getAssignments();
        for (Symbol parentSymbol : groupingKeys) {
            Expression expression = assignments.get(parentSymbol);
            if (!(expression instanceof Reference)) continue;
            Reference reference = (Reference)expression;
            translated.add(Symbol.from(reference));
        }
        return translated;
    }
}

