/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.buffer;

import com.google.common.base.Preconditions;
import com.google.common.base.VerifyException;
import io.airlift.compress.Compressor;
import io.airlift.compress.lz4.Lz4Compressor;
import io.airlift.slice.SizeOf;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
import io.trino.execution.buffer.PageCodecMarker;
import io.trino.execution.buffer.PagesSerdeUtil;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.Page;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.block.BlockEncodingSerde;
import io.trino.util.Ciphers;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.security.GeneralSecurityException;
import java.util.Objects;
import java.util.Optional;
import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import org.openjdk.jol.info.ClassLayout;

public class PageSerializer {
    private static final int INSTANCE_SIZE = Math.toIntExact(ClassLayout.parseClass(PageSerializer.class).instanceSize());
    private final BlockEncodingSerde blockEncodingSerde;
    private final SerializedPageOutput output;

    public PageSerializer(BlockEncodingSerde blockEncodingSerde, boolean compressionEnabled, Optional<SecretKey> encryptionKey, int blockSizeInBytes) {
        this.blockEncodingSerde = Objects.requireNonNull(blockEncodingSerde, "blockEncodingSerde is null");
        Objects.requireNonNull(encryptionKey, "encryptionKey is null");
        encryptionKey.ifPresent(secretKey -> Preconditions.checkArgument((boolean)Ciphers.is256BitSecretKeySpec(secretKey), (Object)"encryptionKey is expected to be an instance of SecretKeySpec containing a 256bit key"));
        this.output = new SerializedPageOutput(compressionEnabled ? Optional.of(new Lz4Compressor()) : Optional.empty(), encryptionKey, blockSizeInBytes);
    }

    public Slice serialize(Page page) {
        this.output.startPage(page.getPositionCount(), Math.toIntExact(page.getSizeInBytes()));
        PagesSerdeUtil.writeRawPage(page, this.output, this.blockEncodingSerde);
        return this.output.closePage();
    }

    public long getRetainedSizeInBytes() {
        return (long)INSTANCE_SIZE + this.output.getRetainedSize();
    }

    private static class SerializedPageOutput
    extends SliceOutput {
        private static final int INSTANCE_SIZE = Math.toIntExact(ClassLayout.parseClass(SerializedPageOutput.class).instanceSize());
        private static final int COMPRESSOR_RETAINED_SIZE = Math.toIntExact(ClassLayout.parseClass(Lz4Compressor.class).instanceSize() + SizeOf.sizeOfIntArray((int)4096));
        private static final int ENCRYPTION_KEY_RETAINED_SIZE = Math.toIntExact(ClassLayout.parseClass(SecretKeySpec.class).instanceSize() + SizeOf.sizeOfByteArray((int)32));
        private static final double MINIMUM_COMPRESSION_RATIO = 0.8;
        private final Optional<Lz4Compressor> compressor;
        private final Optional<SecretKey> encryptionKey;
        private final int markers;
        private final Optional<Cipher> cipher;
        private final WriteBuffer[] buffers;
        private int uncompressedSize;

        private SerializedPageOutput(Optional<Lz4Compressor> compressor, Optional<SecretKey> encryptionKey, int blockSizeInBytes) {
            this.compressor = Objects.requireNonNull(compressor, "compressor is null");
            this.encryptionKey = Objects.requireNonNull(encryptionKey, "encryptionKey is null");
            this.buffers = new WriteBuffer[(compressor.isPresent() ? 1 : 0) + (encryptionKey.isPresent() ? 1 : 0) + 1];
            PageCodecMarker.MarkerSet markerSet = PageCodecMarker.MarkerSet.empty();
            if (compressor.isPresent()) {
                this.buffers[0] = new WriteBuffer(blockSizeInBytes);
                markerSet.add(PageCodecMarker.COMPRESSED);
            }
            if (encryptionKey.isPresent()) {
                int bufferSize = blockSizeInBytes;
                if (compressor.isPresent()) {
                    bufferSize = compressor.get().maxCompressedLength(blockSizeInBytes) + 4;
                }
                this.buffers[this.buffers.length - 2] = new WriteBuffer(bufferSize);
                markerSet.add(PageCodecMarker.ENCRYPTED);
                try {
                    this.cipher = Optional.of(Cipher.getInstance("AES/CBC/PKCS5Padding"));
                }
                catch (GeneralSecurityException e) {
                    throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, "Failed to create cipher: " + e.getMessage(), (Throwable)e);
                }
            } else {
                this.cipher = Optional.empty();
            }
            this.markers = markerSet.byteValue();
        }

