/*
 * Decompiled with CFR 0.152.
 */
package ai.sklearn4j.core.packaging;

import ai.sklearn4j.core.ScikitLearnCoreException;
import ai.sklearn4j.core.libraries.numpy.NumpyArray;
import ai.sklearn4j.core.libraries.numpy.NumpyArrayFactory;
import ai.sklearn4j.core.packaging.IBinaryModelPackagePrimitiveValueReader;
import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Array;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class BinaryModelPackage {
    private static final int ELEMENT_TYPE_BYTE = 1;
    private static final int ELEMENT_TYPE_SHORT = 2;
    private static final int ELEMENT_TYPE_INT = 4;
    private static final int ELEMENT_TYPE_LONG = 8;
    private static final int ELEMENT_TYPE_UNSIGNED_BYTE = 17;
    private static final int ELEMENT_TYPE_UNSIGNED_SHORT = 18;
    private static final int ELEMENT_TYPE_UNSIGNED_INT = 20;
    private static final int ELEMENT_TYPE_UNSIGNED_LONG = 24;
    private static final int ELEMENT_TYPE_FLOAT = 32;
    private static final int ELEMENT_TYPE_DOUBLE = 33;
    private static final int ELEMENT_TYPE_STRING = 48;
    private static final int ELEMENT_TYPE_LIST = 64;
    private static final int ELEMENT_TYPE_DICTIONARY = 65;
    private static final int ELEMENT_TYPE_NUMPY_ARRAY = 66;
    private static final int ELEMENT_TYPE_STRING_ARRAY = 67;
    private static final int ELEMENT_TYPE_NULL = 16;
    private final InputStream stream;

    private BinaryModelPackage(InputStream stream) {
        this.stream = stream;
    }

    public static BinaryModelPackage fromFile(String path) {
        try {
            BufferedInputStream stream = new BufferedInputStream(new FileInputStream(path));
            byte[] data = new byte[((InputStream)stream).available()];
            ((InputStream)stream).read(data);
            BinaryModelPackage result = BinaryModelPackage.fromStream(new ByteArrayInputStream(data));
            ((InputStream)stream).close();
            return result;
        }
        catch (IOException ex) {
            throw new ScikitLearnCoreException("An error occurred while loading a package from file:\n" + ex.getMessage());
        }
    }

    public static BinaryModelPackage fromStream(InputStream stream) {
        return new BinaryModelPackage(stream);
    }

    public byte readByte() {
        int size = 1;
        byte[] data = this.readBuffer(size);
        return data[0];
    }

    public short readShort() {
        int size = 2;
        int result = 0;
        byte[] data = this.readBuffer(size);
        for (int i = 0; i < size; ++i) {
            result *= 256;
            result = (data[size - 1 - i] & 0xFF) + result;
        }
        return (short)result;
    }

    public int readInteger() {
        int size = 4;
        int result = 0;
        byte[] data = this.readBuffer(size);
        for (int i = 0; i < size; ++i) {
            result *= 256;
            result = (data[size - 1 - i] & 0xFF) + result;
        }
        return result;
    }

    public long readLongInteger() {
        int size = 8;
        long result = 0L;
        byte[] data = this.readBuffer(size);
        for (int i = 0; i < size; ++i) {
            result *= 256L;
            result = (long)(data[size - 1 - i] & 0xFF) + result;
        }
        return result;
    }

    public float readFloat() {
        float result = Float.NaN;
        byte hasValue = this.readByte();
        if (hasValue == 1) {
            int temp = this.readInteger();
            result = Float.intBitsToFloat(temp);
        }
        return result;
    }

    public double readDouble() {
        double result = Double.NaN;
        byte hasValue = this.readByte();
        if (hasValue == 1) {
            long temp = this.readLongInteger();
            result = Double.longBitsToDouble(temp);
        }
        return result;
    }

    public String readString() {
        String result = null;
        byte hasValue = this.readByte();
        if (hasValue == 1) {
            int length = this.readInteger();
            byte[] data = this.readBuffer(length);
            result = new String(data, StandardCharsets.UTF_8);
        }
        return result;
    }

    public NumpyArray readNumpyArray() {
        NumpyArray result = null;
        byte hasValue = this.readByte();
        if (hasValue == 1) {
            int[] shape = new int[this.readInteger()];
            byte elementType = this.readByte();
            for (int i = 0; i < shape.length; ++i) {
                shape[i] = this.readInteger();
            }
            result = this.createNumpyArray(elementType, shape);
            this.readNumpyDataFromStream(result.getWrapper().getRawArray(), shape, 0, elementType);
        }
        return result;
    }

    private NumpyArray createNumpyArray(int elementType, int[] shape) {
        NumpyArray<Number> result = null;
        if (elementType == 1 || elementType == 17) {
            result = NumpyArrayFactory.arrayOfInt8WithShape(shape);
        } else if (elementType == 2 || elementType == 18) {
            result = NumpyArrayFactory.arrayOfInt16WithShape(shape);
        } else if (elementType == 4 || elementType == 20) {
            result = NumpyArrayFactory.arrayOfInt32WithShape(shape);
        } else if (elementType == 8 || elementType == 24) {
            result = NumpyArrayFactory.arrayOfInt64WithShape(shape);
        } else if (elementType == 32) {
            result = NumpyArrayFactory.arrayOfFloatWithShape(shape);
        } else if (elementType == 33) {
            result = NumpyArrayFactory.arrayOfDoubleWithShape(shape);
        } else {
            throw new ScikitLearnCoreException(String.format("Numpy array with element type %d is not supported.", elementType));
        }
        return result;
    }

    public List<Object> readList() {
        ArrayList<Object> result = null;
        byte hasValue = this.readByte();
        if (hasValue == 1) {
            result = new ArrayList<Object>();
            int count = this.readInteger();
            for (int i = 0; i < count; ++i) {
                byte elementType = this.readByte();
                if (elementType == 16) {
                    result.add(null);
                    continue;
                }
                IBinaryModelPackagePrimitiveValueReader reader = this.getPrimitiveDataReader(elementType);
                Object value = reader.readPrimitiveValue();
                result.add(value);
            }
        }
        return result;
    }

    public Map<String, Object> readDictionary() {
        HashMap<String, Object> result = null;
        byte hasValue = this.readByte();
        if (hasValue == 1) {
            result = new HashMap<String, Object>();
            int count = this.readInteger();
            for (int i = 0; i < count; ++i) {
                String key = this.readString();
                byte elementType = this.readByte();
                if (elementType == 16) {
                    result.put(key, null);
                    continue;
                }
                if (elementType == 67) {
                    result.put(key, this.readStringArray());
                    continue;
                }
                IBinaryModelPackagePrimitiveValueReader reader = this.getPrimitiveDataReader(elementType);
                Object value = reader.readPrimitiveValue();
                result.put(key, value);
            }
        }
        return result;
    }

    public String[] readStringArray() {
        String[] result = null;
        byte hasValue = this.readByte();
        if (hasValue == 1) {
            int count = this.readInteger();
            result = new String[count];
            for (int i = 0; i < count; ++i) {
                result[i] = this.readString();
            }
        }
        return result;
    }

    private void readNumpyDataFromStream(Object array, int[] shape, int dimension, int elementType) {
        if (dimension == shape.length - 1) {
            IBinaryModelPackagePrimitiveValueReader reader = this.getPrimitiveDataReader(elementType);
            int count = shape[dimension];
            for (int i = 0; i < count; ++i) {
                Array.set(array, i, reader.readPrimitiveValue());
            }
        } else {
            for (int i = 0; i < shape[dimension]; ++i) {
                this.readNumpyDataFromStream(Array.get(array, i), shape, dimension + 1, elementType);
            }
        }
    }

    private IBinaryModelPackagePrimitiveValueReader getPrimitiveDataReader(int elementType) {
        IBinaryModelPackagePrimitiveValueReader result = null;
        if (elementType == 1 || elementType == 17) {
            result = this::readByte;
        } else if (elementType == 2 || elementType == 18) {
            result = this::readShort;
        } else if (elementType == 4 || elementType == 20) {
            result = this::readInteger;
        } else if (elementType == 8 || elementType == 24) {
            result = this::readLongInteger;
        } else if (elementType == 32) {
            result = this::readFloat;
        } else if (elementType == 33) {
            result = this::readDouble;
        } else if (elementType == 48) {
            result = this::readString;
        } else if (elementType == 65) {
            result = this::readDictionary;
        } else if (elementType == 66) {
            result = this::readNumpyArray;
        } else if (elementType == 64) {
            result = this::readList;
        } else {
            throw new ScikitLearnCoreException(String.format("Numpy array with element type %d is not supported.", elementType));
        }
        return result;
    }

    private byte[] readBuffer(int size) {
        byte[] buffer = new byte[size];
        int length = 0;
        try {
            length = this.stream.read(buffer);
        }
        catch (IOException e) {
            throw new ScikitLearnCoreException("Unable to read from buffer.");
        }
        if (length != size) {
            throw new ScikitLearnCoreException(String.format("Unable to read %d bytes from the stream.", size));
        }
        return buffer;
    }

    public boolean canRead() {
        try {
            return this.stream.available() > 0;
        }
        catch (IOException e) {
            throw new ScikitLearnCoreException("An error occurred while assessing if the stream reached end or not:\n" + e.getMessage());
        }
    }
}

