/*
 * Decompiled with CFR 0.152.
 */
package hex.word2vec;

import hex.word2vec.Word2Vec;
import hex.word2vec.Word2VecModel;
import java.util.Random;
import water.H2O;
import water.MRTask;
import water.fvec.CStrChunk;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.nbhm.NonBlockingHashMap;
import water.parser.BufferedString;
import water.util.Log;

public class WordVectorTrainer
extends MRTask<WordVectorTrainer> {
    static final int MAX_SENTENCE_LEN = 1000;
    static final int MIN_SENTENCE_LEN = 10;
    static final int EXP_TABLE_SIZE = 1000;
    static final int MAX_EXP = 6;
    private Word2VecModel.Word2VecModelInfo _input;
    Word2VecModel.Word2VecModelInfo _output;
    Frame _vocab;
    static NonBlockingHashMap<BufferedString, Integer> _vocabHM;
    final Word2Vec.WordModel _wordModel;
    final Word2Vec.NormModel _normModel;
    final int _vocabSize;
    final int _wordVecSize;
    final int _windowSize;
    final int _epochs;
    final int _negExCnt;
    final float _initLearningRate;
    final float _sentSampleRate;
    static float[] _syn0;
    static float[] _syn1;
    static float[] _expTable;
    final int[] _unigramTable;
    final int[][] _HBWTCode;
    final int[][] _HBWTPoint;
    int _chunkNodeCount = 1;
    transient float _curLearningRate;
    transient int _chkIdx = 0;
    transient Random _rand;
    static transient long _seed;
    static long _lastWarn;
    static long _warnCount;

    public WordVectorTrainer(Word2VecModel.Word2VecModelInfo input) {
        super(null);
        this._input = input;
        this._wordModel = input.getParams()._wordModel;
        this._normModel = input.getParams()._normModel;
        this._vocab = (Frame)input.getParams()._vocabKey.get();
        this._vocabSize = (int)this._vocab.numRows();
        this._wordVecSize = input.getParams()._vecSize;
        this._windowSize = input.getParams()._windowSize;
        _syn0 = input._syn0;
        _syn1 = input._syn1;
        this._initLearningRate = input.getParams()._initLearningRate;
        this._sentSampleRate = input.getParams()._sentSampleRate;
        this._epochs = input.getParams()._epochs;
        _seed = System.nanoTime();
        assert (this._output == null);
        assert (this._vocab.numRows() > 0L);
        if (input.getParams()._normModel == Word2Vec.NormModel.NegSampling) {
            this._negExCnt = input.getParams()._negSampleCnt;
            this._unigramTable = input._uniTable;
            this._HBWTCode = null;
            this._HBWTPoint = null;
        } else {
            this._negExCnt = 0;
            this._unigramTable = null;
            this._HBWTCode = input._HBWTCode;
            this._HBWTPoint = input._HBWTPoint;
        }
    }

    public final Word2VecModel.Word2VecModelInfo getModelInfo() {
        return this._output;
    }

    protected void setupLocal() {
        _syn0 = this._input._syn0;
        _syn1 = this._input._syn1;
        this._output = this._input;
        this._input = null;
        this._rand = new Random();
        this.initExpTable();
        this.buildVocabHashMap();
        this._curLearningRate = this._output._curLearningRate;
        this._output.setLocallyProcessed(0);
    }

    private void buildVocabHashMap() {
        Vec word = this._vocab.vec(0);
        _vocabHM = new NonBlockingHashMap((int)this._vocab.numRows());
        int i = 0;
        while ((long)i < this._vocab.numRows()) {
            _vocabHM.put((Object)word.atStr(new BufferedString(), (long)i), (Object)i);
            ++i;
        }
    }

    private void updateAlpha(int localWordCnt) {
        this._curLearningRate = this._initLearningRate * (1.0f - (float)(this._output.getGloballyProcessed() + (long)localWordCnt) / (float)((long)this._epochs * this._output._trainFrameSize + 1L));
        if (this._curLearningRate < this._initLearningRate * 1.0E-4f) {
            this._curLearningRate = this._initLearningRate * 1.0E-4f;
        }
    }

    private int getSentence(int[] sentence, CStrChunk cs) {
        Vec count = this._vocab.vec(1);
        BufferedString tmp = new BufferedString();
        int sentIdx = 0;
        int sentLen = cs._len - 1 - this._chkIdx;
        if (sentLen >= 1000) {
            sentLen = 1000;
        } else if (sentLen < 10) {
            return 0;
        }
        while (this._chkIdx < cs._len) {
            cs.atStr(tmp, this._chkIdx);
            if (_vocabHM.containsKey((Object)tmp)) {
                float ran;
                int wIdx = (Integer)_vocabHM.get((Object)tmp);
                if (!(this._sentSampleRate > 0.0f) || !((ran = ((float)Math.sqrt((float)count.at8((long)wIdx) / (this._sentSampleRate * (float)this._output._trainFrameSize)) + 1.0f) * (this._sentSampleRate * (float)this._output._trainFrameSize) / (float)count.at8((long)wIdx)) < this._rand.nextFloat())) {
                    sentence[sentIdx++] = wIdx;
                    if (sentIdx >= sentLen) break;
                }
            }
            ++this._chkIdx;
        }
        return sentLen;
    }

    private void initExpTable() {
        _expTable = new float[1000];
        for (int i = 0; i < 1000; ++i) {
            WordVectorTrainer._expTable[i] = (float)Math.exp(((float)i / 1000.0f * 2.0f - 1.0f) * 6.0f);
            WordVectorTrainer._expTable[i] = _expTable[i] / (_expTable[i] + 1.0f);
        }
    }

    public void map(Chunk[] cs) {
        int wrdCnt = 0;
        int bagSize = 0;
        int winSize = this._windowSize;
        int vecSize = this._wordVecSize;
        float[] neu1 = new float[vecSize];
        float[] neu1e = new float[vecSize];
        int[] sentence = new int[1000];
        for (Chunk chk : cs) {
            int sentLen;
            if (!(chk instanceof CStrChunk)) continue;
            while ((sentLen = this.getSentence(sentence, (CStrChunk)chk)) > 0) {
                for (int sentIdx = 0; sentIdx < sentLen; ++sentIdx) {
                    int winSizeMod;
                    if (wrdCnt % 10000 == 0) {
                        this.updateAlpha(wrdCnt);
                    }
                    int curWord = sentence[sentIdx];
                    ++wrdCnt;
                    if (this._wordModel == Word2Vec.WordModel.CBOW) {
                        int j;
                        for (j = 0; j < vecSize; ++j) {
                            neu1[j] = 0.0f;
                        }
                        for (j = 0; j < vecSize; ++j) {
                            neu1e[j] = 0.0f;
                        }
                        bagSize = 0;
                    }
                    for (int winIdx = winSizeMod = this.cheapRandInt(winSize); winIdx < winSize * 2 + 1 - winSizeMod; ++winIdx) {
                        int winWordSentIdx;
                        if (winIdx == winSize || (winWordSentIdx = sentIdx - winSize + winIdx) < 0 || winWordSentIdx >= sentLen) continue;
                        int winWord = sentence[winWordSentIdx];
                        if (this._wordModel == Word2Vec.WordModel.SkipGram) {
                            this.skipGram(curWord, winWord, neu1e);
                            continue;
                        }
                        for (int j = 0; j < vecSize; ++j) {
                            int n = j;
                            neu1[n] = neu1[n] + _syn0[j + winWord * vecSize];
                        }
                        ++bagSize;
                    }
                    if (this._wordModel != Word2Vec.WordModel.CBOW || bagSize <= 0) continue;
                    this.CBOW(curWord, sentence, sentIdx, sentLen, winSizeMod, bagSize, neu1, neu1e);
                }
            }
        }
        this._output.addLocallyProcessed(wrdCnt);
    }

    public void reduce(WordVectorTrainer other) {
        if (other._output.getLocallyProcessed() > 0L && other._output != this._output) {
            if (this._output.getLocallyProcessed() == 0L) {
                this._output = other._output;
                this._chunkNodeCount = other._chunkNodeCount;
            } else {
                this._output.add(other._output);
                this._chunkNodeCount += other._chunkNodeCount;
            }
        }
    }

    protected void closeLocal() {
        this._vocab = null;
    }

    protected void postGlobal() {
        if (H2O.CLOUD.size() > 1) {
            long now = System.currentTimeMillis();
            if (this._chunkNodeCount < H2O.CLOUD.size() && now - _lastWarn > 5000L && _warnCount < 3L) {
                Log.warn((Object[])new Object[]{H2O.CLOUD.size() - this._chunkNodeCount + " node(s) (out of " + H2O.CLOUD.size() + ") are not contributing to model updates. Consider setting replicate_training_data to true or using a larger training dataset (or fewer H2O nodes)."});
                _lastWarn = now;
                ++_warnCount;
            }
        }
        this._output.div(this._chunkNodeCount);
        this._output.addGloballyProcessed(this._output.getLocallyProcessed());
        this._output.setLocallyProcessed(0);
        assert (this._input == null);
    }

    private void skipGram(int curWord, int winWord, float[] neu1e) {
        int i;
        int vecSize = this._wordVecSize;
        int l1 = winWord * vecSize;
        for (i = 0; i < vecSize; ++i) {
            neu1e[i] = 0.0f;
        }
        if (this._normModel == Word2Vec.NormModel.NegSampling) {
            this.negSamplingSG(curWord, l1, neu1e);
        } else {
            this.hierarchicalSoftmaxSG(curWord, l1, neu1e);
        }
        for (i = 0; i < vecSize; ++i) {
            int n = i + l1;
            _syn0[n] = _syn0[n] + neu1e[i];
        }
    }

    private void CBOW(int curWord, int[] sentence, int sentIdx, int sentLen, int winSizeMod, int bagSize, float[] neu1, float[] neu1e) {
        int vecSize = this._wordVecSize;
        int winSize = this._windowSize;
        int curWinSize = this._windowSize * 2 + 1 - winSize;
        int i = 0;
        while (i < vecSize) {
            int n = i++;
            neu1[n] = neu1[n] / (float)bagSize;
        }
        if (this._normModel == Word2Vec.NormModel.NegSampling) {
            this.negSamplingCBOW(curWord, neu1, neu1e);
        } else {
            this.hierarchicalSoftmaxCBOW(curWord, neu1, neu1e);
        }
        for (int winIdx = winSizeMod; winIdx < curWinSize; ++winIdx) {
            int winWordSentIdx;
            if (winIdx == winSize || (winWordSentIdx = sentIdx - winSize + winIdx) < 0 || winWordSentIdx >= sentLen) continue;
            int winWord = sentence[winWordSentIdx];
            for (int i2 = 0; i2 < vecSize; ++i2) {
                int n = i2 + winWord * vecSize;
                _syn0[n] = _syn0[n] + neu1e[i2];
            }
        }
    }

    private void negSamplingCBOW(int curWord, float[] neu1, float[] neu1e) {
        int i;
        int vecSize = this._wordVecSize;
        int negExCnt = this._negExCnt;
        int uTblSize = this._unigramTable.length;
        float alpha = this._curLearningRate;
        float f = 0.0f;
        int l2 = curWord * vecSize;
        for (i = 0; i < vecSize; ++i) {
            f += neu1[i] * _syn1[i + l2];
        }
        float gradient = f > 6.0f ? 0.0f : (f < -6.0f ? alpha : (1.0f - _expTable[(int)((f + 6.0f) * 83.0f)]) * alpha);
        for (i = 0; i < vecSize; ++i) {
            int n = i;
            neu1e[n] = neu1e[n] + gradient * _syn1[i + l2];
        }
        for (i = 0; i < vecSize; ++i) {
            int n = i + l2;
            _syn1[n] = _syn1[n] + gradient * neu1[i];
        }
        for (i = 1; i < negExCnt + 1; ++i) {
            int j;
            f = 0.0f;
            int targetWord = this._unigramTable[this.cheapRandInt(uTblSize)];
            if (targetWord == curWord) continue;
            l2 = targetWord * vecSize;
            for (j = 0; j < vecSize; ++j) {
                f += neu1[j] * _syn1[j + l2];
            }
            gradient = f > 6.0f ? -alpha : (f < -6.0f ? 0.0f : -_expTable[(int)((f + 6.0f) * 83.0f)] * alpha);
            for (j = 0; j < vecSize; ++j) {
                int n = j;
                neu1e[n] = neu1e[n] + gradient * _syn1[j + l2];
            }
            for (j = 0; j < vecSize; ++j) {
                int n = j + l2;
                _syn1[n] = _syn1[n] + gradient * neu1[j];
            }
        }
    }

    private void negSamplingSG(int curWord, int l1, float[] neu1e) {
        int i;
        int vecSize = this._wordVecSize;
        int negExCnt = this._negExCnt;
        int uTblSize = this._unigramTable.length;
        float alpha = this._curLearningRate;
        float f = 0.0f;
        int l2 = curWord * vecSize;
        for (i = 0; i < vecSize; ++i) {
            f += _syn0[i + l1] * _syn1[i + l2];
        }
        float gradient = f > 6.0f ? 0.0f : (f < -6.0f ? alpha : (1.0f - _expTable[(int)((f + 6.0f) * 83.0f)]) * alpha);
        for (i = 0; i < vecSize; ++i) {
            int n = i;
            neu1e[n] = neu1e[n] + gradient * _syn1[i + l2];
        }
        for (i = 0; i < vecSize; ++i) {
            int n = i + l2;
            _syn1[n] = _syn1[n] + gradient * _syn0[i + l1];
        }
        for (i = 1; i < negExCnt + 1; ++i) {
            int j;
            f = 0.0f;
            int targetWord = this._unigramTable[this.cheapRandInt(uTblSize)];
            if (targetWord == curWord) continue;
            l2 = targetWord * vecSize;
            for (j = 0; j < vecSize; ++j) {
                f += _syn0[j + l1] * _syn1[j + l2];
            }
            gradient = f > 6.0f ? -alpha : (f < -6.0f ? 0.0f : -_expTable[(int)((f + 6.0f) * 83.0f)] * alpha);
            for (j = 0; j < vecSize; ++j) {
                int n = j;
                neu1e[n] = neu1e[n] + gradient * _syn1[j + l2];
            }
            for (j = 0; j < vecSize; ++j) {
                int n = j + l2;
                _syn1[n] = _syn1[n] + gradient * _syn0[j + l1];
            }
        }
    }

    private int cheapRandInt(int max) {
        int r;
        _seed ^= _seed << 21;
        _seed ^= _seed >>> 35;
        return (r = (int)(_seed ^= _seed << 4) % max) > 0 ? r : -r;
    }

    private void hierarchicalSoftmaxCBOW(int targetWord, float[] neu1, float[] neu1e) {
        int vecSize = this._wordVecSize;
        int tWrdCodeLen = this._HBWTCode[targetWord].length;
        float alpha = this._curLearningRate;
        float f = 0.0f;
        for (int i = 0; i < tWrdCodeLen; ++i) {
            int j;
            int l2 = this._HBWTPoint[targetWord][i] * vecSize;
            for (j = 0; j < vecSize; ++j) {
                f += neu1[j] * _syn1[j + l2];
            }
            if (!(f <= -6.0f) && !(f >= 6.0f)) {
                f = _expTable[(int)((f + 6.0f) * 83.0f)];
                float gradient = ((float)(1 - this._HBWTCode[targetWord][i]) - f) * alpha;
                for (j = 0; j < vecSize; ++j) {
                    int n = j;
                    neu1e[n] = neu1e[n] + gradient * _syn1[j + l2];
                }
                for (j = 0; j < vecSize; ++j) {
                    int n = j + l2;
                    _syn1[n] = _syn1[n] + gradient * neu1[j];
                }
            }
            f = 0.0f;
        }
    }

    private void hierarchicalSoftmaxSG(int targetWord, int l1, float[] neu1e) {
        int vecSize = this._wordVecSize;
        int tWrdCodeLen = this._HBWTCode[targetWord].length;
        float alpha = this._curLearningRate;
        float f = 0.0f;
        for (int i = 0; i < tWrdCodeLen; ++i) {
            int j;
            int l2 = this._HBWTPoint[targetWord][i] * vecSize;
            for (j = 0; j < vecSize; ++j) {
                f += _syn0[j + l1] * _syn1[j + l2];
            }
            if (!(f <= -6.0f) && !(f >= 6.0f)) {
                f = _expTable[(int)((f + 6.0f) * 83.0f)];
                float gradient = ((float)(1 - this._HBWTCode[targetWord][i]) - f) * alpha;
                for (j = 0; j < vecSize; ++j) {
                    int n = j;
                    neu1e[n] = neu1e[n] + gradient * _syn1[j + l2];
                }
                for (j = 0; j < vecSize; ++j) {
                    int n = j + l2;
                    _syn1[n] = _syn1[n] + gradient * _syn0[j + l1];
                }
            }
            f = 0.0f;
        }
    }
}

