/*
 * Decompiled with CFR 0.152.
 */
package ai.optfor.springopenaiapi;

import ai.optfor.springopenaiapi.cache.DefaultPromptCache;
import ai.optfor.springopenaiapi.cache.PromptCache;
import ai.optfor.springopenaiapi.enums.EmbedModel;
import ai.optfor.springopenaiapi.enums.LLMModel;
import ai.optfor.springopenaiapi.enums.TTSModel;
import ai.optfor.springopenaiapi.enums.TTSVoice;
import ai.optfor.springopenaiapi.model.AudioResponse;
import ai.optfor.springopenaiapi.model.ChatCompletionRequest;
import ai.optfor.springopenaiapi.model.ChatCompletionResponse;
import ai.optfor.springopenaiapi.model.ChatMessage;
import ai.optfor.springopenaiapi.model.EmbeddingRequest;
import ai.optfor.springopenaiapi.model.EmbeddingResponse;
import ai.optfor.springopenaiapi.model.ResponseFormat;
import ai.optfor.springopenaiapi.model.STTRequest;
import ai.optfor.springopenaiapi.model.VisionCompletionRequest;
import ai.optfor.springopenaiapi.model.VisionMessage;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import io.micrometer.common.util.StringUtils;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;

public class OpenAIApi {
    private static final Logger log = LoggerFactory.getLogger(OpenAIApi.class);
    private final ObjectMapper mapper = new ObjectMapper();
    private final PromptCache promptCache;
    private final ExecutorService executorService;

    public OpenAIApi(PromptCache promptCache) {
        this.promptCache = promptCache == null ? new DefaultPromptCache() : promptCache;
        this.mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        this.mapper.enable(SerializationFeature.INDENT_OUTPUT);
        this.executorService = Executors.newFixedThreadPool(3);
    }

    public Flux<String> streamingChat(LLMModel model, String system, String user, String assistant, Integer maxTokens, double temperature, String openaiKey) {
        return this.streamingChat(model, List.of(ChatMessage.systemMessage(system), ChatMessage.userMessage(user), ChatMessage.assistantMessage(assistant)), maxTokens, temperature, openaiKey);
    }

