/*
 * Decompiled with CFR 0.152.
 */
package com.hexadevlabs.gpt4all;

import com.hexadevlabs.gpt4all.LLModelLibrary;
import com.hexadevlabs.gpt4all.PromptIsTooLongException;
import com.hexadevlabs.gpt4all.Util;
import java.io.ByteArrayOutputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import jnr.ffi.Pointer;
import jnr.ffi.Runtime;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LLModel
implements AutoCloseable {
    public static String LIBRARY_SEARCH_PATH;
    public static boolean OUTPUT_DEBUG;
    private static final Logger logger;
    public static final String GPT4ALL_VERSION = "2.4.11";
    protected static LLModelLibrary library;
    protected Pointer model;
    protected String modelName;

    public static GenerationConfig.Builder config() {
        return new GenerationConfig.Builder();
    }

    LLModel() {
    }

    public LLModel(Path modelPath) {
        logger.info("Java bindings for gpt4all version: 2.4.11");
        if (library == null) {
            if (LIBRARY_SEARCH_PATH != null) {
                library = Util.loadSharedLibrary(LIBRARY_SEARCH_PATH);
                library.llmodel_set_implementation_search_path(LIBRARY_SEARCH_PATH);
            } else {
                Path tempLibraryDirectory = Util.copySharedLibraries();
                library = Util.loadSharedLibrary(tempLibraryDirectory.toString());
                library.llmodel_set_implementation_search_path(tempLibraryDirectory.toString());
            }
        }
        this.modelName = modelPath.getFileName().toString();
        String modelPathAbs = modelPath.toAbsolutePath().toString();
        LLModelLibrary.LLModelError error = new LLModelLibrary.LLModelError(Runtime.getSystemRuntime());
        if (!Files.exists(modelPath, new LinkOption[0])) {
            throw new IllegalStateException("Model file does not exist: " + modelPathAbs);
        }
        if (!Files.isReadable(modelPath)) {
            throw new IllegalStateException("Model file cannot be read: " + modelPathAbs);
        }
        this.model = library.llmodel_model_create2(modelPathAbs, "auto", error);
        if (this.model == null) {
            throw new IllegalStateException("Could not load, gpt4all backend returned error: " + error.message);
        }
        library.llmodel_loadModel(this.model, modelPathAbs);
        if (!library.llmodel_isModelLoaded(this.model)) {
            throw new IllegalStateException("The model " + this.modelName + " could not be loaded");
        }
    }

    public void setThreadCount(int nThreads) {
        library.llmodel_setThreadCount(this.model, nThreads);
    }

    public int threadCount() {
        return library.llmodel_threadCount(this.model);
    }

    public String generate(String prompt, GenerationConfig generationConfig) {
        return this.generate(prompt, generationConfig, false);
    }

    public String generate(String prompt, GenerationConfig generationConfig, boolean streamToStdOut) {
        ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
        ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
        LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(streamToStdOut, bufferingForStdOutStream, bufferingForWholeGeneration);
        library.llmodel_prompt(this.model, prompt, tokenID -> {
            if (OUTPUT_DEBUG) {
                System.out.println("token " + tokenID);
            }
            return true;
        }, responseCallback, isRecalculating -> {
            if (OUTPUT_DEBUG) {
                System.out.println("recalculating");
            }
            return isRecalculating;
        }, generationConfig);
        return bufferingForWholeGeneration.toString(StandardCharsets.UTF_8);
    }

    static LLModelLibrary.ResponseCallback getResponseCallback(boolean streamToStdOut, ByteArrayOutputStream bufferingForStdOutStream, ByteArrayOutputStream bufferingForWholeGeneration) {
        return (tokenID, response) -> {
            byte nextByte;
            if (OUTPUT_DEBUG) {
                System.out.print("Response token " + tokenID + " ");
            }
            if (tokenID == -1) {
                throw new PromptIsTooLongException(response.getString(0L, 1000, StandardCharsets.UTF_8));
            }
            long len = 0L;
            do {
                try {
                    nextByte = response.getByte(len);
                }
                catch (IndexOutOfBoundsException e) {
                    throw new RuntimeException("Empty array or not null terminated");
                }
                ++len;
                if (nextByte == 0) continue;
                bufferingForWholeGeneration.write(nextByte);
                if (!streamToStdOut) continue;
                bufferingForStdOutStream.write(nextByte);
                byte[] currentBytes = bufferingForStdOutStream.toByteArray();
                String validString = Util.getValidUtf8(currentBytes);
                if (validString == null) continue;
                System.out.print(validString);
                bufferingForStdOutStream.reset();
            } while (nextByte != 0);
            return true;
        };
    }

    public ChatCompletionResponse chatCompletion(List<Map<String, String>> messages, GenerationConfig generationConfig) {
        return this.chatCompletion(messages, generationConfig, false, false);
    }

    public ChatCompletionResponse chatCompletion(List<Map<String, String>> messages, GenerationConfig generationConfig, boolean streamToStdOut, boolean outputFullPromptToStdOut) {
        String fullPrompt = LLModel.buildPrompt(messages);
        if (outputFullPromptToStdOut) {
            System.out.print(fullPrompt);
        }
        String generatedText = this.generate(fullPrompt, generationConfig, streamToStdOut);
        ChatCompletionResponse response = new ChatCompletionResponse();
        response.model = this.modelName;
        Usage usage = new Usage();
        usage.promptTokens = fullPrompt.length();
        usage.completionTokens = generatedText.length();
        usage.totalTokens = fullPrompt.length() + generatedText.length();
        response.usage = usage;
        HashMap<String, String> message = new HashMap<String, String>();
        message.put("role", "assistant");
        message.put("content", generatedText);
        response.choices = List.of(message);
        return response;
    }

    protected static String buildPrompt(List<Map<String, String>> messages) {
        StringBuilder fullPrompt = new StringBuilder();
        for (Map<String, String> message : messages) {
            if (!"system".equals(message.get("role"))) continue;
            String systemMessage = message.get("content") + "\n";
            fullPrompt.append(systemMessage);
        }
        fullPrompt.append("### Instruction: \nThe prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.\n### Prompt: ");
        for (Map<String, String> message : messages) {
            if ("user".equals(message.get("role"))) {
                String userMessage = "\n" + message.get("content");
                fullPrompt.append(userMessage);
            }
            if (!"assistant".equals(message.get("role"))) continue;
            String assistantMessage = "\n### Response: " + message.get("content");
            fullPrompt.append(assistantMessage);
        }
        fullPrompt.append("\n### Response:");
        return fullPrompt.toString();
    }

    @Override
    public void close() throws Exception {
        library.llmodel_model_destroy(this.model);
    }

    static {
        OUTPUT_DEBUG = false;
        logger = LoggerFactory.getLogger(LLModel.class);
    }

    public static class GenerationConfig
    extends LLModelLibrary.LLModelPromptContext {
        private GenerationConfig() {
            super(Runtime.getSystemRuntime());
            this.logits_size.set(0L);
            this.tokens_size.set(0L);
            this.n_past.set(0L);
            this.n_ctx.set(1024L);
            this.n_predict.set(128L);
            this.top_k.set(40L);
            this.top_p.set((Number)0.95);
            this.temp.set((Number)0.28);
            this.n_batch.set(8L);
            this.repeat_penalty.set((Number)1.1);
            this.repeat_last_n.set(10L);
            this.context_erase.set((Number)0.55);
        }

        public static class Builder {
            private final GenerationConfig configToBuild = new GenerationConfig();

            public Builder withNPast(int n_past) {
                this.configToBuild.n_past.set((long)n_past);
                return this;
            }

            public Builder withNCtx(int n_ctx) {
                this.configToBuild.n_ctx.set((long)n_ctx);
                return this;
            }

            public Builder withNPredict(int n_predict) {
                this.configToBuild.n_predict.set((long)n_predict);
                return this;
            }

            public Builder withTopK(int top_k) {
                this.configToBuild.top_k.set((long)top_k);
                return this;
            }

            public Builder withTopP(float top_p) {
                this.configToBuild.top_p.set(top_p);
                return this;
            }

            public Builder withTemp(float temp) {
                this.configToBuild.temp.set(temp);
                return this;
            }

            public Builder withNBatch(int n_batch) {
                this.configToBuild.n_batch.set((long)n_batch);
                return this;
            }

            public Builder withRepeatPenalty(float repeat_penalty) {
                this.configToBuild.repeat_penalty.set(repeat_penalty);
                return this;
            }

            public Builder withRepeatLastN(int repeat_last_n) {
                this.configToBuild.repeat_last_n.set((long)repeat_last_n);
                return this;
            }

            public Builder withContextErase(float context_erase) {
                this.configToBuild.context_erase.set(context_erase);
                return this;
            }

            public GenerationConfig build() {
                return this.configToBuild;
            }
        }
    }

    public static class ChatCompletionResponse {
        public String model;
        public Usage usage;
        public List<Map<String, String>> choices;
    }

    public static class Usage {
        public int promptTokens;
        public int completionTokens;
        public int totalTokens;
    }
}

