/*
 * Decompiled with CFR 0.152.
 */
package hex.word2vec;

import hex.Model;
import hex.ModelBuilder;
import hex.schemas.Word2VecModelV2;
import hex.word2vec.Word2Vec;
import hex.word2vec.WordCountTask;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Random;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Iced;
import water.Key;
import water.api.ModelSchema;
import water.fvec.AppendableVec;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.nbhm.NonBlockingHashMap;
import water.parser.ValueString;
import water.util.ArrayUtils;
import water.util.Log;

public class Word2VecModel
extends Model<Word2VecModel, Word2VecParameters, Word2VecOutput> {
    private volatile Word2VecModelInfo _modelInfo;
    private Key _w2vKey;

    void setModelInfo(Word2VecModelInfo mi) {
        this._modelInfo = mi;
    }

    public final Word2VecModelInfo getModelInfo() {
        return this._modelInfo;
    }

    public Word2VecModel(Key selfKey, Word2VecParameters params, Word2VecOutput output) {
        super(selfKey, (Model.Parameters)params, (Model.Output)output);
        this._modelInfo = new Word2VecModelInfo(params);
        assert (Arrays.equals(this._key._kb, selfKey._kb));
    }

    public boolean isSupervised() {
        return false;
    }

    public ModelSchema schema() {
        return new Word2VecModelV2();
    }

    protected float[] score0(Chunk[] cs, int foo, double[] data, float[] preds) {
        throw H2O.unimpl();
    }

    protected float[] score0(double[] data, float[] preds) {
        throw H2O.unimpl();
    }

    public float[] transform(String target) {
        NonBlockingHashMap<ValueString, Integer> vocabHM = this.buildVocabHashMap();
        Vec[] vs = ((Frame)this._w2vKey.get()).vecs();
        ValueString tmp = new ValueString(target);
        return this.transform(tmp, vocabHM, vs);
    }

    private float[] transform(ValueString tmp, NonBlockingHashMap<ValueString, Integer> vocabHM, Vec[] vs) {
        int vecSize = vs.length - 1;
        float[] vec = new float[vecSize];
        if (!vocabHM.containsKey((Object)tmp)) {
            Log.warn((Object[])new Object[]{"Target word " + tmp + " isn't in vocabulary."});
            return null;
        }
        int row = (Integer)vocabHM.get((Object)tmp);
        for (int i = 0; i < vecSize; ++i) {
            vec[i] = (float)vs[i + 1].at((long)row);
        }
        return vec;
    }

    public HashMap<String, Float> findSynonyms(String target, int cnt) {
        if (cnt > 0) {
            NonBlockingHashMap<ValueString, Integer> vocabHM = this.buildVocabHashMap();
            Vec[] vs = ((Frame)this._w2vKey.get()).vecs();
            ValueString tmp = new ValueString(target);
            float[] tarVec = this.transform(tmp, vocabHM, vs);
            return this.findSynonyms(tarVec, cnt, vs);
        }
        Log.err((Object[])new Object[]{"Synonym count must be greater than 0."});
        return null;
    }

    public void findSynonyms(float[] tarVec, int cnt) {
        if (cnt > 0) {
            Vec[] vs = ((Frame)this._w2vKey.get()).vecs();
            this.findSynonyms(tarVec, cnt, vs);
        } else {
            Log.err((Object[])new Object[]{"Synonym count must be greater than 0."});
        }
    }

    private HashMap<String, Float> findSynonyms(float[] tarVec, int cnt, Vec[] vs) {
        int i;
        int vecSize = vs.length - 1;
        int vocabSize = (int)vs[0].length();
        int[] matches = new int[cnt];
        float[] scores = new float[cnt];
        float[] curVec = new float[vecSize];
        HashMap<String, Float> res = new HashMap<String, Float>();
        if (tarVec.length != vs.length - 1) {
            Log.warn((Object[])new Object[]{"Target vector length differs from the vocab's vector length."});
            return null;
        }
        block0: for (i = 0; i < vocabSize; ++i) {
            for (int j = 0; j < vecSize; ++j) {
                curVec[j] = (float)vs[j + 1].at((long)i);
            }
            float score = this.cosineSimilarity(tarVec, curVec);
            for (int j = 0; j < cnt; ++j) {
                if (!(score > scores[j]) || !((double)score < 0.999999)) continue;
                for (int k = cnt - 1; k > j; --k) {
                    scores[k] = scores[k - 1];
                    matches[k] = matches[k - 1];
                }
                scores[j] = score;
                matches[j] = i;
                continue block0;
            }
        }
        for (i = 0; i < cnt; ++i) {
            res.put(vs[0].atStr(new ValueString(), (long)matches[i]).toString(), Float.valueOf(scores[i]));
        }
        return res;
    }

    public float cosineSimilarity(float[] target, float[] current) {
        float dotProd = 0.0f;
        float tsqr = 0.0f;
        float csqr = 0.0f;
        for (int i = 0; i < target.length; ++i) {
            dotProd += target[i] * current[i];
            tsqr = (float)((double)tsqr + Math.pow(target[i], 2.0));
            csqr = (float)((double)csqr + Math.pow(current[i], 2.0));
        }
        return (float)((double)dotProd / (Math.sqrt(tsqr) * Math.sqrt(csqr)));
    }

    private NonBlockingHashMap<ValueString, Integer> buildVocabHashMap() {
        Vec word = ((Frame)this._w2vKey.get()).vec(0);
        int vocabSize = (int)((Frame)this._w2vKey.get()).numRows();
        NonBlockingHashMap vocabHM = new NonBlockingHashMap(vocabSize);
        for (int i = 0; i < vocabSize; ++i) {
            vocabHM.put((Object)word.atStr(new ValueString(), (long)i), (Object)i);
        }
        return vocabHM;
    }

    public void buildModelOutput() {
        int i;
        int vecSize = ((Word2VecParameters)this._parms)._vecSize;
        Futures fs = new Futures();
        String[] colNames = new String[vecSize];
        Vec[] vecs = new Vec[vecSize];
        Key[] keys = Vec.VectorGroup.VG_LEN1.addVecs(vecs.length);
        NewChunk[] cs = new NewChunk[vecs.length];
        AppendableVec[] avs = new AppendableVec[vecs.length];
        for (i = 0; i < vecs.length; ++i) {
            avs[i] = new AppendableVec(keys[i]);
            cs[i] = new NewChunk((Vec)avs[i], 0);
        }
        for (i = 0; i < this._modelInfo._vocabSize; ++i) {
            for (int j = 0; j < vecSize; ++j) {
                cs[j].addNum((double)this._modelInfo._syn0[i * vecSize + j]);
            }
        }
        for (i = 0; i < vecs.length; ++i) {
            colNames[i] = new String("V" + i);
            cs[i].close(0, fs);
            vecs[i] = avs[i].close(fs);
        }
        fs.blockForPending();
        this._w2vKey = Key.make((String)"w2v");
        Frame fr = new Frame(this._w2vKey);
        fr.add("Word", ((Frame)((Word2VecParameters)this._parms)._vocabKey.get()).vec(0));
        fr.add(colNames, vecs);
        DKV.put((Key)this._w2vKey, (Iced)fr);
    }

    public void delete() {
        ((Word2VecParameters)this._parms)._vocabKey.remove();
        this._w2vKey.remove();
        this.remove();
        super.delete();
    }

    public static class Word2VecModelInfo
    extends Iced {
        static final int UNIGRAM_TABLE_SIZE = 10000000;
        static final float UNIGRAM_POWER = 0.75f;
        static final int MAX_CODE_LENGTH = 40;
        long _trainFrameSize;
        int _vocabSize;
        float _curLearningRate;
        float[] _syn0;
        float[] _syn1;
        int[] _uniTable = null;
        int[][] _HBWTCode = null;
        int[][] _HBWTPoint = null;
        private Word2VecParameters _parameters;
        private static int _localWordCnt = 0;
        private static int _globalWordCnt = 0;

        public final Word2VecParameters getParams() {
            return this._parameters;
        }

        public Word2VecModelInfo() {
        }

        public Word2VecModelInfo(Word2VecParameters params) {
            this._parameters = params;
            if (this._parameters._vocabKey == null) {
                this._parameters._vocabKey = ((WordCountTask)new WordCountTask((int)this._parameters._minWordFreq).doAll((Frame)this._parameters.train()))._wordCountKey;
            }
            this._vocabSize = (int)((Frame)this._parameters._vocabKey.get()).numRows();
            this._trainFrameSize = this.getTrainFrameSize(this._parameters.train());
            Random rand = new Random();
            this._syn1 = new float[this._parameters._vecSize * this._vocabSize];
            this._syn0 = new float[this._parameters._vecSize * this._vocabSize];
            for (int i = 0; i < this._parameters._vecSize * this._vocabSize; ++i) {
                this._syn0[i] = (rand.nextFloat() - 0.5f) / (float)this._parameters._vecSize;
            }
            if (this._parameters._normModel == Word2Vec.NormModel.HSM) {
                this.buildHuffmanBinaryWordTree();
            } else {
                this.buildUnigramTable();
            }
        }

        public synchronized void addLocallyProcessed(long p) {
            _localWordCnt = (int)((long)_localWordCnt + p);
        }

        public synchronized long getLocallyProcessed() {
            return _localWordCnt;
        }

        public synchronized void setLocallyProcessed(int p) {
            _localWordCnt = p;
        }

        public synchronized void addGloballyProcessed(long p) {
            _globalWordCnt = (int)((long)_globalWordCnt + p);
        }

        public synchronized long getGloballyProcessed() {
            return _globalWordCnt;
        }

        public synchronized long getTotalProcessed() {
            return _globalWordCnt + _localWordCnt;
        }

        protected void add(Word2VecModelInfo other) {
            ArrayUtils.add((float[])this._syn0, (float[])other._syn0);
            ArrayUtils.add((float[])this._syn1, (float[])other._syn1);
            this.addLocallyProcessed(other.getLocallyProcessed());
        }

        protected void div(float N) {
            if (N > 1.0f) {
                ArrayUtils.div((float[])this._syn0, (float)N);
                ArrayUtils.div((float[])this._syn1, (float)N);
            }
        }

        public void updateLearningRate() {
            this._curLearningRate = this._parameters._initLearningRate * (1.0f - (float)this.getTotalProcessed() / (float)((long)this._parameters._epochs * this._trainFrameSize + 1L));
            if (this._curLearningRate < this._parameters._initLearningRate * 1.0E-4f) {
                this._curLearningRate = this._parameters._initLearningRate * 1.0E-4f;
            }
        }

        private void buildUnigramTable() {
            float d = 0.0f;
            long vocabWordsPow = 0L;
            this._uniTable = new int[10000000];
            Vec wCount = ((Frame)this._parameters._vocabKey.get()).vec(1);
            int i = 0;
            while ((long)i < wCount.length()) {
                vocabWordsPow = (long)((double)vocabWordsPow + Math.pow(wCount.at8((long)i), 0.75));
                ++i;
            }
            int j = 0;
            for (i = 0; i < 10000000; ++i) {
                this._uniTable[i] = j;
                if (j >= this._vocabSize - 1) {
                    j = 0;
                }
                if (!((float)i / 1.0E7f > d)) continue;
                d = (float)((double)d + Math.pow(wCount.at8((long)j++), 0.75) / (double)vocabWordsPow);
            }
        }

        private void buildHuffmanBinaryWordTree() {
            int i;
            int[] point = new int[40];
            int[] code = new int[40];
            long[] count = new long[this._vocabSize * 2 - 1];
            int[] binary = new int[this._vocabSize * 2 - 1];
            int[] parent_node = new int[this._vocabSize * 2 - 1];
            Vec wCount = ((Frame)this._parameters._vocabKey.get()).vec(1);
            this._HBWTCode = new int[this._vocabSize][];
            this._HBWTPoint = new int[this._vocabSize][];
            assert ((long)this._vocabSize == wCount.length());
            for (i = 0; i < this._vocabSize; ++i) {
                count[i] = wCount.at8((long)i);
            }
            for (i = this._vocabSize; i < this._vocabSize * 2 - 1; ++i) {
                count[i] = 1000000000000000L;
            }
            int pos1 = this._vocabSize - 1;
            int pos2 = this._vocabSize;
            for (i = 0; i < this._vocabSize - 1; ++i) {
                int min1i = pos1 >= 0 ? (count[pos1] < count[pos2] ? pos1-- : pos2++) : pos2++;
                int min2i = pos1 >= 0 ? (count[pos1] < count[pos2] ? pos1-- : pos2++) : pos2++;
                count[this._vocabSize + i] = count[min1i] + count[min2i];
                parent_node[min1i] = this._vocabSize + i;
                parent_node[min2i] = this._vocabSize + i;
                binary[min2i] = 1;
            }
            for (int j = 0; j < this._vocabSize; ++j) {
                int k = j;
                int m = 0;
                do {
                    int val;
                    code[m] = val = binary[k];
                    point[m] = k;
                    ++m;
                } while ((k = parent_node[k]) != 0);
                this._HBWTCode[j] = new int[m];
                this._HBWTPoint[j] = new int[m + 1];
                this._HBWTPoint[j][0] = this._vocabSize - 2;
                for (int l = 0; l < m; ++l) {
                    this._HBWTCode[j][m - l - 1] = code[l];
                    this._HBWTPoint[j][m - l] = point[l] - this._vocabSize;
                }
            }
        }

        private long getTrainFrameSize(Frame tf) {
            long count = 0L;
            for (Vec v : tf.vecs()) {
                if (!v.isString()) continue;
                count += v.length();
            }
            return count;
        }
    }

    public static class Word2VecOutput
    extends Model.Output {
        public Word2Vec.WordModel _wordModel;
        public Word2Vec.NormModel _normModel;
        public int _minWordFreq;
        public int _vecSize;
        public int _windowSize;
        public int _epochs;
        public int _negSampleCnt;
        public float _initLearningRate;
        public float _sentSampleRate;

        public Word2VecOutput(Word2Vec b) {
            super((ModelBuilder)b);
        }

        public Model.ModelCategory getModelCategory() {
            return Model.ModelCategory.Unknown;
        }
    }

    public static class Word2VecParameters
    extends Model.Parameters {
        static final int MAX_VEC_SIZE = 10000;
        public Word2Vec.WordModel _wordModel;
        public Word2Vec.NormModel _normModel;
        public Key _vocabKey;
        public int _minWordFreq;
        public int _vecSize;
        public int _windowSize;
        public int _epochs;
        public int _negSampleCnt;
        public float _initLearningRate;
        public float _sentSampleRate;
    }
}

