/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.routine;

import com.google.common.base.Throwables;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.Session;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.FunctionManager;
import io.trino.metadata.Metadata;
import io.trino.operator.scalar.SpecializedSqlScalarFunction;
import io.trino.security.AccessControl;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.function.FunctionId;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.ScalarFunctionImplementation;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.PlannerContext;
import io.trino.sql.parser.SqlParser;
import io.trino.sql.planner.TestingPlannerContext;
import io.trino.sql.routine.SqlRoutineAnalysis;
import io.trino.sql.routine.SqlRoutineAnalyzer;
import io.trino.sql.routine.SqlRoutineCompiler;
import io.trino.sql.routine.SqlRoutinePlanner;
import io.trino.sql.routine.ir.IrRoutine;
import io.trino.sql.tree.FunctionSpecification;
import io.trino.testing.TestingSession;
import io.trino.testing.TransactionBuilder;
import io.trino.transaction.InMemoryTransactionManager;
import io.trino.transaction.TransactionManager;
import io.trino.type.UnknownType;
import java.lang.invoke.MethodHandle;
import java.util.concurrent.atomic.AtomicLong;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.ThrowingConsumer;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.Test;

class TestSqlFunctions {
    private static final SqlParser SQL_PARSER = new SqlParser();
    private static final TransactionManager TRANSACTION_MANAGER = InMemoryTransactionManager.createTestTransactionManager();
    private static final PlannerContext PLANNER_CONTEXT = TestingPlannerContext.plannerContextBuilder().withTransactionManager(TRANSACTION_MANAGER).build();
    private static final Session SESSION = TestingSession.testSessionBuilder().build();
    private final AtomicLong nextId = new AtomicLong();

    TestSqlFunctions() {
    }

    @Test
    void testConstantReturn() {
        String sql = "FUNCTION answer()\nRETURNS BIGINT\nRETURN 42\n";
        TestSqlFunctions.assertFunction(sql, (ThrowingConsumer<MethodHandle>)((ThrowingConsumer)handle -> Assertions.assertThat((Object)handle.invoke()).isEqualTo((Object)42L)));
    }

    @Test
    void testSimpleReturn() {
        String sql = "FUNCTION hello(s VARCHAR)\nRETURNS VARCHAR\nRETURN 'Hello, ' || s || '!'\n";
        TestSqlFunctions.assertFunction(sql, (ThrowingConsumer<MethodHandle>)((ThrowingConsumer)handle -> {
            Assertions.assertThat((Object)handle.invoke(Slices.utf8Slice((String)"world"))).isEqualTo((Object)Slices.utf8Slice((String)"Hello, world!"));
            Assertions.assertThat((Object)handle.invoke(Slices.utf8Slice((String)"WORLD"))).isEqualTo((Object)Slices.utf8Slice((String)"Hello, WORLD!"));
        }));
        this.testSingleExpression((Type)VarcharType.VARCHAR, Slices.utf8Slice((String)"foo"), (Type)VarcharType.VARCHAR, "Hello, foo!", "'Hello, ' || p || '!'");
    }

    @Test
    void testSimpleExpression() {
        String sql = "FUNCTION test(a bigint)\nRETURNS bigint\nBEGIN\n  DECLARE x bigint DEFAULT CAST(99 AS bigint);\n  RETURN x * a;\nEND\n";
        TestSqlFunctions.assertFunction(sql, (ThrowingConsumer<MethodHandle>)((ThrowingConsumer)handle -> {
            Assertions.assertThat((Object)handle.invoke(0L)).isEqualTo((Object)0L);
            Assertions.assertThat((Object)handle.invoke(1L)).isEqualTo((Object)99L);
            Assertions.assertThat((Object)handle.invoke(42L)).isEqualTo((Object)4158L);
            Assertions.assertThat((Object)handle.invoke(123L)).isEqualTo((Object)12177L);
        }));
    }