        public void startPage(int positionCount, int sizeInBytes) {
            WriteBuffer buffer = new WriteBuffer(Math.round((float)sizeInBytes * 1.2f) + 13);
            buffer.writeInt(positionCount);
            buffer.writeByte(this.markers);
            buffer.skip(8);
            this.buffers[this.buffers.length - 1] = buffer;
            this.uncompressedSize = 0;
        }

        public void writeByte(int value) {
            this.ensureCapacityFor(1);
            this.buffers[0].writeByte(value);
            ++this.uncompressedSize;
        }

        public void writeShort(int value) {
            this.ensureCapacityFor(2);
            this.buffers[0].writeShort(value);
            this.uncompressedSize += 2;
        }

        public void writeInt(int value) {
            this.ensureCapacityFor(4);
            this.buffers[0].writeInt(value);
            this.uncompressedSize += 4;
        }

        public void writeLong(long value) {
            this.ensureCapacityFor(8);
            this.buffers[0].writeLong(value);
            this.uncompressedSize += 8;
        }

        public void writeFloat(float value) {
            this.ensureCapacityFor(4);
            this.buffers[0].writeFloat(value);
            this.uncompressedSize += 4;
        }

        public void writeDouble(double value) {
            this.ensureCapacityFor(8);
            this.buffers[0].writeDouble(value);
            this.uncompressedSize += 8;
        }

        public void writeBytes(Slice source, int sourceIndex, int length) {
            int bytesToCopy;
            WriteBuffer buffer = this.buffers[0];
            int currentIndex = sourceIndex;
            for (int bytesRemaining = length; bytesRemaining > 0; bytesRemaining -= bytesToCopy) {
                this.ensureCapacityFor(Math.min(8, bytesRemaining));
                int bufferCapacity = buffer.remainingCapacity();
                bytesToCopy = Math.min(bytesRemaining, bufferCapacity);
                buffer.writeBytes(source, currentIndex, bytesToCopy);
                currentIndex += bytesToCopy;
            }
            this.uncompressedSize += length;
        }

        public void writeBytes(byte[] source, int sourceIndex, int length) {
            int bytesToCopy;
            WriteBuffer buffer = this.buffers[0];
            int currentIndex = sourceIndex;
            for (int bytesRemaining = length; bytesRemaining > 0; bytesRemaining -= bytesToCopy) {
                this.ensureCapacityFor(Math.min(8, bytesRemaining));
                int bufferCapacity = buffer.remainingCapacity();
                bytesToCopy = Math.min(bytesRemaining, bufferCapacity);
                buffer.writeBytes(source, currentIndex, bytesToCopy);
                currentIndex += bytesToCopy;
            }
            this.uncompressedSize += length;
        }

        public Slice closePage() {
            this.compress();
            this.encrypt();
            WriteBuffer pageBuffer = this.buffers[this.buffers.length - 1];
            int serializedPageSize = pageBuffer.getPosition();
            int compressedSize = serializedPageSize - 13;
            Slice slice = pageBuffer.getSlice();
            slice.setInt(5, this.uncompressedSize);
            slice.setInt(9, compressedSize);
            Slice page = serializedPageSize < slice.length() / 2 ? Slices.copyOf((Slice)slice, (int)0, (int)serializedPageSize) : slice.slice(0, serializedPageSize);
            for (WriteBuffer buffer : this.buffers) {
                buffer.reset();
            }
            this.buffers[this.buffers.length - 1] = null;
            this.uncompressedSize = 0;
            return page;
        }

