/*
 * Decompiled with CFR 0.152.
 */
package hivemall.topicmodel;

import hivemall.annotations.VisibleForTesting;
import hivemall.topicmodel.AbstractProbabilisticTopicModel;
import hivemall.utils.lang.ArrayUtils;
import hivemall.utils.math.MathUtils;
import hivemall.utils.random.PRNG;
import hivemall.utils.random.RandomNumberGeneratorFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;

public final class IncrementalPLSAModel
extends AbstractProbabilisticTopicModel {
    private final float _alpha;
    private final double _delta;
    @Nonnull
    private final PRNG _rnd;
    private List<Map<String, float[]>> _p_dwz;
    private List<float[]> _p_dz;
    @Nonnull
    private final Map<String, float[]> _p_zw;

    public IncrementalPLSAModel(int K, float alpha, double delta) {
        super(K);
        this._alpha = alpha;
        this._delta = delta;
        this._rnd = RandomNumberGeneratorFactory.createPRNG(1001L);
        this._p_zw = new HashMap<String, float[]>();
    }

    @Override
    protected void train(@Nonnull String[][] miniBatch) {
        IncrementalPLSAModel.initMiniBatch(miniBatch, this._miniBatchDocs);
        this._miniBatchSize = this._miniBatchDocs.size();
        this.initParams();
        ArrayList<float[]> pPrev_dz = new ArrayList<float[]>();
        for (int d = 0; d < this._miniBatchSize; ++d) {
            do {
                pPrev_dz.clear();
                for (float[] p_dz_d : this._p_dz) {
                    pPrev_dz.add((float[])p_dz_d.clone());
                }
                this.eStep(d);
                this.mStep(d);
            } while (!this.isPdzConverged(d, pPrev_dz, this._p_dz));
        }
    }

    private void initParams() {
        ArrayList<float[]> p_dz = new ArrayList<float[]>();
        ArrayList<Map<String, float[]>> p_dwz = new ArrayList<Map<String, float[]>>();
        for (int d = 0; d < this._miniBatchSize; ++d) {
            float[] p_dz_d = MathUtils.l1normalize(ArrayUtils.newRandomFloatArray(this._K, this._rnd));
            p_dz.add(p_dz_d);
            HashMap<String, float[]> p_dwz_d = new HashMap<String, float[]>();
            p_dwz.add(p_dwz_d);
            for (String w : ((Map)this._miniBatchDocs.get(d)).keySet()) {
                float[] p_dwz_dw = MathUtils.l1normalize(ArrayUtils.newRandomFloatArray(this._K, this._rnd));
                p_dwz_d.put(w, p_dwz_dw);
                if (this._p_zw.containsKey(w)) continue;
                this._p_zw.put(w, ArrayUtils.newRandomFloatArray(this._K, this._rnd));
            }
        }
        double[] sums = new double[this._K];
        for (float[] p_zw_w : this._p_zw.values()) {
            MathUtils.add(p_zw_w, sums, this._K);
        }
        for (float[] p_zw_w : this._p_zw.values()) {
            for (int z = 0; z < this._K; ++z) {
                int n = z;
                p_zw_w[n] = (float)((double)p_zw_w[n] / sums[z]);
            }
        }
        this._p_dz = p_dz;
        this._p_dwz = p_dwz;
    }

    private void eStep(@Nonnegative int d) {
        Map<String, float[]> p_dwz_d = this._p_dwz.get(d);
        float[] p_dz_d = this._p_dz.get(d);
        for (String w : ((Map)this._miniBatchDocs.get(d)).keySet()) {
            float[] p_dwz_dw = p_dwz_d.get(w);
            float[] p_zw_w = this._p_zw.get(w);
            for (int z = 0; z < this._K; ++z) {
                p_dwz_dw[z] = p_dz_d[z] * p_zw_w[z];
            }
            MathUtils.l1normalize(p_dwz_dw);
        }
    }

    private void mStep(@Nonnegative int d) {
        Map doc = (Map)this._miniBatchDocs.get(d);
        Map<String, float[]> p_dwz_d = this._p_dwz.get(d);
        float[] p_dz_d = this._p_dz.get(d);
        Arrays.fill(p_dz_d, 0.0f);
        for (Map.Entry e : doc.entrySet()) {
            float[] p_dwz_dw = p_dwz_d.get(e.getKey());
            float n = ((Float)e.getValue()).floatValue();
            for (int z = 0; z < this._K; ++z) {
                int n2 = z;
                p_dz_d[n2] = p_dz_d[n2] + n * p_dwz_dw[z];
            }
        }
        MathUtils.l1normalize(p_dz_d);
        double[] sums = new double[this._K];
        for (Map.Entry<String, float[]> e : this._p_zw.entrySet()) {
            String w = e.getKey();
            float[] p_zw_w = e.getValue();
            Float w_value = (Float)doc.get(w);
            if (w_value != null) {
                float n = w_value.floatValue();
                float[] p_dwz_dw = p_dwz_d.get(w);
                for (int z = 0; z < this._K; ++z) {
                    p_zw_w[z] = n * p_dwz_dw[z] + this._alpha * p_zw_w[z];
                }
            } else {
                for (int z = 0; z < this._K; ++z) {
                    p_zw_w[z] = this._alpha * p_zw_w[z];
                }
            }
            MathUtils.add(p_zw_w, sums, this._K);
        }
        for (float[] p_zw_w : this._p_zw.values()) {
            for (int z = 0; z < this._K; ++z) {
                p_zw_w[z] = (float)((double)p_zw_w[z] / sums[z]);
            }
        }
    }

    private boolean isPdzConverged(@Nonnegative int d, @Nonnull List<float[]> pPrev_dz, @Nonnull List<float[]> p_dz) {
        float[] pPrev_dz_d = pPrev_dz.get(d);
        float[] p_dz_d = p_dz.get(d);
        double diff = 0.0;
        for (int z = 0; z < this._K; ++z) {
            diff += (double)Math.abs(pPrev_dz_d[z] - p_dz_d[z]);
        }
        return diff / (double)this._K < this._delta;
    }

    @Override
    protected float computePerplexity() {
        double numer = 0.0;
        double denom = 0.0;
        for (int d = 0; d < this._miniBatchSize; ++d) {
            float[] p_dz_d = this._p_dz.get(d);
            for (Map.Entry e : ((Map)this._miniBatchDocs.get(d)).entrySet()) {
                String w = (String)e.getKey();
                float w_value = ((Float)e.getValue()).floatValue();
                float[] p_zw_w = this._p_zw.get(w);
                double p_dw = 0.0;
                for (int z = 0; z < this._K; ++z) {
                    p_dw += (double)p_zw_w[z] * (double)p_dz_d[z];
                }
                if (p_dw == 0.0) {
                    throw new IllegalStateException("Perplexity would be Infinity. Try different mini-batch size `-s`, larger `-delta` and/or larger `-alpha`.");
                }
                numer += (double)w_value * Math.log(p_dw);
                denom += (double)w_value;
            }
        }
        return (float)Math.exp(-1.0 * (numer / denom));
    }

    @Override
    @Nonnull
    protected SortedMap<Float, List<String>> getTopicWords(@Nonnegative int z) {
        TreeMap<Float, List<String>> res = new TreeMap<Float, List<String>>(Collections.reverseOrder());
        for (Map.Entry<String, float[]> e : this._p_zw.entrySet()) {
            String w = e.getKey();
            float prob = e.getValue()[z];
            ArrayList<String> words = (ArrayList<String>)res.get(Float.valueOf(prob));
            if (words == null) {
                words = new ArrayList<String>();
                res.put(Float.valueOf(prob), words);
            }
            words.add(w);
        }
        return res;
    }

    @Override
    @Nonnull
    protected float[] getTopicDistribution(@Nonnull String[] doc) {
        this.train(new String[][]{doc});
        return this._p_dz.get(0);
    }

    @Override
    @VisibleForTesting
    float getWordScore(@Nonnull String w, @Nonnegative int z) {
        return this._p_zw.get(w)[z];
    }

    @Override
    protected void setWordScore(@Nonnull String w, @Nonnegative int z, float prob) {
        float[] prob_label = this._p_zw.get(w);
        if (prob_label == null) {
            prob_label = ArrayUtils.newRandomFloatArray(this._K, this._rnd);
            this._p_zw.put(w, prob_label);
        }
        prob_label[z] = prob;
        double[] sums = new double[this._K];
        for (float[] p_zw_w : this._p_zw.values()) {
            MathUtils.add(p_zw_w, sums, this._K);
        }
        for (float[] p_zw_w : this._p_zw.values()) {
            for (int zi = 0; zi < this._K; ++zi) {
                int n = zi;
                p_zw_w[n] = (float)((double)p_zw_w[n] / sums[zi]);
            }
        }
    }
}

