/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.core.utils;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import java.util.Random;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class JCompressUtils {
    private static final Log LOG = LogFactory.getLog(JCompressUtils.class);

    public static byte[] serializeInt(int value, int numBytes) {
        assert (Math.pow(2.0, 8 * numBytes - 1) > (double)value);
        byte[] rec = new byte[numBytes];
        boolean isNeg = false;
        if (value < 0) {
            value = -value;
            isNeg = true;
        }
        for (int i = 0; i < numBytes; ++i) {
            rec[numBytes - i - 1] = (byte)value;
            value >>>= 8;
        }
        if (isNeg) {
            rec[0] = (byte)(rec[0] | 0x80);
        }
        return rec;
    }

    public static byte serializeInt(int[] values, int numItems) {
        assert (values.length == numItems && numItems > 1);
        int numBits = 8 / numItems;
        byte rec = new Byte("00");
        int signMask = 128;
        int valueOffset = numBits == 2 ? 6 : 4;
        for (int i = 0; i < values.length; ++i) {
            int value = values[i];
            if (value < 0) {
                value = -value;
                rec = (byte)(rec | signMask);
            }
            assert (Math.pow(2.0, numBits - 1) > (double)value);
            rec = (byte)(rec | value << valueOffset);
            signMask >>= numBits;
            valueOffset -= numBits;
        }
        return rec;
    }

    private static int deserializeInt(byte[] buf) {
        int rec = 0;
        boolean isNegative = (buf[0] & 0x80) == 128;
        buf[0] = (byte)(buf[0] & 0x7F);
        int base = 0;
        for (int i = buf.length - 1; i >= 0; --i) {
            int value = buf[i] & 0xFF;
            rec += value << base;
            base += 8;
        }
        if (isNegative) {
            rec = -rec;
        }
        return rec;
    }

    private static int[] deserializeInt(byte b, int numItems) {
        int[] rec = new int[numItems];
        int numBits = 8 / numItems;
        int signMask = 128;
        int valueMask = numBits == 2 ? 1 : 7;
        int valueOffset = numBits == 2 ? 6 : 4;
        for (int i = 0; i < numItems; ++i) {
            boolean isNeg = (b & signMask) == signMask;
            int value = b >> valueOffset & valueMask;
            if (isNeg) {
                value = -value;
            }
            rec[i] = value;
            signMask >>= numBits;
            valueOffset -= numBits;
        }
        return rec;
    }

    public static void main(String[] argv) {
        Random ran = new Random();
        int len = 102;
        int numBits = 4;
        double[] dArr = new double[len];
        float[] fArr = new float[len];
        for (int i = 0; i < dArr.length; ++i) {
            dArr[i] = ran.nextDouble() - 0.5;
        }
        ByteBuf buf1 = Unpooled.buffer((int)1000);
        Quantification.serializeDouble(buf1, dArr, numBits);
        Quantification.deserializeDouble(buf1);
        for (int i = 0; i < dArr.length; ++i) {
            fArr[i] = (float)dArr[i];
        }
        ByteBuf buf2 = Unpooled.buffer((int)1000);
        Quantification.serializeFloat(buf2, fArr, numBits);
        Quantification.deserializeFloat(buf2);
    }

    public static class Quantification {
        public static int serializeFloat(ByteBuf buf, float[] arr, int numBits) {
            return Quantification.serializeFloat(buf, arr, 0, arr.length, numBits);
        }

        public static int serializeDouble(ByteBuf buf, double[] arr, int numBits) {
            return Quantification.serializeDouble(buf, arr, 0, arr.length, numBits);
        }

        public static int serializeFloat(ByteBuf buf, float[] arr, int start, int end, int numBits) {
            long startTime = System.currentTimeMillis();
            if (numBits < 2 || numBits > 16 || (numBits & numBits - 1) != 0) {
                numBits = 8;
                LOG.error((Object)"Compression bits should be in {2,4,8,16}");
            }
            int len = end - start;
            buf.writeInt(len);
            buf.writeInt(numBits);
            float maxAbs = 0.0f;
            for (int i = start; i < end; ++i) {
                if (!(Math.abs(arr[i]) > maxAbs)) continue;
                maxAbs = Math.abs(arr[i]);
            }
            buf.writeFloat(maxAbs);
            int byteSum = 0;
            int maxPoint = (int)Math.pow(2.0, numBits - 1) - 1;
            int itemPerByte = 8 / numBits;
            int bytePerItem = numBits / 8;
            int i = start;
            while (i < end) {
                if (bytePerItem >= 1) {
                    int point = Quantification.quantify(arr[i], maxAbs, maxPoint);
                    byte[] tmp = JCompressUtils.serializeInt(point, numBits / 8);
                    buf.writeBytes(tmp);
                    byteSum += bytePerItem;
                    ++i;
                    continue;
                }
                int[] tmpQ = new int[itemPerByte];
                for (int j = 0; j < itemPerByte; ++j) {
                    tmpQ[j] = i + j >= end ? 0 : Quantification.quantify(arr[i + j], maxAbs, maxPoint);
                }
                byte tmp = JCompressUtils.serializeInt(tmpQ, itemPerByte);
                buf.writeByte((int)tmp);
                ++byteSum;
                i += itemPerByte;
            }
            LOG.info((Object)String.format("compress %d floats to %d bytes, max abs: %f, max point: %d, cost %d ms", len, byteSum, Float.valueOf(maxAbs), maxPoint, System.currentTimeMillis() - startTime));
            return arr.length;
        }

        public static int serializeDouble(ByteBuf buf, double[] arr, int start, int end, int numBits) {
            long startTime = System.currentTimeMillis();
            if (numBits < 2 || numBits > 32 || (numBits & numBits - 1) != 0) {
                numBits = 8;
                LOG.error((Object)"Compression bits should be in {2,4,8,16,32}");
            }
            int len = end - start;
            buf.writeInt(len);
            buf.writeInt(numBits);
            double maxAbs = 0.0;
            for (int i = start; i < end; ++i) {
                if (!(Math.abs(arr[i]) > maxAbs)) continue;
                maxAbs = Math.abs(arr[i]);
            }
            buf.writeDouble(maxAbs);
            int byteSum = 0;
            int maxPoint = (int)Math.pow(2.0, numBits - 1) - 1;
            int itemPerByte = 8 / numBits;
            int bytePerItem = numBits / 8;
            int i = start;
            while (i < end) {
                if (bytePerItem >= 1) {
                    int point = Quantification.quantify(arr[i], maxAbs, maxPoint);
                    byte[] tmp = JCompressUtils.serializeInt(point, numBits / 8);
                    buf.writeBytes(tmp);
                    byteSum += bytePerItem;
                    ++i;
                    continue;
                }
                int[] tmpQ = new int[itemPerByte];
                for (int j = 0; j < itemPerByte; ++j) {
                    tmpQ[j] = i + j >= end ? 0 : Quantification.quantify(arr[i + j], maxAbs, maxPoint);
                }
                byte tmp = JCompressUtils.serializeInt(tmpQ, itemPerByte);
                buf.writeByte((int)tmp);
                ++byteSum;
                i += itemPerByte;
            }
            LOG.info((Object)String.format("compress %d doubles to %d bytes, max abs: %f, max point: %d, cost %d ms", len, byteSum, maxAbs, maxPoint, System.currentTimeMillis() - startTime));
            return arr.length;
        }

        public static float[] deserializeFloat(ByteBuf buf) {
            long startTime = System.currentTimeMillis();
            int length = buf.readInt();
            int numBits = buf.readInt();
            float maxAbs = buf.readFloat();
            int maxPoint = (int)Math.pow(2.0, numBits - 1) - 1;
            int itemPerByte = 8 / numBits;
            int bytePerItem = numBits / 8;
            float[] arr = new float[length];
            int i = 0;
            while (i < length) {
                int[] points;
                if (bytePerItem >= 1) {
                    float item;
                    byte[] itemBytes = new byte[bytePerItem];
                    buf.readBytes(itemBytes);
                    int point = JCompressUtils.deserializeInt(itemBytes);
                    arr[i] = item = maxAbs / (float)maxPoint * (float)point;
                    ++i;
                    continue;
                }
                byte b = buf.readByte();
                for (int point : points = JCompressUtils.deserializeInt(b, itemPerByte)) {
                    if (i >= length) continue;
                    arr[i] = maxAbs / (float)maxPoint * (float)point;
                    ++i;
                }
            }
            LOG.info((Object)String.format("parse %d floats, max abs: %f, max point: %d, cost %d ms", length, Float.valueOf(maxAbs), maxPoint, System.currentTimeMillis() - startTime));
            return arr;
        }

        public static double[] deserializeDouble(ByteBuf buf) {
            long startTime = System.currentTimeMillis();
            int length = buf.readInt();
            int numBits = buf.readInt();
            double maxAbs = buf.readDouble();
            int maxPoint = (int)Math.pow(2.0, numBits - 1) - 1;
            int itemPerByte = 8 / numBits;
            int bytePerItem = numBits / 8;
            double[] arr = new double[length];
            int i = 0;
            while (i < length) {
                int[] points;
                if (bytePerItem >= 1) {
                    double item;
                    byte[] itemBytes = new byte[bytePerItem];
                    buf.readBytes(itemBytes);
                    int point = JCompressUtils.deserializeInt(itemBytes);
                    arr[i] = item = maxAbs / (double)maxPoint * (double)point;
                    ++i;
                    continue;
                }
                byte b = buf.readByte();
                for (int point : points = JCompressUtils.deserializeInt(b, itemPerByte)) {
                    if (i >= length) continue;
                    arr[i] = maxAbs / (double)maxPoint * (double)point;
                    ++i;
                }
            }
            LOG.info((Object)String.format("parse %d double, max abs: %f, max point: %d, cost %d ms", length, maxAbs, maxPoint, System.currentTimeMillis() - startTime));
            return arr;
        }

        private static int quantify(float item, float threshold, int maxPoint) {
            int point;
            return point += (point = (int)Math.floor(item / threshold * (float)maxPoint)) < maxPoint && Math.random() > 0.5 ? 1 : 0;
        }

        private static int quantify(double item, double threshold, int maxPoint) {
            int point;
            return point += (point = (int)Math.floor(item / threshold * (double)maxPoint)) < maxPoint && Math.random() > 0.5 ? 1 : 0;
        }
    }
}

