/*
 * Decompiled with CFR 0.152.
 */
package ai.optfor.springopenaiapi;

import ai.optfor.springopenaiapi.cache.DefaultPromptCache;
import ai.optfor.springopenaiapi.cache.PromptCache;
import ai.optfor.springopenaiapi.model.ChatCompletionRequest;
import ai.optfor.springopenaiapi.model.ChatCompletionResponse;
import ai.optfor.springopenaiapi.model.ChatMessage;
import ai.optfor.springopenaiapi.model.EmbeddingRequest;
import ai.optfor.springopenaiapi.model.EmbeddingResponse;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.MediaType;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;

public class OpenAIApi {
    private static final Logger log = LoggerFactory.getLogger(OpenAIApi.class);
    private final ObjectMapper mapper = new ObjectMapper();
    private final PromptCache promptCache;
    private final ExecutorService executorService;

    public OpenAIApi(PromptCache promptCache) {
        this.promptCache = promptCache == null ? new DefaultPromptCache() : promptCache;
        this.mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        this.mapper.enable(SerializationFeature.INDENT_OUTPUT);
        this.executorService = Executors.newFixedThreadPool(3);
    }

    public Flux<String> streamingChat(String model, String prompt, String role, Integer maxTokens, double temperature, String openaiKey) {
        String json;
        ChatCompletionRequest request = new ChatCompletionRequest(model, List.of(ChatMessage.roleMessage(role), ChatMessage.contentMessage(prompt)), temperature, maxTokens, true);
        try {
            json = this.mapper.writeValueAsString((Object)request);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
        return WebClient.builder().baseUrl("https://api.openai.com/v1/chat/completions").defaultHeader("Authorization", new String[]{openaiKey}).build().post().contentType(MediaType.APPLICATION_JSON).bodyValue((Object)json).accept(new MediaType[]{MediaType.TEXT_EVENT_STREAM}).exchangeToFlux(r -> r.bodyToFlux(String.class)).takeWhile(response -> !response.equals("[DONE]")).handle((jsonResponse, sink) -> {
            try {
                String delta = ((ChatCompletionResponse)this.mapper.readValue(jsonResponse, ChatCompletionResponse.class)).getDelta();
                if (delta == null) {
                    delta = "";
                }
                sink.next((Object)delta);
            }
            catch (JsonProcessingException e) {
                sink.error((Throwable)new RuntimeException("Error while processing JSON response", e));
            }
        });
    }

    public ChatCompletionResponse chat(String model, String prompt, String role, Integer maxTokens, double temperature, String openaiKey) {
        return this.chat(model, List.of(ChatMessage.roleMessage(role), ChatMessage.contentMessage(prompt)), maxTokens, temperature, openaiKey);
    }

    public ChatCompletionResponse chat(String model, List<ChatMessage> chats, int maxTokens, double temperature, String openaiKey) {
        Future<ChatCompletionResponse> future = this.executorService.submit(() -> this.chatInternal(model, chats, maxTokens, temperature, openaiKey));
        try {
            return future.get();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private ChatCompletionResponse chatInternal(String model, List<ChatMessage> chats, int maxTokens, double temperature, String openaiKey) {
        RestTemplate restTemplate = this.prepareRestTemplate(openaiKey);
        int retryCount = 0;
        while (true) {
            try {
                String cached;
                ChatCompletionRequest request = new ChatCompletionRequest(model, chats, temperature, maxTokens, false);
                log.info("Sending request to OpenAI API: {}", (Object)this.mapper.writeValueAsString((Object)request));
                if (Double.compare(temperature, 0.0) == 0 && (cached = this.promptCache.get(this.createKey(model, chats, maxTokens))) != null) {
                    ChatCompletionResponse response = (ChatCompletionResponse)this.mapper.readValue(cached, ChatCompletionResponse.class);
                    log.info("Returning cached response: {}", (Object)this.mapper.writeValueAsString((Object)response));
                    return response;
                }
                long start = System.currentTimeMillis();
                ChatCompletionResponse response = (ChatCompletionResponse)restTemplate.postForObject("https://api.openai.com/v1/chat/completions", (Object)request, ChatCompletionResponse.class, new Object[0]);
                long end = System.currentTimeMillis();
                double seconds = (double)(end - start) / 1000.0;
                log.info("Received response from OpenAI API: " + seconds + " s.(" + (double)response.usage().completion_tokens() / seconds + " TPS) {}", (Object)this.mapper.writeValueAsString((Object)response));
                if (Double.compare(temperature, 0.0) == 0) {
                    this.promptCache.put(this.createKey(model, chats, maxTokens), this.mapper.writeValueAsString((Object)response));
                }
                return response;
            }
            catch (Exception e) {
                if (++retryCount != 3) continue;
                throw new RuntimeException(e);
            }
            break;
        }
    }

    public EmbeddingResponse embedding(String model, String content, String openaiKey) {
        return this.embedding(model, List.of(content), openaiKey);
    }

    public EmbeddingResponse embedding(String model, List<String> content, String openaiKey) {
        RestTemplate restTemplate = this.prepareRestTemplate(openaiKey);
        EmbeddingRequest request = new EmbeddingRequest(model, content);
        try {
            log.info("Sending request to OpenAI API: {}", (Object)this.mapper.writeValueAsString((Object)request));
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
        return (EmbeddingResponse)restTemplate.postForObject("https://api.openai.com/v1/embeddings", (Object)request, EmbeddingResponse.class, new Object[0]);
    }

    private String createKey(String model, List<ChatMessage> chats, int maxTokens) {
        return model + chats + maxTokens;
    }

    private RestTemplate prepareRestTemplate(String openaiKey) {
        RestTemplate restTemplate = new RestTemplate();
        SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
        requestFactory.setConnectTimeout(5000);
        requestFactory.setReadTimeout(600000);
        restTemplate.setRequestFactory((ClientHttpRequestFactory)requestFactory);
        ClientHttpRequestInterceptor interceptor = (request, body, execution) -> {
            request.getHeaders().add("Authorization", "Bearer " + openaiKey);
            return execution.execute(request, body);
        };
        restTemplate.setInterceptors(List.of(interceptor));
        return restTemplate;
    }
}

