/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.jdbc.expression;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.ConnectorExpressionPatterns;
import io.trino.plugin.base.expression.ConnectorExpressionRule;
import io.trino.plugin.jdbc.CaseSensitivity;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.QueryParameter;
import io.trino.plugin.jdbc.expression.ComparisonOperator;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.VarcharType;
import java.util.List;
import java.util.Optional;
import java.util.Set;

public class RewriteCaseSensitiveComparison
implements ConnectorExpressionRule<Call, ParameterizedExpression> {
    private static final Capture<Variable> LEFT = Capture.newCapture();
    private static final Capture<Variable> RIGHT = Capture.newCapture();
    private final Pattern<Call> pattern;

    public RewriteCaseSensitiveComparison(Set<ComparisonOperator> enabledOperators) {
        Set functionNames = (Set)enabledOperators.stream().map(ComparisonOperator::getFunctionName).collect(ImmutableSet.toImmutableSet());
        this.pattern = ConnectorExpressionPatterns.call().with(ConnectorExpressionPatterns.type().equalTo((Object)BooleanType.BOOLEAN)).with(ConnectorExpressionPatterns.functionName().matching(functionNames::contains)).with(ConnectorExpressionPatterns.argumentCount().equalTo((Object)2)).with(ConnectorExpressionPatterns.argument((int)0).matching(ConnectorExpressionPatterns.variable().with(ConnectorExpressionPatterns.type().matching(VarcharType.class::isInstance)).capturedAs(LEFT))).with(ConnectorExpressionPatterns.argument((int)1).matching(ConnectorExpressionPatterns.variable().with(ConnectorExpressionPatterns.type().matching(VarcharType.class::isInstance)).capturedAs(RIGHT)));
    }

    public Pattern<Call> getPattern() {
        return this.pattern;
    }

    public Optional<ParameterizedExpression> rewrite(Call expression, Captures captures, ConnectorExpressionRule.RewriteContext<ParameterizedExpression> context) {
        ComparisonOperator comparison = ComparisonOperator.forFunctionName(expression.getFunctionName());
        Variable firstArgument = (Variable)captures.get(LEFT);
        Variable secondArgument = (Variable)captures.get(RIGHT);
        if (!RewriteCaseSensitiveComparison.isCaseSensitive(firstArgument, context) || !RewriteCaseSensitiveComparison.isCaseSensitive(secondArgument, context)) {
            return Optional.empty();
        }
        return context.defaultRewrite((ConnectorExpression)firstArgument).flatMap(first -> context.defaultRewrite((ConnectorExpression)secondArgument).map(second -> new ParameterizedExpression("(%s) %s (%s)".formatted(first.expression(), comparison.getOperator(), second.expression()), (List<QueryParameter>)ImmutableList.builder().addAll(first.parameters()).addAll(second.parameters()).build())));
    }

    private static boolean isCaseSensitive(Variable variable, ConnectorExpressionRule.RewriteContext<?> context) {
        return ((JdbcColumnHandle)context.getAssignment(variable.getName())).getJdbcTypeHandle().caseSensitivity().equals(Optional.of(CaseSensitivity.CASE_SENSITIVE));
    }
}

