/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.GlobalFunctionCatalog;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Coalesce;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.DomainTranslator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.ValuesNode;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

public class PushFilterThroughBoolOrAggregation {
    private static final CatalogSchemaFunctionName BOOL_OR = GlobalFunctionCatalog.builtinFunctionName("bool_or");
    private final PlannerContext plannerContext;

    public PushFilterThroughBoolOrAggregation(PlannerContext plannerContext) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    public Set<Rule<?>> rules() {
        return ImmutableSet.of((Object)new PushFilterThroughBoolOrAggregationWithoutProject(this.plannerContext), (Object)new PushFilterThroughBoolOrAggregationWithProject(this.plannerContext));
    }

    private static Rule.Result pushFilter(FilterNode filterNode, AggregationNode aggregationNode, Optional<ProjectNode> projectNode, PlannerContext plannerContext, Rule.Context context) {
        Symbol boolOrSymbol = (Symbol)Iterables.getOnlyElement(aggregationNode.getAggregations().keySet());
        AggregationNode.Aggregation aggregation = (AggregationNode.Aggregation)Iterables.getOnlyElement(aggregationNode.getAggregations().values());
        DomainTranslator.ExtractionResult extractionResult = DomainTranslator.getExtractionResult(plannerContext, context.getSession(), filterNode.getPredicate());
        TupleDomain<Symbol> tupleDomain = extractionResult.getTupleDomain();
        Expression remainingExpression = extractionResult.getRemainingExpression();
        if (tupleDomain.isNone()) {
            return Rule.Result.ofPlanNode(new ValuesNode(filterNode.getId(), filterNode.getOutputSymbols(), (List<Expression>)ImmutableList.of()));
        }
        List<Expression> conjuncts = IrUtils.extractConjuncts(remainingExpression);
        Map<Boolean, List<Expression>> expressions = conjuncts.stream().filter(expression -> SymbolsExtractor.extractUnique(expression).contains(boolOrSymbol)).collect(Collectors.partitioningBy(expression -> PushFilterThroughBoolOrAggregation.isSupportedCoalesce(expression, boolOrSymbol)));
        if (!expressions.get(Boolean.FALSE).isEmpty()) {
            return Rule.Result.empty();
        }
        Optional<Expression> boolOrCoalesce = Optional.ofNullable(expressions.get(Boolean.TRUE)).filter(expr -> !expr.isEmpty()).map(List::getFirst);
        Optional<Domain> boolOrDomain = Optional.ofNullable((Domain)((Map)tupleDomain.getDomains().get()).get(boolOrSymbol));
        if (boolOrDomain.isPresent() && !boolOrDomain.get().equals((Object)Domain.singleValue((Type)BooleanType.BOOLEAN, (Object)true))) {
            return Rule.Result.empty();
        }
        if (boolOrCoalesce.isEmpty() && boolOrDomain.isEmpty()) {
            return Rule.Result.empty();
        }
        FilterNode source = new FilterNode(context.getIdAllocator().getNextId(), aggregationNode.getSource(), aggregation.getArguments().getFirst());
        AggregationNode newAggregationNode = AggregationNode.builderFrom(aggregationNode).setSource(source).setAggregations((Map<Symbol, AggregationNode.Aggregation>)ImmutableMap.of()).build();
        ProjectNode newProjectNode = new ProjectNode(context.getIdAllocator().getNextId(), newAggregationNode, Assignments.builder().putIdentities(newAggregationNode.getOutputSymbols()).put(boolOrSymbol, Booleans.TRUE).build());
        PlanNode filterSource = projectNode.map(project -> project.replaceChildren((List<PlanNode>)ImmutableList.of((Object)newProjectNode))).orElse(newProjectNode);
        if (boolOrCoalesce.isPresent()) {
            remainingExpression = IrUtils.combineConjuncts(conjuncts.stream().filter(expression -> !expression.equals(boolOrCoalesce.get())).toList());
        }
        TupleDomain newTupleDomain = tupleDomain.filter((symbol, domain) -> !symbol.equals(boolOrSymbol));
        Expression newPredicate = IrUtils.combineConjuncts(new DomainTranslator(plannerContext.getMetadata()).toPredicate((TupleDomain<Symbol>)newTupleDomain), remainingExpression);
        if (!newPredicate.equals(Booleans.TRUE)) {
            return Rule.Result.ofPlanNode(new FilterNode(filterNode.getId(), filterSource, newPredicate));
        }
        return Rule.Result.ofPlanNode(filterSource);
    }

