/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.GlobalFunctionCatalog;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.Range;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.predicate.ValueSet;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.planner.DomainTranslator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.ValuesNode;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class PushFilterThroughCountAggregation {
    private static final CatalogSchemaFunctionName COUNT_NAME = GlobalFunctionCatalog.builtinFunctionName("count");
    private final PlannerContext plannerContext;

    public PushFilterThroughCountAggregation(PlannerContext plannerContext) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    public Set<Rule<?>> rules() {
        return ImmutableSet.of((Object)new PushFilterThroughCountAggregationWithoutProject(this.plannerContext), (Object)new PushFilterThroughCountAggregationWithProject(this.plannerContext));
    }

    private static Rule.Result pushFilter(FilterNode filterNode, AggregationNode aggregationNode, Optional<ProjectNode> projectNode, PlannerContext plannerContext, Rule.Context context) {
        Symbol countSymbol = (Symbol)Iterables.getOnlyElement(aggregationNode.getAggregations().keySet());
        AggregationNode.Aggregation aggregation = (AggregationNode.Aggregation)Iterables.getOnlyElement(aggregationNode.getAggregations().values());
        DomainTranslator.ExtractionResult extractionResult = DomainTranslator.getExtractionResult(plannerContext, context.getSession(), filterNode.getPredicate());
        TupleDomain<Symbol> tupleDomain = extractionResult.getTupleDomain();
        if (tupleDomain.isNone()) {
            return Rule.Result.ofPlanNode(new ValuesNode(filterNode.getId(), filterNode.getOutputSymbols()));
        }
        Domain countDomain = (Domain)((Map)tupleDomain.getDomains().get()).get(countSymbol);
        if (countDomain == null) {
            return Rule.Result.empty();
        }
        if (countDomain.contains(Domain.singleValue((Type)countDomain.getType(), (Object)0L))) {
            return Rule.Result.empty();
        }
        FilterNode source = new FilterNode(context.getIdAllocator().getNextId(), aggregationNode.getSource(), aggregation.getMask().get().toSymbolReference());
        AggregationNode.Aggregation newAggregation = new AggregationNode.Aggregation(aggregation.getResolvedFunction(), aggregation.getArguments(), aggregation.isDistinct(), aggregation.getFilter(), aggregation.getOrderingScheme(), Optional.empty());
        AggregationNode newAggregationNode = AggregationNode.builderFrom(aggregationNode).setSource(source).setAggregations((Map<Symbol, AggregationNode.Aggregation>)ImmutableMap.of((Object)countSymbol, (Object)newAggregation)).build();
        PlanNode filterSource = projectNode.map(project -> project.replaceChildren((List<PlanNode>)ImmutableList.of((Object)newAggregationNode))).orElse(newAggregationNode);
        if (countDomain.getValues().contains(ValueSet.ofRanges((Range)Range.greaterThanOrEqual((Type)countDomain.getType(), (Object)1L), (Range[])new Range[0]))) {
            TupleDomain newTupleDomain = tupleDomain.filter((symbol, domain) -> !symbol.equals(countSymbol));
            Expression newPredicate = IrUtils.combineConjuncts(new DomainTranslator(plannerContext.getMetadata()).toPredicate((TupleDomain<Symbol>)newTupleDomain), extractionResult.getRemainingExpression());
            if (newPredicate.equals(Booleans.TRUE)) {
                return Rule.Result.ofPlanNode(filterSource);
            }
            return Rule.Result.ofPlanNode(new FilterNode(filterNode.getId(), filterSource, newPredicate));
        }
        return Rule.Result.ofPlanNode(filterNode.replaceChildren((List<PlanNode>)ImmutableList.of((Object)filterSource)));
    }

    private static boolean isGroupedCountWithMask(AggregationNode aggregationNode) {
        if (!PushFilterThroughCountAggregation.isGroupedAggregation(aggregationNode)) {
            return false;
        }
        if (aggregationNode.getAggregations().size() != 1) {
            return false;
        }
        AggregationNode.Aggregation aggregation = (AggregationNode.Aggregation)Iterables.getOnlyElement(aggregationNode.getAggregations().values());
        if (aggregation.getMask().isEmpty() || aggregation.getFilter().isPresent()) {
            return false;
        }
        BoundSignature signature = aggregation.getResolvedFunction().signature();
        return signature.getArgumentTypes().isEmpty() && signature.getName().equals((Object)COUNT_NAME);
    }

    private static boolean isGroupedAggregation(AggregationNode aggregationNode) {
        return aggregationNode.hasNonEmptyGroupingSet() && aggregationNode.getGroupingSetCount() == 1 && aggregationNode.getStep() == AggregationNode.Step.SINGLE;
    }

    @VisibleForTesting
    public static final class PushFilterThroughCountAggregationWithoutProject
    implements Rule<FilterNode> {
        private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
        private final PlannerContext plannerContext;
        private final Pattern<FilterNode> pattern;

        public PushFilterThroughCountAggregationWithoutProject(PlannerContext plannerContext) {
            this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
            this.pattern = Patterns.filter().with(Patterns.source().matching(Patterns.aggregation().matching(PushFilterThroughCountAggregation::isGroupedCountWithMask).capturedAs(AGGREGATION)));
        }

        @Override
        public Pattern<FilterNode> getPattern() {
            return this.pattern;
        }

        @Override
        public Rule.Result apply(FilterNode node, Captures captures, Rule.Context context) {
            return PushFilterThroughCountAggregation.pushFilter(node, (AggregationNode)captures.get(AGGREGATION), Optional.empty(), this.plannerContext, context);
        }
    }

    @VisibleForTesting
    public static final class PushFilterThroughCountAggregationWithProject
    implements Rule<FilterNode> {
        private static final Capture<ProjectNode> PROJECT = Capture.newCapture();
        private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
        private final PlannerContext plannerContext;
        private final Pattern<FilterNode> pattern;

        public PushFilterThroughCountAggregationWithProject(PlannerContext plannerContext) {
            this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
            this.pattern = Patterns.filter().with(Patterns.source().matching(Patterns.project().matching(ProjectNode::isIdentity).capturedAs(PROJECT).with(Patterns.source().matching(Patterns.aggregation().matching(PushFilterThroughCountAggregation::isGroupedCountWithMask).capturedAs(AGGREGATION)))));
        }

        @Override
        public Pattern<FilterNode> getPattern() {
            return this.pattern;
        }

        @Override
        public Rule.Result apply(FilterNode node, Captures captures, Rule.Context context) {
            return PushFilterThroughCountAggregation.pushFilter(node, (AggregationNode)captures.get(AGGREGATION), Optional.of((ProjectNode)captures.get(PROJECT)), this.plannerContext, context);
        }
    }
}