        private void ensureCapacityFor(int bytes) {
            if (this.buffers[0].remainingCapacity() >= bytes) {
                return;
            }
            this.buffers[this.buffers.length - 1].ensureCapacityFor(bytes);
            this.compress();
            this.encrypt();
        }

        private void compress() {
            int blockSize;
            boolean compressed;
            if (this.compressor.isEmpty()) {
                return;
            }
            Compressor compressor = (Compressor)this.compressor.get();
            WriteBuffer sourceBuffer = this.buffers[0];
            WriteBuffer sinkBuffer = this.buffers[1];
            int maxCompressedLength = compressor.maxCompressedLength(sourceBuffer.getPosition());
            sinkBuffer.ensureCapacityFor(maxCompressedLength + 4);
            int uncompressedSize = sourceBuffer.getPosition();
            int compressedSize = compressor.compress(sourceBuffer.getSlice().byteArray(), sourceBuffer.getSlice().byteArrayOffset(), uncompressedSize, sinkBuffer.getSlice().byteArray(), sinkBuffer.getSlice().byteArrayOffset() + sinkBuffer.getPosition() + 4, maxCompressedLength);
            boolean bl = compressed = (double)uncompressedSize * 0.8 > (double)compressedSize;
            if (!compressed) {
                System.arraycopy(sourceBuffer.getSlice().byteArray(), sourceBuffer.getSlice().byteArrayOffset(), sinkBuffer.getSlice().byteArray(), sinkBuffer.getSlice().byteArrayOffset() + sinkBuffer.getPosition() + 4, uncompressedSize);
                blockSize = uncompressedSize;
            } else {
                blockSize = compressedSize;
            }
            sinkBuffer.writeInt(SerializedPageOutput.createBlockMarker(compressed, blockSize));
            sinkBuffer.skip(blockSize);
            sourceBuffer.reset();
        }

        private static int createBlockMarker(boolean compressed, int size) {
            if (compressed) {
                return size | Integer.MIN_VALUE;
            }
            return size;
        }

