/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.spark.rapids;

import ai.rapids.cudf.HostMemoryBuffer;
import com.nvidia.spark.rapids.RangeWithOffset;
import java.lang.reflect.InvocationTargetException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.ArrayList;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.reflect.MethodUtils;
import org.apache.hadoop.conf.Configuration;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProviderChain;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.core.client.config.SdkAdvancedAsyncClientOption;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import software.amazon.awssdk.http.crt.AwsCrtAsyncHttpClient;
import software.amazon.awssdk.http.crt.TcpKeepAliveConfiguration;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3AsyncClientBuilder;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.utils.ThreadFactoryBuilder;

class RangeCopier {
    private static final Logger LOG = LoggerFactory.getLogger(RangeCopier.class);
    private static S3AsyncClient asyncClient;
    private static final String CREDENTIALS_PROVIDER = "fs.s3a.aws.credentials.provider";
    private static final String ACCESS_KEY = "fs.s3a.access.key";
    private static final String SECRET_KEY = "fs.s3a.secret.key";
    private static final String SESSION_KEY = "fs.s3a.session.key";
    private static final String AWSSDK = "software.amazon.awssdk";
    static volatile boolean useNetty;

    RangeCopier() {
    }

    static long copyToHMB(Configuration hconf, HostMemoryBuffer hmb, URI fromURI, Iterable<RangeWithOffset> ranges) {
        String bucket = fromURI.getAuthority();
        String objectKey = fromURI.getRawPath().substring(1);
        ArrayList<CompletableFuture> responseFutures = new ArrayList<CompletableFuture>(3);
        for (RangeWithOffset range : ranges) {
            long hmbOffset = range.destOffset();
            GetObjectRequest rangeReq = (GetObjectRequest)GetObjectRequest.builder().bucket(bucket).key(objectKey).range(range.rangeSpec()).build();
            responseFutures.add(RangeCopier.create(hconf).getObject(rangeReq, (AsyncResponseTransformer)new AsyncRangeRequestTransformer(hmb, hmbOffset, range.length())));
        }
        return responseFutures.stream().reduce(CompletableFuture.completedFuture(0L), (x, y) -> x.thenCombine((CompletionStage)y, Long::sum)).join();
    }

