/*
 * Decompiled with CFR 0.152.
 */
package org.predict4all.nlp.ngram;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.predict4all.nlp.Tag;
import org.predict4all.nlp.io.TokenFileInputStream;
import org.predict4all.nlp.language.LanguageModel;
import org.predict4all.nlp.ngram.NGramKey;
import org.predict4all.nlp.ngram.debug.NGramDebugger;
import org.predict4all.nlp.ngram.dictionary.TrainingNGramDictionary;
import org.predict4all.nlp.parser.TokenProvider;
import org.predict4all.nlp.parser.token.Token;
import org.predict4all.nlp.trainer.TrainerTask;
import org.predict4all.nlp.trainer.configuration.NGramPruningMethod;
import org.predict4all.nlp.trainer.configuration.TrainingConfiguration;
import org.predict4all.nlp.trainer.corpus.AbstractTrainingDocument;
import org.predict4all.nlp.trainer.corpus.TrainingCorpus;
import org.predict4all.nlp.trainer.step.TrainingStep;
import org.predict4all.nlp.utils.Pair;
import org.predict4all.nlp.utils.Predict4AllUtils;
import org.predict4all.nlp.utils.progressindicator.LoggingProgressIndicator;
import org.predict4all.nlp.utils.progressindicator.ProgressIndicator;
import org.predict4all.nlp.words.WordDictionary;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NGramDictionaryGenerator {
    private static final Logger LOGGER = LoggerFactory.getLogger(NGramDictionaryGenerator.class);
    private final LanguageModel languageModel;
    private final TrainingConfiguration trainingConfiguration;
    private final int maxOrder;
    private final WordDictionary wordDictionary;
    private String debugPrefix;
    private NGramDebugger ngramDebugBeforePruning;
    private NGramDebugger ngramDebugAfterPruning;

    public NGramDictionaryGenerator(LanguageModel languageModel, TrainingConfiguration trainingConfiguration, WordDictionary wordDictionary) {
        this.wordDictionary = wordDictionary;
        this.languageModel = languageModel;
        this.trainingConfiguration = trainingConfiguration;
        this.maxOrder = this.trainingConfiguration.getNgramOrder();
    }

    public Map<Integer, Pair<Integer, Integer>> executeNGramTraining(TrainingCorpus corpus, File ngramOutputFile, Consumer<List<TrainerTask>> blockingTaskExecutor) throws IOException {
        long startInsert = System.currentTimeMillis();
        corpus.initStep(TrainingStep.NGRAM_DICTIONARY);
        LoggingProgressIndicator progressIndicator = new LoggingProgressIndicator("Generating ngrams", corpus.getTotalCountFor(TrainingStep.NGRAM_DICTIONARY));
        ConcurrentHashMap<NGramKey, LongAdder> ngramCounts = new ConcurrentHashMap<NGramKey, LongAdder>(8000000, 0.9f, Runtime.getRuntime().availableProcessors());
        blockingTaskExecutor.accept(corpus.getDocuments(TrainingStep.NGRAM_DICTIONARY).stream().map(d -> new TrainingNGramDictionaryTask(progressIndicator, (AbstractTrainingDocument)d, ngramCounts)).collect(Collectors.toList()));
        LOGGER.info("NGram generation tasks finished in {} s, will now insert to dictionary", (Object)((double)(System.currentTimeMillis() - startInsert) / 1000.0));
        TrainingNGramDictionary ngramDictionary = TrainingNGramDictionary.create(this.maxOrder);
        LoggingProgressIndicator progressIndicatorInsert = new LoggingProgressIndicator("Generating ngram dictionary", ngramCounts.size());
        ngramCounts.forEach((ngram, sum) -> {
            ngramCounts.remove(ngram);
            ngramDictionary.putAndIncrementBy(ngram.ngram, sum.intValue());
            progressIndicatorInsert.increment();
        });
        LOGGER.info("Every ngram inserted in dictionary, will now compact");
        ngramDictionary.compact();
        LOGGER.info("Dictionary compacted");
        ngramDictionary.countNGrams();
        if (this.ngramDebugBeforePruning != null) {
            this.ngramDebugBeforePruning.debug(this.wordDictionary, Predict4AllUtils.isNotBlank(this.debugPrefix) ? ngramDictionary.getNodeForPrefix(Arrays.stream(this.debugPrefix.split(" ")).filter(s -> Predict4AllUtils.isNotBlank(s)).mapToInt(w -> this.wordDictionary.getWordId((String)w)).toArray(), 0) : ngramDictionary.getRoot());
        }
        if (this.trainingConfiguration.getPruningMethod() == NGramPruningMethod.NONE) {
            ngramDictionary.updateProbabilities(ngramDictionary.computeD(this.trainingConfiguration));
        } else {
            switch (this.trainingConfiguration.getPruningMethod()) {
                case WEIGHTED_DIFFERENCE_RAW_PROB: 
                case WEIGHTED_DIFFERENCE_FULL_PROB: {
                    ngramDictionary.pruneNGramsWeightedDifference(this.trainingConfiguration.getNgramPruningWeightedDifferenceThreshold(), this.trainingConfiguration, this.trainingConfiguration.getPruningMethod());
                    break;
                }
                case RAW_COUNT: {
                    ngramDictionary.pruneNGramsCount(this.trainingConfiguration.getNgramPruningCountThreshold(), this.trainingConfiguration);
                    break;
                }
                case ORDER_COUNT: {
                    ngramDictionary.pruneNGramsOrderCount(this.trainingConfiguration.getNgramPruningOrderCountThresholds(), this.trainingConfiguration);
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Pruning method " + (Object)((Object)this.trainingConfiguration.getPruningMethod()) + " not implemented");
                }
            }
        }
        if (this.ngramDebugAfterPruning != null) {
            this.ngramDebugAfterPruning.debug(this.wordDictionary, Predict4AllUtils.isNotBlank(this.debugPrefix) ? ngramDictionary.getNodeForPrefix(Arrays.stream(this.debugPrefix.split(" ")).filter(s -> Predict4AllUtils.isNotBlank(s)).mapToInt(w -> this.wordDictionary.getWordId((String)w)).toArray(), 0) : ngramDictionary.getRoot());
        }
        if (ngramOutputFile != null) {
            ngramDictionary.saveDictionary(ngramOutputFile);
        }
        return ngramDictionary.countNGrams();
    }

    public NGramDebugger getNgramDebugBeforePruning() {
        return this.ngramDebugBeforePruning;
    }

    public void setNgramDebugBeforePruning(NGramDebugger ngramDebugBeforePruning) {
        this.ngramDebugBeforePruning = ngramDebugBeforePruning;
    }

    public NGramDebugger getNgramDebugAfterPruning() {
        return this.ngramDebugAfterPruning;
    }

    public void setNgramDebugAfterPruning(NGramDebugger ngramDebugAfterPruning) {
        this.ngramDebugAfterPruning = ngramDebugAfterPruning;
    }

    public String getDebugPrefix() {
        return this.debugPrefix;
    }

    public void setDebugPrefix(String debugPrefix) {
        this.debugPrefix = debugPrefix;
    }

    private List<int[]> generateNGramForDocument(TokenProvider tokenFis, File outputFile, ProgressIndicator progressIndicator, boolean userTraining) throws IOException {
        ArrayList<int[]> generatedNGrams = new ArrayList<int[]>(500);
        Token currentSentenceStart = null;
        for (Token token = tokenFis.getNext(); token != null; token = token.getNext(tokenFis)) {
            if (token.isSeparator() && token.getSeparator().isSentenceSeparator() || token.getNext(tokenFis) == null) {
                this.generateNGramForSentence(generatedNGrams, tokenFis, currentSentenceStart, token);
                currentSentenceStart = token.getNext(tokenFis);
            } else {
                currentSentenceStart = Predict4AllUtils.getOrDefault(currentSentenceStart, token);
            }
            progressIndicator.increment();
        }
        return generatedNGrams;
    }

    private void generateNGramForSentence(List<int[]> ngramsList, TokenProvider tokenProvider, Token start, Token end) throws IOException {
        ArrayList<Token> tokens = new ArrayList<Token>();
        for (Token current = start; current != null; current = current.getNext(tokenProvider)) {
            if (!current.isSeparator()) {
                tokens.add(current);
            }
            if (current == end) break;
        }
        if (!tokens.isEmpty()) {
            for (int i = -1; i < tokens.size(); ++i) {
                for (int order = 1; order <= this.maxOrder; ++order) {
                    List<int[]> ngrams = this.generateNGramsFromToken(order, tokens, i);
                    if (ngrams == null) continue;
                    ngramsList.addAll(ngrams);
                }
            }
        }
    }

    private List<int[]> generateNGramsFromToken(int order, List<Token> tokens, int startIndex) {
        ArrayList<int[]> ngrams = new ArrayList<int[]>(order);
        for (int j = 0; j < order; ++j) {
            if (j + startIndex >= 0) {
                if (j + startIndex < tokens.size()) {
                    Token token = tokens.get(j + startIndex);
                    int wordId = token.getWordId(this.wordDictionary);
                    if (wordId == Tag.UNKNOWN.getId()) {
                        return null;
                    }
                    if (j == 0) {
                        ngrams.add(this.createArrayAndSetFirst(order, wordId));
                        continue;
                    }
                    for (int[] ngram : ngrams) {
                        ngram[j] = wordId;
                    }
                    continue;
                }
                return null;
            }
            this.createOrAddNGramTag(order, ngrams, j, Tag.START);
        }
        return ngrams;
    }

    private int[] createArrayAndSetFirst(int length, int wordId) {
        int[] array = new int[length];
        array[0] = wordId;
        return array;
    }

    private void createOrAddNGramTag(int order, List<int[]> ngrams, int insertIndex, Tag tag) {
        if (insertIndex == 0) {
            ngrams.add(this.createArrayAndSetFirst(order, tag.getId()));
        } else {
            for (int[] ngram : ngrams) {
                ngram[insertIndex] = tag.getId();
            }
        }
    }

    private class TrainingNGramDictionaryTask
    extends TrainerTask {
        private final ConcurrentHashMap<NGramKey, LongAdder> ngramCounts;

        public TrainingNGramDictionaryTask(ProgressIndicator progressIndicator, AbstractTrainingDocument document, ConcurrentHashMap<NGramKey, LongAdder> ngramCounts) {
            super(progressIndicator, document);
            this.ngramCounts = ngramCounts;
        }

        @Override
        public void run() throws Exception {
            try (TokenFileInputStream tfis = new TokenFileInputStream(this.document.getInputFile());){
                List ngrams = NGramDictionaryGenerator.this.generateNGramForDocument(tfis, null, this.progressIndicator, false);
                for (int[] ngram : ngrams) {
                    this.ngramCounts.computeIfAbsent(new NGramKey(ngram), k -> new LongAdder()).increment();
                }
            }
        }
    }
}

