001package ca.uhn.test.concurrency;
002
003/*-
004 * #%L
005 * HAPI FHIR Test Utilities
006 * %%
007 * Copyright (C) 2014 - 2023 Smile CDR, Inc.
008 * %%
009 * Licensed under the Apache License, Version 2.0 (the "License");
010 * you may not use this file except in compliance with the License.
011 * You may obtain a copy of the License at
012 *
013 *      http://www.apache.org/licenses/LICENSE-2.0
014 *
015 * Unless required by applicable law or agreed to in writing, software
016 * distributed under the License is distributed on an "AS IS" BASIS,
017 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
018 * See the License for the specific language governing permissions and
019 * limitations under the License.
020 * #L%
021 */
022
023
024import ca.uhn.fhir.i18n.Msg;
025import ca.uhn.fhir.interceptor.api.HookParams;
026import ca.uhn.fhir.interceptor.api.IAnonymousInterceptor;
027import ca.uhn.fhir.interceptor.api.IPointcut;
028import com.google.common.collect.ListMultimap;
029import org.apache.commons.lang3.Validate;
030import org.apache.commons.lang3.builder.ToStringBuilder;
031import org.apache.commons.lang3.exception.ExceptionUtils;
032import org.slf4j.Logger;
033import org.slf4j.LoggerFactory;
034
035import java.util.ArrayList;
036import java.util.Collections;
037import java.util.List;
038import java.util.concurrent.CountDownLatch;
039import java.util.concurrent.TimeUnit;
040import java.util.concurrent.atomic.AtomicLong;
041import java.util.concurrent.atomic.AtomicReference;
042import java.util.stream.Collectors;
043
044// This class is primarily used for testing.
045public class PointcutLatch implements IAnonymousInterceptor, IPointcutLatch {
046        private static final Logger ourLog = LoggerFactory.getLogger(PointcutLatch.class);
047        private static final int DEFAULT_TIMEOUT_SECONDS = 10;
048        private static final FhirObjectPrinter ourFhirObjectToStringMapper = new FhirObjectPrinter();
049
050        private final String myName;
051        private final AtomicLong myLastInvoke = new AtomicLong();
052        private final AtomicReference<CountDownLatch> myCountdownLatch = new AtomicReference<>();
053        private final AtomicReference<String> myCountdownLatchSetStacktrace = new AtomicReference<>();
054        private final AtomicReference<List<String>> myFailures = new AtomicReference<>();
055        private final AtomicReference<List<HookParams>> myCalledWith = new AtomicReference<>();
056        private final IPointcut myPointcut;
057        private int myDefaultTimeoutSeconds = DEFAULT_TIMEOUT_SECONDS;
058        private int myInitialCount;
059        private boolean myExactMatch;
060
061        public PointcutLatch(IPointcut thePointcut) {
062                this.myName = thePointcut.name();
063                myPointcut = thePointcut;
064        }
065
066
067        public PointcutLatch(String theName) {
068                this.myName = theName;
069                myPointcut = null;
070        }
071
072        public void runWithExpectedCount(int theExpectedCount, Runnable r) throws InterruptedException {
073                this.setExpectedCount(theExpectedCount);
074                r.run();
075                this.awaitExpected();
076        }
077
078        public long getLastInvoke() {
079                return myLastInvoke.get();
080        }
081
082        // Useful for debugging when you need more time to step through a method
083        public PointcutLatch setDefaultTimeoutSeconds(int theDefaultTimeoutSeconds) {
084                myDefaultTimeoutSeconds = theDefaultTimeoutSeconds;
085                return this;
086        }
087
088        @Override
089        public void setExpectedCount(int theCount) {
090                this.setExpectedCount(theCount, true);
091        }
092
093        public void setExpectedCount(int theCount, boolean theExactMatch) {
094                if (myCountdownLatch.get() != null) {
095                        String previousStack = myCountdownLatchSetStacktrace.get();
096                        throw new PointcutLatchException(Msg.code(1480) + "setExpectedCount() called before previous awaitExpected() completed. Previous set stack:\n" + previousStack);
097                }
098                myExactMatch = theExactMatch;
099                createLatch(theCount);
100                if (theExactMatch) {
101                        ourLog.info("Expecting exactly {} calls to {} latch", theCount, myName);
102                } else {
103                        ourLog.info("Expecting at least {} calls to {} latch", theCount, myName);
104                }
105        }
106
107        public void setExpectAtLeast(int theCount) {
108                setExpectedCount(theCount, false);
109        }
110
111        public boolean isSet() {
112                return myCountdownLatch.get() != null;
113        }
114
115        private void createLatch(int theCount) {
116                myFailures.set(Collections.synchronizedList(new ArrayList<>()));
117                myCalledWith.set(Collections.synchronizedList(new ArrayList<>()));
118                myCountdownLatch.set(new CountDownLatch(theCount));
119                try {
120                        throw new Exception(Msg.code(1481));
121                } catch (Exception e) {
122                        myCountdownLatchSetStacktrace.set(ExceptionUtils.getStackTrace(e));
123                }
124                myInitialCount = theCount;
125        }
126
127        private void addFailure(String failure) {
128                if (myFailures.get() != null) {
129                        myFailures.get().add(failure);
130                } else {
131                        throw new PointcutLatchException(Msg.code(1482) + "trying to set failure on latch that hasn't been created: " + failure);
132                }
133        }
134
135        private String getName() {
136                return myName + " " + this.getClass().getSimpleName();
137        }
138
139        @Override
140        public List<HookParams> awaitExpected() throws InterruptedException {
141                return awaitExpectedWithTimeout(myDefaultTimeoutSeconds);
142        }
143
144        public List<HookParams> awaitExpectedWithTimeout(int timeoutSecond) throws InterruptedException {
145                List<HookParams> retval = myCalledWith.get();
146                try {
147                        CountDownLatch latch = myCountdownLatch.get();
148                        Validate.notNull(latch, getName() + " awaitExpected() called before setExpected() called.");
149                        if (!latch.await(timeoutSecond, TimeUnit.SECONDS)) {
150                                throw new LatchTimedOutError(Msg.code(1483) + getName() + " timed out waiting " + timeoutSecond + " seconds for latch to countdown from " + myInitialCount + " to 0.  Is " + latch.getCount() + ".");
151                        }
152
153                        // Defend against ConcurrentModificationException
154                        String error = getName();
155                        if (myFailures.get() != null && myFailures.get().size() > 0) {
156                                List<String> failures = new ArrayList<>(myFailures.get());
157                                if (failures.size() > 1) {
158                                        error += " ERRORS: \n";
159                                } else {
160                                        error += " ERROR: ";
161                                }
162                                error += String.join("\n", failures);
163                                error += "\nLatch called with values: " + toCalledWithString();
164                                throw new AssertionError(Msg.code(1484) + error);
165                        }
166                } finally {
167                        clear();
168                }
169                Validate.isTrue(retval.equals(myCalledWith.get()), "Concurrency error: Latch switched while waiting.");
170                return retval;
171        }
172
173        @Override
174        public void clear() {
175                myCountdownLatch.set(null);
176                myCountdownLatchSetStacktrace.set(null);
177        }
178
179        private String toCalledWithString() {
180                if (myCalledWith.get() == null) {
181                        return "[]";
182                }
183                // Defend against ConcurrentModificationException
184                List<HookParams> calledWith = new ArrayList<>(myCalledWith.get());
185                if (calledWith.isEmpty()) {
186                        return "[]";
187                }
188                String retVal = "[ ";
189                retVal += calledWith.stream().flatMap(hookParams -> hookParams.values().stream()).map(ourFhirObjectToStringMapper).collect(Collectors.joining(", "));
190                return retVal + " ]";
191        }
192
193        @Override
194        public void invoke(IPointcut thePointcut, HookParams theArgs) {
195                myLastInvoke.set(System.currentTimeMillis());
196                
197                CountDownLatch latch = myCountdownLatch.get();
198                if (myExactMatch) {
199                        if (latch == null) {
200                                throw new PointcutLatchException(Msg.code(1485) + "invoke() for " + myName + " called outside of setExpectedCount() .. awaitExpected().  Probably got more invocations than expected or clear() was called before invoke() arrived with args: " + theArgs, theArgs);
201                        } else if (latch.getCount() <= 0) {
202                                addFailure("invoke() called when countdown was zero.");
203                        }
204                } else if (latch == null || latch.getCount() <= 0) {
205                        return;
206                }
207
208                if (myCalledWith.get() != null) {
209                        myCalledWith.get().add(theArgs);
210                }
211                ourLog.debug("Called {} {} with {}", myName, latch, hookParamsToString(theArgs));
212
213                latch.countDown();
214        }
215
216        public void call(Object arg) {
217                this.invoke(myPointcut, new HookParams(arg));
218        }
219
220        @Override
221        public String toString() {
222                return new ToStringBuilder(this)
223                        .append("name", myName)
224                        .append("myCountdownLatch", myCountdownLatch)
225//                      .append("myFailures", myFailures)
226//                      .append("myCalledWith", myCalledWith)
227                        .append("myInitialCount", myInitialCount)
228                        .toString();
229        }
230
231        public Object getLatchInvocationParameter() {
232                return getLatchInvocationParameter(myCalledWith.get());
233        }
234
235        @SuppressWarnings("unchecked")
236        public <T> T getLatchInvocationParameterOfType(Class<T> theType) {
237                List<HookParams> hookParamsList = myCalledWith.get();
238                Validate.notNull(hookParamsList);
239                Validate.isTrue(hookParamsList.size() == 1, "Expected Pointcut to be invoked 1 time");
240                HookParams hookParams = hookParamsList.get(0);
241                ListMultimap<Class<?>, Object> paramsForType = hookParams.getParamsForType();
242                List<Object> objects = paramsForType.get(theType);
243                Validate.isTrue(objects.size() == 1);
244                return (T) objects.get(0);
245        }
246
247
248        private class PointcutLatchException extends IllegalStateException {
249                private static final long serialVersionUID = 1372636272233536829L;
250
251                PointcutLatchException(String message, HookParams theArgs) {
252                        super(getName() + ": " + message + " called with values: " + hookParamsToString(theArgs));
253                }
254
255                public PointcutLatchException(String message) {
256                        super(getName() + ": " + message);
257                }
258        }
259
260        private static String hookParamsToString(HookParams hookParams) {
261                return hookParams.values().stream().map(ourFhirObjectToStringMapper).collect(Collectors.joining(", "));
262        }
263
264        public static Object getLatchInvocationParameter(List<HookParams> theHookParams) {
265                Validate.notNull(theHookParams);
266                Validate.isTrue(theHookParams.size() == 1, "Expected Pointcut to be invoked 1 time");
267                return getLatchInvocationParameter(theHookParams, 0);
268        }
269
270        public static Object getLatchInvocationParameter(List<HookParams> theHookParams, int index) {
271                Validate.notNull(theHookParams);
272                HookParams arg = theHookParams.get(index);
273                Validate.isTrue(arg.values().size() == 1, "Expected pointcut to be invoked with 1 argument");
274                return arg.values().iterator().next();
275        }
276}