/*
 * Decompiled with CFR 0.152.
 */
package sklearn.ensemble.gradient_boosting;

import java.util.ArrayList;
import java.util.List;
import java.util.function.IntFunction;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.python.AttributeException;
import org.jpmml.python.PythonObject;
import sklearn.Estimator;
import sklearn.HasEstimatorEnsemble;
import sklearn.HasMultiDecisionFunctionField;
import sklearn.HasPriorProbability;
import sklearn.SkLearnClassifier;
import sklearn.VersionUtil;
import sklearn.ensemble.gradient_boosting.GradientBoostingUtil;
import sklearn.ensemble.gradient_boosting.LossFunction;
import sklearn.loss.BaseLoss;
import sklearn.loss.HalfLogitLink;
import sklearn.loss.Link;
import sklearn.tree.HasTreeOptions;
import sklearn.tree.TreeRegressor;
import sklearn.tree.TreeUtil;
import sklearn2pmml.EstimatorProxy;

public class GradientBoostingClassifier
extends SkLearnClassifier
implements HasEstimatorEnsemble<TreeRegressor>,
HasMultiDecisionFunctionField,
HasTreeOptions {
    public GradientBoostingClassifier(String module, String name) {
        super(module, name);
    }

    @Override
    public int getNumberOfFeatures() {
        if (this.hasattr("n_features")) {
            return this.getInteger("n_features");
        }
        return super.getNumberOfFeatures();
    }

    @Override
    public DataType getDataType() {
        return DataType.FLOAT;
    }

    public MiningModel encodeModel(Schema schema) {
        MiningModel miningModel;
        String sklearnVersion = this.getSkLearnVersion();
        HasPriorProbability init = this.getInit();
        Number learningRate = this.getLearningRate();
        PythonObject loss = this.getLoss();
        IntFunction<Number> initialPredictions = init::getPriorProbability;
        if (loss instanceof LossFunction) {
            LossFunction lossFunction = (LossFunction)loss;
            if (sklearnVersion != null && VersionUtil.compareVersion(sklearnVersion, "0.21") >= 0) {
                List<? extends Number> computedInitialPredictions = lossFunction.computeInitialPredictions(init);
                initialPredictions = computedInitialPredictions::get;
            }
        } else if (loss instanceof BaseLoss) {
            BaseLoss baseLoss = (BaseLoss)loss;
            if (sklearnVersion != null && VersionUtil.compareVersion(sklearnVersion, "1.4.0") >= 0) {
                Link link = baseLoss.getLink();
                int numClasses = baseLoss.getNumClasses();
                List<? extends Number> computedInitialPredictions = link.computeInitialPredictions(numClasses, init);
                initialPredictions = computedInitialPredictions::get;
            } else {
                throw new IllegalArgumentException();
            }
        }
        Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        Transformation[] transformations = new Transformation[]{};
        if (loss instanceof LossFunction) {
            LossFunction lossFunction = (LossFunction)loss;
            transformations = new Transformation[]{lossFunction.createTransformation()};
        }
        if (categoricalLabel.size() == 2) {
            SchemaUtil.checkSize((int)2, (DiscreteLabel)categoricalLabel);
            MiningModel model = GradientBoostingUtil.encodeGradientBoosting(this, initialPredictions.apply(1), learningRate, segmentSchema).setOutput(ModelUtil.createPredictedOutput((String)this.getMultiDecisionFunctionField(categoricalLabel.getValue(1)), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])transformations));
            double coefficient = 1.0;
            RegressionModel.NormalizationMethod normalizationMethod = RegressionModel.NormalizationMethod.NONE;
            if (loss instanceof BaseLoss) {
                BaseLoss baseLoss = (BaseLoss)loss;
                Link link = baseLoss.getLink();
                normalizationMethod = RegressionModel.NormalizationMethod.LOGIT;
                if (link instanceof HalfLogitLink) {
                    coefficient = 2.0;
                }
            }
            miningModel = MiningModelUtil.createBinaryLogisticClassification((Model)model, (double)coefficient, (double)0.0, (RegressionModel.NormalizationMethod)normalizationMethod, (boolean)false, (Schema)schema);
        } else if (categoricalLabel.size() > 2) {
            List<TreeRegressor> estimators = this.getEstimators();
            ArrayList<MiningModel> models = new ArrayList<MiningModel>();
            int columns = categoricalLabel.size();
            int rows = estimators.size() / columns;
            for (int i = 0; i < columns; ++i) {
                final List columnEstimators = CMatrixUtil.getColumn(estimators, (int)rows, (int)columns, (int)i);
                GradientBoostingClassifierProxy estimatorProxy = new GradientBoostingClassifierProxy(){

                    @Override
                    public List<TreeRegressor> getEstimators() {
                        return columnEstimators;
                    }
                };
                MiningModel model = GradientBoostingUtil.encodeGradientBoosting(estimatorProxy, initialPredictions.apply(i), learningRate, segmentSchema).setOutput(ModelUtil.createPredictedOutput((String)this.getMultiDecisionFunctionField(categoricalLabel.getValue(i)), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])transformations));
                models.add(model);
            }
            RegressionModel.NormalizationMethod normalizationMethod = RegressionModel.NormalizationMethod.SIMPLEMAX;
            if (loss instanceof BaseLoss) {
                normalizationMethod = RegressionModel.NormalizationMethod.SOFTMAX;
            }
            miningModel = MiningModelUtil.createClassification(models, (RegressionModel.NormalizationMethod)normalizationMethod, (boolean)false, (Schema)schema);
        } else {
            throw new IllegalArgumentException();
        }
        this.encodePredictProbaOutput((Model)miningModel, DataType.DOUBLE, (DiscreteLabel)categoricalLabel);
        return miningModel;
    }

    @Override
    public Schema configureSchema(Schema schema) {
        return TreeUtil.configureSchema(this, schema);
    }

    @Override
    public Model configureModel(Model model) {
        return TreeUtil.configureModel(this, model);
    }

    @Override
    public List<TreeRegressor> getEstimators() {
        return this.getArray("estimators_", TreeRegressor.class);
    }

    public HasPriorProbability getInit() {
        return (HasPriorProbability)this.get("init_", HasPriorProbability.class);
    }

    public Number getLearningRate() {
        return this.getNumber("learning_rate");
    }

    public PythonObject getLoss() {
        if (this.hasattr("loss_")) {
            return (PythonObject)this.get("loss_", LossFunction.class);
        }
        try {
            return (PythonObject)this.get("_loss", LossFunction.class);
        }
        catch (AttributeException ae) {
            return (PythonObject)this.get("_loss", BaseLoss.class);
        }
    }

    private abstract class GradientBoostingClassifierProxy
    extends EstimatorProxy
    implements HasEstimatorEnsemble<TreeRegressor>,
    HasTreeOptions {
        private GradientBoostingClassifierProxy() {
        }

        @Override
        public Estimator getEstimator() {
            return GradientBoostingClassifier.this;
        }
    }
}

