/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.jlama;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.ModelSupport;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.SafeTensorSupport;
import com.github.tjake.jlama.safetensors.prompt.Function;
import com.github.tjake.jlama.safetensors.prompt.Tool;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import dev.langchain4j.model.chat.request.json.JsonSchemaElementHelper;
import dev.langchain4j.model.output.FinishReason;
import io.quarkiverse.langchain4j.jlama.JlamaModelRegistry;
import java.io.File;
import java.nio.file.Path;
import java.util.Map;
import java.util.Optional;

public class JlamaModel {
    private final JlamaModelRegistry registry;
    private final ModelSupport.ModelType modelType;
    private final String modelName;
    private final Optional<String> owner;
    private final String modelId;
    private final boolean isLocal;

    JlamaModel(JlamaModelRegistry registry, ModelSupport.ModelType modelType, String modelName, Optional<String> owner, String modelId, boolean isLocal) {
        this.registry = registry;
        this.modelType = modelType;
        this.modelName = modelName;
        this.owner = owner;
        this.modelId = modelId;
        this.isLocal = isLocal;
    }

    ModelSupport.ModelType getModelType() {
        return this.modelType;
    }

    String getModelName() {
        return this.modelName;
    }

    Optional<String> getOwner() {
        return this.owner;
    }

    String getModelId() {
        return this.modelId;
    }

    boolean isLocal() {
        return this.isLocal;
    }

    Loader loader() {
        return new Loader(this.registry, this.modelName);
    }

    static Tool toTool(ToolSpecification toolSpecification) {
        Function.Builder builder = Function.builder().name(toolSpecification.name()).description(toolSpecification.description());
        if (toolSpecification.parameters() != null) {
            for (Map.Entry p : toolSpecification.parameters().properties().entrySet()) {
                builder.addParameter((String)p.getKey(), JsonSchemaElementHelper.toMap((JsonSchemaElement)((JsonSchemaElement)p.getValue())), toolSpecification.parameters().required().contains(p.getKey()));
            }
        }
        return Tool.from((Function)builder.build());
    }

    static FinishReason toFinishReason(Generator.FinishReason reason) {
        return switch (reason) {
            case Generator.FinishReason.STOP_TOKEN -> FinishReason.STOP;
            case Generator.FinishReason.MAX_TOKENS -> FinishReason.LENGTH;
            case Generator.FinishReason.ERROR -> FinishReason.OTHER;
            case Generator.FinishReason.TOOL_CALL -> FinishReason.TOOL_EXECUTION;
            default -> throw new IllegalArgumentException("Unknown reason: " + String.valueOf(reason));
        };
    }

    static class Loader {
        private final JlamaModelRegistry registry;
        private final String modelName;
        private Path workingDirectory;
        private DType workingQuantizationType = DType.I8;
        private DType quantizationType;
        private Integer threadCount;
        private AbstractModel.InferenceType inferenceType = AbstractModel.InferenceType.FULL_GENERATION;

        private Loader(JlamaModelRegistry registry, String modelName) {
            this.registry = registry;
            this.modelName = modelName;
        }

        public Loader quantized() {
            this.quantizationType = DType.Q4;
            return this;
        }

        public Loader workingQuantizationType(DType workingQuantizationType) {
            this.workingQuantizationType = workingQuantizationType;
            return this;
        }

        public Loader workingDirectory(Path workingDirectory) {
            this.workingDirectory = workingDirectory;
            return this;
        }

        public Loader threadCount(Integer threadCount) {
            this.threadCount = threadCount;
            return this;
        }

        public Loader inferenceType(AbstractModel.InferenceType inferenceType) {
            this.inferenceType = inferenceType;
            return this;
        }

        public AbstractModel load() {
            return ModelSupport.loadModel((AbstractModel.InferenceType)this.inferenceType, (File)new File(this.registry.getModelCachePath().toFile(), this.modelName), (File)(this.workingDirectory == null ? null : this.workingDirectory.toFile()), (DType)DType.F32, (DType)this.workingQuantizationType, Optional.ofNullable(this.quantizationType), Optional.ofNullable(this.threadCount), Optional.empty(), SafeTensorSupport::loadWeights);
        }
    }
}

