/*
 * Decompiled with CFR 0.152.
 */
package hivemall.topicmodel;

import hivemall.annotations.VisibleForTesting;
import hivemall.topicmodel.AbstractProbabilisticTopicModel;
import hivemall.utils.lang.ArrayUtils;
import hivemall.utils.math.MathUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import org.apache.commons.math3.distribution.GammaDistribution;
import org.apache.commons.math3.special.Gamma;

public final class OnlineLDAModel
extends AbstractProbabilisticTopicModel {
    private static final double SHAPE = 100.0;
    private static final double SCALE = 0.01;
    private final float _alpha;
    private final float _eta;
    @Nonnegative
    private final double _tau0;
    @Nonnegative
    private final double _kappa;
    private final double _delta;
    private long _updateCount = 0L;
    private double _rhot;
    private final boolean _isAutoD;
    private List<Map<String, float[]>> _phi;
    private float[][] _gamma;
    @Nonnull
    private final Map<String, float[]> _lambda;
    @Nonnull
    private final GammaDistribution _gd;
    private float _docRatio = 1.0f;
    private double _valueSum = 0.0;

    public OnlineLDAModel(int K, float alpha, double delta) {
        this(K, alpha, 0.05f, -1L, 1020.0, 0.7, delta);
    }

    public OnlineLDAModel(int K, float alpha, float eta, long D, double tau0, double kappa, double delta) {
        super(K);
        if (tau0 < 0.0) {
            throw new IllegalArgumentException("tau0 MUST be positive: " + tau0);
        }
        if (kappa <= 0.5 || 1.0 < kappa) {
            throw new IllegalArgumentException("kappa MUST be in (0.5, 1.0]: " + kappa);
        }
        this._alpha = alpha;
        this._eta = eta;
        this._D = D;
        this._tau0 = tau0;
        this._kappa = kappa;
        this._delta = delta;
        this._isAutoD = this._D <= 0L;
        this._gd = new GammaDistribution(100.0, 0.01);
        this._gd.reseedRandomGenerator(1001L);
        this._lambda = new HashMap<String, float[]>(100);
    }

    @Override
    protected void accumulateDocCount() {
        if (this._isAutoD) {
            ++this._D;
        }
    }

    @Override
    protected void train(@Nonnull String[][] miniBatch) {
        this.preprocessMiniBatch(miniBatch);
        this.initParams(true);
        this.eStep();
        this._rhot = Math.pow(this._tau0 + (double)this._updateCount, -this._kappa);
        this.mStep();
        ++this._updateCount;
    }

    private void preprocessMiniBatch(@Nonnull String[][] miniBatch) {
        OnlineLDAModel.initMiniBatch(miniBatch, this._miniBatchDocs);
        this._miniBatchSize = this._miniBatchDocs.size();
        double valueSum = 0.0;
        for (int d = 0; d < this._miniBatchSize; ++d) {
            for (Float n : ((Map)this._miniBatchDocs.get(d)).values()) {
                valueSum += (double)n.floatValue();
            }
        }
        this._valueSum = valueSum;
        this._docRatio = (float)((double)this._D / (double)this._miniBatchSize);
    }

    private void initParams(boolean gammaWithRandom) {
        ArrayList<Map<String, float[]>> phi = new ArrayList<Map<String, float[]>>();
        float[][] gamma = new float[this._miniBatchSize][];
        for (int d = 0; d < this._miniBatchSize; ++d) {
            gamma[d] = gammaWithRandom ? ArrayUtils.newRandomFloatArray(this._K, this._gd) : ArrayUtils.newFloatArray(this._K, 1.0f);
            HashMap<String, float[]> phi_d = new HashMap<String, float[]>();
            phi.add(phi_d);
            for (String label : ((Map)this._miniBatchDocs.get(d)).keySet()) {
                phi_d.put(label, new float[this._K]);
                if (this._lambda.containsKey(label)) continue;
                this._lambda.put(label, ArrayUtils.newRandomFloatArray(this._K, this._gd));
            }
        }
        this._phi = phi;
        this._gamma = gamma;
    }

    private void eStep() {
        double[] lambdaSum = new double[this._K];
        HashMap<String, float[]> digamma_lambda = new HashMap<String, float[]>();
        for (Map.Entry<String, float[]> e : this._lambda.entrySet()) {
            String label = e.getKey();
            float[] lambda_label = e.getValue();
            MathUtils.add(lambda_label, lambdaSum, this._K);
            digamma_lambda.put(label, MathUtils.digamma(lambda_label));
        }
        double[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
        for (int d = 0; d < this._miniBatchSize; ++d) {
            float[] gammaPrev_d;
            float[] gamma_d = this._gamma[d];
            Map<String, float[]> eLogBeta_d = this.computeElogBetaPerDoc(d, digamma_lambda, digamma_lambdaSum);
            do {
                gammaPrev_d = (float[])gamma_d.clone();
                this.updatePhiPerDoc(d, eLogBeta_d);
                this.updateGammaPerDoc(d);
            } while (!this.checkGammaDiff(gammaPrev_d, gamma_d));
        }
    }

    @Nonnull
    private Map<String, float[]> computeElogBetaPerDoc(@Nonnegative int d, @Nonnull Map<String, float[]> digamma_lambda, @Nonnull double[] digamma_lambdaSum) {
        Map doc = (Map)this._miniBatchDocs.get(d);
        HashMap<String, float[]> eLogBeta_d = new HashMap<String, float[]>(doc.size());
        for (String label : doc.keySet()) {
            float[] eLogBeta_label = (float[])eLogBeta_d.get(label);
            if (eLogBeta_label == null) {
                eLogBeta_label = new float[this._K];
                eLogBeta_d.put(label, eLogBeta_label);
            }
            float[] digamma_lambda_label = digamma_lambda.get(label);
            for (int k = 0; k < this._K; ++k) {
                eLogBeta_label[k] = (float)((double)digamma_lambda_label[k] - digamma_lambdaSum[k]);
            }
        }
        return eLogBeta_d;
    }

    private void updatePhiPerDoc(@Nonnegative int d, @Nonnull Map<String, float[]> eLogBeta_d) {
        float[] gamma_d = this._gamma[d];
        double digamma_gammaSum_d = Gamma.digamma(MathUtils.sum(gamma_d));
        double[] eLogTheta_d = new double[this._K];
        for (int k = 0; k < this._K; ++k) {
            eLogTheta_d[k] = Gamma.digamma(gamma_d[k]) - digamma_gammaSum_d;
        }
        Map<String, float[]> phi_d = this._phi.get(d);
        Map doc = (Map)this._miniBatchDocs.get(d);
        for (String label : doc.keySet()) {
            int k;
            float[] phi_label = phi_d.get(label);
            float[] eLogBeta_label = eLogBeta_d.get(label);
            double normalizer = 0.0;
            for (k = 0; k < this._K; ++k) {
                float phiVal;
                phi_label[k] = phiVal = (float)Math.exp((double)eLogBeta_label[k] + eLogTheta_d[k]) + 1.0E-20f;
                normalizer += (double)phiVal;
            }
            k = 0;
            while (k < this._K) {
                int n = k++;
                phi_label[n] = (float)((double)phi_label[n] / normalizer);
            }
        }
    }

    private void updateGammaPerDoc(@Nonnegative int d) {
        Map doc = (Map)this._miniBatchDocs.get(d);
        Map<String, float[]> phi_d = this._phi.get(d);
        float[] gamma_d = this._gamma[d];
        for (int k = 0; k < this._K; ++k) {
            gamma_d[k] = this._alpha;
        }
        for (Map.Entry e : doc.entrySet()) {
            float[] phi_label = phi_d.get(e.getKey());
            float val = ((Float)e.getValue()).floatValue();
            for (int k = 0; k < this._K; ++k) {
                int n = k;
                gamma_d[n] = gamma_d[n] + phi_label[k] * val;
            }
        }
    }

    private boolean checkGammaDiff(@Nonnull float[] gammaPrev, @Nonnull float[] gammaNext) {
        double diff = 0.0;
        for (int k = 0; k < this._K; ++k) {
            diff += (double)Math.abs(gammaPrev[k] - gammaNext[k]);
        }
        return diff / (double)this._K < this._delta;
    }

    private void mStep() {
        float[] lambdaTilde_label;
        HashMap<String, float[]> lambdaTilde = new HashMap<String, float[]>();
        for (int d = 0; d < this._miniBatchSize; ++d) {
            Map<String, float[]> phi_d = this._phi.get(d);
            for (String label : ((Map)this._miniBatchDocs.get(d)).keySet()) {
                lambdaTilde_label = (float[])lambdaTilde.get(label);
                if (lambdaTilde_label == null) {
                    lambdaTilde_label = ArrayUtils.newFloatArray(this._K, this._eta);
                    lambdaTilde.put(label, lambdaTilde_label);
                }
                float[] phi_label = phi_d.get(label);
                for (int k = 0; k < this._K; ++k) {
                    int n = k;
                    lambdaTilde_label[n] = lambdaTilde_label[n] + this._docRatio * phi_label[k];
                }
            }
        }
        for (Map.Entry<String, float[]> e : this._lambda.entrySet()) {
            String label = e.getKey();
            float[] lambda_label = e.getValue();
            lambdaTilde_label = (float[])lambdaTilde.get(label);
            if (lambdaTilde_label == null) {
                lambdaTilde_label = ArrayUtils.newFloatArray(this._K, this._eta);
            }
            for (int k = 0; k < this._K; ++k) {
                lambda_label[k] = (float)((1.0 - this._rhot) * (double)lambda_label[k] + this._rhot * (double)lambdaTilde_label[k]);
            }
        }
    }

    @Override
    protected float computePerplexity() {
        double bound = this.computeApproxBound();
        double perWordBound = bound / ((double)this._docRatio * this._valueSum);
        return (float)Math.exp(-1.0 * perWordBound);
    }

    private double computeApproxBound() {
        double[] gammaSum = new double[this._miniBatchSize];
        for (int d = 0; d < this._miniBatchSize; ++d) {
            gammaSum[d] = MathUtils.sum(this._gamma[d]);
        }
        double[] digamma_gammaSum = MathUtils.digamma(gammaSum);
        double[] lambdaSum = new double[this._K];
        for (float[] lambda_label : this._lambda.values()) {
            MathUtils.add(lambda_label, lambdaSum, this._K);
        }
        double[] digamma_lambdaSum = MathUtils.digamma(lambdaSum);
        double logGamma_alpha = Gamma.logGamma(this._alpha);
        double logGamma_alphaSum = Gamma.logGamma((float)this._K * this._alpha);
        double score = 0.0;
        for (int d = 0; d < this._miniBatchSize; ++d) {
            double digamma_gammaSum_d = digamma_gammaSum[d];
            float[] gamma_d = this._gamma[d];
            for (Map.Entry e : ((Map)this._miniBatchDocs.get(d)).entrySet()) {
                float[] lambda_label = this._lambda.get(e.getKey());
                double[] temp = new double[this._K];
                double max = Double.MIN_VALUE;
                for (int k = 0; k < this._K; ++k) {
                    double eLogBeta_kw;
                    double eLogTheta_dk = Gamma.digamma(gamma_d[k]) - digamma_gammaSum_d;
                    double tempK = eLogTheta_dk + (eLogBeta_kw = Gamma.digamma(lambda_label[k]) - digamma_lambdaSum[k]);
                    if (tempK > max) {
                        max = tempK;
                    }
                    temp[k] = tempK;
                }
                double logsumexp = MathUtils.logsumexp(temp, max);
                score += (double)((Float)e.getValue()).floatValue() * logsumexp;
            }
            for (int k = 0; k < this._K; ++k) {
                float gamma_dk = gamma_d[k];
                score += (double)(this._alpha - gamma_dk) * (Gamma.digamma(gamma_dk) - digamma_gammaSum_d);
                score += Gamma.logGamma(gamma_dk) - logGamma_alpha;
            }
            score += logGamma_alphaSum;
            score -= Gamma.logGamma(gammaSum[d]);
        }
        score *= (double)this._docRatio;
        double logGamma_eta = Gamma.logGamma(this._eta);
        double logGamma_etaSum = Gamma.logGamma(this._eta * (float)this._lambda.size());
        for (float[] lambda_label : this._lambda.values()) {
            for (int k = 0; k < this._K; ++k) {
                float lambda_label_k = lambda_label[k];
                score += (double)(this._eta - lambda_label_k) * (Gamma.digamma(lambda_label_k) - digamma_lambdaSum[k]);
                score += Gamma.logGamma(lambda_label_k) - logGamma_eta;
            }
        }
        for (int k = 0; k < this._K; ++k) {
            score += logGamma_etaSum - Gamma.logGamma(lambdaSum[k]);
        }
        return score;
    }

    @Override
    @VisibleForTesting
    float getWordScore(@Nonnull String label, @Nonnegative int k) {
        float[] lambda_label = this._lambda.get(label);
        if (lambda_label == null) {
            throw new IllegalArgumentException("Word `" + label + "` is not in the corpus.");
        }
        if (k >= lambda_label.length) {
            throw new IllegalArgumentException("Topic index must be in [0, " + this._lambda.get(label).length + "]");
        }
        return lambda_label[k];
    }

    @Override
    protected void setWordScore(@Nonnull String label, @Nonnegative int k, float lambda_k) {
        float[] lambda_label = this._lambda.get(label);
        if (lambda_label == null) {
            lambda_label = ArrayUtils.newRandomFloatArray(this._K, this._gd);
            this._lambda.put(label, lambda_label);
        }
        lambda_label[k] = lambda_k;
    }

    @Override
    @Nonnull
    protected SortedMap<Float, List<String>> getTopicWords(@Nonnegative int k) {
        return this.getTopicWords(k, this._lambda.keySet().size());
    }

    @Nonnull
    public SortedMap<Float, List<String>> getTopicWords(@Nonnegative int k, @Nonnegative int topN) {
        double lambdaSum = 0.0;
        TreeMap sortedLambda = new TreeMap(Collections.reverseOrder());
        for (Map.Entry<String, float[]> e : this._lambda.entrySet()) {
            float lambda_k = e.getValue()[k];
            lambdaSum += (double)lambda_k;
            ArrayList<String> labels = (ArrayList<String>)sortedLambda.get(Float.valueOf(lambda_k));
            if (labels == null) {
                labels = new ArrayList<String>();
                sortedLambda.put(Float.valueOf(lambda_k), labels);
            }
            labels.add(e.getKey());
        }
        TreeMap<Float, List<String>> ret = new TreeMap<Float, List<String>>(Collections.reverseOrder());
        topN = Math.min(topN, this._lambda.keySet().size());
        int tt = 0;
        for (Map.Entry e : sortedLambda.entrySet()) {
            float key = (float)((double)((Float)e.getKey()).floatValue() / lambdaSum);
            ret.put(Float.valueOf(key), (List<String>)e.getValue());
            if (++tt != topN) continue;
            break;
        }
        return ret;
    }

    @Override
    @Nonnull
    protected float[] getTopicDistribution(@Nonnull String[] doc) {
        this.preprocessMiniBatch(new String[][]{doc});
        this.initParams(false);
        this.eStep();
        float[] topicDistr = new float[this._K];
        float[] gamma0 = this._gamma[0];
        double gammaSum = MathUtils.sum(gamma0);
        for (int k = 0; k < this._K; ++k) {
            topicDistr[k] = (float)((double)gamma0[k] / gammaSum);
        }
        return topicDistr;
    }
}

