/*
 * Decompiled with CFR 0.152.
 */
package org.predict4all.nlp.trainer;

import java.io.File;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
import org.predict4all.nlp.language.LanguageModel;
import org.predict4all.nlp.ngram.NGramDictionaryGenerator;
import org.predict4all.nlp.ngram.debug.NGramDebugger;
import org.predict4all.nlp.parser.Tokenizer;
import org.predict4all.nlp.parser.matcher.TokenConverter;
import org.predict4all.nlp.semantic.SemanticDictionaryGenerator;
import org.predict4all.nlp.trainer.DataTrainerResult;
import org.predict4all.nlp.trainer.TrainerTask;
import org.predict4all.nlp.trainer.configuration.TrainingConfiguration;
import org.predict4all.nlp.trainer.corpus.TrainingCorpus;
import org.predict4all.nlp.trainer.step.TrainingStep;
import org.predict4all.nlp.utils.Pair;
import org.predict4all.nlp.words.WordDictionary;
import org.predict4all.nlp.words.WordDictionaryGenerator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DataTrainer {
    private static final Logger LOGGER = LoggerFactory.getLogger(DataTrainer.class);
    public static final String FILENAME_WORD_DICTIONARY = "words.bin";
    public static final String FILENAME_NGRAM_DICTIONARY = "ngrams.bin";
    public static final String FILENAME_LSA_DICTIONARY = "semantic.bin";
    public static final DecimalFormat PERCENT_FORMAT = new DecimalFormat("##0.00");
    private NGramDebugger ngramDebugBeforePruning;
    private NGramDebugger ngramDebugAfterPruning;
    private String debugPrefix;
    private final File outputDictionary;
    private final File outputNGram;
    private final File outputSemantic;
    private final LanguageModel languageModel;
    private final TrainingConfiguration trainingConfiguration;
    private final TrainingCorpus corpus;
    private final File workingDirectory;

    public DataTrainer(File workingDirectory, File outputDictionary, File outputNGram, File outputSemantic, LanguageModel languageModel, TrainingConfiguration trainingConfiguration) throws IOException {
        int concurrencyLevel = Runtime.getRuntime().availableProcessors();
        this.outputDictionary = outputDictionary;
        this.outputNGram = outputNGram;
        this.outputSemantic = outputSemantic;
        this.workingDirectory = workingDirectory;
        this.workingDirectory.mkdirs();
        this.corpus = new TrainingCorpus(concurrencyLevel, trainingConfiguration.getCorpus(), workingDirectory, "UTF-8");
        this.languageModel = languageModel;
        this.trainingConfiguration = trainingConfiguration;
    }

    public void setNgramDebugBeforePruning(NGramDebugger ngramDebugBeforePruning) {
        this.ngramDebugBeforePruning = ngramDebugBeforePruning;
    }

    public void setNgramDebugAfterPruning(NGramDebugger ngramDebugAfterPruning) {
        this.ngramDebugAfterPruning = ngramDebugAfterPruning;
    }

    public String getDebugPrefix() {
        return this.debugPrefix;
    }

    public void setDebugPrefix(String debugPrefix) {
        this.debugPrefix = debugPrefix;
    }

    public DataTrainerResult launchNGramTraining(TrainingStep initialStep) throws Exception {
        return this.launchTraining(initialStep, false);
    }

    public DataTrainerResult launchLSATraining(TrainingStep initialStep) throws Exception {
        return this.launchTraining(initialStep, true);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private DataTrainerResult launchTraining(TrainingStep initialStep, boolean semantic) throws Exception {
        long startTotal = System.currentTimeMillis();
        DataTrainerResult.Builder dataTrainerResultBuilder = DataTrainerResult.builder();
        Tokenizer simpleTextTokenizer = new Tokenizer(this.languageModel);
        TokenConverter tokenConverter = new TokenConverter(semantic ? this.languageModel.getTokenMatchersForSemanticAnalysis() : this.languageModel.getTokenMatchersForNGram());
        WordDictionaryGenerator wordDictionaryGenerator = new WordDictionaryGenerator(this.languageModel, this.trainingConfiguration);
        LOGGER.info("Will run on {} processors", (Object)this.corpus.getConcurrencyLevel());
        ExecutorService executorService = this.createExecutorService();
        try {
            long start = System.currentTimeMillis();
            if (initialStep.ordinal() <= TrainingStep.PARSER.ordinal()) {
                this.executeTasksBlocking(executorService, simpleTextTokenizer.tokenize(this.corpus));
                LOGGER.info("Raw parsing took {} s", (Object)((double)(System.currentTimeMillis() - start) / 1000.0));
            }
            if (initialStep.ordinal() <= TrainingStep.TOKEN_CONVERT.ordinal() || semantic) {
                start = System.currentTimeMillis();
                this.executeTasksBlocking(executorService, tokenConverter.executeTokenPatternMatching(this.corpus));
                LOGGER.info("Token convert took {} s", (Object)((double)(System.currentTimeMillis() - start) / 1000.0));
            }
            if (initialStep.ordinal() <= TrainingStep.WORDS_DICTIONARY.ordinal() || semantic) {
                start = System.currentTimeMillis();
                wordDictionaryGenerator.createWordDictionary(this.corpus, tasks -> this.executeTasksBlocking(executorService, (List)tasks), semantic ? null : this.outputDictionary);
                LOGGER.info("Word dictionary generation took process took {} s", (Object)((double)(System.currentTimeMillis() - start) / 1000.0));
            }
            if (initialStep.ordinal() <= TrainingStep.NGRAM_DICTIONARY.ordinal() && !semantic) {
                WordDictionary wordDictionary = WordDictionary.loadDictionary(this.languageModel, this.outputDictionary);
                NGramDictionaryGenerator nGramDictionaryGenerator = new NGramDictionaryGenerator(this.languageModel, this.trainingConfiguration, wordDictionary);
                nGramDictionaryGenerator.setNgramDebugAfterPruning(this.ngramDebugAfterPruning);
                nGramDictionaryGenerator.setNgramDebugBeforePruning(this.ngramDebugBeforePruning);
                nGramDictionaryGenerator.setDebugPrefix(this.debugPrefix);
                start = System.currentTimeMillis();
                Map<Integer, Pair<Integer, Integer>> ngramCounts = nGramDictionaryGenerator.executeNGramTraining(this.corpus, this.outputNGram, tasks -> this.executeTasksBlocking(executorService, (List)tasks));
                dataTrainerResultBuilder.withNgramCounts(ngramCounts);
                LOGGER.info("Ngram dictionary process took {} s", (Object)((double)(System.currentTimeMillis() - start) / 1000.0));
            }
            if (initialStep.ordinal() <= TrainingStep.SEMANTIC_DICTIONARY.ordinal() && semantic) {
                start = System.currentTimeMillis();
                SemanticDictionaryGenerator lsaGenerator = new SemanticDictionaryGenerator(this.languageModel, WordDictionary.loadDictionary(this.languageModel, this.outputDictionary), this.trainingConfiguration);
                lsaGenerator.executeLSATrainingForR(this.corpus, this.outputSemantic, tasks -> this.executeTasksBlocking(executorService, (List)tasks));
                LOGGER.info("Semantic dictionary process took {} s", (Object)((double)(System.currentTimeMillis() - start) / 1000.0));
            }
        }
        finally {
            executorService.shutdown();
        }
        LOGGER.info("Whole training process took {} s", (Object)((double)(System.currentTimeMillis() - startTotal) / 1000.0));
        return dataTrainerResultBuilder.build();
    }

    private <T extends TrainerTask> void executeTasksBlocking(ExecutorService executorService, List<T> tasks) {
        try {
            executorService.invokeAll(tasks).forEach(future -> {
                try {
                    future.get();
                }
                catch (Throwable t) {
                    if (t.getCause() != null) {
                        LOGGER.error("Problem in task execution", t.getCause());
                    }
                    LOGGER.error("Problem in task execution", t);
                }
            });
        }
        catch (InterruptedException e) {
            LOGGER.error("Problem in task execution", (Throwable)e);
        }
    }

    private ExecutorService createExecutorService() {
        return Executors.newFixedThreadPool(this.corpus.getConcurrencyLevel(), new ThreadFactory(){
            private final AtomicInteger threadNumber = new AtomicInteger(1);

            @Override
            public Thread newThread(Runnable r) {
                Thread t = new Thread(r, "DataTrainer-Thread-" + this.threadNumber.getAndIncrement());
                t.setDaemon(false);
                t.setPriority(10);
                return t;
            }
        });
    }
}

