/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.lang.reflect.Constructor;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.XGBoostJNI;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class Booster
implements Serializable {
    private static final Log logger;
    private static final boolean USE_KRYO_BOOSTER;
    private static final Class<?> KRYO_BOOSTER_CLASS;
    private long handle = 0L;

    static Booster newBooster(Map<String, Object> params, DMatrix[] cacheMats) throws XGBoostError {
        if (USE_KRYO_BOOSTER) {
            return Booster.newKryoBooster(params, cacheMats);
        }
        return new Booster(params, cacheMats, false);
    }

    private static Booster newKryoBooster(Map<String, Object> params, DMatrix[] cacheMats) throws XGBoostError {
        try {
            Constructor<?> constuctor = KRYO_BOOSTER_CLASS.getDeclaredConstructors()[0];
            return (Booster)constuctor.newInstance(params, cacheMats);
        }
        catch (IllegalArgumentException | ReflectiveOperationException e) {
            logger.error(e);
            throw new XGBoostError(e.getMessage());
        }
    }

    protected Booster(Map<String, Object> params, DMatrix[] cacheMats, boolean isKryoBooster) throws XGBoostError {
        this(isKryoBooster);
        this.init(cacheMats);
        this.setParam("seed", "0");
        this.setParams(params);
    }

    protected Booster(boolean isKryoBooster) {
        if (USE_KRYO_BOOSTER != isKryoBooster) {
            throw new IllegalStateException("Attempt to instantiate a Booster without support for Kryo in an environment that supports Kryo.");
        }
    }

    static Booster loadModel(String modelPath) throws XGBoostError {
        if (modelPath == null) {
            throw new NullPointerException("modelPath : null");
        }
        Booster ret = Booster.newBooster(new HashMap<String, Object>(), new DMatrix[0]);
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModel(ret.handle, modelPath));
        return ret;
    }

    static Booster loadModel(InputStream in) throws XGBoostError, IOException {
        int size;
        byte[] buf = new byte[0x100000];
        ByteArrayOutputStream os = new ByteArrayOutputStream();
        while ((size = in.read(buf)) != -1) {
            os.write(buf, 0, size);
        }
        in.close();
        Booster ret = Booster.newBooster(new HashMap<String, Object>(), new DMatrix[0]);
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle, os.toByteArray()));
        return ret;
    }

    public final void setParam(String key, Object value) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetParam(this.handle, key, value.toString()));
    }

    public void setParams(Map<String, Object> params) throws XGBoostError {
        if (params != null) {
            for (Map.Entry<String, Object> entry : params.entrySet()) {
                this.setParam(entry.getKey(), entry.getValue().toString());
            }
        }
    }

    public void update(DMatrix dtrain, int iter) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(this.handle, iter, dtrain.getHandle()));
    }

    public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
        float[][] predicts = this.predict(dtrain, true, 0, false, false);
        List<float[]> gradients = obj.getGradient(predicts, dtrain);
        this.boost(dtrain, gradients.get(0), gradients.get(1));
    }

    public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
        if (grad.length != hess.length) {
            throw new AssertionError((Object)String.format("grad/hess length mismatch %s / %s", grad.length, hess.length));
        }
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterBoostOneIter(this.handle, dtrain.getHandle(), grad, hess));
    }

    public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
        long[] handles = Booster.dmatrixsToHandles(evalMatrixs);
        String[] evalInfo = new String[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterEvalOneIter(this.handle, iter, handles, evalNames, evalInfo));
        return evalInfo[0];
    }

    public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, float[] metricsOut) throws XGBoostError {
        String stringFormat = this.evalSet(evalMatrixs, evalNames, iter);
        String[] metricPairs = stringFormat.split("\t");
        for (int i = 1; i < metricPairs.length; ++i) {
            metricsOut[i - 1] = Float.valueOf(metricPairs[i].split(":")[1]).floatValue();
        }
        return stringFormat;
    }

    public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval2) throws XGBoostError {
        return this.evalSet(evalMatrixs, evalNames, eval2, new float[evalNames.length]);
    }

    public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval2, float[] metricsOut) throws XGBoostError {
        String evalInfo = "";
        for (int i = 0; i < evalNames.length; ++i) {
            String evalName = evalNames[i];
            DMatrix evalMat = evalMatrixs[i];
            float evalResult = eval2.eval(this.predict(evalMat), evalMat);
            String evalMetric = eval2.getMetric();
            evalInfo = evalInfo + String.format("\t%s-%s:%f", evalName, evalMetric, Float.valueOf(evalResult));
            metricsOut[i] = evalResult;
        }
        return evalInfo;
    }

    private synchronized float[][] predict(DMatrix data, boolean outputMargin, int treeLimit, boolean predLeaf, boolean predContribs) throws XGBoostError {
        int optionMask = 0;
        if (outputMargin) {
            optionMask = 1;
        }
        if (predLeaf) {
            optionMask = 2;
        }
        if (predContribs) {
            optionMask = 4;
        }
        float[][] rawPredicts = new float[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredict(this.handle, data.getHandle(), optionMask, treeLimit, rawPredicts));
        int row = (int)data.rowNum();
        int col = rawPredicts[0].length / row;
        float[][] predicts = new float[row][col];
        for (int i = 0; i < rawPredicts[0].length; ++i) {
            int r = i / col;
            int c = i % col;
            predicts[r][c] = rawPredicts[0][i];
        }
        return predicts;
    }

    public float[][] predictLeaf(DMatrix data, int treeLimit) throws XGBoostError {
        return this.predict(data, false, treeLimit, true, false);
    }

    public float[][] predictContrib(DMatrix data, int treeLimit) throws XGBoostError {
        return this.predict(data, false, treeLimit, true, true);
    }

    public float[][] predict(DMatrix data) throws XGBoostError {
        return this.predict(data, false, 0, false, false);
    }

    public float[][] predict(DMatrix data, boolean outputMargin) throws XGBoostError {
        return this.predict(data, outputMargin, 0, false, false);
    }

    public float[][] predict(DMatrix data, boolean outputMargin, int treeLimit) throws XGBoostError {
        return this.predict(data, outputMargin, treeLimit, false, false);
    }

    public void saveModel(String modelPath) throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveModel(this.handle, modelPath));
    }

    public void saveModel(OutputStream out) throws XGBoostError, IOException {
        out.write(this.toByteArray());
        out.close();
    }

    public String[] getModelDump(String featureMap, boolean withStats) throws XGBoostError {
        return this.getModelDump(featureMap, withStats, "text");
    }

    public String[] getModelDump(String featureMap, boolean withStats, String format) throws XGBoostError {
        int statsFlag = 0;
        if (featureMap == null) {
            featureMap = "";
        }
        if (withStats) {
            statsFlag = 1;
        }
        if (format == null) {
            format = "text";
        }
        String[][] modelInfos = new String[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(this.handle, featureMap, statsFlag, format, modelInfos));
        return modelInfos[0];
    }

    public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
        String[] modelInfos = this.getModelDump(featureMap, false);
        HashMap<String, Integer> featureScore = new HashMap<String, Integer>();
        for (String tree : modelInfos) {
            for (String node : tree.split("\n")) {
                String[] array = node.split("\\[");
                if (array.length == 1) continue;
                String fid = array[1].split("\\]")[0];
                if (featureScore.containsKey(fid = fid.split("<")[0])) {
                    featureScore.put(fid, 1 + (Integer)featureScore.get(fid));
                    continue;
                }
                featureScore.put(fid, 1);
            }
        }
        return featureScore;
    }

    private String[] getDumpInfo(boolean withStats) throws XGBoostError {
        int statsFlag = 0;
        if (withStats) {
            statsFlag = 1;
        }
        String[][] modelInfos = new String[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(this.handle, "", statsFlag, "text", modelInfos));
        return modelInfos[0];
    }

    public byte[] toByteArray() throws XGBoostError {
        byte[][] bytes = new byte[1][];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
        return bytes[0];
    }

    int loadRabitCheckpoint() throws XGBoostError {
        int[] out = new int[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
        return out[0];
    }

    void saveRabitCheckpoint() throws XGBoostError {
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
    }

    private void init(DMatrix[] cacheMats) throws XGBoostError {
        long[] handles = null;
        if (cacheMats != null) {
            handles = Booster.dmatrixsToHandles(cacheMats);
        }
        long[] out = new long[1];
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterCreate(handles, out));
        this.handle = out[0];
    }

    private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
        long[] handles = new long[dmatrixs.length];
        for (int i = 0; i < dmatrixs.length; ++i) {
            handles[i] = dmatrixs[i].getHandle();
        }
        return handles;
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        try {
            out.writeObject(this.toByteArray());
        }
        catch (XGBoostError ex) {
            ex.printStackTrace();
            logger.error(ex.getMessage());
        }
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        try {
            byte[] bytes = (byte[])in.readObject();
            this.initFromBytes(bytes);
        }
        catch (XGBoostError ex) {
            ex.printStackTrace();
            logger.error(ex.getMessage());
        }
    }

    protected void initFromBytes(byte[] bytes) throws XGBoostError {
        this.init(null);
        XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
    }

    protected void finalize() throws Throwable {
        super.finalize();
        this.dispose();
    }

    public synchronized void dispose() {
        if (this.handle != 0L) {
            XGBoostJNI.XGBoosterFree(this.handle);
            this.handle = 0L;
        }
    }

    static {
        Class<?> kryoBoosterClass;
        boolean useKryo;
        logger = LogFactory.getLog(Booster.class);
        try {
            Class.forName("com.esotericsoftware.kryo.KryoSerializable");
            useKryo = true;
        }
        catch (ClassNotFoundException e) {
            useKryo = false;
            logger.debug("Kryo is not available", e);
        }
        if (useKryo) {
            try {
                String kryoBoosterClassName = Booster.class.getPackage().getName() + ".KryoBooster";
                kryoBoosterClass = Class.forName(kryoBoosterClassName);
            }
            catch (ClassNotFoundException e) {
                logger.error("KryoBooster is not available", e);
                kryoBoosterClass = null;
            }
        } else {
            kryoBoosterClass = null;
        }
        USE_KRYO_BOOSTER = kryoBoosterClass != null;
        KRYO_BOOSTER_CLASS = kryoBoosterClass;
    }
}

