Introduce MethodInvoker API for TestExecutionListeners

In order to be able to support parameter injection in
@​BeforeTransaction and @​AfterTransaction methods (see gh-30736), this
commit introduces a MethodInvoker API for TestExecutionListeners as a
generic mechanism for delegating to the underlying testing framework to
invoke methods.

The default implementation simply invokes the method without arguments,
which allows TestExecutionListeners using this mechanism to operate
correctly when the underlying testing framework is JUnit 4, TestNG, etc.

A JUnit Jupiter specific implementation is registered in the
SpringExtension which delegates to the
ExtensionContext.getExecutableInvoker() mechanism introduced in JUnit
Jupiter 5.9. This allows a TestExecutionListener to transparently
benefit from registered ParameterResolvers in JUnit Jupiter (including
the SpringExtension) when invoking user methods, effectively providing
support for parameter injection for arbitrary methods.

Closes gh-31199
This commit is contained in:
Sam Brannen 2023-09-10 14:34:49 +02:00
parent 0ebdd8cb98
commit 41904d46ad
7 changed files with 330 additions and 89 deletions

View File

@ -0,0 +1,62 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.test.context;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
/**
* Default implementation of the {@link MethodInvoker} API.
*
* <p>This implementation never provides arguments to a {@link Method}.
*
* @author Sam Brannen
* @since 6.1
*/
final class DefaultMethodInvoker implements MethodInvoker {
private static final Log logger = LogFactory.getLog(DefaultMethodInvoker.class);
@Override
@Nullable
public Object invoke(Method method, @Nullable Object target) throws Exception {
Assert.notNull(method, "Method must not be null");
try {
ReflectionUtils.makeAccessible(method);
return method.invoke(target);
}
catch (InvocationTargetException ex) {
if (logger.isErrorEnabled()) {
logger.error("Exception encountered while invoking method [%s] on target [%s]"
.formatted(method, target), ex.getTargetException());
}
ReflectionUtils.rethrowException(ex.getTargetException());
// appease the compiler
return null;
}
}
}

View File

