package io.github.javpower.vectorex.keynote.bm25;


import io.github.javpower.vectorex.keynote.analysis.ScoredEntity;
import io.github.javpower.vectorex.keynote.analysis.SegMode;
import io.github.javpower.vectorex.keynote.analysis.TextSegmenter;

import java.util.*;
import java.util.stream.Collectors;

/**
 * BM25算法实现，用于计算文档与查询的相关性。
 */
public class BM25 {

    private final TextSegmenter segmenter = new TextSegmenter(SegMode.SEARCH);

    /**
     * 全部文档中的词汇集合。
     */
    private static List<String> corpusTerms = new ArrayList<>();

    /**
     * 文档集合。
     */
    private static List<List<String>> documentList = new ArrayList<>();

    /**
     * 存储文档ID与文档内容的映射。
     */
    private static Map<List<String>, String> corpusHashMap = new HashMap<>();

    /**
     * 自由参数，通常选择 k1 = 1.2。范围：1.2 ~ 2.0
     */
    private double k1;

    /**
     * 自由参数，通常选择 b = 0.75。范围：0 ~ 1.0
     */
    private double b;

    /**
     * 默认构造函数，k1 = 1.2, b = 0.75。
     */
    public BM25() {
        this(1.2, 0.75);
    }

    /**
     * 构造函数，主要用于加载数据以方便后续计算。
     *
     * @param k1 调整文档词频缩放的参数。
     * @param b  决定文档长度缩放程度的参数。
     */
    public BM25(double k1, double b) {
        if (k1 < 0) {
            throw new IllegalArgumentException("Negative k1 = " + k1);
        }
        if (b < 0 || b > 1) {
            throw new IllegalArgumentException("Invalid b = " + b);
        }
        this.k1 = k1;
        this.b = b;
    }

    /**
     * 使用 BM25 的方法计算文档词频。
     *
     * @param tfDocument 文档的词汇列表。
     * @param term       查询词汇。
     * @return 文档中词汇的 BM25 权重。
     */
    private double tf(List<String> tfDocument, String term) {
        double count = 0;
        int ld = tfDocument.stream().mapToInt(String::length).sum();
        int corpusSize = corpusTerms.stream().mapToInt(String::length).sum();
        double avgDocSize = (double) corpusSize / documentList.size();

        for (String word : tfDocument) {
            if (term.equalsIgnoreCase(word)) {
                count++;
            }
        }
        double freq = count / tfDocument.size();

        return (freq * (k1 + 1)) / (freq + k1 * (1 - b + b * ld / avgDocSize));
    }

    /**
     * 计算逆文档频率。
     *
     * @param term 查询词汇。
     * @return 词汇的逆文档频率。
     */
    private double idf(String term) {
        double count = 0;
        for (List<String> idfDoc : documentList) {
            if (idfDoc.stream().anyMatch(word -> term.equalsIgnoreCase(word))) {
                count++;
            }
        }
        return Math.log(1 + (documentList.size() - count + 0.5) / (count + 0.5));
    }

    /**
     * 计算查询与文档的相关性分数。
     *
     * @param queryTermList 查询词汇列表。
     * @return 文档ID及其分数的映射。
     */
    private Map<String, Double> score(List<String> queryTermList) {
        Map<String, Double> scoredDocument = new HashMap<>();
        for (List<String> docTerms : documentList) {
            double sumScore = 0.0;
            for (String queryTerm : queryTermList) {
                sumScore += tf(docTerms, queryTerm) * idf(queryTerm);
            }
            String docId = corpusHashMap.get(docTerms);
            scoredDocument.put(docId, sumScore);
        }
        return scoredDocument;
    }

    /**
     * 对文档进行 BM25 排序。
     *
     * @param query    查询内容。
     * @param documents 文档集合。
     * @param topNum   返回的文档数量。
     * @return 排序后的文档ID及其分数。
     */
    public Map<String, Double> rankBM25(String query, Map<String, String> documents, int topNum) {
        clear();
        List<String> segmentList = segByCharacter(query);
        for (Map.Entry<String, String> docEntry : documents.entrySet()) {
            String id = docEntry.getKey();
            String doc = docEntry.getValue();
            List<String> segs = segByCharacter(doc);
            documentList.add(segs);
            corpusTerms.addAll(segs);
            corpusHashMap.put(segs, id);
        }
        Map<String, Double> scoredDoc = score(segmentList);
        return getTopN(scoredDoc, topNum);
    }

    /**
     * 获取分数最高的 n 个文档。
     *
     * @param scoredDoc 文档分数映射。
     * @param topNum    返回数量。
     * @return 排序后的文档ID及其分数。
     */
    private Map<String, Double> getTopN(Map<String, Double> scoredDoc, int topNum) {
        PriorityQueue<ScoredEntity<String>> maxHeap = new PriorityQueue<>(Comparator.comparingDouble(entry -> -entry.getScore()));
        scoredDoc.forEach((id, score) -> maxHeap.add(new ScoredEntity(id, score)));
        Map<String, Double> topNDoc = new LinkedHashMap<>();
        for (int i = 0; i < Math.min(topNum, maxHeap.size()); i++) {
            ScoredEntity<String> entry = maxHeap.poll();
            topNDoc.put(entry.getItem(), entry.getScore());
        }
        return topNDoc;
    }

    /**
     * 对句子进行分词。
     * 待优化：通过词性过滤无意义的词。
     *
     * @param sentence 输入句子。
     * @return 分词结果。
     */
    private List<String> seg(String sentence) {
        List<String> segs = segmenter.process(sentence).stream()
                .map(token -> token.getWord().toLowerCase())
                .collect(Collectors.toList());
        return segs;
    }

    /**
     * 对于中文分词，把字当作一个词单元效果更好。
     *
     * @param sentence 输入句子。
     * @return 分词结果。
     */
    private List<String> segByCharacter(String sentence) {
        return sentence.chars()
                .mapToObj(c -> String.valueOf((char) c).toLowerCase())
                .collect(Collectors.toList());
    }

    /**
     * 清除缓存。
     */
    private void clear() {
        documentList.clear();
        corpusTerms.clear();
        corpusHashMap.clear();
    }
}