/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.slm;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.util.Map;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Model;
import org.tribuo.OutputFactory;
import org.tribuo.SparseModel;
import org.tribuo.data.DataOptions;
import org.tribuo.math.la.SparseVector;
import org.tribuo.regression.RegressionFactory;
import org.tribuo.regression.evaluation.RegressionEvaluation;
import org.tribuo.regression.slm.ElasticNetCDTrainer;
import org.tribuo.regression.slm.LARSLassoTrainer;
import org.tribuo.regression.slm.LARSTrainer;
import org.tribuo.regression.slm.SLMTrainer;
import org.tribuo.regression.slm.SparseLinearModel;
import org.tribuo.util.Util;

public class TrainTest {
    private static final Logger logger = Logger.getLogger(TrainTest.class.getName());

    public static void main(String[] args) throws IOException {
        Object trainer;
        ConfigurationManager cm;
        LabsLogFormatter.setAllLogFormatters();
        SLMOptions o = new SLMOptions();
        try {
            cm = new ConfigurationManager(args, (Options)o);
        }
        catch (UsageException e) {
            logger.info(e.getMessage());
            return;
        }
        if (o.general.trainingPath == null || o.general.testingPath == null) {
            logger.info(cm.usage());
            return;
        }
        RegressionFactory factory = new RegressionFactory();
        Pair data = o.general.load((OutputFactory)factory);
        Dataset train = (Dataset)data.getA();
        Dataset test = (Dataset)data.getB();
        switch (o.algorithm) {
            case SFS: {
                trainer = new SLMTrainer(false, Math.min(train.getFeatureMap().size(), o.maxNumFeatures));
                break;
            }
            case LARS: {
                trainer = new LARSTrainer(Math.min(train.getFeatureMap().size(), o.maxNumFeatures));
                break;
            }
            case LARSLASSO: {
                trainer = new LARSLassoTrainer(Math.min(train.getFeatureMap().size(), o.maxNumFeatures));
                break;
            }
            case SFSN: {
                trainer = new SLMTrainer(true, Math.min(train.getFeatureMap().size(), o.maxNumFeatures));
                break;
            }
            case ELASTICNET: {
                trainer = new ElasticNetCDTrainer(o.alpha, o.l1Ratio, 1.0E-4, o.iterations, false, o.general.seed);
                break;
            }
            default: {
                logger.warning("Unknown SLMType, found " + (Object)((Object)o.algorithm));
                return;
            }
        }
        logger.info("Training using " + trainer.toString());
        long trainStart = System.currentTimeMillis();
        SparseModel model = trainer.train(train);
        long trainStop = System.currentTimeMillis();
        logger.info("Finished training regressor " + Util.formatDuration((long)trainStart, (long)trainStop));
        logger.info("Selected features: " + model.getActiveFeatures());
        Map<String, SparseVector> weights = ((SparseLinearModel)model).getWeights();
        for (Map.Entry<String, SparseVector> e : weights.entrySet()) {
            logger.info("Target:" + e.getKey());
            logger.info("\tWeights: " + e.getValue());
            logger.info("\tWeights one norm: " + e.getValue().oneNorm());
            logger.info("\tWeights two norm: " + e.getValue().twoNorm());
        }
        long testStart = System.currentTimeMillis();
        RegressionEvaluation evaluation = (RegressionEvaluation)factory.getEvaluator().evaluate((Model)model, test);
        long testStop = System.currentTimeMillis();
        logger.info("Finished evaluating model " + Util.formatDuration((long)testStart, (long)testStop));
        System.out.println(evaluation.toString());
        if (o.general.outputPath != null) {
            o.general.saveModel((Model)model);
        }
    }

    public static class SLMOptions
    implements Options {
        public DataOptions general;
        @Option(charName=109, longName="max-features-num", usage="Set the maximum number of features.")
        public int maxNumFeatures = -1;
        @Option(charName=97, longName="algorithm", usage="Choose the training algorithm (stepwise forward selection or least angle regression).")
        public SLMType algorithm = SLMType.LARS;
        @Option(charName=98, longName="alpha", usage="Regularisation strength in the Elastic Net.")
        public double alpha = 1.0;
        @Option(charName=108, longName="l1Ratio", usage="Ratio between the l1 and l2 penalties in the Elastic Net. Must be between 0 and 1.")
        public double l1Ratio = 1.0;
        @Option(longName="iterations", usage="Iterations of Elastic Net.")
        public int iterations = 500;

        public String getOptionsDescription() {
            return "Trains and tests a sparse linear regression model on the specified datasets.";
        }
    }

    public static enum SLMType {
        SFS,
        SFSN,
        LARS,
        LARSLASSO,
        ELASTICNET;

    }
}

