/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.jlama;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.ModelSupport;
import com.github.tjake.jlama.model.bert.BertModel;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
import dev.langchain4j.model.jlama.JlamaModel;
import dev.langchain4j.model.jlama.JlamaModelRegistry;
import dev.langchain4j.model.jlama.spi.JlamaEmbeddingModelBuilderFactory;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.spi.ServiceHelper;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;

public class JlamaEmbeddingModel
extends DimensionAwareEmbeddingModel {
    private final BertModel model;

    public JlamaEmbeddingModel(Path modelCachePath, String modelName, String authToken, Integer threadCount, Boolean quantizeModelAtRuntime, Path workingDirectory) {
        JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
        JlamaModel jlamaModel = (JlamaModel)RetryUtils.withRetry(() -> registry.downloadModel(modelName, Optional.ofNullable(authToken)), (int)3);
        if (jlamaModel.getModelType() != ModelSupport.ModelType.BERT) {
            throw new IllegalArgumentException("Model type must be BERT");
        }
        JlamaModel.Loader loader = jlamaModel.loader();
        if (quantizeModelAtRuntime != null && quantizeModelAtRuntime.booleanValue()) {
            loader = loader.quantized();
        }
        if (threadCount != null) {
            loader = loader.threadCount(threadCount);
        }
        if (workingDirectory != null) {
            loader = loader.workingDirectory(workingDirectory);
        }
        loader = loader.inferenceType(AbstractModel.InferenceType.FORWARD_PASS);
        this.model = (BertModel)loader.load();
        this.dimension = this.model.getConfig().embeddingLength;
    }

    public static JlamaEmbeddingModelBuilder builder() {
        Iterator iterator = ServiceHelper.loadFactories(JlamaEmbeddingModelBuilderFactory.class).iterator();
        if (iterator.hasNext()) {
            JlamaEmbeddingModelBuilderFactory factory = (JlamaEmbeddingModelBuilderFactory)iterator.next();
            return (JlamaEmbeddingModelBuilder)factory.get();
        }
        return new JlamaEmbeddingModelBuilder();
    }

    public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
        ArrayList embeddings = new ArrayList();
        textSegments.forEach(textSegment -> embeddings.add(Embedding.from((float[])this.model.embed(textSegment.text()))));
        return Response.from(embeddings);
    }

    public static class JlamaEmbeddingModelBuilder {
        private Path modelCachePath;
        private String modelName;
        private String authToken;
        private Integer threadCount;
        private Boolean quantizeModelAtRuntime;
        private Path workingDirectory;

        public JlamaEmbeddingModelBuilder modelCachePath(Path modelCachePath) {
            this.modelCachePath = modelCachePath;
            return this;
        }

        public JlamaEmbeddingModelBuilder modelName(String modelName) {
            this.modelName = modelName;
            return this;
        }

        public JlamaEmbeddingModelBuilder authToken(String authToken) {
            this.authToken = authToken;
            return this;
        }

        public JlamaEmbeddingModelBuilder threadCount(Integer threadCount) {
            this.threadCount = threadCount;
            return this;
        }

        public JlamaEmbeddingModelBuilder quantizeModelAtRuntime(Boolean quantizeModelAtRuntime) {
            this.quantizeModelAtRuntime = quantizeModelAtRuntime;
            return this;
        }

        public JlamaEmbeddingModelBuilder workingDirectory(Path workingDirectory) {
            this.workingDirectory = workingDirectory;
            return this;
        }

        public JlamaEmbeddingModel build() {
            return new JlamaEmbeddingModel(this.modelCachePath, this.modelName, this.authToken, this.threadCount, this.quantizeModelAtRuntime, this.workingDirectory);
        }

        public String toString() {
            return "JlamaEmbeddingModel.JlamaEmbeddingModelBuilder(modelCachePath=" + String.valueOf(this.modelCachePath) + ", modelName=" + this.modelName + ", authToken=" + this.authToken + ", threadCount=" + this.threadCount + ", quantizeModelAtRuntime=" + this.quantizeModelAtRuntime + ", workingDirectory=" + String.valueOf(this.workingDirectory) + ")";
        }
    }
}

