/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Map;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.XGBoostJNI;

public class Communicator {
    private static void checkCall(int ret) throws XGBoostError {
        if (ret != 0) {
            throw new XGBoostError(XGBoostJNI.XGBGetLastError());
        }
    }

    public static void init(Map<String, Object> envs) throws XGBoostError {
        ObjectMapper mapper = new ObjectMapper();
        try {
            String jconfig = mapper.writeValueAsString(envs);
            Communicator.checkCall(XGBoostJNI.CommunicatorInit(jconfig));
        }
        catch (JsonProcessingException ex) {
            throw new XGBoostError("Failed to read arguments for the communicator.", ex);
        }
    }

    public static void shutdown() throws XGBoostError {
        Communicator.checkCall(XGBoostJNI.CommunicatorFinalize());
    }

    public static void communicatorPrint(String msg) throws XGBoostError {
        Communicator.checkCall(XGBoostJNI.CommunicatorPrint(msg));
    }

    public static int getRank() throws XGBoostError {
        int[] out = new int[1];
        Communicator.checkCall(XGBoostJNI.CommunicatorGetRank(out));
        return out[0];
    }

    public static int getWorldSize() throws XGBoostError {
        int[] out = new int[1];
        Communicator.checkCall(XGBoostJNI.CommunicatorGetWorldSize(out));
        return out[0];
    }

    public static float[] allReduce(float[] elements, OpType op) {
        DataType dataType = DataType.FLOAT32;
        ByteBuffer buffer = ByteBuffer.allocateDirect(dataType.getSize() * elements.length).order(ByteOrder.nativeOrder());
        for (float el : elements) {
            buffer.putFloat(el);
        }
        buffer.flip();
        XGBoostJNI.CommunicatorAllreduce(buffer, elements.length, dataType.getEnumOp(), op.getOperand());
        float[] results = new float[elements.length];
        buffer.asFloatBuffer().get(results);
        return results;
    }

    public static enum DataType implements Serializable
    {
        FLOAT16(0, 2),
        FLOAT32(1, 4),
        FLOAT64(2, 8),
        INT8(4, 1),
        INT16(5, 2),
        INT32(6, 4),
        INT64(7, 8),
        UINT8(8, 1),
        UINT16(9, 2),
        UINT32(10, 4),
        UINT64(11, 8);

        private final int enumOp;
        private final int size;

        public int getEnumOp() {
            return this.enumOp;
        }

        public int getSize() {
            return this.size;
        }

        private DataType(int enumOp, int size) {
            this.enumOp = enumOp;
            this.size = size;
        }
    }

    public static enum OpType implements Serializable
    {
        MAX(0),
        MIN(1),
        SUM(2);

        private int op;

        public int getOperand() {
            return this.op;
        }

        private OpType(int op) {
            this.op = op;
        }
    }
}