    private static synchronized S3AsyncClient create(Configuration hadoopConf) {
        if (asyncClient == null) {
            LOG.debug("Initializing RAPIDS S3 Range Copier ...");
            Conf conf = new Conf(hadoopConf);
            asyncClient = (S3AsyncClient)((S3AsyncClientBuilder)((S3AsyncClientBuilder)((S3AsyncClientBuilder)((S3AsyncClientBuilder)S3AsyncClient.builder().credentialsProvider(conf.getAwsCredentialsProvider())).forcePathStyle(Boolean.valueOf(conf.pathStyle))).httpClientBuilder((SdkAsyncHttpClient.Builder)(useNetty ? RangeCopier.nettyBuilder(conf) : RangeCopier.crtBuilder(conf)))).asyncConfiguration(b -> b.advancedOption(SdkAdvancedAsyncClientOption.FUTURE_COMPLETION_EXECUTOR, (Object)RangeCopier.createThreadPoolExecutor(conf)))).build();
            LOG.debug("Done initializing RAPIDS S3 Range Copier: {}", (Object)asyncClient);
        }
        return asyncClient;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    static synchronized void shutdown() {
        if (asyncClient != null) {
            try (S3AsyncClient ignored = asyncClient;){
                LOG.debug("Closing client {}", (Object)asyncClient);
            }
            finally {
                asyncClient = null;
            }
        }
    }

    private static ThreadPoolExecutor createThreadPoolExecutor(Conf conf) {
        ThreadPoolExecutor executor = new ThreadPoolExecutor(50, conf.maxThreads, conf.threadsKeepAliveTime, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>(conf.maxTasks), new ThreadFactoryBuilder().threadNamePrefix("spark-rapids-async-s3").build());
        executor.allowCoreThreadTimeOut(true);
        return executor;
    }

    private static NettyNioAsyncHttpClient.Builder nettyBuilder(Conf conf) {
        return NettyNioAsyncHttpClient.builder().maxConcurrency(Integer.valueOf(conf.maxConcurrency)).tcpKeepAlive(Boolean.valueOf(conf.connectionKeepAlive)).connectionTimeToLive(Duration.ofMinutes(conf.connectionTTL));
    }

    private static AwsCrtAsyncHttpClient.Builder crtBuilder(Conf conf) {
        AwsCrtAsyncHttpClient.Builder builder = AwsCrtAsyncHttpClient.builder().maxConcurrency(Integer.valueOf(conf.maxConcurrency));
        if (conf.connectionKeepAlive) {
            builder.tcpKeepAliveConfiguration(TcpKeepAliveConfiguration.builder().keepAliveInterval(Duration.ofMinutes(5L)).keepAliveTimeout(Duration.ofSeconds(30L)).build());
        }
        return builder;
    }

    private static class AsyncRangeRequestTransformer
    implements AsyncResponseTransformer<GetObjectResponse, Long> {
        private final HostMemoryBuffer hostMemoryBuffer;
        private final long outputOffset;
        private final long rangeLength;
        private volatile CompletableFuture<Long> resultFuture;

        public AsyncRangeRequestTransformer(HostMemoryBuffer hostMemoryBuffer, long outputOffset, long rangeLength) {
            this.hostMemoryBuffer = hostMemoryBuffer;
            this.outputOffset = outputOffset;
            this.rangeLength = rangeLength;
        }

        public CompletableFuture<Long> prepare() {
            this.resultFuture = new CompletableFuture();
            return this.resultFuture;
        }

        public void onResponse(GetObjectResponse response) {
            LOG.debug("Response available: {}", (Object)response);
        }

        public void onStream(SdkPublisher<ByteBuffer> publisher) {
            publisher.subscribe((Subscriber)new ByteBufferSubscriber(this.resultFuture, this.hostMemoryBuffer, this.outputOffset, this.rangeLength));
        }

        public void exceptionOccurred(Throwable error) {
            this.resultFuture.completeExceptionally(error);
        }

        private static class ByteBufferSubscriber
        implements Subscriber<ByteBuffer> {
            private final CompletableFuture<Long> resultFuture;
            private final HostMemoryBuffer hostMemoryBuffer;
            private long pos;
            private final long limit;
            private long totalCopied;
            private Subscription byteBufferSubscription;

            public ByteBufferSubscriber(CompletableFuture<Long> cf, HostMemoryBuffer hostMemoryBuffer, long pos, long len) {
                this.resultFuture = cf;
                this.hostMemoryBuffer = hostMemoryBuffer;
                this.pos = pos;
                this.limit = pos + len;
            }

            public void onSubscribe(Subscription s) {
                if (this.byteBufferSubscription != null) {
                    this.byteBufferSubscription.cancel();
                    return;
                }
                this.byteBufferSubscription = s;
                this.byteBufferSubscription.request(Long.MAX_VALUE);
            }

            public void onNext(ByteBuffer byteBuffer) {
                int chunkLength = byteBuffer.remaining();
                this.hostMemoryBuffer.asByteBuffer(this.pos, chunkLength).put(byteBuffer);
                this.pos += (long)chunkLength;
                this.totalCopied += (long)chunkLength;
                if (this.pos < this.limit) {
                    this.byteBufferSubscription.request(1L);
                } else if (this.pos > this.limit) {
                    this.resultFuture.completeExceptionally(new IllegalStateException("INFEASIBLE: Remaining zero bytes expected, read past the range by bytes: " + (this.pos - this.limit)));
                }
            }

            public void onError(Throwable t) {
                this.resultFuture.completeExceptionally(t);
            }

            public void onComplete() {
                this.resultFuture.complete(this.totalCopied);
            }
        }
    }

    private static class Conf {
        private final Configuration hadoopConf;
        final int maxConcurrency;
        final int maxTasks;
        final long threadsKeepAliveTime;
        final int maxThreads;
        final boolean connectionKeepAlive;
        final long connectionTTL;
        final boolean pathStyle;

        Conf(Configuration hadoopConf) {
            LOG.debug("Creating async S3 client conf from S3AFileSystem conf: {}", (Object)hadoopConf);
            this.hadoopConf = hadoopConf;
            this.maxConcurrency = hadoopConf.getInt("fs.s3a.connection.maximum", 200);
            this.maxTasks = hadoopConf.getInt("fs.s3a.max.total.tasks", 1000);
            this.threadsKeepAliveTime = hadoopConf.getTimeDuration("fs.s3a.threads.keepalivetime", 60L, TimeUnit.SECONDS);
            this.maxThreads = hadoopConf.getInt("fs.s3a.threads.max", 136);
            this.connectionKeepAlive = hadoopConf.getBoolean("fs.s3a.connection.keepalive", true);
            this.connectionTTL = hadoopConf.getTimeDuration("fs.s3a.connection.ttl", 5L, TimeUnit.MINUTES);
            this.pathStyle = hadoopConf.getBoolean("fs.s3a.path.style.access", false);
        }

        AwsCredentialsProvider getAwsCredentialsProvider() {
            DefaultCredentialsProvider creds;
            LOG.debug("Building AwsCredentialsProvider ...");
            String accessKey = this.hadoopConf.get(RangeCopier.ACCESS_KEY);
            String secretKey = this.hadoopConf.get(RangeCopier.SECRET_KEY);
            String sessionKey = this.hadoopConf.get(RangeCopier.SESSION_KEY);
            String credsProvider = this.hadoopConf.get(RangeCopier.CREDENTIALS_PROVIDER);
            if (accessKey != null && secretKey != null && sessionKey != null) {
                LOG.debug("StaticCredentialsProvider using {}, {}, {}", new Object[]{RangeCopier.ACCESS_KEY, RangeCopier.SECRET_KEY, RangeCopier.SESSION_KEY});
                creds = StaticCredentialsProvider.create((AwsCredentials)AwsSessionCredentials.create((String)accessKey, (String)secretKey, (String)sessionKey));
            } else if (accessKey != null && secretKey != null) {
                LOG.debug("StaticCredentialsProvider using {}, {}", (Object)RangeCopier.ACCESS_KEY, (Object)RangeCopier.SECRET_KEY);
                creds = StaticCredentialsProvider.create((AwsCredentials)AwsBasicCredentials.create((String)accessKey, (String)secretKey));
            } else if (credsProvider != null) {
                LOG.debug("AwsCredentialsProviderChain using {}", (Object)RangeCopier.CREDENTIALS_PROVIDER);
                AwsCredentialsProviderChain.Builder chainBuilder = AwsCredentialsProviderChain.builder();
                for (Class<?> clazz : this.hadoopConf.getClasses(RangeCopier.CREDENTIALS_PROVIDER, new Class[0])) {
                    try {
                        String packageName = clazz.getPackage().getName();
                        Class<?> v2credentialsProviderCls = packageName.startsWith("com.amazonaws.") ? Class.forName("software.amazon.awssdk.auth.credentials." + clazz.getSimpleName()) : clazz;
                        AwsCredentialsProvider awsCredentialsProvider = (AwsCredentialsProvider)MethodUtils.invokeStaticMethod((Class)v2credentialsProviderCls, (String)"create", (Object[])new Object[0]);
                        chainBuilder.addCredentialsProvider(awsCredentialsProvider);
                    }
                    catch (ClassNotFoundException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
                        throw new RuntimeException(e);
                    }
                }
                creds = chainBuilder.build();
            } else {
                LOG.warn("Missing fs.s3a access config, using default CredentialsProvider");
                creds = DefaultCredentialsProvider.create();
            }
            LOG.info("Configured CredentialsProvider object for S3 Client: {}", (Object)creds.getClass().getName());
            return creds;
        }
    }
}

