/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.ir.optimizer.rule;

import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.metadata.Metadata;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Case;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrExpressions;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.ir.optimizer.IrOptimizerRule;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.Symbol;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class SimplifyRedundantCase
implements IrOptimizerRule {
    private final Metadata metadata;

    public SimplifyRedundantCase(PlannerContext context) {
        this.metadata = context.getMetadata();
    }

    @Override
    public Optional<Expression> apply(Expression expression, Session session, Map<Symbol, Expression> bindings) {
        if (!(expression instanceof Case)) {
            return Optional.empty();
        }
        Case caseTerm = (Case)expression;
        Expression defaultValue = caseTerm.defaultValue();
        if (!caseTerm.whenClauses().stream().map(WhenClause::getResult).allMatch(result -> result.equals(Booleans.TRUE) || result.equals(Booleans.FALSE)) || !defaultValue.equals(Booleans.TRUE) && !defaultValue.equals(Booleans.FALSE) || caseTerm.whenClauses().stream().map(WhenClause::getOperand).anyMatch(e -> !DeterminismEvaluator.isDeterministic(e))) {
            return Optional.empty();
        }
        return this.transformRecursive(0, caseTerm.whenClauses(), defaultValue).or(() -> Optional.of(Booleans.FALSE));
    }

    private Optional<Expression> transformRecursive(int start, List<WhenClause> clauses, Expression defaultExpression) {
        int end;
        for (end = start; end < clauses.size() && clauses.get(end).getResult().equals(Booleans.FALSE); ++end) {
        }
        List<Expression> falseTerms = clauses.subList(start, end).stream().map(clause -> IrExpressions.not(this.metadata, new Comparison(Comparison.Operator.IDENTICAL, clause.getOperand(), Booleans.TRUE))).toList();
        if (end < clauses.size()) {
            ArrayList<Expression> terms = new ArrayList<Expression>();
            terms.add(new Comparison(Comparison.Operator.IDENTICAL, clauses.get(end).getOperand(), Booleans.TRUE));
            this.transformRecursive(end + 1, clauses, defaultExpression).ifPresent(terms::add);
            return Optional.of(IrUtils.and((Collection<Expression>)ImmutableList.builder().addAll(falseTerms).add((Object)IrUtils.or(terms)).build()));
        }
        if (defaultExpression.equals(Booleans.TRUE)) {
            return Optional.of(IrUtils.and(falseTerms));
        }
        return Optional.empty();
    }
}

