/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.mlp;

import com.google.common.base.Preconditions;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.classifier.mlp.MultilayerPerceptron;
import org.apache.mahout.math.Arrays;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Deprecated
public final class TrainMultilayerPerceptron {
    private static final Logger log = LoggerFactory.getLogger(TrainMultilayerPerceptron.class);

    public static void main(String[] args) throws Exception {
        Parameters parameters = new Parameters();
        if (TrainMultilayerPerceptron.parseArgs(args, parameters)) {
            MultilayerPerceptron mlp;
            log.info("Validate model...");
            Path modelPath = new Path(parameters.modelFilePath);
            FileSystem modelFs = modelPath.getFileSystem(new Configuration());
            if (modelFs.exists(modelPath) && parameters.updateModel) {
                log.info("Build model from existing model...");
                mlp = new MultilayerPerceptron(parameters.modelFilePath);
            } else {
                if (modelFs.exists(modelPath)) {
                    modelFs.delete(modelPath, true);
                }
                log.info("Build model from scratch...");
                mlp = new MultilayerPerceptron();
                for (int i = 0; i < parameters.layerSizeList.size(); ++i) {
                    if (i != parameters.layerSizeList.size() - 1) {
                        mlp.addLayer(parameters.layerSizeList.get(i), false, parameters.squashingFunctionName);
                    } else {
                        mlp.addLayer(parameters.layerSizeList.get(i), true, parameters.squashingFunctionName);
                    }
                    mlp.setCostFunction("Minus_Squared");
                    mlp.setLearningRate(parameters.learningRate).setMomentumWeight(parameters.momemtumWeight).setRegularizationWeight(parameters.regularizationWeight);
                }
                mlp.setModelPath(parameters.modelFilePath);
            }
            mlp.setLearningRate(parameters.learningRate).setMomentumWeight(parameters.momemtumWeight).setRegularizationWeight(parameters.regularizationWeight);
            Path trainingDataPath = new Path(parameters.inputFilePath);
            FileSystem dataFs = trainingDataPath.getFileSystem(new Configuration());
            Preconditions.checkArgument((boolean)dataFs.exists(trainingDataPath), (String)"Training dataset %s cannot be found!", (Object[])new Object[]{parameters.inputFilePath});
            log.info("Read data and train model...");
            try (BufferedReader reader = new BufferedReader(new InputStreamReader((InputStream)dataFs.open(trainingDataPath)));){
                String line;
                if (parameters.skipHeader) {
                    reader.readLine();
                }
                int labelDimension = parameters.labelsIndex.size();
                while ((line = reader.readLine()) != null) {
                    int i;
                    String[] token = line.split(",");
                    String label = token[token.length - 1];
                    int labelIndex = parameters.labelsIndex.get(label);
                    double[] instances = new double[token.length - 1 + labelDimension];
                    for (i = 0; i < token.length - 1; ++i) {
                        instances[i] = Double.parseDouble(token[i]);
                    }
                    for (i = 0; i < labelDimension; ++i) {
                        instances[token.length - 1 + i] = 0.0;
                    }
                    instances[token.length - 1 + labelIndex] = 1.0;
                    Vector instance = new DenseVector(instances).viewPart(0, instances.length);
                    mlp.trainOnline(instance);
                }
                log.info("Write trained model to {}", (Object)parameters.modelFilePath);
                mlp.writeModelToFile();
                mlp.close();
            }
        }
    }

