/*
 * Decompiled with CFR 0.152.
 */
package hivemall.topicmodel;

import com.google.common.base.Preconditions;
import hivemall.UDTFWithOptions;
import hivemall.annotations.VisibleForTesting;
import hivemall.topicmodel.AbstractProbabilisticTopicModel;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.FileUtils;
import hivemall.utils.io.NIOUtils;
import hivemall.utils.io.NioStatefulSegment;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.Reporter;

public abstract class ProbabilisticTopicModelBaseUDTF
extends UDTFWithOptions {
    private static final Log logger = LogFactory.getLog(ProbabilisticTopicModelBaseUDTF.class);
    public static final int DEFAULT_TOPICS = 10;
    protected int topics = 10;
    protected int iterations = 10;
    protected double eps = 0.1;
    protected int miniBatchSize = 128;
    protected String[][] miniBatch;
    protected int miniBatchCount;
    protected transient AbstractProbabilisticTopicModel model;
    protected ListObjectInspector wordCountsOI;
    protected transient NioStatefulSegment fileIO;
    protected transient ByteBuffer inputBuf;
    private float cumPerplexity;

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("k", "topics", true, "The number of topics [default: 10]");
        opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]");
        opts.addOption("eps", "epsilon", true, "Check convergence based on the difference of perplexity [default: 1E-1]");
        opts.addOption("s", "mini_batch_size", true, "Repeat model updating per mini-batch [default: 128]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        CommandLine cl = null;
        if (argOIs.length >= 2) {
            String rawArgs = HiveUtils.getConstString(argOIs[1]);
            cl = this.parseOptions(rawArgs);
            this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 10);
            this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 10);
            if (this.iterations < 1) {
                throw new UDFArgumentException("'-iterations' must be greater than or equals to 1: " + this.iterations);
            }
            this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 0.1);
            this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128);
        }
        return cl;
    }

    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length < 1) {
            throw new UDFArgumentException("_FUNC_ takes 1 arguments: array<string> words [, const string options]");
        }
        this.wordCountsOI = HiveUtils.asListOI(argOIs[0]);
        HiveUtils.validateFeatureOI(this.wordCountsOI.getListElementObjectInspector());
        this.processOptions(argOIs);
        this.model = null;
        this.miniBatch = new String[this.miniBatchSize][];
        this.miniBatchCount = 0;
        this.cumPerplexity = 0.0f;
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<Object> fieldOIs = new ArrayList<Object>();
        fieldNames.add("topic");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("word");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
        fieldNames.add("score");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Nonnull
    protected abstract AbstractProbabilisticTopicModel createModel();

    public void process(Object[] args) throws HiveException {
        if (this.model == null) {
            this.model = this.createModel();
        }
        Preconditions.checkArgument((args.length >= 1 ? 1 : 0) != 0);
        Object arg0 = args[0];
        if (arg0 == null) {
            return;
        }
        int length = this.wordCountsOI.getListLength(arg0);
        String[] wordCounts = new String[length];
        int j = 0;
        for (int i = 0; i < length; ++i) {
            String s;
            Object o = this.wordCountsOI.getListElement(arg0, i);
            if (o == null) {
                throw new HiveException("Given feature vector contains invalid null elements");
            }
            wordCounts[j] = s = o.toString();
            ++j;
        }
        if (j == 0) {
            return;
        }
        this.model.accumulateDocCount();
        this.update(wordCounts);
        this.recordTrainSampleToTempFile(wordCounts);
    }

    protected void recordTrainSampleToTempFile(@Nonnull String[] wordCounts) throws HiveException {
        if (this.iterations == 1) {
            return;
        }
        ByteBuffer buf = this.inputBuf;
        NioStatefulSegment dst = this.fileIO;
        if (buf == null) {
            File file;
            try {
                file = File.createTempFile("hivemall_topicmodel", ".sgmt");
                file.deleteOnExit();
                if (!file.canWrite()) {
                    throw new UDFArgumentException("Cannot write a temporary file: " + file.getAbsolutePath());
                }
                logger.info((Object)("Record training samples to a file: " + file.getAbsolutePath()));
            }
            catch (IOException ioe) {
                throw new UDFArgumentException((Throwable)ioe);
            }
            catch (Throwable e) {
                throw new UDFArgumentException(e);
            }
            this.inputBuf = buf = ByteBuffer.allocateDirect(0x100000);
            this.fileIO = dst = new NioStatefulSegment(file, false);
        }
        int wcLengthTotal = 0;
        for (String wc : wordCounts) {
            if (wc == null) continue;
            wcLengthTotal += wc.length();
        }
        int recordBytes = 4 + 4 * wordCounts.length + wcLengthTotal * 2;
        int requiredBytes = 4 + recordBytes;
        int remain = buf.remaining();
        if (remain < requiredBytes) {
            ProbabilisticTopicModelBaseUDTF.writeBuffer(buf, dst);
        }
        buf.putInt(recordBytes);
        buf.putInt(wordCounts.length);
        for (String wc : wordCounts) {
            NIOUtils.putString(wc, buf);
        }
    }

    private void update(@Nonnull String[] wordCounts) {
        this.miniBatch[this.miniBatchCount] = wordCounts;
        ++this.miniBatchCount;
        if (this.miniBatchCount == this.miniBatchSize) {
            this.train();
        }
    }

    protected void train() {
        if (this.miniBatchCount == 0) {
            return;
        }
        this.model.train(this.miniBatch);
        this.cumPerplexity += this.model.computePerplexity();
        Arrays.fill((Object[])this.miniBatch, null);
        this.miniBatchCount = 0;
    }

    private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefulSegment dst) throws HiveException {
        srcBuf.flip();
        try {
            dst.write(srcBuf);
        }
        catch (IOException e) {
            throw new HiveException("Exception causes while writing a buffer to file", (Throwable)e);
        }
        srcBuf.clear();
    }

    public void close() throws HiveException {
        if (this.model == null) {
            logger.warn((Object)"Model is not initialized bacause no training exmples to learn. Better to revise input data.");
            return;
        }
        if (this.model.getDocCount() == 0L) {
            logger.warn((Object)"model.getDocCount() is zero because no training exmples to learn. Better to revise input data.");
            this.model = null;
            return;
        }
        this.finalizeTraining();
        this.forwardModel();
        this.model = null;
    }

    @VisibleForTesting
    void finalizeTraining() throws HiveException {
        if (this.miniBatchCount > 0) {
            this.model.train((String[][])Arrays.copyOfRange(this.miniBatch, 0, this.miniBatchCount));
        }
        if (this.iterations > 1) {
            this.runIterativeTraining(this.iterations);
        }
    }

    protected final void runIterativeTraining(@Nonnegative int iterations) throws HiveException {
        block34: {
            Reporter reporter;
            ByteBuffer buf = this.inputBuf;
            NioStatefulSegment dst = this.fileIO;
            assert (buf != null);
            assert (dst != null);
            long numTrainingExamples = this.model.getDocCount();
            long numTrain = numTrainingExamples / (long)this.miniBatchSize;
            if (numTrainingExamples % (long)this.miniBatchSize != 0L) {
                ++numTrain;
            }
            Counters.Counter iterCounter = (reporter = this.getReporter()) == null ? null : reporter.getCounter("hivemall.topicmodel.ProbabilisticTopicModel$Counter", "iteration");
            try {
                int iter;
                if (dst.getPosition() == 0L) {
                    if (buf.position() == 0) {
                        return;
                    }
                    buf.flip();
                    float perplexity = this.cumPerplexity / (float)numTrain;
                    for (iter = 2; iter <= iterations; ++iter) {
                        float perplexityPrev = perplexity;
                        this.cumPerplexity = 0.0f;
                        ProbabilisticTopicModelBaseUDTF.reportProgress(reporter);
                        ProbabilisticTopicModelBaseUDTF.setCounterValue(iterCounter, iter);
                        while (buf.remaining() > 0) {
                            int recordBytes = buf.getInt();
                            assert (recordBytes > 0) : recordBytes;
                            int wcLength = buf.getInt();
                            String[] wordCounts = new String[wcLength];
                            for (int j = 0; j < wcLength; ++j) {
                                wordCounts[j] = NIOUtils.getString(buf);
                            }
                            this.update(wordCounts);
                        }
                        buf.rewind();
                        perplexity = this.cumPerplexity / (float)numTrain;
                        logger.info((Object)("Mean perplexity over mini-batches: " + perplexity));
                        if ((double)Math.abs(perplexityPrev - perplexity) < this.eps) break;
                    }
                    logger.info((Object)("Performed " + Math.min(iter, iterations) + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus " + NumberUtils.formatNumber(numTrainingExamples * (long)Math.min(iter, iterations)) + " training updates in total) "));
                    break block34;
                }
                if (buf.remaining() > 0) {
                    ProbabilisticTopicModelBaseUDTF.writeBuffer(buf, dst);
                }
                try {
                    dst.flush();
                }
                catch (IOException e) {
                    throw new HiveException("Failed to flush a file: " + dst.getFile().getAbsolutePath(), (Throwable)e);
                }
                if (logger.isInfoEnabled()) {
                    File tmpFile = dst.getFile();
                    logger.info((Object)("Wrote " + numTrainingExamples + " records to a temporary file for iterative training: " + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + ")"));
                }
                float perplexity = this.cumPerplexity / (float)numTrain;
                for (iter = 2; iter <= iterations; ++iter) {
                    float perplexityPrev = perplexity;
                    this.cumPerplexity = 0.0f;
                    ProbabilisticTopicModelBaseUDTF.setCounterValue(iterCounter, iter);
                    buf.clear();
                    dst.resetPosition();
                    while (true) {
                        int recordBytes;
                        int remain;
                        int bytesRead;
                        ProbabilisticTopicModelBaseUDTF.reportProgress(reporter);
                        try {
                            bytesRead = dst.read(buf);
                        }
                        catch (IOException e) {
                            throw new HiveException("Failed to read a file: " + dst.getFile().getAbsolutePath(), (Throwable)e);
                        }
                        if (bytesRead == 0) break;
                        assert (bytesRead > 0) : bytesRead;
                        buf.flip();
                        if (remain < 4) {
                            throw new HiveException("Illegal file format was detected");
                        }
                        for (remain = buf.remaining(); remain >= 4; remain -= recordBytes) {
                            int pos = buf.position();
                            recordBytes = buf.getInt();
                            if ((remain -= 4) < recordBytes) {
                                buf.position(pos);
                                break;
                            }
                            int wcLength = buf.getInt();
                            String[] wordCounts = new String[wcLength];
                            for (int j = 0; j < wcLength; ++j) {
                                wordCounts[j] = NIOUtils.getString(buf);
                            }
                            this.update(wordCounts);
                        }
                        buf.compact();
                    }
                    perplexity = this.cumPerplexity / (float)numTrain;
                    logger.info((Object)("Mean perplexity over mini-batches: " + perplexity));
                    if ((double)Math.abs(perplexityPrev - perplexity) < this.eps) break;
                }
                logger.info((Object)("Performed " + Math.min(iter, iterations) + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on a secondary storage (thus " + NumberUtils.formatNumber(numTrainingExamples * (long)Math.min(iter, iterations)) + " training updates in total)"));
            }
            catch (Throwable e) {
                throw new HiveException("Exception caused in the iterative training", e);
            }
            finally {
                try {
                    dst.close(true);
                }
                catch (IOException e) {
                    throw new HiveException("Failed to close a file: " + dst.getFile().getAbsolutePath(), (Throwable)e);
                }
                this.inputBuf = null;
                this.fileIO = null;
            }
        }
    }

    protected void forwardModel() throws HiveException {
        IntWritable topicIdx = new IntWritable();
        Text word = new Text();
        FloatWritable score = new FloatWritable();
        Object[] forwardObjs = new Object[]{topicIdx, word, score};
        for (int k = 0; k < this.topics; ++k) {
            topicIdx.set(k);
            SortedMap<Float, List<String>> topicWords = this.model.getTopicWords(k);
            if (topicWords == null) continue;
            for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
                score.set(e.getKey().floatValue());
                for (String v : e.getValue()) {
                    word.set(v);
                    this.forward(forwardObjs);
                }
            }
        }
        logger.info((Object)("Forwarded topic words each of " + this.topics + " topics"));
    }

    @VisibleForTesting
    float getWordScore(String label, int k) {
        return this.model.getWordScore(label, k);
    }

    @VisibleForTesting
    SortedMap<Float, List<String>> getTopicWords(int k) {
        return this.model.getTopicWords(k);
    }

    @VisibleForTesting
    float[] getTopicDistribution(@Nonnull String[] doc) {
        return this.model.getTopicDistribution(doc);
    }
}

