/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.sql.ir.BooleanLiteral;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IfExpression;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.IsNullPredicate;
import io.trino.sql.ir.LogicalExpression;
import io.trino.sql.ir.NotExpression;
import io.trino.sql.ir.NullIfExpression;
import io.trino.sql.ir.NullLiteral;
import io.trino.sql.ir.SearchedCaseExpression;
import io.trino.sql.ir.SimpleCaseExpression;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.Patterns;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Optional;

public class SimplifyFilterPredicate
implements Rule<FilterNode> {
    private static final Pattern<FilterNode> PATTERN = Patterns.filter();
    private final Metadata metadata;

    public SimplifyFilterPredicate(Metadata metadata) {
        this.metadata = metadata;
    }

    @Override
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(FilterNode node, Captures captures, Rule.Context context) {
        List<Expression> conjuncts = IrUtils.extractConjuncts(node.getPredicate());
        ImmutableList.Builder newConjuncts = ImmutableList.builder();
        boolean simplified = false;
        for (Expression conjunct : conjuncts) {
            Optional<Expression> simplifiedConjunct = this.simplifyFilterExpression(conjunct);
            if (simplifiedConjunct.isPresent()) {
                simplified = true;
                newConjuncts.add((Object)simplifiedConjunct.get());
                continue;
            }
            newConjuncts.add((Object)conjunct);
        }
        if (!simplified) {
            return Rule.Result.empty();
        }
        return Rule.Result.ofPlanNode(new FilterNode(node.getId(), node.getSource(), IrUtils.combineConjuncts(this.metadata, (Collection<Expression>)newConjuncts.build())));
    }

    private Optional<Expression> simplifyFilterExpression(Expression expression) {
        if (expression instanceof IfExpression) {
            IfExpression ifExpression = (IfExpression)expression;
            Expression condition = ifExpression.getCondition();
            Expression trueValue = ifExpression.getTrueValue();
            Optional<Expression> falseValue = ifExpression.getFalseValue();
            if (trueValue.equals(BooleanLiteral.TRUE_LITERAL) && (falseValue.isEmpty() || SimplifyFilterPredicate.isNotTrue(falseValue.get()))) {
                return Optional.of(condition);
            }
            if (SimplifyFilterPredicate.isNotTrue(trueValue) && falseValue.isPresent() && falseValue.get().equals(BooleanLiteral.TRUE_LITERAL)) {
                return Optional.of(SimplifyFilterPredicate.isFalseOrNullPredicate(condition));
            }
            if (falseValue.isPresent() && falseValue.get().equals(trueValue) && DeterminismEvaluator.isDeterministic(trueValue, this.metadata)) {
                return Optional.of(trueValue);
            }
            if (SimplifyFilterPredicate.isNotTrue(trueValue) && (falseValue.isEmpty() || SimplifyFilterPredicate.isNotTrue(falseValue.get()))) {
                return Optional.of(BooleanLiteral.FALSE_LITERAL);
            }
            if (condition.equals(BooleanLiteral.TRUE_LITERAL)) {
                return Optional.of(trueValue);
            }
            if (SimplifyFilterPredicate.isNotTrue(condition)) {
                return Optional.of(falseValue.orElse(BooleanLiteral.FALSE_LITERAL));
            }
            return Optional.empty();
        }
        if (expression instanceof NullIfExpression) {
            NullIfExpression nullIfExpression = (NullIfExpression)expression;
            return Optional.of(LogicalExpression.and(nullIfExpression.getFirst(), SimplifyFilterPredicate.isFalseOrNullPredicate(nullIfExpression.getSecond())));
        }
        if (expression instanceof SearchedCaseExpression) {
            Expression operand2;
            ImmutableList.Builder builder;
            SearchedCaseExpression caseExpression = (SearchedCaseExpression)expression;
            Optional<Expression> defaultValue = caseExpression.getDefaultValue();
            List operands = (List)caseExpression.getWhenClauses().stream().map(WhenClause::getOperand).collect(ImmutableList.toImmutableList());
            List results = (List)caseExpression.getWhenClauses().stream().map(WhenClause::getResult).collect(ImmutableList.toImmutableList());
            long trueResultsCount = results.stream().filter(result -> result.equals(BooleanLiteral.TRUE_LITERAL)).count();
            long notTrueResultsCount = results.stream().filter(SimplifyFilterPredicate::isNotTrue).count();
            if (trueResultsCount == (long)results.size() && defaultValue.isPresent() && defaultValue.get().equals(BooleanLiteral.TRUE_LITERAL)) {
                return Optional.of(BooleanLiteral.TRUE_LITERAL);
            }
            if (notTrueResultsCount == (long)results.size() && (defaultValue.isEmpty() || SimplifyFilterPredicate.isNotTrue(defaultValue.get()))) {
                return Optional.of(BooleanLiteral.FALSE_LITERAL);
            }
            if (trueResultsCount == 1L && notTrueResultsCount == (long)(results.size() - 1) && (defaultValue.isEmpty() || SimplifyFilterPredicate.isNotTrue(defaultValue.get()))) {
                builder = ImmutableList.builder();
                for (WhenClause whenClause : caseExpression.getWhenClauses()) {
                    operand2 = whenClause.getOperand();
                    Expression result2 = whenClause.getResult();
                    if (SimplifyFilterPredicate.isNotTrue(result2)) {
                        builder.add((Object)SimplifyFilterPredicate.isFalseOrNullPredicate(operand2));
                        continue;
                    }
                    builder.add((Object)operand2);
                    return Optional.of(IrUtils.combineConjuncts(this.metadata, (Collection<Expression>)builder.build()));
                }
            }
            if (notTrueResultsCount == (long)results.size() && defaultValue.isPresent() && defaultValue.get().equals(BooleanLiteral.TRUE_LITERAL)) {
                builder = ImmutableList.builder();
                operands.forEach(operand -> builder.add((Object)SimplifyFilterPredicate.isFalseOrNullPredicate(operand)));
                return Optional.of(IrUtils.combineConjuncts(this.metadata, (Collection<Expression>)builder.build()));
            }
            ArrayList<WhenClause> whenClauses = new ArrayList<WhenClause>();
            for (WhenClause whenClause : caseExpression.getWhenClauses()) {
                operand2 = whenClause.getOperand();
                if (operand2.equals(BooleanLiteral.TRUE_LITERAL)) {
                    if (whenClauses.isEmpty()) {
                        return Optional.of(whenClause.getResult());
                    }
                    return Optional.of(new SearchedCaseExpression(whenClauses, Optional.of(whenClause.getResult())));
                }
                if (SimplifyFilterPredicate.isNotTrue(operand2)) continue;
                whenClauses.add(whenClause);
            }
            if (whenClauses.isEmpty()) {
                return Optional.of(defaultValue.orElse(BooleanLiteral.FALSE_LITERAL));
            }
            if (whenClauses.size() < caseExpression.getWhenClauses().size()) {
                return Optional.of(new SearchedCaseExpression(whenClauses, defaultValue));
            }
            return Optional.empty();
        }
        if (expression instanceof SimpleCaseExpression) {
            SimpleCaseExpression caseExpression = (SimpleCaseExpression)expression;
            Optional<Expression> defaultValue = caseExpression.getDefaultValue();
            if (caseExpression.getOperand() instanceof NullLiteral) {
                return Optional.of(defaultValue.orElse(BooleanLiteral.FALSE_LITERAL));
            }
            List results = (List)caseExpression.getWhenClauses().stream().map(WhenClause::getResult).collect(ImmutableList.toImmutableList());
            if (results.stream().allMatch(result -> result.equals(BooleanLiteral.TRUE_LITERAL)) && defaultValue.isPresent() && defaultValue.get().equals(BooleanLiteral.TRUE_LITERAL)) {
                return Optional.of(BooleanLiteral.TRUE_LITERAL);
            }
            if (results.stream().allMatch(SimplifyFilterPredicate::isNotTrue) && (defaultValue.isEmpty() || SimplifyFilterPredicate.isNotTrue(defaultValue.get()))) {
                return Optional.of(BooleanLiteral.FALSE_LITERAL);
            }
            return Optional.empty();
        }
        return Optional.empty();
    }

    private static boolean isNotTrue(Expression expression) {
        return expression.equals(BooleanLiteral.FALSE_LITERAL) || expression instanceof NullLiteral || expression instanceof Cast && SimplifyFilterPredicate.isNotTrue(((Cast)expression).getExpression());
    }

    private static Expression isFalseOrNullPredicate(Expression expression) {
        return LogicalExpression.or(new IsNullPredicate(expression), new NotExpression(expression));
    }
}

