/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.community.model.xinference.client;

import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.community.model.xinference.client.AuthorizationHeaderInjector;
import dev.langchain4j.community.model.xinference.client.GenericHeaderInjector;
import dev.langchain4j.community.model.xinference.client.RequestExecutor;
import dev.langchain4j.community.model.xinference.client.RequestLoggingInterceptor;
import dev.langchain4j.community.model.xinference.client.ResponseLoggingInterceptor;
import dev.langchain4j.community.model.xinference.client.SyncOrAsync;
import dev.langchain4j.community.model.xinference.client.SyncOrAsyncOrStreaming;
import dev.langchain4j.community.model.xinference.client.XinferenceApi;
import dev.langchain4j.community.model.xinference.client.chat.ChatCompletionRequest;
import dev.langchain4j.community.model.xinference.client.chat.ChatCompletionResponse;
import dev.langchain4j.community.model.xinference.client.completion.CompletionRequest;
import dev.langchain4j.community.model.xinference.client.completion.CompletionResponse;
import dev.langchain4j.community.model.xinference.client.embedding.EmbeddingRequest;
import dev.langchain4j.community.model.xinference.client.embedding.EmbeddingResponse;
import dev.langchain4j.community.model.xinference.client.image.ImageRequest;
import dev.langchain4j.community.model.xinference.client.image.ImageResponse;
import dev.langchain4j.community.model.xinference.client.image.OcrRequest;
import dev.langchain4j.community.model.xinference.client.rerank.RerankRequest;
import dev.langchain4j.community.model.xinference.client.rerank.RerankResponse;
import dev.langchain4j.community.model.xinference.client.shared.StreamOptions;
import dev.langchain4j.community.model.xinference.client.utils.JsonUtil;
import dev.langchain4j.internal.Utils;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import okhttp3.Cache;
import okhttp3.Interceptor;
import okhttp3.MediaType;
import okhttp3.MultipartBody;
import okhttp3.OkHttpClient;
import okhttp3.RequestBody;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import retrofit2.Converter;
import retrofit2.Retrofit;
import retrofit2.converter.jackson.JacksonConverterFactory;

public class XinferenceClient {
    private static final Logger log = LoggerFactory.getLogger(XinferenceClient.class);
    private final String baseUrl;
    private final OkHttpClient okHttpClient;
    private final XinferenceApi xinferenceApi;
    private final boolean logStreamingResponses;

    private XinferenceClient(Builder builder) {
        this.baseUrl = builder.baseUrl;
        OkHttpClient.Builder okHttpClientBuilder = new OkHttpClient.Builder().callTimeout(builder.callTimeout).connectTimeout(builder.connectTimeout).readTimeout(builder.readTimeout).writeTimeout(builder.writeTimeout);
        if (builder.apiKey != null) {
            okHttpClientBuilder.addInterceptor((Interceptor)new AuthorizationHeaderInjector(builder.apiKey));
        }
        HashMap<String, String> headers = new HashMap<String, String>();
        if (builder.customHeaders != null) {
            headers.putAll(builder.customHeaders);
        }
        if (!headers.isEmpty()) {
            okHttpClientBuilder.addInterceptor((Interceptor)new GenericHeaderInjector(headers));
        }
        if (builder.proxy != null) {
            okHttpClientBuilder.proxy(builder.proxy);
        }
        if (builder.logRequests) {
            okHttpClientBuilder.addInterceptor((Interceptor)new RequestLoggingInterceptor());
        }
        if (builder.logResponses) {
            okHttpClientBuilder.addInterceptor((Interceptor)new ResponseLoggingInterceptor());
        }
        this.logStreamingResponses = builder.logStreamingResponses;
        this.okHttpClient = okHttpClientBuilder.build();
        Retrofit.Builder retrofitBuilder = new Retrofit.Builder().baseUrl(this.baseUrl).client(this.okHttpClient);
        retrofitBuilder.addConverterFactory((Converter.Factory)JacksonConverterFactory.create((ObjectMapper)JsonUtil.getObjectMapper()));
        this.xinferenceApi = (XinferenceApi)retrofitBuilder.build().create(XinferenceApi.class);
    }

