/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.KryoSerializable;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.XGBoostJNI;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class Booster
implements Serializable,
KryoSerializable {
    private static final Log logger = LogFactory.getLog(Booster.class);
    private long handle = 0L;
    private int version = 0;

    Booster(Map<String, Object> map, DMatrix[] dMatrixArray) throws XGBoostError {
        this.init(dMatrixArray);
        this.setParam("seed", "0");
        this.setParams(map);
    }

    static Booster loadModel(String string) throws XGBoostError {
        if (string == null) {
            throw new NullPointerException("modelPath : null");
        }
        Booster booster = new Booster(new HashMap<String, Object>(), new DMatrix[0]);
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModel(booster.handle, string));
        return booster;
    }

    static Booster loadModel(InputStream inputStream) throws XGBoostError, IOException {
        int n;
        byte[] byArray = new byte[0x100000];
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        while ((n = inputStream.read(byArray)) != -1) {
            byteArrayOutputStream.write(byArray, 0, n);
        }
        inputStream.close();
        Booster booster = new Booster(new HashMap<String, Object>(), new DMatrix[0]);
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(booster.handle, byteArrayOutputStream.toByteArray()));
        return booster;
    }

    public final void setParam(String string, Object object) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetParam(this.handle, string, object.toString()));
    }

    public void setParams(Map<String, Object> map) throws XGBoostError {
        if (map != null) {
            for (Map.Entry<String, Object> entry : map.entrySet()) {
                this.setParam(entry.getKey(), entry.getValue().toString());
            }
        }
    }

    public void update(DMatrix dMatrix, int n) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(this.handle, n, dMatrix.getHandle()));
    }

    public void update(DMatrix dMatrix, IObjective iObjective) throws XGBoostError {
        float[][] fArray = this.predict(dMatrix, true, 0, false, false);
        List<float[]> list = iObjective.getGradient(fArray, dMatrix);
        this.boost(dMatrix, list.get(0), list.get(1));
    }

    public void boost(DMatrix dMatrix, float[] fArray, float[] fArray2) throws XGBoostError {
        if (fArray.length != fArray2.length) {
            throw new AssertionError((Object)String.format("grad/hess length mismatch %s / %s", fArray.length, fArray2.length));
        }
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterBoostOneIter(this.handle, dMatrix.getHandle(), fArray, fArray2));
    }

    public String evalSet(DMatrix[] dMatrixArray, String[] stringArray, int n) throws XGBoostError {
        long[] lArray = Booster.dmatrixsToHandles(dMatrixArray);
        String[] stringArray2 = new String[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterEvalOneIter(this.handle, n, lArray, stringArray, stringArray2));
        return stringArray2[0];
    }

    public String evalSet(DMatrix[] dMatrixArray, String[] stringArray, int n, float[] fArray) throws XGBoostError {
        String string = this.evalSet(dMatrixArray, stringArray, n);
        String[] stringArray2 = string.split("\t");
        for (int i = 1; i < stringArray2.length; ++i) {
            fArray[i - 1] = Float.valueOf(stringArray2[i].split(":")[1]).floatValue();
        }
        return string;
    }

    public String evalSet(DMatrix[] dMatrixArray, String[] stringArray, IEvaluation iEvaluation) throws XGBoostError {
        return this.evalSet(dMatrixArray, stringArray, iEvaluation, new float[stringArray.length]);
    }

    public String evalSet(DMatrix[] dMatrixArray, String[] stringArray, IEvaluation iEvaluation, float[] fArray) throws XGBoostError {
        String string = "";
        for (int i = 0; i < stringArray.length; ++i) {
            String string2 = stringArray[i];
            DMatrix dMatrix = dMatrixArray[i];
            float f = iEvaluation.eval(this.predict(dMatrix), dMatrix);
            String string3 = iEvaluation.getMetric();
            string = string + String.format("\t%s-%s:%f", string2, string3, Float.valueOf(f));
            fArray[i] = f;
        }
        return string;
    }

    private synchronized float[][] predict(DMatrix dMatrix, boolean bl, int n, boolean bl2, boolean bl3) throws XGBoostError {
        int n2 = 0;
        if (bl) {
            n2 = 1;
        }
        if (bl2) {
            n2 = 2;
        }
        if (bl3) {
            n2 = 4;
        }
        float[][] fArrayArray = new float[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredict(this.handle, dMatrix.getHandle(), n2, n, fArrayArray));
        int n3 = (int)dMatrix.rowNum();
        int n4 = fArrayArray[0].length / n3;
        float[][] fArray = new float[n3][n4];
        for (int i = 0; i < fArrayArray[0].length; ++i) {
            int n5 = i / n4;
            int n6 = i % n4;
            fArray[n5][n6] = fArrayArray[0][i];
        }
        return fArray;
    }

    public float[][] predictLeaf(DMatrix dMatrix, int n) throws XGBoostError {
        return this.predict(dMatrix, false, n, true, false);
    }

    public float[][] predictContrib(DMatrix dMatrix, int n) throws XGBoostError {
        return this.predict(dMatrix, false, n, true, true);
    }

    public float[][] predict(DMatrix dMatrix) throws XGBoostError {
        return this.predict(dMatrix, false, 0, false, false);
    }

    public float[][] predict(DMatrix dMatrix, boolean bl) throws XGBoostError {
        return this.predict(dMatrix, bl, 0, false, false);
    }

    public float[][] predict(DMatrix dMatrix, boolean bl, int n) throws XGBoostError {
        return this.predict(dMatrix, bl, n, false, false);
    }

    public void saveModel(String string) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveModel(this.handle, string));
    }

    public void saveModel(OutputStream outputStream) throws XGBoostError, IOException {
        outputStream.write(this.toByteArray());
        outputStream.close();
    }

    public String[] getModelDump(String string, boolean bl) throws XGBoostError {
        return this.getModelDump(string, bl, "text");
    }

    public String[] getModelDump(String string, boolean bl, String string2) throws XGBoostError {
        int n = 0;
        if (string == null) {
            string = "";
        }
        if (bl) {
            n = 1;
        }
        if (string2 == null) {
            string2 = "text";
        }
        String[][] stringArray = new String[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(this.handle, string, n, string2, stringArray));
        return stringArray[0];
    }

    public String[] getModelDump(String[] stringArray, boolean bl) throws XGBoostError {
        return this.getModelDump(stringArray, bl, "text");
    }

    public String[] getModelDump(String[] stringArray, boolean bl, String string) throws XGBoostError {
        int n = 0;
        if (bl) {
            n = 1;
        }
        if (string == null) {
            string = "text";
        }
        String[][] stringArray2 = new String[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelExWithFeatures(this.handle, stringArray, n, string, stringArray2));
        return stringArray2[0];
    }

    public Map<String, Integer> getFeatureScore(String[] stringArray) throws XGBoostError {
        String[] stringArray2 = this.getModelDump(stringArray, false);
        HashMap<String, Integer> hashMap = new HashMap<String, Integer>();
        for (String string : stringArray2) {
            for (String string2 : string.split("\n")) {
                String[] stringArray3 = string2.split("\\[");
                if (stringArray3.length == 1) continue;
                String string3 = stringArray3[1].split("\\]")[0];
                if (hashMap.containsKey(string3 = string3.split("<")[0])) {
                    hashMap.put(string3, 1 + (Integer)hashMap.get(string3));
                    continue;
                }
                hashMap.put(string3, 1);
            }
        }
        return hashMap;
    }

    public Map<String, Integer> getFeatureScore(String string) throws XGBoostError {
        String[] stringArray = this.getModelDump(string, false);
        HashMap<String, Integer> hashMap = new HashMap<String, Integer>();
        for (String string2 : stringArray) {
            for (String string3 : string2.split("\n")) {
                String[] stringArray2 = string3.split("\\[");
                if (stringArray2.length == 1) continue;
                String string4 = stringArray2[1].split("\\]")[0];
                if (hashMap.containsKey(string4 = string4.split("<")[0])) {
                    hashMap.put(string4, 1 + (Integer)hashMap.get(string4));
                    continue;
                }
                hashMap.put(string4, 1);
            }
        }
        return hashMap;
    }

    private String[] getDumpInfo(boolean bl) throws XGBoostError {
        int n = 0;
        if (bl) {
            n = 1;
        }
        String[][] stringArray = new String[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(this.handle, "", n, "text", stringArray));
        return stringArray[0];
    }

    public int getVersion() {
        return this.version;
    }

    public void setVersion(int n) {
        this.version = n;
    }

    public byte[] toByteArray() throws XGBoostError {
        byte[][] byArrayArray = new byte[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, byArrayArray));
        return byArrayArray[0];
    }

    int loadRabitCheckpoint() throws XGBoostError {
        int[] nArray = new int[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, nArray));
        this.version = nArray[0];
        return this.version;
    }

    void saveRabitCheckpoint() throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
        ++this.version;
    }

    private void init(DMatrix[] dMatrixArray) throws XGBoostError {
        long[] lArray = null;
        if (dMatrixArray != null) {
            lArray = Booster.dmatrixsToHandles(dMatrixArray);
        }
        long[] lArray2 = new long[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterCreate(lArray, lArray2));
        this.handle = lArray2[0];
    }

    private static long[] dmatrixsToHandles(DMatrix[] dMatrixArray) {
        long[] lArray = new long[dMatrixArray.length];
        for (int i = 0; i < dMatrixArray.length; ++i) {
            lArray[i] = dMatrixArray[i].getHandle();
        }
        return lArray;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        try {
            objectOutputStream.writeInt(this.version);
            objectOutputStream.writeObject(this.toByteArray());
        }
        catch (XGBoostError xGBoostError) {
            xGBoostError.printStackTrace();
            logger.error((Object)xGBoostError.getMessage());
        }
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        try {
            this.init(null);
            this.version = objectInputStream.readInt();
            byte[] byArray = (byte[])objectInputStream.readObject();
            XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, byArray));
        }
        catch (XGBoostError xGBoostError) {
            xGBoostError.printStackTrace();
            logger.error((Object)xGBoostError.getMessage());
        }
    }

    protected void finalize() throws Throwable {
        super.finalize();
        this.dispose();
    }

    public synchronized void dispose() {
        if (this.handle != 0L) {
            XGBoostJNI.XGBoosterFree(this.handle);
            this.handle = 0L;
        }
    }

    public void write(Kryo kryo, Output output) {
        try {
            byte[] byArray = this.toByteArray();
            int n = byArray.length;
            output.writeInt(n);
            output.writeInt(this.version);
            output.write(byArray);
        }
        catch (XGBoostError xGBoostError) {
            logger.error((Object)xGBoostError.getMessage(), (Throwable)xGBoostError);
        }
    }

    public void read(Kryo kryo, Input input) {
        try {
            this.init(null);
            int n = input.readInt();
            this.version = input.readInt();
            byte[] byArray = new byte[n];
            input.readBytes(byArray);
            XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, byArray));
        }
        catch (XGBoostError xGBoostError) {
            logger.error((Object)xGBoostError.getMessage(), (Throwable)xGBoostError);
        }
    }
}

