/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.XGBoostJNI;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class Booster
implements Serializable {
    public static final String DEFAULT_FORMAT = "deprecated";
    private static final Log logger = LogFactory.getLog(Booster.class);
    private long handle = 0L;
    private int version = 0;

    Booster(Map<String, Object> map, DMatrix[] dMatrixArray) throws XGBoostError {
        this.init(dMatrixArray);
        this.setParams(map);
    }

    static Booster loadModel(String string) throws XGBoostError {
        if (string == null) {
            throw new NullPointerException("modelPath : null");
        }
        Booster booster = new Booster(new HashMap<String, Object>(), new DMatrix[0]);
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModel(booster.handle, string));
        return booster;
    }

    static Booster loadModel(byte[] byArray) throws XGBoostError {
        Booster booster = new Booster(new HashMap<String, Object>(), new DMatrix[0]);
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(booster.handle, byArray));
        return booster;
    }

    public final void setParam(String string, Object object) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetParam(this.handle, string, object.toString()));
    }

    public void setParams(Map<String, Object> map) throws XGBoostError {
        if (map != null) {
            for (Map.Entry<String, Object> entry : map.entrySet()) {
                this.setParam(entry.getKey(), entry.getValue().toString());
            }
        }
    }

    public final Map<String, String> getAttrs() throws XGBoostError {
        String[][] stringArray = new String[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetAttrNames(this.handle, stringArray));
        HashMap<String, String> hashMap = new HashMap<String, String>();
        for (String string : stringArray[0]) {
            hashMap.put(string, this.getAttr(string));
        }
        return hashMap;
    }

    public final String getAttr(String string) throws XGBoostError {
        String[] stringArray = new String[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetAttr(this.handle, string, stringArray));
        return stringArray[0];
    }

    public final void setAttr(String string, String string2) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetAttr(this.handle, string, string2));
    }

    public void setAttrs(Map<String, String> map) throws XGBoostError {
        if (map != null) {
            for (Map.Entry<String, String> entry : map.entrySet()) {
                this.setAttr(entry.getKey(), entry.getValue());
            }
        }
    }

    public final String[] getFeatureNames() throws XGBoostError {
        int n = (int)this.getNumFeature();
        String[] stringArray = new String[n];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(this.handle, "feature_name", stringArray));
        return stringArray;
    }

    public void setFeatureNames(String[] stringArray) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo(this.handle, "feature_name", stringArray));
    }

    public final String[] getFeatureTypes() throws XGBoostError {
        int n = (int)this.getNumFeature();
        String[] stringArray = new String[n];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(this.handle, "feature_type", stringArray));
        return stringArray;
    }

    public void setFeatureTypes(String[] stringArray) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo(this.handle, "feature_type", stringArray));
    }

    public void update(DMatrix dMatrix, int n) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(this.handle, n, dMatrix.getHandle()));
    }

    public void update(DMatrix dMatrix, IObjective iObjective) throws XGBoostError {
        float[][] fArray = this.predict(dMatrix, true, 0, false, false);
        List<float[]> list = iObjective.getGradient(fArray, dMatrix);
        this.boost(dMatrix, list.get(0), list.get(1));
    }

    public void boost(DMatrix dMatrix, float[] fArray, float[] fArray2) throws XGBoostError {
        if (fArray.length != fArray2.length) {
            throw new AssertionError((Object)String.format("grad/hess length mismatch %s / %s", fArray.length, fArray2.length));
        }
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterBoostOneIter(this.handle, dMatrix.getHandle(), fArray, fArray2));
    }

    public String evalSet(DMatrix[] dMatrixArray, String[] stringArray, int n) throws XGBoostError {
        long[] lArray = Booster.dmatrixsToHandles(dMatrixArray);
        String[] stringArray2 = new String[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterEvalOneIter(this.handle, n, lArray, stringArray, stringArray2));
        return stringArray2[0];
    }

    public String evalSet(DMatrix[] dMatrixArray, String[] stringArray, int n, float[] fArray) throws XGBoostError {
        String string = this.evalSet(dMatrixArray, stringArray, n);
        String[] stringArray2 = string.split("\t");
        for (int i = 1; i < stringArray2.length; ++i) {
            String string2 = stringArray2[i].split(":")[1];
            fArray[i - 1] = string2.equalsIgnoreCase("nan") ? Float.NaN : (string2.equalsIgnoreCase("-nan") ? Float.NaN : Float.valueOf(string2).floatValue());
        }
        return string;
    }

    public String evalSet(DMatrix[] dMatrixArray, String[] stringArray, IEvaluation iEvaluation) throws XGBoostError {
        return this.evalSet(dMatrixArray, stringArray, iEvaluation, new float[stringArray.length]);
    }

    public String evalSet(DMatrix[] dMatrixArray, String[] stringArray, IEvaluation iEvaluation, float[] fArray) throws XGBoostError {
        Object object = "";
        for (int i = 0; i < stringArray.length; ++i) {
            String string = stringArray[i];
            DMatrix dMatrix = dMatrixArray[i];
            float f = iEvaluation.eval(this.predict(dMatrix), dMatrix);
            String string2 = iEvaluation.getMetric();
            object = (String)object + String.format("\t%s-%s:%f", string, string2, Float.valueOf(f));
            fArray[i] = f;
        }
        return object;
    }

    private synchronized float[][] predict(DMatrix dMatrix, boolean bl, int n, boolean bl2, boolean bl3) throws XGBoostError {
        int n2 = 0;
        if (bl) {
            n2 = 1;
        }
        if (bl2) {
            n2 = 2;
        }
        if (bl3) {
            n2 = 4;
        }
        float[][] fArrayArray = new float[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredict(this.handle, dMatrix.getHandle(), n2, n, fArrayArray));
        int n3 = (int)dMatrix.rowNum();
        int n4 = fArrayArray[0].length / n3;
        float[][] fArray = new float[n3][n4];
        for (int i = 0; i < fArrayArray[0].length; ++i) {
            int n5 = i / n4;
            int n6 = i % n4;
            fArray[n5][n6] = fArrayArray[0][i];
        }
        return fArray;
    }

    public float[][] predictLeaf(DMatrix dMatrix, int n) throws XGBoostError {
        return this.predict(dMatrix, false, n, true, false);
    }

    public float[][] predictContrib(DMatrix dMatrix, int n) throws XGBoostError {
        return this.predict(dMatrix, false, n, true, true);
    }

    public float[][] predict(DMatrix dMatrix) throws XGBoostError {
        return this.predict(dMatrix, false, 0, false, false);
    }

    public float[][] predict(DMatrix dMatrix, boolean bl) throws XGBoostError {
        return this.predict(dMatrix, bl, 0, false, false);
    }

    public float[][] predict(DMatrix dMatrix, boolean bl, int n) throws XGBoostError {
        return this.predict(dMatrix, bl, n, false, false);
    }

    public void saveModel(String string) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveModel(this.handle, string));
    }

    public void saveModel(OutputStream outputStream) throws XGBoostError, IOException {
        this.saveModel(outputStream, DEFAULT_FORMAT);
    }

    public void saveModel(OutputStream outputStream, String string) throws XGBoostError, IOException {
        outputStream.write(this.toByteArray(string));
        outputStream.close();
    }

    public String[] getModelDump(String string, boolean bl) throws XGBoostError {
        return this.getModelDump(string, bl, "text");
    }

    public String[] getModelDump(String string, boolean bl, String string2) throws XGBoostError {
        int n = 0;
        if (string == null) {
            string = "";
        }
        if (bl) {
            n = 1;
        }
        if (string2 == null) {
            string2 = "text";
        }
        String[][] stringArray = new String[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(this.handle, string, n, string2, stringArray));
        return stringArray[0];
    }

    public String[] getModelDump(String[] stringArray, boolean bl) throws XGBoostError {
        return this.getModelDump(stringArray, bl, "text");
    }

    public String[] getModelDump(String[] stringArray, boolean bl, String string) throws XGBoostError {
        int n = 0;
        if (bl) {
            n = 1;
        }
        if (string == null) {
            string = "text";
        }
        String[][] stringArray2 = new String[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelExWithFeatures(this.handle, stringArray, n, string, stringArray2));
        return stringArray2[0];
    }

    public Map<String, Integer> getFeatureScore(String[] stringArray) throws XGBoostError {
        String[] stringArray2 = this.getModelDump(stringArray, false);
        return this.getFeatureWeightsFromModel(stringArray2);
    }

    public Map<String, Integer> getFeatureScore(String string) throws XGBoostError {
        String[] stringArray = this.getModelDump(string, false);
        return this.getFeatureWeightsFromModel(stringArray);
    }

    private Map<String, Integer> getFeatureWeightsFromModel(String[] stringArray) throws XGBoostError {
        HashMap<String, Integer> hashMap = new HashMap<String, Integer>();
        for (String string : stringArray) {
            for (String string2 : string.split("\n")) {
                String[] stringArray2 = string2.split("\\[");
                if (stringArray2.length == 1) continue;
                String string3 = stringArray2[1].split("\\]")[0];
                if (hashMap.containsKey(string3 = string3.split("<")[0])) {
                    hashMap.put(string3, 1 + (Integer)hashMap.get(string3));
                    continue;
                }
                hashMap.put(string3, 1);
            }
        }
        return hashMap;
    }

    public Map<String, Double> getScore(String[] stringArray, String string) throws XGBoostError {
        String[] stringArray2 = this.getModelDump(stringArray, true);
        return this.getFeatureImportanceFromModel(stringArray2, string);
    }

    public Map<String, Double> getScore(String string, String string2) throws XGBoostError {
        String[] stringArray = this.getModelDump(string, true);
        return this.getFeatureImportanceFromModel(stringArray, string2);
    }

    private Map<String, Double> getFeatureImportanceFromModel(String[] stringArray, String string) throws XGBoostError {
        if (!FeatureImportanceType.ACCEPTED_TYPES.contains(string)) {
            throw new AssertionError((Object)String.format("Importance type %s is not supported", string));
        }
        HashMap<String, Double> hashMap = new HashMap<String, Double>();
        HashMap<String, Double> hashMap2 = new HashMap<String, Double>();
        if (string.equals("weight")) {
            Map<String, Integer> map = this.getFeatureWeightsFromModel(stringArray);
            for (String string2 : map.keySet()) {
                hashMap.put(string2, new Double(map.get(string2).intValue()));
            }
            return hashMap;
        }
        String string3 = "gain=";
        if (string.equals("cover") || string.equals("total_cover")) {
            string3 = "cover=";
        }
        for (String string4 : stringArray) {
            for (String string5 : string4.split("\n")) {
                String[] stringArray2 = string5.split("\\[");
                if (stringArray2.length == 1) continue;
                String[] stringArray3 = stringArray2[1].split("\\]");
                Double d = Double.parseDouble(stringArray3[1].split(string3)[1].split(",")[0]);
                String string6 = stringArray3[0].split("<")[0];
                if (hashMap.containsKey(string6)) {
                    hashMap.put(string6, d + (Double)hashMap.get(string6));
                    hashMap2.put(string6, 1.0 + (Double)hashMap2.get(string6));
                    continue;
                }
                hashMap.put(string6, d);
                hashMap2.put(string6, 1.0);
            }
        }
        if (string.equals("cover") || string.equals("gain")) {
            for (String string7 : hashMap.keySet()) {
                hashMap.put(string7, (Double)hashMap.get(string7) / (Double)hashMap2.get(string7));
            }
        }
        return hashMap;
    }

    private String[] getDumpInfo(boolean bl) throws XGBoostError {
        int n = 0;
        if (bl) {
            n = 1;
        }
        String[][] stringArray = new String[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(this.handle, "", n, "text", stringArray));
        return stringArray[0];
    }

    public int getVersion() {
        return this.version;
    }

    public void setVersion(int n) {
        this.version = n;
    }

    public byte[] toByteArray() throws XGBoostError {
        return this.toByteArray(DEFAULT_FORMAT);
    }

    public byte[] toByteArray(String string) throws XGBoostError {
        byte[][] byArrayArray = new byte[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveModelToBuffer(this.handle, string, byArrayArray));
        return byArrayArray[0];
    }

    int loadRabitCheckpoint() throws XGBoostError {
        int[] nArray = new int[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, nArray));
        this.version = nArray[0];
        return this.version;
    }

    void saveRabitCheckpoint() throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
        ++this.version;
    }

    public long getNumFeature() throws XGBoostError {
        long[] lArray = new long[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, lArray));
        return lArray[0];
    }

    private void init(DMatrix[] dMatrixArray) throws XGBoostError {
        long[] lArray = null;
        if (dMatrixArray != null) {
            lArray = Booster.dmatrixsToHandles(dMatrixArray);
        }
        long[] lArray2 = new long[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterCreate(lArray, lArray2));
        this.handle = lArray2[0];
    }

    private static long[] dmatrixsToHandles(DMatrix[] dMatrixArray) {
        long[] lArray = new long[dMatrixArray.length];
        for (int i = 0; i < dMatrixArray.length; ++i) {
            lArray[i] = dMatrixArray[i].getHandle();
        }
        return lArray;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        try {
            objectOutputStream.writeInt(this.version);
            objectOutputStream.writeObject(this.toByteArray("ubj"));
        }
        catch (XGBoostError xGBoostError) {
            xGBoostError.printStackTrace();
            logger.error((Object)xGBoostError.getMessage());
        }
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        try {
            this.init(null);
            this.version = objectInputStream.readInt();
            byte[] byArray = (byte[])objectInputStream.readObject();
            XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, byArray));
        }
        catch (XGBoostError xGBoostError) {
            xGBoostError.printStackTrace();
            logger.error((Object)xGBoostError.getMessage());
        }
    }

    protected void finalize() throws Throwable {
        super.finalize();
        this.dispose();
    }

    public synchronized void dispose() {
        if (this.handle != 0L) {
            XGBoostJNI.XGBoosterFree(this.handle);
            this.handle = 0L;
        }
    }

    public static class FeatureImportanceType {
        public static final String WEIGHT = "weight";
        public static final String GAIN = "gain";
        public static final String COVER = "cover";
        public static final String TOTAL_GAIN = "total_gain";
        public static final String TOTAL_COVER = "total_cover";
        public static final Set<String> ACCEPTED_TYPES = new HashSet<String>(Arrays.asList("weight", "gain", "cover", "total_gain", "total_cover"));
    }
}

