diff --git a/spring-test/src/main/java/org/springframework/test/context/DefaultMethodInvoker.java b/spring-test/src/main/java/org/springframework/test/context/DefaultMethodInvoker.java new file mode 100644 index 00000000000..a0f96ed47e0 --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/context/DefaultMethodInvoker.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; + +/** + * Default implementation of the {@link MethodInvoker} API. + * + *

This implementation never provides arguments to a {@link Method}. + * + * @author Sam Brannen + * @since 6.1 + */ +final class DefaultMethodInvoker implements MethodInvoker { + + private static final Log logger = LogFactory.getLog(DefaultMethodInvoker.class); + + + @Override + @Nullable + public Object invoke(Method method, @Nullable Object target) throws Exception { + Assert.notNull(method, "Method must not be null"); + + try { + ReflectionUtils.makeAccessible(method); + return method.invoke(target); + } + catch (InvocationTargetException ex) { + if (logger.isErrorEnabled()) { + logger.error("Exception encountered while invoking method [%s] on target [%s]" + .formatted(method, target), ex.getTargetException()); + } + ReflectionUtils.rethrowException(ex.getTargetException()); + // appease the compiler + return null; + } + } + +} diff --git a/spring-test/src/main/java/org/springframework/test/context/MethodInvoker.java b/spring-test/src/main/java/org/springframework/test/context/MethodInvoker.java new file mode 100644 index 00000000000..f98ef3f3e6b --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/context/MethodInvoker.java @@ -0,0 +1,72 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context; + +import java.lang.reflect.Method; + +import org.springframework.lang.Nullable; + +/** + * {@code MethodInvoker} defines a generic API for invoking a {@link Method} + * within the Spring TestContext Framework. + * + *

Specifically, a {@code MethodInvoker} is made available to a + * {@link TestExecutionListener} via {@link TestContext#getMethodInvoker()}, and + * a {@code TestExecutionListener} can use the invoker to transparently benefit + * from any special method invocation features of the underlying testing framework. + * + *

For example, when the underlying testing framework is JUnit Jupiter, a + * {@code TestExecutionListener} can use a {@code MethodInvoker} to invoke + * arbitrary methods with JUnit Jupiter's + * {@linkplain org.junit.jupiter.api.extension.ExecutableInvoker parameter resolution + * mechanism}. For other testing frameworks, the {@link #DEFAULT_INVOKER} will be + * used. + * + * @author Sam Brannen + * @since 6.1 + * @see org.junit.jupiter.api.extension.ExecutableInvoker + * @see org.springframework.util.MethodInvoker + */ +public interface MethodInvoker { + + /** + * Shared instance of the default {@link MethodInvoker}. + *

This invoker never provides arguments to a {@link Method}. + */ + static final MethodInvoker DEFAULT_INVOKER = new DefaultMethodInvoker(); + + + /** + * Invoke the supplied {@link Method} on the supplied {@code target}. + *

When the {@link #DEFAULT_INVOKER} is used — for example, when + * the underlying testing framework is JUnit 4 or TestNG — the method + * must not declare any formal parameters. When the underlying testing + * framework is JUnit Jupiter, parameters will be dynamically resolved via + * registered {@link org.junit.jupiter.api.extension.ParameterResolver + * ParameterResolvers} (such as the + * {@link org.springframework.test.context.junit.jupiter.SpringExtension + * SpringExtension}). + * @param method the method to invoke + * @param target the object on which to invoke the method, may be {@code null} + * if the method is {@code static} + * @return the value returned from the method invocation, potentially {@code null} + * @throws Exception if any error occurs + */ + @Nullable + Object invoke(Method method, @Nullable Object target) throws Exception; + +} diff --git a/spring-test/src/main/java/org/springframework/test/context/TestContext.java b/spring-test/src/main/java/org/springframework/test/context/TestContext.java index f0b21a9479f..cce8366ece8 100644 --- a/spring-test/src/main/java/org/springframework/test/context/TestContext.java +++ b/spring-test/src/main/java/org/springframework/test/context/TestContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,6 +38,9 @@ import org.springframework.test.annotation.DirtiesContext.HierarchyMode; * that does not provide a copy constructor will likely fail in an environment * that executes tests concurrently. * + *

As of Spring Framework 6.1, concrete implementations are highly encouraged to + * override {@link #setMethodInvoker(MethodInvoker)} and {@link #getMethodInvoker()}. + * * @author Sam Brannen * @since 2.5 * @see TestContextManager @@ -150,4 +153,28 @@ public interface TestContext extends AttributeAccessor, Serializable { */ void updateState(@Nullable Object testInstance, @Nullable Method testMethod, @Nullable Throwable testException); + /** + * Set the {@link MethodInvoker} to use. + *

By default, this method does nothing. + *

Concrete implementations should track the supplied {@code MethodInvoker} + * and return it from {@link #getMethodInvoker()}. Note that the standard + * {@code TestContext} implementation in Spring overrides this method appropriately. + * @since 6.1 + */ + default void setMethodInvoker(MethodInvoker methodInvoker) { + /* no-op */ + } + + /** + * Get the {@link MethodInvoker} to use. + *

By default, this method returns {@link MethodInvoker#DEFAULT_INVOKER}. + *

Concrete implementations should return the {@code MethodInvoker} supplied + * to {@link #setMethodInvoker(MethodInvoker)}. Note that the standard + * {@code TestContext} implementation in Spring overrides this method appropriately. + * @since 6.1 + */ + default MethodInvoker getMethodInvoker() { + return MethodInvoker.DEFAULT_INVOKER; + } + } diff --git a/spring-test/src/main/java/org/springframework/test/context/TestContextManager.java b/spring-test/src/main/java/org/springframework/test/context/TestContextManager.java index 5d8d96dd9a1..db62ee87005 100644 --- a/spring-test/src/main/java/org/springframework/test/context/TestContextManager.java +++ b/spring-test/src/main/java/org/springframework/test/context/TestContextManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -195,21 +195,26 @@ public class TestContextManager { * @see #getTestExecutionListeners() */ public void beforeTestClass() throws Exception { - Class testClass = getTestContext().getTestClass(); - if (logger.isTraceEnabled()) { - logger.trace("beforeTestClass(): class [" + typeName(testClass) + "]"); - } - getTestContext().updateState(null, null, null); + try { + Class testClass = getTestContext().getTestClass(); + if (logger.isTraceEnabled()) { + logger.trace("beforeTestClass(): class [" + typeName(testClass) + "]"); + } + getTestContext().updateState(null, null, null); - for (TestExecutionListener testExecutionListener : getTestExecutionListeners()) { - try { - testExecutionListener.beforeTestClass(getTestContext()); - } - catch (Throwable ex) { - logException(ex, "beforeTestClass", testExecutionListener, testClass); - ReflectionUtils.rethrowException(ex); + for (TestExecutionListener testExecutionListener : getTestExecutionListeners()) { + try { + testExecutionListener.beforeTestClass(getTestContext()); + } + catch (Throwable ex) { + logException(ex, "beforeTestClass", testExecutionListener, testClass); + ReflectionUtils.rethrowException(ex); + } } } + finally { + resetMethodInvoker(); + } } /** @@ -231,25 +236,30 @@ public class TestContextManager { * @see #getTestExecutionListeners() */ public void prepareTestInstance(Object testInstance) throws Exception { - if (logger.isTraceEnabled()) { - logger.trace("prepareTestInstance(): instance [" + testInstance + "]"); - } - getTestContext().updateState(testInstance, null, null); - - for (TestExecutionListener testExecutionListener : getTestExecutionListeners()) { - try { - testExecutionListener.prepareTestInstance(getTestContext()); + try { + if (logger.isTraceEnabled()) { + logger.trace("prepareTestInstance(): instance [" + testInstance + "]"); } - catch (Throwable ex) { - if (logger.isErrorEnabled()) { - logger.error(""" + getTestContext().updateState(testInstance, null, null); + + for (TestExecutionListener testExecutionListener : getTestExecutionListeners()) { + try { + testExecutionListener.prepareTestInstance(getTestContext()); + } + catch (Throwable ex) { + if (logger.isErrorEnabled()) { + logger.error(""" Caught exception while allowing TestExecutionListener [%s] to \ prepare test instance [%s]""" - .formatted(typeName(testExecutionListener), testInstance), ex); + .formatted(typeName(testExecutionListener), testInstance), ex); + } + ReflectionUtils.rethrowException(ex); } - ReflectionUtils.rethrowException(ex); } } + finally { + resetMethodInvoker(); + } } /** @@ -280,17 +290,22 @@ public class TestContextManager { * @see #getTestExecutionListeners() */ public void beforeTestMethod(Object testInstance, Method testMethod) throws Exception { - String callbackName = "beforeTestMethod"; - prepareForBeforeCallback(callbackName, testInstance, testMethod); + try { + String callbackName = "beforeTestMethod"; + prepareForBeforeCallback(callbackName, testInstance, testMethod); - for (TestExecutionListener testExecutionListener : getTestExecutionListeners()) { - try { - testExecutionListener.beforeTestMethod(getTestContext()); - } - catch (Throwable ex) { - handleBeforeException(ex, callbackName, testExecutionListener, testInstance, testMethod); + for (TestExecutionListener testExecutionListener : getTestExecutionListeners()) { + try { + testExecutionListener.beforeTestMethod(getTestContext()); + } + catch (Throwable ex) { + handleBeforeException(ex, callbackName, testExecutionListener, testInstance, testMethod); + } } } + finally { + resetMethodInvoker(); + } } /** @@ -319,17 +334,22 @@ public class TestContextManager { * @see #getTestExecutionListeners() */ public void beforeTestExecution(Object testInstance, Method testMethod) throws Exception { - String callbackName = "beforeTestExecution"; - prepareForBeforeCallback(callbackName, testInstance, testMethod); + try { + String callbackName = "beforeTestExecution"; + prepareForBeforeCallback(callbackName, testInstance, testMethod); - for (TestExecutionListener testExecutionListener : getTestExecutionListeners()) { - try { - testExecutionListener.beforeTestExecution(getTestContext()); - } - catch (Throwable ex) { - handleBeforeException(ex, callbackName, testExecutionListener, testInstance, testMethod); + for (TestExecutionListener testExecutionListener : getTestExecutionListeners()) { + try { + testExecutionListener.beforeTestExecution(getTestContext()); + } + catch (Throwable ex) { + handleBeforeException(ex, callbackName, testExecutionListener, testInstance, testMethod); + } } } + finally { + resetMethodInvoker(); + } } /** @@ -367,29 +387,34 @@ public class TestContextManager { public void afterTestExecution(Object testInstance, Method testMethod, @Nullable Throwable exception) throws Exception { - String callbackName = "afterTestExecution"; - prepareForAfterCallback(callbackName, testInstance, testMethod, exception); - Throwable afterTestExecutionException = null; + try { + String callbackName = "afterTestExecution"; + prepareForAfterCallback(callbackName, testInstance, testMethod, exception); + Throwable afterTestExecutionException = null; - // Traverse the TestExecutionListeners in reverse order to ensure proper - // "wrapper"-style execution of listeners. - for (TestExecutionListener testExecutionListener : getReversedTestExecutionListeners()) { - try { - testExecutionListener.afterTestExecution(getTestContext()); + // Traverse the TestExecutionListeners in reverse order to ensure proper + // "wrapper"-style execution of listeners. + for (TestExecutionListener testExecutionListener : getReversedTestExecutionListeners()) { + try { + testExecutionListener.afterTestExecution(getTestContext()); + } + catch (Throwable ex) { + logException(ex, callbackName, testExecutionListener, testInstance, testMethod); + if (afterTestExecutionException == null) { + afterTestExecutionException = ex; + } + else { + afterTestExecutionException.addSuppressed(ex); + } + } } - catch (Throwable ex) { - logException(ex, callbackName, testExecutionListener, testInstance, testMethod); - if (afterTestExecutionException == null) { - afterTestExecutionException = ex; - } - else { - afterTestExecutionException.addSuppressed(ex); - } + + if (afterTestExecutionException != null) { + ReflectionUtils.rethrowException(afterTestExecutionException); } } - - if (afterTestExecutionException != null) { - ReflectionUtils.rethrowException(afterTestExecutionException); + finally { + resetMethodInvoker(); } } @@ -429,29 +454,34 @@ public class TestContextManager { public void afterTestMethod(Object testInstance, Method testMethod, @Nullable Throwable exception) throws Exception { - String callbackName = "afterTestMethod"; - prepareForAfterCallback(callbackName, testInstance, testMethod, exception); - Throwable afterTestMethodException = null; + try { + String callbackName = "afterTestMethod"; + prepareForAfterCallback(callbackName, testInstance, testMethod, exception); + Throwable afterTestMethodException = null; - // Traverse the TestExecutionListeners in reverse order to ensure proper - // "wrapper"-style execution of listeners. - for (TestExecutionListener testExecutionListener : getReversedTestExecutionListeners()) { - try { - testExecutionListener.afterTestMethod(getTestContext()); + // Traverse the TestExecutionListeners in reverse order to ensure proper + // "wrapper"-style execution of listeners. + for (TestExecutionListener testExecutionListener : getReversedTestExecutionListeners()) { + try { + testExecutionListener.afterTestMethod(getTestContext()); + } + catch (Throwable ex) { + logException(ex, callbackName, testExecutionListener, testInstance, testMethod); + if (afterTestMethodException == null) { + afterTestMethodException = ex; + } + else { + afterTestMethodException.addSuppressed(ex); + } + } } - catch (Throwable ex) { - logException(ex, callbackName, testExecutionListener, testInstance, testMethod); - if (afterTestMethodException == null) { - afterTestMethodException = ex; - } - else { - afterTestMethodException.addSuppressed(ex); - } + + if (afterTestMethodException != null) { + ReflectionUtils.rethrowException(afterTestMethodException); } } - - if (afterTestMethodException != null) { - ReflectionUtils.rethrowException(afterTestMethodException); + finally { + resetMethodInvoker(); } } @@ -504,6 +534,15 @@ public class TestContextManager { } } + /** + * Reset the {@link MethodInvoker} to the default to ensure that a custom + * {@code MethodInvoker} for the current test execution is not retained for + * subsequent test executions. + */ + private void resetMethodInvoker() { + getTestContext().setMethodInvoker(MethodInvoker.DEFAULT_INVOKER); + } + private void prepareForBeforeCallback(String callbackName, Object testInstance, Method testMethod) { if (logger.isTraceEnabled()) { logger.trace("%s(): instance [%s], method [%s]".formatted(callbackName, testInstance, testMethod)); diff --git a/spring-test/src/main/java/org/springframework/test/context/junit/jupiter/SpringExtension.java b/spring-test/src/main/java/org/springframework/test/context/junit/jupiter/SpringExtension.java index 5eaa16e5a8d..8b10a7e9bbc 100644 --- a/spring-test/src/main/java/org/springframework/test/context/junit/jupiter/SpringExtension.java +++ b/spring-test/src/main/java/org/springframework/test/context/junit/jupiter/SpringExtension.java @@ -52,6 +52,7 @@ import org.springframework.core.annotation.MergedAnnotations; import org.springframework.core.annotation.MergedAnnotations.SearchStrategy; import org.springframework.core.annotation.RepeatableContainers; import org.springframework.lang.Nullable; +import org.springframework.test.context.MethodInvoker; import org.springframework.test.context.TestConstructor; import org.springframework.test.context.TestContextAnnotationUtils; import org.springframework.test.context.TestContextManager; @@ -122,7 +123,9 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes */ @Override public void beforeAll(ExtensionContext context) throws Exception { - getTestContextManager(context).beforeTestClass(); + TestContextManager testContextManager = getTestContextManager(context); + registerMethodInvoker(testContextManager, context); + testContextManager.beforeTestClass(); } /** @@ -131,7 +134,9 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes @Override public void afterAll(ExtensionContext context) throws Exception { try { - getTestContextManager(context).afterTestClass(); + TestContextManager testContextManager = getTestContextManager(context); + registerMethodInvoker(testContextManager, context); + testContextManager.afterTestClass(); } finally { getStore(context).remove(context.getRequiredTestClass()); @@ -148,7 +153,9 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes public void postProcessTestInstance(Object testInstance, ExtensionContext context) throws Exception { validateAutowiredConfig(context); validateRecordApplicationEventsConfig(context); - getTestContextManager(context).prepareTestInstance(testInstance); + TestContextManager testContextManager = getTestContextManager(context); + registerMethodInvoker(testContextManager, context); + testContextManager.prepareTestInstance(testInstance); } /** @@ -223,7 +230,9 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes public void beforeEach(ExtensionContext context) throws Exception { Object testInstance = context.getRequiredTestInstance(); Method testMethod = context.getRequiredTestMethod(); - getTestContextManager(context).beforeTestMethod(testInstance, testMethod); + TestContextManager testContextManager = getTestContextManager(context); + registerMethodInvoker(testContextManager, context); + testContextManager.beforeTestMethod(testInstance, testMethod); } /** @@ -233,7 +242,9 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes public void beforeTestExecution(ExtensionContext context) throws Exception { Object testInstance = context.getRequiredTestInstance(); Method testMethod = context.getRequiredTestMethod(); - getTestContextManager(context).beforeTestExecution(testInstance, testMethod); + TestContextManager testContextManager = getTestContextManager(context); + registerMethodInvoker(testContextManager, context); + testContextManager.beforeTestExecution(testInstance, testMethod); } /** @@ -244,7 +255,9 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes Object testInstance = context.getRequiredTestInstance(); Method testMethod = context.getRequiredTestMethod(); Throwable testException = context.getExecutionException().orElse(null); - getTestContextManager(context).afterTestExecution(testInstance, testMethod, testException); + TestContextManager testContextManager = getTestContextManager(context); + registerMethodInvoker(testContextManager, context); + testContextManager.afterTestExecution(testInstance, testMethod, testException); } /** @@ -255,7 +268,9 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes Object testInstance = context.getRequiredTestInstance(); Method testMethod = context.getRequiredTestMethod(); Throwable testException = context.getExecutionException().orElse(null); - getTestContextManager(context).afterTestMethod(testInstance, testMethod, testException); + TestContextManager testContextManager = getTestContextManager(context); + registerMethodInvoker(testContextManager, context); + testContextManager.afterTestMethod(testInstance, testMethod, testException); } /** @@ -350,6 +365,17 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes return context.getRoot().getStore(TEST_CONTEXT_MANAGER_NAMESPACE); } + /** + * Register a {@link MethodInvoker} adaptor for Jupiter's + * {@link org.junit.jupiter.api.extension.ExecutableInvoker ExecutableInvoker} + * in the {@link org.springframework.test.context.TestContext TestContext} for + * the supplied {@link TestContextManager}. + * @since 6.1 + */ + private static void registerMethodInvoker(TestContextManager testContextManager, ExtensionContext context) { + testContextManager.getTestContext().setMethodInvoker(context.getExecutableInvoker()::invoke); + } + private static boolean isAutowiredTestOrLifecycleMethod(Method method) { MergedAnnotations mergedAnnotations = MergedAnnotations.from(method, SearchStrategy.DIRECT, RepeatableContainers.none()); diff --git a/spring-test/src/main/java/org/springframework/test/context/support/DefaultTestContext.java b/spring-test/src/main/java/org/springframework/test/context/support/DefaultTestContext.java index be33676662b..03b6a25827e 100644 --- a/spring-test/src/main/java/org/springframework/test/context/support/DefaultTestContext.java +++ b/spring-test/src/main/java/org/springframework/test/context/support/DefaultTestContext.java @@ -30,6 +30,7 @@ import org.springframework.lang.Nullable; import org.springframework.test.annotation.DirtiesContext.HierarchyMode; import org.springframework.test.context.CacheAwareContextLoaderDelegate; import org.springframework.test.context.MergedContextConfiguration; +import org.springframework.test.context.MethodInvoker; import org.springframework.test.context.TestContext; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -64,6 +65,8 @@ public class DefaultTestContext implements TestContext { @Nullable private volatile Throwable testException; + private volatile MethodInvoker methodInvoker = MethodInvoker.DEFAULT_INVOKER; + /** * Copy constructor for creating a new {@code DefaultTestContext} @@ -183,6 +186,17 @@ public class DefaultTestContext implements TestContext { this.testException = testException; } + @Override + public final void setMethodInvoker(MethodInvoker methodInvoker) { + Assert.notNull(methodInvoker, "MethodInvoker must not be null"); + this.methodInvoker = methodInvoker; + } + + @Override + public final MethodInvoker getMethodInvoker() { + return this.methodInvoker; + } + @Override public void setAttribute(String name, @Nullable Object value) { Assert.notNull(name, "Name must not be null"); diff --git a/spring-test/src/test/java/org/springframework/test/context/transaction/TransactionalTestExecutionListenerTests.java b/spring-test/src/test/java/org/springframework/test/context/transaction/TransactionalTestExecutionListenerTests.java index b21b0d1fc1d..37b57aa756c 100644 --- a/spring-test/src/test/java/org/springframework/test/context/transaction/TransactionalTestExecutionListenerTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/transaction/TransactionalTestExecutionListenerTests.java @@ -37,6 +37,7 @@ import org.springframework.transaction.support.SimpleTransactionStatus; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.CALLS_REAL_METHODS; import static org.mockito.Mockito.mock; import static org.springframework.transaction.annotation.Propagation.NOT_SUPPORTED; import static org.springframework.transaction.annotation.Propagation.REQUIRED; @@ -58,7 +59,7 @@ class TransactionalTestExecutionListenerTests { } }; - private final TestContext testContext = mock(); + private final TestContext testContext = mock(CALLS_REAL_METHODS); @AfterEach