/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.cf.taste.hadoop.als;

import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.net.URI;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.Charsets;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
import org.apache.mahout.cf.taste.hadoop.als.ALS;
import org.apache.mahout.cf.taste.hadoop.als.ParallelALSFactorizationJob;
import org.apache.mahout.cf.taste.hadoop.als.RecommenderJob;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.map.OpenIntObjectHashMap;

public class FactorizationEvaluator
extends AbstractJob {
    private static final String USER_FEATURES_PATH = RecommenderJob.class.getName() + ".userFeatures";
    private static final String ITEM_FEATURES_PATH = RecommenderJob.class.getName() + ".itemFeatures";

    public static void main(String[] args) throws Exception {
        ToolRunner.run((Tool)new FactorizationEvaluator(), (String[])args);
    }

    public int run(String[] args) throws Exception {
        boolean succeeded;
        this.addInputOption();
        this.addOption("userFeatures", null, "path to the user feature matrix", true);
        this.addOption("itemFeatures", null, "path to the item feature matrix", true);
        this.addOption("usesLongIDs", null, "input contains long IDs that need to be translated");
        this.addOutputOption();
        Map<String, List<String>> parsedArgs = this.parseArguments(args);
        if (parsedArgs == null) {
            return -1;
        }
        Path errors = this.getTempPath("errors");
        Job predictRatings = this.prepareJob(this.getInputPath(), errors, TextInputFormat.class, PredictRatingsMapper.class, DoubleWritable.class, NullWritable.class, SequenceFileOutputFormat.class);
        Configuration conf = predictRatings.getConfiguration();
        conf.set(USER_FEATURES_PATH, this.getOption("userFeatures"));
        conf.set(ITEM_FEATURES_PATH, this.getOption("itemFeatures"));
        boolean usesLongIDs = Boolean.parseBoolean(this.getOption("usesLongIDs"));
        if (usesLongIDs) {
            conf.set(ParallelALSFactorizationJob.USES_LONG_IDS, String.valueOf(true));
        }
        if (!(succeeded = predictRatings.waitForCompletion(true))) {
            return -1;
        }
        FileSystem fs = FileSystem.get((URI)this.getOutputPath().toUri(), (Configuration)this.getConf());
        FSDataOutputStream outputStream = fs.create(this.getOutputPath("rmse.txt"));
        try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter((OutputStream)outputStream, Charsets.UTF_8));){
            double rmse = this.computeRmse(errors);
            writer.write(String.valueOf(rmse));
        }
        return 0;
    }

    private double computeRmse(Path errors) {
        FullRunningAverage average = new FullRunningAverage();
        for (Pair entry : new SequenceFileDirIterable(errors, PathType.LIST, PathFilters.logsCRCFilter(), this.getConf())) {
            DoubleWritable error = (DoubleWritable)entry.getFirst();
            average.addDatum(error.get() * error.get());
        }
        return Math.sqrt(average.getAverage());
    }

    public static class PredictRatingsMapper
    extends Mapper<LongWritable, Text, DoubleWritable, NullWritable> {
        private OpenIntObjectHashMap<Vector> U;
        private OpenIntObjectHashMap<Vector> M;
        private boolean usesLongIDs;
        private final DoubleWritable error = new DoubleWritable();

        protected void setup(Mapper.Context ctx) throws IOException, InterruptedException {
            Configuration conf = ctx.getConfiguration();
            Path pathToU = new Path(conf.get(USER_FEATURES_PATH));
            Path pathToM = new Path(conf.get(ITEM_FEATURES_PATH));
            this.U = ALS.readMatrixByRows(pathToU, conf);
            this.M = ALS.readMatrixByRows(pathToM, conf);
            this.usesLongIDs = conf.getBoolean(ParallelALSFactorizationJob.USES_LONG_IDS, false);
        }

        protected void map(LongWritable key, Text value, Mapper.Context ctx) throws IOException, InterruptedException {
            String[] tokens = TasteHadoopUtils.splitPrefTokens(value.toString());
            int userID = TasteHadoopUtils.readID(tokens[0], this.usesLongIDs);
            int itemID = TasteHadoopUtils.readID(tokens[1], this.usesLongIDs);
            double rating = Double.parseDouble(tokens[2]);
            if (this.U.containsKey(userID) && this.M.containsKey(itemID)) {
                double estimate = ((Vector)this.U.get(userID)).dot((Vector)this.M.get(itemID));
                this.error.set(rating - estimate);
                ctx.write((Object)this.error, (Object)NullWritable.get());
            }
        }
    }
}

