/*
 * Decompiled with CFR 0.152.
 */
package org.predict4all.nlp.ngram.dictionary;

import java.io.File;
import java.io.IOException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.StandardOpenOption;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.predict4all.nlp.ngram.dictionary.AbstractNGramDictionary;
import org.predict4all.nlp.ngram.trie.DynamicNGramTrieNode;
import org.predict4all.nlp.trainer.configuration.NGramPruningMethod;
import org.predict4all.nlp.trainer.configuration.TrainingConfiguration;
import org.predict4all.nlp.utils.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TrainingNGramDictionary
extends AbstractNGramDictionary<DynamicNGramTrieNode> {
    private static final Logger LOGGER = LoggerFactory.getLogger(TrainingNGramDictionary.class);
    public static final DecimalFormat NGRAM_COUNT_FORMAT = new DecimalFormat("###,###,###,###,###");

    protected TrainingNGramDictionary(int maxOrderP) {
        this(new DynamicNGramTrieNode(), maxOrderP);
    }

    protected TrainingNGramDictionary(DynamicNGramTrieNode root, int maxOrderP) {
        super(root, maxOrderP);
    }

    @Override
    public DynamicNGramTrieNode getNodeForPrefix(int[] prefix, int index) {
        return ((DynamicNGramTrieNode)this.rootNode).getNodeFor(prefix, index, -1);
    }

    @Override
    public boolean checkChildrenLoading(DynamicNGramTrieNode node) {
        return node.getChildren() != null;
    }

    @Override
    public void putAndIncrementBy(int[] ngram, int increment) {
        this.putAndIncrementBy(ngram, 0, increment);
    }

    @Override
    public void putAndIncrementBy(int[] ngram, int index, int increment) {
        ((DynamicNGramTrieNode)this.rootNode).putAndIncrementBy(ngram, index, increment);
    }

    @Override
    public void saveDictionary(File dictionaryFile) throws IOException {
        try (FileChannel fileChannel = FileChannel.open(dictionaryFile.toPath(), StandardOpenOption.CREATE, StandardOpenOption.WRITE, StandardOpenOption.TRUNCATE_EXISTING);){
            fileChannel.position(this.getRootBlockSize() + 4L);
            for (int order = this.maxOrder; order >= 1; --order) {
                LOGGER.info("Save NGram level {}", (Object)order);
                this.executeWriteLevelOnRoot(fileChannel, order);
            }
            LOGGER.info("All level saved, will now save general information and root block");
            fileChannel.position(0L);
            ByteBuffer buffWrite = ByteBuffer.allocateDirect(4);
            this.writeDictionaryInfo(buffWrite);
            ((Buffer)buffWrite).flip();
            fileChannel.write(buffWrite);
            this.executeWriteLevelOnRoot(fileChannel, 0);
        }
    }

    protected void executeWriteLevelOnRoot(FileChannel fileChannel, int level) throws IOException {
        ((DynamicNGramTrieNode)this.rootNode).writeLevelForStaticUse(fileChannel, -1, 0, level);
    }

    protected long getRootBlockSize() {
        return 28L;
    }

    @Override
    public void updateProbabilities(double[] d) {
        long startProb = System.currentTimeMillis();
        LOGGER.info("Start computing ngram probabilities with d={}", (Object)d);
        ((DynamicNGramTrieNode)this.rootNode).computeProbabilityForChildren(0, d, true);
        LOGGER.info("NGram probabilities computed in {} ms", (Object)(System.currentTimeMillis() - startProb));
    }

    @Override
    public void updateProbabilities(int[] prefix, int prefixIndex, double[] d) {
        int level = prefix.length - prefixIndex - 1;
        DynamicNGramTrieNode nodeFor = ((DynamicNGramTrieNode)this.rootNode).getNodeFor(prefix, prefixIndex, prefix.length - 1);
        if (nodeFor == null) {
            throw new IllegalArgumentException("No existing node for " + Arrays.toString(prefix) + " : can't compute probabilities");
        }
        nodeFor.computeProbabilityForChildren(level, d, false);
    }

    @Override
    public double[] computeD(TrainingConfiguration configuration) {
        if (configuration.getSmoothingDiscountValue() > 0.0) {
            double[] d = new double[this.getMaxOrder()];
            Arrays.fill(d, configuration.getSmoothingDiscountValue());
            return d;
        }
        AtomicInteger[] n1Count = new AtomicInteger[this.getMaxOrder()];
        Arrays.setAll(n1Count, i -> new AtomicInteger(0));
        AtomicInteger[] n2Count = new AtomicInteger[this.getMaxOrder()];
        Arrays.setAll(n2Count, i -> new AtomicInteger(0));
        LOGGER.info("D will be computed automatically, count ngram with count=1 and count=2 to compute d");
        ((DynamicNGramTrieNode)this.rootNode).countOneAndTwoOccurenceNGrams(0, n1Count, n2Count);
        double[] dArray = new double[this.getMaxOrder()];
        for (int i2 = 0; i2 < this.getMaxOrder(); ++i2) {
            double d = (double)n1Count[i2].get() * 1.0 / ((double)n1Count[i2].get() * 1.0 + 2.0 * (double)n2Count[i2].get());
            d = Double.isNaN(d) ? 0.5 : d;
            dArray[i2] = Math.min(Math.max(d, configuration.getSmoothingDiscountValueLowerBound()), configuration.getSmoothingDiscountValueUpperBound());
            LOGGER.info("[ORDER {}] Found {} ngram with count=1, {} ngram with count=2, d is set to {}", new Object[]{i2 + 1, n1Count[i2].get(), n2Count[i2].get(), dArray[i2]});
        }
        return dArray;
    }

    public void pruneNGramsWeightedDifference(double thresholdPruning, TrainingConfiguration configuration, NGramPruningMethod pruningMethod) {
        this.updateProbabilities(this.computeD(configuration));
        if (thresholdPruning > 0.0) {
            LOGGER.info("Start pruning ngrams using weighted difference ({}) and threshold {}", (Object)pruningMethod, (Object)thresholdPruning);
            for (int order = this.maxOrder; order > 1; --order) {
                ArrayList toDelete = new ArrayList(20000);
                ((DynamicNGramTrieNode)this.rootNode).listTrieLeaves(new int[order - 1], -1, 0, order, (ngram, wantedWord) -> {
                    double orderProb = pruningMethod == NGramPruningMethod.WEIGHTED_DIFFERENCE_FULL_PROB ? this.getProbability((int[])ngram, 0, ((int[])ngram).length, (int)wantedWord) : this.getRawProbability((int[])ngram, 0, ((int[])ngram).length, (int)wantedWord);
                    double lowerOrderProb = pruningMethod == NGramPruningMethod.WEIGHTED_DIFFERENCE_FULL_PROB ? this.getProbability((int[])ngram, 1, ((int[])ngram).length - 1, (int)wantedWord) : this.getRawProbability((int[])ngram, 1, ((int[])ngram).length - 1, (int)wantedWord);
                    double weightDiff = orderProb * (Math.log(orderProb) - Math.log(lowerOrderProb));
                    if (weightDiff <= thresholdPruning) {
                        toDelete.add(Pair.of(ngram, wantedWord));
                    }
                });
                LOGGER.info("Found {} {}-gram to prune", (Object)toDelete.size(), (Object)order);
                for (Pair ngramToDelete : toDelete) {
                    this.getNodeForPrefix((int[])ngramToDelete.getLeft(), 0).getChildren().remove((Integer)ngramToDelete.getRight());
                }
                LOGGER.info("{} {}-gram removed, will now compact and compute probabilities again", (Object)toDelete.size(), (Object)order);
                this.compact();
                this.updateProbabilities(this.computeD(configuration));
            }
        } else {
            LOGGER.info("Ignore ngram pruning because threshold = {}", (Object)thresholdPruning);
        }
    }

    public void pruneNGramsCount(int countThreshold, TrainingConfiguration configuration) {
        LOGGER.info("Start pruning ngrams using raw count and threshold {}", (Object)countThreshold);
        double[] d = this.computeD(configuration);
        for (int order = this.maxOrder; order > 1; --order) {
            ((DynamicNGramTrieNode)this.rootNode).pruningCountingNGram(0, order, countThreshold);
            this.compact();
        }
        this.updateProbabilities(d);
    }

    public void pruneNGramsOrderCount(int[] counts, TrainingConfiguration configuration) {
        LOGGER.info("Start pruning ngrams using order counts and threshold {}", (Object)Arrays.toString(counts));
        double[] d = this.computeD(configuration);
        for (int order = this.maxOrder; order > 1; --order) {
            ((DynamicNGramTrieNode)this.rootNode).pruningCountingNGram(0, order, counts[order - 1]);
            this.compact();
        }
        this.updateProbabilities(d);
    }

    @Override
    public void close() throws Exception {
    }

    @Override
    protected void openDictionary(File dictionaryFile) throws IOException {
        throw new UnsupportedOperationException("Training ngram dictionary can't be opened");
    }

    public static TrainingNGramDictionary create(int maxOrder) {
        return new TrainingNGramDictionary(maxOrder);
    }

    public Map<Integer, Pair<Integer, Integer>> countNGrams() {
        HashMap<Integer, Pair<Integer, Integer>> ngramCounts = new HashMap<Integer, Pair<Integer, Integer>>();
        for (int order = 1; order <= this.maxOrder; ++order) {
            AtomicInteger uniqueCounter = new AtomicInteger(0);
            AtomicInteger totalCounter = new AtomicInteger(0);
            ((DynamicNGramTrieNode)this.rootNode).countNGram(0, order, totalCounter, uniqueCounter);
            ngramCounts.put(order, Pair.of(uniqueCounter.get(), totalCounter.get()));
            LOGGER.info("{} total {}-gram detected, with {} unique {}-gram", new Object[]{NGRAM_COUNT_FORMAT.format(totalCounter.get()), order, NGRAM_COUNT_FORMAT.format(uniqueCounter.get()), order});
        }
        return ngramCounts;
    }
}

