001package ca.uhn.fhir.test.utilities.server;
002
003/*-
004 * #%L
005 * HAPI FHIR Test Utilities
006 * %%
007 * Copyright (C) 2014 - 2023 Smile CDR, Inc.
008 * %%
009 * Licensed under the Apache License, Version 2.0 (the "License");
010 * you may not use this file except in compliance with the License.
011 * You may obtain a copy of the License at
012 *
013 *      http://www.apache.org/licenses/LICENSE-2.0
014 *
015 * Unless required by applicable law or agreed to in writing, software
016 * distributed under the License is distributed on an "AS IS" BASIS,
017 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
018 * See the License for the specific language governing permissions and
019 * limitations under the License.
020 * #L%
021 */
022
023import ca.uhn.fhir.rest.annotation.Transaction;
024import ca.uhn.fhir.rest.annotation.TransactionParam;
025import org.hl7.fhir.instance.model.api.IBaseBundle;
026import org.junit.jupiter.api.extension.AfterEachCallback;
027import org.junit.jupiter.api.extension.BeforeEachCallback;
028import org.junit.jupiter.api.extension.ExtensionContext;
029import org.slf4j.Logger;
030import org.slf4j.LoggerFactory;
031
032import java.util.ArrayList;
033import java.util.Collections;
034import java.util.List;
035
036import static org.awaitility.Awaitility.await;
037import static org.hamcrest.MatcherAssert.assertThat;
038import static org.hamcrest.Matchers.equalTo;
039import static org.hamcrest.Matchers.greaterThanOrEqualTo;
040
041public class TransactionCapturingProviderExtension<T extends IBaseBundle> implements BeforeEachCallback, AfterEachCallback {
042
043        private static final Logger ourLog = LoggerFactory.getLogger(TransactionCapturingProviderExtension.class);
044        private final RestfulServerExtension myRestfulServerExtension;
045        private final List<T> myInputBundles = Collections.synchronizedList(new ArrayList<>());
046        private PlainProvider myProvider;
047
048        /**
049         * Constructor
050         */
051        public TransactionCapturingProviderExtension(RestfulServerExtension theRestfulServerExtension, Class<T> theBundleType) {
052                myRestfulServerExtension = theRestfulServerExtension;
053        }
054
055        @Override
056        public void afterEach(ExtensionContext context) throws Exception {
057                myProvider = new PlainProvider();
058                myRestfulServerExtension.getRestfulServer().unregisterProvider(myProvider);
059        }
060
061        @Override
062        public void beforeEach(ExtensionContext context) throws Exception {
063                myRestfulServerExtension.getRestfulServer().registerProvider(myProvider);
064                myInputBundles.clear();
065        }
066
067        public void waitForTransactionCount(int theCount) {
068                assertThat(theCount, greaterThanOrEqualTo(myInputBundles.size()));
069                await().until(()->myInputBundles.size(), equalTo(theCount));
070        }
071
072        public List<T> getTransactions() {
073                return Collections.unmodifiableList(myInputBundles);
074        }
075
076        private class PlainProvider {
077
078                @Transaction
079                public T transaction(@TransactionParam T theInput) {
080                        ourLog.info("Received transaction update");
081                        myInputBundles.add(theInput);
082                        return theInput;
083                }
084
085        }
086
087
088}