/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.jetty.servlet;

import jakarta.servlet.AsyncContext;
import jakarta.servlet.AsyncEvent;
import jakarta.servlet.AsyncListener;
import jakarta.servlet.Servlet;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.eclipse.jetty.http.HttpTester;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.HandlerContainer;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHolder;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class AsyncContextListenersTest {
    private Server _server;
    private ServerConnector _connector;

    public void prepare(String path, HttpServlet servlet) throws Exception {
        this._server = new Server();
        this._connector = new ServerConnector(this._server);
        this._server.addConnector((Connector)this._connector);
        ServletContextHandler context = new ServletContextHandler((HandlerContainer)this._server, "/", false, false);
        context.addServlet(new ServletHolder((Servlet)servlet), path);
        this._server.start();
    }

    @AfterEach
    public void dispose() throws Exception {
        this._server.stop();
    }

    @Test
    public void testListenerClearedOnSecondRequest() throws Exception {
        final AtomicReference<CountDownLatch> completes = new AtomicReference<CountDownLatch>(new CountDownLatch(1));
        String path = "/path";
        this.prepare(path, new HttpServlet(){

            protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
                AsyncContext asyncContext = request.startAsync((ServletRequest)request, (ServletResponse)response);
                asyncContext.addListener(new AsyncListener(){

                    public void onStartAsync(AsyncEvent event) throws IOException {
                    }

                    public void onComplete(AsyncEvent event) throws IOException {
                        ((CountDownLatch)completes.get()).countDown();
                    }

                    public void onTimeout(AsyncEvent event) throws IOException {
                    }

                    public void onError(AsyncEvent event) throws IOException {
                    }
                });
                asyncContext.complete();
            }
        });
        try (Socket socket = new Socket("localhost", this._connector.getLocalPort());){
            OutputStream output = socket.getOutputStream();
            String request = "GET " + path + " HTTP/1.1\r\nHost: localhost\r\n\r\n";
            output.write(request.getBytes(StandardCharsets.UTF_8));
            output.flush();
            HttpTester.Input input = HttpTester.from((InputStream)socket.getInputStream());
            HttpTester.Response response = HttpTester.parseResponse((HttpTester.Input)input);
            Assertions.assertEquals((int)200, (int)response.getStatus());
            completes.get().await(10L, TimeUnit.SECONDS);
            completes.set(new CountDownLatch(1));
            output.write(request.getBytes(StandardCharsets.UTF_8));
            output.flush();
            response = HttpTester.parseResponse((HttpTester.Input)input);
            Assertions.assertEquals((int)200, (int)response.getStatus());
            completes.get().await(10L, TimeUnit.SECONDS);
        }
    }

    @Test
    public void testListenerAddedFromListener() throws Exception {
        final AtomicReference<CountDownLatch> completes = new AtomicReference<CountDownLatch>(new CountDownLatch(1));
        String path = "/path";
        this.prepare(path, new HttpServlet(){

            protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
                AsyncContext asyncContext = request.startAsync((ServletRequest)request, (ServletResponse)response);
                asyncContext.addListener(new AsyncListener(){

                    public void onStartAsync(AsyncEvent event) throws IOException {
                        event.getAsyncContext().addListener((AsyncListener)this);
                    }

                    public void onComplete(AsyncEvent event) throws IOException {
                        ((CountDownLatch)completes.get()).countDown();
                    }

                    public void onTimeout(AsyncEvent event) throws IOException {
                    }

                    public void onError(AsyncEvent event) throws IOException {
                    }
                });
                asyncContext.complete();
            }
        });
        try (Socket socket = new Socket("localhost", this._connector.getLocalPort());){
            OutputStream output = socket.getOutputStream();
            String request = "GET " + path + " HTTP/1.1\r\nHost: localhost\r\n\r\n";
            output.write(request.getBytes(StandardCharsets.UTF_8));
            output.flush();
            HttpTester.Input input = HttpTester.from((InputStream)socket.getInputStream());
            HttpTester.Response response = HttpTester.parseResponse((HttpTester.Input)input);
            Assertions.assertEquals((int)200, (int)response.getStatus());
            completes.get().await(10L, TimeUnit.SECONDS);
            completes.set(new CountDownLatch(1));
            output.write(request.getBytes(StandardCharsets.UTF_8));
            output.flush();
            response = HttpTester.parseResponse((HttpTester.Input)input);
            Assertions.assertEquals((int)200, (int)response.getStatus());
            completes.get().await(10L, TimeUnit.SECONDS);
        }
    }

    @Test
    public void testAsyncDispatchAsyncCompletePreservesListener() throws Exception {
        final AtomicReference<CountDownLatch> completes = new AtomicReference<CountDownLatch>(new CountDownLatch(1));
        String path = "/path";
        this.prepare("/path/*", new HttpServlet(){

            protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
                String requestURI = request.getRequestURI();
                if (requestURI.endsWith("/one")) {
                    AsyncContext asyncContext = request.startAsync((ServletRequest)request, (ServletResponse)response);
                    asyncContext.addListener(new AsyncListener(){

                        public void onStartAsync(AsyncEvent event) throws IOException {
                            event.getAsyncContext().addListener((AsyncListener)this);
                        }

                        public void onComplete(AsyncEvent event) throws IOException {
                            ((CountDownLatch)completes.get()).countDown();
                        }

                        public void onTimeout(AsyncEvent event) throws IOException {
                        }

                        public void onError(AsyncEvent event) throws IOException {
                        }
                    });
                    asyncContext.dispatch("/path/two");
                } else if (requestURI.endsWith("/two")) {
                    AsyncContext asyncContext = request.startAsync((ServletRequest)request, (ServletResponse)response);
                    asyncContext.complete();
                }
            }
        });
        try (Socket socket = new Socket("localhost", this._connector.getLocalPort());){
            OutputStream output = socket.getOutputStream();
            String request = "GET /path/one HTTP/1.1\r\nHost: localhost\r\n\r\n";
            output.write(request.getBytes(StandardCharsets.UTF_8));
            output.flush();
            HttpTester.Input input = HttpTester.from((InputStream)socket.getInputStream());
            HttpTester.Response response = HttpTester.parseResponse((HttpTester.Input)input);
            Assertions.assertEquals((int)200, (int)response.getStatus());
            completes.get().await(10L, TimeUnit.SECONDS);
            completes.set(new CountDownLatch(1));
            output.write(request.getBytes(StandardCharsets.UTF_8));
            output.flush();
            response = HttpTester.parseResponse((HttpTester.Input)input);
            Assertions.assertEquals((int)200, (int)response.getStatus());
            completes.get().await(10L, TimeUnit.SECONDS);
        }
    }
}

