/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.psf.compress;

import com.tencent.angel.PartitionKey;
import com.tencent.angel.ml.core.utils.JCompressUtils;
import com.tencent.angel.ml.matrix.psf.update.base.PartitionUpdateParam;
import com.tencent.angel.ml.matrix.psf.update.base.UpdateParam;
import com.tencent.angel.psagent.PSAgentContext;
import io.netty.buffer.ByteBuf;
import java.util.ArrayList;
import java.util.List;

public class QuantifyDoubleParam
extends UpdateParam {
    private final int rowId;
    private final double[] array;
    private int numBits;

    public QuantifyDoubleParam(int matrixId, int rowId, double[] array, int numBits) {
        super(matrixId, false);
        this.rowId = rowId;
        this.array = array;
        this.numBits = numBits;
    }

    public List<PartitionUpdateParam> split() {
        List partList = PSAgentContext.get().getMatrixMetaManager().getPartitions(this.matrixId, this.rowId);
        int size = partList.size();
        ArrayList<PartitionUpdateParam> partParams = new ArrayList<PartitionUpdateParam>(size);
        for (PartitionKey part : partList) {
            if (this.rowId < part.getStartRow() || this.rowId >= part.getEndRow()) {
                throw new RuntimeException("Wrong rowId!");
            }
            partParams.add(new QuantifyDoublePartUParam(this.matrixId, part, this.rowId, (int)part.getStartCol(), (int)part.getEndCol(), this.array, this.numBits));
        }
        return partParams;
    }

    public static void main(String[] argv) {
        int bitPerItem = 32;
        long maxPoint = (long)Math.pow(2.0, bitPerItem - 1) - 1L;
        double maxAbs = 2500.25;
        double item = 373.0;
        long point = (long)Math.floor(Math.abs(item) / maxAbs * (double)maxPoint);
        byte[] tmp = QuantifyDoublePartUParam.long2Byte(point, bitPerItem / 8, item < -1.0E-10);
        System.out.println("Length of bytes: " + tmp.length);
        long parsedPoint = QuantifyDoublePartUParam.byte2long(tmp);
        System.out.println("Max point: " + maxPoint + ", point: " + point + ", parsed point: " + parsedPoint);
    }

    public static class QuantifyDoublePartUParam
    extends PartitionUpdateParam {
        private int rowId;
        private int start;
        private int end;
        private double[] array;
        private double[] arraySlice;
        private int numBits;

        public QuantifyDoublePartUParam(int matrixId, PartitionKey partKey, int rowId, int start, int end, double[] array, int numBits) {
            super(matrixId, partKey, false);
            this.rowId = rowId;
            this.start = start;
            this.end = end;
            this.array = array;
            this.numBits = numBits;
        }

        public QuantifyDoublePartUParam() {
        }

        public void serialize(ByteBuf buf) {
            super.serialize(buf);
            buf.writeInt(this.rowId);
            JCompressUtils.Quantification.serializeDouble(buf, this.array, this.start, this.end, this.numBits);
        }

        public void deserialize(ByteBuf buf) {
            super.deserialize(buf);
            this.rowId = buf.readInt();
            this.arraySlice = JCompressUtils.Quantification.deserializeDouble(buf);
        }

        public int bufferLen() {
            return super.bufferLen() + 20 + (int)Math.ceil((this.end - this.start) * this.numBits / 8);
        }

        public int getRowId() {
            return this.rowId;
        }

        public double[] getArraySlice() {
            return this.arraySlice;
        }

        public String toString() {
            return "QuantifyDoublePartUParam [rowId=" + this.rowId + ", numBits=" + this.numBits + ", toString()=" + super.toString() + "]";
        }

        private static byte[] int2Byte(int value, int size, boolean isNeg) {
            assert (Math.pow(2.0, 8 * size - 1) > (double)value);
            byte[] rec = new byte[size];
            for (int i = 0; i < size; ++i) {
                rec[size - i - 1] = (byte)value;
                value >>>= 8;
            }
            if (isNeg) {
                rec[0] = (byte)(rec[0] | 0x80);
            }
            return rec;
        }

        public static byte[] long2Byte(long value, int size, boolean isNeg) {
            assert (Math.pow(2.0, 8 * size - 1) > (double)value);
            byte[] rec = new byte[size];
            for (int i = 0; i < size; ++i) {
                rec[size - i - 1] = (byte)value;
                value >>>= 8;
            }
            if (isNeg) {
                rec[0] = (byte)(rec[0] | 0x80);
            }
            return rec;
        }

        private static int byte2int(byte[] buffer) {
            int rec = 0;
            boolean isNegative = (buffer[0] & 0x80) == 128;
            buffer[0] = (byte)(buffer[0] & 0x7F);
            int base = 0;
            for (int i = buffer.length - 1; i >= 0; --i) {
                long value = buffer[i] & 0xFF;
                rec = (int)((long)rec + (value << base));
                base += 8;
            }
            if (isNegative) {
                rec = -1 * rec;
            }
            return rec;
        }

        public static long byte2long(byte[] buffer) {
            long rec = 0L;
            boolean isNegative = (buffer[0] & 0x80) == 128;
            buffer[0] = (byte)(buffer[0] & 0x7F);
            int base = 0;
            for (int i = buffer.length - 1; i >= 0; --i) {
                long value = buffer[i] & 0xFF;
                rec += value << base;
                base += 8;
            }
            if (isNegative) {
                rec = -1L * rec;
            }
            return rec;
        }
    }
}

