/*
 * Decompiled with CFR 0.152.
 */
package io.trino.parquet.spark;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;

public final class VariantUtil {
    public static final int BASIC_TYPE_BITS = 2;
    public static final int BASIC_TYPE_MASK = 3;
    public static final int TYPE_INFO_MASK = 63;
    public static final int PRIMITIVE = 0;
    public static final int SHORT_STR = 1;
    public static final int OBJECT = 2;
    public static final int ARRAY = 3;
    public static final int NULL = 0;
    public static final int TRUE = 1;
    public static final int FALSE = 2;
    public static final int INT1 = 3;
    public static final int INT2 = 4;
    public static final int INT4 = 5;
    public static final int INT8 = 6;
    public static final int DOUBLE = 7;
    public static final int DECIMAL4 = 8;
    public static final int DECIMAL8 = 9;
    public static final int DECIMAL16 = 10;
    public static final int DATE = 11;
    public static final int TIMESTAMP = 12;
    public static final int TIMESTAMP_NTZ = 13;
    public static final int FLOAT = 14;
    public static final int BINARY = 15;
    public static final int LONG_STR = 16;
    public static final byte VERSION = 1;
    public static final byte VERSION_MASK = 15;
    public static final int U24_MAX = 0xFFFFFF;
    public static final int U32_SIZE = 4;
    public static final int SIZE_LIMIT = 0x1000000;
    public static final int MAX_DECIMAL4_PRECISION = 9;
    public static final int MAX_DECIMAL8_PRECISION = 18;
    public static final int MAX_DECIMAL16_PRECISION = 38;

    private VariantUtil() {
    }

    static void checkIndex(int position, int length) {
        if (position < 0 || position >= length) {
            throw new IllegalArgumentException("Index out of bound: %s (length: %s)".formatted(position, length));
        }
    }

    static long readLong(byte[] bytes, int position, int numBytes) {
        VariantUtil.checkIndex(position, bytes.length);
        VariantUtil.checkIndex(position + numBytes - 1, bytes.length);
        long result = 0L;
        for (int i = 0; i < numBytes - 1; ++i) {
            long unsignedByteValue = bytes[position + i] & 0xFF;
            result |= unsignedByteValue << 8 * i;
        }
        long signedByteValue = bytes[position + numBytes - 1];
        return result |= signedByteValue << 8 * (numBytes - 1);
    }

    static int readUnsigned(byte[] bytes, int position, int numBytes) {
        VariantUtil.checkIndex(position, bytes.length);
        VariantUtil.checkIndex(position + numBytes - 1, bytes.length);
        int result = 0;
        for (int i = 0; i < numBytes; ++i) {
            int unsignedByteValue = bytes[position + i] & 0xFF;
            result |= unsignedByteValue << 8 * i;
        }
        if (result < 0) {
            throw new IllegalArgumentException("Value out of bound: %s".formatted(result));
        }
        return result;
    }

    public static Type getType(byte[] value, int position) {
        VariantUtil.checkIndex(position, value.length);
        int basicType = value[position] & 3;
        int typeInfo = value[position] >> 2 & 0x3F;
        return switch (basicType) {
            case 1 -> Type.STRING;
            case 2 -> Type.OBJECT;
            case 3 -> Type.ARRAY;
            default -> {
                switch (typeInfo) {
                    case 0: {
                        yield Type.NULL;
                    }
                    case 1: 
                    case 2: {
                        yield Type.BOOLEAN;
                    }
                    case 3: 
                    case 4: 
                    case 5: 
                    case 6: {
                        yield Type.LONG;
                    }
                    case 7: {
                        yield Type.DOUBLE;
                    }
                    case 8: 
                    case 9: 
                    case 10: {
                        yield Type.DECIMAL;
                    }
                    case 11: {
                        yield Type.DATE;
                    }
                    case 12: {
                        yield Type.TIMESTAMP;
                    }
                    case 13: {
                        yield Type.TIMESTAMP_NTZ;
                    }
                    case 14: {
                        yield Type.FLOAT;
                    }
                    case 15: {
                        yield Type.BINARY;
                    }
                    case 16: {
                        yield Type.STRING;
                    }
                }
                throw new IllegalArgumentException("Unexpected type: " + typeInfo);
            }
        };
    }

