001package ca.uhn.fhir.test.utilities.server;
002
003/*-
004 * #%L
005 * HAPI FHIR Test Utilities
006 * %%
007 * Copyright (C) 2014 - 2023 Smile CDR, Inc.
008 * %%
009 * Licensed under the Apache License, Version 2.0 (the "License");
010 * you may not use this file except in compliance with the License.
011 * You may obtain a copy of the License at
012 *
013 *      http://www.apache.org/licenses/LICENSE-2.0
014 *
015 * Unless required by applicable law or agreed to in writing, software
016 * distributed under the License is distributed on an "AS IS" BASIS,
017 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
018 * See the License for the specific language governing permissions and
019 * limitations under the License.
020 * #L%
021 */
022
023import ca.uhn.fhir.rest.api.Constants;
024import ca.uhn.fhir.test.utilities.JettyUtil;
025import org.apache.commons.lang3.Validate;
026import org.apache.http.impl.client.CloseableHttpClient;
027import org.apache.http.impl.client.HttpClientBuilder;
028import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
029import org.eclipse.jetty.io.Connection;
030import org.eclipse.jetty.io.Connection.Listener;
031import org.eclipse.jetty.server.Connector;
032import org.eclipse.jetty.server.HttpConnectionFactory;
033import org.eclipse.jetty.server.Server;
034import org.eclipse.jetty.server.ServerConnector;
035import org.eclipse.jetty.server.handler.HandlerList;
036import org.eclipse.jetty.servlet.FilterHolder;
037import org.eclipse.jetty.servlet.ServletContextHandler;
038import org.eclipse.jetty.servlet.ServletHolder;
039import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
040import org.junit.jupiter.api.extension.AfterAllCallback;
041import org.junit.jupiter.api.extension.AfterEachCallback;
042import org.junit.jupiter.api.extension.BeforeEachCallback;
043import org.junit.jupiter.api.extension.ExtensionContext;
044import org.slf4j.Logger;
045import org.slf4j.LoggerFactory;
046import org.springframework.context.annotation.AnnotatedBeanDefinitionReader;
047import org.springframework.web.context.support.GenericWebApplicationContext;
048import org.springframework.web.servlet.DispatcherServlet;
049import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
050
051import javax.annotation.PreDestroy;
052import javax.servlet.DispatcherType;
053import javax.servlet.Filter;
054import javax.servlet.FilterChain;
055import javax.servlet.FilterConfig;
056import javax.servlet.ServletException;
057import javax.servlet.ServletRequest;
058import javax.servlet.ServletResponse;
059import javax.servlet.http.HttpServlet;
060import javax.servlet.http.HttpServletRequest;
061import java.io.IOException;
062import java.util.ArrayList;
063import java.util.EnumSet;
064import java.util.Enumeration;
065import java.util.List;
066import java.util.concurrent.TimeUnit;
067import java.util.concurrent.atomic.AtomicLong;
068
069import static org.apache.commons.lang3.StringUtils.defaultString;
070import static org.apache.commons.lang3.StringUtils.isNotBlank;
071
072public abstract class BaseJettyServerExtension<T extends BaseJettyServerExtension<?>> implements BeforeEachCallback, AfterEachCallback, AfterAllCallback {
073        private static final Logger ourLog = LoggerFactory.getLogger(BaseJettyServerExtension.class);
074        private final List<List<String>> myRequestHeaders = new ArrayList<>();
075        private final List<String> myRequestContentTypes = new ArrayList<>();
076        private String myServletPath = "/*";
077        private Server myServer;
078        private CloseableHttpClient myHttpClient;
079        private int myPort = 0;
080        private boolean myKeepAliveBetweenTests;
081        private String myContextPath = "";
082        private AtomicLong myConnectionsOpenedCounter;
083        private Class<? extends WebSocketConfigurer> myEnableSpringWebsocketSupport;
084        private String myEnableSpringWebsocketContextPath;
085        private long myIdleTimeoutMillis = 30000;
086
087        /**
088         * Sets the Jetty server "idle timeout" in millis. This is the amount of time that
089         * the HTTP processor will allow a request to take before it hangs up on the
090         * client. This means the amount of time receiving the request over the network or
091         * streaming the response, not the amount of time spent actually processing the
092         * request (ie this is a network timeout, not a CPU timeout). Default is
093         * 30000.
094         */
095        public T withIdleTimeout(long theIdleTimeoutMillis) {
096                myIdleTimeoutMillis = theIdleTimeoutMillis;
097                return (T) this;
098        }
099
100        @SuppressWarnings("unchecked")
101        public T withContextPath(String theContextPath) {
102                myContextPath = defaultString(theContextPath);
103                return (T) this;
104        }
105
106        /**
107         * Returns the total number of connections that this server has received. This
108         * is not the current number of open connections, it's the number of new
109         * connections that have been opened at any point.
110         */
111        public long getConnectionsOpenedCount() {
112                return myConnectionsOpenedCounter.get();
113        }
114
115        public void resetConnectionsOpenedCount() {
116                myConnectionsOpenedCounter.set(0);
117        }
118
119        public CloseableHttpClient getHttpClient() {
120                return myHttpClient;
121        }
122
123        public List<String> getRequestContentTypes() {
124                return myRequestContentTypes;
125        }
126
127        public List<List<String>> getRequestHeaders() {
128                return myRequestHeaders;
129        }
130
131        @PreDestroy
132        public void stopServer() throws Exception {
133                if (!isRunning()) {
134                        return;
135                }
136                JettyUtil.closeServer(myServer);
137                myServer = null;
138
139                myHttpClient.close();
140                myHttpClient = null;
141        }
142
143        protected void startServer() throws Exception {
144                if (isRunning()) {
145                        return;
146                }
147
148                myServer = new Server();
149                myConnectionsOpenedCounter = new AtomicLong(0);
150
151                ServerConnector connector = new ServerConnector(myServer);
152                connector.setIdleTimeout(myIdleTimeoutMillis);
153                connector.setPort(myPort);
154                myServer.setConnectors(new Connector[]{connector});
155
156                HttpConnectionFactory connectionFactory = (HttpConnectionFactory) connector.getConnectionFactories().iterator().next();
157                connectionFactory.addBean(new Listener() {
158                        @Override
159                        public void onOpened(Connection connection) {
160                                myConnectionsOpenedCounter.incrementAndGet();
161                        }
162
163                        @Override
164                        public void onClosed(Connection connection) {
165                                // nothing
166                        }
167                });
168
169                ServletHolder servletHolder = new ServletHolder(provideServlet());
170
171                HandlerList handlerList = new HandlerList();
172
173                ServletContextHandler contextHandler = new ServletContextHandler();
174                contextHandler.setContextPath(myContextPath);
175                contextHandler.addServlet(servletHolder, myServletPath);
176                contextHandler.addFilter(new FilterHolder(requestCapturingFilter()), "/*", EnumSet.allOf(DispatcherType.class));
177                handlerList.addHandler(contextHandler);
178
179                if (myEnableSpringWebsocketSupport != null) {
180
181                        GenericWebApplicationContext wac = new GenericWebApplicationContext();
182                        wac.setParent(SpringContextGrabbingTestExecutionListener.getApplicationContext());
183                        AnnotatedBeanDefinitionReader reader = new AnnotatedBeanDefinitionReader(wac);
184                        reader.register(myEnableSpringWebsocketSupport);
185
186                        DispatcherServlet dispatcherServlet = new DispatcherServlet();
187                        dispatcherServlet.setApplicationContext(wac);
188                        ServletHolder subsServletHolder = new ServletHolder();
189                        subsServletHolder.setServlet(dispatcherServlet);
190
191                        ServletContextHandler servletContextHandler = new ServletContextHandler();
192                        servletContextHandler.setContextPath(myEnableSpringWebsocketContextPath);
193                        servletContextHandler.setAllowNullPathInfo(true);
194                        servletContextHandler.addServlet(new ServletHolder(dispatcherServlet), "/*");
195                        JettyWebSocketServletContainerInitializer.configure(servletContextHandler, null);
196
197                        handlerList.addHandler(servletContextHandler);
198                }
199
200                myServer.setHandler(handlerList);
201                myServer.start();
202
203                myPort = JettyUtil.getPortForStartedServer(myServer);
204                ourLog.info("Server has started on port {}", myPort);
205                PoolingHttpClientConnectionManager connectionManager = new PoolingHttpClientConnectionManager(5000, TimeUnit.MILLISECONDS);
206                HttpClientBuilder builder = HttpClientBuilder.create();
207                builder.setConnectionManager(connectionManager);
208                myHttpClient = builder.build();
209        }
210
211        private Filter requestCapturingFilter() {
212                return new RequestCapturingFilter();
213        }
214
215        public int getPort() {
216                return myPort;
217        }
218
219        protected abstract HttpServlet provideServlet();
220
221
222        public String getWebsocketContextPath() {
223                return myEnableSpringWebsocketContextPath;
224        }
225
226        /**
227         * Should be in the format <code>/the/path/*</code>
228         */
229        @SuppressWarnings("unchecked")
230        public T withServletPath(String theServletPath) {
231                Validate.isTrue(theServletPath.startsWith("/"), "Servlet path should start with /");
232                Validate.isTrue(theServletPath.endsWith("/*"), "Servlet path should end with /*");
233                myServletPath = theServletPath;
234                return (T) this;
235        }
236
237        @SuppressWarnings("unchecked")
238        public T withPort(int thePort) {
239                myPort = thePort;
240                return (T) this;
241        }
242
243        @SuppressWarnings("unchecked")
244        public T keepAliveBetweenTests() {
245                myKeepAliveBetweenTests = true;
246                return (T) this;
247        }
248
249        protected boolean isRunning() {
250                return myServer != null;
251        }
252
253        /**
254         * Returns the server base URL with no trailing slash
255         */
256        public String getBaseUrl() {
257                return "http://localhost:" + myPort + myContextPath + myServletPath.substring(0, myServletPath.length() - 2);
258        }
259
260        @Override
261        public void beforeEach(ExtensionContext context) throws Exception {
262                startServer();
263                myRequestContentTypes.clear();
264                myRequestHeaders.clear();
265        }
266
267        @Override
268        public void afterEach(ExtensionContext context) throws Exception {
269                if (!myKeepAliveBetweenTests) {
270                        stopServer();
271                }
272        }
273
274        @Override
275        public void afterAll(ExtensionContext context) throws Exception {
276                stopServer();
277        }
278
279        /**
280         * To use this method, you need to add the following to your
281         * test class:
282         * <code>@TestExecutionListeners(value = SpringContextGrabbingTestExecutionListener.class, mergeMode = TestExecutionListeners.MergeMode.MERGE_WITH_DEFAULTS)</code>
283         */
284        @SuppressWarnings("unchecked")
285        public T withSpringWebsocketSupport(String theContextPath, Class<? extends WebSocketConfigurer> theContextConfigClass) {
286                assert !isRunning();
287                assert theContextConfigClass != null;
288                myEnableSpringWebsocketSupport = theContextConfigClass;
289                myEnableSpringWebsocketContextPath = theContextPath;
290                return (T) this;
291        }
292
293        private class RequestCapturingFilter implements Filter {
294                @Override
295                public void init(FilterConfig filterConfig) throws ServletException {
296                        // nothing
297                }
298
299                @Override
300                public void doFilter(ServletRequest theRequest, ServletResponse theResponse, FilterChain theChain) throws IOException, ServletException {
301                        HttpServletRequest request = (HttpServletRequest) theRequest;
302
303                        String header = request.getHeader(Constants.HEADER_CONTENT_TYPE);
304                        if (isNotBlank(header)) {
305                                myRequestContentTypes.add(header.replaceAll(";.*", ""));
306                        } else {
307                                myRequestContentTypes.add(null);
308                        }
309
310                        java.util.Enumeration<String> headerNamesEnum = request.getHeaderNames();
311                        List<String> requestHeaders = new ArrayList<>();
312                        myRequestHeaders.add(requestHeaders);
313                        while (headerNamesEnum.hasMoreElements()) {
314                                String nextName = headerNamesEnum.nextElement();
315                                Enumeration<String> valueEnum = request.getHeaders(nextName);
316                                while (valueEnum.hasMoreElements()) {
317                                        String nextValue = valueEnum.nextElement();
318                                        requestHeaders.add(nextName + ": " + nextValue);
319                                }
320                        }
321
322                        theChain.doFilter(theRequest, theResponse);
323                }
324
325                @Override
326                public void destroy() {
327                        // nothing
328                }
329        }
330}