/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.search.llm;

import ai.vespa.llm.LanguageModel;
import ai.vespa.llm.completion.Prompt;
import ai.vespa.llm.completion.StringPrompt;
import ai.vespa.search.llm.LLMSearcher;
import ai.vespa.search.llm.LlmSearcherConfig;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.result.Hit;
import com.yahoo.search.result.HitGroup;
import com.yahoo.search.searchchain.Execution;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.logging.Logger;
import java.util.stream.Collectors;

@Beta
public class RAGSearcher
extends LLMSearcher {
    private static Logger log = Logger.getLogger(RAGSearcher.class.getName());
    private static final String CONTEXT_PROPERTY = "context";
    private static final String FIELDS_TO_INCLUDE_PROPERTY = "fields";

    @Inject
    public RAGSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) {
        super(config, languageModels);
        log.info("Starting " + RAGSearcher.class.getName() + " with language model " + config.providerId());
    }

    @Override
    public Result search(Query query, Execution execution) {
        Result result = execution.search(query);
        execution.fill(result);
        return this.complete(query, this.buildPrompt(query, result), result, execution);
    }

    protected Prompt buildPrompt(Query query, Result result) {
        String context;
        Object prompt = this.getPrompt(query);
        if (((String)prompt).contains("@query")) {
            prompt = ((String)prompt).replace("@query", query.getModel().getQueryString());
        }
        if ((context = this.lookupProperty(CONTEXT_PROPERTY, query)) == null || !context.equals("skip")) {
            if (!((String)prompt).contains("{context}")) {
                prompt = "{context}\n" + (String)prompt;
            }
            prompt = ((String)prompt).replace("{context}", this.buildContext(result));
        }
        return StringPrompt.from((String)prompt);
    }

    private String buildContext(Result result) {
        Set<String> fieldsToInclude = this.getFieldsToInclude(result.getQuery());
        StringBuilder sb = new StringBuilder();
        HitGroup hits = result.hits();
        int counter = 1;
        for (Hit hit : hits) {
            sb.append("document [").append(counter++).append("]:\n");
            hit.fields().forEach((key, value) -> {
                if (fieldsToInclude.isEmpty() || fieldsToInclude.contains(key)) {
                    sb.append((String)key).append(": ").append(value).append("\n");
                }
            });
            sb.append("\n");
        }
        return sb.toString();
    }

    private Set<String> getFieldsToInclude(Query query) {
        String includedFields = this.lookupProperty(FIELDS_TO_INCLUDE_PROPERTY, query);
        if (includedFields != null) {
            return Arrays.stream(includedFields.split(",")).map(String::trim).collect(Collectors.toSet());
        }
        return new HashSet<String>();
    }
}