    private static IllegalStateException unexpectedType(Type type) {
        return new IllegalStateException("Expect type to be " + String.valueOf((Object)type));
    }

    public static boolean getBoolean(byte[] value, int position) {
        VariantUtil.checkIndex(position, value.length);
        int basicType = value[position] & 3;
        int typeInfo = value[position] >> 2 & 0x3F;
        if (basicType != 0 || typeInfo != 1 && typeInfo != 2) {
            throw VariantUtil.unexpectedType(Type.BOOLEAN);
        }
        return typeInfo == 1;
    }

    public static long getLong(byte[] value, int position) {
        VariantUtil.checkIndex(position, value.length);
        int basicType = value[position] & 3;
        int typeInfo = value[position] >> 2 & 0x3F;
        String exceptionMessage = "Expect type to be LONG/DATE/TIMESTAMP/TIMESTAMP_NTZ";
        if (basicType != 0) {
            throw new IllegalStateException(exceptionMessage);
        }
        return switch (typeInfo) {
            case 3 -> VariantUtil.readLong(value, position + 1, 1);
            case 4 -> VariantUtil.readLong(value, position + 1, 2);
            case 5, 11 -> VariantUtil.readLong(value, position + 1, 4);
            case 6, 12, 13 -> VariantUtil.readLong(value, position + 1, 8);
            default -> throw new IllegalStateException(exceptionMessage);
        };
    }

    public static double getDouble(byte[] value, int position) {
        VariantUtil.checkIndex(position, value.length);
        int basicType = value[position] & 3;
        int typeInfo = value[position] >> 2 & 0x3F;
        if (basicType != 0 || typeInfo != 7) {
            throw VariantUtil.unexpectedType(Type.DOUBLE);
        }
        return Double.longBitsToDouble(VariantUtil.readLong(value, position + 1, 8));
    }

    private static void checkDecimal(BigDecimal decimal, int maxPrecision) {
        if (decimal.precision() > maxPrecision || decimal.scale() > maxPrecision) {
            throw new IllegalArgumentException("Decimal out of bound: " + String.valueOf(decimal));
        }
    }

    public static BigDecimal getDecimal(byte[] value, int position) {
        BigDecimal result;
        VariantUtil.checkIndex(position, value.length);
        int basicType = value[position] & 3;
        int typeInfo = value[position] >> 2 & 0x3F;
        if (basicType != 0) {
            throw VariantUtil.unexpectedType(Type.DECIMAL);
        }
        int scale = value[position + 1] & 0xFF;
        switch (typeInfo) {
            case 8: {
                result = BigDecimal.valueOf(VariantUtil.readLong(value, position + 2, 4), scale);
                VariantUtil.checkDecimal(result, 9);
                break;
            }
            case 9: {
                result = BigDecimal.valueOf(VariantUtil.readLong(value, position + 2, 8), scale);
                VariantUtil.checkDecimal(result, 18);
                break;
            }
            case 10: {
                VariantUtil.checkIndex(position + 17, value.length);
                byte[] bytes = new byte[16];
                for (int i = 0; i < 16; ++i) {
                    bytes[i] = value[position + 17 - i];
                }
                result = new BigDecimal(new BigInteger(bytes), scale);
                VariantUtil.checkDecimal(result, 38);
                break;
            }
            default: {
                throw VariantUtil.unexpectedType(Type.DECIMAL);
            }
        }
        return result.stripTrailingZeros();
    }

    public static float getFloat(byte[] value, int position) {
        VariantUtil.checkIndex(position, value.length);
        int basicType = value[position] & 3;
        int typeInfo = value[position] >> 2 & 0x3F;
        if (basicType != 0 || typeInfo != 14) {
            throw VariantUtil.unexpectedType(Type.FLOAT);
        }
        return Float.intBitsToFloat((int)VariantUtil.readLong(value, position + 1, 4));
    }

    public static byte[] getBinary(byte[] value, int position) {
        VariantUtil.checkIndex(position, value.length);
        int basicType = value[position] & 3;
        int typeInfo = value[position] >> 2 & 0x3F;
        if (basicType != 0 || typeInfo != 15) {
            throw VariantUtil.unexpectedType(Type.BINARY);
        }
        int start = position + 1 + 4;
        int length = VariantUtil.readUnsigned(value, position + 1, 4);
        VariantUtil.checkIndex(start + length - 1, value.length);
        return Arrays.copyOfRange(value, start, start + length);
    }

