/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import java.io.IOException;
import java.util.Iterator;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.DataBatch;
import ml.dmlc.xgboost4j.java.JNIErrorHandle;
import ml.dmlc.xgboost4j.java.NativeLibLoader;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.XGBoostJNI;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class DMatrix {
    private static final Log logger = LogFactory.getLog(DMatrix.class);
    protected long handle = 0L;

    public DMatrix(Iterator<LabeledPoint> iter, String cacheInfo) throws XGBoostError {
        if (iter == null) {
            throw new NullPointerException("iter: null");
        }
        int batchSize = 32768;
        DataBatch.BatchIterator batchIter = new DataBatch.BatchIterator(iter, batchSize);
        long[] out = new long[1];
        JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, cacheInfo, out));
        this.handle = out[0];
    }

    public DMatrix(String dataPath) throws XGBoostError {
        if (dataPath == null) {
            throw new NullPointerException("dataPath: null");
        }
        long[] out = new long[1];
        JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
        this.handle = out[0];
    }

    @Deprecated
    public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
        long[] out = new long[1];
        if (st == SparseType.CSR) {
            JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, 0, out));
        } else if (st == SparseType.CSC) {
            JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, 0, out));
        } else {
            throw new UnknownError("unknow sparsetype");
        }
        this.handle = out[0];
    }

    public DMatrix(long[] headers, int[] indices, float[] data, SparseType st, int shapeParam) throws XGBoostError {
        long[] out = new long[1];
        if (st == SparseType.CSR) {
            JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, shapeParam, out));
        } else if (st == SparseType.CSC) {
            JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, shapeParam, out));
        } else {
            throw new UnknownError("unknow sparsetype");
        }
        this.handle = out[0];
    }

    public DMatrix(float[] data, int nrow, int ncol) throws XGBoostError {
        long[] out = new long[1];
        JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
        this.handle = out[0];
    }

    public DMatrix(float[] data, int nrow, int ncol, float missing) throws XGBoostError {
        long[] out = new long[1];
        JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, missing, out));
        this.handle = out[0];
    }

    public DMatrix(float[][] data, int nrow, int ncol) throws XGBoostError {
        long[] out = new long[1];
        JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFrom2DMat(data, nrow, ncol, 0.0f, out));
        this.handle = out[0];
    }

    public DMatrix(float[][] data, int nrow, int ncol, float missing) throws XGBoostError {
        long[] out = new long[1];
        JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFrom2DMat(data, nrow, ncol, missing, out));
        this.handle = out[0];
    }

    protected DMatrix(long handle) {
        this.handle = handle;
    }

    public void setLabel(float[] labels) throws XGBoostError {
        JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(this.handle, "label", labels));
    }

    public void setWeight(float[] weights) throws XGBoostError {
        JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(this.handle, "weight", weights));
    }

    public void setBaseMargin(float[] baseMargin) throws XGBoostError {
        JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(this.handle, "base_margin", baseMargin));
    }

    public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
        float[] flattenMargin = DMatrix.flatten(baseMargin);
        this.setBaseMargin(flattenMargin);
    }

    public void setGroup(int[] group) throws XGBoostError {
        JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetGroup(this.handle, group));
    }

    private float[] getFloatInfo(String field) throws XGBoostError {
        float[][] infos = new float[1][];
        JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixGetFloatInfo(this.handle, field, infos));
        return infos[0];
    }

    private int[] getIntInfo(String field) throws XGBoostError {
        int[][] infos = new int[1][];
        JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixGetUIntInfo(this.handle, field, infos));
        return infos[0];
    }

    public float[] getLabel() throws XGBoostError {
        return this.getFloatInfo("label");
    }

    public float[] getWeight() throws XGBoostError {
        return this.getFloatInfo("weight");
    }

    public float[] getBaseMargin() throws XGBoostError {
        return this.getFloatInfo("base_margin");
    }

    public DMatrix slice(int[] rowIndex) throws XGBoostError {
        long[] out = new long[1];
        JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSliceDMatrix(this.handle, rowIndex, out));
        long sHandle = out[0];
        DMatrix sMatrix = new DMatrix(sHandle);
        return sMatrix;
    }

    public long rowNum() throws XGBoostError {
        long[] rowNum = new long[1];
        JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixNumRow(this.handle, rowNum));
        return rowNum[0];
    }

    public void saveBinary(String filePath) {
        XGBoostJNI.XGDMatrixSaveBinary(this.handle, filePath, 1);
    }

    public long getHandle() {
        return this.handle;
    }

    private static float[] flatten(float[][] mat) {
        int size = 0;
        for (float[] array : mat) {
            size += array.length;
        }
        float[] result = new float[size];
        int pos = 0;
        for (float[] ar : mat) {
            System.arraycopy(ar, 0, result, pos, ar.length);
            pos += ar.length;
        }
        return result;
    }

    protected void finalize() {
        this.dispose();
    }

    public synchronized void dispose() {
        if (this.handle != 0L) {
            XGBoostJNI.XGDMatrixFree(this.handle);
            this.handle = 0L;
        }
    }

    static {
        try {
            NativeLibLoader.initXGBoost();
        }
        catch (IOException ex) {
            logger.error((Object)"load native library failed.");
            logger.error((Object)ex);
        }
    }

    public static enum SparseType {
        CSR,
        CSC;

    }
}

