001package com.unitils.boot;
002
003import org.springframework.beans.BeansException;
004import org.springframework.context.ApplicationContext;
005import org.springframework.transaction.PlatformTransactionManager;
006import org.unitils.core.Module;
007import org.unitils.core.TestListener;
008import org.unitils.core.Unitils;
009import org.unitils.core.UnitilsException;
010import org.unitils.database.DatabaseModule;
011import org.unitils.database.annotations.Transactional;
012import org.unitils.database.transaction.impl.UnitilsTransactionManagementConfiguration;
013import org.unitils.spring.annotation.*;
014import org.unitils.spring.enums.LoadTime;
015import org.unitils.spring.util.ApplicationContextFactory;
016import org.unitils.spring.util.ApplicationContextManager;
017import org.unitils.util.AnnotationUtils;
018import org.unitils.util.ReflectionUtils;
019
020import java.lang.reflect.Field;
021import java.lang.reflect.InvocationTargetException;
022import java.lang.reflect.Method;
023import java.util.Map;
024import java.util.Properties;
025import java.util.Set;
026
027import static org.apache.commons.lang.StringUtils.isEmpty;
028import static org.unitils.util.AnnotationUtils.*;
029import static org.unitils.util.PropertyUtils.getInstance;
030import static org.unitils.util.ReflectionUtils.*;
031
032/**
033 * @Author: yangjianzhou
034 * @Description:
035 * @Date:Created in 2018-07-08
036 */
037public class SpringBootModule implements Module {
038
039    /* Property key of the class name of the application context factory */
040    public static final String PROPKEY_APPLICATION_CONTEXT_FACTORY_CLASS_NAME = "SpringModule.applicationContextFactory.implClassName";
041
042    /* Manager for storing and creating org.unitils.spring application contexts */
043    private ApplicationContextManager applicationContextManager;
044
045    private  static ApplicationContext applicationContext;
046
047    public static void setApplicationContext(ApplicationContext applicationContext) {
048        SpringBootModule.applicationContext = applicationContext;
049    }
050
051    /**
052     * Initializes this module using the given configuration
053     *
054     * @param configuration The configuration, not null
055     */
056    public void init(Properties configuration) {
057        // create application context manager that stores and creates the application contexts
058        ApplicationContextFactory applicationContextFactory = getInstance(PROPKEY_APPLICATION_CONTEXT_FACTORY_CLASS_NAME, configuration);
059        applicationContextManager = new ApplicationContextManager(applicationContextFactory);
060    }
061
062
063    /**
064     * No after initialization needed for this module
065     */
066    public void afterInit() {
067        // Make sure that, if a custom transaction manager is configured in the org.unitils.spring ApplicationContext associated with
068        // the current test, it is used for managing transactions.
069        if (isDatabaseModuleEnabled()) {
070            getDatabaseModule().registerTransactionManagementConfiguration(new UnitilsTransactionManagementConfiguration() {
071
072                public boolean isApplicableFor(Object testObject) {
073                    if (!isApplicationContextConfiguredFor(testObject)) {
074                        return false;
075                    }
076                    ApplicationContext context = getApplicationContext(testObject);
077                    return context.getBeansOfType(getPlatformTransactionManagerClass()).size() != 0;
078                }
079
080                @SuppressWarnings("unchecked")
081                public PlatformTransactionManager getSpringPlatformTransactionManager(Object testObject) {
082                    ApplicationContext context = getApplicationContext(testObject);
083                    Class<?> platformTransactionManagerClass = getPlatformTransactionManagerClass();
084                    Map<String, PlatformTransactionManager> platformTransactionManagers = (Map<String, PlatformTransactionManager>) context.getBeansOfType(platformTransactionManagerClass);
085                    if (platformTransactionManagers.size() == 0) {
086                        throw new UnitilsException("Could not find a model of type " + platformTransactionManagerClass.getSimpleName()
087                                + " in the org.unitils.spring ApplicationContext for this class");
088                    }
089                    if (platformTransactionManagers.size() > 1) {
090                        Method testMethod = Unitils.getInstance().getTestContext().getTestMethod();
091                        String transactionManagerName = getMethodOrClassLevelAnnotationProperty(Transactional.class, "transactionManagerName", "",
092                                testMethod, testObject.getClass());
093                        if (isEmpty(transactionManagerName))
094                            throw new UnitilsException("Found more than one model of type " + platformTransactionManagerClass.getSimpleName()
095                                    + " in the org.unitils.spring ApplicationContext for this class. Use the transactionManagerName on the @Transactional"
096                                    + " annotation to select the correct one.");
097                        if (!platformTransactionManagers.containsKey(transactionManagerName))
098                            throw new UnitilsException("No model of type " + platformTransactionManagerClass.getSimpleName()
099                                    + " found in the org.unitils.spring ApplicationContext with the name " + transactionManagerName);
100                        return platformTransactionManagers.get(transactionManagerName);
101                    }
102                    return platformTransactionManagers.values().iterator().next();
103                }
104
105                public boolean isTransactionalResourceAvailable(Object testObject) {
106                    return true;
107                }
108
109                public Integer getPreference() {
110                    return 20;
111                }
112
113                protected Class<?> getPlatformTransactionManagerClass() {
114                    return ReflectionUtils.getClassWithName("org.springframework.transaction.PlatformTransactionManager");
115                }
116
117            });
118        }
119    }
120
121
122    /**
123     * Gets the org.unitils.spring model with the given name. The given test instance, by using {@link SpringApplicationContext},
124     * determines the application context in which to look for the model.
125     * <p/>
126     * A UnitilsException is thrown when the no model could be found for the given name.
127     *
128     * @param testObject The test instance, not null
129     * @param name       The name, not null
130     * @return The model, not null
131     */
132    public Object getSpringBean(Object testObject, String name) {
133        try {
134            return getApplicationContext(testObject).getBean(name);
135
136        } catch (BeansException e) {
137            throw new UnitilsException("Unable to get Spring model. No Spring model found for name " + name);
138        }
139    }
140
141    /**
142     * Gets the org.unitils.spring model with the given type. The given test instance, by using {@link SpringApplicationContext},
143     * determines the application context in which to look for the model.
144     * If more there is not exactly 1 possible model assignment, an UnitilsException will be thrown.
145     *
146     * @param testObject The test instance, not null
147     * @param type       The type, not null
148     * @return The model, not null
149     */
150    public <T> T getSpringBeanByType(Object testObject, Class<T> type) {
151        Map<String, T> beans = getApplicationContext(testObject).getBeansOfType(type);
152        if (beans == null || beans.size() == 0) {
153            throw new UnitilsException("Unable to get Spring model by type. No Spring model found for type " + type.getSimpleName());
154        }
155        if (beans.size() > 1) {
156            throw new UnitilsException("Unable to get Spring model by type. More than one possible Spring model for type " + type.getSimpleName() + ". Possible beans; " + beans);
157        }
158        return beans.values().iterator().next();
159    }
160
161    /**
162     * @param testObject The test object
163     * @return Whether an ApplicationContext has been configured for the given testObject
164     */
165    public boolean isApplicationContextConfiguredFor(Object testObject) {
166        return applicationContextManager.hasApplicationContext(testObject);
167    }
168
169
170    /**
171     * Gets the application context for this test. A new one will be created if it does not exist yet. If a superclass
172     * has also declared the creation of an application context, this one will be retrieved (or created if it was not
173     * created yet) and used as parent context for this classes context.
174     * <p/>
175     * If needed, an application context will be created using the settings of the {@link SpringApplicationContext}
176     * annotation.
177     * <p/>
178     * If a class level {@link SpringApplicationContext} annotation is found, the passed locations will be loaded using
179     * a <code>ClassPathXmlApplicationContext</code>.
180     * Custom creation methods can be created by annotating them with {@link SpringApplicationContext}. They
181     * should have an <code>ApplicationContext</code> as return type and either no or exactly 1 argument of type
182     * <code>ApplicationContext</code>. In the latter case, the current configured application context is passed as the argument.
183     * <p/>
184     * A UnitilsException will be thrown if no context could be retrieved or created.
185     *
186      * @param testObject
187     * @return
188     */
189    public ApplicationContext getApplicationContext(Object testObject) {
190
191        if (applicationContext == null) {
192            applicationContext = applicationContextManager.getApplicationContext(testObject);
193        }
194        return applicationContext;
195    }
196
197
198    /**
199     * Forces the reloading of the application context the next time that it is requested. If classes are given
200     * only contexts that are linked to those classes will be reset. If no classes are given, all cached
201     * contexts will be reset.
202     *
203     * @param classes The classes for which to reset the contexts
204     */
205    public void invalidateApplicationContext(Class<?>... classes) {
206        applicationContextManager.invalidateApplicationContext(classes);
207        applicationContext = null;
208    }
209
210
211    /**
212     * Gets the application context for this class and sets it on the fields and setter methods that are
213     * annotated with {@link SpringApplicationContext}. If no application context could be created, an
214     * UnitilsException will be raised.
215     *
216     * @param testObject The test instance, not null
217     */
218    public void injectApplicationContext(Object testObject) {
219
220
221        // inject into fields annotated with @SpringApplicationContext
222        Set<Field> fields = getFieldsAnnotatedWith(testObject.getClass(), SpringApplicationContext.class);
223        for (Field field : fields) {
224            try {
225                setFieldValue(testObject, field, getApplicationContext(testObject));
226            } catch (UnitilsException e) {
227                throw new UnitilsException("Unable to assign the application context to field annotated with @" + SpringApplicationContext.class.getSimpleName(), e);
228            }
229        }
230
231        // inject into setter methods annotated with @SpringApplicationContext
232        Set<Method> methods = getMethodsAnnotatedWith(testObject.getClass(), SpringApplicationContext.class, false);
233        for (Method method : methods) {
234            // ignore custom create methods
235            if (method.getReturnType() != Void.TYPE) {
236                continue;
237            }
238            try {
239                invokeMethod(testObject, method, getApplicationContext(testObject));
240
241            } catch (Exception e) {
242                throw new UnitilsException("Unable to assign the application context to setter annotated with @" + SpringApplicationContext.class.getSimpleName(), e);
243            }
244        }
245    }
246
247
248    /**
249     * Injects org.unitils.spring beans into all fields that are annotated with {@link SpringBean}.
250     *
251     * @param testObject The test instance, not null
252     */
253    public void injectSpringBeans(Object testObject) {
254        // assign to fields
255        Set<Field> fields = getFieldsAnnotatedWith(testObject.getClass(), SpringBean.class);
256        for (Field field : fields) {
257            try {
258                SpringBean springBeanAnnotation = field.getAnnotation(SpringBean.class);
259                setFieldValue(testObject, field, getSpringBean(testObject, springBeanAnnotation.value()));
260
261            } catch (UnitilsException e) {
262                throw new UnitilsException("Unable to assign the Spring model value to field annotated with @" + SpringBean.class.getSimpleName(), e);
263            }
264        }
265
266        // assign to setters
267        Set<Method> methods = getMethodsAnnotatedWith(testObject.getClass(), SpringBean.class);
268        for (Method method : methods) {
269            try {
270                if (!isSetter(method)) {
271                    throw new UnitilsException("Unable to assign the Spring model value to method annotated with @" + SpringBean.class.getSimpleName() + ". Method " +
272                            method.getName() + " is not a setter method.");
273                }
274                SpringBean springBeanAnnotation = method.getAnnotation(SpringBean.class);
275                invokeMethod(testObject, method, getSpringBean(testObject, springBeanAnnotation.value()));
276
277            } catch (UnitilsException e) {
278                throw new UnitilsException("Unable to assign the Spring model value to method annotated with @" + SpringBean.class.getSimpleName(), e);
279            } catch (InvocationTargetException e) {
280                throw new UnitilsException("Unable to assign the Spring model value to method annotated with @" + SpringBean.class.getSimpleName() + ". Method " +
281                        "has thrown an exception.", e.getCause());
282            }
283        }
284    }
285
286
287    /**
288     * Injects org.unitils.spring beans into all fields methods that are annotated with {@link SpringBeanByType}.
289     *
290     * @param testObject The test instance, not null
291     */
292    public void injectSpringBeansByType(Object testObject) {
293        // assign to fields
294        Set<Field> fields = getFieldsAnnotatedWith(testObject.getClass(), SpringBeanByType.class);
295        for (Field field : fields) {
296            try {
297                setFieldValue(testObject, field, getSpringBeanByType(testObject, field.getType()));
298
299            } catch (UnitilsException e) {
300                throw new UnitilsException("Unable to assign the Spring model value to field annotated with @" + SpringBeanByType.class.getSimpleName(), e);
301            }
302        }
303
304        // assign to setters
305        Set<Method> methods = getMethodsAnnotatedWith(testObject.getClass(), SpringBeanByType.class);
306        for (Method method : methods) {
307            try {
308                if (!isSetter(method)) {
309                    throw new UnitilsException("Unable to assign the Spring model value to method annotated with @" + SpringBeanByType.class.getSimpleName() + ". Method " +
310                            method.getName() + " is not a setter method.");
311                }
312                invokeMethod(testObject, method, getSpringBeanByType(testObject, method.getParameterTypes()[0]));
313
314            } catch (UnitilsException e) {
315                throw new UnitilsException("Unable to assign the Spring model value to method annotated with @" + SpringBeanByType.class.getSimpleName(), e);
316            } catch (InvocationTargetException e) {
317                throw new UnitilsException("Unable to assign the Spring model value to method annotated with @" + SpringBeanByType.class.getSimpleName() + ". Method " +
318                        "has thrown an exception.", e.getCause());
319            }
320        }
321    }
322
323
324    /**
325     * Injects org.unitils.spring beans into all fields that are annotated with {@link SpringBeanByName}.
326     *
327     * @param testObject The test instance, not null
328     */
329    public void injectSpringBeansByName(Object testObject) {
330        // assign to fields
331        Set<Field> fields = getFieldsAnnotatedWith(testObject.getClass(), SpringBeanByName.class);
332        for (Field field : fields) {
333            try {
334                setFieldValue(testObject, field, getSpringBean(testObject, field.getName()));
335
336            } catch (UnitilsException e) {
337                throw new UnitilsException("Unable to assign the Spring model value to field annotated with @" + SpringBeanByName.class.getSimpleName(), e);
338            }
339        }
340
341        // assign to setters
342        Set<Method> methods = getMethodsAnnotatedWith(testObject.getClass(), SpringBeanByName.class);
343        for (Method method : methods) {
344            try {
345                if (!isSetter(method)) {
346                    throw new UnitilsException("Unable to assign the Spring model value to method annotated with @" + SpringBeanByName.class.getSimpleName() + ". Method " +
347                            method.getName() + " is not a setter method.");
348                }
349                invokeMethod(testObject, method, getSpringBean(testObject, getPropertyName(method)));
350
351            } catch (UnitilsException e) {
352                throw new UnitilsException("Unable to assign the Spring model value to method annotated with @" + SpringBeanByName.class.getSimpleName(), e);
353            } catch (InvocationTargetException e) {
354                throw new UnitilsException("Unable to assign the Spring model value to method annotated with @" + SpringBeanByName.class.getSimpleName() + ". Method " +
355                        "has thrown an exception.", e.getCause());
356            }
357        }
358    }
359
360    protected void closeApplicationContextIfNeeded(Object testObject) {
361        if (this.isApplicationContextConfiguredFor(testObject)) {
362            this.invalidateApplicationContext(testObject.getClass());
363        }
364    }
365
366    protected boolean isDatabaseModuleEnabled() {
367        return Unitils.getInstance().getModulesRepository().isModuleEnabled(DatabaseModule.class);
368    }
369
370
371    protected DatabaseModule getDatabaseModule() {
372        return Unitils.getInstance().getModulesRepository().getModuleOfType(DatabaseModule.class);
373    }
374
375    public void initialize(Object testObject) {
376        injectApplicationContext(testObject);
377        injectSpringBeans(testObject);
378        injectSpringBeansByType(testObject);
379        injectSpringBeansByName(testObject);
380    }
381
382
383    /**
384     * @return The {@link TestListener} for this module
385     */
386    public TestListener getTestListener() {
387        return new SpringTestListener();
388    }
389
390    public LoadTime findLoadTime(Class<?> clzz) {
391        LoadOn loadOnAnnotation = AnnotationUtils.getClassLevelAnnotation(LoadOn.class, clzz);
392        if (loadOnAnnotation == null) {
393            return LoadTime.METHOD;
394        } else {
395            return loadOnAnnotation.load();
396        }
397
398    }
399
400    /**
401     * The {@link TestListener} for this module
402     */
403    protected class SpringTestListener extends TestListener {
404
405        @Override
406        public void beforeTestSetUp(Object testObject, Method testMethod) {
407            if (findLoadTime(testObject.getClass()) == LoadTime.METHOD || applicationContext == null) {
408                initialize(testObject);
409            }
410        }
411
412        /**
413         * @see TestListener#afterTestTearDown(Object, Method)
414         */
415        @Override
416        public void afterTestTearDown(Object testObject, Method testMethod) {
417            if (findLoadTime(testObject.getClass()) == LoadTime.METHOD) {
418                closeApplicationContextIfNeeded(testObject);
419            }
420
421        }
422
423    }
424
425}