001package ca.uhn.fhir.test.utilities.server; 002 003/*- 004 * #%L 005 * HAPI FHIR Test Utilities 006 * %% 007 * Copyright (C) 2014 - 2023 Smile CDR, Inc. 008 * %% 009 * Licensed under the Apache License, Version 2.0 (the "License"); 010 * you may not use this file except in compliance with the License. 011 * You may obtain a copy of the License at 012 * 013 * http://www.apache.org/licenses/LICENSE-2.0 014 * 015 * Unless required by applicable law or agreed to in writing, software 016 * distributed under the License is distributed on an "AS IS" BASIS, 017 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 018 * See the License for the specific language governing permissions and 019 * limitations under the License. 020 * #L% 021 */ 022 023import ca.uhn.fhir.rest.api.Constants; 024import ca.uhn.fhir.test.utilities.JettyUtil; 025import org.apache.commons.lang3.Validate; 026import org.apache.http.impl.client.CloseableHttpClient; 027import org.apache.http.impl.client.HttpClientBuilder; 028import org.apache.http.impl.conn.PoolingHttpClientConnectionManager; 029import org.eclipse.jetty.io.Connection; 030import org.eclipse.jetty.io.Connection.Listener; 031import org.eclipse.jetty.server.Connector; 032import org.eclipse.jetty.server.HttpConnectionFactory; 033import org.eclipse.jetty.server.Server; 034import org.eclipse.jetty.server.ServerConnector; 035import org.eclipse.jetty.server.handler.HandlerList; 036import org.eclipse.jetty.servlet.FilterHolder; 037import org.eclipse.jetty.servlet.ServletContextHandler; 038import org.eclipse.jetty.servlet.ServletHolder; 039import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; 040import org.junit.jupiter.api.extension.AfterAllCallback; 041import org.junit.jupiter.api.extension.AfterEachCallback; 042import org.junit.jupiter.api.extension.BeforeEachCallback; 043import org.junit.jupiter.api.extension.ExtensionContext; 044import org.slf4j.Logger; 045import org.slf4j.LoggerFactory; 046import org.springframework.context.annotation.AnnotatedBeanDefinitionReader; 047import org.springframework.web.context.support.GenericWebApplicationContext; 048import org.springframework.web.servlet.DispatcherServlet; 049import org.springframework.web.socket.config.annotation.WebSocketConfigurer; 050 051import javax.annotation.PreDestroy; 052import javax.servlet.DispatcherType; 053import javax.servlet.Filter; 054import javax.servlet.FilterChain; 055import javax.servlet.FilterConfig; 056import javax.servlet.ServletException; 057import javax.servlet.ServletRequest; 058import javax.servlet.ServletResponse; 059import javax.servlet.http.HttpServlet; 060import javax.servlet.http.HttpServletRequest; 061import java.io.IOException; 062import java.util.ArrayList; 063import java.util.EnumSet; 064import java.util.Enumeration; 065import java.util.List; 066import java.util.concurrent.TimeUnit; 067import java.util.concurrent.atomic.AtomicLong; 068 069import static org.apache.commons.lang3.StringUtils.defaultString; 070import static org.apache.commons.lang3.StringUtils.isNotBlank; 071 072public abstract class BaseJettyServerExtension<T extends BaseJettyServerExtension<?>> implements BeforeEachCallback, AfterEachCallback, AfterAllCallback { 073 private static final Logger ourLog = LoggerFactory.getLogger(BaseJettyServerExtension.class); 074 private final List<List<String>> myRequestHeaders = new ArrayList<>(); 075 private final List<String> myRequestContentTypes = new ArrayList<>(); 076 private String myServletPath = "/*"; 077 private Server myServer; 078 private CloseableHttpClient myHttpClient; 079 private int myPort = 0; 080 private boolean myKeepAliveBetweenTests; 081 private String myContextPath = ""; 082 private AtomicLong myConnectionsOpenedCounter; 083 private Class<? extends WebSocketConfigurer> myEnableSpringWebsocketSupport; 084 private String myEnableSpringWebsocketContextPath; 085 private long myIdleTimeoutMillis = 30000; 086 087 /** 088 * Sets the Jetty server "idle timeout" in millis. This is the amount of time that 089 * the HTTP processor will allow a request to take before it hangs up on the 090 * client. This means the amount of time receiving the request over the network or 091 * streaming the response, not the amount of time spent actually processing the 092 * request (ie this is a network timeout, not a CPU timeout). Default is 093 * 30000. 094 */ 095 public T withIdleTimeout(long theIdleTimeoutMillis) { 096 myIdleTimeoutMillis = theIdleTimeoutMillis; 097 return (T) this; 098 } 099 100 @SuppressWarnings("unchecked") 101 public T withContextPath(String theContextPath) { 102 myContextPath = defaultString(theContextPath); 103 return (T) this; 104 } 105 106 /** 107 * Returns the total number of connections that this server has received. This 108 * is not the current number of open connections, it's the number of new 109 * connections that have been opened at any point. 110 */ 111 public long getConnectionsOpenedCount() { 112 return myConnectionsOpenedCounter.get(); 113 } 114 115 public void resetConnectionsOpenedCount() { 116 myConnectionsOpenedCounter.set(0); 117 } 118 119 public CloseableHttpClient getHttpClient() { 120 return myHttpClient; 121 } 122 123 public List<String> getRequestContentTypes() { 124 return myRequestContentTypes; 125 } 126 127 public List<List<String>> getRequestHeaders() { 128 return myRequestHeaders; 129 } 130 131 @PreDestroy 132 public void stopServer() throws Exception { 133 if (!isRunning()) { 134 return; 135 } 136 JettyUtil.closeServer(myServer); 137 myServer = null; 138 139 myHttpClient.close(); 140 myHttpClient = null; 141 } 142 143 protected void startServer() throws Exception { 144 if (isRunning()) { 145 return; 146 } 147 148 myServer = new Server(); 149 myConnectionsOpenedCounter = new AtomicLong(0); 150 151 ServerConnector connector = new ServerConnector(myServer); 152 connector.setIdleTimeout(myIdleTimeoutMillis); 153 connector.setPort(myPort); 154 myServer.setConnectors(new Connector[]{connector}); 155 156 HttpConnectionFactory connectionFactory = (HttpConnectionFactory) connector.getConnectionFactories().iterator().next(); 157 connectionFactory.addBean(new Listener() { 158 @Override 159 public void onOpened(Connection connection) { 160 myConnectionsOpenedCounter.incrementAndGet(); 161 } 162 163 @Override 164 public void onClosed(Connection connection) { 165 // nothing 166 } 167 }); 168 169 ServletHolder servletHolder = new ServletHolder(provideServlet()); 170 171 HandlerList handlerList = new HandlerList(); 172 173 ServletContextHandler contextHandler = new ServletContextHandler(); 174 contextHandler.setContextPath(myContextPath); 175 contextHandler.addServlet(servletHolder, myServletPath); 176 contextHandler.addFilter(new FilterHolder(requestCapturingFilter()), "/*", EnumSet.allOf(DispatcherType.class)); 177 handlerList.addHandler(contextHandler); 178 179 if (myEnableSpringWebsocketSupport != null) { 180 181 GenericWebApplicationContext wac = new GenericWebApplicationContext(); 182 wac.setParent(SpringContextGrabbingTestExecutionListener.getApplicationContext()); 183 AnnotatedBeanDefinitionReader reader = new AnnotatedBeanDefinitionReader(wac); 184 reader.register(myEnableSpringWebsocketSupport); 185 186 DispatcherServlet dispatcherServlet = new DispatcherServlet(); 187 dispatcherServlet.setApplicationContext(wac); 188 ServletHolder subsServletHolder = new ServletHolder(); 189 subsServletHolder.setServlet(dispatcherServlet); 190 191 ServletContextHandler servletContextHandler = new ServletContextHandler(); 192 servletContextHandler.setContextPath(myEnableSpringWebsocketContextPath); 193 servletContextHandler.setAllowNullPathInfo(true); 194 servletContextHandler.addServlet(new ServletHolder(dispatcherServlet), "/*"); 195 JettyWebSocketServletContainerInitializer.configure(servletContextHandler, null); 196 197 handlerList.addHandler(servletContextHandler); 198 } 199 200 myServer.setHandler(handlerList); 201 myServer.start(); 202 203 myPort = JettyUtil.getPortForStartedServer(myServer); 204 ourLog.info("Server has started on port {}", myPort); 205 PoolingHttpClientConnectionManager connectionManager = new PoolingHttpClientConnectionManager(5000, TimeUnit.MILLISECONDS); 206 HttpClientBuilder builder = HttpClientBuilder.create(); 207 builder.setConnectionManager(connectionManager); 208 myHttpClient = builder.build(); 209 } 210 211 private Filter requestCapturingFilter() { 212 return new RequestCapturingFilter(); 213 } 214 215 public int getPort() { 216 return myPort; 217 } 218 219 protected abstract HttpServlet provideServlet(); 220 221 222 public String getWebsocketContextPath() { 223 return myEnableSpringWebsocketContextPath; 224 } 225 226 /** 227 * Should be in the format <code>/the/path/*</code> 228 */ 229 @SuppressWarnings("unchecked") 230 public T withServletPath(String theServletPath) { 231 Validate.isTrue(theServletPath.startsWith("/"), "Servlet path should start with /"); 232 Validate.isTrue(theServletPath.endsWith("/*"), "Servlet path should end with /*"); 233 myServletPath = theServletPath; 234 return (T) this; 235 } 236 237 @SuppressWarnings("unchecked") 238 public T withPort(int thePort) { 239 myPort = thePort; 240 return (T) this; 241 } 242 243 @SuppressWarnings("unchecked") 244 public T keepAliveBetweenTests() { 245 myKeepAliveBetweenTests = true; 246 return (T) this; 247 } 248 249 protected boolean isRunning() { 250 return myServer != null; 251 } 252 253 /** 254 * Returns the server base URL with no trailing slash 255 */ 256 public String getBaseUrl() { 257 return "http://localhost:" + myPort + myContextPath + myServletPath.substring(0, myServletPath.length() - 2); 258 } 259 260 @Override 261 public void beforeEach(ExtensionContext context) throws Exception { 262 startServer(); 263 myRequestContentTypes.clear(); 264 myRequestHeaders.clear(); 265 } 266 267 @Override 268 public void afterEach(ExtensionContext context) throws Exception { 269 if (!myKeepAliveBetweenTests) { 270 stopServer(); 271 } 272 } 273 274 @Override 275 public void afterAll(ExtensionContext context) throws Exception { 276 stopServer(); 277 } 278 279 /** 280 * To use this method, you need to add the following to your 281 * test class: 282 * <code>@TestExecutionListeners(value = SpringContextGrabbingTestExecutionListener.class, mergeMode = TestExecutionListeners.MergeMode.MERGE_WITH_DEFAULTS)</code> 283 */ 284 @SuppressWarnings("unchecked") 285 public T withSpringWebsocketSupport(String theContextPath, Class<? extends WebSocketConfigurer> theContextConfigClass) { 286 assert !isRunning(); 287 assert theContextConfigClass != null; 288 myEnableSpringWebsocketSupport = theContextConfigClass; 289 myEnableSpringWebsocketContextPath = theContextPath; 290 return (T) this; 291 } 292 293 private class RequestCapturingFilter implements Filter { 294 @Override 295 public void init(FilterConfig filterConfig) throws ServletException { 296 // nothing 297 } 298 299 @Override 300 public void doFilter(ServletRequest theRequest, ServletResponse theResponse, FilterChain theChain) throws IOException, ServletException { 301 HttpServletRequest request = (HttpServletRequest) theRequest; 302 303 String header = request.getHeader(Constants.HEADER_CONTENT_TYPE); 304 if (isNotBlank(header)) { 305 myRequestContentTypes.add(header.replaceAll(";.*", "")); 306 } else { 307 myRequestContentTypes.add(null); 308 } 309 310 java.util.Enumeration<String> headerNamesEnum = request.getHeaderNames(); 311 List<String> requestHeaders = new ArrayList<>(); 312 myRequestHeaders.add(requestHeaders); 313 while (headerNamesEnum.hasMoreElements()) { 314 String nextName = headerNamesEnum.nextElement(); 315 Enumeration<String> valueEnum = request.getHeaders(nextName); 316 while (valueEnum.hasMoreElements()) { 317 String nextValue = valueEnum.nextElement(); 318 requestHeaders.add(nextName + ": " + nextValue); 319 } 320 } 321 322 theChain.doFilter(theRequest, theResponse); 323 } 324 325 @Override 326 public void destroy() { 327 // nothing 328 } 329 } 330}