/*
 * Decompiled with CFR 0.152.
 */
package alluxio.grpc;

import alluxio.conf.AlluxioConfiguration;
import alluxio.conf.Configuration;
import alluxio.conf.PropertyKey;
import alluxio.grpc.GrpcChannel;
import alluxio.grpc.GrpcChannelKey;
import alluxio.grpc.GrpcNetworkGroup;
import alluxio.grpc.GrpcServerAddress;
import alluxio.network.ChannelType;
import alluxio.util.CommonUtils;
import alluxio.util.WaitForOptions;
import alluxio.util.network.NettyUtils;
import alluxio.util.network.tls.SslContextProvider;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import io.grpc.ConnectivityState;
import io.grpc.ManagedChannel;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.channel.EventLoopGroup;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import javax.annotation.concurrent.ThreadSafe;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ThreadSafe
public class GrpcChannelPool {
    public static final GrpcChannelPool INSTANCE = new GrpcChannelPool();
    private static final Logger LOG = LoggerFactory.getLogger(GrpcChannelPool.class);
    private static final long GRACEFUL_TIMEOUT_MS = Configuration.getMs(PropertyKey.NETWORK_CONNECTION_SHUTDOWN_GRACEFUL_TIMEOUT);
    private static final long SHUTDOWN_TIMEOUT_MS = Configuration.getMs(PropertyKey.NETWORK_CONNECTION_SHUTDOWN_TIMEOUT);
    private final ConcurrentMap<GrpcChannelKey, CountingReference<ManagedChannel>> mChannels = new ConcurrentHashMap<GrpcChannelKey, CountingReference<ManagedChannel>>();
    private final ConcurrentMap<GrpcNetworkGroup, CountingReference<EventLoopGroup>> mEventLoops = new ConcurrentHashMap<GrpcNetworkGroup, CountingReference<EventLoopGroup>>();
    private final Map<GrpcNetworkGroup, AtomicLong> mNetworkGroupCounters = (Map)Arrays.stream(GrpcNetworkGroup.values()).collect(ImmutableMap.toImmutableMap(Function.identity(), group -> new AtomicLong()));
    private final SslContextProvider mSslContextProvider = SslContextProvider.Factory.create(Configuration.global());

    private GrpcChannelPool() {
    }

    public GrpcChannel acquireChannel(GrpcNetworkGroup networkGroup, GrpcServerAddress serverAddress, AlluxioConfiguration conf, boolean alwaysEnableTLS) {
        GrpcChannelKey channelKey = this.getChannelKey(networkGroup, serverAddress, conf);
        CountingReference channelRef = this.mChannels.compute(channelKey, (key, ref) -> {
            boolean shutdownExistingConnection = false;
            int existingRefCount = 0;
            if (ref != null) {
                if (this.waitForConnectionReady((ManagedChannel)((CountingReference)ref).get(), conf)) {
                    LOG.debug("Acquiring an existing connection. ConnectionKey: {}. Ref-count: {}", key, (Object)((CountingReference)ref).getRefCount());
                    return ((CountingReference)ref).reference();
                }
                shutdownExistingConnection = true;
            }
            if (shutdownExistingConnection) {
                existingRefCount = ((CountingReference)ref).getRefCount();
                LOG.debug("Shutting down an existing unhealthy connection. ConnectionKey: {}. Ref-count: {}", key, (Object)existingRefCount);
                this.shutdownManagedChannel((ManagedChannel)((CountingReference)ref).get());
            }
            LOG.debug("Creating a new managed channel. ConnectionKey: {}. Ref-count:{}, alwaysEnableTLS:{} config TLS:{}", new Object[]{key, existingRefCount, alwaysEnableTLS, conf.getBoolean(PropertyKey.NETWORK_TLS_ENABLED)});
            ManagedChannel managedChannel = this.createManagedChannel(channelKey, conf, alwaysEnableTLS);
            return new CountingReference(managedChannel, existingRefCount).reference();
        });
        return new GrpcChannel(channelKey, (ManagedChannel)channelRef.get());
    }

