/*
 * Decompiled with CFR 0.152.
 */
package edu.umn.biomedicus.tnt;

import edu.umn.biomedicus.common.tuples.WordCap;
import edu.umn.biomedicus.common.types.syntax.PartOfSpeech;
import edu.umn.biomedicus.common.types.syntax.PartsOfSpeech;
import edu.umn.biomedicus.tagging.PosTag;
import edu.umn.biomedicus.tnt.DataStoreFactory;
import edu.umn.biomedicus.tnt.FilteredWordPosFrequencies;
import edu.umn.biomedicus.tnt.KnownWordProbabilityModel;
import edu.umn.biomedicus.tnt.PosCapTrigramModel;
import edu.umn.biomedicus.tnt.PosCapTrigramModelTrainer;
import edu.umn.biomedicus.tnt.SuffixWordProbabilityModel;
import edu.umn.biomedicus.tnt.TntModel;
import edu.umn.biomedicus.tnt.WordCapAdapter;
import edu.umn.biomedicus.tnt.WordCapFilter;
import edu.umn.biomedicus.tnt.WordPosFrequencies;
import edu.umn.biomedicus.tnt.WordProbabilityModel;
import edu.umn.biomedicus.tokenization.ParseToken;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TntModelTrainer {
    private static final Logger LOGGER = LoggerFactory.getLogger(TntModelTrainer.class);
    private final List<FilteredWordPosFrequencies> filteredWordPosFrequencies;
    private final PosCapTrigramModelTrainer posCapTrigramModelTrainer;
    private final int maxSuffixLength;
    private final int maxWordFrequency;
    private final boolean useMslSuffixModel;
    private final boolean restrictToOpenClass;
    private final DataStoreFactory dataStoreFactory;

    private TntModelTrainer(List<FilteredWordPosFrequencies> filteredWordPosFrequencies, PosCapTrigramModelTrainer posCapTrigramModelTrainer, int maxSuffixLength, int maxWordFrequency, boolean useMslSuffixModel, boolean restrictToOpenClass, DataStoreFactory dataStoreFactory) {
        this.filteredWordPosFrequencies = filteredWordPosFrequencies;
        this.posCapTrigramModelTrainer = posCapTrigramModelTrainer;
        this.maxSuffixLength = maxSuffixLength;
        this.maxWordFrequency = maxWordFrequency;
        this.useMslSuffixModel = useMslSuffixModel;
        this.restrictToOpenClass = restrictToOpenClass;
        this.dataStoreFactory = dataStoreFactory;
    }

    public static Builder builder() {
        return new Builder();
    }

    public void addSentence(List<ParseToken> tokens, List<PosTag> partOfSpeeches) {
        for (int i = 0; i < tokens.size(); ++i) {
            ParseToken token = tokens.get(i);
            String tokenText = token.getText();
            boolean isCapitalized = Character.isUpperCase(tokenText.charAt(0));
            WordCap wordCap = new WordCap(tokenText, isCapitalized);
            PartOfSpeech partOfSpeech = partOfSpeeches.get(i).getPartOfSpeech();
            for (FilteredWordPosFrequencies filteredWordPosFrequencies : this.filteredWordPosFrequencies) {
                if (partOfSpeech == null) continue;
                filteredWordPosFrequencies.addWord(wordCap, partOfSpeech);
            }
        }
        this.posCapTrigramModelTrainer.addSentence(tokens, partOfSpeeches);
    }

    public TntModel createModel() {
        PosCapTrigramModel posCapTrigramModel = this.posCapTrigramModelTrainer.build();
        Set<PartOfSpeech> tagSet = this.restrictToOpenClass ? PartsOfSpeech.getOpenClass() : PartsOfSpeech.getRealTags();
        ArrayList<WordProbabilityModel> knownWordModels = new ArrayList<WordProbabilityModel>();
        ArrayList<SuffixWordProbabilityModel> suffixModels = new ArrayList<SuffixWordProbabilityModel>();
        int priority = 0;
        for (FilteredWordPosFrequencies filteredFreqs : this.filteredWordPosFrequencies) {
            WordPosFrequencies wordPosFrequencies = filteredFreqs.getWordPosFrequencies();
            WordCapFilter filter = filteredFreqs.getFilter();
            WordCapAdapter wordCapAdapter = filteredFreqs.getWordCapAdapter();
            KnownWordProbabilityModel knownWordProbabilityModel = new KnownWordProbabilityModel();
            knownWordProbabilityModel.setId(priority);
            knownWordProbabilityModel.setFilter(filter);
            knownWordProbabilityModel.setWordCapAdapter(wordCapAdapter);
            knownWordProbabilityModel.createDataStore(this.dataStoreFactory);
            knownWordProbabilityModel.train(wordPosFrequencies, tagSet);
            knownWordModels.add(knownWordProbabilityModel);
            WordPosFrequencies suffixFrequencies = wordPosFrequencies.onlyWordsOccurringUpTo(this.maxWordFrequency).expandSuffixes(this.maxSuffixLength);
            SuffixWordProbabilityModel suffixWordProbabilityModel = new SuffixWordProbabilityModel();
            suffixWordProbabilityModel.setMaxSuffixLength(this.maxSuffixLength);
            suffixWordProbabilityModel.setId(this.filteredWordPosFrequencies.size() + priority++);
            suffixWordProbabilityModel.setWordCapAdapter(wordCapAdapter);
            suffixWordProbabilityModel.createDataStore(this.dataStoreFactory);
            suffixWordProbabilityModel.setFilter(filter);
            if (this.useMslSuffixModel) {
                throw new UnsupportedOperationException();
            }
            suffixWordProbabilityModel.trainPI(suffixFrequencies, tagSet);
            suffixModels.add(suffixWordProbabilityModel);
        }
        knownWordModels.addAll(suffixModels);
        LOGGER.debug("Word models: {}", knownWordModels);
        return new TntModel(posCapTrigramModel, knownWordModels);
    }

    public static final class Builder {
        private int maxSuffixLength = 5;
        private int maxWordFrequency = 10;
        private boolean useCapitalization = true;
        private boolean useMslSuffixModel = false;
        private boolean restrictToOpenClass = false;
        private DataStoreFactory dataStoreFactory;

        private Builder() {
        }

        public Builder maxSuffixLength(int maxSuffixLength) {
            this.maxSuffixLength = maxSuffixLength;
            return this;
        }

        public Builder maxWordFrequency(int maxWordFrequency) {
            this.maxWordFrequency = maxWordFrequency;
            return this;
        }

        public Builder useCapitalization(boolean useCapitalization) {
            this.useCapitalization = useCapitalization;
            return this;
        }

        public Builder useMslSuffixModel(boolean useMslSuffixModel) {
            this.useMslSuffixModel = useMslSuffixModel;
            return this;
        }

        public Builder restrictToOpenClass(boolean restrictToOpenClass) {
            this.restrictToOpenClass = restrictToOpenClass;
            return this;
        }

        public Builder dataStoreFactory(DataStoreFactory dataStoreFactory) {
            this.dataStoreFactory = dataStoreFactory;
            return this;
        }

        public TntModelTrainer build() {
            ArrayList<FilteredWordPosFrequencies> filteredWordPosFrequencies = new ArrayList<FilteredWordPosFrequencies>();
            if (this.useCapitalization) {
                filteredWordPosFrequencies.add(new FilteredWordPosFrequencies(new WordCapFilter(true, false), new WordCapAdapter(true, false)));
                filteredWordPosFrequencies.add(new FilteredWordPosFrequencies(new WordCapFilter(false, true), new WordCapAdapter(true, false)));
            } else {
                filteredWordPosFrequencies.add(new FilteredWordPosFrequencies(new WordCapFilter(false, false), new WordCapAdapter(true, true)));
            }
            PosCapTrigramModelTrainer posCapTrigramModelTrainer = new PosCapTrigramModelTrainer();
            return new TntModelTrainer(filteredWordPosFrequencies, posCapTrigramModelTrainer, this.maxSuffixLength, this.maxWordFrequency, this.useMslSuffixModel, this.restrictToOpenClass, this.dataStoreFactory);
        }
    }
}

