/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Case;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.IsNull;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.Not;
import io.trino.sql.ir.NullIf;
import io.trino.sql.ir.Switch;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.Patterns;
import java.lang.runtime.SwitchBootstraps;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;

public class SimplifyFilterPredicate
implements Rule<FilterNode> {
    private static final Pattern<FilterNode> PATTERN = Patterns.filter();

    @Override
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(FilterNode node, Captures captures, Rule.Context context) {
        Constant constant;
        List<Expression> conjuncts = IrUtils.extractConjuncts(node.getPredicate());
        ImmutableList.Builder newConjuncts = ImmutableList.builder();
        boolean simplified = false;
        Iterator<Expression> iterator = conjuncts.iterator();
        while (iterator.hasNext()) {
            Optional simplifiedConjunct;
            Expression conjunct;
            Expression expression = conjunct = iterator.next();
            int n = 0;
            switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{NullIf.class, Case.class, Switch.class}, (Object)expression, n)) {
                case 0: {
                    NullIf expression2 = (NullIf)expression;
                    Optional<Object> optional = Optional.of(Logical.and(expression2.first(), SimplifyFilterPredicate.isFalseOrNullPredicate(expression2.second())));
                    break;
                }
                case 1: {
                    Case expression3 = (Case)expression;
                    Optional<Object> optional = SimplifyFilterPredicate.simplify(expression3);
                    break;
                }
                case 2: {
                    Switch expression4 = (Switch)expression;
                    Optional<Object> optional = SimplifyFilterPredicate.simplify(expression4);
                    break;
                }
                default: {
                    Optional<Object> optional = simplifiedConjunct = Optional.empty();
                }
            }
            if (simplifiedConjunct.isPresent()) {
                simplified = true;
                newConjuncts.add((Object)((Expression)simplifiedConjunct.get()));
                continue;
            }
            newConjuncts.add((Object)conjunct);
        }
        if (!simplified) {
            return Rule.Result.empty();
        }
        Expression predicate = IrUtils.combineConjuncts((Collection<Expression>)newConjuncts.build());
        if (predicate instanceof Constant && (constant = (Constant)predicate).value() == null) {
            predicate = Booleans.FALSE;
        }
        return Rule.Result.ofPlanNode(new FilterNode(node.getId(), node.getSource(), predicate));
    }

    private static Optional<Expression> simplify(Expression condition, Expression trueValue, Expression falseValue) {
        if (trueValue.equals(Booleans.TRUE) && SimplifyFilterPredicate.isNotTrue(falseValue)) {
            return Optional.of(condition);
        }
        if (SimplifyFilterPredicate.isNotTrue(trueValue) && falseValue.equals(Booleans.TRUE)) {
            return Optional.of(SimplifyFilterPredicate.isFalseOrNullPredicate(condition));
        }
        if (falseValue.equals(trueValue) && DeterminismEvaluator.isDeterministic(trueValue)) {
            return Optional.of(trueValue);
        }
        if (SimplifyFilterPredicate.isNotTrue(trueValue) && SimplifyFilterPredicate.isNotTrue(falseValue)) {
            return Optional.of(Booleans.FALSE);
        }
        if (condition.equals(Booleans.TRUE)) {
            return Optional.of(trueValue);
        }
        if (SimplifyFilterPredicate.isNotTrue(condition)) {
            return Optional.of(falseValue);
        }
        return Optional.empty();
    }

    private static Optional<Expression> simplify(Case caseExpression) {
        Expression operand2;
        ImmutableList.Builder builder;
        if (caseExpression.whenClauses().size() == 1) {
            return SimplifyFilterPredicate.simplify(caseExpression.whenClauses().getFirst().getOperand(), caseExpression.whenClauses().getFirst().getResult(), caseExpression.defaultValue());
        }
        List operands = (List)caseExpression.whenClauses().stream().map(WhenClause::getOperand).collect(ImmutableList.toImmutableList());
        List results = (List)caseExpression.whenClauses().stream().map(WhenClause::getResult).collect(ImmutableList.toImmutableList());
        long trueResultsCount = results.stream().filter(result -> result.equals(Booleans.TRUE)).count();
        long notTrueResultsCount = results.stream().filter(SimplifyFilterPredicate::isNotTrue).count();
        if (trueResultsCount == (long)results.size() && caseExpression.defaultValue().equals(Booleans.TRUE)) {
            return Optional.of(Booleans.TRUE);
        }
        if (notTrueResultsCount == (long)results.size() && SimplifyFilterPredicate.isNotTrue(caseExpression.defaultValue())) {
            return Optional.of(Booleans.FALSE);
        }
        if (trueResultsCount == 1L && notTrueResultsCount == (long)(results.size() - 1) && SimplifyFilterPredicate.isNotTrue(caseExpression.defaultValue())) {
            builder = ImmutableList.builder();
            for (WhenClause whenClause : caseExpression.whenClauses()) {
                operand2 = whenClause.getOperand();
                Expression result2 = whenClause.getResult();
                if (SimplifyFilterPredicate.isNotTrue(result2)) {
                    builder.add((Object)SimplifyFilterPredicate.isFalseOrNullPredicate(operand2));
                    continue;
                }
                builder.add((Object)operand2);
                return Optional.of(IrUtils.combineConjuncts((Collection<Expression>)builder.build()));
            }
        }
        if (notTrueResultsCount == (long)results.size() && caseExpression.defaultValue().equals(Booleans.TRUE)) {
            builder = ImmutableList.builder();
            operands.forEach(operand -> builder.add((Object)SimplifyFilterPredicate.isFalseOrNullPredicate(operand)));
            return Optional.of(IrUtils.combineConjuncts((Collection<Expression>)builder.build()));
        }
        ArrayList<WhenClause> whenClauses = new ArrayList<WhenClause>();
        for (WhenClause whenClause : caseExpression.whenClauses()) {
            operand2 = whenClause.getOperand();
            if (operand2.equals(Booleans.TRUE)) {
                if (whenClauses.isEmpty()) {
                    return Optional.of(whenClause.getResult());
                }
                return Optional.of(new Case(whenClauses, whenClause.getResult()));
            }
            if (SimplifyFilterPredicate.isNotTrue(operand2)) continue;
            whenClauses.add(whenClause);
        }
        if (whenClauses.isEmpty()) {
            return Optional.of(caseExpression.defaultValue());
        }
        if (whenClauses.size() < caseExpression.whenClauses().size()) {
            return Optional.of(new Case(whenClauses, caseExpression.defaultValue()));
        }
        return Optional.empty();
    }

    private static Optional<Expression> simplify(Switch caseExpression) {
        Constant literal;
        Optional<Expression> defaultValue = Optional.of(caseExpression.defaultValue());
        Expression expression = caseExpression.operand();
        if (expression instanceof Constant && (literal = (Constant)expression).value() == null) {
            return defaultValue;
        }
        List results = (List)caseExpression.whenClauses().stream().map(WhenClause::getResult).collect(ImmutableList.toImmutableList());
        if (results.stream().allMatch(result -> result.equals(Booleans.TRUE)) && defaultValue.get().equals(Booleans.TRUE)) {
            return Optional.of(Booleans.TRUE);
        }
        if (results.stream().allMatch(SimplifyFilterPredicate::isNotTrue) && SimplifyFilterPredicate.isNotTrue(defaultValue.get())) {
            return Optional.of(Booleans.FALSE);
        }
        return Optional.empty();
    }

    private static boolean isNotTrue(Expression expression) {
        Constant literal;
        return expression.equals(Booleans.FALSE) || expression instanceof Constant && (literal = (Constant)expression).value() == null;
    }

    private static Expression isFalseOrNullPredicate(Expression expression) {
        return Logical.or(new IsNull(expression), new Not(expression));
    }
}

