/*
 * Decompiled with CFR 0.152.
 */
package ee.carlrobert.llm.client.google;

import com.fasterxml.jackson.core.JacksonException;
import com.fasterxml.jackson.core.JsonProcessingException;
import ee.carlrobert.llm.PropertiesLoader;
import ee.carlrobert.llm.client.DeserializationUtil;
import ee.carlrobert.llm.client.google.completion.ApiResponseError;
import ee.carlrobert.llm.client.google.completion.ErrorDetails;
import ee.carlrobert.llm.client.google.completion.GoogleCompletionContent;
import ee.carlrobert.llm.client.google.completion.GoogleCompletionRequest;
import ee.carlrobert.llm.client.google.completion.GoogleCompletionResponse;
import ee.carlrobert.llm.client.google.completion.GoogleContentPart;
import ee.carlrobert.llm.client.google.embedding.ContentEmbedding;
import ee.carlrobert.llm.client.google.embedding.GoogleBatchEmbeddingResponse;
import ee.carlrobert.llm.client.google.embedding.GoogleEmbeddingContentRequest;
import ee.carlrobert.llm.client.google.embedding.GoogleEmbeddingRequest;
import ee.carlrobert.llm.client.google.embedding.GoogleEmbeddingResponse;
import ee.carlrobert.llm.client.google.models.GoogleModel;
import ee.carlrobert.llm.client.google.models.GoogleModelsResponse;
import ee.carlrobert.llm.client.google.models.GoogleTokensResponse;
import ee.carlrobert.llm.completion.CompletionEventListener;
import ee.carlrobert.llm.completion.CompletionEventSourceListener;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import okhttp3.HttpUrl;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSources;

public class GoogleClient {
    private static final MediaType APPLICATION_JSON = MediaType.parse((String)"application/json");
    private final OkHttpClient httpClient;
    private final String host;
    private final String apiKey;

    protected GoogleClient(Builder builder, OkHttpClient.Builder httpClientBuilder) {
        this.httpClient = httpClientBuilder.build();
        this.host = builder.host;
        this.apiKey = builder.apiKey;
    }

    public EventSource getChatCompletionAsync(GoogleCompletionRequest request, String model, CompletionEventListener<String> eventListener) {
        GoogleModel googleModel = GoogleModel.findByCode(model);
        if (googleModel == null) {
            return this.getChatCompletionAsync(request, model, eventListener, false);
        }
        return this.getChatCompletionAsync(request, googleModel, eventListener);
    }

    public EventSource getChatCompletionAsync(GoogleCompletionRequest request, GoogleModel model, CompletionEventListener<String> eventListener) {
        return this.getChatCompletionAsync(request, model.getCode(), eventListener, model.isExperimental());
    }

    private EventSource getChatCompletionAsync(GoogleCompletionRequest request, String model, CompletionEventListener<String> eventListener, boolean isExperimental) {
        return EventSources.createFactory((OkHttpClient)this.httpClient).newEventSource(this.buildPostRequest(request, model, "streamGenerateContent", true, isExperimental), this.getEventSourceListener(eventListener));
    }

    public GoogleCompletionResponse getChatCompletion(GoogleCompletionRequest request, GoogleModel model) {
        return this.getChatCompletion(request, model.getCode());
    }