    public void shutdown() {
        this.okHttpClient.dispatcher().executorService().shutdown();
        this.okHttpClient.connectionPool().evictAll();
        Cache cache = this.okHttpClient.cache();
        if (cache != null) {
            try {
                cache.close();
            }
            catch (IOException e) {
                log.error("Failed to close cache", (Throwable)e);
            }
        }
    }

    public SyncOrAsyncOrStreaming<CompletionResponse> completions(CompletionRequest request) {
        return new RequestExecutor<CompletionRequest, CompletionResponse, CompletionResponse>(this.xinferenceApi.completions(CompletionRequest.builder().from(request).stream(null).build()), r -> r, this.okHttpClient, this.formatUrl("v1/completions"), () -> CompletionRequest.builder().from(request).stream(true).streamOptions(StreamOptions.of(true)).build(), CompletionResponse.class, r -> r, this.logStreamingResponses);
    }

    public SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletions(ChatCompletionRequest request) {
        return new RequestExecutor<ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponse>(this.xinferenceApi.chatCompletions(ChatCompletionRequest.builder().from(request).stream(null).build()), r -> r, this.okHttpClient, this.formatUrl("v1/chat/completions"), () -> ChatCompletionRequest.builder().from(request).stream(true).streamOptions(StreamOptions.of(true)).build(), ChatCompletionResponse.class, r -> r, this.logStreamingResponses);
    }

    public SyncOrAsync<EmbeddingResponse> embeddings(EmbeddingRequest request) {
        return new RequestExecutor(this.xinferenceApi.embeddings(request), r -> r);
    }

    public SyncOrAsync<RerankResponse> rerank(RerankRequest request) {
        return new RequestExecutor(this.xinferenceApi.rerank(request), r -> r);
    }

    public SyncOrAsync<ImageResponse> generations(ImageRequest request) {
        return new RequestExecutor(this.xinferenceApi.generations(request), r -> r);
    }

    public SyncOrAsync<ImageResponse> variations(ImageRequest request, byte[] image) {
        MultipartBody.Builder builder = XinferenceClient.toMultipartBuilder(request);
        builder.addFormDataPart("image", "image", RequestBody.create((byte[])image, (MediaType)MediaType.parse((String)"image")));
        return new RequestExecutor(this.xinferenceApi.variations((RequestBody)builder.build()), r -> r);
    }

    public SyncOrAsync<ImageResponse> inpainting(ImageRequest request, byte[] image, byte[] maskImage) {
        MultipartBody.Builder builder = XinferenceClient.toMultipartBuilder(request);
        builder.addFormDataPart("image", "image", RequestBody.create((byte[])image, (MediaType)MediaType.parse((String)"image")));
        builder.addFormDataPart("mask_image", "mask_image", RequestBody.create((byte[])maskImage, (MediaType)MediaType.parse((String)"image")));
        return new RequestExecutor(this.xinferenceApi.inpainting((RequestBody)builder.build()), r -> r);
    }

    public SyncOrAsync<String> ocr(OcrRequest request) {
        MultipartBody.Builder builder = new MultipartBody.Builder().setType(MediaType.get((String)"multipart/form-data")).addFormDataPart("model", request.getModel()).addFormDataPart("image", "image", RequestBody.create((byte[])request.getImage(), (MediaType)MediaType.parse((String)"image")));
        if (Utils.isNotNullOrBlank((String)request.getKwargs())) {
            builder.addFormDataPart("kwargs", request.getKwargs());
        }
        return new RequestExecutor(this.xinferenceApi.ocr((RequestBody)builder.build()), r -> r);
    }

    private String formatUrl(String endpoint) {
        return this.baseUrl + endpoint;
    }

