package io.github.javpower.vectorex.keynote.analysis;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.util.*;

public class FinalSegmenter {
    private static volatile FinalSegmenter instance;
    private static final String PROB_EMIT = "/prob_emit.txt";
    private static final char[] STATES = {'B', 'M', 'E', 'S'};
    private static final double MIN_FLOAT = -3.14e100;

    private final Map<Character, Map<Character, Double>> emitProb;
    private final Map<Character, Double> startProb;
    private final Map<Character, Map<Character, Double>> transProb;
    private final Map<Character, char[]> prevStatus;

    private FinalSegmenter() {
        emitProb = new HashMap<>();
        startProb = new HashMap<>();
        transProb = new HashMap<>();
        prevStatus = new HashMap<>();

        loadModel();
    }

    public static FinalSegmenter getInstance() {
        if (instance == null) {
            synchronized (FinalSegmenter.class) {
                if (instance == null) {
                    instance = new FinalSegmenter();
                }
            }
        }
        return instance;
    }

    private void loadModel() {
        // 初始化转移概率和发射概率
        prevStatus.put('B', new char[]{'E', 'S'});
        prevStatus.put('M', new char[]{'M', 'B'});
        prevStatus.put('E', new char[]{'B', 'M'});
        prevStatus.put('S', new char[]{'S', 'E'});

        startProb.put('B', -0.26268660809250016);
        startProb.put('E', MIN_FLOAT);
        startProb.put('M', MIN_FLOAT);
        startProb.put('S', -1.4652633398537678);

        Map<Character, Double> emitProbB = new HashMap<>();
        emitProbB.put('E', -0.510825623765990);
        emitProbB.put('M', -0.916290731874155);
        transProb.put('B', emitProbB);
        Map<Character, Double> emitProbM = new HashMap<>();
        emitProbM.put('E', -0.33344856811948514);
        emitProbM.put('M', -1.2603623820268226);
        transProb.put('M', emitProbM);
        Map<Character, Double> emitProbE = new HashMap<>();
        emitProbE.put('B', -0.5897149736854513);
        emitProbE.put('S', -0.8085250474669937);
        transProb.put('E', emitProbE);
        Map<Character, Double> emitProbS = new HashMap<>();
        emitProbS.put('B', -0.7211965654669841);
        emitProbS.put('S', -0.6658631448798212);
        transProb.put('S', emitProbS);

        // 加载发射概率
        InputStream is = this.getClass().getResourceAsStream(PROB_EMIT);
        try {
            BufferedReader reader = new BufferedReader(new InputStreamReader(is, Charset.forName("UTF-8")));
            String line;
            Map<Character, Double> currentMap = null;
            while ((line = reader.readLine()) != null) {
                String[] tokens = line.split("\t");
                if (tokens.length == 1) {
                    currentMap = new HashMap<>();
                    emitProb.put(tokens[0].charAt(0), currentMap);
                } else {
                    currentMap.put(tokens[0].charAt(0), Double.parseDouble(tokens[1]));
                }
            }
        } catch (IOException e) {
            System.err.println("Failed to load emission probabilities: " + PROB_EMIT);
        }
    }

    public List<String> cut(String sentence) {
        List<String> tokens = new ArrayList<>();
        if (sentence == null || sentence.isEmpty()) return tokens;

        // Viterbi algorithm
        int length = sentence.length();
        double[][] dp = new double[length][STATES.length];
        int[][] path = new int[length][STATES.length];

        // Initialize
        for (int i = 0; i < STATES.length; i++) {
            char state = STATES[i];
            double emit = emitProb.getOrDefault(state, Collections.emptyMap())
                    .getOrDefault(sentence.charAt(0), MIN_FLOAT);
            dp[0][i] = startProb.getOrDefault(state, MIN_FLOAT) + emit;
        }

        // Dynamic programming
        for (int i = 1; i < length; i++) {
            for (int j = 0; j < STATES.length; j++) {
                char current = STATES[j];
                dp[i][j] = Double.NEGATIVE_INFINITY;
                for (int k = 0; k < STATES.length; k++) {
                    char prev = STATES[k];
                    double trans = transProb.getOrDefault(prev, Collections.emptyMap())
                            .getOrDefault(current, MIN_FLOAT);
                    double emit = emitProb.getOrDefault(current, Collections.emptyMap())
                            .getOrDefault(sentence.charAt(i), MIN_FLOAT);
                    double score = dp[i - 1][k] + trans + emit;
                    if (score > dp[i][j]) {
                        dp[i][j] = score;
                        path[i][j] = k;
                    }
                }
            }
        }

        // Backtracking
        int lastState = 0;
        for (int i = 1; i < STATES.length; i++) {
            if (dp[length - 1][i] > dp[length - 1][lastState]) {
                lastState = i;
            }
        }

        List<Character> bestPath = new ArrayList<>();
        for (int i = length - 1; i >= 0; i--) {
            bestPath.add(STATES[lastState]);
            lastState = path[i][lastState];
        }
        Collections.reverse(bestPath);

        // Extract tokens based on the best path
        StringBuilder token = new StringBuilder();
        for (int i = 0; i < length; i++) {
            char state = bestPath.get(i);
            token.append(sentence.charAt(i));
            if (state == 'E' || state == 'S') {
                tokens.add(token.toString());
                token.setLength(0);
            }
        }

        return tokens;
    }
}