/*
 * Decompiled with CFR 0.152.
 */
package jsat.text.topicmodel;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.ScaledVector;
import jsat.linear.SparseVector;
import jsat.linear.Vec;
import jsat.math.FastMath;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class OnlineLDAsvi
implements Parameterized {
    private double alpha = 1.0;
    private double eta = 1.0;
    private double tau0 = 128.0;
    private double kappa = 0.7;
    private int epochs = 1;
    private int D = -1;
    private int K = -1;
    private int W = -1;
    private int miniBatchSize = 256;
    private int t;
    private List<Vec> lambda;
    private List<Lock> lambdaLocks;
    private DoubleList lambdaSums;
    private int[] lastUsed;
    private List<Vec> ELogBeta;
    private List<Vec> ExpELogBeta;
    private ThreadLocal<Vec> gammaLocal;
    private ThreadLocal<Vec> logThetaLocal;
    private ThreadLocal<Vec> expLogThetaLocal;

    public OnlineLDAsvi() {
        this.W = -1;
        this.D = -1;
        this.K = -1;
    }

    public OnlineLDAsvi(int K, int D, int W) {
        this.setK(K);
        this.setD(D);
        this.setVocabSize(W);
    }

    public void setK(final int K) {
        if (K < 2) {
            throw new IllegalArgumentException("At least 2 topics must be learned");
        }
        this.K = K;
        this.gammaLocal = new ThreadLocal<Vec>(){

            @Override
            protected Vec initialValue() {
                return new DenseVector(K);
            }
        };
        this.logThetaLocal = new ThreadLocal<Vec>(){

            @Override
            protected Vec initialValue() {
                return new DenseVector(K);
            }
        };
        this.expLogThetaLocal = new ThreadLocal<Vec>(){

            @Override
            protected Vec initialValue() {
                return new DenseVector(K);
            }
        };
        this.lambda = null;
    }

    public int getK() {
        return this.K;
    }

    public void setD(int D) {
        if (D < 1) {
            throw new IllegalArgumentException("The number of documents must be positive, not " + D);
        }
        this.D = D;
    }

    public int getD() {
        return this.D;
    }

    public void setVocabSize(int W) {
        if (W < 1) {
            throw new IllegalArgumentException("Vocabulary size must be positive, not " + W);
        }
        this.W = W;
    }

    public int getVocabSize() {
        return this.W;
    }

    public void setAlpha(double alpha) {
        if (alpha <= 0.0 || Double.isInfinite(alpha) || Double.isNaN(alpha)) {
            throw new IllegalArgumentException("Alpha must be a positive constant, not " + alpha);
        }
        this.alpha = alpha;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setEta(double eta) {
        if (eta <= 0.0 || Double.isInfinite(eta) || Double.isNaN(eta)) {
            throw new IllegalArgumentException("Eta must be a positive constant, not " + eta);
        }
        this.eta = eta;
    }

    public double getEta() {
        return this.eta;
    }

    public void setTau0(double tau0) {
        if (tau0 <= 0.0 || Double.isInfinite(tau0) || Double.isNaN(tau0)) {
            throw new IllegalArgumentException("Eta must be a positive constant, not " + tau0);
        }
        this.tau0 = tau0;
    }

    public void setEpochs(int epochs) {
        this.epochs = epochs;
    }

    public int getEpochs() {
        return this.epochs;
    }

    public void setKappa(double kappa) {
        if (kappa < 0.5 || kappa > 1.0 || Double.isNaN(kappa)) {
            throw new IllegalArgumentException("Kapp must be in [0.5, 1], not " + kappa);
        }
        this.kappa = kappa;
    }

    public double getKappa() {
        return this.kappa;
    }

    public void setMiniBatchSize(int miniBatchSize) {
        if (miniBatchSize < 1) {
            throw new IllegalArgumentException("the batch size must be a positive constant, not " + miniBatchSize);
        }
        this.miniBatchSize = miniBatchSize;
    }

    public Vec getTopicVec(int k) {
        return new ScaledVector(1.0 / this.lambda.get(k).sum(), this.lambda.get(k));
    }

    private void expandPsiMinusPsiSum(Vec input, double sum, Vec output) {
        double psiSum = FastMath.digamma(sum);
        for (int i = 0; i < input.length(); ++i) {
            output.set(i, FastMath.digamma(input.get(i)) - psiSum);
        }
    }

    private static double sampleExpoDist(double lambdaInv, double p) {
        return -lambdaInv * FastMath.log(1.0 - p);
    }

    public void update(List<Vec> docs) {
        this.update(docs, new FakeExecutor());
    }

    public void update(final List<Vec> docs, ExecutorService ex) {
        if (this.lambda == null) {
            this.initialize();
        }
        this.updateBetas(docs, ex);
        final double rho_t = Math.pow(this.tau0 + (double)this.t++, -this.kappa);
        for (int k = 0; k < this.K; ++k) {
            this.lambda.get(k).mutableMultiply(1.0 - rho_t);
            this.lambdaSums.set(k, this.lambdaSums.getD(k) * (1.0 - rho_t));
        }
        final int P = SystemInfo.LogicalCores;
        final CountDownLatch latch = new CountDownLatch(P);
        int id = 0;
        while (id < P) {
            final int ID = id++;
            ex.submit(new Runnable(){

                @Override
                public void run() {
                    Random rand = RandomUtil.getRandom();
                    for (int d = ParallelUtils.getStartBlock(docs.size(), ID, P); d < ParallelUtils.getEndBlock(docs.size(), ID, P); ++d) {
                        Vec doc = (Vec)docs.get(d);
                        if (doc.nnz() == 0) continue;
                        Vec ELogTheta_d = (Vec)OnlineLDAsvi.this.logThetaLocal.get();
                        Vec ExpELogTheta_d = (Vec)OnlineLDAsvi.this.expLogThetaLocal.get();
                        Vec gamma_d = (Vec)OnlineLDAsvi.this.gammaLocal.get();
                        OnlineLDAsvi.this.prepareGammaTheta(gamma_d, ELogTheta_d, ExpELogTheta_d, rand);
                        int[] indexMap = new int[doc.nnz()];
                        double[] phiCols = new double[doc.nnz()];
                        OnlineLDAsvi.this.computePhi(doc, indexMap, phiCols, OnlineLDAsvi.this.K, gamma_d, ELogTheta_d, ExpELogTheta_d);
                        IntList toUpdate = new IntList(OnlineLDAsvi.this.K);
                        ListUtils.addRange(toUpdate, 0, OnlineLDAsvi.this.K, 1);
                        Collections.shuffle(toUpdate, rand);
                        int updatePos = 0;
                        while (!toUpdate.isEmpty()) {
                            int k = toUpdate.getI(updatePos);
                            if (((Lock)OnlineLDAsvi.this.lambdaLocks.get(k)).tryLock()) {
                                double coeff = ExpELogTheta_d.get(k) * rho_t * (double)OnlineLDAsvi.this.D / (double)docs.size();
                                Vec lambda_k = (Vec)OnlineLDAsvi.this.lambda.get(k);
                                Vec ExpELogBeta_k = (Vec)OnlineLDAsvi.this.ExpELogBeta.get(k);
                                double lambdaSum_k = OnlineLDAsvi.this.lambdaSums.getD(k);
                                for (int i = 0; i < doc.nnz(); ++i) {
                                    int indx = indexMap[i];
                                    double toAdd = coeff * phiCols[i] * ExpELogBeta_k.get(indx);
                                    lambda_k.increment(indx, toAdd);
                                    lambdaSum_k += toAdd;
                                }
                                OnlineLDAsvi.this.lambdaSums.set(k, lambdaSum_k);
                                ((Lock)OnlineLDAsvi.this.lambdaLocks.get(k)).unlock();
                                toUpdate.remove(updatePos);
                            }
                            if (toUpdate.isEmpty()) continue;
                            updatePos = (updatePos + 1) % toUpdate.size();
                        }
                    }
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex1) {
            Logger.getLogger(OnlineLDAsvi.class.getName()).log(Level.SEVERE, null, ex1);
        }
    }

    public void model(DataSet dataSet, int topics) {
        this.model(dataSet, topics, new FakeExecutor());
    }

    public void model(DataSet dataSet, int topics, ExecutorService ex) {
        if (ex == null) {
            ex = new FakeExecutor();
        }
        this.setK(topics);
        this.setD(dataSet.getSampleSize());
        this.setVocabSize(dataSet.getNumNumericalVars());
        List<Vec> docs = dataSet.getDataVectors();
        for (int epoch = 0; epoch < this.epochs; ++epoch) {
            Collections.shuffle(docs);
            for (int i = 0; i < this.D; i += this.miniBatchSize) {
                int to = Math.min(i + this.miniBatchSize, this.D);
                this.update(docs.subList(i, to), ex);
            }
        }
    }

    public Vec getTopics(Vec doc) {
        DenseVector gamma = new DenseVector(this.K);
        Random rand = RandomUtil.getRandom();
        double lambdaInv = (double)(this.W * this.K) / ((double)this.D * 100.0);
        for (int j = 0; j < ((Vec)gamma).length(); ++j) {
            ((Vec)gamma).set(j, OnlineLDAsvi.sampleExpoDist(lambdaInv, rand.nextDouble()) + this.eta);
        }
        DenseVector eLogTheta_i = new DenseVector(this.K);
        DenseVector expLogTheta_i = new DenseVector(this.K);
        this.expandPsiMinusPsiSum(gamma, ((Vec)gamma).sum(), eLogTheta_i);
        for (int j = 0; j < ((Vec)eLogTheta_i).length(); ++j) {
            ((Vec)expLogTheta_i).set(j, FastMath.exp(((Vec)eLogTheta_i).get(j)));
        }
        this.computePhi(doc, new int[doc.nnz()], new double[doc.nnz()], this.K, gamma, eLogTheta_i, expLogTheta_i);
        ((Vec)gamma).mutableDivide(((Vec)gamma).sum());
        return gamma;
    }

    private void updateBetas(List<Vec> docs, ExecutorService ex) {
        final double[] digammaLambdaSum = new double[this.K];
        for (int k = 0; k < this.K; ++k) {
            digammaLambdaSum[k] = FastMath.digamma((double)this.W * this.eta + this.lambdaSums.getD(k));
        }
        List<List<Vec>> docSplits = ListUtils.splitList(docs, SystemInfo.LogicalCores);
        final CountDownLatch latch = new CountDownLatch(docSplits.size());
        for (final List<Vec> docsSub : docSplits) {
            ex.submit(new Runnable(){

                @Override
                public void run() {
                    for (Vec doc : docsSub) {
                        for (IndexValue iv : doc) {
                            int indx = iv.getIndex();
                            if (OnlineLDAsvi.this.lastUsed[indx] == OnlineLDAsvi.this.t) continue;
                            for (int k = 0; k < OnlineLDAsvi.this.K; ++k) {
                                double lambda_kj = ((Vec)OnlineLDAsvi.this.lambda.get(k)).get(indx);
                                double logBeta_kj = FastMath.digamma(OnlineLDAsvi.this.eta + lambda_kj) - digammaLambdaSum[k];
                                ((Vec)OnlineLDAsvi.this.ELogBeta.get(k)).set(indx, logBeta_kj);
                                ((Vec)OnlineLDAsvi.this.ExpELogBeta.get(k)).set(indx, FastMath.exp(logBeta_kj));
                            }
                            ((OnlineLDAsvi)OnlineLDAsvi.this).lastUsed[indx] = OnlineLDAsvi.this.t;
                        }
                    }
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex1) {
            Logger.getLogger(OnlineLDAsvi.class.getName()).log(Level.SEVERE, null, ex1);
        }
    }

    private void prepareGammaTheta(Vec gamma_i, Vec eLogTheta_i, Vec expLogTheta_i, Random rand) {
        int j;
        double lambdaInv = (double)(this.W * this.K) / ((double)this.D * 100.0);
        for (j = 0; j < gamma_i.length(); ++j) {
            gamma_i.set(j, OnlineLDAsvi.sampleExpoDist(lambdaInv, rand.nextDouble()) + this.eta);
        }
        this.expandPsiMinusPsiSum(gamma_i, gamma_i.sum(), eLogTheta_i);
        for (j = 0; j < eLogTheta_i.length(); ++j) {
            expLogTheta_i.set(j, FastMath.exp(eLogTheta_i.get(j)));
        }
    }

    private void computePhi(Vec doc, int[] indexMap, double[] phiCols, int K, Vec gamma_d, Vec ELogTheta_d, Vec ExpELogTheta_d) {
        int i;
        int pos = 0;
        SparseVector updateVec = new SparseVector(indexMap, phiCols, this.W, doc.nnz());
        for (IndexValue iv : doc) {
            int wordIndex = iv.getIndex();
            double sum = 0.0;
            for (i = 0; i < ExpELogTheta_d.length(); ++i) {
                sum += ExpELogTheta_d.get(i) * this.ExpELogBeta.get(i).get(wordIndex);
            }
            indexMap[pos] = wordIndex;
            phiCols[pos] = iv.getValue() / (sum + 1.0E-15);
            ++pos;
        }
        for (int iter = 0; iter < 100; ++iter) {
            double meanAbsChange = 0.0;
            double gamma_d_sum = 0.0;
            for (int k = 0; k < K; ++k) {
                double origGamma_dk = gamma_d.get(k);
                double gamma_dtk = this.alpha;
                gamma_d.set(k, gamma_dtk += ExpELogTheta_d.get(k) * updateVec.dot(this.ExpELogBeta.get(k)));
                meanAbsChange += Math.abs(gamma_dtk - origGamma_dk);
                gamma_d_sum += gamma_dtk;
            }
            this.expandPsiMinusPsiSum(gamma_d, gamma_d_sum, ELogTheta_d);
            for (i = 0; i < ELogTheta_d.length(); ++i) {
                ExpELogTheta_d.set(i, FastMath.exp(ELogTheta_d.get(i)));
            }
            int indx = 0;
            for (IndexValue iv : doc) {
                int wordIndex = iv.getIndex();
                double sum = 0.0;
                for (int i2 = 0; i2 < ExpELogTheta_d.length(); ++i2) {
                    sum += ExpELogTheta_d.get(i2) * this.ExpELogBeta.get(i2).get(wordIndex);
                }
                phiCols[indx] = iv.getValue() / (sum + 1.0E-15);
                ++indx;
            }
            if (meanAbsChange < 0.001 * (double)K) break;
        }
    }

    private void initialize() {
        if (this.K < 1) {
            throw new FailedToFitException("Topic number for LDA has not yet been specified");
        }
        if (this.D < 1) {
            throw new FailedToFitException("Expected number of documents has not yet been specified");
        }
        if (this.W < 1) {
            throw new FailedToFitException("Topic vocuabulary size has not yet been specified");
        }
        this.t = 0;
        this.lambda = new ArrayList<Vec>(this.K);
        this.lambdaLocks = new ArrayList<Lock>(this.K);
        this.lambdaSums = new DoubleList(this.K);
        this.ELogBeta = new ArrayList<Vec>(this.K);
        this.ExpELogBeta = new ArrayList<Vec>(this.K);
        this.lastUsed = new int[this.W];
        Arrays.fill(this.lastUsed, -1);
        double lambdaInv = (double)(this.K * this.W) / ((double)this.D * 100.0);
        Random rand = RandomUtil.getRandom();
        for (int i = 0; i < this.K; ++i) {
            DenseVector lambda_i = new DenseVector(this.W);
            this.lambda.add(new ScaledVector(lambda_i));
            this.lambdaLocks.add(new ReentrantLock());
            this.ELogBeta.add(new DenseVector(this.W));
            this.ExpELogBeta.add(new DenseVector(this.W));
            double rowSum = 0.0;
            for (int j = 0; j < this.W; ++j) {
                double sample = OnlineLDAsvi.sampleExpoDist(lambdaInv, rand.nextDouble()) + this.eta;
                ((Vec)lambda_i).set(j, sample);
                rowSum += sample;
            }
            this.lambdaSums.add(rowSum);
        }
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }
}