    public Flux<String> streamingChat(LLMModel model, List<ChatMessage> messages, Integer maxTokens, double temperature, String openaiKey) {
        String json;
        ChatCompletionRequest request = new ChatCompletionRequest(model.getApiName(), messages, temperature, maxTokens, null, true);
        try {
            json = this.mapper.writeValueAsString((Object)request);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
        return WebClient.builder().baseUrl("https://api.openai.com/v1/chat/completions").defaultHeader("Authorization", new String[]{"Bearer " + openaiKey}).build().post().contentType(MediaType.APPLICATION_JSON).bodyValue((Object)json).accept(new MediaType[]{MediaType.TEXT_EVENT_STREAM}).exchangeToFlux(r -> r.bodyToFlux(String.class)).takeWhile(response -> !response.equals("[DONE]")).handle((jsonResponse, sink) -> {
            try {
                String delta = ((ChatCompletionResponse)this.mapper.readValue(jsonResponse, ChatCompletionResponse.class)).getDelta();
                if (delta == null) {
                    delta = "";
                }
                sink.next((Object)delta);
            }
            catch (JsonProcessingException e) {
                sink.error((Throwable)new RuntimeException("Error while processing JSON response", e));
            }
        });
    }

    public ChatCompletionResponse vision(List<VisionMessage> messages, Integer maxTokens, double temperature, String openaiKey) {
        VisionCompletionRequest request = new VisionCompletionRequest(LLMModel.GPT_4_VISION_PREVIEW.getApiName(), messages, temperature, maxTokens, false);
        return (ChatCompletionResponse)this.prepareRestTemplate(openaiKey).postForObject("https://api.openai.com/v1/chat/completions", (Object)request, ChatCompletionResponse.class, new Object[0]);
    }

    public Flux<String> visionStreaming(LLMModel model, List<VisionMessage> messages, Integer maxTokens, double temperature, String openaiKey) {
        String json;
        log.info("\nCalling OpenAI API:\nModel: " + model + " Max tokens:" + maxTokens + " Temperature:" + temperature + "\n" + messages.stream().map(chatMessage -> chatMessage.role() + ":\n" + chatMessage.content()).collect(Collectors.joining("\n")));
        VisionCompletionRequest request = new VisionCompletionRequest(model.getApiName(), messages, temperature, maxTokens, true);
        try {
            json = this.mapper.writeValueAsString((Object)request);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
        StringBuilder fullResponse = new StringBuilder();
        long start = System.currentTimeMillis();
        Flux result = WebClient.builder().baseUrl("https://api.openai.com/v1/chat/completions").defaultHeader("Authorization", new String[]{"Bearer " + openaiKey}).build().post().contentType(MediaType.APPLICATION_JSON).bodyValue((Object)json).accept(new MediaType[]{MediaType.TEXT_EVENT_STREAM}).exchangeToFlux(r -> r.bodyToFlux(String.class)).takeWhile(response -> !response.equals("[DONE]")).handle((jsonResponse, sink) -> {
            try {
                String delta = ((ChatCompletionResponse)this.mapper.readValue(jsonResponse, ChatCompletionResponse.class)).getDelta();
                if (delta == null) {
                    delta = "";
                }
                fullResponse.append(delta);
                sink.next((Object)delta);
            }
            catch (JsonProcessingException e) {
                sink.error((Throwable)new RuntimeException("Error while processing JSON response", e));
            }
        });
        return result.doOnComplete(() -> {
            long end = System.currentTimeMillis();
            double seconds = (double)(end - start) / 1000.0;
            log.info("\nReceived response from OpenAI API: " + seconds + " s.(" + fullResponse + ")");
        });
    }

    public byte[] createSpeech(TTSModel model, String input, TTSVoice voice, String openaiKey) {
        RestTemplate restTemplate = this.prepareRestTemplate(openaiKey);
        ResponseEntity response = restTemplate.postForEntity("https://api.openai.com/v1/audio/speech", (Object)new STTRequest(model.getApiName(), input, voice.toApiName()), byte[].class, new Object[0]);
        if (response.hasBody()) {
            return (byte[])response.getBody();
        }
        throw new RuntimeException("Failed to get audio response from OpenAI API");
    }

    public String transcribeAudio(byte[] audioBytes, String languageKey, String openaiKey) {
        ByteArrayResource audioResource = new ByteArrayResource(this, audioBytes){

            public String getFilename() {
                return "audio.oga";
            }
        };
        LinkedMultiValueMap body = new LinkedMultiValueMap();
        body.add((Object)"file", (Object)audioResource);
        body.add((Object)"model", (Object)"whisper-1");
        body.add((Object)"language", (Object)languageKey);
        HttpHeaders headers = new HttpHeaders();
        headers.setContentType(MediaType.MULTIPART_FORM_DATA);
        HttpEntity requestEntity = new HttpEntity((Object)body, (MultiValueMap)headers);
        String url = "https://api.openai.com/v1/audio/transcriptions";
        ResponseEntity response = this.prepareRestTemplate(openaiKey).postForEntity(url, (Object)requestEntity, String.class, new Object[0]);
        if (response.getStatusCode().is2xxSuccessful()) {
            try {
                return ((AudioResponse)this.mapper.readValue((String)response.getBody(), AudioResponse.class)).text();
            }
            catch (JsonProcessingException e) {
                throw new RuntimeException(e);
            }
        }
        throw new RuntimeException(response.toString());
    }

    public ChatCompletionResponse chat(LLMModel model, String system, String user, Integer maxTokens, double temperature, boolean jsonMode, String openaiKey) {
        return this.chat(model, List.of(ChatMessage.systemMessage(system), ChatMessage.userMessage(user)), maxTokens, temperature, jsonMode, openaiKey);
    }

    public ChatCompletionResponse chat(LLMModel model, String system, String user, String assistant, Integer maxTokens, double temperature, boolean jsonMode, String openaiKey) {
        return this.chat(model, List.of(ChatMessage.systemMessage(system), ChatMessage.userMessage(user), ChatMessage.assistantMessage(assistant)), maxTokens, temperature, jsonMode, openaiKey);
    }

    public ChatCompletionResponse chat(LLMModel model, List<ChatMessage> chats, int maxTokens, double temperature, boolean jsonMode, String openaiKey) {
        List<ChatMessage> filteredChats = chats.stream().filter(c -> !StringUtils.isBlank((String)c.content())).toList();
        Future<ChatCompletionResponse> future = this.executorService.submit(() -> this.chatInternal(model, filteredChats, maxTokens, temperature, jsonMode, openaiKey));
        try {
            return future.get();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private ChatCompletionResponse chatInternal(LLMModel model, List<ChatMessage> chats, int maxTokens, double temperature, boolean jsonMode, String openaiKey) {
        log.info("\nCalling OpenAI API:\nModel: " + model + " Max tokens:" + maxTokens + " Temperature:" + temperature + "\n" + chats.stream().map(chatMessage -> chatMessage.role() + ":\n" + chatMessage.content()).collect(Collectors.joining("\n")));
        RestTemplate restTemplate = this.prepareRestTemplate(openaiKey);
        int retryCount = 0;
        while (true) {
            try {
                String cached;
                ChatCompletionRequest request = new ChatCompletionRequest(model.getApiName(), chats, temperature, maxTokens, jsonMode ? new ResponseFormat("json_object") : null, false);
                if (Double.compare(temperature, 0.0) == 0 && (cached = this.promptCache.get(this.createKey(model, chats, maxTokens))) != null) {
                    ChatCompletionResponse response = (ChatCompletionResponse)this.mapper.readValue(cached, ChatCompletionResponse.class);
                    log.info("\nReturning cached response: {}", (Object)this.mapper.writeValueAsString((Object)response));
                    return response;
                }
                long start = System.currentTimeMillis();
                ChatCompletionResponse response = (ChatCompletionResponse)restTemplate.postForObject("https://api.openai.com/v1/chat/completions", (Object)request, ChatCompletionResponse.class, new Object[0]);
                long end = System.currentTimeMillis();
                double seconds = (double)(end - start) / 1000.0;
                log.info("\nReceived response from OpenAI API: " + seconds + " s.(" + (double)response.usage().completion_tokens() / seconds + " TPS) {}", (Object)this.mapper.writeValueAsString((Object)response));
                if (Double.compare(temperature, 0.0) == 0) {
                    this.promptCache.put(this.createKey(model, chats, maxTokens), this.mapper.writeValueAsString((Object)response));
                }
                return response;
            }
            catch (Exception e) {
                if (++retryCount != 3) continue;
                throw new RuntimeException(e);
            }
            break;
        }
    }

    public EmbeddingResponse embedding(EmbedModel model, String content, String openaiKey) {
        return this.embedding(model, List.of(content), openaiKey);
    }

    public EmbeddingResponse embedding(EmbedModel model, List<String> content, String openaiKey) {
        RestTemplate restTemplate = this.prepareRestTemplate(openaiKey);
        EmbeddingRequest request = new EmbeddingRequest(model.getApiName(), content);
        try {
            log.info("Sending request to OpenAI API: {}", (Object)this.mapper.writeValueAsString((Object)request));
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
        return (EmbeddingResponse)restTemplate.postForObject("https://api.openai.com/v1/embeddings", (Object)request, EmbeddingResponse.class, new Object[0]);
    }

    private String createKey(LLMModel model, List<ChatMessage> chats, int maxTokens) {
        return model.getApiName() + chats + maxTokens;
    }

    private RestTemplate prepareRestTemplate(String openaiKey) {
        RestTemplate restTemplate = new RestTemplate();
        SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
        requestFactory.setConnectTimeout(5000);
        requestFactory.setReadTimeout(600000);
        restTemplate.setRequestFactory((ClientHttpRequestFactory)requestFactory);
        ClientHttpRequestInterceptor interceptor = (request, body, execution) -> {
            request.getHeaders().add("Authorization", "Bearer " + openaiKey);
            return execution.execute(request, body);
        };
        restTemplate.setInterceptors(List.of(interceptor));
        return restTemplate;
    }
}