    @Test
    void testSimpleCase() {
        String sql = "FUNCTION simple_case(a bigint)\nRETURNS varchar\nBEGIN\n  CASE a\n    WHEN 0 THEN RETURN 'zero';\n    WHEN 1 THEN RETURN 'one';\n    WHEN DECIMAL '10.0' THEN RETURN 'ten';\n    WHEN 20.0E0 THEN RETURN 'twenty';\n    ELSE RETURN 'other';\n  END CASE;\n  RETURN NULL;\nEND\n";
        TestSqlFunctions.assertFunction(sql, (ThrowingConsumer<MethodHandle>)((ThrowingConsumer)handle -> {
            Assertions.assertThat((Object)handle.invoke(0L)).isEqualTo((Object)Slices.utf8Slice((String)"zero"));
            Assertions.assertThat((Object)handle.invoke(1L)).isEqualTo((Object)Slices.utf8Slice((String)"one"));
            Assertions.assertThat((Object)handle.invoke(10L)).isEqualTo((Object)Slices.utf8Slice((String)"ten"));
            Assertions.assertThat((Object)handle.invoke(20L)).isEqualTo((Object)Slices.utf8Slice((String)"twenty"));
            Assertions.assertThat((Object)handle.invoke(42L)).isEqualTo((Object)Slices.utf8Slice((String)"other"));
        }));
    }

    @Test
    void testSingleIf() {
        String sql = "FUNCTION test_if(a bigint)\n  RETURNS varchar\n  BEGIN\n    IF a = 0 THEN\n      RETURN 'zero';\n    END IF;\n    RETURN 'other';\n  END\n";
        TestSqlFunctions.assertFunction(sql, (ThrowingConsumer<MethodHandle>)((ThrowingConsumer)handle -> {
            Assertions.assertThat((Object)handle.invoke(0L)).isEqualTo((Object)Slices.utf8Slice((String)"zero"));
            Assertions.assertThat((Object)handle.invoke(1L)).isEqualTo((Object)Slices.utf8Slice((String)"other"));
            Assertions.assertThat((Object)handle.invoke(10L)).isEqualTo((Object)Slices.utf8Slice((String)"other"));
        }));
    }

    @Test
    void testSingleBranchIfElse() {
        String sql = "FUNCTION if_else(a bigint)\n  RETURNS varchar\n  BEGIN\n    IF a = 0 THEN\n      RETURN 'zero';\n    ELSE\n      RETURN 'other';\n    END IF;\n    RETURN NULL;\n  END\n";
        TestSqlFunctions.assertFunction(sql, (ThrowingConsumer<MethodHandle>)((ThrowingConsumer)handle -> {
            Assertions.assertThat((Object)handle.invoke(0L)).isEqualTo((Object)Slices.utf8Slice((String)"zero"));
            Assertions.assertThat((Object)handle.invoke(1L)).isEqualTo((Object)Slices.utf8Slice((String)"other"));
            Assertions.assertThat((Object)handle.invoke(10L)).isEqualTo((Object)Slices.utf8Slice((String)"other"));
        }));
    }

    @Test
    void testMultiBranchIfElse() {
        String sql = "FUNCTION multi_if_else(a bigint)\n  RETURNS varchar\n  BEGIN\n    IF a = 0 THEN\n      RETURN 'zero';\n    ELSEIF a = 1 THEN\n      RETURN 'one';\n    ELSEIF a = 2 THEN\n      RETURN 'two';\n    ELSE\n      RETURN 'other';\n    END IF;\n    RETURN NULL;\n  END\n";
        TestSqlFunctions.assertFunction(sql, (ThrowingConsumer<MethodHandle>)((ThrowingConsumer)handle -> {
            Assertions.assertThat((Object)handle.invoke(0L)).isEqualTo((Object)Slices.utf8Slice((String)"zero"));
            Assertions.assertThat((Object)handle.invoke(1L)).isEqualTo((Object)Slices.utf8Slice((String)"one"));
            Assertions.assertThat((Object)handle.invoke(2L)).isEqualTo((Object)Slices.utf8Slice((String)"two"));
            Assertions.assertThat((Object)handle.invoke(10L)).isEqualTo((Object)Slices.utf8Slice((String)"other"));
        }));
    }

