/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.rtree;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Model;
import org.tribuo.OutputFactory;
import org.tribuo.SparseModel;
import org.tribuo.common.tree.AbstractCARTTrainer;
import org.tribuo.data.DataOptions;
import org.tribuo.regression.RegressionFactory;
import org.tribuo.regression.evaluation.RegressionEvaluation;
import org.tribuo.regression.rtree.CARTJointRegressionTrainer;
import org.tribuo.regression.rtree.CARTRegressionTrainer;
import org.tribuo.regression.rtree.impurity.MeanAbsoluteError;
import org.tribuo.regression.rtree.impurity.MeanSquaredError;
import org.tribuo.regression.rtree.impurity.RegressorImpurity;
import org.tribuo.util.Util;

public class TrainTest {
    private static final Logger logger = Logger.getLogger(TrainTest.class.getName());

    public static void main(String[] args) throws IOException {
        AbstractCARTTrainer trainer;
        RegressorImpurity impurity;
        ConfigurationManager cm;
        LabsLogFormatter.setAllLogFormatters();
        RegressionTreeOptions o = new RegressionTreeOptions();
        try {
            cm = new ConfigurationManager(args, (Options)o);
        }
        catch (UsageException e) {
            logger.info(e.getMessage());
            return;
        }
        RegressionFactory factory = new RegressionFactory(o.splitChar);
        Pair data = o.general.load((OutputFactory)factory);
        Dataset train = (Dataset)data.getA();
        Dataset test = (Dataset)data.getB();
        switch (o.impurityType) {
            case MAE: {
                impurity = new MeanAbsoluteError();
                break;
            }
            case MSE: {
                impurity = new MeanSquaredError();
                break;
            }
            default: {
                logger.severe("unknown impurity type " + (Object)((Object)o.impurityType));
                return;
            }
        }
        if (o.general.trainingPath == null || o.general.testingPath == null) {
            logger.info(cm.usage());
            return;
        }
        switch (o.treeType) {
            case CART_INDEPENDENT: {
                trainer = new CARTRegressionTrainer(o.depth, o.minChildWeight, o.minImpurityDecrease, o.fraction, o.useRandomSplitPoints, impurity, o.general.seed);
                break;
            }
            case CART_JOINT: {
                trainer = new CARTJointRegressionTrainer(o.depth, o.minChildWeight, o.minImpurityDecrease, o.fraction, o.useRandomSplitPoints, impurity, o.normalize, o.general.seed);
                break;
            }
            default: {
                logger.severe("unknown tree type " + (Object)((Object)o.treeType));
                return;
            }
        }
        logger.info("Training using " + trainer.toString());
        long trainStart = System.currentTimeMillis();
        SparseModel model = trainer.train(train);
        long trainStop = System.currentTimeMillis();
        logger.info("Finished training regressor " + Util.formatDuration((long)trainStart, (long)trainStop));
        if (o.printTree) {
            logger.info(model.toString());
        }
        logger.info("Selected features: " + model.getActiveFeatures());
        long testStart = System.currentTimeMillis();
        RegressionEvaluation evaluation = (RegressionEvaluation)factory.getEvaluator().evaluate((Model)model, test);
        long testStop = System.currentTimeMillis();
        logger.info("Finished evaluating model " + Util.formatDuration((long)testStart, (long)testStop));
        System.out.println(evaluation.toString());
        if (o.general.outputPath != null) {
            o.general.saveModel((Model)model);
        }
    }

    public static class RegressionTreeOptions
    implements Options {
        public DataOptions general;
        @Option(longName="csv-response-split-char", usage="Character to split the CSV response on to generate multiple regression dimensions. Defaults to ':'.")
        public char splitChar = (char)58;
        @Option(charName=100, longName="max-depth", usage="Maximum depth in the decision tree.")
        public int depth = 6;
        @Option(charName=101, longName="split-fraction", usage="Fraction of features in split.")
        public float fraction = 1.0f;
        @Option(charName=109, longName="min-child-weight", usage="Minimum child weight.")
        public float minChildWeight = 5.0f;
        @Option(charName=112, longName="min-impurity-decrease", usage="Minimumum decrease in impurity required in order for the node to be split.")
        public float minImpurityDecrease = 0.0f;
        @Option(charName=114, longName="use-random-split-points", usage="Whether to choose split points for features at random.")
        public boolean useRandomSplitPoints = false;
        @Option(charName=110, longName="normalize", usage="Normalize the leaf outputs so each leaf sums to 1.0.")
        public boolean normalize = false;
        @Option(charName=105, longName="impurity", usage="Impurity measure to use. Defaults to MSE.")
        public ImpurityType impurityType = ImpurityType.MSE;
        @Option(charName=116, longName="tree-type", usage="Tree type.")
        public TreeType treeType = TreeType.CART_INDEPENDENT;
        @Option(longName="print-tree", usage="Prints the decision tree.")
        public boolean printTree;

        public String getOptionsDescription() {
            return "Trains and tests a CART regression model on the specified datasets.";
        }
    }

    public static enum ImpurityType {
        MSE,
        MAE;

    }

    public static enum TreeType {
        CART_INDEPENDENT,
        CART_JOINT;

    }
}

