package com.aliyun.datahub.client.http.converter.batch;

import com.aliyun.datahub.client.exception.DatahubClientException;
import com.aliyun.datahub.client.exception.InvalidParameterException;
import com.aliyun.datahub.client.exception.MalformedRecordException;
import com.aliyun.datahub.client.model.Field;
import com.aliyun.datahub.client.model.RecordSchema;
import org.apache.commons.codec.Charsets;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class BinaryRecord {
    private static final int RECORD_HEADER_SIZE = 16;
    private static final int BYTE_SIZE_ONE_FIELD = 8;
    private static final int FIELD_COUNT_BYTE_SIZE = 4;
    private static final int INT_BYTE_SIZE = 4;
    private static final byte[] PADDING_BYTES = new byte[] { 0, 0, 0, 0, 0, 0, 0, 0 };

    private final int fieldCnt;
    private final int fieldPos;
    private int nextPos;
    private int attrLength;

    private final int versionId;
    private final RecordSchema schema;
    private byte[] recordBuffer;
    private Map<String, String> attrMap = new HashMap<>();
    private boolean hasInitAttrMap;
    private RecordHeader recordHeader;

    public BinaryRecord(byte[] buffer, RecordHeader header, RecordSchema schema, int versionId) {
        this.hasInitAttrMap = false;
        this.recordBuffer = buffer;
        this.schema = schema;
        this.versionId = versionId;
        this.fieldCnt = schema != null ? schema.getFields().size() : 1;
        this.fieldPos = getFixHeaderLength(fieldCnt);
        this.recordHeader = header;
    }

    public BinaryRecord(RecordSchema schema, int versionId) {
        this.hasInitAttrMap = true;
        this.schema = schema;
        this.versionId = versionId;
        this.fieldCnt = schema != null ? schema.getFields().size() : 1;

        int minAllocSize = getMinAllocSize();
        this.recordBuffer = new byte[minAllocSize];
        this.fieldPos = getFixHeaderLength(fieldCnt);
        this.nextPos = minAllocSize;
    }

    public RecordSchema getSchema() {
        return schema;
    }

    public int getSchemaVersionId() {
        return versionId;
    }

    public Object getField(int pos) {
        if (isFieldNull(pos)) {
            return null;
        }

        try {
            if (schema == null) {
                // 二进制数据不能简单的转String再转byte[]
                return readBytesField(0);
            } else {
                Field field = schema.getField(pos);
                switch (field.getType()) {
                    case STRING:
                        return readStrField(pos);
                    case DECIMAL:
                        return new BigDecimal(readStrField(pos));
                    case FLOAT:
                        return readField(pos).getFloat();
                    case DOUBLE:
                        return readField(pos).getDouble();
                    case BOOLEAN:
                        return readField(pos).getLong() != 0;
                    case INTEGER:
                        return readField(pos).getInt();
                    case TINYINT:
                        return (byte)(readField(pos).getInt());
                    case SMALLINT:
                        return readField(pos).getShort();
                    case BIGINT:
                    case TIMESTAMP:
                        return readField(pos).getLong();
                    default:
                        throw new InvalidParameterException("Unknown schema type");
                }
            }
        } catch (Exception e) {
            throw new MalformedRecordException("Parse field fail. position:" + pos + ", error:" + e.getMessage());
        }
    }

    public void addAttribute(String key, String value) {
        attrMap.put(key, value);
        attrLength += (INT_BYTE_SIZE * 2 + key.length() + value.length());
    }

    public Map<String, String> getAttrMap() {
        initAttrMapIfNeed();
        return attrMap;
    }

    private void initAttrMapIfNeed() {
        if (hasInitAttrMap) {
            return;
        }

        RecordHeader recordHeader = constructRecordHeader();
        int offset = recordHeader.attrOffset;
        int attrSize =  BatchUtil.readInt(recordBuffer, offset);
        if (attrSize != 0 && attrMap == null) {
            attrMap = new HashMap<>();
        }

        offset += 4;
        for (int i = 0; i < attrSize; ++i) {
            int len = BatchUtil.readInt(recordBuffer, offset);
            offset += 4;
            String keyStr = new String(recordBuffer, offset, len);
            offset += len;
            len = BatchUtil.readInt(recordBuffer, offset);
            offset += 4;
            String valueStr = new String(recordBuffer, offset, len);
            attrMap.put(keyStr, valueStr);
            offset += len;
        }
        hasInitAttrMap = true;
    }

    public int getRecordSize() {
        return INT_BYTE_SIZE + attrLength + nextPos;
    }

    private byte[] readBytesField(int pos) {
        ByteBuffer buffer = readField(pos);
        long data = buffer.getLong();
        boolean isLittleStr = (data & (0x80L << 56)) != 0;
        if (isLittleStr) {
            int len = (int)((data >> 56) & 0x07);
            byte[] bytes = BatchUtil.parseLong(data);
            return Arrays.copyOfRange(bytes, 0, len);
        }

        int strOffset = RECORD_HEADER_SIZE + (int)(data >> 32);
        return Arrays.copyOfRange(recordBuffer, strOffset, strOffset + (int)data);
    }

    private String readStrField(int pos) {
        ByteBuffer buffer = readField(pos);
        long data = buffer.getLong();
        boolean isLittleStr = (data & (0x80L << 56)) != 0;
        if (isLittleStr) {
            int len = (int)((data >> 56) & 0x07);
            byte[] bytes = BatchUtil.parseLong(data);
            return new String(bytes, 0, len);
        }

        int strOffset = RECORD_HEADER_SIZE + (int)(data >> 32);
        return new String(recordBuffer, strOffset, (int)data, Charsets.UTF_8);
    }

    private ByteBuffer readField(int pos) {
        int offset = getFieldOffset(pos);
        return ByteBuffer.wrap(recordBuffer, offset, BYTE_SIZE_ONE_FIELD).order(ByteOrder.LITTLE_ENDIAN);
    }

    public void setField(int pos, Object value) {
        setNotNullAt(pos);

        if (schema == null) {
            if (!(value instanceof byte[])) {
                throw new DatahubClientException("Only support write byte[] for no schema");
            }
            writeBytesField(0, (byte[]) value);
            return;
        }
        Field field = schema.getField(pos);
        switch (field.getType()) {
            case STRING:
                writeBytesField(pos, ((String)value).getBytes(Charsets.UTF_8));
                break;
            case DECIMAL:
                writeBytesField(pos, ((BigDecimal)value).toPlainString().getBytes(Charsets.UTF_8));
                break;
            case FLOAT:
                writeField(pos, BatchUtil.parseFloat((float)value));
                break;
            case DOUBLE:
                writeField(pos, BatchUtil.parseDouble((double)value));
                break;
            case BOOLEAN:
            {
                long data = (boolean)value ? 1 : 0;
                writeField(pos, BatchUtil.parseLong(data));
                break;
            }
            default:
                writeField(pos, BatchUtil.parseLong(new Long(value.toString())));
                break;
        }
    }

    private void setNotNullAt(int pos) {
        checkPosValid(pos);
        int nullOffset = RECORD_HEADER_SIZE + FIELD_COUNT_BYTE_SIZE + (pos >> 3);
        byte value = (byte)(recordBuffer[nullOffset] | (1 << (pos & 0x07)));
        recordBuffer[nullOffset] = value;
    }

    private boolean isFieldNull(int pos) {
        checkPosValid(pos);
        int nullOffset = RECORD_HEADER_SIZE + FIELD_COUNT_BYTE_SIZE + (pos >> 3);
        byte value = (byte)(recordBuffer[nullOffset] & (1 << (pos & 0x07)));
        return value == 0;
    }

    private void writeBytesField(int pos, byte[] bytes) {
        if (bytes.length <= 7) {

            writeField(pos, BatchUtil.parseLittleStr(bytes));
        } else {
            WriteBigBytes(pos, bytes);
        }
    }

    private void WriteBigBytes(int pos, byte[] bytes) {
        int byteLength = bytes.length;
        int needSize = alignSize(byteLength);
        ensureBufferCapability(needSize);

        long offsetAndSize = ((long)(nextPos - RECORD_HEADER_SIZE) << 32) | byteLength;
        writeField(pos, BatchUtil.parseLong(offsetAndSize));
        System.arraycopy(bytes, 0, recordBuffer, nextPos, bytes.length);

        int left = needSize - byteLength;
        if (left > 0) {
            System.arraycopy(PADDING_BYTES, 0, recordBuffer, nextPos + byteLength, left);
        }
        nextPos += needSize;
    }

    private void ensureBufferCapability(int needSize) {
        int minCap = nextPos + needSize;
        int currCap = recordBuffer.length;
        if (minCap > currCap) {
            int newCap = currCap * 2;
            if (newCap < minCap) {
                newCap = minCap;
            }
            recordBuffer = Arrays.copyOf(recordBuffer, newCap);
        }
    }

    private void writeField(int pos, byte[] bytes) {
        int offset = getFieldOffset(pos);
        System.arraycopy(bytes, 0, recordBuffer, offset, bytes.length);
    }

    private int alignSize(int size) {
        return (size + 7) & ~7;
    }

    private int getFieldOffset(int pos) {
        return fieldPos + pos * BYTE_SIZE_ONE_FIELD;
    }

    private int getMinAllocSize() {
        return getFixHeaderLength(fieldCnt) + fieldCnt * BYTE_SIZE_ONE_FIELD;
    }

    private int getFixHeaderLength(int fieldCount) {
        return RECORD_HEADER_SIZE + FIELD_COUNT_BYTE_SIZE + (((fieldCount + 63) >> 6) << 3);
    }

    private RecordHeader constructRecordHeader() {
        if (recordHeader == null) {
            recordHeader = new RecordHeader() {{
                setEncodeType(0);
                setSchemaVersion(versionId);
                setTotalSize(getRecordSize());
                setAttrOffset(nextPos);
            }};
        }
        return recordHeader;
    }

    public void serialize(ByteArrayOutputStream output) throws IOException {
        RecordHeader header = constructRecordHeader();
        byte[] headerBuffer = RecordHeader.serialize(header);
        System.arraycopy(headerBuffer, 0, recordBuffer, 0, headerBuffer.length);
        output.write(recordBuffer, 0, nextPos);
        // write attr
        output.write(BatchUtil.parseInt(attrMap.size()));

        for (Map.Entry<String, String> entry : attrMap.entrySet()) {
            output.write(BatchUtil.parseInt(entry.getKey().length()));
            output.write(entry.getKey().getBytes(Charsets.UTF_8));
            output.write(BatchUtil.parseInt(entry.getValue().length()));
            output.write(entry.getValue().getBytes(Charsets.UTF_8));
        }
    }

    private void checkPosValid(int pos) {
        if (pos >= fieldCnt) {
            throw new DatahubClientException("Invalid position. position:" + pos + ", fieldCount:" + fieldCnt);
        }
    }

    public static class RecordHeader {
        private static final ThreadLocal<ByteBuffer> BYTE_BUFFER = ThreadLocal.withInitial(() -> ByteBuffer.allocate(RECORD_HEADER_SIZE).order(ByteOrder.LITTLE_ENDIAN));
        private int encodeType;
        private int schemaVersion;
        private int totalSize;
        private int attrOffset;

        public int getEncodeType() {
            return encodeType;
        }

        public void setEncodeType(int encodeType) {
            this.encodeType = encodeType;
        }

        public int getSchemaVersion() {
            return schemaVersion;
        }

        public void setSchemaVersion(int schemaVersion) {
            this.schemaVersion = schemaVersion;
        }

        public int getTotalSize() {
            return totalSize;
        }

        public void setTotalSize(int totalSize) {
            this.totalSize = totalSize;
        }

        public int getAttrOffset() {
            return attrOffset;
        }

        public void setAttrOffset(int attrOffset) {
            this.attrOffset = attrOffset;
        }

        public static RecordHeader parseFrom(ByteArrayInputStream input) {
            ByteBuffer byteBuffer = BYTE_BUFFER.get();
            byte[] buffer = new byte[RECORD_HEADER_SIZE];
            int len = input.read(buffer, 0, RECORD_HEADER_SIZE);
            if (len < RECORD_HEADER_SIZE) {
                throw new DatahubClientException("read batch header fail");
            }

            byteBuffer.clear();
            byteBuffer.put(buffer);
            byteBuffer.flip();
            return new RecordHeader() {{
                setEncodeType(byteBuffer.getInt());
                setSchemaVersion(byteBuffer.getInt());
                setTotalSize(byteBuffer.getInt());
                setAttrOffset(byteBuffer.getInt());
            }};
        }

        public static byte[] serialize(RecordHeader header) {
            ByteBuffer byteBuffer = BYTE_BUFFER.get();
            byteBuffer.clear();
            byteBuffer.putInt(header.getEncodeType());
            byteBuffer.putInt(header.getSchemaVersion());
            byteBuffer.putInt(header.getTotalSize());
            byteBuffer.putInt(header.getAttrOffset());
            return byteBuffer.array();
        }
    }
}
