/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import com.google.common.collect.MoreCollectors;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.FilterStatsCalculator;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.SimpleStatsRule;
import io.trino.cost.StatsNormalizer;
import io.trino.cost.StatsProvider;
import io.trino.matching.Pattern;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.GroupReference;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import java.util.Objects;
import java.util.Optional;

public class FilterProjectAggregationStatsRule
extends SimpleStatsRule<FilterNode> {
    private static final Pattern<FilterNode> PATTERN = Patterns.filter();
    private final FilterStatsCalculator filterStatsCalculator;

    public FilterProjectAggregationStatsRule(StatsNormalizer normalizer, FilterStatsCalculator filterStatsCalculator) {
        super(normalizer);
        this.filterStatsCalculator = Objects.requireNonNull(filterStatsCalculator, "filterStatsCalculator cannot be null");
    }

    @Override
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    @Override
    protected Optional<PlanNodeStatsEstimate> doCalculate(FilterNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) {
        AggregationNode aggregationNode;
        if (!SystemSessionProperties.isNonEstimatablePredicateApproximationEnabled(session)) {
            return Optional.empty();
        }
        PlanNode nodeSource = FilterProjectAggregationStatsRule.resolveGroup(lookup, node.getSource());
        if (nodeSource instanceof ProjectNode) {
            ProjectNode projectNode = (ProjectNode)nodeSource;
            if (!projectNode.isIdentity()) {
                return Optional.empty();
            }
            PlanNode projectNodeSource = FilterProjectAggregationStatsRule.resolveGroup(lookup, projectNode.getSource());
            if (!(projectNodeSource instanceof AggregationNode)) {
                return Optional.empty();
            }
            aggregationNode = (AggregationNode)projectNodeSource;
        } else if (nodeSource instanceof AggregationNode) {
            aggregationNode = (AggregationNode)nodeSource;
        } else {
            return Optional.empty();
        }
        return this.calculate(node, aggregationNode, sourceStats, session, types);
    }

    private Optional<PlanNodeStatsEstimate> calculate(FilterNode filterNode, AggregationNode aggregationNode, StatsProvider statsProvider, Session session, TypeProvider types) {
        PlanNodeStatsEstimate filteredStats = this.filterStatsCalculator.filterStats(statsProvider.getStats(filterNode.getSource()), filterNode.getPredicate(), session, types);
        if (filteredStats.isOutputRowCountUnknown()) {
            PlanNodeStatsEstimate sourceStats = statsProvider.getStats(aggregationNode);
            if (sourceStats.isOutputRowCountUnknown()) {
                return Optional.of(filteredStats);
            }
            return Optional.of(sourceStats.mapOutputRowCount(rowCount -> rowCount * 0.9));
        }
        return Optional.of(filteredStats);
    }

    private static PlanNode resolveGroup(Lookup lookup, PlanNode node) {
        if (node instanceof GroupReference) {
            return (PlanNode)lookup.resolveGroup(node).collect(MoreCollectors.onlyElement());
        }
        return node;
    }
}

