/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import java.io.IOException;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Map;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.NativeLibLoader;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.XGBoostJNI;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class Rabit {
    private static final Log logger = LogFactory.getLog(DMatrix.class);

    private static void checkCall(int ret) throws XGBoostError {
        if (ret != 0) {
            throw new XGBoostError(XGBoostJNI.XGBGetLastError());
        }
    }

    public static void init(Map<String, String> envs) throws XGBoostError {
        String[] args = new String[envs.size()];
        int idx = 0;
        for (Map.Entry<String, String> e : envs.entrySet()) {
            args[idx++] = e.getKey() + '=' + e.getValue();
        }
        Rabit.checkCall(XGBoostJNI.RabitInit(args));
    }

    public static void shutdown() throws XGBoostError {
        Rabit.checkCall(XGBoostJNI.RabitFinalize());
    }

    public static void trackerPrint(String msg) throws XGBoostError {
        Rabit.checkCall(XGBoostJNI.RabitTrackerPrint(msg));
    }

    public static int versionNumber() throws XGBoostError {
        int[] out = new int[1];
        Rabit.checkCall(XGBoostJNI.RabitVersionNumber(out));
        return out[0];
    }

    public static int getRank() throws XGBoostError {
        int[] out = new int[1];
        Rabit.checkCall(XGBoostJNI.RabitGetRank(out));
        return out[0];
    }

    public static int getWorldSize() throws XGBoostError {
        int[] out = new int[1];
        Rabit.checkCall(XGBoostJNI.RabitGetWorldSize(out));
        return out[0];
    }

    public static float[] allReduce(float[] elements, OpType op) {
        DataType dataType = DataType.FLOAT;
        ByteBuffer buffer = ByteBuffer.allocateDirect(dataType.getSize() * elements.length).order(ByteOrder.nativeOrder());
        for (float el : elements) {
            buffer.putFloat(el);
        }
        buffer.flip();
        XGBoostJNI.RabitAllreduce(buffer, elements.length, dataType.getEnumOp(), op.getOperand());
        float[] results = new float[elements.length];
        buffer.asFloatBuffer().get(results);
        return results;
    }

    static {
        try {
            NativeLibLoader.initXGBoost();
        }
        catch (IOException ex) {
            logger.error((Object)"load native library failed.");
            logger.error((Object)ex);
        }
    }

    public static enum DataType implements Serializable
    {
        CHAR(0, 1),
        UCHAR(1, 1),
        INT(2, 4),
        UNIT(3, 4),
        LONG(4, 8),
        ULONG(5, 8),
        FLOAT(6, 4),
        DOUBLE(7, 8),
        LONGLONG(8, 8),
        ULONGLONG(9, 8);

        private int enumOp;
        private int size;

        public int getEnumOp() {
            return this.enumOp;
        }

        public int getSize() {
            return this.size;
        }

        private DataType(int enumOp, int size) {
            this.enumOp = enumOp;
            this.size = size;
        }
    }

    public static enum OpType implements Serializable
    {
        MAX(0),
        MIN(1),
        SUM(2),
        BITWISE_OR(3);

        private int op;

        public int getOperand() {
            return this.op;
        }

        private OpType(int op) {
            this.op = op;
        }
    }
}

