/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.cost.ComparisonStatsCalculator;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.PlanNodeStatsEstimateMath;
import io.trino.cost.ScalarStatsCalculator;
import io.trino.cost.StatsNormalizer;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.execution.warnings.WarningCollector;
import io.trino.security.AccessControl;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.statistics.StatsUtil;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.DynamicFilters;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.ExpressionAnalyzer;
import io.trino.sql.analyzer.Scope;
import io.trino.sql.planner.ExpressionInterpreter;
import io.trino.sql.planner.LiteralEncoder;
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.tree.AstVisitor;
import io.trino.sql.tree.BetweenPredicate;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.InListExpression;
import io.trino.sql.tree.InPredicate;
import io.trino.sql.tree.IsNotNullPredicate;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.LogicalExpression;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.NotExpression;
import io.trino.sql.tree.Parameter;
import io.trino.sql.tree.SymbolReference;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalDouble;
import javax.annotation.Nullable;
import javax.inject.Inject;

public class FilterStatsCalculator {
    static final double UNKNOWN_FILTER_COEFFICIENT = 0.9;
    private final PlannerContext plannerContext;
    private final ScalarStatsCalculator scalarStatsCalculator;
    private final StatsNormalizer normalizer;

    @Inject
    public FilterStatsCalculator(PlannerContext plannerContext, ScalarStatsCalculator scalarStatsCalculator, StatsNormalizer normalizer) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
        this.scalarStatsCalculator = Objects.requireNonNull(scalarStatsCalculator, "scalarStatsCalculator is null");
        this.normalizer = Objects.requireNonNull(normalizer, "normalizer is null");
    }

    public PlanNodeStatsEstimate filterStats(PlanNodeStatsEstimate statsEstimate, Expression predicate, Session session, TypeProvider types) {
        Expression simplifiedExpression = this.simplifyExpression(session, predicate, types);
        return (PlanNodeStatsEstimate)new FilterExpressionStatsCalculatingVisitor(statsEstimate, session, types).process((Node)simplifiedExpression);
    }

    private Expression simplifyExpression(Session session, Expression predicate, TypeProvider types) {
        Map<NodeRef<Expression>, Type> expressionTypes = ExpressionUtils.getExpressionTypes(this.plannerContext, session, predicate, types);
        ExpressionInterpreter interpreter = new ExpressionInterpreter(predicate, this.plannerContext, session, expressionTypes);
        Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE);
        if (value == null) {
            value = false;
        }
        return new LiteralEncoder(this.plannerContext).toExpression(session, value, (Type)BooleanType.BOOLEAN);
    }

    private class FilterExpressionStatsCalculatingVisitor
    extends AstVisitor<PlanNodeStatsEstimate, Void> {
        private final PlanNodeStatsEstimate input;
        private final Session session;
        private final TypeProvider types;

        FilterExpressionStatsCalculatingVisitor(PlanNodeStatsEstimate input, Session session, TypeProvider types) {
            this.input = input;
            this.session = session;
            this.types = types;
        }

        public PlanNodeStatsEstimate process(Node node, @Nullable Void context) {
            return FilterStatsCalculator.this.normalizer.normalize((PlanNodeStatsEstimate)super.process(node, (Object)context), this.types);
        }

        protected PlanNodeStatsEstimate visitExpression(Expression node, Void context) {
            return PlanNodeStatsEstimate.unknown();
        }

        protected PlanNodeStatsEstimate visitNotExpression(NotExpression node, Void context) {
            if (node.getValue() instanceof IsNullPredicate) {
                return (PlanNodeStatsEstimate)this.process((Node)new IsNotNullPredicate(((IsNullPredicate)node.getValue()).getValue()));
            }
            return PlanNodeStatsEstimateMath.subtractSubsetStats(this.input, (PlanNodeStatsEstimate)this.process((Node)node.getValue()));
        }

        protected PlanNodeStatsEstimate visitLogicalExpression(LogicalExpression node, Void context) {
            switch (node.getOperator()) {
                case AND: {
                    return this.estimateLogicalAnd(node.getTerms());
                }
                case OR: {
                    return this.estimateLogicalOr(node.getTerms());
                }
            }
            throw new IllegalArgumentException("Unexpected binary operator: " + node.getOperator());
        }

        private PlanNodeStatsEstimate estimateLogicalAnd(List<Expression> terms) {
            Optional<PlanNodeStatsEstimate> smallest;
            PlanNodeStatsEstimate estimate = (PlanNodeStatsEstimate)this.process((Node)terms.get(0));
            if (!estimate.isOutputRowCountUnknown()) {
                for (int i = 1; i < terms.size() && !(estimate = (PlanNodeStatsEstimate)new FilterExpressionStatsCalculatingVisitor(estimate, this.session, this.types).process((Node)terms.get(i))).isOutputRowCountUnknown(); ++i) {
                }
                if (!estimate.isOutputRowCountUnknown()) {
                    return estimate;
                }
            }
            if ((smallest = terms.stream().map(arg_0 -> ((FilterExpressionStatsCalculatingVisitor)this).process(arg_0)).filter(termEstimate -> !termEstimate.isOutputRowCountUnknown()).sorted(Comparator.comparingDouble(PlanNodeStatsEstimate::getOutputRowCount)).findFirst()).isEmpty()) {
                return PlanNodeStatsEstimate.unknown();
            }
            return smallest.get().mapOutputRowCount(rowCount -> rowCount * 0.9);
        }

        private PlanNodeStatsEstimate estimateLogicalOr(List<Expression> terms) {
            PlanNodeStatsEstimate previous = (PlanNodeStatsEstimate)this.process((Node)terms.get(0));
            if (previous.isOutputRowCountUnknown()) {
                return PlanNodeStatsEstimate.unknown();
            }
            for (int i = 1; i < terms.size(); ++i) {
                PlanNodeStatsEstimate current = (PlanNodeStatsEstimate)this.process((Node)terms.get(i));
                if (current.isOutputRowCountUnknown()) {
                    return PlanNodeStatsEstimate.unknown();
                }
                PlanNodeStatsEstimate andEstimate = (PlanNodeStatsEstimate)new FilterExpressionStatsCalculatingVisitor(previous, this.session, this.types).process((Node)terms.get(i));
                if (andEstimate.isOutputRowCountUnknown()) {
                    return PlanNodeStatsEstimate.unknown();
                }
                previous = PlanNodeStatsEstimateMath.capStats(PlanNodeStatsEstimateMath.subtractSubsetStats(PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues(previous, current), andEstimate), this.input);
            }
            return previous;
        }

        protected PlanNodeStatsEstimate visitBooleanLiteral(BooleanLiteral node, Void context) {
            if (node.getValue()) {
                return this.input;
            }
            PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
            result.setOutputRowCount(0.0);
            this.input.getSymbolsWithKnownStatistics().forEach(symbol -> result.addSymbolStatistics((Symbol)symbol, SymbolStatsEstimate.zero()));
            return result.build();
        }

        protected PlanNodeStatsEstimate visitIsNotNullPredicate(IsNotNullPredicate node, Void context) {
            if (node.getValue() instanceof SymbolReference) {
                Symbol symbol = Symbol.from(node.getValue());
                SymbolStatsEstimate symbolStats = this.input.getSymbolStatistics(symbol);
                PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(this.input);
                result.setOutputRowCount(this.input.getOutputRowCount() * (1.0 - symbolStats.getNullsFraction()));
                result.addSymbolStatistics(symbol, symbolStats.mapNullsFraction(x -> 0.0));
                return result.build();
            }
            return PlanNodeStatsEstimate.unknown();
        }

        protected PlanNodeStatsEstimate visitIsNullPredicate(IsNullPredicate node, Void context) {
            if (node.getValue() instanceof SymbolReference) {
                Symbol symbol = Symbol.from(node.getValue());
                SymbolStatsEstimate symbolStats = this.input.getSymbolStatistics(symbol);
                PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(this.input);
                result.setOutputRowCount(this.input.getOutputRowCount() * symbolStats.getNullsFraction());
                result.addSymbolStatistics(symbol, SymbolStatsEstimate.builder().setNullsFraction(1.0).setLowValue(Double.NaN).setHighValue(Double.NaN).setDistinctValuesCount(0.0).build());
                return result.build();
            }
            return PlanNodeStatsEstimate.unknown();
        }

        protected PlanNodeStatsEstimate visitBetweenPredicate(BetweenPredicate node, Void context) {
            if (!(node.getValue() instanceof SymbolReference)) {
                return PlanNodeStatsEstimate.unknown();
            }
            if (!this.getExpressionStats(node.getMin()).isSingleValue()) {
                return PlanNodeStatsEstimate.unknown();
            }
            if (!this.getExpressionStats(node.getMax()).isSingleValue()) {
                return PlanNodeStatsEstimate.unknown();
            }
            SymbolStatsEstimate valueStats = this.input.getSymbolStatistics(Symbol.from(node.getValue()));
            ComparisonExpression lowerBound = new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin());
            ComparisonExpression upperBound = new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, node.getValue(), node.getMax());
            Expression transformed = Double.isInfinite(valueStats.getLowValue()) ? ExpressionUtils.and(new Expression[]{lowerBound, upperBound}) : ExpressionUtils.and(new Expression[]{upperBound, lowerBound});
            return (PlanNodeStatsEstimate)this.process((Node)transformed);
        }

        protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context) {
            if (!(node.getValueList() instanceof InListExpression)) {
                return PlanNodeStatsEstimate.unknown();
            }
            InListExpression inList = (InListExpression)node.getValueList();
            ImmutableList equalityEstimates = (ImmutableList)inList.getValues().stream().map(inValue -> (PlanNodeStatsEstimate)this.process((Node)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, node.getValue(), inValue))).collect(ImmutableList.toImmutableList());
            if (equalityEstimates.stream().anyMatch(PlanNodeStatsEstimate::isOutputRowCountUnknown)) {
                return PlanNodeStatsEstimate.unknown();
            }
            PlanNodeStatsEstimate inEstimate = equalityEstimates.stream().reduce(PlanNodeStatsEstimateMath::addStatsAndSumDistinctValues).orElse(PlanNodeStatsEstimate.unknown());
            if (inEstimate.isOutputRowCountUnknown()) {
                return PlanNodeStatsEstimate.unknown();
            }
            SymbolStatsEstimate valueStats = this.getExpressionStats(node.getValue());
            if (valueStats.isUnknown()) {
                return PlanNodeStatsEstimate.unknown();
            }
            double notNullValuesBeforeIn = this.input.getOutputRowCount() * (1.0 - valueStats.getNullsFraction());
            PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(this.input);
            result.setOutputRowCount(Double.min(inEstimate.getOutputRowCount(), notNullValuesBeforeIn));
            if (node.getValue() instanceof SymbolReference) {
                Symbol valueSymbol = Symbol.from(node.getValue());
                SymbolStatsEstimate newSymbolStats = inEstimate.getSymbolStatistics(valueSymbol).mapDistinctValuesCount(newDistinctValuesCount -> Double.min(newDistinctValuesCount, valueStats.getDistinctValuesCount()));
                result.addSymbolStatistics(valueSymbol, newSymbolStats);
            }
            return result.build();
        }

        protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression node, Void context) {
            Optional<Symbol> leftSymbol;
            ComparisonExpression.Operator operator = node.getOperator();
            Expression left = node.getLeft();
            Expression right = node.getRight();
            Preconditions.checkArgument((!this.isEffectivelyLiteral(left) || !this.isEffectivelyLiteral(right) ? 1 : 0) != 0, (Object)"Literal-to-literal not supported here, should be eliminated earlier");
            if (!(left instanceof SymbolReference) && right instanceof SymbolReference) {
                return (PlanNodeStatsEstimate)this.process((Node)new ComparisonExpression(operator.flip(), right, left));
            }
            if (this.isEffectivelyLiteral(left)) {
                Verify.verify((!this.isEffectivelyLiteral(right) ? 1 : 0) != 0);
                return (PlanNodeStatsEstimate)this.process((Node)new ComparisonExpression(operator.flip(), right, left));
            }
            if (left instanceof SymbolReference && left.equals((Object)right)) {
                return (PlanNodeStatsEstimate)this.process((Node)new IsNotNullPredicate(left));
            }
            SymbolStatsEstimate leftStats = this.getExpressionStats(left);
            Optional<Symbol> optional = leftSymbol = left instanceof SymbolReference ? Optional.of(Symbol.from(left)) : Optional.empty();
            if (this.isEffectivelyLiteral(right)) {
                OptionalDouble literal = this.doubleValueFromLiteral(this.getType(left), right);
                return ComparisonStatsCalculator.estimateExpressionToLiteralComparison(this.input, leftStats, leftSymbol, literal, operator);
            }
            SymbolStatsEstimate rightStats = this.getExpressionStats(right);
            if (rightStats.isSingleValue()) {
                OptionalDouble value = Double.isNaN(rightStats.getLowValue()) ? OptionalDouble.empty() : OptionalDouble.of(rightStats.getLowValue());
                return ComparisonStatsCalculator.estimateExpressionToLiteralComparison(this.input, leftStats, leftSymbol, value, operator);
            }
            Optional<Symbol> rightSymbol = right instanceof SymbolReference ? Optional.of(Symbol.from(right)) : Optional.empty();
            return ComparisonStatsCalculator.estimateExpressionToExpressionComparison(this.input, leftStats, leftSymbol, rightStats, rightSymbol, operator);
        }

        protected PlanNodeStatsEstimate visitFunctionCall(FunctionCall node, Void context) {
            if (DynamicFilters.isDynamicFilter((Expression)node)) {
                return this.process((Node)BooleanLiteral.TRUE_LITERAL, context);
            }
            return PlanNodeStatsEstimate.unknown();
        }

        private Type getType(Expression expression) {
            if (expression instanceof SymbolReference) {
                Symbol symbol = Symbol.from(expression);
                return Objects.requireNonNull(this.types.get(symbol), () -> String.format("No type for symbol %s", symbol));
            }
            ExpressionAnalyzer expressionAnalyzer = ExpressionAnalyzer.createWithoutSubqueries(FilterStatsCalculator.this.plannerContext, (AccessControl)new AllowAllAccessControl(), this.session, this.types, (Map<NodeRef<Parameter>, Expression>)ImmutableMap.of(), node -> new VerifyException("Unexpected subquery"), WarningCollector.NOOP, false);
            return expressionAnalyzer.analyze(expression, Scope.create());
        }

        private SymbolStatsEstimate getExpressionStats(Expression expression) {
            if (expression instanceof SymbolReference) {
                Symbol symbol = Symbol.from(expression);
                return Objects.requireNonNull(this.input.getSymbolStatistics(symbol), () -> String.format("No statistics for symbol %s", symbol));
            }
            return FilterStatsCalculator.this.scalarStatsCalculator.calculate(expression, this.input, this.session, this.types);
        }

        private boolean isEffectivelyLiteral(Expression expression) {
            return ExpressionUtils.isEffectivelyLiteral(FilterStatsCalculator.this.plannerContext, this.session, expression);
        }

        private OptionalDouble doubleValueFromLiteral(Type type, Expression literal) {
            Object literalValue = ExpressionInterpreter.evaluateConstantExpression(literal, type, FilterStatsCalculator.this.plannerContext, this.session, new AllowAllAccessControl(), (Map<NodeRef<Parameter>, Expression>)ImmutableMap.of());
            return StatsUtil.toStatsRepresentation((Type)type, (Object)literalValue);
        }
    }
}

