/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.buffer;

import com.google.common.base.Preconditions;
import com.google.common.base.VerifyException;
import io.airlift.compress.Decompressor;
import io.airlift.compress.lz4.Lz4Decompressor;
import io.airlift.compress.lz4.Lz4RawCompressor;
import io.airlift.slice.SizeOf;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceInput;
import io.airlift.slice.Slices;
import io.trino.execution.buffer.PagesSerdeUtil;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.Page;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.block.BlockEncodingSerde;
import io.trino.util.Ciphers;
import java.io.IOException;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;
import java.security.GeneralSecurityException;
import java.security.Key;
import java.util.Objects;
import java.util.Optional;
import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;

public class PageDeserializer {
    private static final int INSTANCE_SIZE = SizeOf.instanceSize(PageDeserializer.class);
    private final BlockEncodingSerde blockEncodingSerde;
    private final SerializedPageInput input;

    public PageDeserializer(BlockEncodingSerde blockEncodingSerde, boolean compressionEnabled, Optional<SecretKey> encryptionKey, int blockSizeInBytes) {
        this.blockEncodingSerde = Objects.requireNonNull(blockEncodingSerde, "blockEncodingSerde is null");
        Objects.requireNonNull(encryptionKey, "encryptionKey is null");
        encryptionKey.ifPresent(secretKey -> Preconditions.checkArgument((boolean)Ciphers.is256BitSecretKeySpec(secretKey), (Object)"encryptionKey is expected to be an instance of SecretKeySpec containing a 256bit key"));
        this.input = new SerializedPageInput(compressionEnabled ? Optional.of(new Lz4Decompressor()) : Optional.empty(), encryptionKey, blockSizeInBytes);
    }

    public Page deserialize(Slice serializedPage) {
        int positionCount = this.input.startPage(serializedPage);
        Page page = PagesSerdeUtil.readRawPage(positionCount, this.input, this.blockEncodingSerde);
        this.input.finishPage();
        return page;
    }

    public long getRetainedSizeInBytes() {
        return (long)INSTANCE_SIZE + this.input.getRetainedSize();
    }

    private static class SerializedPageInput
    extends SliceInput {
        private static final int INSTANCE_SIZE = SizeOf.instanceSize(SerializedPageInput.class);
        private static final int DECOMPRESSOR_RETAINED_SIZE = SizeOf.instanceSize(Lz4Decompressor.class);
        private static final int ENCRYPTION_KEY_RETAINED_SIZE = Math.toIntExact((long)SizeOf.instanceSize(SecretKeySpec.class) + SizeOf.sizeOfByteArray((int)32));
        private final Optional<Lz4Decompressor> decompressor;
        private final Optional<SecretKey> encryptionKey;
        private final Optional<Cipher> cipher;
        private final ReadBuffer[] buffers;

        private SerializedPageInput(Optional<Lz4Decompressor> decompressor, Optional<SecretKey> encryptionKey, int blockSizeInBytes) {
            int bufferSize;
            this.decompressor = Objects.requireNonNull(decompressor, "decompressor is null");
            this.encryptionKey = Objects.requireNonNull(encryptionKey, "encryptionKey is null");
            this.buffers = new ReadBuffer[(decompressor.isPresent() ? 1 : 0) + (encryptionKey.isPresent() ? 1 : 0) + 1];
            if (decompressor.isPresent()) {
                bufferSize = blockSizeInBytes + 8;
                this.buffers[0] = new ReadBuffer(Slices.allocate((int)bufferSize));
                this.buffers[0].setPosition(bufferSize);
            }
            if (encryptionKey.isPresent()) {
                bufferSize = decompressor.isPresent() ? Lz4RawCompressor.maxCompressedLength((int)blockSizeInBytes) + 4 + 8 : blockSizeInBytes + 8;
                this.buffers[this.buffers.length - 2] = new ReadBuffer(Slices.allocate((int)bufferSize));
                this.buffers[this.buffers.length - 2].setPosition(bufferSize);
                try {
                    this.cipher = Optional.of(Cipher.getInstance("AES/CBC/PKCS5Padding"));
                }
                catch (GeneralSecurityException e) {
                    throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, "Failed to create cipher: " + e.getMessage(), (Throwable)e);
                }
            } else {
                this.cipher = Optional.empty();
            }
        }

        public int startPage(Slice page) {
            int positionCount = PagesSerdeUtil.getSerializedPagePositionCount(page);
            ReadBuffer buffer = new ReadBuffer(page);
            buffer.setPosition(13);
            this.buffers[this.buffers.length - 1] = buffer;
            return positionCount;
        }

