/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.xgboost;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.Visitor;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.model.visitors.DataDictionaryCleaner;
import org.jpmml.model.visitors.DeepFieldResolver;
import org.jpmml.model.visitors.MiningSchemaCleaner;
import org.jpmml.xgboost.FeatureMap;
import org.jpmml.xgboost.GBTree;
import org.jpmml.xgboost.LinearRegression;
import org.jpmml.xgboost.LogisticClassification;
import org.jpmml.xgboost.LogisticRegression;
import org.jpmml.xgboost.ObjFunction;
import org.jpmml.xgboost.PoissonRegression;
import org.jpmml.xgboost.SoftMaxClassification;
import org.jpmml.xgboost.XGBoostDataInput;

public class Learner {
    private float base_score;
    private int num_features;
    private int num_class;
    private int contain_extra_attrs;
    private ObjFunction obj;
    private GBTree gbtree;
    private Map<String, String> attributes = null;

    public void load(XGBoostDataInput input) throws IOException {
        String name_gbm;
        String name_obj;
        this.base_score = input.readFloat();
        this.num_features = input.readInt();
        this.num_class = input.readInt();
        this.contain_extra_attrs = input.readInt();
        input.readReserved(30);
        switch (name_obj = input.readString()) {
            case "reg:linear": {
                this.obj = new LinearRegression();
                break;
            }
            case "reg:logistic": {
                this.obj = new LogisticRegression();
                break;
            }
            case "count:poisson": {
                this.obj = new PoissonRegression();
                break;
            }
            case "binary:logistic": {
                this.obj = new LogisticClassification();
                break;
            }
            case "multi:softmax": 
            case "multi:softprob": {
                this.obj = new SoftMaxClassification(this.num_class);
                break;
            }
            default: {
                throw new IllegalArgumentException(name_obj);
            }
        }
        switch (name_gbm = input.readString()) {
            case "gbtree": {
                break;
            }
            default: {
                throw new IllegalArgumentException(name_gbm);
            }
        }
        this.gbtree = new GBTree();
        this.gbtree.load(input);
        if (this.contain_extra_attrs != 0) {
            this.attributes = input.readStringMap();
        }
    }

    public PMML encodePMML(FieldName targetField, List<String> targetCategories, FeatureMap featureMap) {
        if (targetField == null) {
            targetField = FieldName.create((String)"_target");
        }
        DataField dataField = new DataField(targetField, this.obj.getOpType(), this.obj.getDataType());
        if ((targetCategories = this.obj.prepareTargetCategories(targetCategories)) != null && targetCategories.size() > 0) {
            List values = dataField.getValues();
            values.addAll(PMMLUtil.createValues(targetCategories));
        }
        Schema schema = featureMap.createSchema(targetField, targetCategories);
        MiningModel miningModel = this.encodeMiningModel(schema);
        ArrayList<DataField> dataFields = new ArrayList<DataField>();
        dataFields.add(dataField);
        dataFields.addAll(featureMap.getDataFields());
        DataDictionary dataDictionary = new DataDictionary(dataFields);
        PMML pmml = new PMML("4.3", PMMLUtil.createHeader(Learner.class), dataDictionary).addModels(new Model[]{miningModel});
        List<DeepFieldResolver> visitors = Arrays.asList(new MiningSchemaCleaner(), new DataDictionaryCleaner());
        for (Visitor visitor : visitors) {
            visitor.applyTo((Visitable)pmml);
        }
        return pmml;
    }

    public MiningModel encodeMiningModel(Schema schema) {
        return this.gbtree.encodeMiningModel(this.obj, this.base_score, schema);
    }

    public float getBaseScore() {
        return this.base_score;
    }

    public int getNumClass() {
        return this.num_class;
    }

    public int getNumFeatures() {
        return this.num_features;
    }

    public ObjFunction getObj() {
        return this.obj;
    }

    public GBTree getGBTree() {
        return this.gbtree;
    }
}

