/*
 * Decompiled with CFR 0.152.
 */
package org.predict4all.nlp.semantic;

import gnu.trove.map.hash.TIntDoubleHashMap;
import gnu.trove.map.hash.TIntIntHashMap;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.zip.GZIPInputStream;
import org.predict4all.nlp.prediction.model.AbstractPredictionToCompute;
import org.predict4all.nlp.semantic.SemanticDictionaryConfiguration;
import org.predict4all.nlp.utils.Pair;
import org.predict4all.nlp.utils.SingleThreadDoubleAdder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SemanticDictionary {
    private static final Logger LOGGER = LoggerFactory.getLogger(SemanticDictionary.class);
    private static final int NO_VALUE_ENTRY = -1;
    private final TIntIntHashMap rowIndexes;
    private final double[][] semanticMatrix;
    private final TIntDoubleHashMap densitiesMap;
    private final int rowCount;
    private final int columnCount;

    private SemanticDictionary(TIntIntHashMap rowIndexes, double[][] semanticMatrix, TIntDoubleHashMap densitiesMap) {
        this.rowIndexes = rowIndexes;
        this.semanticMatrix = semanticMatrix;
        this.rowCount = semanticMatrix.length;
        this.columnCount = semanticMatrix[0].length;
        this.densitiesMap = densitiesMap;
    }

    public TIntDoubleHashMap getDensitiesMap() {
        return this.densitiesMap;
    }

    private Pair<Double, TIntDoubleHashMap> computeScoreMapFor(double[] wordRow, List<AbstractPredictionToCompute> predictions, double constrastFactor) {
        TIntDoubleHashMap scoreMap = new TIntDoubleHashMap(this.rowCount, 0.9f, -1, -1.0);
        double minVal = 0.0;
        for (AbstractPredictionToCompute prediction : predictions) {
            int rowIndex = this.rowIndexes.get(prediction.getWordId());
            if (rowIndex < 0) continue;
            double angle = SemanticDictionary.cosineAngle(wordRow, this.semanticMatrix[rowIndex]);
            scoreMap.put(prediction.getWordId(), angle);
            minVal = Math.min(minVal, angle);
        }
        double minValFinal = minVal;
        SingleThreadDoubleAdder sum = new SingleThreadDoubleAdder(0.0);
        scoreMap.transformValues(value -> sum.addAndReturnAdded(Math.pow(value - minValFinal, constrastFactor)));
        return Pair.of(sum.sum(), scoreMap);
    }

    public Pair<Double, TIntDoubleHashMap> getSimilarityCosineFor(Collection<Integer> wordIds, List<AbstractPredictionToCompute> predictions, double constrastFactor) {
        double[] wordFactors = new double[this.columnCount];
        boolean foundKnownWord = false;
        for (int wordId : wordIds) {
            int rowIndexForWord = this.rowIndexes.get(wordId);
            if (rowIndexForWord < 0) continue;
            foundKnownWord |= true;
            double[] rowForWord = this.semanticMatrix[rowIndexForWord];
            for (int c = 0; c < this.columnCount; ++c) {
                int n = c;
                wordFactors[n] = wordFactors[n] + rowForWord[c];
            }
        }
        if (foundKnownWord) {
            SemanticDictionary.normalizeRow(wordFactors);
            return this.computeScoreMapFor(wordFactors, predictions, constrastFactor);
        }
        return null;
    }

    static double cosineAngle(double[] row1, double[] row2) {
        double prod = 0.0;
        for (int c = 0; c < row1.length; ++c) {
            prod += row1[c] * row2[c];
        }
        return prod;
    }

    private static double length(double[] v) {
        double sum = 0.0;
        for (int i = 0; i < v.length; ++i) {
            sum += v[i] * v[i];
        }
        if (sum < 0.0) {
            throw new ArithmeticException("Vector val sum is negative, int overflow ?");
        }
        return Math.sqrt(sum);
    }

    public static void normalizeRow(double[] row) {
        double length = SemanticDictionary.length(row);
        if (length > 0.0) {
            for (int c = 0; c < row.length; ++c) {
                row[c] = row[c] / length;
            }
        }
    }

    public static SemanticDictionary loadDictionary(File semanticDataFile, SemanticDictionaryConfiguration configuration) throws IOException {
        long start = System.currentTimeMillis();
        try (DataInputStream dis = new DataInputStream(new GZIPInputStream(new FileInputStream(semanticDataFile)));){
            int rowCount = dis.readInt();
            int columnCount = dis.readInt();
            TIntIntHashMap rowIndexes = new TIntIntHashMap(rowCount, 0.9f, -1, -1);
            for (int i = 0; i < rowCount; ++i) {
                rowIndexes.put(dis.readInt(), dis.readInt());
            }
            rowIndexes.compact();
            LOGGER.info("Read {} matrix row indexes from semantic file", (Object)rowIndexes.size());
            double[][] matrix = new double[rowCount][columnCount];
            for (int r = 0; r < rowCount; ++r) {
                for (int c = 0; c < columnCount; ++c) {
                    matrix[r][c] = dis.readDouble();
                }
            }
            LOGGER.info("Read {}x{} semantic matrix in {} ms", new Object[]{rowCount, columnCount, System.currentTimeMillis() - start});
            long startD = System.currentTimeMillis();
            double min = 0.0;
            double max = 0.0;
            double[] densities = new double[rowCount];
            for (int r = 0; r < rowCount; ++r) {
                densities[r] = dis.readDouble();
                min = Math.min(min, densities[r]);
                max = Math.max(max, densities[r]);
            }
            double tmin = Double.MAX_VALUE;
            double tmax = Double.MIN_VALUE;
            for (int r = 0; r < rowCount; ++r) {
                densities[r] = configuration.getSemanticDensityMinBound() + densities[r] / max * (configuration.getSemanticDensityMaxBound() - configuration.getSemanticDensityMinBound());
                tmin = Math.min(tmin, densities[r]);
                tmax = Math.max(tmax, densities[r]);
            }
            LOGGER.info("LSA densities scaled in {} ms, min = {}, max = {}", new Object[]{System.currentTimeMillis() - startD, tmin, tmax});
            TIntDoubleHashMap densitiesMap = new TIntDoubleHashMap(densities.length);
            rowIndexes.forEachKey(wordId -> {
                densitiesMap.put(wordId, densities[rowIndexes.get(wordId)]);
                return true;
            });
            densitiesMap.compact();
            LOGGER.info("Semantic matrix and densities loaded in {} ms", (Object)(System.currentTimeMillis() - start));
            SemanticDictionary semanticDictionary = new SemanticDictionary(rowIndexes, matrix, densitiesMap);
            return semanticDictionary;
        }
    }
}

