/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.xgboost;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Trainer;
import org.tribuo.common.xgboost.XGBoostModel;
import org.tribuo.common.xgboost.XGBoostTrainer;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.xgboost.XGBoostRegressionConverter;

public final class XGBoostRegressionTrainer
extends XGBoostTrainer<Regressor> {
    private static final Logger logger = Logger.getLogger(XGBoostRegressionTrainer.class.getName());
    @Config(description="The type of regression.")
    private RegressionType rType = RegressionType.LINEAR;

    public XGBoostRegressionTrainer(int numTrees) {
        this(RegressionType.LINEAR, numTrees);
    }

    public XGBoostRegressionTrainer(RegressionType rType, int numTrees) {
        this(rType, numTrees, 0.3, 0.0, 6, 1.0, 1.0, 1.0, 1.0, 0.0, 4, true, 12345L);
    }

    public XGBoostRegressionTrainer(RegressionType rType, int numTrees, int numThreads, boolean silent) {
        this(rType, numTrees, 0.3, 0.0, 6, 1.0, 1.0, 1.0, 1.0, 0.0, numThreads, silent, 12345L);
    }

    public XGBoostRegressionTrainer(RegressionType rType, int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed) {
        super(numTrees, eta, gamma, maxDepth, minChildWeight, subsample, featureSubsample, lambda, alpha, nThread, silent, seed);
        this.rType = rType;
        this.postConfig();
    }

    public XGBoostRegressionTrainer(XGBoostTrainer.BoosterType boosterType, XGBoostTrainer.TreeMethod treeMethod, RegressionType rType, int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, XGBoostTrainer.LoggingVerbosity verbosity, long seed) {
        super(boosterType, treeMethod, numTrees, eta, gamma, maxDepth, minChildWeight, subsample, featureSubsample, lambda, alpha, nThread, verbosity, seed);
        this.rType = rType;
        this.postConfig();
    }

    public XGBoostRegressionTrainer(RegressionType rType, int numTrees, Map<String, Object> parameters) {
        super(numTrees, parameters);
        this.rType = rType;
        this.postConfig();
    }

    private XGBoostRegressionTrainer() {
    }

    public void postConfig() {
        super.postConfig();
        this.parameters.put("objective", this.rType.paramName);
        if (!this.overrideParameters.isEmpty() && !((String)this.overrideParameters.get("objective")).equals(this.rType.paramName)) {
            throw new PropertyException("", "overrideParameters", "The objective in overrideParameters must match the one supplied as rType.");
        }
    }

    public synchronized XGBoostModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) {
        return this.train(examples, runProvenance, -1);
    }

    public synchronized XGBoostModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance, int invocationCount) {
        if (examples.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
        ImmutableOutputInfo outputInfo = examples.getOutputIDInfo();
        int numOutputs = outputInfo.size();
        if (invocationCount != -1) {
            this.setInvocationCount(invocationCount);
        }
        TrainerProvenance trainerProvenance = this.getProvenance();
        ++this.trainInvocationCounter;
        ArrayList<Booster> models = new ArrayList<Booster>();
        try {
            XGBoostTrainer.DMatrixTuple trainingData = XGBoostRegressionTrainer.convertExamples(examples, (ImmutableFeatureMap)featureMap, null);
            int[] dimensionIds = ((ImmutableRegressionInfo)outputInfo).getNaturalOrderToIDMapping();
            float[][] outputs = new float[numOutputs][examples.size()];
            float[] weights = new float[examples.size()];
            int i = 0;
            for (Example e : examples) {
                weights[i] = e.getWeight();
                double[] curOutputs = ((Regressor)e.getOutput()).getValues();
                for (int j = 0; j < numOutputs; ++j) {
                    outputs[dimensionIds[j]][i] = (float)curOutputs[j];
                }
                ++i;
            }
            trainingData.data.setWeight(weights);
            Map curParams = this.overrideParameters.isEmpty() ? this.copyParams(this.parameters) : this.copyParams(this.overrideParameters);
            for (i = 0; i < numOutputs; ++i) {
                trainingData.data.setLabel(outputs[i]);
                models.add(XGBoost.train((DMatrix)trainingData.data, (Map)curParams, (int)this.numTrees, Collections.emptyMap(), null, null));
            }
        }
        catch (XGBoostError e) {
            logger.log(Level.SEVERE, "XGBoost threw an error", e);
            throw new IllegalStateException(e);
        }
        ModelProvenance provenance = new ModelProvenance(XGBoostModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance);
        XGBoostModel xgModel = this.createModel("xgboost-regression-model", provenance, featureMap, outputInfo, models, new XGBoostRegressionConverter());
        return xgModel;
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }

    public static enum RegressionType {
        LINEAR("reg:squarederror"),
        GAMMA("reg:gamma"),
        TWEEDIE("reg:tweedie"),
        PSEUDOHUBER("reg:pseudohubererror");

        public final String paramName;

        private RegressionType(String paramName) {
            this.paramName = paramName;
        }
    }
}

