/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.lda.psf;

import com.tencent.angel.PartitionKey;
import com.tencent.angel.exception.AngelException;
import com.tencent.angel.ml.lda.psf.LikelihoodParam;
import com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult;
import com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarPartitionAggrResult;
import com.tencent.angel.ml.matrix.psf.get.base.GetFunc;
import com.tencent.angel.ml.matrix.psf.get.base.GetParam;
import com.tencent.angel.ml.matrix.psf.get.base.GetResult;
import com.tencent.angel.ml.matrix.psf.get.base.PartitionGetParam;
import com.tencent.angel.ml.matrix.psf.get.base.PartitionGetResult;
import com.tencent.angel.ps.storage.vector.ServerIntIntRow;
import com.tencent.angel.ps.storage.vector.ServerRow;
import com.tencent.angel.ps.storage.vector.ServerRowUtils;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math.special.Gamma;

public class LikelihoodFunc
extends GetFunc {
    private static final Log LOG = LogFactory.getLog(LikelihoodFunc.class);

    public LikelihoodFunc(int matrixId, float beta) {
        super((GetParam)new LikelihoodParam(matrixId, beta));
    }

    public LikelihoodFunc() {
        super(null);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public PartitionGetResult partitionGet(PartitionGetParam partParam) {
        PartitionKey pkey = partParam.getPartKey();
        pkey = this.psContext.getMatrixMetaManager().getMatrixMeta(pkey.getMatrixId()).getPartitionMeta(pkey.getPartitionId()).getPartitionKey();
        int ws = pkey.getStartRow();
        int es = pkey.getEndRow();
        if (partParam instanceof LikelihoodParam.LikelihoodPartParam) {
            LikelihoodParam.LikelihoodPartParam param = (LikelihoodParam.LikelihoodPartParam)partParam;
            float beta = param.getBeta();
            double lgammaBeta = Gamma.logGamma((double)beta);
            double ll = 0.0;
            for (int w = ws; w < es; ++w) {
                ServerRow row = this.psContext.getMatrixStorageManager().getRow(pkey, w);
                try {
                    row.startRead();
                    ll += this.likelihood((ServerIntIntRow)row, beta, lgammaBeta);
                    continue;
                }
                finally {
                    row.endRead();
                }
            }
            return new ScalarPartitionAggrResult(ll);
        }
        throw new AngelException("Should be LikelihoodParam.LikelihoodPartParam");
    }

    private double likelihood(ServerIntIntRow row, float beta, double lgammaBeta) {
        int len = (int)(row.getEndCol() - row.getStartCol());
        double ll = 0.0;
        if (row.isDense()) {
            int[] values = ServerRowUtils.getVector((ServerIntIntRow)row).getStorage().getValues();
            for (int i = 0; i < len; ++i) {
                if (values[i] <= 0) continue;
                ll += Gamma.logGamma((double)((float)values[i] + beta)) - lgammaBeta;
            }
        } else if (row.isSparse()) {
            ObjectIterator iterator = ServerRowUtils.getVector((ServerIntIntRow)row).getStorage().entryIterator();
            while (iterator.hasNext()) {
                Int2IntMap.Entry entry = (Int2IntMap.Entry)iterator.next();
                int val = entry.getIntValue();
                if (val <= 0) continue;
                ll += Gamma.logGamma((double)((float)val + beta)) - lgammaBeta;
            }
        } else {
            throw new AngelException("should be ServerDenseIntRow");
        }
        return ll;
    }

    public GetResult merge(List<PartitionGetResult> partResults) {
        double ll = 0.0;
        for (PartitionGetResult r : partResults) {
            ll += ((ScalarPartitionAggrResult)r).result;
        }
        return new ScalarAggrResult(ll);
    }
}