    public void releaseConnection(GrpcChannelKey channelKey) {
        this.mChannels.compute(channelKey, (key, ref) -> {
            Preconditions.checkNotNull((Object)ref, (Object)"Cannot release nonexistent connection");
            LOG.debug("Releasing connection for: {}. Ref-count: {}", key, (Object)((CountingReference)ref).getRefCount());
            if (((CountingReference)ref).dereference() == 0) {
                LOG.debug("Shutting down connection after: {}", (Object)channelKey);
                this.shutdownManagedChannel((ManagedChannel)((CountingReference)ref).get());
                this.releaseNetworkEventLoop(channelKey);
                return null;
            }
            return ref;
        });
    }

    private GrpcChannelKey getChannelKey(GrpcNetworkGroup networkGroup, GrpcServerAddress serverAddress, AlluxioConfiguration conf) {
        long groupIndex = this.mNetworkGroupCounters.get((Object)networkGroup).incrementAndGet();
        long maxConnectionsForGroup = conf.getLong(PropertyKey.Template.USER_NETWORK_MAX_CONNECTIONS.format(networkGroup.getPropertyCode()));
        return new GrpcChannelKey(networkGroup, serverAddress, (int)(groupIndex % maxConnectionsForGroup));
    }

    private ManagedChannel createManagedChannel(GrpcChannelKey channelKey, AlluxioConfiguration conf, boolean alwaysEnableTLS) {
        NettyChannelBuilder channelBuilder;
        SocketAddress address = channelKey.getServerAddress().getSocketAddress();
        if (address instanceof InetSocketAddress) {
            InetSocketAddress inetServerAddress = (InetSocketAddress)address;
            channelBuilder = NettyChannelBuilder.forAddress((String)inetServerAddress.getHostName(), (int)inetServerAddress.getPort());
        } else {
            channelBuilder = NettyChannelBuilder.forAddress((SocketAddress)address);
        }
        channelBuilder.keepAliveTime(conf.getMs(PropertyKey.Template.USER_NETWORK_KEEPALIVE_TIME_MS.format(channelKey.getNetworkGroup().getPropertyCode())), TimeUnit.MILLISECONDS);
        channelBuilder.keepAliveTimeout(conf.getMs(PropertyKey.Template.USER_NETWORK_KEEPALIVE_TIMEOUT_MS.format(channelKey.getNetworkGroup().getPropertyCode())), TimeUnit.MILLISECONDS);
        channelBuilder.maxInboundMessageSize((int)conf.getBytes(PropertyKey.Template.USER_NETWORK_MAX_INBOUND_MESSAGE_SIZE.format(channelKey.getNetworkGroup().getPropertyCode())));
        channelBuilder.flowControlWindow((int)conf.getBytes(PropertyKey.Template.USER_NETWORK_FLOWCONTROL_WINDOW.format(channelKey.getNetworkGroup().getPropertyCode())));
        channelBuilder.channelType(NettyUtils.getChannelClass(!(channelKey.getServerAddress().getSocketAddress() instanceof InetSocketAddress), PropertyKey.Template.USER_NETWORK_NETTY_CHANNEL.format(channelKey.getNetworkGroup().getPropertyCode()), conf));
        channelBuilder.eventLoopGroup(this.acquireNetworkEventLoop(channelKey, conf));
        channelBuilder.usePlaintext();
        if (channelKey.getNetworkGroup() == GrpcNetworkGroup.SECRET) {
            channelBuilder.sslContext(this.mSslContextProvider.getSelfSignedClientSslContext());
            channelBuilder.useTransportSecurity();
        } else if (conf.getBoolean(PropertyKey.NETWORK_TLS_ENABLED)) {
            channelBuilder.sslContext(this.mSslContextProvider.getClientSslContext());
            channelBuilder.useTransportSecurity();
        } else if (alwaysEnableTLS) {
            channelBuilder.sslContext(this.mSslContextProvider.getSelfSignedClientSslContext());
            channelBuilder.useTransportSecurity();
        }
        return channelBuilder.build();
    }