    @Test
    void testSearchCase() {
        String sql = "FUNCTION search_case(a bigint, b bigint)\nRETURNS varchar\nBEGIN\n  CASE\n    WHEN a = 0 THEN RETURN 'zero';\n    WHEN b = 1 THEN RETURN 'one';\n    WHEN a = DECIMAL '10.0' THEN RETURN 'ten';\n    WHEN b = 20.0E0 THEN RETURN 'twenty';\n    ELSE RETURN 'other';\n  END CASE;\n  RETURN NULL;\nEND\n";
        TestSqlFunctions.assertFunction(sql, (ThrowingConsumer<MethodHandle>)((ThrowingConsumer)handle -> {
            Assertions.assertThat((Object)handle.invoke(0L, 42L)).isEqualTo((Object)Slices.utf8Slice((String)"zero"));
            Assertions.assertThat((Object)handle.invoke(42L, 1L)).isEqualTo((Object)Slices.utf8Slice((String)"one"));
            Assertions.assertThat((Object)handle.invoke(10L, 42L)).isEqualTo((Object)Slices.utf8Slice((String)"ten"));
            Assertions.assertThat((Object)handle.invoke(42L, 20L)).isEqualTo((Object)Slices.utf8Slice((String)"twenty"));
            Assertions.assertThat((Object)handle.invoke(42L, 42L)).isEqualTo((Object)Slices.utf8Slice((String)"other"));
            Assertions.assertThat((Object)handle.invoke(0L, 1L)).isEqualTo((Object)Slices.utf8Slice((String)"zero"));
            Assertions.assertThat((Object)handle.invoke(10L, 1L)).isEqualTo((Object)Slices.utf8Slice((String)"one"));
            Assertions.assertThat((Object)handle.invoke(10L, 20L)).isEqualTo((Object)Slices.utf8Slice((String)"ten"));
            Assertions.assertThat((Object)handle.invoke(42L, 20L)).isEqualTo((Object)Slices.utf8Slice((String)"twenty"));
        }));
    }

    @Test
    void testFibonacciWhileLoop() {
        String sql = "FUNCTION fib(n bigint)\nRETURNS bigint\nBEGIN\n  DECLARE a, b bigint DEFAULT 1;\n  DECLARE c bigint;\n  IF n <= 2 THEN\n    RETURN 1;\n  END IF;\n  WHILE n > 2 DO\n    SET n = n - 1;\n    SET c = a + b;\n    SET a = b;\n    SET b = c;\n  END WHILE;\n  RETURN c;\nEND\n";
        TestSqlFunctions.assertFunction(sql, (ThrowingConsumer<MethodHandle>)((ThrowingConsumer)handle -> {
            Assertions.assertThat((Object)handle.invoke(1L)).isEqualTo((Object)1L);
            Assertions.assertThat((Object)handle.invoke(2L)).isEqualTo((Object)1L);
            Assertions.assertThat((Object)handle.invoke(3L)).isEqualTo((Object)2L);
            Assertions.assertThat((Object)handle.invoke(4L)).isEqualTo((Object)3L);
            Assertions.assertThat((Object)handle.invoke(5L)).isEqualTo((Object)5L);
            Assertions.assertThat((Object)handle.invoke(6L)).isEqualTo((Object)8L);
            Assertions.assertThat((Object)handle.invoke(7L)).isEqualTo((Object)13L);
            Assertions.assertThat((Object)handle.invoke(8L)).isEqualTo((Object)21L);
        }));
    }