        public boolean readBoolean() {
            this.ensureReadable(1);
            return this.buffers[0].readBoolean();
        }

        public byte readByte() {
            this.ensureReadable(1);
            return this.buffers[0].readByte();
        }

        public short readShort() {
            this.ensureReadable(2);
            return this.buffers[0].readShort();
        }

        public int readInt() {
            this.ensureReadable(4);
            return this.buffers[0].readInt();
        }

        public long readLong() {
            this.ensureReadable(8);
            return this.buffers[0].readLong();
        }

        public float readFloat() {
            this.ensureReadable(4);
            return this.buffers[0].readFloat();
        }

        public double readDouble() {
            this.ensureReadable(8);
            return this.buffers[0].readDouble();
        }

        public int read(byte[] destination, int destinationIndex, int length) {
            ReadBuffer buffer = this.buffers[0];
            int bytesRemaining = length;
            while (bytesRemaining > 0) {
                this.ensureReadable(Math.min(8, bytesRemaining));
                int bytesToRead = Math.min(bytesRemaining, buffer.available());
                int bytesRead = buffer.read(destination, destinationIndex, bytesToRead);
                if (bytesRead == -1) break;
                bytesRemaining -= bytesRead;
                destinationIndex += bytesRead;
            }
            return length - bytesRemaining;
        }

        public void readBytes(byte[] destination, int destinationIndex, int length) {
            ReadBuffer buffer = this.buffers[0];
            int bytesRemaining = length;
            while (bytesRemaining > 0) {
                this.ensureReadable(Math.min(8, bytesRemaining));
                int bytesToRead = Math.min(bytesRemaining, buffer.available());
                buffer.readBytes(destination, destinationIndex, bytesToRead);
                bytesRemaining -= bytesToRead;
                destinationIndex += bytesToRead;
            }
        }

        public void readShorts(short[] destination, int destinationIndex, int length) {
            ReadBuffer buffer = this.buffers[0];
            int shortsRemaining = length;
            while (shortsRemaining > 0) {
                this.ensureReadable(Math.min(8, shortsRemaining * 2));
                int shortsToRead = Math.min(shortsRemaining, buffer.available() / 2);
                buffer.readShorts(destination, destinationIndex, shortsToRead);
                shortsRemaining -= shortsToRead;
                destinationIndex += shortsToRead;
            }
        }

        public void readInts(int[] destination, int destinationIndex, int length) {
            ReadBuffer buffer = this.buffers[0];
            int intsRemaining = length;
            while (intsRemaining > 0) {
                this.ensureReadable(Math.min(8, intsRemaining * 4));
                int intsToRead = Math.min(intsRemaining, buffer.available() / 4);
                buffer.readInts(destination, destinationIndex, intsToRead);
                intsRemaining -= intsToRead;
                destinationIndex += intsToRead;
            }
        }

        public void readLongs(long[] destination, int destinationIndex, int length) {
            ReadBuffer buffer = this.buffers[0];
            int longsRemaining = length;
            while (longsRemaining > 0) {
                this.ensureReadable(Math.min(8, longsRemaining * 8));
                int longsToRead = Math.min(longsRemaining, buffer.available() / 8);
                buffer.readLongs(destination, destinationIndex, longsToRead);
                longsRemaining -= longsToRead;
                destinationIndex += longsToRead;
            }
        }

        public void readFloats(float[] destination, int destinationIndex, int length) {
            ReadBuffer buffer = this.buffers[0];
            int floatsRemaining = length;
            while (floatsRemaining > 0) {
                this.ensureReadable(Math.min(8, floatsRemaining * 4));
                int floatsToRead = Math.min(floatsRemaining, buffer.available() / 4);
                buffer.readFloats(destination, destinationIndex, floatsToRead);
                floatsRemaining -= floatsToRead;
                destinationIndex += floatsToRead;
            }
        }

        public void readDoubles(double[] destination, int destinationIndex, int length) {
            ReadBuffer buffer = this.buffers[0];
            int doublesRemaining = length;
            while (doublesRemaining > 0) {
                this.ensureReadable(Math.min(8, doublesRemaining * 8));
                int doublesToRead = Math.min(doublesRemaining, buffer.available() / 8);
                buffer.readDoubles(destination, destinationIndex, doublesToRead);
                doublesRemaining -= doublesToRead;
                destinationIndex += doublesToRead;
            }
        }

