/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.jdbc.expression;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.MoreCollectors;
import io.trino.matching.Match;
import io.trino.plugin.base.expression.ConnectorExpressionRule;
import io.trino.plugin.jdbc.expression.GenericRewrite;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.Type;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public class TestGenericRewrite {
    @Test
    public void testRewriteCall() {
        GenericRewrite rewrite = new GenericRewrite(Map.of(), session -> true, "add(foo: decimal(p, s), bar: bigint): decimal(rp, rs)", "foo + bar::decimal(rp,rs)");
        Call expression = new Call((Type)DecimalType.createDecimalType((int)21, (int)2), new FunctionName("add"), List.of(new Variable("first", (Type)DecimalType.createDecimalType((int)10, (int)2)), new Variable("second", (Type)BigintType.BIGINT)));
        ParameterizedExpression rewritten = TestGenericRewrite.apply(rewrite, (ConnectorExpression)expression).orElseThrow();
        Assertions.assertThat((String)rewritten.expression()).isEqualTo("(\"first\") + (\"second\")::decimal(21,2)");
        Assertions.assertThat((List)rewritten.parameters()).isEqualTo(List.of());
    }

    @Test
    public void testRewriteCallWithTypeClass() {
        Map<String, Set<String>> typeClasses = Map.of("integer_class", Set.of("integer", "bigint"));
        GenericRewrite rewrite = new GenericRewrite(typeClasses, session -> true, "add(foo: integer_class, bar: bigint): integer_class", "foo + bar");
        Assertions.assertThat((String)TestGenericRewrite.apply(rewrite, (ConnectorExpression)new Call((Type)BigintType.BIGINT, new FunctionName("add"), List.of(new Variable("first", (Type)IntegerType.INTEGER), new Variable("second", (Type)BigintType.BIGINT)))).orElseThrow().expression()).isEqualTo("(\"first\") + (\"second\")");
        Assertions.assertThat(TestGenericRewrite.apply(rewrite, (ConnectorExpression)new Call((Type)BigintType.BIGINT, new FunctionName("add"), List.of(new Variable("first", (Type)DoubleType.DOUBLE), new Variable("second", (Type)BigintType.BIGINT))))).isEmpty();
        Assertions.assertThat(TestGenericRewrite.apply(rewrite, (ConnectorExpression)new Call((Type)DoubleType.DOUBLE, new FunctionName("add"), List.of(new Variable("first", (Type)IntegerType.INTEGER), new Variable("second", (Type)BigintType.BIGINT))))).isEmpty();
    }

    private static Optional<ParameterizedExpression> apply(GenericRewrite rewrite, ConnectorExpression expression) {
        Optional match = (Optional)rewrite.getPattern().match((Object)expression).collect(MoreCollectors.toOptional());
        if (match.isEmpty()) {
            return Optional.empty();
        }
        return rewrite.rewrite(expression, ((Match)match.get()).captures(), (ConnectorExpressionRule.RewriteContext)new ConnectorExpressionRule.RewriteContext<ParameterizedExpression>(){

            public Map<String, ColumnHandle> getAssignments() {
                throw new UnsupportedOperationException();
            }

            public ConnectorSession getSession() {
                throw new UnsupportedOperationException();
            }

            public Optional<ParameterizedExpression> defaultRewrite(ConnectorExpression expression) {
                if (expression instanceof Variable) {
                    Variable variable = (Variable)expression;
                    return Optional.of(new ParameterizedExpression("\"" + variable.getName().replace("\"", "\"\"") + "\"", (List)ImmutableList.of()));
                }
                return Optional.empty();
            }
        });
    }
}

