/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.dashscope;

import com.alibaba.dashscope.embeddings.TextEmbedding;
import com.alibaba.dashscope.embeddings.TextEmbeddingOutput;
import com.alibaba.dashscope.embeddings.TextEmbeddingParam;
import com.alibaba.dashscope.embeddings.TextEmbeddingResult;
import com.alibaba.dashscope.embeddings.TextEmbeddingResultItem;
import com.alibaba.dashscope.exception.NoApiKeyException;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.dashscope.spi.QwenEmbeddingModelBuilderFactory;
import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.spi.ServiceHelper;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

public class QwenEmbeddingModel
extends DimensionAwareEmbeddingModel {
    public static final String TYPE_KEY = "type";
    public static final String TYPE_QUERY = "query";
    public static final String TYPE_DOCUMENT = "document";
    private static final int BATCH_SIZE = 6;
    private final String apiKey;
    private final String modelName;
    private final TextEmbedding embedding;

    public QwenEmbeddingModel(String baseUrl, String apiKey, String modelName) {
        if (Utils.isNullOrBlank((String)apiKey)) {
            throw new IllegalArgumentException("DashScope api key must be defined. It can be generated here: https://dashscope.console.aliyun.com/apiKey");
        }
        this.modelName = Utils.isNullOrBlank((String)modelName) ? "text-embedding-v2" : modelName;
        this.apiKey = apiKey;
        this.embedding = Utils.isNullOrBlank((String)baseUrl) ? new TextEmbedding() : new TextEmbedding(baseUrl);
    }

    private boolean containsDocuments(List<TextSegment> textSegments) {
        return textSegments.stream().map(TextSegment::metadata).map(metadata -> metadata.getString(TYPE_KEY)).anyMatch(TYPE_DOCUMENT::equalsIgnoreCase);
    }

    private boolean containsQueries(List<TextSegment> textSegments) {
        return textSegments.stream().map(TextSegment::metadata).map(metadata -> metadata.getString(TYPE_KEY)).anyMatch(TYPE_QUERY::equalsIgnoreCase);
    }

    private Response<List<Embedding>> embedTexts(List<TextSegment> textSegments, TextEmbeddingParam.TextType textType) {
        int size = textSegments.size();
        if (size < 6) {
            return this.batchEmbedTexts(textSegments, textType);
        }
        ArrayList allEmbeddings = new ArrayList(size);
        TokenUsage allUsage = null;
        for (int i = 0; i < size; i += 6) {
            List<TextSegment> batchTextSegments = textSegments.subList(i, Math.min(size, i + 6));
            Response<List<Embedding>> batchResponse = this.batchEmbedTexts(batchTextSegments, textType);
            allEmbeddings.addAll((Collection)batchResponse.content());
            allUsage = TokenUsage.sum(allUsage, (TokenUsage)batchResponse.tokenUsage());
        }
        return Response.from(allEmbeddings, allUsage);
    }

    private Response<List<Embedding>> batchEmbedTexts(List<TextSegment> textSegments, TextEmbeddingParam.TextType textType) {
        TextEmbeddingParam param = ((TextEmbeddingParam.TextEmbeddingParamBuilder)((TextEmbeddingParam.TextEmbeddingParamBuilder)TextEmbeddingParam.builder().apiKey(this.apiKey)).model(this.modelName)).textType(textType).texts((Collection)textSegments.stream().map(TextSegment::text).collect(Collectors.toList())).build();
        try {
            TextEmbeddingResult generationResult = this.embedding.call(param);
            TokenUsage usage = new TokenUsage(generationResult.getUsage().getTotalTokens());
            List embeddings = Optional.of(generationResult).map(TextEmbeddingResult::getOutput).map(TextEmbeddingOutput::getEmbeddings).orElse(Collections.emptyList()).stream().sorted(Comparator.comparing(TextEmbeddingResultItem::getTextIndex)).map(TextEmbeddingResultItem::getEmbedding).map(doubleList -> doubleList.stream().map(Double::floatValue).collect(Collectors.toList())).map(Embedding::from).collect(Collectors.toList());
            return Response.from(embeddings, (TokenUsage)usage);
        }
        catch (NoApiKeyException e) {
            throw new RuntimeException(e);
        }
    }

    public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
        boolean queries = this.containsQueries(textSegments);
        if (!queries) {
            return this.embedTexts(textSegments, TextEmbeddingParam.TextType.DOCUMENT);
        }
        boolean documents = this.containsDocuments(textSegments);
        if (!documents) {
            return this.embedTexts(textSegments, TextEmbeddingParam.TextType.QUERY);
        }
        ArrayList embeddings = new ArrayList(textSegments.size());
        Integer tokens = null;
        for (TextSegment textSegment : textSegments) {
            Response<List<Embedding>> result = TYPE_QUERY.equalsIgnoreCase(textSegment.metadata().getString(TYPE_KEY)) ? this.embedTexts(Collections.singletonList(textSegment), TextEmbeddingParam.TextType.QUERY) : this.embedTexts(Collections.singletonList(textSegment), TextEmbeddingParam.TextType.DOCUMENT);
            embeddings.addAll((Collection)result.content());
            if (result.tokenUsage() == null) continue;
            if (tokens == null) {
                tokens = result.tokenUsage().inputTokenCount();
                continue;
            }
            tokens = tokens + result.tokenUsage().inputTokenCount();
        }
        return Response.from(embeddings, (TokenUsage)new TokenUsage(tokens));
    }

    public static QwenEmbeddingModelBuilder builder() {
        Iterator iterator = ServiceHelper.loadFactories(QwenEmbeddingModelBuilderFactory.class).iterator();
        if (iterator.hasNext()) {
            QwenEmbeddingModelBuilderFactory factory = (QwenEmbeddingModelBuilderFactory)iterator.next();
            return (QwenEmbeddingModelBuilder)factory.get();
        }
        return new QwenEmbeddingModelBuilder();
    }

    public static class QwenEmbeddingModelBuilder {
        private String baseUrl;
        private String apiKey;
        private String modelName;

        public QwenEmbeddingModelBuilder baseUrl(String baseUrl) {
            this.baseUrl = baseUrl;
            return this;
        }

        public QwenEmbeddingModelBuilder apiKey(String apiKey) {
            this.apiKey = apiKey;
            return this;
        }

        public QwenEmbeddingModelBuilder modelName(String modelName) {
            this.modelName = modelName;
            return this;
        }

        public QwenEmbeddingModel build() {
            return new QwenEmbeddingModel(this.baseUrl, this.apiKey, this.modelName);
        }

        public String toString() {
            return "QwenEmbeddingModel.QwenEmbeddingModelBuilder(baseUrl=" + this.baseUrl + ", apiKey=" + this.apiKey + ", modelName=" + this.modelName + ")";
        }
    }
}