    private static MultipartBody.Builder toMultipartBuilder(ImageRequest request) {
        MultipartBody.Builder builder = new MultipartBody.Builder().setType(MediaType.get((String)"multipart/form-data")).addFormDataPart("model", request.getModel()).addFormDataPart("prompt", request.getPrompt()).addFormDataPart("response_format", request.getResponseFormat().getValue());
        if (Utils.isNotNullOrBlank((String)request.getNegativePrompt())) {
            builder.addFormDataPart("negative_prompt", request.getNegativePrompt());
        }
        if (Objects.nonNull(request.getN())) {
            builder.addFormDataPart("n", String.valueOf(request.getN()));
        }
        if (Utils.isNotNullOrBlank((String)request.getSize())) {
            builder.addFormDataPart("size", request.getSize());
        }
        if (Utils.isNotNullOrBlank((String)request.getKwargs())) {
            builder.addFormDataPart("kwargs", request.getKwargs());
        }
        return builder;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        public String baseUrl;
        public String apiKey;
        public Duration callTimeout = Duration.ofSeconds(60L);
        public Duration connectTimeout = Duration.ofSeconds(60L);
        public Duration readTimeout = Duration.ofSeconds(60L);
        public Duration writeTimeout = Duration.ofSeconds(60L);
        private Proxy proxy;
        public boolean logRequests;
        public boolean logResponses;
        public boolean logStreamingResponses;
        public Map<String, String> customHeaders;

        public Builder baseUrl(String baseUrl) {
            if (baseUrl == null || baseUrl.trim().isEmpty()) {
                throw new IllegalArgumentException("baseUrl cannot be null or empty");
            }
            this.baseUrl = baseUrl.endsWith("/") ? baseUrl : baseUrl + "/";
            return this;
        }

        public Builder apiKey(String apiKey) {
            this.apiKey = apiKey;
            return this;
        }

        public Builder callTimeout(Duration callTimeout) {
            if (callTimeout == null) {
                throw new IllegalArgumentException("callTimeout cannot be null");
            }
            this.callTimeout = callTimeout;
            return this;
        }

        public Builder connectTimeout(Duration connectTimeout) {
            if (connectTimeout == null) {
                throw new IllegalArgumentException("connectTimeout cannot be null");
            }
            this.connectTimeout = connectTimeout;
            return this;
        }

        public Builder readTimeout(Duration readTimeout) {
            if (readTimeout == null) {
                throw new IllegalArgumentException("readTimeout cannot be null");
            }
            this.readTimeout = readTimeout;
            return this;
        }

        public Builder writeTimeout(Duration writeTimeout) {
            if (writeTimeout == null) {
                throw new IllegalArgumentException("writeTimeout cannot be null");
            }
            this.writeTimeout = writeTimeout;
            return this;
        }

        public Builder proxy(Proxy.Type type, String ip, int port) {
            this.proxy = new Proxy(type, new InetSocketAddress(ip, port));
            return this;
        }

        public Builder proxy(Proxy proxy) {
            this.proxy = proxy;
            return this;
        }

        public Builder logRequests() {
            return this.logRequests(true);
        }

        public Builder logRequests(Boolean logRequests) {
            if (logRequests == null) {
                logRequests = false;
            }
            this.logRequests = logRequests;
            return this;
        }

        public Builder logResponses() {
            return this.logResponses(true);
        }

        public Builder logResponses(Boolean logResponses) {
            if (logResponses == null) {
                logResponses = false;
            }
            this.logResponses = logResponses;
            return this;
        }

        public Builder logStreamingResponses() {
            return this.logStreamingResponses(true);
        }

        public Builder logStreamingResponses(Boolean logStreamingResponses) {
            if (logStreamingResponses == null) {
                logStreamingResponses = false;
            }
            this.logStreamingResponses = logStreamingResponses;
            return this;
        }

        public Builder customHeaders(Map<String, String> customHeaders) {
            this.customHeaders = customHeaders;
            return this;
        }

        public XinferenceClient build() {
            return new XinferenceClient(this);
        }
    }
}

