/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.ai.functions;

import com.google.common.collect.ImmutableList;
import com.google.inject.Inject;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.plugin.ai.functions.AiClient;
import io.trino.spi.block.Block;
import io.trino.spi.block.MapValueBuilder;
import io.trino.spi.block.SqlMap;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionDependencies;
import io.trino.spi.function.FunctionId;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.FunctionProvider;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.ScalarFunctionAdapter;
import io.trino.spi.function.ScalarFunctionImplementation;
import io.trino.spi.function.Signature;
import io.trino.spi.type.MapType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.VarcharType;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class AiFunctions
implements FunctionProvider {
    private static final TypeSignature TEXT = VarcharType.VARCHAR.getTypeSignature();
    private static final List<FunctionMetadata> FUNCTIONS = ImmutableList.builder().add((Object)AiFunctions.function("ai_analyze_sentiment").description("Perform sentiment analysis on text").signature(AiFunctions.signature(TEXT, TEXT)).build()).add((Object)AiFunctions.function("ai_classify").description("Classify text with the provided labels").signature(AiFunctions.signature(TEXT, TEXT, TypeSignature.arrayType((TypeSignature)TEXT))).build()).add((Object)AiFunctions.function("ai_extract").description("Extract values for the provided labels from text").signature(AiFunctions.signature(TypeSignature.mapType((TypeSignature)TEXT, (TypeSignature)TEXT), TEXT, TypeSignature.arrayType((TypeSignature)TEXT))).build()).add((Object)AiFunctions.function("ai_fix_grammar").description("Correct grammatical errors in text").signature(AiFunctions.signature(TEXT, TEXT)).build()).add((Object)AiFunctions.function("ai_gen").description("Generate text based on a prompt").signature(AiFunctions.signature(TEXT, TEXT)).build()).add((Object)AiFunctions.function("ai_mask").description("Mask values for the provided labels in text").signature(AiFunctions.signature(TEXT, TEXT, TypeSignature.arrayType((TypeSignature)TEXT))).build()).add((Object)AiFunctions.function("ai_translate").description("Translate text to the specified language").signature(AiFunctions.signature(TEXT, TEXT, TEXT)).build()).build();
    private static final MethodHandle AI_ANALYZE_SENTIMENT;
    private static final MethodHandle AI_CLASSIFY;
    private static final MethodHandle AI_EXTRACT;
    private static final MethodHandle AI_FIX_GRAMMAR;
    private static final MethodHandle AI_GEN;
    private static final MethodHandle AI_MASK;
    private static final MethodHandle AI_TRANSLATE;
    private final AiClient client;

    @Inject
    public AiFunctions(AiClient client) {
        this.client = Objects.requireNonNull(client, "client is null");
    }

    public List<FunctionMetadata> getFunctions() {
        return FUNCTIONS;
    }

    public ScalarFunctionImplementation getScalarFunctionImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies, InvocationConvention invocationConvention) {
        String name;
        MethodHandle handle = switch (name = functionId.toString()) {
            case "ai_analyze_sentiment" -> AI_ANALYZE_SENTIMENT;
            case "ai_classify" -> AI_CLASSIFY;
            case "ai_extract" -> AI_EXTRACT;
            case "ai_fix_grammar" -> AI_FIX_GRAMMAR;
            case "ai_gen" -> AI_GEN;
            case "ai_mask" -> AI_MASK;
            case "ai_translate" -> AI_TRANSLATE;
            default -> throw new IllegalArgumentException("Invalid function ID: " + String.valueOf(functionId));
        };
        handle = handle.bindTo(this);
        if (name.equals("ai_extract")) {
            handle = handle.bindTo(functionDependencies.getType(TypeSignature.mapType((TypeSignature)TEXT, (TypeSignature)TEXT)));
        }
        InvocationConvention actualConvention = new InvocationConvention(Collections.nCopies(boundSignature.getArity(), InvocationConvention.InvocationArgumentConvention.NEVER_NULL), InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, false, false);
        handle = ScalarFunctionAdapter.adapt((MethodHandle)handle, (Type)boundSignature.getReturnType(), (List)boundSignature.getArgumentTypes(), (InvocationConvention)actualConvention, (InvocationConvention)invocationConvention);
        return ScalarFunctionImplementation.builder().methodHandle(handle).build();
    }

    public Slice aiAnalyzeSentiment(Slice text) {
        return Slices.utf8Slice((String)this.client.analyzeSentiment(text.toStringUtf8()));
    }

    public Slice aiClassify(Slice text, Block labels) {
        return Slices.utf8Slice((String)this.client.classify(text.toStringUtf8(), AiFunctions.fromSqlArray(labels)));
    }

    public SqlMap aiExtract(MapType mapType, Slice text, Block labels) {
        return AiFunctions.toSqlMap(mapType, this.client.extract(text.toStringUtf8(), AiFunctions.fromSqlArray(labels)));
    }

    public Slice aiFixGrammar(Slice text) {
        return Slices.utf8Slice((String)this.client.fixGrammar(text.toStringUtf8()));
    }

    public Slice aiGen(Slice prompt) {
        return Slices.utf8Slice((String)this.client.generate(prompt.toStringUtf8()));
    }

    public Slice aiMask(Slice text, Block labels) {
        return Slices.utf8Slice((String)this.client.mask(text.toStringUtf8(), AiFunctions.fromSqlArray(labels)));
    }

    public Slice aiTranslate(Slice text, Slice language) {
        return Slices.utf8Slice((String)this.client.translate(text.toStringUtf8(), language.toStringUtf8()));
    }

    private static List<String> fromSqlArray(Block block) {
        ArrayList<String> list = new ArrayList<String>();
        for (int i = 0; i < block.getPositionCount(); ++i) {
            list.add(VarcharType.VARCHAR.getSlice(block, i).toStringUtf8());
        }
        return list;
    }

    private static SqlMap toSqlMap(MapType type, Map<String, String> map) {
        return MapValueBuilder.buildMapValue((MapType)type, (int)map.size(), (keyBuilder, valueBuilder) -> map.forEach((key, value) -> {
            VarcharType.VARCHAR.writeSlice(keyBuilder, Slices.utf8Slice((String)key));
            if (value == null) {
                valueBuilder.appendNull();
            } else {
                VarcharType.VARCHAR.writeSlice(valueBuilder, Slices.utf8Slice((String)value));
            }
        }));
    }

    private static FunctionMetadata.Builder function(String name) {
        return FunctionMetadata.scalarBuilder((String)name).functionId(new FunctionId(name)).nondeterministic();
    }

    private static Signature signature(TypeSignature returnType, TypeSignature ... argumentTypes) {
        return Signature.builder().returnType(returnType).argumentTypes(List.of(argumentTypes)).build();
    }

    static {
        try {
            AI_ANALYZE_SENTIMENT = MethodHandles.lookup().findVirtual(AiFunctions.class, "aiAnalyzeSentiment", MethodType.methodType(Slice.class, Slice.class));
            AI_CLASSIFY = MethodHandles.lookup().findVirtual(AiFunctions.class, "aiClassify", MethodType.methodType(Slice.class, Slice.class, Block.class));
            AI_EXTRACT = MethodHandles.lookup().findVirtual(AiFunctions.class, "aiExtract", MethodType.methodType(SqlMap.class, MapType.class, Slice.class, Block.class));
            AI_FIX_GRAMMAR = MethodHandles.lookup().findVirtual(AiFunctions.class, "aiFixGrammar", MethodType.methodType(Slice.class, Slice.class));
            AI_GEN = MethodHandles.lookup().findVirtual(AiFunctions.class, "aiGen", MethodType.methodType(Slice.class, Slice.class));
            AI_MASK = MethodHandles.lookup().findVirtual(AiFunctions.class, "aiMask", MethodType.methodType(Slice.class, Slice.class, Block.class));
            AI_TRANSLATE = MethodHandles.lookup().findVirtual(AiFunctions.class, "aiTranslate", MethodType.methodType(Slice.class, Slice.class, Slice.class));
        }
        catch (ReflectiveOperationException e) {
            throw new AssertionError((Object)e);
        }
    }
}

