/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.xgboost;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Model;
import org.tribuo.OutputFactory;
import org.tribuo.data.DataOptions;
import org.tribuo.regression.RegressionFactory;
import org.tribuo.regression.evaluation.RegressionEvaluation;
import org.tribuo.regression.xgboost.XGBoostRegressionTrainer;
import org.tribuo.util.Util;

public class TrainTest {
    private static final Logger logger = Logger.getLogger(TrainTest.class.getName());

    public static void main(String[] args) throws IOException {
        ConfigurationManager cm;
        LabsLogFormatter.setAllLogFormatters();
        XGBoostOptions o = new XGBoostOptions();
        try {
            cm = new ConfigurationManager(args, (Options)o);
        }
        catch (UsageException e) {
            logger.info(e.getMessage());
            return;
        }
        if (o.general.trainingPath == null || o.general.testingPath == null) {
            logger.info(cm.usage());
            logger.info("Please supply a training path and a testing path");
            return;
        }
        if (o.ensembleSize == -1) {
            logger.info(cm.usage());
            logger.info("Please supply the number of trees.");
            return;
        }
        RegressionFactory factory = new RegressionFactory();
        Pair data = o.general.load((OutputFactory)factory);
        Dataset train = (Dataset)data.getA();
        Dataset test = (Dataset)data.getB();
        XGBoostRegressionTrainer trainer = new XGBoostRegressionTrainer(o.rType, o.ensembleSize, o.eta, o.gamma, o.depth, o.minWeight, o.subsample, o.subsampleFeatures, o.lambda, o.alpha, o.numThreads, o.quiet, o.general.seed);
        logger.info("Training using " + trainer.toString());
        long trainStart = System.currentTimeMillis();
        Model model = trainer.train(train);
        long trainStop = System.currentTimeMillis();
        logger.info("Finished training regressor " + Util.formatDuration((long)trainStart, (long)trainStop));
        long testStart = System.currentTimeMillis();
        RegressionEvaluation evaluation = (RegressionEvaluation)factory.getEvaluator().evaluate(model, test);
        long testStop = System.currentTimeMillis();
        logger.info("Finished evaluating model " + Util.formatDuration((long)testStart, (long)testStop));
        System.out.println(evaluation.toString());
        if (o.general.outputPath != null) {
            o.general.saveModel(model);
        }
    }

    public static class XGBoostOptions
    implements Options {
        public DataOptions general;
        @Option(longName="regression-metric", usage="Regression type to use. Defaults to LINEAR.")
        public XGBoostRegressionTrainer.RegressionType rType = XGBoostRegressionTrainer.RegressionType.LINEAR;
        @Option(charName=109, longName="ensemble-size", usage="Number of trees in the ensemble.")
        public int ensembleSize = -1;
        @Option(charName=97, longName="alpha", usage="L1 regularization term for weights (default 0).")
        public float alpha = 0.0f;
        @Option(longName="min-weight", usage="Minimum sum of instance weights needed in a leaf (default 1, range [0,inf]).")
        public float minWeight = 1.0f;
        @Option(charName=100, longName="max-depth", usage="Max tree depth (default 6, range (0,inf]).")
        public int depth = 6;
        @Option(charName=101, longName="eta", usage="Step size shrinkage parameter (default 0.3, range [0,1]).")
        public float eta = 0.3f;
        @Option(longName="subsample-features", usage="Subsample features for each tree (default 1, range (0,1]).")
        public float subsampleFeatures = 1.0f;
        @Option(charName=103, longName="gamma", usage="Minimum loss reduction to make a split (default 0, range [0,inf]).")
        public float gamma = 0.0f;
        @Option(charName=108, longName="lambda", usage="L2 regularization term for weights (default 1).")
        public float lambda = 1.0f;
        @Option(charName=113, longName="quiet", usage="Make the XGBoost training procedure quiet.")
        public boolean quiet;
        @Option(longName="subsample", usage="Subsample size for each tree (default 1, range (0,1]).")
        public float subsample = 1.0f;
        @Option(charName=116, longName="num-threads", usage="Number of threads to use (default 4, range (1, num hw threads)).")
        public int numThreads = 4;

        public String getOptionsDescription() {
            return "Trains and tests an XGBoost regression model on the specified datasets.";
        }
    }
}

