001package ca.uhn.test.concurrency; 002 003/*- 004 * #%L 005 * HAPI FHIR Test Utilities 006 * %% 007 * Copyright (C) 2014 - 2023 Smile CDR, Inc. 008 * %% 009 * Licensed under the Apache License, Version 2.0 (the "License"); 010 * you may not use this file except in compliance with the License. 011 * You may obtain a copy of the License at 012 * 013 * http://www.apache.org/licenses/LICENSE-2.0 014 * 015 * Unless required by applicable law or agreed to in writing, software 016 * distributed under the License is distributed on an "AS IS" BASIS, 017 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 018 * See the License for the specific language governing permissions and 019 * limitations under the License. 020 * #L% 021 */ 022 023 024import ca.uhn.fhir.i18n.Msg; 025import ca.uhn.fhir.interceptor.api.HookParams; 026import ca.uhn.fhir.interceptor.api.IAnonymousInterceptor; 027import ca.uhn.fhir.interceptor.api.IPointcut; 028import com.google.common.collect.ListMultimap; 029import org.apache.commons.lang3.Validate; 030import org.apache.commons.lang3.builder.ToStringBuilder; 031import org.apache.commons.lang3.exception.ExceptionUtils; 032import org.slf4j.Logger; 033import org.slf4j.LoggerFactory; 034 035import java.util.ArrayList; 036import java.util.Collections; 037import java.util.List; 038import java.util.concurrent.CountDownLatch; 039import java.util.concurrent.TimeUnit; 040import java.util.concurrent.atomic.AtomicLong; 041import java.util.concurrent.atomic.AtomicReference; 042import java.util.stream.Collectors; 043 044// This class is primarily used for testing. 045public class PointcutLatch implements IAnonymousInterceptor, IPointcutLatch { 046 private static final Logger ourLog = LoggerFactory.getLogger(PointcutLatch.class); 047 private static final int DEFAULT_TIMEOUT_SECONDS = 10; 048 private static final FhirObjectPrinter ourFhirObjectToStringMapper = new FhirObjectPrinter(); 049 050 private final String myName; 051 private final AtomicLong myLastInvoke = new AtomicLong(); 052 private final AtomicReference<CountDownLatch> myCountdownLatch = new AtomicReference<>(); 053 private final AtomicReference<String> myCountdownLatchSetStacktrace = new AtomicReference<>(); 054 private final AtomicReference<List<String>> myFailures = new AtomicReference<>(); 055 private final AtomicReference<List<HookParams>> myCalledWith = new AtomicReference<>(); 056 private final IPointcut myPointcut; 057 private int myDefaultTimeoutSeconds = DEFAULT_TIMEOUT_SECONDS; 058 private int myInitialCount; 059 private boolean myExactMatch; 060 061 public PointcutLatch(IPointcut thePointcut) { 062 this.myName = thePointcut.name(); 063 myPointcut = thePointcut; 064 } 065 066 067 public PointcutLatch(String theName) { 068 this.myName = theName; 069 myPointcut = null; 070 } 071 072 public void runWithExpectedCount(int theExpectedCount, Runnable r) throws InterruptedException { 073 this.setExpectedCount(theExpectedCount); 074 r.run(); 075 this.awaitExpected(); 076 } 077 078 public long getLastInvoke() { 079 return myLastInvoke.get(); 080 } 081 082 // Useful for debugging when you need more time to step through a method 083 public PointcutLatch setDefaultTimeoutSeconds(int theDefaultTimeoutSeconds) { 084 myDefaultTimeoutSeconds = theDefaultTimeoutSeconds; 085 return this; 086 } 087 088 @Override 089 public void setExpectedCount(int theCount) { 090 this.setExpectedCount(theCount, true); 091 } 092 093 public void setExpectedCount(int theCount, boolean theExactMatch) { 094 if (myCountdownLatch.get() != null) { 095 String previousStack = myCountdownLatchSetStacktrace.get(); 096 throw new PointcutLatchException(Msg.code(1480) + "setExpectedCount() called before previous awaitExpected() completed. Previous set stack:\n" + previousStack); 097 } 098 myExactMatch = theExactMatch; 099 createLatch(theCount); 100 if (theExactMatch) { 101 ourLog.info("Expecting exactly {} calls to {} latch", theCount, myName); 102 } else { 103 ourLog.info("Expecting at least {} calls to {} latch", theCount, myName); 104 } 105 } 106 107 public void setExpectAtLeast(int theCount) { 108 setExpectedCount(theCount, false); 109 } 110 111 public boolean isSet() { 112 return myCountdownLatch.get() != null; 113 } 114 115 private void createLatch(int theCount) { 116 myFailures.set(Collections.synchronizedList(new ArrayList<>())); 117 myCalledWith.set(Collections.synchronizedList(new ArrayList<>())); 118 myCountdownLatch.set(new CountDownLatch(theCount)); 119 try { 120 throw new Exception(Msg.code(1481)); 121 } catch (Exception e) { 122 myCountdownLatchSetStacktrace.set(ExceptionUtils.getStackTrace(e)); 123 } 124 myInitialCount = theCount; 125 } 126 127 private void addFailure(String failure) { 128 if (myFailures.get() != null) { 129 myFailures.get().add(failure); 130 } else { 131 throw new PointcutLatchException(Msg.code(1482) + "trying to set failure on latch that hasn't been created: " + failure); 132 } 133 } 134 135 private String getName() { 136 return myName + " " + this.getClass().getSimpleName(); 137 } 138 139 @Override 140 public List<HookParams> awaitExpected() throws InterruptedException { 141 return awaitExpectedWithTimeout(myDefaultTimeoutSeconds); 142 } 143 144 public List<HookParams> awaitExpectedWithTimeout(int timeoutSecond) throws InterruptedException { 145 List<HookParams> retval = myCalledWith.get(); 146 try { 147 CountDownLatch latch = myCountdownLatch.get(); 148 Validate.notNull(latch, getName() + " awaitExpected() called before setExpected() called."); 149 if (!latch.await(timeoutSecond, TimeUnit.SECONDS)) { 150 throw new LatchTimedOutError(Msg.code(1483) + getName() + " timed out waiting " + timeoutSecond + " seconds for latch to countdown from " + myInitialCount + " to 0. Is " + latch.getCount() + "."); 151 } 152 153 // Defend against ConcurrentModificationException 154 String error = getName(); 155 if (myFailures.get() != null && myFailures.get().size() > 0) { 156 List<String> failures = new ArrayList<>(myFailures.get()); 157 if (failures.size() > 1) { 158 error += " ERRORS: \n"; 159 } else { 160 error += " ERROR: "; 161 } 162 error += String.join("\n", failures); 163 error += "\nLatch called with values: " + toCalledWithString(); 164 throw new AssertionError(Msg.code(1484) + error); 165 } 166 } finally { 167 clear(); 168 } 169 Validate.isTrue(retval.equals(myCalledWith.get()), "Concurrency error: Latch switched while waiting."); 170 return retval; 171 } 172 173 @Override 174 public void clear() { 175 myCountdownLatch.set(null); 176 myCountdownLatchSetStacktrace.set(null); 177 } 178 179 private String toCalledWithString() { 180 if (myCalledWith.get() == null) { 181 return "[]"; 182 } 183 // Defend against ConcurrentModificationException 184 List<HookParams> calledWith = new ArrayList<>(myCalledWith.get()); 185 if (calledWith.isEmpty()) { 186 return "[]"; 187 } 188 String retVal = "[ "; 189 retVal += calledWith.stream().flatMap(hookParams -> hookParams.values().stream()).map(ourFhirObjectToStringMapper).collect(Collectors.joining(", ")); 190 return retVal + " ]"; 191 } 192 193 @Override 194 public void invoke(IPointcut thePointcut, HookParams theArgs) { 195 myLastInvoke.set(System.currentTimeMillis()); 196 197 CountDownLatch latch = myCountdownLatch.get(); 198 if (myExactMatch) { 199 if (latch == null) { 200 throw new PointcutLatchException(Msg.code(1485) + "invoke() for " + myName + " called outside of setExpectedCount() .. awaitExpected(). Probably got more invocations than expected or clear() was called before invoke() arrived with args: " + theArgs, theArgs); 201 } else if (latch.getCount() <= 0) { 202 addFailure("invoke() called when countdown was zero."); 203 } 204 } else if (latch == null || latch.getCount() <= 0) { 205 return; 206 } 207 208 if (myCalledWith.get() != null) { 209 myCalledWith.get().add(theArgs); 210 } 211 ourLog.debug("Called {} {} with {}", myName, latch, hookParamsToString(theArgs)); 212 213 latch.countDown(); 214 } 215 216 public void call(Object arg) { 217 this.invoke(myPointcut, new HookParams(arg)); 218 } 219 220 @Override 221 public String toString() { 222 return new ToStringBuilder(this) 223 .append("name", myName) 224 .append("myCountdownLatch", myCountdownLatch) 225// .append("myFailures", myFailures) 226// .append("myCalledWith", myCalledWith) 227 .append("myInitialCount", myInitialCount) 228 .toString(); 229 } 230 231 public Object getLatchInvocationParameter() { 232 return getLatchInvocationParameter(myCalledWith.get()); 233 } 234 235 @SuppressWarnings("unchecked") 236 public <T> T getLatchInvocationParameterOfType(Class<T> theType) { 237 List<HookParams> hookParamsList = myCalledWith.get(); 238 Validate.notNull(hookParamsList); 239 Validate.isTrue(hookParamsList.size() == 1, "Expected Pointcut to be invoked 1 time"); 240 HookParams hookParams = hookParamsList.get(0); 241 ListMultimap<Class<?>, Object> paramsForType = hookParams.getParamsForType(); 242 List<Object> objects = paramsForType.get(theType); 243 Validate.isTrue(objects.size() == 1); 244 return (T) objects.get(0); 245 } 246 247 248 private class PointcutLatchException extends IllegalStateException { 249 private static final long serialVersionUID = 1372636272233536829L; 250 251 PointcutLatchException(String message, HookParams theArgs) { 252 super(getName() + ": " + message + " called with values: " + hookParamsToString(theArgs)); 253 } 254 255 public PointcutLatchException(String message) { 256 super(getName() + ": " + message); 257 } 258 } 259 260 private static String hookParamsToString(HookParams hookParams) { 261 return hookParams.values().stream().map(ourFhirObjectToStringMapper).collect(Collectors.joining(", ")); 262 } 263 264 public static Object getLatchInvocationParameter(List<HookParams> theHookParams) { 265 Validate.notNull(theHookParams); 266 Validate.isTrue(theHookParams.size() == 1, "Expected Pointcut to be invoked 1 time"); 267 return getLatchInvocationParameter(theHookParams, 0); 268 } 269 270 public static Object getLatchInvocationParameter(List<HookParams> theHookParams, int index) { 271 Validate.notNull(theHookParams); 272 HookParams arg = theHookParams.get(index); 273 Validate.isTrue(arg.values().size() == 1, "Expected pointcut to be invoked with 1 argument"); 274 return arg.values().iterator().next(); 275 } 276}