/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.ai.functions;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.UncheckedExecutionException;
import io.airlift.json.JsonCodec;
import io.trino.cache.SafeCaches;
import io.trino.plugin.ai.functions.AiClient;
import io.trino.plugin.ai.functions.AiConfig;
import io.trino.plugin.ai.functions.AiErrorCode;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.TrinoException;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutionException;

public abstract class AbstractAiClient
implements AiClient {
    protected static final JsonCodec<List<String>> LIST_CODEC = JsonCodec.listJsonCodec(String.class);
    protected static final JsonCodec<Map<String, String>> MAP_CODEC = JsonCodec.mapJsonCodec(String.class, String.class);
    protected static final JsonCodec<String> STRING_CODEC = JsonCodec.jsonCodec(String.class);
    protected final String analyzeSentimentModel;
    protected final String classifyModel;
    protected final String extractModel;
    protected final String fixGrammarModel;
    protected final String generateModel;
    protected final String maskModel;
    protected final String translateModel;
    private final Cache<String, String> completionCache = SafeCaches.buildNonEvictableCache((CacheBuilder)CacheBuilder.newBuilder().maximumSize(1000L));

    protected AbstractAiClient(AiConfig config) {
        this.analyzeSentimentModel = Objects.requireNonNullElse(config.getAnalyzeSentimentModel(), config.getModel());
        this.classifyModel = Objects.requireNonNullElse(config.getClassifyModel(), config.getModel());
        this.extractModel = Objects.requireNonNullElse(config.getExtractModel(), config.getModel());
        this.fixGrammarModel = Objects.requireNonNullElse(config.getFixGrammarModel(), config.getModel());
        this.generateModel = Objects.requireNonNullElse(config.getGenerateModel(), config.getModel());
        this.maskModel = Objects.requireNonNullElse(config.getMaskModel(), config.getModel());
        this.translateModel = Objects.requireNonNullElse(config.getTranslateModel(), config.getModel());
    }

    @Override
    public String analyzeSentiment(String text) {
        String prompt = "Classify the text below into one of the following labels: [positive, negative, neutral, mixed]\nOutput only the label.\n=====\n%s\n".formatted(text);
        String response = this.completion(this.analyzeSentimentModel, prompt);
        return response.toLowerCase(Locale.ROOT);
    }

    @Override
    public String classify(String text, List<String> labels) {
        String prompt = "Classify the text below into one of the following JSON encoded labels: %s\nOutput the label as a JSON string (not a JSON object).\nOutput only the label.\n=====\n%s\n".formatted(LIST_CODEC.toJson(labels), text);
        String response = this.completion(this.classifyModel, prompt);
        try {
            return (String)STRING_CODEC.fromJson(response);
        }
        catch (IllegalArgumentException e) {
            throw new TrinoException((ErrorCodeSupplier)AiErrorCode.AI_ERROR, "Failed to parse AI response", (Throwable)e);
        }
    }

    @Override
    public Map<String, String> extract(String text, List<String> labels) {
        String prompt = "Extract a value for each of the JSON encoded labels from the text below.\nFor each label, only extract a single value.\nLabels: %s\nOutput the extracted values as a JSON object.\nOutput only the JSON.\nDo not output a code block for the JSON.\n=====\n%s\n".formatted(LIST_CODEC.toJson(labels), text);
        String response = this.completion(this.extractModel, prompt);
        try {
            return Maps.filterValues((Map)((Map)MAP_CODEC.fromJson(response)), Objects::nonNull);
        }
        catch (IllegalArgumentException e) {
            throw new TrinoException((ErrorCodeSupplier)AiErrorCode.AI_ERROR, "Failed to parse AI response", (Throwable)e);
        }
    }

    @Override
    public String fixGrammar(String text) {
        String prompt = "Fix the grammar in the text below.\nOutput only the text.\n=====\n%s\n".formatted(text);
        return this.completion(this.fixGrammarModel, prompt);
    }

    @Override
    public String generate(String prompt) {
        return this.completion(this.generateModel, prompt);
    }

    @Override
    public String mask(String text, List<String> labels) {
        String prompt = "Mask the values for each of the JSON encoded labels in the text below.\nLabels: %s\nReplace the values with the text \"[MASKED]\".\nOutput only the masked text.\nDo not output anything else.\n=====\n%s\n".formatted(LIST_CODEC.toJson(labels), text);
        return this.completion(this.maskModel, prompt);
    }

    @Override
    public String translate(String text, String language) {
        String prompt = "Translate the text below to the language specified.\nThe language is encoded as a JSON string.\nOutput only the translated text.\nLanguage: %s\n=====\n%s\n".formatted(STRING_CODEC.toJson((Object)language), text);
        return this.completion(this.translateModel, prompt);
    }

    private String completion(String model, String prompt) {
        try {
            String key = model + "\u0000" + prompt;
            return (String)this.completionCache.get((Object)key, () -> this.generateCompletion(model, prompt));
        }
        catch (ExecutionException e) {
            throw new UncheckedExecutionException((Throwable)e);
        }
        catch (UncheckedExecutionException e) {
            Throwable throwable = e.getCause();
            if (throwable instanceof TrinoException) {
                TrinoException ex = (TrinoException)throwable;
                throw ex;
            }
            throw e;
        }
    }

    protected abstract String generateCompletion(String var1, String var2);
}