    @Test
    void testBreakContinue() {
        String sql = "FUNCTION test()\nRETURNS bigint\nBEGIN\n  DECLARE a, b int DEFAULT 0;\n  top: WHILE a < 10 DO\n    SET a = a + 1;\n    IF a < 3 THEN\n      ITERATE top;\n    END IF;\n    SET b = b + 1;\n    IF a > 6 THEN\n      LEAVE top;\n    END IF;\n  END WHILE;\n  RETURN b;\nEND\n";
        TestSqlFunctions.assertFunction(sql, (ThrowingConsumer<MethodHandle>)((ThrowingConsumer)handle -> Assertions.assertThat((Object)handle.invoke()).isEqualTo((Object)5L)));
    }

    @Test
    void testRepeat() {
        String sql = "FUNCTION test_repeat(a bigint)\nRETURNS bigint\nBEGIN\n  REPEAT\n    SET a = a + 1;\n  UNTIL a >= 10 END REPEAT;\n  RETURN a;\nEND\n";
        TestSqlFunctions.assertFunction(sql, (ThrowingConsumer<MethodHandle>)((ThrowingConsumer)handle -> {
            Assertions.assertThat((Object)handle.invoke(0L)).isEqualTo((Object)10L);
            Assertions.assertThat((Object)handle.invoke(100L)).isEqualTo((Object)101L);
        }));
    }

    @Test
    void testRepeatContinue() {
        String sql = "FUNCTION test_repeat_continue()\nRETURNS bigint\nBEGIN\n  DECLARE a int DEFAULT 0;\n  DECLARE b int DEFAULT 0;\n  top: REPEAT\n    SET a = a + 1;\n    IF a <= 3 THEN\n      ITERATE top;\n    END IF;\n    SET b = b + 1;\n  UNTIL a >= 10 END REPEAT;\n  RETURN b;\nEND\n";
        TestSqlFunctions.assertFunction(sql, (ThrowingConsumer<MethodHandle>)((ThrowingConsumer)handle -> Assertions.assertThat((Object)handle.invoke()).isEqualTo((Object)7L)));
    }

    @Test
    void testReuseLabels() {
        String sql = "FUNCTION test()\nRETURNS int\nBEGIN\n  DECLARE r int DEFAULT 0;\n  abc: LOOP\n    SET r = r + 1;\n    LEAVE abc;\n  END LOOP;\n  abc: LOOP\n    SET r = r + 1;\n    LEAVE abc;\n  END LOOP;\n  RETURN r;\nEND\n";
        TestSqlFunctions.assertFunction(sql, (ThrowingConsumer<MethodHandle>)((ThrowingConsumer)handle -> Assertions.assertThat((Object)handle.invoke()).isEqualTo((Object)2L)));
    }

    @Test
    void testReuseVariables() {
        String sql = "FUNCTION test()\nRETURNS bigint\nBEGIN\n  DECLARE r bigint DEFAULT 0;\n  BEGIN\n    DECLARE x varchar DEFAULT 'hello';\n    SET r = r + length(x);\n  END;\n  BEGIN\n    DECLARE x array(int) DEFAULT array[1, 2, 3];\n    SET r = r + cardinality(x);\n  END;\n  RETURN r;\nEND\n";
        TestSqlFunctions.assertFunction(sql, (ThrowingConsumer<MethodHandle>)((ThrowingConsumer)handle -> Assertions.assertThat((Object)handle.invoke()).isEqualTo((Object)8L)));
    }

    @Test
    void testAssignParameter() {
        String sql = "FUNCTION test(x int)\nRETURNS int\nBEGIN\n  SET x = x * 3;\n  RETURN x;\nEND\n";
        TestSqlFunctions.assertFunction(sql, (ThrowingConsumer<MethodHandle>)((ThrowingConsumer)handle -> Assertions.assertThat((Object)handle.invoke(2L)).isEqualTo((Object)6L)));
    }

    @Test
    void testCall() {
        this.testSingleExpression((Type)BigintType.BIGINT, -123L, (Type)BigintType.BIGINT, 123L, "abs(p)");
    }