@ -0,0 +1,72 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.test.context;
import java.lang.reflect.Method;
import org.springframework.lang.Nullable;
/**
* {@code MethodInvoker} defines a generic API for invoking a {@link Method}
* within the <em>Spring TestContext Framework</em>.
*
* <p>Specifically, a {@code MethodInvoker} is made available to a
* {@link TestExecutionListener} via {@link TestContext#getMethodInvoker()}, and
* a {@code TestExecutionListener} can use the invoker to transparently benefit
* from any special method invocation features of the underlying testing framework.
*
* <p>For example, when the underlying testing framework is JUnit Jupiter, a
* {@code TestExecutionListener} can use a {@code MethodInvoker} to invoke
* arbitrary methods with JUnit Jupiter's
* {@linkplain org.junit.jupiter.api.extension.ExecutableInvoker parameter resolution
* mechanism}. For other testing frameworks, the {@link #DEFAULT_INVOKER} will be
* used.
*
* @author Sam Brannen
* @since 6.1
* @see org.junit.jupiter.api.extension.ExecutableInvoker
* @see org.springframework.util.MethodInvoker
*/
public interface MethodInvoker {
/**
* Shared instance of the default {@link MethodInvoker}.
* <p>This invoker never provides arguments to a {@link Method}.
*/
static final MethodInvoker DEFAULT_INVOKER = new DefaultMethodInvoker();
/**
* Invoke the supplied {@link Method} on the supplied {@code target}.
* <p>When the {@link #DEFAULT_INVOKER} is used &mdash; for example, when
* the underlying testing framework is JUnit 4 or TestNG &mdash; the method
* must not declare any formal parameters. When the underlying testing
* framework is JUnit Jupiter, parameters will be dynamically resolved via
* registered {@link org.junit.jupiter.api.extension.ParameterResolver
* ParameterResolvers} (such as the
* {@link org.springframework.test.context.junit.jupiter.SpringExtension
* SpringExtension}).
* @param method the method to invoke
* @param target the object on which to invoke the method, may be {@code null}
* if the method is {@code static}
* @return the value returned from the method invocation, potentially {@code null}
* @throws Exception if any error occurs
*/
@Nullable
Object invoke(Method method, @Nullable Object target) throws Exception;
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -38,6 +38,9 @@ import org.springframework.test.annotation.DirtiesContext.HierarchyMode;
* that does not provide a copy constructor will likely fail in an environment
* that executes tests concurrently.
*
* <p>As of Spring Framework 6.1, concrete implementations are highly encouraged to
* override {@link #setMethodInvoker(MethodInvoker)} and {@link #getMethodInvoker()}.
*
* @author Sam Brannen
* @since 2.5
* @see TestContextManager
@ -150,4 +153,28 @@ public interface TestContext extends AttributeAccessor, Serializable {
*/
void updateState(@Nullable Object testInstance, @Nullable Method testMethod, @Nullable Throwable testException);
/**
* Set the {@link MethodInvoker} to use.
* <p>By default, this method does nothing.
* <p>Concrete implementations should track the supplied {@code MethodInvoker}
* and return it from {@link #getMethodInvoker()}. Note that the standard
* {@code TestContext} implementation in Spring overrides this method appropriately.
* @since 6.1
*/
default void setMethodInvoker(MethodInvoker methodInvoker) {
/* no-op */
}
/**
* Get the {@link MethodInvoker} to use.
* <p>By default, this method returns {@link MethodInvoker#DEFAULT_INVOKER}.
* <p>Concrete implementations should return the {@code MethodInvoker} supplied
* to {@link #setMethodInvoker(MethodInvoker)}. Note that the standard
* {@code TestContext} implementation in Spring overrides this method appropriately.
* @since 6.1
*/
default MethodInvoker getMethodInvoker() {
return MethodInvoker.DEFAULT_INVOKER;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -195,21 +195,26 @@ public class TestContextManager {
* @see #getTestExecutionListeners()
*/
public void beforeTestClass() throws Exception {
Class<?> testClass = getTestContext().getTestClass();
if (logger.isTraceEnabled()) {
logger.trace("beforeTestClass(): class [" + typeName(testClass) + "]");
}
getTestContext().updateState(null, null, null);
try {
Class<?> testClass = getTestContext().getTestClass();
if (logger.isTraceEnabled()) {
logger.trace("beforeTestClass(): class [" + typeName(testClass) + "]");
}
getTestContext().updateState(null, null, null);
for (TestExecutionListener testExecutionListener : getTestExecutionListeners()) {
try {
testExecutionListener.beforeTestClass(getTestContext());
}
catch (Throwable ex) {
logException(ex, "beforeTestClass", testExecutionListener, testClass);
ReflectionUtils.rethrowException(ex);
for (TestExecutionListener testExecutionListener : getTestExecutionListeners()) {
try {
testExecutionListener.beforeTestClass(getTestContext());
}
catch (Throwable ex) {
logException(ex, "beforeTestClass", testExecutionListener, testClass);
ReflectionUtils.rethrowException(ex);
}
}
}
finally {
resetMethodInvoker();
}
}
/**
@ -231,25 +236,30 @@ public class TestContextManager {
* @see #getTestExecutionListeners()
*/
public void prepareTestInstance(Object testInstance) throws Exception {
if (logger.isTraceEnabled()) {
logger.trace("prepareTestInstance(): instance [" + testInstance + "]");
}
getTestContext().updateState(testInstance, null, null);
for (TestExecutionListener testExecutionListener : getTestExecutionListeners()) {
try {
testExecutionListener.prepareTestInstance(getTestContext());
try {
if (logger.isTraceEnabled()) {
logger.trace("prepareTestInstance(): instance [" + testInstance + "]");
}
catch (Throwable ex) {
if (logger.isErrorEnabled()) {
logger.error("""
getTestContext().updateState(testInstance, null, null);
for (TestExecutionListener testExecutionListener : getTestExecutionListeners()) {
try {
testExecutionListener.prepareTestInstance(getTestContext());
}
catch (Throwable ex) {
if (logger.isErrorEnabled()) {
logger.error("""
Caught exception while allowing TestExecutionListener [%s] to \
prepare test instance [%s]"""
.formatted(typeName(testExecutionListener), testInstance), ex);
.formatted(typeName(testExecutionListener), testInstance), ex);
}
ReflectionUtils.rethrowException(ex);
}
ReflectionUtils.rethrowException(ex);
}
}
finally {
resetMethodInvoker();
}
}
/**
@ -280,17 +290,22 @@ public class TestContextManager {
* @see #getTestExecutionListeners()
*/
public void beforeTestMethod(Object testInstance, Method testMethod) throws Exception {
String callbackName = "beforeTestMethod";
prepareForBeforeCallback(callbackName, testInstance, testMethod);
try {
String callbackName = "beforeTestMethod";
prepareForBeforeCallback(callbackName, testInstance, testMethod);
for (TestExecutionListener testExecutionListener : getTestExecutionListeners()) {
try {
testExecutionListener.beforeTestMethod(getTestContext());
}
catch (Throwable ex) {
handleBeforeException(ex, callbackName, testExecutionListener, testInstance, testMethod);
for (TestExecutionListener testExecutionListener : getTestExecutionListeners()) {
try {
testExecutionListener.beforeTestMethod(getTestContext());
}
catch (Throwable ex) {
handleBeforeException(ex, callbackName, testExecutionListener, testInstance, testMethod);
}
}
}
finally {
resetMethodInvoker();
}
}
/**
@ -319,17 +334,22 @@ public class TestContextManager {
* @see #getTestExecutionListeners()
*/
public void beforeTestExecution(Object testInstance, Method testMethod) throws Exception {
String callbackName = "beforeTestExecution";
prepareForBeforeCallback(callbackName, testInstance, testMethod);
try {
String callbackName = "beforeTestExecution";
prepareForBeforeCallback(callbackName, testInstance, testMethod);
for (TestExecutionListener testExecutionListener : getTestExecutionListeners()) {
try {
testExecutionListener.beforeTestExecution(getTestContext());
}
catch (Throwable ex) {
handleBeforeException(ex, callbackName, testExecutionListener, testInstance, testMethod);
for (TestExecutionListener testExecutionListener : getTestExecutionListeners()) {
try {
testExecutionListener.beforeTestExecution(getTestContext());
}
catch (Throwable ex) {
handleBeforeException(ex, callbackName, testExecutionListener, testInstance, testMethod);
}
}
}
finally {
resetMethodInvoker();
}
}
/**
@ -367,29 +387,34 @@ public class TestContextManager {
public void afterTestExecution(Object testInstance, Method testMethod, @Nullable Throwable exception)
throws Exception {
String callbackName = "afterTestExecution";
prepareForAfterCallback(callbackName, testInstance, testMethod, exception);
Throwable afterTestExecutionException = null;
try {
String callbackName = "afterTestExecution";
prepareForAfterCallback(callbackName, testInstance, testMethod, exception);
Throwable afterTestExecutionException = null;
// Traverse the TestExecutionListeners in reverse order to ensure proper
// "wrapper"-style execution of listeners.
for (TestExecutionListener testExecutionListener : getReversedTestExecutionListeners()) {
try {
testExecutionListener.afterTestExecution(getTestContext());
// Traverse the TestExecutionListeners in reverse order to ensure proper
// "wrapper"-style execution of listeners.
for (TestExecutionListener testExecutionListener : getReversedTestExecutionListeners()) {
try {
testExecutionListener.afterTestExecution(getTestContext());
}
catch (Throwable ex) {
logException(ex, callbackName, testExecutionListener, testInstance, testMethod);
if (afterTestExecutionException == null) {
afterTestExecutionException = ex;
}
else {
afterTestExecutionException.addSuppressed(ex);
}
}
}
catch (Throwable ex) {
logException(ex, callbackName, testExecutionListener, testInstance, testMethod);
if (afterTestExecutionException == null) {
afterTestExecutionException = ex;
}
else {
afterTestExecutionException.addSuppressed(ex);
}
if (afterTestExecutionException != null) {
ReflectionUtils.rethrowException(afterTestExecutionException);
}
}
if (afterTestExecutionException != null) {
ReflectionUtils.rethrowException(afterTestExecutionException);
finally {
resetMethodInvoker();
}
}
@ -429,29 +454,34 @@ public class TestContextManager {
public void afterTestMethod(Object testInstance, Method testMethod, @Nullable Throwable exception)
throws Exception {
String callbackName = "afterTestMethod";
prepareForAfterCallback(callbackName, testInstance, testMethod, exception);
Throwable afterTestMethodException = null;
try {
String callbackName = "afterTestMethod";
prepareForAfterCallback(callbackName, testInstance, testMethod, exception);
Throwable afterTestMethodException = null;
// Traverse the TestExecutionListeners in reverse order to ensure proper
// "wrapper"-style execution of listeners.
for (TestExecutionListener testExecutionListener : getReversedTestExecutionListeners()) {
try {
testExecutionListener.afterTestMethod(getTestContext());
// Traverse the TestExecutionListeners in reverse order to ensure proper
// "wrapper"-style execution of listeners.
for (TestExecutionListener testExecutionListener : getReversedTestExecutionListeners()) {
try {
testExecutionListener.afterTestMethod(getTestContext());
}
catch (Throwable ex) {
logException(ex, callbackName, testExecutionListener, testInstance, testMethod);
if (afterTestMethodException == null) {
afterTestMethodException = ex;
}
else {
afterTestMethodException.addSuppressed(ex);
}
}
}
catch (Throwable ex) {
logException(ex, callbackName, testExecutionListener, testInstance, testMethod);
if (afterTestMethodException == null) {
afterTestMethodException = ex;
}
else {
afterTestMethodException.addSuppressed(ex);
}
if (afterTestMethodException != null) {
ReflectionUtils.rethrowException(afterTestMethodException);
}
}
if (afterTestMethodException != null) {
ReflectionUtils.rethrowException(afterTestMethodException);
finally {
resetMethodInvoker();
}
}
@ -504,6 +534,15 @@ public class TestContextManager {
}
}
/**
* Reset the {@link MethodInvoker} to the default to ensure that a custom
* {@code MethodInvoker} for the current test execution is not retained for
* subsequent test executions.
*/
private void resetMethodInvoker() {
getTestContext().setMethodInvoker(MethodInvoker.DEFAULT_INVOKER);
}
private void prepareForBeforeCallback(String callbackName, Object testInstance, Method testMethod) {
if (logger.isTraceEnabled()) {
logger.trace("%s(): instance [%s], method [%s]".formatted(callbackName, testInstance, testMethod));

View File

@ -52,6 +52,7 @@ import org.springframework.core.annotation.MergedAnnotations;
import org.springframework.core.annotation.MergedAnnotations.SearchStrategy;
import org.springframework.core.annotation.RepeatableContainers;
import org.springframework.lang.Nullable;
import org.springframework.test.context.MethodInvoker;
import org.springframework.test.context.TestConstructor;
import org.springframework.test.context.TestContextAnnotationUtils;
import org.springframework.test.context.TestContextManager;
@ -122,7 +123,9 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes
*/
@Override
public void beforeAll(ExtensionContext context) throws Exception {
getTestContextManager(context).beforeTestClass();
TestContextManager testContextManager = getTestContextManager(context);
registerMethodInvoker(testContextManager, context);
testContextManager.beforeTestClass();
}
/**
@ -131,7 +134,9 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes
@Override
public void afterAll(ExtensionContext context) throws Exception {
try {
getTestContextManager(context).afterTestClass();
TestContextManager testContextManager = getTestContextManager(context);
registerMethodInvoker(testContextManager, context);
testContextManager.afterTestClass();
}
finally {
getStore(context).remove(context.getRequiredTestClass());
@ -148,7 +153,9 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes
public void postProcessTestInstance(Object testInstance, ExtensionContext context) throws Exception {
validateAutowiredConfig(context);
validateRecordApplicationEventsConfig(context);
getTestContextManager(context).prepareTestInstance(testInstance);
TestContextManager testContextManager = getTestContextManager(context);
registerMethodInvoker(testContextManager, context);
testContextManager.prepareTestInstance(testInstance);
}
/**
@ -223,7 +230,9 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes
public void beforeEach(ExtensionContext context) throws Exception {
Object testInstance = context.getRequiredTestInstance();
Method testMethod = context.getRequiredTestMethod();
getTestContextManager(context).beforeTestMethod(testInstance, testMethod);
TestContextManager testContextManager = getTestContextManager(context);
registerMethodInvoker(testContextManager, context);
testContextManager.beforeTestMethod(testInstance, testMethod);
}
/**
@ -233,7 +242,9 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes
public void beforeTestExecution(ExtensionContext context) throws Exception {
Object testInstance = context.getRequiredTestInstance();
Method testMethod = context.getRequiredTestMethod();
getTestContextManager(context).beforeTestExecution(testInstance, testMethod);
TestContextManager testContextManager = getTestContextManager(context);
registerMethodInvoker(testContextManager, context);
testContextManager.beforeTestExecution(testInstance, testMethod);
}
/**
@ -244,7 +255,9 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes
Object testInstance = context.getRequiredTestInstance();
Method testMethod = context.getRequiredTestMethod();
Throwable testException = context.getExecutionException().orElse(null);
getTestContextManager(context).afterTestExecution(testInstance, testMethod, testException);
TestContextManager testContextManager = getTestContextManager(context);
registerMethodInvoker(testContextManager, context);
testContextManager.afterTestExecution(testInstance, testMethod, testException);
}
/**
@ -255,7 +268,9 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes
Object testInstance = context.getRequiredTestInstance();
Method testMethod = context.getRequiredTestMethod();
Throwable testException = context.getExecutionException().orElse(null);
getTestContextManager(context).afterTestMethod(testInstance, testMethod, testException);
TestContextManager testContextManager = getTestContextManager(context);
registerMethodInvoker(testContextManager, context);
testContextManager.afterTestMethod(testInstance, testMethod, testException);
}
/**
@ -350,6 +365,17 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes
return context.getRoot().getStore(TEST_CONTEXT_MANAGER_NAMESPACE);
}
/**
* Register a {@link MethodInvoker} adaptor for Jupiter's
* {@link org.junit.jupiter.api.extension.ExecutableInvoker ExecutableInvoker}
* in the {@link org.springframework.test.context.TestContext TestContext} for
* the supplied {@link TestContextManager}.
* @since 6.1
*/
private static void registerMethodInvoker(TestContextManager testContextManager, ExtensionContext context) {
testContextManager.getTestContext().setMethodInvoker(context.getExecutableInvoker()::invoke);
}
private static boolean isAutowiredTestOrLifecycleMethod(Method method) {
MergedAnnotations mergedAnnotations =
MergedAnnotations.from(method, SearchStrategy.DIRECT, RepeatableContainers.none());

View File

@ -30,6 +30,7 @@ import org.springframework.lang.Nullable;
import org.springframework.test.annotation.DirtiesContext.HierarchyMode;
import org.springframework.test.context.CacheAwareContextLoaderDelegate;
import org.springframework.test.context.MergedContextConfiguration;
import org.springframework.test.context.MethodInvoker;
import org.springframework.test.context.TestContext;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
@ -64,6 +65,8 @@ public class DefaultTestContext implements TestContext {
@Nullable
private volatile Throwable testException;
private volatile MethodInvoker methodInvoker = MethodInvoker.DEFAULT_INVOKER;
/**
* <em>Copy constructor</em> for creating a new {@code DefaultTestContext}
@ -183,6 +186,17 @@ public class DefaultTestContext implements TestContext {
this.testException = testException;
}
@Override
public final void setMethodInvoker(MethodInvoker methodInvoker) {
Assert.notNull(methodInvoker, "MethodInvoker must not be null");
this.methodInvoker = methodInvoker;
}
@Override
public final MethodInvoker getMethodInvoker() {
return this.methodInvoker;
}
@Override
public void setAttribute(String name, @Nullable Object value) {
Assert.notNull(name, "Name must not be null");

View File

@ -37,6 +37,7 @@ import org.springframework.transaction.support.SimpleTransactionStatus;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.CALLS_REAL_METHODS;
import static org.mockito.Mockito.mock;
import static org.springframework.transaction.annotation.Propagation.NOT_SUPPORTED;
import static org.springframework.transaction.annotation.Propagation.REQUIRED;
@ -58,7 +59,7 @@ class TransactionalTestExecutionListenerTests {
}
};
private final TestContext testContext = mock();
private final TestContext testContext = mock(CALLS_REAL_METHODS);
@AfterEach