/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import com.google.common.io.Resources;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Locale;
import java.util.Set;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.commons.cli2.util.HelpFormatter;
import org.apache.commons.io.Charsets;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticModelParameters;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
import org.apache.mahout.classifier.sgd.CrossFoldLearner;
import org.apache.mahout.classifier.sgd.CsvRecordFactory;
import org.apache.mahout.classifier.sgd.LogisticModelParameters;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.classifier.sgd.RecordFactory;
import org.apache.mahout.ep.State;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;

public final class TrainAdaptiveLogistic {
    private static String inputFile;
    private static String outputFile;
    private static AdaptiveLogisticModelParameters lmp;
    private static int passes;
    private static boolean showperf;
    private static int skipperfnum;
    private static AdaptiveLogisticRegression model;

    private TrainAdaptiveLogistic() {
    }

    public static void main(String[] args) throws Exception {
        TrainAdaptiveLogistic.mainToOutput(args, new PrintWriter((Writer)new OutputStreamWriter((OutputStream)System.out, Charsets.UTF_8), true));
    }

    static void mainToOutput(String[] args, PrintWriter output) throws Exception {
        if (TrainAdaptiveLogistic.parseArgs(args)) {
            State best;
            Object in;
            CsvRecordFactory csv = lmp.getCsvRecordFactory();
            model = lmp.createAdaptiveLogisticRegression();
            CrossFoldLearner learner = null;
            int k = 0;
            for (int pass = 0; pass < passes; ++pass) {
                in = TrainAdaptiveLogistic.open(inputFile);
                csv.firstLine(((BufferedReader)in).readLine());
                Object line = ((BufferedReader)in).readLine();
                while (line != null) {
                    RandomAccessSparseVector input = new RandomAccessSparseVector(lmp.getNumFeatures());
                    int targetValue = csv.processLine((String)line, (Vector)input);
                    model.train(targetValue, (Vector)input);
                    if (showperf && ++k % (skipperfnum + 1) == 0) {
                        best = model.getBest();
                        if (best != null) {
                            learner = ((AdaptiveLogisticRegression.Wrapper)best.getPayload()).getLearner();
                        }
                        if (learner != null) {
                            double averageCorrect = learner.percentCorrect();
                            double averageLL = learner.logLikelihood();
                            output.printf("%d\t%.3f\t%.2f%n", k, averageLL, averageCorrect * 100.0);
                        } else {
                            output.printf(Locale.ENGLISH, "%10d %2d %s%n", k, targetValue, "AdaptiveLogisticRegression has not found a good model ......");
                        }
                    }
                    line = ((BufferedReader)in).readLine();
                }
                ((BufferedReader)in).close();
            }
            best = model.getBest();
            if (best != null) {
                learner = ((AdaptiveLogisticRegression.Wrapper)best.getPayload()).getLearner();
            }
            if (learner == null) {
                output.println("AdaptiveLogisticRegression has failed to train a model.");
                return;
            }
            FileOutputStream modelOutput = new FileOutputStream(outputFile);
            in = null;
            try {
                lmp.saveTo(modelOutput);
            }
            catch (Throwable line) {
                in = line;
                throw line;
            }
            finally {
                if (modelOutput != null) {
                    if (in != null) {
                        try {
                            ((OutputStream)modelOutput).close();
                        }
                        catch (Throwable line) {
                            ((Throwable)in).addSuppressed(line);
                        }
                    } else {
                        ((OutputStream)modelOutput).close();
                    }
                }
            }
            OnlineLogisticRegression lr = (OnlineLogisticRegression)learner.getModels().get(0);
            output.println(lmp.getNumFeatures());
            output.println(lmp.getTargetVariable() + " ~ ");
            String sep = "";
            for (String v : csv.getTraceDictionary().keySet()) {
                double weight = TrainAdaptiveLogistic.predictorWeight(lr, 0, (RecordFactory)csv, v);
                if (weight == 0.0) continue;
                output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
                sep = " + ";
            }
            output.printf("%n", new Object[0]);
            for (int row = 0; row < lr.getBeta().numRows(); ++row) {
                for (String key : csv.getTraceDictionary().keySet()) {
                    double weight = TrainAdaptiveLogistic.predictorWeight(lr, row, (RecordFactory)csv, key);
                    if (weight == 0.0) continue;
                    output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight);
                }
                for (int column = 0; column < lr.getBeta().numCols(); ++column) {
                    output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
                }
                output.println();
            }
        }
    }

    private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {
        double weight = 0.0;
        for (Integer column : (Set)csv.getTraceDictionary().get(predictor)) {
            weight += lr.getBeta().get(row, column.intValue());
        }
        return weight;
    }

    private static boolean parseArgs(String[] args) {
        DefaultOptionBuilder builder = new DefaultOptionBuilder();
        DefaultOption help = builder.withLongName("help").withDescription("print this list").create();
        DefaultOption quiet = builder.withLongName("quiet").withDescription("be extra quiet").create();
        ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        DefaultOption showperf = builder.withLongName("showperf").withDescription("output performance measures during training").create();
        DefaultOption inputFile = builder.withLongName("input").withRequired(true).withArgument(argumentBuilder.withName("input").withMaximum(1).create()).withDescription("where to get training data").create();
        DefaultOption outputFile = builder.withLongName("output").withRequired(true).withArgument(argumentBuilder.withName("output").withMaximum(1).create()).withDescription("where to write the model content").create();
        DefaultOption threads = builder.withLongName("threads").withArgument(argumentBuilder.withName("threads").withDefault((Object)"4").create()).withDescription("the number of threads AdaptiveLogisticRegression uses").create();
        DefaultOption predictors = builder.withLongName("predictors").withRequired(true).withArgument(argumentBuilder.withName("predictors").create()).withDescription("a list of predictor variables").create();
        DefaultOption types = builder.withLongName("types").withRequired(true).withArgument(argumentBuilder.withName("types").create()).withDescription("a list of predictor variable types (numeric, word, or text)").create();
        DefaultOption target = builder.withLongName("target").withDescription("the name of the target variable").withRequired(true).withArgument(argumentBuilder.withName("target").withMaximum(1).create()).create();
        DefaultOption targetCategories = builder.withLongName("categories").withDescription("the number of target categories to be considered").withRequired(true).withArgument(argumentBuilder.withName("categories").withMaximum(1).create()).create();
        DefaultOption features = builder.withLongName("features").withDescription("the number of internal hashed features to use").withArgument(argumentBuilder.withName("numFeatures").withDefault((Object)"1000").withMaximum(1).create()).create();
        DefaultOption passes = builder.withLongName("passes").withDescription("the number of times to pass over the input data").withArgument(argumentBuilder.withName("passes").withDefault((Object)"2").withMaximum(1).create()).create();
        DefaultOption interval = builder.withLongName("interval").withArgument(argumentBuilder.withName("interval").withDefault((Object)"500").create()).withDescription("the interval property of AdaptiveLogisticRegression").create();
        DefaultOption window = builder.withLongName("window").withArgument(argumentBuilder.withName("window").withDefault((Object)"800").create()).withDescription("the average propery of AdaptiveLogisticRegression").create();
        DefaultOption skipperfnum = builder.withLongName("skipperfnum").withArgument(argumentBuilder.withName("skipperfnum").withDefault((Object)"99").create()).withDescription("show performance measures every (skipperfnum + 1) rows").create();
        DefaultOption prior = builder.withLongName("prior").withArgument(argumentBuilder.withName("prior").withDefault((Object)"L1").create()).withDescription("the prior algorithm to use: L1, L2, ebp, tp, up").create();
        DefaultOption priorOption = builder.withLongName("prioroption").withArgument(argumentBuilder.withName("prioroption").create()).withDescription("constructor parameter for ElasticBandPrior and TPrior").create();
        DefaultOption auc = builder.withLongName("auc").withArgument(argumentBuilder.withName("auc").withDefault((Object)"global").create()).withDescription("the auc to use: global or grouped").create();
        Group normalArgs = new GroupBuilder().withOption((Option)help).withOption((Option)quiet).withOption((Option)inputFile).withOption((Option)outputFile).withOption((Option)target).withOption((Option)targetCategories).withOption((Option)predictors).withOption((Option)types).withOption((Option)passes).withOption((Option)interval).withOption((Option)window).withOption((Option)threads).withOption((Option)prior).withOption((Option)features).withOption((Option)showperf).withOption((Option)skipperfnum).withOption((Option)priorOption).withOption((Option)auc).create();
        Parser parser = new Parser();
        parser.setHelpOption((Option)help);
        parser.setHelpTrigger("--help");
        parser.setGroup(normalArgs);
        parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
        CommandLine cmdLine = parser.parseAndHelp(args);
        if (cmdLine == null) {
            return false;
        }
        TrainAdaptiveLogistic.inputFile = TrainAdaptiveLogistic.getStringArgument(cmdLine, (Option)inputFile);
        TrainAdaptiveLogistic.outputFile = TrainAdaptiveLogistic.getStringArgument(cmdLine, (Option)outputFile);
        ArrayList<String> typeList = new ArrayList<String>();
        for (Object x : cmdLine.getValues((Option)types)) {
            typeList.add(x.toString());
        }
        ArrayList<String> predictorList = new ArrayList<String>();
        for (Object x : cmdLine.getValues((Option)predictors)) {
            predictorList.add(x.toString());
        }
        lmp = new AdaptiveLogisticModelParameters();
        lmp.setTargetVariable(TrainAdaptiveLogistic.getStringArgument(cmdLine, (Option)target));
        lmp.setMaxTargetCategories(TrainAdaptiveLogistic.getIntegerArgument(cmdLine, (Option)targetCategories));
        lmp.setNumFeatures(TrainAdaptiveLogistic.getIntegerArgument(cmdLine, (Option)features));
        lmp.setInterval(TrainAdaptiveLogistic.getIntegerArgument(cmdLine, (Option)interval));
        lmp.setAverageWindow(TrainAdaptiveLogistic.getIntegerArgument(cmdLine, (Option)window));
        lmp.setThreads(TrainAdaptiveLogistic.getIntegerArgument(cmdLine, (Option)threads));
        lmp.setAuc(TrainAdaptiveLogistic.getStringArgument(cmdLine, (Option)auc));
        lmp.setPrior(TrainAdaptiveLogistic.getStringArgument(cmdLine, (Option)prior));
        if (cmdLine.getValue((Option)priorOption) != null) {
            lmp.setPriorOption(TrainAdaptiveLogistic.getDoubleArgument(cmdLine, (Option)priorOption));
        }
        lmp.setTypeMap(predictorList, typeList);
        TrainAdaptiveLogistic.showperf = TrainAdaptiveLogistic.getBooleanArgument(cmdLine, (Option)showperf);
        TrainAdaptiveLogistic.skipperfnum = TrainAdaptiveLogistic.getIntegerArgument(cmdLine, (Option)skipperfnum);
        TrainAdaptiveLogistic.passes = TrainAdaptiveLogistic.getIntegerArgument(cmdLine, (Option)passes);
        lmp.checkParameters();
        return true;
    }

    private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
        return (String)cmdLine.getValue(inputFile);
    }

    private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
        return cmdLine.hasOption(option);
    }

    private static int getIntegerArgument(CommandLine cmdLine, Option features) {
        return Integer.parseInt((String)cmdLine.getValue(features));
    }

    private static double getDoubleArgument(CommandLine cmdLine, Option op) {
        return Double.parseDouble((String)cmdLine.getValue(op));
    }

    public static AdaptiveLogisticRegression getModel() {
        return model;
    }

    public static LogisticModelParameters getParameters() {
        return lmp;
    }

    static BufferedReader open(String inputFile) throws IOException {
        InputStream in;
        try {
            in = Resources.getResource((String)inputFile).openStream();
        }
        catch (IllegalArgumentException e) {
            in = new FileInputStream(new File(inputFile));
        }
        return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8));
    }

    static {
        skipperfnum = 99;
    }
}

