/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.clustering.lda.cvb;

import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.cli2.Option;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.clustering.lda.cvb.CVB0DocInferenceMapper;
import org.apache.mahout.clustering.lda.cvb.CVB0TopicTermVectorNormalizerMapper;
import org.apache.mahout.clustering.lda.cvb.CachingCVB0Mapper;
import org.apache.mahout.clustering.lda.cvb.CachingCVB0PerplexityMapper;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.common.mapreduce.VectorSumReducer;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CVB0Driver
extends AbstractJob {
    private static final Logger log = LoggerFactory.getLogger(CVB0Driver.class);
    public static final String NUM_TOPICS = "num_topics";
    public static final String NUM_TERMS = "num_terms";
    public static final String DOC_TOPIC_SMOOTHING = "doc_topic_smoothing";
    public static final String TERM_TOPIC_SMOOTHING = "term_topic_smoothing";
    public static final String DICTIONARY = "dictionary";
    public static final String DOC_TOPIC_OUTPUT = "doc_topic_output";
    public static final String MODEL_TEMP_DIR = "topic_model_temp_dir";
    public static final String ITERATION_BLOCK_SIZE = "iteration_block_size";
    public static final String RANDOM_SEED = "random_seed";
    public static final String TEST_SET_FRACTION = "test_set_fraction";
    public static final String NUM_TRAIN_THREADS = "num_train_threads";
    public static final String NUM_UPDATE_THREADS = "num_update_threads";
    public static final String MAX_ITERATIONS_PER_DOC = "max_doc_topic_iters";
    public static final String MODEL_WEIGHT = "prev_iter_mult";
    public static final String NUM_REDUCE_TASKS = "num_reduce_tasks";
    public static final String BACKFILL_PERPLEXITY = "backfill_perplexity";
    private static final String MODEL_PATHS = "mahout.lda.cvb.modelPath";
    private static final double DEFAULT_CONVERGENCE_DELTA = 0.0;
    private static final double DEFAULT_DOC_TOPIC_SMOOTHING = 1.0E-4;
    private static final double DEFAULT_TERM_TOPIC_SMOOTHING = 1.0E-4;
    private static final int DEFAULT_ITERATION_BLOCK_SIZE = 10;
    private static final double DEFAULT_TEST_SET_FRACTION = 0.0;
    private static final int DEFAULT_NUM_TRAIN_THREADS = 4;
    private static final int DEFAULT_NUM_UPDATE_THREADS = 1;
    private static final int DEFAULT_MAX_ITERATIONS_PER_DOC = 10;
    private static final int DEFAULT_NUM_REDUCE_TASKS = 10;

    public int run(String[] args) throws Exception {
        this.addInputOption();
        this.addOutputOption();
        this.addOption((Option)DefaultOptionCreator.maxIterationsOption().create());
        this.addOption("convergenceDelta", "cd", "The convergence delta value", String.valueOf(0.0));
        this.addOption((Option)DefaultOptionCreator.overwriteOption().create());
        this.addOption(NUM_TOPICS, "k", "Number of topics to learn", true);
        this.addOption(NUM_TERMS, "nt", "Vocabulary size", false);
        this.addOption(DOC_TOPIC_SMOOTHING, "a", "Smoothing for document/topic distribution", String.valueOf(1.0E-4));
        this.addOption(TERM_TOPIC_SMOOTHING, "e", "Smoothing for topic/term distribution", String.valueOf(1.0E-4));
        this.addOption(DICTIONARY, "dict", "Path to term-dictionary file(s) (glob expression supported)", false);
        this.addOption(DOC_TOPIC_OUTPUT, "dt", "Output path for the training doc/topic distribution", false);
        this.addOption(MODEL_TEMP_DIR, "mt", "Path to intermediate model path (useful for restarting)", false);
        this.addOption(ITERATION_BLOCK_SIZE, "block", "Number of iterations per perplexity check", String.valueOf(10));
        this.addOption(RANDOM_SEED, "seed", "Random seed", false);
        this.addOption(TEST_SET_FRACTION, "tf", "Fraction of data to hold out for testing", String.valueOf(0.0));
        this.addOption(NUM_TRAIN_THREADS, "ntt", "number of threads per mapper to train with", String.valueOf(4));
        this.addOption(NUM_UPDATE_THREADS, "nut", "number of threads per mapper to update the model with", String.valueOf(1));
        this.addOption(MAX_ITERATIONS_PER_DOC, "mipd", "max number of iterations per doc for p(topic|doc) learning", String.valueOf(10));
        this.addOption(NUM_REDUCE_TASKS, null, "number of reducers to use during model estimation", String.valueOf(10));
        this.addOption(CVB0Driver.buildOption(BACKFILL_PERPLEXITY, null, "enable backfilling of missing perplexity values", false, false, null));
        if (this.parseArguments(args) == null) {
            return -1;
        }
        int numTopics = Integer.parseInt(this.getOption(NUM_TOPICS));
        Path inputPath = this.getInputPath();
        Path topicModelOutputPath = this.getOutputPath();
        int maxIterations = Integer.parseInt(this.getOption("maxIter"));
        int iterationBlockSize = Integer.parseInt(this.getOption(ITERATION_BLOCK_SIZE));
        double convergenceDelta = Double.parseDouble(this.getOption("convergenceDelta"));
        double alpha = Double.parseDouble(this.getOption(DOC_TOPIC_SMOOTHING));
        double eta = Double.parseDouble(this.getOption(TERM_TOPIC_SMOOTHING));
        int numTrainThreads = Integer.parseInt(this.getOption(NUM_TRAIN_THREADS));
        int numUpdateThreads = Integer.parseInt(this.getOption(NUM_UPDATE_THREADS));
        int maxItersPerDoc = Integer.parseInt(this.getOption(MAX_ITERATIONS_PER_DOC));
        Path dictionaryPath = this.hasOption(DICTIONARY) ? new Path(this.getOption(DICTIONARY)) : null;
        int numTerms = this.hasOption(NUM_TERMS) ? Integer.parseInt(this.getOption(NUM_TERMS)) : CVB0Driver.getNumTerms(this.getConf(), dictionaryPath);
        Path docTopicOutputPath = this.hasOption(DOC_TOPIC_OUTPUT) ? new Path(this.getOption(DOC_TOPIC_OUTPUT)) : null;
        Path modelTempPath = this.hasOption(MODEL_TEMP_DIR) ? new Path(this.getOption(MODEL_TEMP_DIR)) : this.getTempPath("topicModelState");
        long seed = this.hasOption(RANDOM_SEED) ? Long.parseLong(this.getOption(RANDOM_SEED)) : System.nanoTime() % 10000L;
        float testFraction = this.hasOption(TEST_SET_FRACTION) ? Float.parseFloat(this.getOption(TEST_SET_FRACTION)) : 0.0f;
        int numReduceTasks = Integer.parseInt(this.getOption(NUM_REDUCE_TASKS));
        boolean backfillPerplexity = this.hasOption(BACKFILL_PERPLEXITY);
        return this.run(this.getConf(), inputPath, topicModelOutputPath, numTopics, numTerms, alpha, eta, maxIterations, iterationBlockSize, convergenceDelta, dictionaryPath, docTopicOutputPath, modelTempPath, seed, testFraction, numTrainThreads, numUpdateThreads, maxItersPerDoc, numReduceTasks, backfillPerplexity);
    }

    private static int getNumTerms(Configuration conf, Path dictionaryPath) throws IOException {
        FileSystem fs = dictionaryPath.getFileSystem(conf);
        Text key = new Text();
        IntWritable value = new IntWritable();
        int maxTermId = -1;
        for (FileStatus stat : fs.globStatus(dictionaryPath)) {
            SequenceFile.Reader reader = new SequenceFile.Reader(fs, stat.getPath(), conf);
            while (reader.next((Writable)key, (Writable)value)) {
                maxTermId = Math.max(maxTermId, value.get());
            }
        }
        return maxTermId + 1;
    }

    public int run(Configuration conf, Path inputPath, Path topicModelOutputPath, int numTopics, int numTerms, double alpha, double eta, int maxIterations, int iterationBlockSize, double convergenceDelta, Path dictionaryPath, Path docTopicOutputPath, Path topicModelStateTempPath, long randomSeed, float testFraction, int numTrainThreads, int numUpdateThreads, int maxItersPerDoc, int numReduceTasks, boolean backfillPerplexity) throws ClassNotFoundException, IOException, InterruptedException {
        Job docInferenceJob;
        this.setConf(conf);
        Preconditions.checkArgument(((double)testFraction >= 0.0 && (double)testFraction <= 1.0 ? 1 : 0) != 0, (String)"Expected 'testFraction' value in range [0, 1] but found value '%s'", (Object[])new Object[]{Float.valueOf(testFraction)});
        Preconditions.checkArgument((!backfillPerplexity || (double)testFraction > 0.0 ? 1 : 0) != 0, (String)"Expected 'testFraction' value in range (0, 1] but found value '%s'", (Object[])new Object[]{Float.valueOf(testFraction)});
        String infoString = "Will run Collapsed Variational Bayes (0th-derivative approximation) learning for LDA on {} (numTerms: {}), finding {}-topics, with document/topic prior {}, topic/term prior {}.  Maximum iterations to run will be {}, unless the change in perplexity is less than {}.  Topic model output (p(term|topic) for each topic) will be stored {}.  Random initialization seed is {}, holding out {} of the data for perplexity check\n";
        log.info(infoString, new Object[]{inputPath, numTerms, numTopics, alpha, eta, maxIterations, convergenceDelta, topicModelOutputPath, randomSeed, Float.valueOf(testFraction)});
        infoString = dictionaryPath == null ? "" : "Dictionary to be used located " + dictionaryPath.toString() + '\n';
        infoString = infoString + (docTopicOutputPath == null ? "" : "p(topic|docId) will be stored " + docTopicOutputPath.toString() + '\n');
        log.info(infoString);
        FileSystem fs = FileSystem.get((URI)topicModelStateTempPath.toUri(), (Configuration)conf);
        int iterationNumber = CVB0Driver.getCurrentIterationNumber(conf, topicModelStateTempPath, maxIterations);
        log.info("Current iteration number: {}", (Object)iterationNumber);
        conf.set(NUM_TOPICS, String.valueOf(numTopics));
        conf.set(NUM_TERMS, String.valueOf(numTerms));
        conf.set(DOC_TOPIC_SMOOTHING, String.valueOf(alpha));
        conf.set(TERM_TOPIC_SMOOTHING, String.valueOf(eta));
        conf.set(RANDOM_SEED, String.valueOf(randomSeed));
        conf.set(NUM_TRAIN_THREADS, String.valueOf(numTrainThreads));
        conf.set(NUM_UPDATE_THREADS, String.valueOf(numUpdateThreads));
        conf.set(MAX_ITERATIONS_PER_DOC, String.valueOf(maxItersPerDoc));
        conf.set(MODEL_WEIGHT, "1");
        conf.set(TEST_SET_FRACTION, String.valueOf(testFraction));
        ArrayList<Double> perplexities = new ArrayList<Double>();
        for (int i = 1; i <= iterationNumber; ++i) {
            Path modelPath = CVB0Driver.modelPath(topicModelStateTempPath, i);
            double perplexity = CVB0Driver.readPerplexity(conf, topicModelStateTempPath, i);
            if (Double.isNaN(perplexity)) {
                if (!backfillPerplexity || i % iterationBlockSize != 0) continue;
                log.info("Backfilling perplexity at iteration {}", (Object)i);
                if (!fs.exists(modelPath)) {
                    log.error("Model path '{}' does not exist; Skipping iteration {} perplexity calculation", (Object)modelPath.toString(), (Object)i);
                    continue;
                }
                perplexity = this.calculatePerplexity(conf, inputPath, modelPath, i);
            }
            perplexities.add(perplexity);
            log.info("Perplexity at iteration {} = {}", (Object)i, (Object)perplexity);
        }
        long startTime = System.currentTimeMillis();
        while (iterationNumber < maxIterations) {
            double delta;
            if (convergenceDelta > 0.0 && (delta = CVB0Driver.rateOfChange(perplexities)) < convergenceDelta) {
                log.info("Convergence achieved at iteration {} with perplexity {} and delta {}", new Object[]{iterationNumber, perplexities.get(perplexities.size() - 1), delta});
                break;
            }
            log.info("About to run iteration {} of {}", (Object)(++iterationNumber), (Object)maxIterations);
            Path modelInputPath = CVB0Driver.modelPath(topicModelStateTempPath, iterationNumber - 1);
            Path modelOutputPath = CVB0Driver.modelPath(topicModelStateTempPath, iterationNumber);
            this.runIteration(conf, inputPath, modelInputPath, modelOutputPath, iterationNumber, maxIterations, numReduceTasks);
            if (!(testFraction > 0.0f) || iterationNumber % iterationBlockSize != 0) continue;
            perplexities.add(this.calculatePerplexity(conf, inputPath, modelOutputPath, iterationNumber));
            log.info("Current perplexity = {}", perplexities.get(perplexities.size() - 1));
            log.info("(p_{} - p_{}) / p_0 = {}; target = {}", new Object[]{iterationNumber, iterationNumber - iterationBlockSize, CVB0Driver.rateOfChange(perplexities), convergenceDelta});
        }
        log.info("Completed {} iterations in {} seconds", (Object)iterationNumber, (Object)((System.currentTimeMillis() - startTime) / 1000L));
        log.info("Perplexities: ({})", (Object)Joiner.on((String)", ").join(perplexities));
        Path finalIterationData = CVB0Driver.modelPath(topicModelStateTempPath, iterationNumber);
        Job topicModelOutputJob = topicModelOutputPath != null ? this.writeTopicModel(conf, finalIterationData, topicModelOutputPath) : null;
        Job job = docInferenceJob = docTopicOutputPath != null ? this.writeDocTopicInference(conf, inputPath, finalIterationData, docTopicOutputPath) : null;
        if (topicModelOutputJob != null && !topicModelOutputJob.waitForCompletion(true)) {
            return -1;
        }
        if (docInferenceJob != null && !docInferenceJob.waitForCompletion(true)) {
            return -1;
        }
        return 0;
    }

    private static double rateOfChange(List<Double> perplexities) {
        int sz = perplexities.size();
        if (sz < 2) {
            return Double.MAX_VALUE;
        }
        return Math.abs(perplexities.get(sz - 1) - perplexities.get(sz - 2)) / perplexities.get(0);
    }

    private double calculatePerplexity(Configuration conf, Path corpusPath, Path modelPath, int iteration) throws IOException, ClassNotFoundException, InterruptedException {
        String jobName = "Calculating perplexity for " + modelPath;
        log.info("About to run: {}", (Object)jobName);
        Path outputPath = CVB0Driver.perplexityPath(modelPath.getParent(), iteration);
        Job job = this.prepareJob(corpusPath, outputPath, CachingCVB0PerplexityMapper.class, DoubleWritable.class, DoubleWritable.class, DualDoubleSumReducer.class, DoubleWritable.class, DoubleWritable.class);
        job.setJobName(jobName);
        job.setCombinerClass(DualDoubleSumReducer.class);
        job.setNumReduceTasks(1);
        CVB0Driver.setModelPaths(job, modelPath);
        HadoopUtil.delete(conf, outputPath);
        if (!job.waitForCompletion(true)) {
            throw new InterruptedException("Failed to calculate perplexity for: " + modelPath);
        }
        return CVB0Driver.readPerplexity(conf, modelPath.getParent(), iteration);
    }

    public static double readPerplexity(Configuration conf, Path topicModelStateTemp, int iteration) throws IOException {
        Path perplexityPath = CVB0Driver.perplexityPath(topicModelStateTemp, iteration);
        FileSystem fs = FileSystem.get((URI)perplexityPath.toUri(), (Configuration)conf);
        if (!fs.exists(perplexityPath)) {
            log.warn("Perplexity path {} does not exist, returning NaN", (Object)perplexityPath);
            return Double.NaN;
        }
        double perplexity = 0.0;
        double modelWeight = 0.0;
        long n = 0L;
        for (Pair pair : new SequenceFileDirIterable(perplexityPath, PathType.LIST, PathFilters.partFilter(), null, true, conf)) {
            modelWeight += ((DoubleWritable)pair.getFirst()).get();
            perplexity += ((DoubleWritable)pair.getSecond()).get();
            ++n;
        }
        log.info("Read {} entries with total perplexity {} and model weight {}", new Object[]{n, perplexity, modelWeight});
        return perplexity / modelWeight;
    }

    private Job writeTopicModel(Configuration conf, Path modelInput, Path output) throws IOException, InterruptedException, ClassNotFoundException {
        String jobName = String.format("Writing final topic/term distributions from %s to %s", modelInput, output);
        log.info("About to run: {}", (Object)jobName);
        Job job = this.prepareJob(modelInput, output, SequenceFileInputFormat.class, CVB0TopicTermVectorNormalizerMapper.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, jobName);
        job.submit();
        return job;
    }

    private Job writeDocTopicInference(Configuration conf, Path corpus, Path modelInput, Path output) throws IOException, ClassNotFoundException, InterruptedException {
        String jobName = String.format("Writing final document/topic inference from %s to %s", corpus, output);
        log.info("About to run: {}", (Object)jobName);
        Job job = this.prepareJob(corpus, output, SequenceFileInputFormat.class, CVB0DocInferenceMapper.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, jobName);
        FileSystem fs = FileSystem.get((URI)corpus.toUri(), (Configuration)conf);
        if (modelInput != null && fs.exists(modelInput)) {
            FileStatus[] statuses = fs.listStatus(modelInput, PathFilters.partFilter());
            URI[] modelUris = new URI[statuses.length];
            for (int i = 0; i < statuses.length; ++i) {
                modelUris[i] = statuses[i].getPath().toUri();
            }
            DistributedCache.setCacheFiles((URI[])modelUris, (Configuration)conf);
            CVB0Driver.setModelPaths(job, modelInput);
        }
        job.submit();
        return job;
    }

    public static Path modelPath(Path topicModelStateTempPath, int iterationNumber) {
        return new Path(topicModelStateTempPath, "model-" + iterationNumber);
    }

    public static Path perplexityPath(Path topicModelStateTempPath, int iterationNumber) {
        return new Path(topicModelStateTempPath, "perplexity-" + iterationNumber);
    }

    private static int getCurrentIterationNumber(Configuration config, Path modelTempDir, int maxIterations) throws IOException {
        FileSystem fs = FileSystem.get((URI)modelTempDir.toUri(), (Configuration)config);
        int iterationNumber = 1;
        Path iterationPath = CVB0Driver.modelPath(modelTempDir, iterationNumber);
        while (fs.exists(iterationPath) && iterationNumber <= maxIterations) {
            log.info("Found previous state: {}", (Object)iterationPath);
            iterationPath = CVB0Driver.modelPath(modelTempDir, ++iterationNumber);
        }
        return iterationNumber - 1;
    }

    public void runIteration(Configuration conf, Path corpusInput, Path modelInput, Path modelOutput, int iterationNumber, int maxIterations, int numReduceTasks) throws IOException, ClassNotFoundException, InterruptedException {
        String jobName = String.format("Iteration %d of %d, input path: %s", iterationNumber, maxIterations, modelInput);
        log.info("About to run: {}", (Object)jobName);
        Job job = this.prepareJob(corpusInput, modelOutput, CachingCVB0Mapper.class, IntWritable.class, VectorWritable.class, VectorSumReducer.class, IntWritable.class, VectorWritable.class);
        job.setCombinerClass(VectorSumReducer.class);
        job.setNumReduceTasks(numReduceTasks);
        job.setJobName(jobName);
        CVB0Driver.setModelPaths(job, modelInput);
        HadoopUtil.delete(conf, modelOutput);
        if (!job.waitForCompletion(true)) {
            throw new InterruptedException(String.format("Failed to complete iteration %d stage 1", iterationNumber));
        }
    }

    private static void setModelPaths(Job job, Path modelPath) throws IOException {
        Configuration conf = job.getConfiguration();
        if (modelPath == null || !FileSystem.get((URI)modelPath.toUri(), (Configuration)conf).exists(modelPath)) {
            return;
        }
        FileStatus[] statuses = FileSystem.get((URI)modelPath.toUri(), (Configuration)conf).listStatus(modelPath, PathFilters.partFilter());
        Preconditions.checkState((statuses.length > 0 ? 1 : 0) != 0, (String)"No part files found in model path '%s'", (Object[])new Object[]{modelPath.toString()});
        String[] modelPaths = new String[statuses.length];
        for (int i = 0; i < statuses.length; ++i) {
            modelPaths[i] = statuses[i].getPath().toUri().toString();
        }
        conf.setStrings(MODEL_PATHS, modelPaths);
    }

    public static Path[] getModelPaths(Configuration conf) {
        String[] modelPathNames = conf.getStrings(MODEL_PATHS);
        if (modelPathNames == null || modelPathNames.length == 0) {
            return null;
        }
        Path[] modelPaths = new Path[modelPathNames.length];
        for (int i = 0; i < modelPathNames.length; ++i) {
            modelPaths[i] = new Path(modelPathNames[i]);
        }
        return modelPaths;
    }

    public static void main(String[] args) throws Exception {
        ToolRunner.run((Configuration)new Configuration(), (Tool)new CVB0Driver(), (String[])args);
    }

    public static class DualDoubleSumReducer
    extends Reducer<DoubleWritable, DoubleWritable, DoubleWritable, DoubleWritable> {
        private final DoubleWritable outKey = new DoubleWritable();
        private final DoubleWritable outValue = new DoubleWritable();

        public void run(Reducer.Context context) throws IOException, InterruptedException {
            double keySum = 0.0;
            double valueSum = 0.0;
            while (context.nextKey()) {
                keySum += ((DoubleWritable)context.getCurrentKey()).get();
                for (DoubleWritable value : context.getValues()) {
                    valueSum += value.get();
                }
            }
            this.outKey.set(keySum);
            this.outValue.set(valueSum);
            context.write((Object)this.outKey, (Object)this.outValue);
        }
    }
}