    private boolean waitForConnectionReady(ManagedChannel managedChannel, AlluxioConfiguration conf) {
        long healthCheckTimeoutMs = conf.getMs(PropertyKey.NETWORK_CONNECTION_HEALTH_CHECK_TIMEOUT);
        try {
            return CommonUtils.waitForResult("channel to be ready", () -> {
                ConnectivityState currentState = managedChannel.getState(true);
                switch (currentState) {
                    case READY: {
                        return true;
                    }
                    case TRANSIENT_FAILURE: 
                    case SHUTDOWN: {
                        return false;
                    }
                }
                return null;
            }, Objects::nonNull, WaitForOptions.defaults().setTimeoutMs((int)healthCheckTimeoutMs));
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            return false;
        }
        catch (TimeoutException e) {
            return false;
        }
    }

    private void shutdownManagedChannel(ManagedChannel managedChannel) {
        if (!managedChannel.isShutdown()) {
            managedChannel.shutdown();
            try {
                if (!managedChannel.awaitTermination(GRACEFUL_TIMEOUT_MS, TimeUnit.MILLISECONDS)) {
                    LOG.warn("Timed out gracefully shutting down connection: {}. ", (Object)managedChannel);
                }
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
        if (!managedChannel.isTerminated()) {
            managedChannel.shutdownNow();
            try {
                if (!managedChannel.awaitTermination(SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)) {
                    LOG.warn("Timed out forcefully shutting down connection: {}. ", (Object)managedChannel);
                }
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
    }

    private EventLoopGroup acquireNetworkEventLoop(GrpcChannelKey channelKey, AlluxioConfiguration conf) {
        return (EventLoopGroup)this.mEventLoops.compute(channelKey.getNetworkGroup(), (key, v) -> {
            if (v != null) {
                GrpcChannelPool.LOG.debug("Acquiring an existing event-loop for {}. Ref-Count:{}", (Object)channelKey, (Object)((CountingReference)v).getRefCount());
                ((CountingReference)v).reference();
                return v;
            }
            ChannelType nettyChannelType = NettyUtils.getChannelType(PropertyKey.Template.USER_NETWORK_NETTY_CHANNEL.format(key.getPropertyCode()), conf);
            int nettyWorkerThreadCount = conf.getInt(PropertyKey.Template.USER_NETWORK_NETTY_WORKER_THREADS.format(key.getPropertyCode()));
            GrpcChannelPool.LOG.debug("Created a new event loop. NetworkGroup: {}. NettyChannelType: {}, NettyThreadCount: {}", new Object[]{key, nettyChannelType, nettyWorkerThreadCount});
            return new CountingReference(NettyUtils.createEventLoop(nettyChannelType, nettyWorkerThreadCount, String.format("alluxio-client-netty-event-loop-%s-%%d", key.name()), true), 1);
        }).get();
    }

    private void releaseNetworkEventLoop(GrpcChannelKey channelKey) {
        this.mEventLoops.compute(channelKey.getNetworkGroup(), (key, ref) -> {
            Preconditions.checkNotNull((Object)ref, (Object)"Cannot release nonexistent event-loop");
            LOG.debug("Releasing event-loop for: {}. Ref-count: {}", (Object)channelKey, (Object)((CountingReference)ref).getRefCount());
            if (((CountingReference)ref).dereference() == 0) {
                LOG.debug("Shutting down event-loop: {}", ((CountingReference)ref).get());
                ((EventLoopGroup)((CountingReference)ref).get()).shutdownGracefully();
                return null;
            }
            return ref;
        });
    }

    private static class CountingReference<T> {
        private final T mObject;
        private final AtomicInteger mRefCount;

        private CountingReference(T object, int initialRefCount) {
            this.mObject = object;
            this.mRefCount = new AtomicInteger(initialRefCount);
        }

        private CountingReference<T> reference() {
            this.mRefCount.incrementAndGet();
            return this;
        }

        private int dereference() {
            return this.mRefCount.decrementAndGet();
        }

        private int getRefCount() {
            return this.mRefCount.get();
        }

        private T get() {
            return this.mObject;
        }
    }
}

