/*
 * Decompiled with CFR 0.152.
 */
package io.netty.handler.ssl;

import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.handler.ssl.OpenSsl;
import io.netty.handler.ssl.OpenSslAsyncPrivateKeyMethod;
import io.netty.handler.ssl.OpenSslContext;
import io.netty.handler.ssl.OpenSslContextOption;
import io.netty.handler.ssl.OpenSslPrivateKeyMethod;
import io.netty.handler.ssl.OpenSslTestUtils;
import io.netty.handler.ssl.OpenSslX509KeyManagerFactory;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslContextOption;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.ImmediateEventExecutor;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.ThreadLocalRandom;
import java.net.SocketAddress;
import java.security.NoSuchAlgorithmException;
import java.security.Signature;
import java.security.SignatureException;
import java.security.cert.X509Certificate;
import java.security.spec.MGF1ParameterSpec;
import java.security.spec.PSSParameterSpec;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLHandshakeException;
import org.hamcrest.Matcher;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

public class OpenSslPrivateKeyMethodTest {
    private static final String RFC_CIPHER_NAME = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256";
    private static EventLoopGroup GROUP;
    private static SelfSignedCertificate CERT;
    private static ExecutorService EXECUTOR;

    static Collection<Object[]> parameters() {
        ArrayList<Object[]> dst = new ArrayList<Object[]>();
        for (int a = 0; a < 2; ++a) {
            for (int b = 0; b < 2; ++b) {
                for (int c = 0; c < 2; ++c) {
                    dst.add(new Object[]{a == 0, b == 0, c == 0});
                }
            }
        }
        return dst;
    }

    @BeforeAll
    public static void init() throws Exception {
        OpenSslTestUtils.checkShouldUseKeyManagerFactory();
        Assumptions.assumeTrue((boolean)OpenSsl.isBoringSSL());
        OpenSslPrivateKeyMethodTest.assumeCipherAvailable(SslProvider.OPENSSL);
        OpenSslPrivateKeyMethodTest.assumeCipherAvailable(SslProvider.JDK);
        GROUP = new DefaultEventLoopGroup();
        CERT = new SelfSignedCertificate();
        EXECUTOR = Executors.newCachedThreadPool(new ThreadFactory(){

            @Override
            public Thread newThread(Runnable r) {
                return new DelegateThread(r);
            }
        });
    }

    @AfterAll
    public static void destroy() {
        if (OpenSsl.isBoringSSL()) {
            GROUP.shutdownGracefully();
            CERT.delete();
            EXECUTOR.shutdown();
        }
    }

    private static void assumeCipherAvailable(SslProvider provider) throws NoSuchAlgorithmException {
        boolean cipherSupported = false;
        if (provider == SslProvider.JDK) {
            SSLEngine engine = SSLContext.getDefault().createSSLEngine();
            for (String c : engine.getSupportedCipherSuites()) {
                if (!RFC_CIPHER_NAME.equals(c)) continue;
                cipherSupported = true;
                break;
            }
        } else {
            cipherSupported = OpenSsl.isCipherSuiteAvailable((String)RFC_CIPHER_NAME);
        }
        Assumptions.assumeTrue((boolean)cipherSupported, (String)"Unsupported cipher: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256");
    }