        private void encrypt() {
            int encryptedSize;
            if (this.encryptionKey.isEmpty()) {
                return;
            }
            Cipher cipher = this.initCipher(this.encryptionKey.get());
            byte[] iv = cipher.getIV();
            WriteBuffer sourceBuffer = this.buffers[this.buffers.length - 2];
            WriteBuffer sinkBuffer = this.buffers[this.buffers.length - 1];
            int maxEncryptedSize = cipher.getOutputSize(sourceBuffer.getPosition()) + iv.length;
            sinkBuffer.ensureCapacityFor(maxEncryptedSize + 4 + iv.length);
            sinkBuffer.skip(4);
            sinkBuffer.writeBytes(iv, 0, iv.length);
            try {
                encryptedSize = cipher.update(sourceBuffer.getSlice().byteArray(), sourceBuffer.getSlice().byteArrayOffset(), sourceBuffer.getPosition(), sinkBuffer.getSlice().byteArray(), sinkBuffer.getSlice().byteArrayOffset() + sinkBuffer.getPosition());
                encryptedSize += cipher.doFinal(sinkBuffer.getSlice().byteArray(), sinkBuffer.getSlice().byteArrayOffset() + sinkBuffer.getPosition() + encryptedSize);
            }
            catch (GeneralSecurityException e) {
                throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, "Failed to encrypt data: " + e.getMessage(), (Throwable)e);
            }
            sinkBuffer.getSlice().setInt(sinkBuffer.getPosition() - 4 - iv.length, encryptedSize);
            sinkBuffer.skip(encryptedSize);
            sourceBuffer.reset();
        }

        private Cipher initCipher(SecretKey key) {
            Cipher cipher = this.cipher.orElseThrow(() -> new VerifyException("cipher is expected to be present"));
            try {
                cipher.init(1, key);
            }
            catch (GeneralSecurityException e) {
                throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, "Failed to init cipher: " + e.getMessage(), (Throwable)e);
            }
            return cipher;
        }

        public long getRetainedSize() {
            long size = INSTANCE_SIZE;
            size += SizeOf.sizeOf(this.compressor, compressor -> COMPRESSOR_RETAINED_SIZE);
            size += SizeOf.sizeOf(this.encryptionKey, encryptionKey -> ENCRYPTION_KEY_RETAINED_SIZE);
            size += SizeOf.sizeOf(this.cipher, cipher -> 1024L);
            for (WriteBuffer buffer : this.buffers) {
                if (buffer == null) continue;
                size += buffer.getRetainedSizeInBytes();
            }
            return size;
        }

        public int writableBytes() {
            return Integer.MAX_VALUE;
        }

        public boolean isWritable() {
            return true;
        }

        public void writeBytes(byte[] source) {
            this.writeBytes(source, 0, source.length);
        }

        public void writeBytes(Slice source) {
            this.writeBytes(source, 0, source.length());
        }

        public void writeBytes(InputStream in, int length) throws IOException {
            throw new UnsupportedOperationException();
        }

        public Slice slice() {
            throw new UnsupportedOperationException();
        }

        public Slice getUnderlyingSlice() {
            throw new UnsupportedOperationException();
        }

        public void reset() {
            throw new UnsupportedOperationException();
        }

        public void reset(int position) {
            throw new UnsupportedOperationException();
        }

        public int size() {
            throw new UnsupportedOperationException();
        }

        public String toString(Charset charset) {
            throw new UnsupportedOperationException();
        }

        public SliceOutput appendLong(long value) {
            this.writeLong(value);
            return this;
        }

        public SliceOutput appendDouble(double value) {
            this.writeDouble(value);
            return this;
        }

        public SliceOutput appendInt(int value) {
            this.writeInt(value);
            return this;
        }

        public SliceOutput appendShort(int value) {
            this.writeShort(value);
            return this;
        }

        public SliceOutput appendByte(int value) {
            this.writeByte(value);
            return this;
        }

        public SliceOutput appendBytes(byte[] source, int sourceIndex, int length) {
            this.writeBytes(source, sourceIndex, length);
            return this;
        }

        public SliceOutput appendBytes(byte[] source) {
            return this.appendBytes(source, 0, source.length);
        }

        public SliceOutput appendBytes(Slice slice) {
            this.writeBytes(slice);
            return this;
        }
    }

    private static class WriteBuffer {
        private static final int INSTANCE_SIZE = Math.toIntExact(ClassLayout.parseClass(WriteBuffer.class).instanceSize());
        private Slice slice;
        private int position;

        public WriteBuffer(int initialCapacity) {
            this.slice = Slices.allocate((int)initialCapacity);
        }

        public void writeByte(int value) {
            this.slice.setByte(this.position, value);
            ++this.position;
        }

        public void writeShort(int value) {
            this.slice.setShort(this.position, value);
            this.position += 2;
        }

        public void writeInt(int value) {
            this.slice.setInt(this.position, value);
            this.position += 4;
        }

        public void writeLong(long value) {
            this.slice.setLong(this.position, value);
            this.position += 8;
        }

        public void writeFloat(float value) {
            this.slice.setFloat(this.position, value);
            this.position += 4;
        }

        public void writeDouble(double value) {
            this.slice.setDouble(this.position, value);
            this.position += 8;
        }

        public void writeBytes(Slice source, int sourceIndex, int length) {
            this.slice.setBytes(this.position, source, sourceIndex, length);
            this.position += length;
        }

        public void writeBytes(byte[] source, int sourceIndex, int length) {
            this.slice.setBytes(this.position, source, sourceIndex, length);
            this.position += length;
        }

        public void skip(int length) {
            this.position += length;
        }

        public int remainingCapacity() {
            return this.slice.length() - this.position;
        }

        public int getPosition() {
            return this.position;
        }

        public Slice getSlice() {
            return this.slice;
        }

        public void reset() {
            this.position = 0;
        }

        public long getRetainedSizeInBytes() {
            return (long)INSTANCE_SIZE + this.slice.getRetainedSize();
        }

        public void ensureCapacityFor(int bytes) {
            this.slice = Slices.ensureSize((Slice)this.slice, (int)(this.position + bytes));
        }
    }
}

