/*
 * Decompiled with CFR 0.152.
 */
package com.azure.storage.blob.specialized.cryptography;

import com.azure.core.cryptography.AsyncKeyEncryptionKey;
import com.azure.core.cryptography.AsyncKeyEncryptionKeyResolver;
import com.azure.core.http.HttpHeaders;
import com.azure.core.http.HttpMethod;
import com.azure.core.http.HttpPipelineCallContext;
import com.azure.core.http.HttpPipelineNextPolicy;
import com.azure.core.http.HttpResponse;
import com.azure.core.http.policy.HttpPipelinePolicy;
import com.azure.core.util.FluxUtil;
import com.azure.core.util.logging.ClientLogger;
import com.azure.storage.blob.specialized.cryptography.EncryptedBlobRange;
import com.azure.storage.blob.specialized.cryptography.EncryptionData;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.security.GeneralSecurityException;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import java.util.Locale;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicLong;
import javax.crypto.Cipher;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import reactor.core.Exceptions;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class BlobDecryptionPolicy
implements HttpPipelinePolicy {
    private final ClientLogger logger = new ClientLogger(BlobDecryptionPolicy.class);
    private final AsyncKeyEncryptionKeyResolver keyResolver;
    private final AsyncKeyEncryptionKey keyWrapper;
    private final boolean requiresEncryption;

    BlobDecryptionPolicy(AsyncKeyEncryptionKey key, AsyncKeyEncryptionKeyResolver keyResolver, boolean requiresEncryption) {
        this.keyWrapper = key;
        this.keyResolver = keyResolver;
        this.requiresEncryption = requiresEncryption;
    }

    public Mono<HttpResponse> process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) {
        HttpHeaders requestHeaders = context.getHttpRequest().getHeaders();
        EncryptedBlobRange encryptedRange = EncryptedBlobRange.getEncryptedBlobRangeFromHeader(requestHeaders.getValue("x-ms-range"));
        if (requestHeaders.getValue("x-ms-range") != null) {
            requestHeaders.put("x-ms-range", encryptedRange.toBlobRange().toString());
        }
        return next.process().flatMap(httpResponse -> {
            if (httpResponse.getRequest().getHttpMethod() == HttpMethod.GET && httpResponse.getBody() != null) {
                HttpHeaders responseHeaders = httpResponse.getHeaders();
                encryptedRange.setAdjustedDownloadCount(Long.parseLong(responseHeaders.getValue("Content-Length")));
                boolean padding = encryptedRange.toBlobRange().getOffset() + encryptedRange.toBlobRange().getCount() > this.blobSize(responseHeaders) - 16L;
                String encryptedDataString = responseHeaders.getValue("x-ms-meta-encryptiondata");
                Flux<ByteBuffer> plainTextData = this.decryptBlob(encryptedDataString, (Flux<ByteBuffer>)httpResponse.getBody(), encryptedRange, padding);
                return Mono.just((Object)((Object)new DecryptedResponse((HttpResponse)httpResponse, plainTextData)));
            }
            return Mono.just((Object)httpResponse);
        });
    }

    Flux<ByteBuffer> decryptBlob(String encryptedDataString, Flux<ByteBuffer> encryptedFlux, EncryptedBlobRange encryptedBlobRange, boolean padding) {
        EncryptionData encryptionData = this.getAndValidateEncryptionData(encryptedDataString);
        AtomicLong totalInputBytes = new AtomicLong(0L);
        AtomicLong totalOutputBytes = new AtomicLong(0L);
        Flux dataToTrim = encryptionData == null ? encryptedFlux : this.getKeyEncryptionKey(encryptionData).flatMapMany(contentEncryptionKey -> {
            Cipher cipher;
            byte[] iv = encryptedBlobRange.getOffsetAdjustment() <= 16 ? encryptionData.getContentEncryptionIV() : new byte[16];
            try {
                cipher = this.getCipher((byte[])contentEncryptionKey, encryptionData, iv, padding);
            }
            catch (InvalidKeyException e) {
                throw this.logger.logExceptionAsError(Exceptions.propagate((Throwable)e));
            }
            return encryptedFlux.map(encryptedByteBuffer -> {
                ByteBuffer plaintextByteBuffer = ByteBuffer.allocate(cipher.getOutputSize(encryptedByteBuffer.remaining()));
                int bytesToInput = encryptedByteBuffer.remaining();
                try {
                    if (totalInputBytes.longValue() + (long)bytesToInput >= encryptedBlobRange.getAdjustedDownloadCount()) {
                        cipher.doFinal((ByteBuffer)encryptedByteBuffer, plaintextByteBuffer);
                    } else {
                        cipher.update((ByteBuffer)encryptedByteBuffer, plaintextByteBuffer);
                    }
                }
                catch (GeneralSecurityException e) {
                    throw this.logger.logExceptionAsError(Exceptions.propagate((Throwable)e));
                }
                totalInputBytes.addAndGet(bytesToInput);
                plaintextByteBuffer.flip();
                return plaintextByteBuffer;
            });
        });
        return dataToTrim.map(plaintextByteBuffer -> {
            int decryptedBytes = plaintextByteBuffer.limit();
            if (totalOutputBytes.longValue() <= (long)encryptedBlobRange.getOffsetAdjustment()) {
                int remainingAdjustment = encryptedBlobRange.getOffsetAdjustment() - (int)totalOutputBytes.longValue();
                int newPosition = Math.min(remainingAdjustment, plaintextByteBuffer.limit());
                plaintextByteBuffer.position(newPosition);
            }
            long beginningOfEndAdjustment = encryptedBlobRange.getOriginalRange().getCount() == null ? Long.MAX_VALUE : (long)encryptedBlobRange.getOffsetAdjustment() + encryptedBlobRange.getOriginalRange().getCount();
            if ((long)decryptedBytes + totalOutputBytes.longValue() > beginningOfEndAdjustment) {
                long amountPastEnd = (long)decryptedBytes + totalOutputBytes.longValue() - beginningOfEndAdjustment;
                int newLimit = totalOutputBytes.longValue() <= beginningOfEndAdjustment ? decryptedBytes - (int)amountPastEnd : plaintextByteBuffer.position();
                plaintextByteBuffer.limit(newLimit);
            } else if ((long)decryptedBytes + totalOutputBytes.longValue() > (long)encryptedBlobRange.getOffsetAdjustment()) {
                plaintextByteBuffer.limit(decryptedBytes);
            } else {
                plaintextByteBuffer.limit(plaintextByteBuffer.position());
            }
            totalOutputBytes.addAndGet(decryptedBytes);
            return plaintextByteBuffer;
        });
    }

    private EncryptionData getAndValidateEncryptionData(String encryptedDataString) {
        if (encryptedDataString == null) {
            if (this.requiresEncryption) {
                throw this.logger.logExceptionAsError((RuntimeException)new IllegalStateException("'requiresEncryption' set to true but downloaded data is not encrypted."));
            }
            return null;
        }
        try {
            EncryptionData encryptionData = EncryptionData.fromJsonString(encryptedDataString);
            if (encryptionData == null) {
                if (this.requiresEncryption) {
                    throw this.logger.logExceptionAsError((RuntimeException)new IllegalStateException("'requiresEncryption' set to true but downloaded data is not encrypted."));
                }
                return null;
            }
            Objects.requireNonNull(encryptionData.getContentEncryptionIV(), "contentEncryptionIV in encryptionData cannot be null");
            Objects.requireNonNull(encryptionData.getWrappedContentKey().getEncryptedKey(), "encryptedKey in encryptionData.wrappedContentKey cannot be null");
            if (!"1.0".equals(encryptionData.getEncryptionAgent().getProtocol())) {
                throw this.logger.logExceptionAsError((RuntimeException)new IllegalArgumentException(String.format(Locale.ROOT, "Invalid Encryption Agent. This version of the client library does not understand the Encryption Agent set on the blob message: %s", encryptionData.getEncryptionAgent())));
            }
            return encryptionData;
        }
        catch (IOException e) {
            throw this.logger.logExceptionAsError(new RuntimeException(e));
        }
    }

    private Mono<byte[]> getKeyEncryptionKey(EncryptionData encryptionData) {
        Mono keyMono = this.keyResolver != null ? this.keyResolver.buildAsyncKeyEncryptionKey(encryptionData.getWrappedContentKey().getKeyId()).onErrorResume(NullPointerException.class, e -> {
            throw this.logger.logExceptionAsError(Exceptions.propagate((Throwable)e));
        }) : this.keyWrapper.getKeyId().flatMap(keyId -> {
            if (encryptionData.getWrappedContentKey().getKeyId().equals(keyId)) {
                return Mono.just((Object)this.keyWrapper);
            }
            throw this.logger.logExceptionAsError(Exceptions.propagate((Throwable)new IllegalArgumentException("Key mismatch. The key id stored on the service does not match the specified key.")));
        });
        return keyMono.flatMap(keyEncryptionKey -> keyEncryptionKey.unwrapKey(encryptionData.getWrappedContentKey().getAlgorithm(), encryptionData.getWrappedContentKey().getEncryptedKey()));
    }

    private Cipher getCipher(byte[] contentEncryptionKey, EncryptionData encryptionData, byte[] iv, boolean padding) throws InvalidKeyException {
        try {
            switch (encryptionData.getEncryptionAgent().getAlgorithm()) {
                case AES_CBC_256: {
                    Cipher cipher = padding ? Cipher.getInstance("AES/CBC/PKCS5Padding") : Cipher.getInstance("AES/CBC/NoPadding");
                    IvParameterSpec ivParameterSpec = new IvParameterSpec(iv);
                    SecretKeySpec keySpec = new SecretKeySpec(contentEncryptionKey, 0, contentEncryptionKey.length, "AES");
                    cipher.init(2, (Key)keySpec, ivParameterSpec);
                    return cipher;
                }
            }
            throw this.logger.logExceptionAsError((RuntimeException)new IllegalArgumentException("Invalid Encryption Algorithm found on the resource. This version of the client library does not support the specified encryption algorithm."));
        }
        catch (InvalidAlgorithmParameterException | NoSuchAlgorithmException | NoSuchPaddingException e) {
            throw this.logger.logExceptionAsError(Exceptions.propagate((Throwable)e));
        }
    }

    private Long blobSize(HttpHeaders headers) {
        if (headers.getValue("Content-Range") != null) {
            String range = headers.getValue("Content-Range");
            return Long.valueOf(range.split("/")[1]);
        }
        return Long.valueOf(headers.getValue("Content-Length"));
    }

    static class DecryptedResponse
    extends HttpResponse {
        private final Flux<ByteBuffer> plainTextBody;
        private final HttpHeaders httpHeaders;
        private final int statusCode;

        DecryptedResponse(HttpResponse httpResponse, Flux<ByteBuffer> plainTextBody) {
            super(httpResponse.getRequest());
            this.plainTextBody = plainTextBody;
            this.httpHeaders = httpResponse.getHeaders();
            this.statusCode = httpResponse.getStatusCode();
        }

        public int getStatusCode() {
            return this.statusCode;
        }

        public String getHeaderValue(String name) {
            return this.httpHeaders.getValue(name);
        }

        public HttpHeaders getHeaders() {
            return this.httpHeaders;
        }

        public Flux<ByteBuffer> getBody() {
            return this.plainTextBody;
        }

        public Mono<byte[]> getBodyAsByteArray() {
            return FluxUtil.collectBytesInByteBufferStream(this.plainTextBody);
        }

        public Mono<String> getBodyAsString() {
            return FluxUtil.collectBytesInByteBufferStream(this.plainTextBody).map(String::new);
        }

        public Mono<String> getBodyAsString(Charset charset) {
            return FluxUtil.collectBytesInByteBufferStream(this.plainTextBody).map(b -> new String((byte[])b, charset));
        }
    }
}