    public static String getString(byte[] value, int position) {
        VariantUtil.checkIndex(position, value.length);
        int basicType = value[position] & 3;
        int typeInfo = value[position] >> 2 & 0x3F;
        if (basicType == 1 || basicType == 0 && typeInfo == 16) {
            int length;
            int start;
            if (basicType == 1) {
                start = position + 1;
                length = typeInfo;
            } else {
                start = position + 1 + 4;
                length = VariantUtil.readUnsigned(value, position + 1, 4);
            }
            VariantUtil.checkIndex(start + length - 1, value.length);
            return new String(value, start, length, StandardCharsets.UTF_8);
        }
        throw VariantUtil.unexpectedType(Type.STRING);
    }

    public static <T> T handleObject(byte[] value, int position, ObjectHandler<T> handler) {
        VariantUtil.checkIndex(position, value.length);
        int basicType = value[position] & 3;
        int typeInfo = value[position] >> 2 & 0x3F;
        if (basicType != 2) {
            throw VariantUtil.unexpectedType(Type.OBJECT);
        }
        boolean largeSize = (typeInfo >> 4 & 1) != 0;
        int sizeBytes = largeSize ? 4 : 1;
        int size = VariantUtil.readUnsigned(value, position + 1, sizeBytes);
        int idSize = (typeInfo >> 2 & 3) + 1;
        int offsetSize = (typeInfo & 3) + 1;
        int idStart = position + 1 + sizeBytes;
        int offsetStart = idStart + size * idSize;
        int dataStart = offsetStart + (size + 1) * offsetSize;
        return handler.apply(size, idSize, offsetSize, idStart, offsetStart, dataStart);
    }

    public static <T> T handleArray(byte[] value, int position, ArrayHandler<T> handler) {
        VariantUtil.checkIndex(position, value.length);
        int basicType = value[position] & 3;
        int typeInfo = value[position] >> 2 & 0x3F;
        if (basicType != 3) {
            throw VariantUtil.unexpectedType(Type.ARRAY);
        }
        boolean largeSize = (typeInfo >> 2 & 1) != 0;
        int sizeBytes = largeSize ? 4 : 1;
        int size = VariantUtil.readUnsigned(value, position + 1, sizeBytes);
        int offsetSize = (typeInfo & 3) + 1;
        int offsetStart = position + 1 + sizeBytes;
        int dataStart = offsetStart + (size + 1) * offsetSize;
        return handler.apply(size, offsetSize, offsetStart, dataStart);
    }

    public static String getMetadataKey(byte[] metadata, int id) {
        int nextOffset;
        VariantUtil.checkIndex(0, metadata.length);
        int offsetSize = (metadata[0] >> 6 & 3) + 1;
        int dictSize = VariantUtil.readUnsigned(metadata, 1, offsetSize);
        if (id >= dictSize) {
            throw new IllegalArgumentException("Index out of bound: %s (size: %s)".formatted(id, dictSize));
        }
        int stringStart = 1 + (dictSize + 2) * offsetSize;
        int offset = VariantUtil.readUnsigned(metadata, 1 + (id + 1) * offsetSize, offsetSize);
        if (offset > (nextOffset = VariantUtil.readUnsigned(metadata, 1 + (id + 2) * offsetSize, offsetSize))) {
            throw new IllegalArgumentException("Invalid offset: %s > %s".formatted(offset, nextOffset));
        }
        VariantUtil.checkIndex(stringStart + nextOffset - 1, metadata.length);
        return new String(metadata, stringStart + offset, nextOffset - offset, StandardCharsets.UTF_8);
    }

    public static enum Type {
        NULL,
        BOOLEAN,
        LONG,
        FLOAT,
        DOUBLE,
        DECIMAL,
        STRING,
        BINARY,
        DATE,
        TIMESTAMP,
        TIMESTAMP_NTZ,
        ARRAY,
        OBJECT;

    }

    public static interface ObjectHandler<T> {
        public T apply(int var1, int var2, int var3, int var4, int var5, int var6);
    }

    public static interface ArrayHandler<T> {
        public T apply(int var1, int var2, int var3, int var4);
    }
}