        public void readBytes(Slice destination, int destinationIndex, int length) {
            ReadBuffer buffer = this.buffers[0];
            int bytesRemaining = length;
            while (bytesRemaining > 0) {
                this.ensureReadable(Math.min(8, bytesRemaining));
                int bytesToRead = Math.min(bytesRemaining, buffer.available());
                buffer.readBytes(destination, destinationIndex, bytesToRead);
                bytesRemaining -= bytesToRead;
                destinationIndex += bytesToRead;
            }
        }

        private void ensureReadable(int bytes) {
            if (this.buffers[0].available() >= bytes) {
                return;
            }
            this.decrypt();
            this.decompress();
        }

        private void decrypt() {
            int decryptedSize;
            if (this.encryptionKey.isEmpty()) {
                return;
            }
            ReadBuffer source = this.buffers[this.buffers.length - 1];
            ReadBuffer sink = this.buffers[this.buffers.length - 2];
            int bytesPreserved = sink.rollOver();
            int encryptedSize = source.readInt();
            int ivSize = this.cipher.orElseThrow().getBlockSize();
            IvParameterSpec iv = new IvParameterSpec(source.getSlice().byteArray(), source.getSlice().byteArrayOffset() + source.getPosition(), ivSize);
            source.setPosition(source.getPosition() + ivSize);
            Cipher cipher = this.initCipher(this.encryptionKey.get(), iv);
            try {
                decryptedSize = cipher.update(source.getSlice().byteArray(), source.getSlice().byteArrayOffset() + source.getPosition(), encryptedSize, sink.getSlice().byteArray(), sink.getSlice().byteArrayOffset() + bytesPreserved);
                decryptedSize += cipher.doFinal(sink.getSlice().byteArray(), sink.getSlice().byteArrayOffset() + bytesPreserved + decryptedSize);
            }
            catch (GeneralSecurityException e) {
                throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, "Cannot decrypt previously encrypted data: " + e.getMessage(), (Throwable)e);
            }
            source.setPosition(source.getPosition() + encryptedSize);
            sink.setLimit(bytesPreserved + decryptedSize);
        }

        private Cipher initCipher(SecretKey key, IvParameterSpec iv) {
            Cipher cipher = this.cipher.orElseThrow(() -> new VerifyException("cipher is expected to be present"));
            try {
                cipher.init(2, (Key)key, iv);
            }
            catch (GeneralSecurityException e) {
                throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, "Failed to init cipher: " + e.getMessage(), (Throwable)e);
            }
            return cipher;
        }

        private void decompress() {
            int decompressedSize;
            if (this.decompressor.isEmpty()) {
                return;
            }
            Decompressor decompressor = (Decompressor)this.decompressor.get();
            ReadBuffer source = this.buffers[1];
            ReadBuffer sink = this.buffers[0];
            int bytesPreserved = sink.rollOver();
            int compressedBlockMarker = source.readInt();
            int blockSize = SerializedPageInput.getCompressedBlockSize(compressedBlockMarker);
            boolean compressed = SerializedPageInput.isCompressed(compressedBlockMarker);
            if (compressed) {
                decompressedSize = decompressor.decompress(source.getSlice().byteArray(), source.getSlice().byteArrayOffset() + source.getPosition(), blockSize, sink.getSlice().byteArray(), sink.getSlice().byteArrayOffset() + bytesPreserved, sink.getSlice().length() - bytesPreserved);
            } else {
                System.arraycopy(source.getSlice().byteArray(), source.getSlice().byteArrayOffset() + source.getPosition(), sink.getSlice().byteArray(), sink.getSlice().byteArrayOffset() + bytesPreserved, blockSize);
                decompressedSize = blockSize;
            }
            source.setPosition(source.getPosition() + blockSize);
            sink.setLimit(bytesPreserved + decompressedSize);
        }

        private static int getCompressedBlockSize(int compressedBlockMarker) {
            return compressedBlockMarker & Integer.MAX_VALUE;
        }

        private static boolean isCompressed(int compressedBlockMarker) {
            return (compressedBlockMarker & Integer.MIN_VALUE) == Integer.MIN_VALUE;
        }

        public void finishPage() {
            this.buffers[this.buffers.length - 1] = null;
            for (ReadBuffer buffer : this.buffers) {
                if (buffer == null) continue;
                buffer.setPosition(buffer.getSlice().length());
                buffer.setLimit(buffer.getSlice().length());
            }
        }

        public int read() {
            return this.readByte();
        }

        public int readUnsignedByte() {
            return this.readByte() & 0xFF;
        }

        public int readUnsignedShort() {
            return this.readShort() & 0xFFFF;
        }

        public Slice readSlice(int length) {
            Slice slice = Slices.allocate((int)length);
            this.readBytes(slice, 0, length);
            return slice;
        }

        public boolean isReadable() {
            return this.available() > 0;
        }

        public int available() {
            return this.buffers[0].available();
        }

        public long skip(long length) {
            return 0L;
        }

        public int skipBytes(int length) {
            return Math.toIntExact(this.skip(length));
        }

        public long getRetainedSize() {
            long size = INSTANCE_SIZE;
            size += SizeOf.sizeOf(this.decompressor, compressor -> DECOMPRESSOR_RETAINED_SIZE);
            size += SizeOf.sizeOf(this.encryptionKey, encryptionKey -> ENCRYPTION_KEY_RETAINED_SIZE);
            size += SizeOf.sizeOf(this.cipher, cipher -> 1024L);
            for (ReadBuffer input : this.buffers) {
                if (input == null) continue;
                size += input.getRetainedSizeInBytes();
            }
            return size;
        }

        public void readBytes(OutputStream out, int length) throws IOException {
            throw new UnsupportedEncodingException();
        }

        public long position() {
            throw new UnsupportedOperationException();
        }

        public void setPosition(long position) {
            throw new UnsupportedOperationException();
        }
    }

    private static class ReadBuffer {
        private static final int INSTANCE_SIZE = SizeOf.instanceSize(ReadBuffer.class);
        private final Slice slice;
        private int position;
        private int limit;

        public ReadBuffer(Slice slice) {
            Objects.requireNonNull(slice, "slice is null");
            this.slice = slice;
            this.limit = slice.length();
        }

        public int available() {
            return this.limit - this.position;
        }

        public Slice getSlice() {
            return this.slice;
        }

        public int getPosition() {
            return this.position;
        }

        public void setPosition(int position) {
            this.position = position;
        }

        public void setLimit(int limit) {
            this.limit = limit;
        }

        public int rollOver() {
            int bytesToCopy = this.available();
            if (bytesToCopy != 0) {
                this.slice.setBytes(0, this.slice, this.position, bytesToCopy);
            }
            this.position = 0;
            return bytesToCopy;
        }

        public boolean readBoolean() {
            boolean value = this.slice.getByte(this.position) == 1;
            ++this.position;
            return value;
        }

        public byte readByte() {
            byte value = this.slice.getByte(this.position);
            ++this.position;
            return value;
        }

        public short readShort() {
            short value = this.slice.getShort(this.position);
            this.position += 2;
            return value;
        }

        public int readInt() {
            int value = this.slice.getInt(this.position);
            this.position += 4;
            return value;
        }

        public long readLong() {
            long value = this.slice.getLong(this.position);
            this.position += 8;
            return value;
        }

        public float readFloat() {
            float value = this.slice.getFloat(this.position);
            this.position += 4;
            return value;
        }

        public double readDouble() {
            double value = this.slice.getDouble(this.position);
            this.position += 8;
            return value;
        }

        public int read(byte[] destination, int destinationIndex, int length) {
            int bytesToRead = Math.min(length, this.slice.length() - this.position);
            this.slice.getBytes(this.position, destination, destinationIndex, bytesToRead);
            this.position += bytesToRead;
            return bytesToRead;
        }

        public void readBytes(byte[] destination, int destinationIndex, int length) {
            this.slice.getBytes(this.position, destination, destinationIndex, length);
            this.position += length;
        }

        public void readShorts(short[] destination, int destinationIndex, int length) {
            this.slice.getShorts(this.position, destination, destinationIndex, length);
            this.position += length * 2;
        }

        public void readInts(int[] destination, int destinationIndex, int length) {
            this.slice.getInts(this.position, destination, destinationIndex, length);
            this.position += length * 4;
        }

        public void readLongs(long[] destination, int destinationIndex, int length) {
            this.slice.getLongs(this.position, destination, destinationIndex, length);
            this.position += length * 8;
        }

        public void readFloats(float[] destination, int destinationIndex, int length) {
            this.slice.getFloats(this.position, destination, destinationIndex, length);
            this.position += length * 4;
        }

        public void readDoubles(double[] destination, int destinationIndex, int length) {
            this.slice.getDoubles(this.position, destination, destinationIndex, length);
            this.position += length * 8;
        }

        public void readBytes(Slice destination, int destinationIndex, int length) {
            this.slice.getBytes(this.position, destination, destinationIndex, length);
            this.position += length;
        }

        public long getRetainedSizeInBytes() {
            return (long)INSTANCE_SIZE + this.slice.getRetainedSize();
        }
    }
}