    @Test
    void testCallNested() {
        this.testSingleExpression((Type)BigintType.BIGINT, -123L, (Type)BigintType.BIGINT, 123L, "abs(ceiling(p))");
        this.testSingleExpression((Type)BigintType.BIGINT, 42L, (Type)DoubleType.DOUBLE, 42.0, "to_unixTime(from_unixtime(p))");
    }

    @Test
    void testArray() {
        this.testSingleExpression((Type)BigintType.BIGINT, 3L, (Type)BigintType.BIGINT, 5L, "array[3,4,5,6,7][p]");
        this.testSingleExpression((Type)BigintType.BIGINT, 0L, (Type)BigintType.BIGINT, 0L, "array_sort(array[3,2,4,5,1,p])[1]");
    }

    @Test
    void testRow() {
        this.testSingleExpression((Type)BigintType.BIGINT, 8L, (Type)BigintType.BIGINT, 8L, "ROW(1, 'a', p)[3]");
    }

    @Test
    void testLambda() {
        this.testSingleExpression((Type)BigintType.BIGINT, 3L, (Type)BigintType.BIGINT, 9L, "(transform(ARRAY [5, 6], x -> x + p)[2])", false);
    }

    @Test
    void testTry() {
        this.testSingleExpression((Type)VarcharType.VARCHAR, Slices.utf8Slice((String)"42"), (Type)BigintType.BIGINT, 42L, "try(cast(p AS bigint))");
        this.testSingleExpression((Type)VarcharType.VARCHAR, Slices.utf8Slice((String)"abc"), (Type)BigintType.BIGINT, null, "try(cast(p AS bigint))");
    }

    @Test
    void testTryCast() {
        this.testSingleExpression((Type)VarcharType.VARCHAR, Slices.utf8Slice((String)"42"), (Type)BigintType.BIGINT, 42L, "try_cast(p AS bigint)");
        this.testSingleExpression((Type)VarcharType.VARCHAR, Slices.utf8Slice((String)"abc"), (Type)BigintType.BIGINT, null, "try_cast(p AS bigint)");
    }

    @Test
    void testNonCanonical() {
        this.testSingleExpression((Type)BigintType.BIGINT, 100000L, (Type)BigintType.BIGINT, 1970L, "EXTRACT(YEAR FROM from_unixtime(p))");
    }

    @Test
    void testAtTimeZone() {
        this.testSingleExpression((Type)UnknownType.UNKNOWN, null, (Type)VarcharType.VARCHAR, "2012-10-30 18:00:00 America/Los_Angeles", "CAST(TIMESTAMP '2012-10-31 01:00 UTC' AT TIME ZONE 'America/Los_Angeles' AS VARCHAR)");
    }

    @Test
    void testSession() {
        this.testSingleExpression((Type)UnknownType.UNKNOWN, null, (Type)DoubleType.DOUBLE, Math.floor((double)SESSION.getStart().toEpochMilli() / 1000.0), "floor(to_unixtime(localtimestamp))");
        this.testSingleExpression((Type)UnknownType.UNKNOWN, null, (Type)VarcharType.VARCHAR, SESSION.getUser(), "current_user");
    }

    @Test
    void testSpecialType() {
        this.testSingleExpression((Type)VarcharType.VARCHAR, Slices.utf8Slice((String)"abc"), (Type)BooleanType.BOOLEAN, true, "(p LIKE '%bc')");
        this.testSingleExpression((Type)VarcharType.VARCHAR, Slices.utf8Slice((String)"xb"), (Type)BooleanType.BOOLEAN, false, "(p LIKE '%bc')");
        this.testSingleExpression((Type)VarcharType.VARCHAR, Slices.utf8Slice((String)"abc"), (Type)BooleanType.BOOLEAN, false, "regexp_like(p, '\\d')");
        this.testSingleExpression((Type)VarcharType.VARCHAR, Slices.utf8Slice((String)"123"), (Type)BooleanType.BOOLEAN, true, "regexp_like(p, '\\d')");
        this.testSingleExpression((Type)VarcharType.VARCHAR, Slices.utf8Slice((String)"[4,5,6]"), (Type)VarcharType.VARCHAR, "6", "json_extract_scalar(p, '$[2]')");
    }