    public GoogleCompletionResponse getChatCompletion(GoogleCompletionRequest request, String model) {
        GoogleCompletionResponse googleCompletionResponse;
        block8: {
            Response response = this.httpClient.newCall(this.buildPostRequest(request, model, "generateContent", false)).execute();
            try {
                googleCompletionResponse = DeserializationUtil.mapResponse(response, GoogleCompletionResponse.class);
                if (response == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (response != null) {
                        try {
                            response.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (IOException e) {
                    throw new RuntimeException("Could not get llama completion for the given request:\n" + request, e);
                }
            }
            response.close();
        }
        return googleCompletionResponse;
    }

    public double[] getEmbedding(String text, GoogleModel model) {
        return this.getEmbedding(List.of(text), model.getCode());
    }

    public double[] getEmbedding(String text, String model) {
        return this.getEmbedding(List.of(text), model);
    }

    public double[] getEmbedding(List<String> texts, GoogleModel model) {
        return this.getEmbedding(texts, model.getCode());
    }

    public double[] getEmbedding(List<String> texts, String model) {
        return this.getEmbedding(new GoogleEmbeddingRequest.Builder(new GoogleCompletionContent(texts)).build(), model);
    }

    public double[] getEmbedding(GoogleEmbeddingRequest request, GoogleModel model) {
        return this.getEmbedding(request, model.getCode());
    }

    public double[] getEmbedding(GoogleEmbeddingRequest request, String model) {
        double[] dArray;
        block8: {
            Response response = this.httpClient.newCall(this.buildPostRequest(request, model, "embedContent", false)).execute();
            try {
                dArray = Optional.ofNullable(DeserializationUtil.mapResponse(response, GoogleEmbeddingResponse.class)).map(GoogleEmbeddingResponse::getEmbedding).map(ContentEmbedding::getValues).orElse(null);
                if (response == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (response != null) {
                        try {
                            response.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (IOException e) {
                    throw new RuntimeException("Unable to fetch embedding", e);
                }
            }
            response.close();
        }
        return dArray;
    }

    public List<double[]> getBatchEmbeddings(List<GoogleEmbeddingContentRequest> requests, GoogleModel model) {
        return this.getBatchEmbeddings(requests, model.getCode());
    }

    public List<double[]> getBatchEmbeddings(List<GoogleEmbeddingContentRequest> requests, String model) {
        List<double[]> list;
        block8: {
            Response response = this.httpClient.newCall(this.buildPostRequest(Map.of("requests", requests), model, "batchEmbedContents", false)).execute();
            try {
                List<double[]> embeddings = Optional.ofNullable(DeserializationUtil.mapResponse(response, GoogleBatchEmbeddingResponse.class)).map(GoogleBatchEmbeddingResponse::getEmbeddings).stream().flatMap(Collection::stream).filter(Objects::nonNull).map(ContentEmbedding::getValues).filter(Objects::nonNull).collect(Collectors.toList());
                List<double[]> list2 = list = embeddings.isEmpty() ? null : embeddings;
                if (response == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (response != null) {
                        try {
                            response.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (IOException e) {
                    throw new RuntimeException("Unable to fetch embedding", e);
                }
            }
            response.close();
        }
        return list;
    }

    public GoogleModelsResponse getModels(Integer pageSize, String pageToken) {
        GoogleModelsResponse googleModelsResponse;
        block10: {
            String url = this.host + "/v1/models";
            HttpUrl.Builder urlBuilder = HttpUrl.parse((String)url).newBuilder();
            if (pageSize != null) {
                urlBuilder.addQueryParameter("pageSize", pageSize.toString());
            }
            if (pageToken != null) {
                urlBuilder.addQueryParameter("pageToken", pageToken);
            }
            Response response = this.httpClient.newCall(this.defaultRequestBuilder(urlBuilder, false).get().build()).execute();
            try {
                googleModelsResponse = DeserializationUtil.mapResponse(response, GoogleModelsResponse.class);
                if (response == null) break block10;
            }
            catch (Throwable throwable) {
                try {
                    if (response != null) {
                        try {
                            response.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (IOException e) {
                    throw new RuntimeException("Unable to fetch models", e);
                }
            }
            response.close();
        }
        return googleModelsResponse;
    }

    public GoogleModelsResponse.GeminiModelDetails getModel(String name) {
        GoogleModelsResponse.GeminiModelDetails geminiModelDetails;
        block8: {
            String url = this.host + "/v1/models/" + name;
            Response response = this.httpClient.newCall(this.defaultRequestBuilder(url, false).get().build()).execute();
            try {
                geminiModelDetails = DeserializationUtil.mapResponse(response, GoogleModelsResponse.GeminiModelDetails.class);
                if (response == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (response != null) {
                        try {
                            response.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (IOException e) {
                    throw new RuntimeException("Unable to fetch model", e);
                }
            }
            response.close();
        }
        return geminiModelDetails;
    }

    public GoogleTokensResponse getCountTokens(List<GoogleCompletionContent> contents, GoogleModel model) {
        return this.getCountTokens(contents, model.getCode());
    }

    public GoogleTokensResponse getCountTokens(List<GoogleCompletionContent> contents, String model) {
        GoogleTokensResponse googleTokensResponse;
        block8: {
            Response response = this.httpClient.newCall(this.buildPostRequest(Map.of("contents", contents), model, "countTokens", false)).execute();
            try {
                googleTokensResponse = DeserializationUtil.mapResponse(response, GoogleTokensResponse.class);
                if (response == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (response != null) {
                        try {
                            response.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (IOException e) {
                    throw new RuntimeException("Unable to fetch tokens count", e);
                }
            }
            response.close();
        }
        return googleTokensResponse;
    }

    private Request buildPostRequest(Object request, String model, String path, boolean stream) {
        return this.buildPostRequest(request, model, path, stream, false);
    }

    private Request buildPostRequest(Object request, String model, String path, boolean stream, boolean experimentalModel) {
        try {
            Request.Builder builder = this.defaultRequestBuilder(this.host + String.format("/%s/models/%s:%s", experimentalModel ? "v1alpha" : "v1", model, path), stream).post(RequestBody.create((String)DeserializationUtil.OBJECT_MAPPER.writeValueAsString(request), (MediaType)APPLICATION_JSON));
            return builder.build();
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }

    private Request.Builder defaultRequestBuilder(String url, boolean stream) {
        return this.defaultRequestBuilder(HttpUrl.parse((String)url).newBuilder(), stream);
    }

    private Request.Builder defaultRequestBuilder(HttpUrl.Builder url, boolean stream) {
        if (this.apiKey != null && !this.apiKey.isEmpty()) {
            url.addQueryParameter("key", this.apiKey);
        }
        if (stream) {
            url.addQueryParameter("alt", "sse");
        }
        return new Request.Builder().url(url.build()).header("Cache-Control", "no-cache").header("Content-Type", "application/json").header("Accept", stream ? "text/event-stream" : "text/json");
    }

    private CompletionEventSourceListener<String> getEventSourceListener(CompletionEventListener<String> eventListener) {
        return new CompletionEventSourceListener<String>(eventListener){

            @Override
            protected String getMessage(String data) {
                try {
                    List<GoogleCompletionResponse.Candidate> candidates = ((GoogleCompletionResponse)DeserializationUtil.OBJECT_MAPPER.readValue(data, GoogleCompletionResponse.class)).getCandidates();
                    return (candidates == null ? Stream.empty() : candidates.stream()).filter(Objects::nonNull).flatMap(candidate -> {
                        if (candidate.getContent() != null && candidate.getContent().getParts() != null) {
                            return candidate.getContent().getParts().stream();
                        }
                        return Stream.empty();
                    }).filter(Objects::nonNull).findFirst().map(GoogleContentPart::getText).orElse("");
                }
                catch (JacksonException e) {
                    System.out.println();
                    return "";
                }
            }

            @Override
            protected ee.carlrobert.llm.client.openai.completion.ErrorDetails getErrorDetails(String data) throws JsonProcessingException {
                ErrorDetails googleError = ((ApiResponseError)DeserializationUtil.OBJECT_MAPPER.readValue(data, ApiResponseError.class)).getError();
                return googleError == null ? null : new ee.carlrobert.llm.client.openai.completion.ErrorDetails(googleError.getMessage(), googleError.getStatus(), null, googleError.getCode());
            }
        };
    }

    public static class Builder {
        private String host = PropertiesLoader.getValue("google.baseUrl");
        private String apiKey;

        public Builder(String apiKey) {
            this.apiKey = apiKey;
        }

        public Builder setHost(String host) {
            this.host = host;
            return this;
        }

        public Builder setApiKey(String apiKey) {
            this.apiKey = apiKey;
            return this;
        }

        public GoogleClient build(OkHttpClient.Builder builder) {
            return new GoogleClient(this, builder);
        }

        public GoogleClient build() {
            return this.build(new OkHttpClient.Builder());
        }
    }
}