    private static boolean isSupportedCoalesce(Expression expression, Symbol boolOrSymbol) {
        Coalesce coalesce;
        if (!(expression instanceof Coalesce) || (coalesce = (Coalesce)expression).operands().size() != 2) {
            return false;
        }
        Expression firstOperand = coalesce.operands().getFirst();
        Expression secondOperand = coalesce.operands().getLast();
        return firstOperand.equals(boolOrSymbol.toSymbolReference()) && secondOperand.equals(Booleans.FALSE);
    }

    public static boolean isGroupedBoolOr(AggregationNode node) {
        if (!PushFilterThroughBoolOrAggregation.isGroupedAggregation(node)) {
            return false;
        }
        if (node.getAggregations().size() != 1) {
            return false;
        }
        AggregationNode.Aggregation aggregation = (AggregationNode.Aggregation)Iterables.getOnlyElement(node.getAggregations().values());
        if (aggregation.getFilter().isPresent() || aggregation.getMask().isPresent()) {
            return false;
        }
        return aggregation.getResolvedFunction().name().equals((Object)BOOL_OR) && aggregation.getArguments().getFirst() instanceof Reference;
    }

    private static boolean isGroupedAggregation(AggregationNode node) {
        return node.hasNonEmptyGroupingSet() && node.getGroupingSetCount() == 1 && node.getStep() == AggregationNode.Step.SINGLE;
    }

    @VisibleForTesting
    public static final class PushFilterThroughBoolOrAggregationWithoutProject
    implements Rule<FilterNode> {
        private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
        private final PlannerContext plannerContext;
        private final Pattern<FilterNode> pattern;

        public PushFilterThroughBoolOrAggregationWithoutProject(PlannerContext plannerContext) {
            this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
            this.pattern = Patterns.filter().with(Patterns.source().matching(Patterns.aggregation().matching(PushFilterThroughBoolOrAggregation::isGroupedBoolOr).capturedAs(AGGREGATION)));
        }

        @Override
        public Pattern<FilterNode> getPattern() {
            return this.pattern;
        }

        @Override
        public Rule.Result apply(FilterNode node, Captures captures, Rule.Context context) {
            return PushFilterThroughBoolOrAggregation.pushFilter(node, (AggregationNode)captures.get(AGGREGATION), Optional.empty(), this.plannerContext, context);
        }
    }

    @VisibleForTesting
    public static final class PushFilterThroughBoolOrAggregationWithProject
    implements Rule<FilterNode> {
        private static final Capture<ProjectNode> PROJECT = Capture.newCapture();
        private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
        private final PlannerContext plannerContext;
        private final Pattern<FilterNode> pattern;

        public PushFilterThroughBoolOrAggregationWithProject(PlannerContext plannerContext) {
            this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
            this.pattern = Patterns.filter().with(Patterns.source().matching(Patterns.project().matching(ProjectNode::isIdentity).capturedAs(PROJECT).with(Patterns.source().matching(Patterns.aggregation().matching(PushFilterThroughBoolOrAggregation::isGroupedBoolOr).capturedAs(AGGREGATION)))));
        }

        @Override
        public Pattern<FilterNode> getPattern() {
            return this.pattern;
        }

        @Override
        public Rule.Result apply(FilterNode node, Captures captures, Rule.Context context) {
            return PushFilterThroughBoolOrAggregation.pushFilter(node, (AggregationNode)captures.get(AGGREGATION), Optional.of((ProjectNode)captures.get(PROJECT)), this.plannerContext, context);
        }
    }
}