    private void testSingleExpression(Type inputType, Object input, Type outputType, Object output, String expression) {
        this.testSingleExpression(inputType, input, outputType, output, expression, true);
    }

    private void testSingleExpression(Type inputType, Object input, Type outputType, Object output, String expression, boolean deterministic) {
        String sql = "FUNCTION %s(p %s)\nRETURNS %s\n%s\nRETURN %s".formatted("test" + this.nextId.incrementAndGet(), inputType.getTypeSignature(), outputType.getTypeSignature(), deterministic ? "DETERMINISTIC" : "NOT DETERMINISTIC", expression);
        TestSqlFunctions.assertFunction(sql, (ThrowingConsumer<MethodHandle>)((ThrowingConsumer)handle -> {
            Object result = handle.invoke(input);
            if (outputType instanceof VarcharType && result instanceof Slice) {
                Slice slice = (Slice)result;
                result = slice.toStringUtf8();
            }
            Assertions.assertThat((Object)result).isEqualTo(output);
        }));
    }

    private static void assertFunction(@Language(value="SQL") String sql, ThrowingConsumer<MethodHandle> consumer) {
        TransactionBuilder.transaction((TransactionManager)TRANSACTION_MANAGER, (Metadata)PLANNER_CONTEXT.getMetadata(), (AccessControl)new AllowAllAccessControl()).singleStatement().execute(SESSION, session -> {
            ScalarFunctionImplementation implementation = TestSqlFunctions.compileFunction(sql, session);
            MethodHandle handle = implementation.getMethodHandle().bindTo(TestSqlFunctions.getInstance(implementation)).bindTo(session.toConnectorSession());
            consumer.accept((Object)handle);
        });
    }

    private static Object getInstance(ScalarFunctionImplementation implementation) {
        try {
            return ((MethodHandle)implementation.getInstanceFactory().orElseThrow()).invoke();
        }
        catch (Throwable t) {
            Throwables.throwIfUnchecked((Throwable)t);
            throw new RuntimeException(t);
        }
    }

    private static ScalarFunctionImplementation compileFunction(@Language(value="SQL") String sql, Session session) {
        FunctionSpecification function = SQL_PARSER.createFunctionSpecification(sql);
        FunctionMetadata metadata = SqlRoutineAnalyzer.extractFunctionMetadata((FunctionId)new FunctionId("test"), (FunctionSpecification)function);
        SqlRoutineAnalyzer analyzer = new SqlRoutineAnalyzer(PLANNER_CONTEXT, WarningCollector.NOOP);
        SqlRoutineAnalysis analysis = analyzer.analyze(session, (AccessControl)new AllowAllAccessControl(), function);
        SqlRoutinePlanner planner = new SqlRoutinePlanner(PLANNER_CONTEXT, WarningCollector.NOOP);
        IrRoutine routine = planner.planSqlFunction(session, function, analysis);
        SqlRoutineCompiler compiler = new SqlRoutineCompiler(FunctionManager.createTestingFunctionManager());
        SpecializedSqlScalarFunction sqlScalarFunction = compiler.compile(routine);
        InvocationConvention invocationConvention = new InvocationConvention(metadata.getFunctionNullability().getArgumentNullable().stream().map(nullable -> nullable != false ? InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE : InvocationConvention.InvocationArgumentConvention.NEVER_NULL).toList(), metadata.getFunctionNullability().isReturnNullable() ? InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN : InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, true, true);
        return sqlScalarFunction.getScalarFunctionImplementation(invocationConvention);
    }
}

