/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.FilterStatsCalculator;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.SimpleStatsRule;
import io.trino.cost.StatsCalculator;
import io.trino.cost.StatsNormalizer;
import io.trino.cost.StatsProvider;
import io.trino.matching.Pattern;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import java.util.Objects;
import java.util.Optional;

public class FilterProjectAggregationStatsRule
extends SimpleStatsRule<FilterNode> {
    private static final Pattern<FilterNode> PATTERN = Patterns.filter();
    private final FilterStatsCalculator filterStatsCalculator;

    public FilterProjectAggregationStatsRule(StatsNormalizer normalizer, FilterStatsCalculator filterStatsCalculator) {
        super(normalizer);
        this.filterStatsCalculator = Objects.requireNonNull(filterStatsCalculator, "filterStatsCalculator cannot be null");
    }

    @Override
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    @Override
    protected Optional<PlanNodeStatsEstimate> doCalculate(FilterNode node, StatsCalculator.Context context) {
        AggregationNode aggregationNode;
        if (!SystemSessionProperties.isNonEstimatablePredicateApproximationEnabled(context.session())) {
            return Optional.empty();
        }
        PlanNode nodeSource = context.lookup().resolve(node.getSource());
        if (nodeSource instanceof ProjectNode) {
            ProjectNode projectNode = (ProjectNode)nodeSource;
            if (!projectNode.isIdentity()) {
                return Optional.empty();
            }
            PlanNode projectNodeSource = context.lookup().resolve(projectNode.getSource());
            if (!(projectNodeSource instanceof AggregationNode)) {
                return Optional.empty();
            }
            AggregationNode value = (AggregationNode)projectNodeSource;
            aggregationNode = value;
        } else if (nodeSource instanceof AggregationNode) {
            AggregationNode value;
            aggregationNode = value = (AggregationNode)nodeSource;
        } else {
            return Optional.empty();
        }
        return this.calculate(node, aggregationNode, context.statsProvider(), context.session());
    }

    private Optional<PlanNodeStatsEstimate> calculate(FilterNode filterNode, AggregationNode aggregationNode, StatsProvider statsProvider, Session session) {
        PlanNodeStatsEstimate filteredStats = this.filterStatsCalculator.filterStats(statsProvider.getStats(filterNode.getSource()), filterNode.getPredicate(), session);
        if (filteredStats.isOutputRowCountUnknown()) {
            PlanNodeStatsEstimate sourceStats = statsProvider.getStats(aggregationNode);
            if (sourceStats.isOutputRowCountUnknown()) {
                return Optional.of(filteredStats);
            }
            return Optional.of(sourceStats.mapOutputRowCount(rowCount -> rowCount * 0.9));
        }
        return Optional.of(filteredStats);
    }
}

