/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.lda.algo;

import com.tencent.angel.PartitionKey;
import com.tencent.angel.exception.AngelException;
import com.tencent.angel.ml.lda.LDAModel;
import com.tencent.angel.ml.lda.algo.BinarySearch;
import com.tencent.angel.ml.lda.algo.CSRTokens;
import com.tencent.angel.ml.lda.algo.structures.FTree;
import com.tencent.angel.ml.lda.algo.structures.I2ITranverseMap;
import com.tencent.angel.ml.lda.algo.structures.S2BTraverseMap;
import com.tencent.angel.ml.lda.algo.structures.S2ITraverseMap;
import com.tencent.angel.ml.lda.algo.structures.S2STraverseMap;
import com.tencent.angel.ml.lda.algo.structures.TraverseHashMap;
import com.tencent.angel.ml.lda.psf.CSRPartUpdateParam;
import com.tencent.angel.ml.lda.psf.PartCSRResult;
import com.tencent.angel.ml.lda.psf.UpdatePartFunc;
import com.tencent.angel.ml.matrix.psf.update.base.PartitionUpdateParam;
import com.tencent.angel.ml.matrix.psf.update.base.UpdateFunc;
import com.tencent.angel.ml.matrix.psf.update.base.VoidResult;
import com.tencent.angel.psagent.PSAgentContext;
import com.tencent.angel.psagent.matrix.transport.FutureResult;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import java.util.Random;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math.special.Gamma;

public class Sampler {
    public CSRTokens data;
    public LDAModel model;
    public FTree tree;
    public float[] psum;
    public int[] tidx;
    public long[] nk;
    public int[] wk;
    public float[] maxDoc;
    public int K;
    public float alpha;
    public float beta;
    public float vbeta;
    public double lgammaAlpha;
    public double lgammaBeta;
    public double lgammaAlphaSum;
    public boolean error = false;
    private static final Log LOG = LogFactory.getLog(Sampler.class);

    public Sampler(CSRTokens data, LDAModel model) {
        this.data = data;
        this.model = model;
        this.K = model.K();
        this.alpha = model.alpha();
        this.beta = model.beta();
        this.vbeta = (float)data.n_words * this.beta;
        this.lgammaBeta = Gamma.logGamma((double)this.beta);
        this.lgammaAlpha = Gamma.logGamma((double)this.alpha);
        this.lgammaAlphaSum = Gamma.logGamma((double)(this.alpha * (float)this.K));
        this.nk = new long[this.K];
        this.wk = new int[this.K];
        this.tidx = new int[this.K];
        this.psum = new float[this.K];
        this.maxDoc = new float[Math.min(data.maxDocLen, this.K)];
        this.tree = new FTree(this.K);
    }