    private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
        if (executor == null) {
            return sslCtx.newHandler(allocator);
        }
        return sslCtx.newHandler(allocator, executor);
    }

    private SslContext buildServerContext(OpenSslPrivateKeyMethod method) throws Exception {
        List<String> ciphers = Collections.singletonList(RFC_CIPHER_NAME);
        OpenSslX509KeyManagerFactory kmf = OpenSslX509KeyManagerFactory.newKeyless((X509Certificate[])new X509Certificate[]{CERT.cert()});
        return SslContextBuilder.forServer((KeyManagerFactory)kmf).sslProvider(SslProvider.OPENSSL).ciphers(ciphers).protocols(new String[]{"TLSv1.2"}).option((SslContextOption)OpenSslContextOption.PRIVATE_KEY_METHOD, (Object)method).build();
    }

    private SslContext buildClientContext() throws Exception {
        return SslContextBuilder.forClient().sslProvider(SslProvider.JDK).ciphers(Collections.singletonList(RFC_CIPHER_NAME)).protocols(new String[]{"TLSv1.2"}).trustManager(InsecureTrustManagerFactory.INSTANCE).build();
    }

    private static Executor delegateExecutor(boolean delegate) {
        return delegate ? EXECUTOR : null;
    }

    private SslContext buildServerContext(OpenSslAsyncPrivateKeyMethod method) throws Exception {
        List<String> ciphers = Collections.singletonList(RFC_CIPHER_NAME);
        OpenSslX509KeyManagerFactory kmf = OpenSslX509KeyManagerFactory.newKeyless((X509Certificate[])new X509Certificate[]{CERT.cert()});
        return SslContextBuilder.forServer((KeyManagerFactory)kmf).sslProvider(SslProvider.OPENSSL).ciphers(ciphers).protocols(new String[]{"TLSv1.2"}).option((SslContextOption)OpenSslContextOption.ASYNC_PRIVATE_KEY_METHOD, (Object)method).build();
    }

    private static void assertThread(boolean delegate) {
        if (delegate && OpenSslContext.USE_TASKS) {
            Assertions.assertEquals(DelegateThread.class, Thread.currentThread().getClass());
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @ParameterizedTest(name="{index}: delegate = {0}, async = {1}, newThread={2}")
    @MethodSource(value={"parameters"})
    public void testPrivateKeyMethod(final boolean delegate, boolean async, boolean newThread) throws Exception {
        final AtomicBoolean signCalled = new AtomicBoolean();
        OpenSslPrivateKeyMethod keyMethod = new OpenSslPrivateKeyMethod(){

            public byte[] sign(SSLEngine engine, int signatureAlgorithm, byte[] input) throws Exception {
                Signature signature;
                signCalled.set(true);
                OpenSslPrivateKeyMethodTest.assertThread(delegate);
                Assertions.assertEquals((Object)CERT.cert().getPublicKey(), (Object)engine.getSession().getLocalCertificates()[0].getPublicKey());
                if (signatureAlgorithm == OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256) {
                    signature = Signature.getInstance("SHA256withRSA");
                } else if (signatureAlgorithm == OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA256) {
                    signature = Signature.getInstance("RSASSA-PSS");
                    signature.setParameter(new PSSParameterSpec("SHA-256", "MGF1", MGF1ParameterSpec.SHA256, 32, 1));
                } else {
                    throw new AssertionError((Object)("Unexpected signature algorithm " + signatureAlgorithm));
                }
                signature.initSign(CERT.key());
                signature.update(input);
                return signature.sign();
            }

            public byte[] decrypt(SSLEngine engine, byte[] input) {
                throw new UnsupportedOperationException();
            }
        };
        final SslContext sslServerContext = async ? this.buildServerContext(new OpenSslPrivateKeyMethodAdapter(keyMethod, newThread)) : this.buildServerContext(keyMethod);
        final SslContext sslClientContext = this.buildClientContext();
        try {
            try {
                final Promise serverPromise = GROUP.next().newPromise();
                final Promise clientPromise = GROUP.next().newPromise();
                ChannelInitializer<Channel> serverHandler = new ChannelInitializer<Channel>(){

                    protected void initChannel(Channel ch) {
                        ChannelPipeline pipeline = ch.pipeline();
                        pipeline.addLast(new ChannelHandler[]{OpenSslPrivateKeyMethodTest.newSslHandler(sslServerContext, ch.alloc(), OpenSslPrivateKeyMethodTest.delegateExecutor(delegate))});
                        pipeline.addLast(new ChannelHandler[]{new SimpleChannelInboundHandler<Object>(){

                            public void channelInactive(ChannelHandlerContext ctx) {
                                serverPromise.cancel(true);
                                ctx.fireChannelInactive();
                            }

                            public void channelRead0(ChannelHandlerContext ctx, Object msg) {
                                if (serverPromise.trySuccess(null)) {
                                    ctx.writeAndFlush((Object)Unpooled.wrappedBuffer((byte[])new byte[]{80, 79, 78, 71}));
                                }
                                ctx.close();
                            }

                            public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
                                if (!serverPromise.tryFailure(cause)) {
                                    ctx.fireExceptionCaught(cause);
                                }
                            }
                        }});
                    }
                };
                LocalAddress address = new LocalAddress("test-" + SslProvider.OPENSSL + '-' + SslProvider.JDK + '-' + RFC_CIPHER_NAME + '-' + delegate);
                Channel server = OpenSslPrivateKeyMethodTest.server(address, (ChannelHandler)serverHandler);
                try {
                    ChannelInitializer<Channel> clientHandler = new ChannelInitializer<Channel>(){

                        protected void initChannel(Channel ch) {
                            ChannelPipeline pipeline = ch.pipeline();
                            pipeline.addLast(new ChannelHandler[]{OpenSslPrivateKeyMethodTest.newSslHandler(sslClientContext, ch.alloc(), OpenSslPrivateKeyMethodTest.delegateExecutor(delegate))});
                            pipeline.addLast(new ChannelHandler[]{new SimpleChannelInboundHandler<Object>(){

                                public void channelInactive(ChannelHandlerContext ctx) {
                                    clientPromise.cancel(true);
                                    ctx.fireChannelInactive();
                                }

                                public void channelRead0(ChannelHandlerContext ctx, Object msg) {
                                    clientPromise.trySuccess(null);
                                    ctx.close();
                                }

                                public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
                                    if (!clientPromise.tryFailure(cause)) {
                                        ctx.fireExceptionCaught(cause);
                                    }
                                }
                            }});
                        }
                    };
                    Channel client = OpenSslPrivateKeyMethodTest.client(server, (ChannelHandler)clientHandler);
                    try {
                        client.writeAndFlush((Object)Unpooled.wrappedBuffer((byte[])new byte[]{80, 73, 78, 71})).syncUninterruptibly();
                        Assertions.assertTrue((boolean)clientPromise.await(5L, TimeUnit.SECONDS), (String)"client timeout");
                        Assertions.assertTrue((boolean)serverPromise.await(5L, TimeUnit.SECONDS), (String)"server timeout");
                        clientPromise.sync();
                        serverPromise.sync();
                        Assertions.assertTrue((boolean)signCalled.get());
                    }
                    finally {
                        client.close().sync();
                    }
                }
                finally {
                    server.close().sync();
                }
            }
            finally {
                ReferenceCountUtil.release((Object)sslClientContext);
            }
        }
        finally {
            ReferenceCountUtil.release((Object)sslServerContext);
        }
    }

    @ParameterizedTest(name="{index}: delegate = {0}")
    @MethodSource(value={"parameters"})
    public void testPrivateKeyMethodFailsBecauseOfException(boolean delegate) throws Exception {
        this.testPrivateKeyMethodFails(delegate, false);
    }

    @ParameterizedTest(name="{index}: delegate = {0}")
    @MethodSource(value={"parameters"})
    public void testPrivateKeyMethodFailsBecauseOfNull(boolean delegate) throws Exception {
        this.testPrivateKeyMethodFails(delegate, true);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void testPrivateKeyMethodFails(final boolean delegate, final boolean returnNull) throws Exception {
        SslContext sslServerContext = this.buildServerContext(new OpenSslPrivateKeyMethod(){

            public byte[] sign(SSLEngine engine, int signatureAlgorithm, byte[] input) throws Exception {
                OpenSslPrivateKeyMethodTest.assertThread(delegate);
                if (returnNull) {
                    return null;
                }
                throw new SignatureException();
            }

            public byte[] decrypt(SSLEngine engine, byte[] input) {
                throw new UnsupportedOperationException();
            }
        });
        SslContext sslClientContext = this.buildClientContext();
        SslHandler serverSslHandler = OpenSslPrivateKeyMethodTest.newSslHandler(sslServerContext, (ByteBufAllocator)UnpooledByteBufAllocator.DEFAULT, OpenSslPrivateKeyMethodTest.delegateExecutor(delegate));
        SslHandler clientSslHandler = OpenSslPrivateKeyMethodTest.newSslHandler(sslClientContext, (ByteBufAllocator)UnpooledByteBufAllocator.DEFAULT, OpenSslPrivateKeyMethodTest.delegateExecutor(delegate));
        try {
            try {
                LocalAddress address = new LocalAddress("test-" + SslProvider.OPENSSL + '-' + SslProvider.JDK + '-' + RFC_CIPHER_NAME + '-' + delegate);
                Channel server = OpenSslPrivateKeyMethodTest.server(address, (ChannelHandler)serverSslHandler);
                try {
                    Channel client = OpenSslPrivateKeyMethodTest.client(server, (ChannelHandler)clientSslHandler);
                    try {
                        Throwable clientCause = clientSslHandler.handshakeFuture().await().cause();
                        Throwable serverCause = serverSslHandler.handshakeFuture().await().cause();
                        Assertions.assertNotNull((Object)clientCause);
                        MatcherAssert.assertThat((Object)serverCause, (Matcher)Matchers.instanceOf(SSLHandshakeException.class));
                    }
                    finally {
                        client.close().sync();
                    }
                }
                finally {
                    server.close().sync();
                }
            }
            finally {
                ReferenceCountUtil.release((Object)sslClientContext);
            }
        }
        finally {
            ReferenceCountUtil.release((Object)sslServerContext);
        }
    }

    private static Channel server(LocalAddress address, ChannelHandler handler) throws Exception {
        ServerBootstrap bootstrap = ((ServerBootstrap)new ServerBootstrap().channel(LocalServerChannel.class)).group(GROUP).childHandler(handler);
        return bootstrap.bind((SocketAddress)address).sync().channel();
    }

    private static Channel client(Channel server, ChannelHandler handler) throws Exception {
        SocketAddress remoteAddress = server.localAddress();
        Bootstrap bootstrap = (Bootstrap)((Bootstrap)((Bootstrap)new Bootstrap().channel(LocalChannel.class)).group(GROUP)).handler(handler);
        return bootstrap.connect(remoteAddress).sync().channel();
    }

    private static final class OpenSslPrivateKeyMethodAdapter
    implements OpenSslAsyncPrivateKeyMethod {
        private final OpenSslPrivateKeyMethod keyMethod;
        private final boolean newThread;

        OpenSslPrivateKeyMethodAdapter(OpenSslPrivateKeyMethod keyMethod, boolean newThread) {
            this.keyMethod = keyMethod;
            this.newThread = newThread;
        }

        public Future<byte[]> sign(final SSLEngine engine, final int signatureAlgorithm, final byte[] input) {
            final Promise promise = ImmediateEventExecutor.INSTANCE.newPromise();
            try {
                if (this.newThread) {
                    new DelegateThread(new Runnable(){

                        @Override
                        public void run() {
                            try {
                                Thread.sleep(ThreadLocalRandom.current().nextLong(100L, 500L));
                                promise.setSuccess((Object)OpenSslPrivateKeyMethodAdapter.this.keyMethod.sign(engine, signatureAlgorithm, input));
                            }
                            catch (Throwable cause) {
                                promise.setFailure(cause);
                            }
                        }
                    }).start();
                } else {
                    promise.setSuccess((Object)this.keyMethod.sign(engine, signatureAlgorithm, input));
                }
            }
            catch (Throwable cause) {
                promise.setFailure(cause);
            }
            return promise;
        }

        public Future<byte[]> decrypt(final SSLEngine engine, final byte[] input) {
            final Promise promise = ImmediateEventExecutor.INSTANCE.newPromise();
            try {
                if (this.newThread) {
                    new DelegateThread(new Runnable(){

                        @Override
                        public void run() {
                            try {
                                Thread.sleep(ThreadLocalRandom.current().nextLong(100L, 500L));
                                promise.setSuccess((Object)OpenSslPrivateKeyMethodAdapter.this.keyMethod.decrypt(engine, input));
                            }
                            catch (Throwable cause) {
                                promise.setFailure(cause);
                            }
                        }
                    }).start();
                } else {
                    promise.setSuccess((Object)this.keyMethod.decrypt(engine, input));
                }
            }
            catch (Throwable cause) {
                promise.setFailure(cause);
            }
            return promise;
        }
    }

    private static final class DelegateThread
    extends Thread {
        DelegateThread(Runnable target) {
            super(target);
        }
    }
}

