diff --git a/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/AbstractMockitoTestExecutionListener.java b/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/AbstractMockitoTestExecutionListener.java index b368af1032..c6f4dbf477 100644 --- a/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/AbstractMockitoTestExecutionListener.java +++ b/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/AbstractMockitoTestExecutionListener.java @@ -16,16 +16,16 @@ package org.springframework.test.context.bean.override.mockito; -import java.lang.annotation.Annotation; -import java.lang.reflect.AnnotatedElement; -import java.util.Arrays; -import java.util.concurrent.atomic.AtomicBoolean; +import java.lang.reflect.Field; import java.util.function.Predicate; +import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.core.annotation.MergedAnnotations; +import org.springframework.core.annotation.MergedAnnotations.SearchStrategy; import org.springframework.test.context.TestContext; +import org.springframework.test.context.TestContextAnnotationUtils; import org.springframework.test.context.support.AbstractTestExecutionListener; import org.springframework.util.ClassUtils; -import org.springframework.util.ReflectionUtils; /** * Abstract base class for {@code TestExecutionListener} implementations involving @@ -44,8 +44,8 @@ abstract class AbstractMockitoTestExecutionListener extends AbstractTestExecutio private static final String ORG_MOCKITO_PACKAGE = "org.mockito"; - private static final Predicate isMockitoAnnotation = annotation -> { - String packageName = annotation.annotationType().getPackageName(); + private static final Predicate> isMockitoAnnotation = mergedAnnotation -> { + String packageName = mergedAnnotation.getType().getPackageName(); return (packageName.startsWith(SPRING_MOCKITO_PACKAGE) || packageName.startsWith(ORG_MOCKITO_PACKAGE)); }; @@ -60,25 +60,47 @@ abstract class AbstractMockitoTestExecutionListener extends AbstractTestExecutio return hasMockitoAnnotations(testContext.getTestClass()); } - private static boolean hasMockitoAnnotations(Class testClass) { - if (isAnnotated(testClass)) { + /** + * Determine if Mockito annotations are declared on the supplied class, on an + * interface it implements, on a superclass, or on an enclosing class or + * whether a field in any such class is annotated with a Mockito annotation. + */ + private static boolean hasMockitoAnnotations(Class clazz) { + // Declared on the class? + if (MergedAnnotations.from(clazz, SearchStrategy.DIRECT).stream().anyMatch(isMockitoAnnotation)) { return true; } - // TODO Ideally we should short-circuit the search once we've found a Mockito annotation, - // since there's no need to continue searching additional fields or further up the class - // hierarchy; however, that is not possible with ReflectionUtils#doWithFields. Plus, the - // previous invocation of isAnnotated(testClass) only finds annotations declared directly - // on the test class. So, we'll likely need a completely different approach that combines - // the "test class/interface is annotated?" and "field is annotated?" checks in a single - // search algorithm, and we'll also need to support @Nested class hierarchies. - AtomicBoolean found = new AtomicBoolean(); - ReflectionUtils.doWithFields(testClass, - field -> found.set(true), AbstractMockitoTestExecutionListener::isAnnotated); - return found.get(); - } - private static boolean isAnnotated(AnnotatedElement annotatedElement) { - return Arrays.stream(annotatedElement.getAnnotations()).anyMatch(isMockitoAnnotation); + // Declared on a field? + for (Field field : clazz.getDeclaredFields()) { + if (MergedAnnotations.from(field, SearchStrategy.DIRECT).stream().anyMatch(isMockitoAnnotation)) { + return true; + } + } + + // Declared on an interface? + for (Class ifc : clazz.getInterfaces()) { + if (hasMockitoAnnotations(ifc)) { + return true; + } + } + + // Declared on a superclass? + Class superclass = clazz.getSuperclass(); + if (superclass != null & superclass != Object.class) { + if (hasMockitoAnnotations(superclass)) { + return true; + } + } + + // Declared on an enclosing class of an inner class? + if (TestContextAnnotationUtils.searchEnclosingClass(clazz)) { + if (hasMockitoAnnotations(clazz.getEnclosingClass())) { + return true; + } + } + + return false; } } diff --git a/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoTestExecutionListener.java b/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoTestExecutionListener.java index b6c68876ba..5648657b69 100644 --- a/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoTestExecutionListener.java +++ b/spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoTestExecutionListener.java @@ -63,45 +63,26 @@ public class MockitoTestExecutionListener extends AbstractMockitoTestExecutionLi return 1950; } - @Override - public void prepareTestInstance(TestContext testContext) { - if (mockitoPresent) { - closeMocks(testContext); - initMocks(testContext); - } - } - @Override public void beforeTestMethod(TestContext testContext) { - if (mockitoPresent && Boolean.TRUE.equals( - testContext.getAttribute(DependencyInjectionTestExecutionListener.REINJECT_DEPENDENCIES_ATTRIBUTE))) { - closeMocks(testContext); + if (mockitoPresent && hasMockitoAnnotations(testContext)) { initMocks(testContext); } } @Override public void afterTestMethod(TestContext testContext) { - if (mockitoPresent) { - closeMocks(testContext); - } - } - - @Override - public void afterTestClass(TestContext testContext) { - if (mockitoPresent) { + if (mockitoPresent && hasMockitoAnnotations(testContext)) { closeMocks(testContext); } } private static void initMocks(TestContext testContext) { - if (hasMockitoAnnotations(testContext)) { - Class testClass = testContext.getTestClass(); - Object testInstance = testContext.getTestInstance(); - MockitoBeanSettings annotation = AnnotationUtils.findAnnotation(testClass, MockitoBeanSettings.class); - Strictness strictness = (annotation != null ? annotation.value() : Strictness.STRICT_STUBS); - testContext.setAttribute(MOCKITO_SESSION_ATTRIBUTE_NAME, initMockitoSession(testInstance, strictness)); - } + Class testClass = testContext.getTestClass(); + Object testInstance = testContext.getTestInstance(); + MockitoBeanSettings annotation = AnnotationUtils.findAnnotation(testClass, MockitoBeanSettings.class); + Strictness strictness = (annotation != null ? annotation.value() : Strictness.STRICT_STUBS); + testContext.setAttribute(MOCKITO_SESSION_ATTRIBUTE_NAME, initMockitoSession(testInstance, strictness)); } private static MockitoSession initMockitoSession(Object testInstance, Strictness strictness) { diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanNestedTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanNestedTests.java new file mode 100644 index 0000000000..6a02168eca --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanNestedTests.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.bean.override.mockito; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; + +import static org.mockito.BDDMockito.then; +import static org.mockito.Mockito.times; + +/** + * Verifies proper handling of the {@link org.mockito.MockitoSession MockitoSession} + * when a {@link MockitoBean @MockitoBean} field is declared in the enclosing class of + * a {@link Nested @Nested} test class. + * + * @author Andy Wilkinson + * @author Sam Brannen + * @since 6.2 + */ +@ExtendWith(SpringExtension.class) +// TODO Remove @ContextConfiguration declaration. +// @ContextConfiguration is currently required due to a bug in the TestContext framework. +@ContextConfiguration +class MockitoBeanNestedTests { + + @MockitoBean + Runnable action; + + @Autowired + Task task; + + @Test + void mockWasInvokedOnce() { + task.execute(); + then(action).should().run(); + } + + @Test + void mockWasInvokedTwice() { + task.execute(); + task.execute(); + then(action).should(times(2)).run(); + } + + @Nested + class MockitoBeanFieldInEnclosingClassTests { + + @Test + void mockWasInvokedOnce() { + task.execute(); + then(action).should().run(); + } + + @Test + void mockWasInvokedTwice() { + task.execute(); + task.execute(); + then(action).should(times(2)).run(); + } + } + + record Task(Runnable action) { + + void execute() { + this.action.run(); + } + } + + @Configuration(proxyBeanMethods = false) + static class TestConfiguration { + + @Bean + Task task(Runnable action) { + return new Task(action); + } + } + +}