/*
 * Decompiled with CFR 0.152.
 */
package org.predict4all.nlp.prediction;

import java.io.IOException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.predict4all.nlp.Separator;
import org.predict4all.nlp.ngram.NGramWordPredictorUtils;
import org.predict4all.nlp.ngram.dictionary.AbstractNGramDictionary;
import org.predict4all.nlp.ngram.trie.AbstractNGramTrieNode;
import org.predict4all.nlp.ngram.trie.DynamicNGramTrieNode;
import org.predict4all.nlp.parser.Tokenizer;
import org.predict4all.nlp.parser.matcher.TokenConverter;
import org.predict4all.nlp.parser.token.Token;
import org.predict4all.nlp.prediction.PredictionParameter;
import org.predict4all.nlp.prediction.WordPrediction;
import org.predict4all.nlp.prediction.WordPredictionResult;
import org.predict4all.nlp.prediction.model.AbstractPredictionToCompute;
import org.predict4all.nlp.prediction.model.DoublePredictionToCompute;
import org.predict4all.nlp.prediction.model.UniquePredictionToCompute;
import org.predict4all.nlp.trainer.configuration.TrainingConfiguration;
import org.predict4all.nlp.utils.BiIntegerKey;
import org.predict4all.nlp.utils.Pair;
import org.predict4all.nlp.utils.Predict4AllUtils;
import org.predict4all.nlp.utils.Triple;
import org.predict4all.nlp.words.NextWord;
import org.predict4all.nlp.words.WordDictionary;
import org.predict4all.nlp.words.WordPrefixDetected;
import org.predict4all.nlp.words.WordPrefixDetector;
import org.predict4all.nlp.words.correction.WordCorrectionGenerator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WordPredictor {
    private static final Logger LOGGER = LoggerFactory.getLogger(WordPredictor.class);
    private static final DecimalFormat DEBUG_DF = new DecimalFormat("0.0000");
    private static final int WANTED_COUNT_FACTOR = 3;
    private static final Set<Integer> EMPTY_INT_SET = new HashSet<Integer>(1){

        @Override
        public boolean contains(Object o) {
            return false;
        }
    };
    private static final int NGRAM_MAX_LAST_TEXT_LENGTH = 70;
    private static final TrainingConfiguration configuration = TrainingConfiguration.defaultConfiguration();
    private final NGramWordPredictorUtils ngramWordPredictorUtils;
    private final WordPrefixDetector wordPrefixDetector;
    private final Tokenizer tokenizer;
    private final TokenConverter termConverter;
    private final WordCorrectionGenerator wordCorrectionGenerator;
    private final WordDictionary wordDictionary;
    private final PredictionParameter predictionParameter;
    private final AbstractNGramDictionary<? extends AbstractNGramTrieNode<?>> staticNgramDictionary;
    private final AbstractNGramDictionary<? extends DynamicNGramTrieNode> dynamicNGramDictionary;

    public WordPredictor(PredictionParameter predictionParameter, WordDictionary wordDictionary, AbstractNGramDictionary<? extends AbstractNGramTrieNode<?>> staticNgramDictionary, AbstractNGramDictionary<? extends DynamicNGramTrieNode> dynamicNGramDictionary) {
        this.predictionParameter = Predict4AllUtils.checkNull(predictionParameter, "Word predictor requires PredictionParameter to work");
        this.wordDictionary = Predict4AllUtils.checkNull(wordDictionary, "Word predictor requires WordDictionary to work");
        this.staticNgramDictionary = Predict4AllUtils.checkNull(staticNgramDictionary, "Word predictor requires a static ngram dictionary to work");
        this.dynamicNGramDictionary = dynamicNGramDictionary;
        this.wordCorrectionGenerator = new WordCorrectionGenerator(wordDictionary, staticNgramDictionary, this.predictionParameter);
        this.ngramWordPredictorUtils = new NGramWordPredictorUtils(wordDictionary, predictionParameter);
        this.wordPrefixDetector = new WordPrefixDetector(wordDictionary, this.wordCorrectionGenerator, predictionParameter);
        this.termConverter = new TokenConverter(this.predictionParameter.getLanguageModel().getTokenMatchersForNGram());
        this.tokenizer = new Tokenizer(this.predictionParameter.getLanguageModel());
    }

    public WordPredictor(PredictionParameter predictionParameter, WordDictionary wordDictionary, AbstractNGramDictionary<? extends AbstractNGramTrieNode<?>> staticNgramDictionary) {
        this(predictionParameter, wordDictionary, staticNgramDictionary, null);
    }

    public WordPredictionResult predict(String textBeforeCaret, String textAfterCaret, int wantedCount) throws Exception {
        return this._predict(textBeforeCaret, textAfterCaret, wantedCount, null);
    }

    public WordPredictionResult predict(String textBeforeCaret, int wantedCount) throws Exception {
        return this._predict(textBeforeCaret, null, wantedCount, null);
    }

    public WordPredictionResult predict(String textBeforeCaret) throws Exception {
        return this._predict(textBeforeCaret, null, 5, null);
    }

    public WordPredictionResult predict(String textBeforeCaret, Set<Integer> wordIdsToExclude) throws Exception {
        return this._predict(textBeforeCaret, null, 5, wordIdsToExclude);
    }

    public WordPredictionResult predict(String textBeforeCaret, String textAfterCaret, int wantedCount, Set<Integer> wordIdsToExclude) throws Exception {
        return this._predict(textBeforeCaret, textAfterCaret, wantedCount, wordIdsToExclude);
    }

    public void trainDynamicModel(String rawText, boolean ignoreLastSentence) throws IOException {
        this._trainDynamicModel(rawText, ignoreLastSentence);
    }

    public void trainDynamicModel(String rawText) throws IOException {
        this._trainDynamicModel(rawText, false);
    }

    public PredictionParameter getPredictionParameter() {
        return this.predictionParameter;
    }

    public WordDictionary getWordDictionary() {
        return this.wordDictionary;
    }

    public AbstractNGramDictionary<? extends AbstractNGramTrieNode<?>> getStaticNgramDictionary() {
        return this.staticNgramDictionary;
    }

    public AbstractNGramDictionary<? extends DynamicNGramTrieNode> getDynamicNGramDictionary() {
        return this.dynamicNGramDictionary;
    }

    public void dispose() {
        try {
            this.staticNgramDictionary.close();
        }
        catch (Exception e) {
            LOGGER.error("NGram static dictionary closing failed", (Throwable)e);
        }
        try {
            if (this.dynamicNGramDictionary != null) {
                this.dynamicNGramDictionary.close();
            }
        }
        catch (Exception e) {
            LOGGER.error("NGram dynamic dictionary closing failed", (Throwable)e);
        }
        try {
            if (this.wordCorrectionGenerator != null) {
                this.wordCorrectionGenerator.dispose();
            }
        }
        catch (Exception e) {
            LOGGER.error("NGram dynamic dictionary closing failed", (Throwable)e);
        }
        LOGGER.info("Word predictor disposed");
    }

    private WordPredictionResult _predict(String textBeforeCaret, String textAfterCaret, int wantedCount, Set<Integer> wordIdsToExclude) throws Exception {
        if (this.predictionParameter.isEnableDebugInformation()) {
            LOGGER.warn("Predictor debug is enabled, never enable this configuration in production !");
        }
        wordIdsToExclude = wordIdsToExclude != null ? wordIdsToExclude : EMPTY_INT_SET;
        LOGGER.debug("Predict for \"{}\" (wanted count = {})", (Object)textBeforeCaret, (Object)wantedCount);
        long startTotal = System.currentTimeMillis();
        Pair<List<Token>, Boolean> tokenForPrediction = this.getTokenForPrediction(textBeforeCaret);
        List<Token> tokens = tokenForPrediction.getLeft();
        LOGGER.debug("Tokens for prediction are : {}", tokens);
        long startLongestMatch = System.currentTimeMillis();
        WordPrefixDetected longestMatchingWords = this.wordPrefixDetector.getLongestMatchingWords(tokens, wantedCount * 3, wordIdsToExclude);
        LOGGER.debug("Longest match detection took {} ms, found matching word : {} - {}", new Object[]{System.currentTimeMillis() - startLongestMatch, longestMatchingWords != null, longestMatchingWords != null ? longestMatchingWords.getLongestWordPrefix() : null});
        if (this.predictionParameter.getMinCountToProvidePrediction() <= 0 || longestMatchingWords != null && longestMatchingWords.getLongestWordPrefix().length() >= this.predictionParameter.getMinCountToProvidePrediction()) {
            int wantedPrefixLenght = this.staticNgramDictionary.getMaxOrder();
            Triple<int[], Boolean, Boolean> prefixAndUnknownWord = this.ngramWordPredictorUtils.createPrefixFor(tokens, longestMatchingWords, wantedPrefixLenght, this.predictionParameter.isAddNewWordsEnabled());
            int[] prefixForNGram = prefixAndUnknownWord.getLeft();
            long startDebug = System.currentTimeMillis();
            Map<BiIntegerKey, NextWord> nextWords = this.getNextWords(wantedCount, wordIdsToExclude, longestMatchingWords, prefixForNGram);
            LOGGER.debug("getNextWords in {} ms", (Object)(System.currentTimeMillis() - startDebug));
            startDebug = System.currentTimeMillis();
            List<AbstractPredictionToCompute> predictions = this.transformNextWordsToPrediction(prefixForNGram, nextWords, longestMatchingWords != null, wordIdsToExclude);
            LOGGER.debug("transformNextWordsToPrediction in {} ms", (Object)(System.currentTimeMillis() - startDebug));
            startDebug = System.currentTimeMillis();
            double predSum = this.computeProbabilities(prefixForNGram, predictions);
            LOGGER.debug("computeProbabilities2 in {} ms", (Object)(System.currentTimeMillis() - startDebug));
            startDebug = System.currentTimeMillis();
            Collections.sort(predictions);
            LOGGER.debug("sort in {} ms", (Object)(System.currentTimeMillis() - startDebug));
            LOGGER.debug("Prediction prob sum before normalization = {}", (Object)predSum);
            List<AbstractPredictionToCompute> predSubList = predictions.subList(0, Math.min(predictions.size(), wantedCount));
            boolean capitalize = this.handleDoubleWordByCase(tokens, longestMatchingWords, predSubList);
            ArrayList<WordPrediction> wordPredictions = new ArrayList<WordPrediction>(predSubList.size());
            for (AbstractPredictionToCompute prediction : predSubList) {
                this.createWordPrediction(longestMatchingWords, predSum, capitalize, wordPredictions, prediction);
            }
            long time = System.currentTimeMillis() - startTotal;
            LOGGER.info("Prediction took {} ms ({} results)", (Object)time, (Object)predictions.size());
            return new WordPredictionResult(null, Predict4AllUtils.countEndUntilNextSeparator(textAfterCaret), wordPredictions);
        }
        return new WordPredictionResult(null, 0, Collections.emptyList());
    }

    private double computeProbabilities(int[] prefixForNGram, List<AbstractPredictionToCompute> predictions) {
        Pair<Double, Double> probInter = this.computeProbInterpolation();
        double staticNgramWeight = probInter.getLeft();
        double dynamicNGramWeight = probInter.getRight();
        double probSum = 0.0;
        for (AbstractPredictionToCompute prediction : predictions) {
            if (!prediction.isDouble()) {
                prediction.setScore(this.computeProb(prefixForNGram, staticNgramWeight, dynamicNGramWeight, prediction.getWordId(), prediction));
            } else {
                DoublePredictionToCompute mp = (DoublePredictionToCompute)prediction;
                mp.setScore(this.computeProb(mp.getFirstPrefix(), staticNgramWeight, dynamicNGramWeight, mp.getFirstWordId(), prediction) * this.computeProb(mp.getSecondPrefix(), staticNgramWeight, dynamicNGramWeight, mp.getSecondWordId(), prediction));
            }
            double probFactor = this.wordDictionary.getWord(prediction.getWordId()).getProbFactor();
            prediction.setScore(prediction.getScore() * probFactor);
            if (this.predictionParameter.isEnableDebugInformation()) {
                prediction.getDebugInformation().append("\nWDF = ").append(probFactor);
            }
            probSum += prediction.getScore();
        }
        return probSum;
    }

    private double computeProb(int[] prefixForNGram, double staticNgramWeight, double dynamicNGramWeight, int wordId, AbstractPredictionToCompute prediction) {
        double staticProb = this.staticNgramDictionary.getProbability(prefixForNGram, 0, prefixForNGram.length, wordId);
        double probability = staticNgramWeight * staticProb;
        if (this.predictionParameter.isEnableDebugInformation()) {
            prediction.getDebugInformation().append("\n").append(DEBUG_DF.format(staticNgramWeight)).append(" * ").append(DEBUG_DF.format(staticProb));
        }
        if (this.dynamicNGramDictionary != null && this.predictionParameter.isDynamicModelEnabled()) {
            double dynProb = this.dynamicNGramDictionary.getProbability(prefixForNGram, 0, prefixForNGram.length, wordId);
            probability += dynamicNGramWeight * dynProb;
            if (this.predictionParameter.isEnableDebugInformation()) {
                prediction.getDebugInformation().append(" + ").append(DEBUG_DF.format(dynamicNGramWeight)).append(" * ").append(DEBUG_DF.format(dynProb));
            }
        }
        if (this.predictionParameter.isEnableDebugInformation()) {
            prediction.getDebugInformation().append(" * ").append(DEBUG_DF.format(prediction.getFactor()));
        }
        return prediction.getFactor() * probability;
    }

    private Pair<Double, Double> computeProbInterpolation() {
        if (this.predictionParameter.isDynamicModelEnabled() && this.dynamicNGramDictionary != null) {
            DynamicNGramTrieNode dynRoot = this.dynamicNGramDictionary.getRoot();
            double dynamicWeight = Math.max(1.0 * dynRoot.getChildrenCountSum() / (1.0 * (double)this.staticNgramDictionary.getRoot().getChildrenSize()), this.predictionParameter.getDynamicModelMinimumWeight());
            return Pair.of(1.0 - dynamicWeight, dynamicWeight);
        }
        return Pair.of(1.0, 0.0);
    }

    private Pair<List<Token>, Boolean> getTokenForPrediction(String rawText) throws IOException {
        long startTokenize = System.currentTimeMillis();
        String textForNGram = rawText.substring(Math.max(0, rawText.length() - 70), rawText.length());
        List<Token> tokens = this.tokenizer.tokenize(textForNGram);
        boolean lastTokenSeparator = !tokens.isEmpty() && tokens.get(tokens.size() - 1).isSeparator();
        tokens = this.termConverter.executeTermDetection(tokens);
        LOGGER.debug("Tokenization before word prediction took {} ms", (Object)(System.currentTimeMillis() - startTokenize));
        return Pair.of(tokens, lastTokenSeparator);
    }

    private boolean handleDoubleWordByCase(List<Token> tokens, WordPrefixDetected longestMatchingWords, List<AbstractPredictionToCompute> predSubList) {
        boolean capitalize = longestMatchingWords != null ? longestMatchingWords.isCapitalizedWord() : this.wordPrefixDetector.isNextWordsCapitalized(tokens, null, 0);
        Map<String, Long> predictionCountsForLowerCaseWord = predSubList.stream().peek(p -> p.computePrediction(this.wordDictionary)).collect(Collectors.groupingBy(p -> Predict4AllUtils.lowerCase(p.getPrediction()), Collectors.counting()));
        ListIterator<AbstractPredictionToCompute> predIterator = predSubList.listIterator(predSubList.size());
        while (predIterator.hasPrevious()) {
            AbstractPredictionToCompute pred = predIterator.previous();
            String lowerCasePredText = Predict4AllUtils.lowerCase(pred.getPrediction());
            Long predCount = predictionCountsForLowerCaseWord.remove(lowerCasePredText);
            if (predCount == null || predCount <= 1L) continue;
            predIterator.remove();
        }
        return capitalize;
    }

    private void createWordPrediction(WordPrefixDetected longestMatchingWords, double predSum, boolean capitalize, List<WordPrediction> wordPredictions, AbstractPredictionToCompute prediction) {
        String predictionToDisplay = capitalize ? Predict4AllUtils.capitalize(prediction.getPrediction()) : prediction.getPrediction();
        int previousCharCountToRemove = longestMatchingWords == null || predictionToDisplay.startsWith(longestMatchingWords.getLongestWordPrefix()) ? 0 : longestMatchingWords.getLongestWordPrefix().length();
        String predictionToInsert = longestMatchingWords == null || !predictionToDisplay.startsWith(longestMatchingWords.getLongestWordPrefix()) ? predictionToDisplay : predictionToDisplay.substring(longestMatchingWords.getLongestWordPrefix().length());
        boolean insertSpacePossible = Separator.getSeparatorFor(predictionToDisplay.charAt(predictionToDisplay.length() - 1)) != Separator.APOSTROPHE;
        wordPredictions.add(new WordPrediction(predictionToDisplay, predictionToInsert, insertSpacePossible, prediction.getScore() / predSum, previousCharCountToRemove, prediction.isCorrection(), prediction.getWordId(), prediction.getDebugInformation() != null ? prediction.getDebugInformation().toString() : null));
    }

    private Map<BiIntegerKey, NextWord> getNextWords(int wantedCount, Set<Integer> wordIdsToExclude, WordPrefixDetected longestMatchingWords, int[] prefixForNGram) throws IOException {
        HashMap<BiIntegerKey, NextWord> nextWords;
        HashMap<BiIntegerKey, NextWord> hashMap = nextWords = longestMatchingWords != null ? longestMatchingWords.getWords() : new HashMap<BiIntegerKey, NextWord>(wantedCount * 3);
        if (longestMatchingWords == null) {
            this.staticNgramDictionary.listNextWords(prefixForNGram, this.wordDictionary, this.predictionParameter, wordIdsToExclude, nextWords, wantedCount * 3, true);
            if (this.predictionParameter.isDynamicModelEnabled() && this.dynamicNGramDictionary != null) {
                this.dynamicNGramDictionary.listNextWords(prefixForNGram, this.wordDictionary, this.predictionParameter, wordIdsToExclude, nextWords, wantedCount * 3, false);
            }
        }
        return nextWords;
    }

    private List<AbstractPredictionToCompute> transformNextWordsToPrediction(int[] prefixForNGram, Map<BiIntegerKey, NextWord> nextWords, boolean enableDoublePrediction, Set<Integer> wordIdsToExclude) {
        ArrayList<AbstractPredictionToCompute> predictions = new ArrayList<AbstractPredictionToCompute>();
        nextWords.forEach((key, nextWord) -> {
            if (nextWord.isUnique()) {
                int wordId = nextWord.getWordId1();
                predictions.add(new UniquePredictionToCompute(wordId, nextWord.getFactor(), nextWord.isCorrection(), nextWord.getDebugInformation()));
                boolean endWithApostrophe = Predict4AllUtils.endsWith(this.wordDictionary.getWord(wordId).getWord(), Separator.APOSTROPHE.getOfficialCharString());
                if (enableDoublePrediction && endWithApostrophe) {
                    int[] modifiedPrefix = new int[prefixForNGram.length];
                    System.arraycopy(prefixForNGram, 1, modifiedPrefix, 0, prefixForNGram.length - 1);
                    modifiedPrefix[modifiedPrefix.length - 1] = wordId;
                    boolean wantedCount = true;
                    HashMap<BiIntegerKey, NextWord> nextWordsAfterWrittenWord = new HashMap<BiIntegerKey, NextWord>(10);
                    this.staticNgramDictionary.listNextWords(modifiedPrefix, this.wordDictionary, this.predictionParameter, wordIdsToExclude, nextWordsAfterWrittenWord, 1, false);
                    nextWordsAfterWrittenWord.forEach((key2, nextWord2) -> {
                        if (!nextWord2.isUnique()) {
                            throw new IllegalStateException("List next word should never return not unique predictions");
                        }
                        predictions.add(new DoublePredictionToCompute(wordId, nextWord2.getWordId1(), false, prefixForNGram, modifiedPrefix, nextWord.getFactor() * nextWord2.getFactor(), nextWord.isCorrection(), this.predictionParameter.isEnableDebugInformation() ? new StringBuilder(nextWord.getDebugInformation()) : null));
                    });
                }
            } else {
                int[] modifiedPrefix = new int[prefixForNGram.length];
                System.arraycopy(prefixForNGram, 1, modifiedPrefix, 0, prefixForNGram.length - 1);
                modifiedPrefix[modifiedPrefix.length - 1] = nextWord.getWordId1();
                predictions.add(new DoublePredictionToCompute(nextWord.getWordId1(), nextWord.getWordId2(), nextWord.getSeparator() != Separator.APOSTROPHE, prefixForNGram, modifiedPrefix, nextWord.getFactor(), nextWord.isCorrection(), nextWord.getDebugInformation()));
            }
        });
        return predictions;
    }

    private void _trainDynamicModel(String rawText, boolean ignoreLastSentence) throws IOException {
        if (this.dynamicNGramDictionary != null && this.predictionParameter.isDynamicModelEnabled()) {
            long start = System.currentTimeMillis();
            List<Token> tokens = this.termConverter.executeTermDetection(this.tokenizer.tokenize(rawText));
            List<List<Token>> sentences = this.divideInSentencesAndRemoveSeparator(tokens);
            for (int i = 0; i < sentences.size(); ++i) {
                List<Token> sentence = sentences.get(i);
                if (sentence.isEmpty() || ignoreLastSentence && i == sentences.size() - 1) continue;
                for (int endIndex = 1; endIndex <= sentence.size(); ++endIndex) {
                    List<Token> sentencePart = sentence.subList(0, endIndex);
                    Triple<int[], Boolean, Boolean> prefixAndUnknownWord = this.ngramWordPredictorUtils.createPrefixFor(sentencePart, null, this.dynamicNGramDictionary.getMaxOrder() + 1, this.predictionParameter.isAddNewWordsEnabled());
                    this.trainDynamicNGramModel(prefixAndUnknownWord, sentencePart.size() == 1 ? this.dynamicNGramDictionary.getMaxOrder() - 2 : 0);
                }
            }
            LOGGER.info("Trained dynamic prediction model in {} ms", (Object)(System.currentTimeMillis() - start));
        }
    }

    private void trainDynamicNGramModel(Triple<int[], Boolean, Boolean> prefixAndUnknownWord, int trainingStartOrder) {
        if (this.dynamicNGramDictionary != null && this.predictionParameter.isDynamicModelEnabled()) {
            int[] predictionPrefix = prefixAndUnknownWord.getLeft();
            this.wordDictionary.incrementUserWord(predictionPrefix[predictionPrefix.length - 1]);
            if (prefixAndUnknownWord.getMiddle().booleanValue() || prefixAndUnknownWord.getRight().booleanValue() && !this.predictionParameter.isAddNewWordsEnabled()) {
                LOGGER.info("Skip user training because there is only start tag = {}, or because new word learning is disabled = {}", (Object)prefixAndUnknownWord.getMiddle(), (Object)prefixAndUnknownWord.getRight());
                return;
            }
            for (int i = trainingStartOrder; i < this.dynamicNGramDictionary.getMaxOrder(); ++i) {
                this.dynamicNGramDictionary.putAndIncrementBy(predictionPrefix, i, 1);
                this.dynamicNGramDictionary.updateProbabilities(predictionPrefix, i, this.dynamicNGramDictionary.computeD(configuration));
            }
        }
    }

    private List<List<Token>> divideInSentencesAndRemoveSeparator(List<Token> tokens) {
        ArrayList<List<Token>> sentences = new ArrayList<List<Token>>();
        ArrayList<Token> currentSentence = new ArrayList<Token>();
        sentences.add(currentSentence);
        for (Token token : tokens) {
            if (token.isSeparator() && token.getSeparator().isSentenceSeparator()) {
                currentSentence = new ArrayList();
                sentences.add(currentSentence);
                continue;
            }
            if (token.isSeparator()) continue;
            currentSentence.add(token);
        }
        return sentences;
    }
}

