package io.github.javpower.vectorex.keynote.tfidf;


import io.github.javpower.vectorex.keynote.analysis.ScoredEntity;
import io.github.javpower.vectorex.keynote.analysis.SegMode;
import io.github.javpower.vectorex.keynote.analysis.SegToken;
import io.github.javpower.vectorex.keynote.analysis.TextSegmenter;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.*;

/**
 * TF-IDF分析器类，用于从文本中提取关键词并计算其TF-IDF值。
 */
public class TFIDF {

    // IDF值的映射表，存储词语及其对应的IDF值
    private static Map<String, Double> idfMap;

    // 停用词集合，用于过滤无意义的词语
    private static Set<String> stopWordsSet;

    // IDF值的中位数，用于处理未登录词
    private static double idfMedian;

    /**
     * 分析文本内容，提取关键词并计算其TF-IDF值，返回前topN个关键词。
     *
     * @param content 待分析的文本内容
     * @param topN    返回的关键词数量
     * @return 包含关键词及其TF-IDF值的列表，按TF-IDF值降序排列
     */
    public List<ScoredEntity> analyze(String content, int topN) {
        // 初始化停用词集合和IDF映射表
        initializeResources();

        // 计算文本中每个词语的TF值
        Map<String, Double> tfMap = calculateTermFrequency(content);

        // 计算每个词语的TF-IDF值并生成关键词列表
        List<ScoredEntity> keywordList = calculateTFIDFKeywords(tfMap);

        // 对关键词列表按TF-IDF值降序排序
        keywordList.sort((k1, k2) -> Double.compare(k2.getScore(), k1.getScore()));

        // 返回前topN个关键词
        return keywordList.size() > topN ? keywordList.subList(0, topN) : keywordList;
    }

    /**
     * 初始化停用词集合和IDF映射表。
     */
    private void initializeResources() {
        if (stopWordsSet == null) {
            stopWordsSet = new HashSet<>();
            loadStopWords(stopWordsSet, this.getClass().getResourceAsStream("/stop_words.txt"));
        }
        if (idfMap == null) {
            idfMap = new HashMap<>();
            loadIDFMap(idfMap, this.getClass().getResourceAsStream("/idf_dict.txt"));
        }
    }

    /**
     * 计算文本中每个词语的TF值。
     *
     * @param content 待分析的文本内容
     * @return 包含词语及其TF值的映射表
     */
    private Map<String, Double> calculateTermFrequency(String content) {
        Map<String, Double> tfMap = new HashMap<>();
        if (content == null || content.trim().isEmpty()) {
            return tfMap;
        }

        // 使用TextSegmenter对文本进行分词
        TextSegmenter segmenter = new TextSegmenter(SegMode.INDEX);
        List<SegToken> tokens = segmenter.process(content);

        // 统计每个词语的频率
        Map<String, Integer> freqMap = new HashMap<>();
        int wordSum = 0;
        for (SegToken token : tokens) {
            String word = token.getWord();
            if (!stopWordsSet.contains(word) && word.length() > 1) {
                wordSum++;
                freqMap.put(word, freqMap.getOrDefault(word, 0) + 1);
            }
        }
        // 计算每个词语的TF值
        for (Map.Entry<String, Integer> entry : freqMap.entrySet()) {
            tfMap.put(entry.getKey(), (double) entry.getValue() / wordSum);
        }

        return tfMap;
    }

    /**
     * 计算每个词语的TF-IDF值并生成关键词列表。
     *
     * @param tfMap 包含词语及其TF值的映射表
     * @return 包含关键词及其TF-IDF值的列表
     */
    private List<ScoredEntity> calculateTFIDFKeywords(Map<String, Double> tfMap) {
        List<ScoredEntity> keywordList = new ArrayList<>();
        for (Map.Entry<String, Double> entry : tfMap.entrySet()) {
            String word = entry.getKey();
            double tfValue = entry.getValue();
            double idfValue = idfMap.getOrDefault(word, idfMedian); // 使用中位数处理未登录词
            keywordList.add(new ScoredEntity(word, tfValue * idfValue));
        }
        return keywordList;
    }

    /**
     * 加载停用词集合。
     *
     * @param set 停用词集合
     * @param in  停用词文件的输入流
     */
    private void loadStopWords(Set<String> set, InputStream in) {
        try (BufferedReader reader = new BufferedReader(new InputStreamReader(in))) {
            String line;
            while ((line = reader.readLine()) != null) {
                set.add(line.trim());
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * 加载IDF映射表并计算IDF值的中位数。
     *
     * @param map IDF映射表
     * @param in  IDF字典文件的输入流
     */
    private void loadIDFMap(Map<String, Double> map, InputStream in) {
        try (BufferedReader reader = new BufferedReader(new InputStreamReader(in))) {
            String line;
            while ((line = reader.readLine()) != null) {
                String[] kv = line.trim().split(" ");
                map.put(kv[0], Double.parseDouble(kv[1]));
            }

            // 计算IDF值的中位数
            List<Double> idfList = new ArrayList<>(map.values());
            Collections.sort(idfList);
            idfMedian = idfList.get(idfList.size() / 2);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

}