    private static boolean parseArgs(String[] args, Parameters parameters) throws Exception {
        log.info("Validate and parse arguments...");
        DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
        GroupBuilder groupBuilder = new GroupBuilder();
        ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        DefaultOption skipHeaderOption = optionBuilder.withLongName("skipHeader").withShortName("sh").create();
        Group skipHeaderGroup = groupBuilder.withOption((Option)skipHeaderOption).create();
        DefaultOption inputOption = optionBuilder.withLongName("input").withShortName("i").withRequired(true).withChildren(skipHeaderGroup).withArgument(argumentBuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("the file path of training dataset").create();
        DefaultOption labelsOption = optionBuilder.withLongName("labels").withShortName("labels").withRequired(true).withArgument(argumentBuilder.withName("label-name").withMinimum(2).create()).withDescription("label names").create();
        DefaultOption updateOption = optionBuilder.withLongName("update").withShortName("u").withDescription("whether to incrementally update model if the model exists").create();
        Group modelUpdateGroup = groupBuilder.withOption((Option)updateOption).create();
        DefaultOption modelOption = optionBuilder.withLongName("model").withShortName("mo").withRequired(true).withArgument(argumentBuilder.withName("model-path").withMinimum(1).withMaximum(1).create()).withDescription("the path to store the trained model").withChildren(modelUpdateGroup).create();
        DefaultOption layerSizeOption = optionBuilder.withLongName("layerSize").withShortName("ls").withRequired(true).withArgument(argumentBuilder.withName("size of layer").withMinimum(2).withMaximum(5).create()).withDescription("the size of each layer").create();
        DefaultOption squashingFunctionOption = optionBuilder.withLongName("squashingFunction").withShortName("sf").withArgument(argumentBuilder.withName("squashing function").withMinimum(1).withMaximum(1).withDefault((Object)"Sigmoid").create()).withDescription("the name of squashing function (currently only supports Sigmoid)").create();
        DefaultOption learningRateOption = optionBuilder.withLongName("learningRate").withShortName("l").withArgument(argumentBuilder.withName("learning rate").withMaximum(1).withMinimum(1).withDefault((Object)0.5).create()).withDescription("learning rate").create();
        DefaultOption momemtumOption = optionBuilder.withLongName("momemtumWeight").withShortName("m").withArgument(argumentBuilder.withName("momemtum weight").withMaximum(1).withMinimum(1).withDefault((Object)0.1).create()).withDescription("momemtum weight").create();
        DefaultOption regularizationOption = optionBuilder.withLongName("regularizationWeight").withShortName("r").withArgument(argumentBuilder.withName("regularization weight").withMaximum(1).withMinimum(1).withDefault((Object)0.0).create()).withDescription("regularization weight").create();
        Parser parser = new Parser();
        Group normalOptions = groupBuilder.withOption((Option)inputOption).withOption((Option)skipHeaderOption).withOption((Option)updateOption).withOption((Option)labelsOption).withOption((Option)modelOption).withOption((Option)layerSizeOption).withOption((Option)squashingFunctionOption).withOption((Option)learningRateOption).withOption((Option)momemtumOption).withOption((Option)regularizationOption).create();
        parser.setGroup(normalOptions);
        CommandLine commandLine = parser.parseAndHelp(args);
        if (commandLine == null) {
            return false;
        }
        parameters.learningRate = TrainMultilayerPerceptron.getDouble(commandLine, (Option)learningRateOption);
        parameters.momemtumWeight = TrainMultilayerPerceptron.getDouble(commandLine, (Option)momemtumOption);
        parameters.regularizationWeight = TrainMultilayerPerceptron.getDouble(commandLine, (Option)regularizationOption);
        parameters.inputFilePath = TrainMultilayerPerceptron.getString(commandLine, (Option)inputOption);
        parameters.skipHeader = commandLine.hasOption((Option)skipHeaderOption);
        List<String> labelsList = TrainMultilayerPerceptron.getStringList(commandLine, (Option)labelsOption);
        int currentIndex = 0;
        for (String label : labelsList) {
            parameters.labelsIndex.put(label, currentIndex++);
        }
        parameters.modelFilePath = TrainMultilayerPerceptron.getString(commandLine, (Option)modelOption);
        parameters.updateModel = commandLine.hasOption((Option)updateOption);
        parameters.layerSizeList = TrainMultilayerPerceptron.getIntegerList(commandLine, (Option)layerSizeOption);
        parameters.squashingFunctionName = TrainMultilayerPerceptron.getString(commandLine, (Option)squashingFunctionOption);
        System.out.printf("Input: %s, Model: %s, Update: %s, Layer size: %s, Squashing function: %s, Learning rate: %f, Momemtum weight: %f, Regularization Weight: %f\n", parameters.inputFilePath, parameters.modelFilePath, parameters.updateModel, Arrays.toString((Object[])parameters.layerSizeList.toArray()), parameters.squashingFunctionName, parameters.learningRate, parameters.momemtumWeight, parameters.regularizationWeight);
        return true;
    }

    static Double getDouble(CommandLine commandLine, Option option) {
        Object val = commandLine.getValue(option);
        if (val != null) {
            return Double.parseDouble(val.toString());
        }
        return null;
    }

    static String getString(CommandLine commandLine, Option option) {
        Object val = commandLine.getValue(option);
        if (val != null) {
            return val.toString();
        }
        return null;
    }

    static List<Integer> getIntegerList(CommandLine commandLine, Option option) {
        List list = commandLine.getValues(option);
        ArrayList<Integer> valList = new ArrayList<Integer>();
        for (String str : list) {
            valList.add(Integer.parseInt(str));
        }
        return valList;
    }

    static List<String> getStringList(CommandLine commandLine, Option option) {
        return commandLine.getValues(option);
    }

    static class Parameters {
        double learningRate;
        double momemtumWeight;
        double regularizationWeight;
        String inputFilePath;
        boolean skipHeader;
        Map<String, Integer> labelsIndex = new HashMap<String, Integer>();
        String modelFilePath;
        boolean updateModel;
        List<Integer> layerSizeList = new ArrayList<Integer>();
        String squashingFunctionName;

        Parameters() {
        }
    }
}

