/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.models.embeddings.word2vec;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.models.embeddings.word2vec.NegativeHolder;
import org.deeplearning4j.spark.models.embeddings.word2vec.VocabHolder;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import scala.Tuple2;

public class SecondIterationFunction
implements FlatMapFunction<Iterator<Tuple2<List<VocabWord>, Long>>, Map.Entry<VocabWord, INDArray>> {
    private int ithIteration = 1;
    private int vectorLength;
    private boolean useAdaGrad;
    private int batchSize = 0;
    private double negative;
    private int window;
    private double alpha;
    private double minAlpha;
    private long totalWordCount;
    private long seed;
    private int maxExp;
    private double[] expTable;
    private int iterations;
    private AtomicLong nextRandom = new AtomicLong(5L);
    private volatile VocabCache<VocabWord> vocab;
    private volatile transient NegativeHolder negativeHolder;
    private volatile transient VocabHolder vocabHolder;
    private AtomicLong cid = new AtomicLong(0L);
    private AtomicLong aff = new AtomicLong(0L);

    public SecondIterationFunction(Broadcast<Map<String, Object>> word2vecVarMapBroadcast, Broadcast<double[]> expTableBroadcast, Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast) {
        Map word2vecVarMap = (Map)word2vecVarMapBroadcast.getValue();
        this.expTable = (double[])expTableBroadcast.getValue();
        this.vectorLength = (Integer)word2vecVarMap.get("vectorLength");
        this.useAdaGrad = (Boolean)word2vecVarMap.get("useAdaGrad");
        this.negative = (Double)word2vecVarMap.get("negative");
        this.window = (Integer)word2vecVarMap.get("window");
        this.alpha = (Double)word2vecVarMap.get("alpha");
        this.minAlpha = (Double)word2vecVarMap.get("minAlpha");
        this.totalWordCount = (Long)word2vecVarMap.get("totalWordCount");
        this.seed = (Long)word2vecVarMap.get("seed");
        this.maxExp = (Integer)word2vecVarMap.get("maxExp");
        this.iterations = (Integer)word2vecVarMap.get("iterations");
        this.batchSize = (Integer)word2vecVarMap.get("batchSize");
        this.vocab = (VocabCache)vocabCacheBroadcast.getValue();
        if (this.vocab == null) {
            throw new RuntimeException("VocabCache is null");
        }
    }

    public Iterator<Map.Entry<VocabWord, INDArray>> call(Iterator<Tuple2<List<VocabWord>, Long>> pairIter) {
        this.vocabHolder = VocabHolder.getInstance();
        this.vocabHolder.setSeed(this.seed, this.vectorLength);
        if (this.negative > 0.0) {
            this.negativeHolder = NegativeHolder.getInstance();
            this.negativeHolder.initHolder(this.vocab, this.expTable, this.vectorLength);
        }
        while (pairIter.hasNext()) {
            ArrayList<Pair> batch = new ArrayList<Pair>();
            while (pairIter.hasNext() && batch.size() < this.batchSize) {
                Tuple2<List<VocabWord>, Long> pair = pairIter.next();
                List vocabWordsList = (List)pair._1();
                Long sentenceCumSumCount = (Long)pair._2();
                batch.add(Pair.of((Object)vocabWordsList, (Object)sentenceCumSumCount));
            }
            for (int i = 0; i < this.iterations; ++i) {
                for (Pair pair : batch) {
                    List vocabWordsList = (List)pair.getKey();
                    Long sentenceCumSumCount = (Long)pair.getValue();
                    double currentSentenceAlpha = Math.max(this.minAlpha, this.alpha - (this.alpha - this.minAlpha) * ((double)sentenceCumSumCount.longValue() / (double)this.totalWordCount));
                    this.trainSentence(vocabWordsList, currentSentenceAlpha);
                }
            }
        }
        return this.vocabHolder.getSplit(this.vocab).iterator();
    }

    public void trainSentence(List<VocabWord> vocabWordsList, double currentSentenceAlpha) {
        if (vocabWordsList != null && !vocabWordsList.isEmpty()) {
            for (int ithWordInSentence = 0; ithWordInSentence < vocabWordsList.size(); ++ithWordInSentence) {
                this.nextRandom.set(Math.abs(this.nextRandom.get() * 25214903917L + 11L));
                int b = (int)this.nextRandom.get() % this.window;
                VocabWord currentWord = vocabWordsList.get(ithWordInSentence);
                if (currentWord == null) continue;
                this.skipGram(ithWordInSentence, vocabWordsList, b, currentSentenceAlpha);
            }
        }
    }

    public void skipGram(int ithWordInSentence, List<VocabWord> vocabWordsList, int b, double currentSentenceAlpha) {
        VocabWord currentWord = vocabWordsList.get(ithWordInSentence);
        if (currentWord != null && !vocabWordsList.isEmpty()) {
            int end = this.window * 2 + 1 - b;
            for (int a = b; a < end; ++a) {
                int c;
                if (a == this.window || (c = ithWordInSentence - this.window + a) < 0 || c >= vocabWordsList.size()) continue;
                VocabWord lastWord = vocabWordsList.get(c);
                this.iterateSample(currentWord, lastWord, currentSentenceAlpha);
            }
        }
    }

    public void iterateSample(VocabWord w1, VocabWord w2, double currentSentenceAlpha) {
        INDArray l1;
        INDArray neu1e;
        block8: {
            if (w1 == null || w2 == null || w2.getIndex() < 0 || w2.getIndex() == w1.getIndex()) {
                return;
            }
            int currentWordIndex = w2.getIndex();
            neu1e = Nd4j.create((int)this.vectorLength);
            l1 = this.vocabHolder.getSyn0Vector(currentWordIndex, this.vocab);
            for (int i = 0; i < w1.getCodeLength(); ++i) {
                int idx;
                byte code = (Byte)w1.getCodes().get(i);
                int point = (Integer)w1.getPoints().get(i);
                if (point < 0) {
                    throw new IllegalStateException("Illegal point " + point);
                }
                INDArray syn1 = this.vocabHolder.getSyn1Vector(point);
                double dot = Nd4j.getBlasWrapper().level1().dot((long)this.vectorLength, 1.0, l1, syn1);
                if (dot < (double)(-this.maxExp) || dot >= (double)this.maxExp || (idx = (int)((dot + (double)this.maxExp) * ((double)this.expTable.length / (double)this.maxExp / 2.0))) >= this.expTable.length) continue;
                double f = this.expTable[idx];
                double g = ((double)(1 - code) - f) * (this.useAdaGrad ? w1.getGradient(i, currentSentenceAlpha, currentSentenceAlpha) : currentSentenceAlpha);
                Nd4j.getBlasWrapper().level1().axpy((long)this.vectorLength, g, syn1, neu1e);
                Nd4j.getBlasWrapper().level1().axpy((long)this.vectorLength, g, l1, syn1);
            }
            int target = w1.getIndex();
            if (!(this.negative > 0.0)) break block8;
            int d = 0;
            while ((double)d < this.negative + 1.0) {
                block11: {
                    double g;
                    block13: {
                        double f;
                        int label;
                        block14: {
                            block12: {
                                block10: {
                                    block9: {
                                        if (d != 0) break block9;
                                        label = 1;
                                        break block10;
                                    }
                                    this.nextRandom.set(Math.abs(this.nextRandom.get() * 25214903917L + 11L));
                                    int idx = (int)Math.abs((long)((int)(this.nextRandom.get() >> 16)) % this.negativeHolder.getTable().length());
                                    target = this.negativeHolder.getTable().getInt(new int[]{idx});
                                    if (target <= 0) {
                                        target = (int)this.nextRandom.get() % (this.vocab.numWords() - 1) + 1;
                                    }
                                    if (target == w1.getIndex()) break block11;
                                    label = 0;
                                }
                                if (target >= this.negativeHolder.getSyn1Neg().rows() || target < 0) break block11;
                                f = Nd4j.getBlasWrapper().dot(l1, this.negativeHolder.getSyn1Neg().slice((long)target));
                                if (!(f > (double)this.maxExp)) break block12;
                                g = this.useAdaGrad ? w1.getGradient(target, (double)(label - 1), this.alpha) : (double)(label - 1) * this.alpha;
                                break block13;
                            }
                            if (!(f < (double)(-this.maxExp))) break block14;
                            g = (double)label * (this.useAdaGrad ? w1.getGradient(target, this.alpha, this.alpha) : this.alpha);
                            break block13;
                        }
                        int idx = (int)((f + (double)this.maxExp) * (double)(this.expTable.length / this.maxExp / 2));
                        if (idx >= this.expTable.length) break block11;
                        g = this.useAdaGrad ? w1.getGradient(target, (double)label - this.expTable[idx], this.alpha) : ((double)label - this.expTable[idx]) * this.alpha;
                    }
                    Nd4j.getBlasWrapper().level1().axpy((long)this.vectorLength, g, this.negativeHolder.getSyn1Neg().slice((long)target), neu1e);
                    Nd4j.getBlasWrapper().level1().axpy((long)this.vectorLength, g, l1, this.negativeHolder.getSyn1Neg().slice((long)target));
                }
                ++d;
            }
        }
        Nd4j.getBlasWrapper().level1().axpy((long)this.vectorLength, 1.0, neu1e, l1);
    }

    private INDArray getRandomSyn0Vec(int vectorLength, long lseed) {
        return Nd4j.rand((int[])new int[]{1, vectorLength}, (long)(lseed * this.seed)).subi((Number)0.5).divi((Number)vectorLength);
    }
}

