/*
 * Decompiled with CFR 0.152.
 */
package org.predict4all.nlp.semantic;

import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.procedure.TIntIntProcedure;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Scanner;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.zip.GZIPOutputStream;
import org.predict4all.nlp.Tag;
import org.predict4all.nlp.io.TokenFileInputStream;
import org.predict4all.nlp.language.LanguageModel;
import org.predict4all.nlp.language.StopWordDictionary;
import org.predict4all.nlp.parser.token.Token;
import org.predict4all.nlp.semantic.CoOccurrenceKey;
import org.predict4all.nlp.semantic.SemanticDictionary;
import org.predict4all.nlp.trainer.TrainerTask;
import org.predict4all.nlp.trainer.configuration.TrainingConfiguration;
import org.predict4all.nlp.trainer.corpus.AbstractTrainingDocument;
import org.predict4all.nlp.trainer.corpus.TrainingCorpus;
import org.predict4all.nlp.trainer.step.TrainingStep;
import org.predict4all.nlp.utils.Pair;
import org.predict4all.nlp.utils.progressindicator.LoggingProgressIndicator;
import org.predict4all.nlp.utils.progressindicator.ProgressIndicator;
import org.predict4all.nlp.words.WordDictionary;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SemanticDictionaryGenerator {
    private static final Logger LOGGER = LoggerFactory.getLogger(SemanticDictionaryGenerator.class);
    private static final String FILE_SUFFIX = ".bin";
    private static final String FILE_PREFIX = "predict4all-r-computing";
    private static final int NO_VALUE_ENTRY = -1;
    private static final int DENSITY_BLOCK_SIZE = 1000;
    private static final String R_SCRIPT_NAME = "svd.r";
    private final LanguageModel languageModel;
    private final WordDictionary wordDictionary;
    private final TrainingConfiguration trainingConfiguration;
    private final StopWordDictionary stopWordDictionary;

    public SemanticDictionaryGenerator(LanguageModel languageModel, WordDictionary wordDictionary, TrainingConfiguration trainingConfiguration) throws IOException {
        this.wordDictionary = wordDictionary;
        this.languageModel = languageModel;
        this.trainingConfiguration = trainingConfiguration;
        this.stopWordDictionary = this.languageModel.getStopWordDictionary(trainingConfiguration);
        if (!this.stopWordDictionary.isInitialized()) {
            this.stopWordDictionary.initialize(wordDictionary);
        }
    }

    public void executeLSATrainingForR(TrainingCorpus corpus, File lsaOutputFile, Consumer<List<? extends TrainerTask>> blockingTaskExecutor) throws IOException {
        corpus.initStep(TrainingStep.SEMANTIC_DICTIONARY);
        LoggingProgressIndicator progressIndicator = new LoggingProgressIndicator("LSA generation", corpus.getTotalCountFor(TrainingStep.SEMANTIC_DICTIONARY) * 2);
        List<Pair<Integer, Integer>> wordOrdered = this.countWordAndGetSortedList(corpus, blockingTaskExecutor, progressIndicator);
        Pair<TIntIntHashMap, TIntIntHashMap> indexes = this.initializeOccurenceMatrix(wordOrdered);
        TIntIntHashMap rowIndexes = indexes.getLeft();
        TIntIntHashMap columnIndexes = indexes.getRight();
        int rowCount = rowIndexes.size();
        ConcurrentHashMap<CoOccurrenceKey, LongAdder> countMap = new ConcurrentHashMap<CoOccurrenceKey, LongAdder>();
        LOGGER.info("Count matrix initiliazed, will now start counting");
        int windowSize = this.trainingConfiguration.getLsaWindowSize() % 2 == 0 ? this.trainingConfiguration.getLsaWindowSize() + 1 : this.trainingConfiguration.getLsaWindowSize();
        blockingTaskExecutor.accept(corpus.getDocuments(TrainingStep.SEMANTIC_DICTIONARY).stream().map(d -> new FillCountMatrixTask(progressIndicator, (AbstractTrainingDocument)d, windowSize, windowSize / 2, rowIndexes, columnIndexes, countMap)).collect(Collectors.toList()));
        LOGGER.info("Matrix filled, filling percentage is {}%", (Object)(100.0 * (1.0 * (double)countMap.size()) / (1.0 * (double)rowIndexes.size() * (double)columnIndexes.size())));
        File sizeFile = File.createTempFile(FILE_PREFIX, FILE_SUFFIX);
        File rowIndexesFile = File.createTempFile(FILE_PREFIX, FILE_SUFFIX);
        File columnIndexesFile = File.createTempFile(FILE_PREFIX, FILE_SUFFIX);
        File valuesFile = File.createTempFile(FILE_PREFIX, FILE_SUFFIX);
        this.prepapreDataForR(sizeFile, rowIndexesFile, columnIndexesFile, valuesFile, rowIndexes, columnIndexes, countMap);
        System.gc();
        File ouputMatrixFile = this.launchRScript(sizeFile, rowIndexesFile, columnIndexesFile, valuesFile);
        try (DataOutputStream dos = new DataOutputStream(new GZIPOutputStream(new FileOutputStream(lsaOutputFile)));){
            dos.writeInt(rowCount);
            dos.writeInt(this.trainingConfiguration.getLsaTargetSvdSize());
            this.writeIndexes(dos, rowIndexes);
            double[][] semanticMatrix = this.readMatrixFromR(ouputMatrixFile, dos, rowIndexes);
            this.computeAndWriteDensities(blockingTaskExecutor, rowCount, dos, semanticMatrix);
        }
    }

    private void computeAndWriteDensities(Consumer<List<? extends TrainerTask>> blockingTaskExecutor, int rowCount, DataOutputStream dos, double[][] semanticMatrix) throws IOException {
        ArrayList<ComputeDensitiesTask> tasks = new ArrayList<ComputeDensitiesTask>();
        double[] densities = new double[rowCount];
        LoggingProgressIndicator proDens = new LoggingProgressIndicator("Computing densities", rowCount);
        for (int i = 0; i < rowCount; i += 1000) {
            tasks.add(new ComputeDensitiesTask(proDens, semanticMatrix, i, Math.min(rowCount, i + 1000), densities));
        }
        blockingTaskExecutor.accept(tasks);
        double minD = Double.MAX_VALUE;
        double maxD = Double.MIN_VALUE;
        for (double d : densities) {
            minD = Math.min(minD, d);
            maxD = Math.max(maxD, d);
            dos.writeDouble(d);
        }
        LOGGER.info("Wrote density matrix in file, min value is {}, max value is {}", (Object)minD, (Object)maxD);
    }

    private double[][] readMatrixFromR(File ouputMatrixFile, DataOutputStream dos, TIntIntHashMap rowIndexes) throws IOException, FileNotFoundException {
        double[][] semanticMatrix;
        try (DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(ouputMatrixFile)));){
            int rowCountR = dis.readInt();
            int colCountR = dis.readInt();
            LOGGER.info("Will read data matrix from r result : {}x{}", (Object)rowCountR, (Object)colCountR);
            semanticMatrix = new double[rowCountR][colCountR];
            for (int r = 0; r < rowCountR; ++r) {
                double[] values = new double[colCountR];
                for (int c = 0; c < colCountR; ++c) {
                    values[c] = dis.readDouble();
                }
                SemanticDictionary.normalizeRow(values);
                semanticMatrix[r] = values;
                double sum = 0.0;
                boolean foundZero = false;
                for (double v : values) {
                    dos.writeDouble(v);
                    sum += v;
                    if (v != 0.0) continue;
                    foundZero = true;
                }
                if (sum != 0.0 && !foundZero) continue;
                LOGGER.warn("Found a line with only zeros, index {}", (Object)r);
                final int rowIndexWithOnlyZeros = r;
                rowIndexes.forEachEntry(new TIntIntProcedure(){

                    public boolean execute(int wordId, int rowIndex) {
                        if (rowIndex == rowIndexWithOnlyZeros) {
                            LOGGER.warn("Found the corresponding word for zero line : {}", (Object)SemanticDictionaryGenerator.this.wordDictionary.getWord(wordId));
                            return false;
                        }
                        return true;
                    }
                });
            }
        }
        return semanticMatrix;
    }

    private void writeIndexes(final DataOutputStream dos, TIntIntHashMap rowIndexes) throws IOException {
        boolean writeIndexesSuccess = rowIndexes.forEachEntry(new TIntIntProcedure(){

            public boolean execute(int wordId, int rowIndex) {
                try {
                    dos.writeInt(wordId);
                    dos.writeInt(rowIndex);
                    return true;
                }
                catch (Exception e) {
                    return false;
                }
            }
        });
        if (!writeIndexesSuccess) {
            throw new IOException("Couldn't write whole row indexes into semantic data file");
        }
    }

    private File launchRScript(File sizeFile, File rowIndexesFile, File columnIndexesFile, File valuesFile) throws IOException {
        File rDir = Files.createTempDirectory("predict4all-r", new FileAttribute[0]).toFile();
        try (FileOutputStream os = new FileOutputStream(rDir.getPath() + File.separator + R_SCRIPT_NAME);
             InputStream is = this.getClass().getResourceAsStream("/r-script/svd.r");){
            int n;
            byte[] buffer = new byte[4096];
            while (-1 != (n = is.read(buffer))) {
                ((OutputStream)os).write(buffer, 0, n);
            }
        }
        LOGGER.info("R script created to {}", (Object)rDir.getAbsolutePath());
        File ouputMatrixFile = File.createTempFile(FILE_PREFIX, FILE_SUFFIX);
        LOGGER.info("Will launch R process, expected output R file : {}", (Object)ouputMatrixFile.getAbsolutePath());
        Process rProcess = new ProcessBuilder("Rscript", "--vanilla", R_SCRIPT_NAME, sizeFile.getAbsolutePath(), rowIndexesFile.getAbsolutePath(), columnIndexesFile.getAbsolutePath(), valuesFile.getAbsolutePath(), ouputMatrixFile.getAbsolutePath(), "" + this.trainingConfiguration.getLsaTargetSvdSize()).directory(rDir).start();
        this.createProcessLogger(rProcess, true);
        this.createProcessLogger(rProcess, false);
        try {
            int r = rProcess.waitFor();
            LOGGER.info("R script ended with result {}", (Object)r);
        }
        catch (InterruptedException e1) {
            throw new IOException("R script failed", e1);
        }
        return ouputMatrixFile;
    }

    private void prepapreDataForR(File sizeFile, File rowIndexesFile, File columnIndexesFile, File valuesFile, TIntIntHashMap rowIndexes, TIntIntHashMap columnIndexes, ConcurrentHashMap<CoOccurrenceKey, LongAdder> countMap) throws IOException, FileNotFoundException {
        try (DataOutputStream dos = new DataOutputStream(new FileOutputStream(sizeFile));){
            dos.writeInt(rowIndexes.size());
            dos.writeInt(columnIndexes.size());
            dos.writeInt(countMap.size());
        }
        LOGGER.info("Size matrix read");
        try (DataOutputStream dosRowIndex = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(rowIndexesFile)));
             DataOutputStream dosColumIndex = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(columnIndexesFile)));
             DataOutputStream dosValue = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(valuesFile)));){
            countMap.forEach((index, v) -> {
                try {
                    dosRowIndex.writeInt(index.rowIndex + 1);
                    dosColumIndex.writeInt(index.columnIndex + 1);
                    dosValue.writeInt(v.intValue());
                }
                catch (IOException e) {
                    throw new RuntimeException(e);
                }
            });
        }
        countMap.clear();
        LOGGER.info("File prepared for R script, will now launch it\n\tSize = {}\n\tRow indexes = {}\n\tColumn indexes = {}\n\tValues = {}", new Object[]{sizeFile, rowIndexesFile, columnIndexesFile, valuesFile});
    }

    private Pair<TIntIntHashMap, TIntIntHashMap> initializeOccurenceMatrix(List<Pair<Integer, Integer>> wordOrdered) {
        int columnCount = this.trainingConfiguration.getLsaFrequentWordSize();
        int rowCount = Math.min(wordOrdered.size(), this.trainingConfiguration.getLsaVocabularySize());
        TIntIntHashMap rowIndexes = new TIntIntHashMap(rowCount, 0.9f, -1, -1);
        TIntIntHashMap columnIndexes = new TIntIntHashMap(columnCount, 0.9f, -1, -1);
        LOGGER.info("{} words sorted by count, will initilize count matrix {} rows, {} columns", new Object[]{wordOrdered.size(), rowCount, columnCount});
        for (int i = 0; i < rowCount; ++i) {
            rowIndexes.put(wordOrdered.get(i).getLeft().intValue(), i);
            if (i >= columnCount) continue;
            columnIndexes.put(wordOrdered.get(i).getLeft().intValue(), i);
        }
        rowIndexes.compact();
        columnIndexes.compact();
        wordOrdered.clear();
        return Pair.of(rowIndexes, columnIndexes);
    }

    private List<Pair<Integer, Integer>> countWordAndGetSortedList(TrainingCorpus corpus, Consumer<List<? extends TrainerTask>> blockingTaskExecutor, LoggingProgressIndicator progressIndicator) {
        ConcurrentHashMap<Integer, LongAdder> wordCounts = new ConcurrentHashMap<Integer, LongAdder>((int)((double)(this.trainingConfiguration.getLsaWindowSize() * this.trainingConfiguration.getLsaFrequentWordSize()) * 1.05), 0.95f, Runtime.getRuntime().availableProcessors());
        blockingTaskExecutor.accept(corpus.getDocuments(TrainingStep.SEMANTIC_DICTIONARY).stream().map(d -> new CountWordTask(wordCounts, progressIndicator, (AbstractTrainingDocument)d)).collect(Collectors.toList()));
        LOGGER.info("Word count created, found {} differents words, will now sort ", (Object)wordCounts.size());
        ArrayList<Pair<Integer, Integer>> wordOrdered = new ArrayList<Pair<Integer, Integer>>();
        wordCounts.forEach((id, count) -> wordOrdered.add(Pair.of(id, count.intValue())));
        wordCounts.clear();
        Collections.sort(wordOrdered, (p1, p2) -> Integer.compare((Integer)p2.getRight(), (Integer)p1.getRight()));
        return wordOrdered;
    }

    private void createProcessLogger(Process process, boolean error) {
        Thread logThread = new Thread(() -> {
            InputStream is = error ? process.getErrorStream() : process.getInputStream();
            try (Scanner scan = new Scanner(is);){
                while (scan.hasNextLine()) {
                    String line = scan.nextLine();
                    if (error) {
                        LOGGER.warn("RScript : {}", (Object)line);
                        continue;
                    }
                    LOGGER.info("RScript : {}", (Object)line);
                }
            }
        }, "Process-Logger-" + (error ? "Err" : "Out"));
        logThread.setDaemon(true);
        logThread.start();
    }

    private void putLSADataInMatrix(int windowSize, Token[] windowArray, int middleIndex, TIntIntHashMap rowIndexes, TIntIntHashMap columnIndexes, ConcurrentHashMap<CoOccurrenceKey, LongAdder> countMap) throws IOException {
        int srcId;
        Token srcToken = windowArray[middleIndex];
        if (srcToken != null && (srcId = this.getWordIdForLSA(srcToken)) >= 0) {
            for (int i = 0; i < windowArray.length; ++i) {
                Token dstToken = windowArray[i];
                if (srcToken == dstToken || dstToken == null) continue;
                int dstId = this.getWordIdForLSA(dstToken);
                this.incrementCount(srcId, dstId, rowIndexes, columnIndexes, countMap);
                this.incrementCount(dstId, srcId, rowIndexes, columnIndexes, countMap);
            }
        }
    }

    private void incrementCount(int srcId, int dstId, TIntIntHashMap rowIndexes, TIntIntHashMap columnIndexes, ConcurrentHashMap<CoOccurrenceKey, LongAdder> countMap) {
        int currentWordRowIndex = rowIndexes.get(srcId);
        int wordInWindowColumnIndex = columnIndexes.get(dstId);
        if (currentWordRowIndex >= 0 && wordInWindowColumnIndex >= 0) {
            countMap.computeIfAbsent(new CoOccurrenceKey(currentWordRowIndex, wordInWindowColumnIndex), k -> new LongAdder()).increment();
        }
    }

    private int getWordIdForLSA(Token token) {
        int wordId;
        if (!token.isSeparator() && (wordId = token.getWordId(this.wordDictionary)) != Tag.UNKNOWN.getId() && !this.stopWordDictionary.containsWord(wordId)) {
            return wordId;
        }
        return -1;
    }

    private class ComputeDensitiesTask
    extends TrainerTask {
        private final double[][] matrix;
        private final int rowIndexStart;
        private final int rowIndexEnd;
        private final double[] densities;

        public ComputeDensitiesTask(ProgressIndicator progressIndicator, double[][] matrix, int rowIndexStart, int rowIndexEnd, double[] densities) {
            super(progressIndicator, null);
            this.matrix = matrix;
            this.rowIndexStart = rowIndexStart;
            this.rowIndexEnd = rowIndexEnd;
            this.densities = densities;
        }

        @Override
        public void run() throws Exception {
            for (int r = this.rowIndexStart; r < this.rowIndexEnd; ++r) {
                double[] row = this.matrix[r];
                ArrayList<Double> cosineAngles = new ArrayList<Double>(this.matrix.length);
                for (int r2 = 0; r2 < this.matrix.length; ++r2) {
                    if (r == r2) continue;
                    cosineAngles.add(SemanticDictionary.cosineAngle(row, this.matrix[r2]));
                }
                this.densities[r] = cosineAngles.stream().sorted((a, b) -> Double.compare(b, a)).limit(SemanticDictionaryGenerator.this.trainingConfiguration.getLsaDensitySize()).mapToDouble(v -> v).average().orElseGet(() -> 0.0);
                this.progressIndicator.increment();
            }
        }
    }

    private class CountWordTask
    extends TrainerTask {
        private final ConcurrentHashMap<Integer, LongAdder> wordCounts;

        public CountWordTask(ConcurrentHashMap<Integer, LongAdder> wordCounts, ProgressIndicator progressIndicator, AbstractTrainingDocument document) {
            super(progressIndicator, document);
            this.wordCounts = wordCounts;
        }

        @Override
        public void run() throws Exception {
            try (TokenFileInputStream tokenFis = new TokenFileInputStream(this.document.getInputFile());){
                for (Token token = tokenFis.readToken(); token != null; token = token.getNext(tokenFis)) {
                    int wordId = SemanticDictionaryGenerator.this.getWordIdForLSA(token);
                    if (wordId >= 0) {
                        this.wordCounts.computeIfAbsent(wordId, k -> new LongAdder()).increment();
                    }
                    this.progressIndicator.increment();
                }
            }
        }
    }

    private class FillCountMatrixTask
    extends TrainerTask {
        private final int windowSize;
        private final int middleIndex;
        private final TIntIntHashMap rowIndexes;
        private final TIntIntHashMap columnIndexes;
        private final ConcurrentHashMap<CoOccurrenceKey, LongAdder> countMap;

        public FillCountMatrixTask(ProgressIndicator progressIndicator, AbstractTrainingDocument document, int windowSize, int middleIndex, TIntIntHashMap rowIndexes, TIntIntHashMap columnIndexes, ConcurrentHashMap<CoOccurrenceKey, LongAdder> countMap) {
            super(progressIndicator, document);
            this.windowSize = windowSize;
            this.middleIndex = middleIndex;
            this.rowIndexes = rowIndexes;
            this.columnIndexes = columnIndexes;
            this.countMap = countMap;
        }

        @Override
        public void run() throws Exception {
            Token[] windowArray = new Token[this.windowSize];
            try (TokenFileInputStream tokenFis = new TokenFileInputStream(this.document.getInputFile());){
                for (Token token = tokenFis.getNext(); token != null; token = token.getNext(tokenFis)) {
                    if (!token.isSeparator()) {
                        System.arraycopy(windowArray, 1, windowArray, 0, windowArray.length - 1);
                        windowArray[windowArray.length - 1] = token;
                    }
                    SemanticDictionaryGenerator.this.putLSADataInMatrix(SemanticDictionaryGenerator.this.trainingConfiguration.getLsaWindowSize(), windowArray, this.middleIndex, this.rowIndexes, this.columnIndexes, this.countMap);
                    this.progressIndicator.increment();
                }
            }
        }
    }
}