    public Future<VoidResult> sample(PartitionKey pkey, PartCSRResult csr) {
        return this.sample(pkey, csr, true);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Future<VoidResult> sample(PartitionKey pkey, PartCSRResult csr, boolean update2) {
        int ws = pkey.getStartRow();
        int we = pkey.getEndRow();
        Random rand = new Random(System.currentTimeMillis());
        Int2IntOpenHashMap[] updates = null;
        try {
            if (update2) {
                updates = new Int2IntOpenHashMap[we - ws];
            }
            for (int w = ws; w < we; ++w) {
                if (this.data.ws[w + 1] - this.data.ws[w] == 0) continue;
                if (!csr.read(this.wk)) {
                    throw new AngelException("some error happens");
                }
                this.buildFTree();
                if (update2) {
                    updates[w - ws] = new Int2IntOpenHashMap();
                }
                for (int wi = this.data.ws[w]; wi < this.data.ws[w + 1]; ++wi) {
                    int d = this.data.docs[wi];
                    int tt = this.data.topics[this.data.dindex[wi]];
                    if (update2 && this.wk[tt] <= 0) {
                        LOG.error((Object)String.format("Error wk[%d] = %d for word %d", tt, this.wk[tt], w));
                        continue;
                    }
                    if (update2) {
                        int n = tt;
                        this.wk[n] = this.wk[n] - 1;
                        int n2 = tt;
                        this.nk[n2] = this.nk[n2] - 1L;
                        this.tree.update(tt, ((float)this.wk[tt] + this.beta) / ((float)this.nk[tt] + this.vbeta));
                        updates[w - ws].addTo(tt, -1);
                    }
                    String string = this.data.docIds[d];
                    synchronized (string) {
                        float sum;
                        if (this.data.dks[d] == null) {
                            sum = this.build(d, this.maxDoc, this.tree, tt);
                        } else {
                            this.data.dks[d].dec(tt);
                            sum = this.build(this.data.dks[d]);
                        }
                        float u = rand.nextFloat() * (sum + this.alpha * this.tree.first());
                        if (u < sum) {
                            int idx;
                            u = rand.nextFloat() * sum;
                            if (this.data.dks[d] == null) {
                                int length = this.data.ds[d + 1] - this.data.ds[d];
                                idx = BinarySearch.binarySearch(this.maxDoc, u, 0, length - 1);
                                tt = this.data.topics[this.data.ds[d] + idx];
                            } else if (this.data.dks[d].size == 1) {
                                tt = this.tidx[0];
                            } else {
                                idx = BinarySearch.binarySearch(this.psum, u, 0, this.data.dks[d].size - 1);
                                tt = this.tidx[idx];
                            }
                        } else {
                            tt = this.tree.sample(rand.nextFloat() * this.tree.first());
                        }
                        if (this.data.dks[d] != null) {
                            this.data.dks[d].inc(tt);
                        }
                    }
                    if (update2) {
                        int n = tt;
                        this.wk[n] = this.wk[n] + 1;
                        int n3 = tt;
                        this.nk[n3] = this.nk[n3] + 1L;
                        this.tree.update(tt, ((float)this.wk[tt] + this.beta) / ((float)this.nk[tt] + this.vbeta));
                        updates[w - ws].addTo(tt, 1);
                    }
                    this.data.topics[this.data.dindex[wi]] = tt;
                }
            }
        }
        finally {
            csr.clear();
        }
        FutureResult future = null;
        if (update2) {
            CSRPartUpdateParam param = new CSRPartUpdateParam(this.model.wtMat().getMatrixId(), pkey, updates);
            future = PSAgentContext.get().getMatrixTransportClient().update((UpdateFunc)new UpdatePartFunc(null), (PartitionUpdateParam)param);
        }
        return future;
    }

    private void buildFTree() {
        for (int k = 0; k < this.K; ++k) {
            this.psum[k] = ((float)this.wk[k] + this.beta) / ((float)this.nk[k] + this.vbeta);
        }
        this.tree.build(this.psum);
    }

    public float build(int d, float[] p, FTree tree, int remove) {
        float psum = 0.0f;
        boolean find = false;
        for (int i = this.data.ds[d]; i < this.data.ds[d + 1]; ++i) {
            int tt = this.data.topics[i];
            if (!find && tt == remove) {
                find = true;
            } else {
                psum += tree.get(tt);
            }
            p[i - this.data.ds[d]] = psum;
        }
        return psum;
    }

    private float build(S2STraverseMap dk) {
        float sum = 0.0f;
        for (int i = 0; i < dk.size; ++i) {
            int k = dk.key[dk.idx[i]];
            short v = dk.value[dk.idx[i]];
            this.psum[i] = sum += (float)v * this.tree.get(k);
            this.tidx[i] = k;
        }
        return sum;
    }

    private float build(S2BTraverseMap dk) {
        float sum = 0.0f;
        for (int i = 0; i < dk.size; ++i) {
            int k = dk.key[dk.idx[i]];
            short v = dk.value[dk.idx[i]];
            this.psum[i] = sum += (float)v * this.tree.get(k);
            this.tidx[i] = k;
        }
        return sum;
    }

    private float build(S2ITraverseMap dk) {
        float sum = 0.0f;
        for (int i = 0; i < dk.size; ++i) {
            int k = dk.key[dk.idx[i]];
            int v = dk.value[dk.idx[i]];
            this.psum[i] = sum += (float)v * this.tree.get(k);
            this.tidx[i] = k;
        }
        return sum;
    }

    private float build(I2ITranverseMap dk) {
        float sum = 0.0f;
        for (int i = 0; i < dk.size; ++i) {
            int k = dk.key[dk.idx[i]];
            int v = dk.value[dk.idx[i]];
            this.psum[i] = sum += (float)v * this.tree.get(k);
            this.tidx[i] = k;
        }
        return sum;
    }

    private float build(TraverseHashMap dk) {
        if (dk instanceof S2STraverseMap) {
            return this.build((S2STraverseMap)dk);
        }
        if (dk instanceof S2BTraverseMap) {
            return this.build((S2BTraverseMap)dk);
        }
        if (dk instanceof S2ITraverseMap) {
            return this.build((S2ITraverseMap)dk);
        }
        if (dk instanceof I2ITranverseMap) {
            return this.build((I2ITranverseMap)dk);
        }
        return 0.0f;
    }

    public Future<VoidResult> initialize(PartitionKey pkey) {
        return this.initialize(pkey, true);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Future<VoidResult> initialize(PartitionKey pkey, boolean update2) {
        int ws = pkey.getStartRow();
        int es = pkey.getEndRow();
        Random rand = new Random(System.currentTimeMillis());
        Int2IntOpenHashMap[] updates = null;
        if (update2) {
            updates = new Int2IntOpenHashMap[es - ws];
        }
        for (int w = ws; w < es; ++w) {
            if (this.data.ws[w + 1] == this.data.ws[w]) continue;
            if (update2) {
                updates[w - ws] = new Int2IntOpenHashMap();
            }
            for (int wi = this.data.ws[w]; wi < this.data.ws[w + 1]; ++wi) {
                int d;
                int t;
                this.data.topics[this.data.dindex[wi]] = t = rand.nextInt(this.K);
                if (update2) {
                    updates[w - ws].addTo(t, 1);
                    int n = t;
                    this.nk[n] = this.nk[n] + 1L;
                }
                if (this.data.dks[d = this.data.docs[wi]] == null) continue;
                I2ITranverseMap i2ITranverseMap = this.data.dks[d];
                synchronized (i2ITranverseMap) {
                    this.data.dks[d].inc(t);
                    continue;
                }
            }
        }
        FutureResult future = null;
        if (update2) {
            CSRPartUpdateParam param = new CSRPartUpdateParam(this.model.wtMat().getMatrixId(), pkey, updates);
            future = PSAgentContext.get().getMatrixTransportClient().update((UpdateFunc)new UpdatePartFunc(null), (PartitionUpdateParam)param);
        }
        return future;
    }

    public Sampler set(long[] nk) {
        System.arraycopy(nk, 0, this.nk, 0, this.K);
        return this;
    }

    public Future<VoidResult> reset(PartitionKey pkey) {
        int ws = pkey.getStartRow();
        int es = pkey.getEndRow();
        Int2IntOpenHashMap[] updates = new Int2IntOpenHashMap[es - ws];
        for (int w = ws; w < es; ++w) {
            if (this.data.ws[w + 1] == this.data.ws[w]) continue;
            updates[w - ws] = new Int2IntOpenHashMap();
            for (int wi = this.data.ws[w]; wi < this.data.ws[w + 1]; ++wi) {
                int tt = this.data.topics[this.data.dindex[wi]];
                updates[w - ws].addTo(tt, 1);
                int n = tt;
                this.nk[n] = this.nk[n] + 1L;
            }
        }
        CSRPartUpdateParam param = new CSRPartUpdateParam(this.model.wtMat().getMatrixId(), pkey, updates);
        FutureResult future = PSAgentContext.get().getMatrixTransportClient().update((UpdateFunc)new UpdatePartFunc(null), (PartitionUpdateParam)param);
        return future;
    }

    public void initForInference(PartitionKey pkey) {
        this.initialize(pkey, false);
    }

    public void inference(PartitionKey pkey, PartCSRResult csr) {
        this.sample(pkey, csr, false);
    }
}

