This commit reviews the structure of several classes to comply with
our guidelines. Also, rather than exposing a static method to configure
the context in a test, we call the high-level API directly.
This commit is contained in:
Stéphane Nicoll 2024-06-07 16:05:23 +02:00
parent 0165529d97
commit 02517e5011
7 changed files with 94 additions and 97 deletions

View File

@ -77,12 +77,6 @@ class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor,
}
@Override
public int getOrder() {
return Ordered.LOWEST_PRECEDENCE - 10;
}
@Override
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
if (!(beanFactory instanceof BeanDefinitionRegistry registry)) {
@ -92,6 +86,11 @@ class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor,
postProcessWithRegistry(beanFactory, registry);
}
@Override
public int getOrder() {
return Ordered.LOWEST_PRECEDENCE - 10;
}
private void postProcessWithRegistry(ConfigurableListableBeanFactory beanFactory, BeanDefinitionRegistry registry) {
for (OverrideMetadata metadata : this.overrideRegistrar.getOverrideMetadata()) {
registerBeanOverride(beanFactory, registry, metadata);
@ -240,7 +239,7 @@ class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor,
}
static final class WrapEarlyBeanPostProcessor implements SmartInstantiationAwareBeanPostProcessor,
static class WrapEarlyBeanPostProcessor implements SmartInstantiationAwareBeanPostProcessor,
PriorityOrdered {
private final Map<String, Object> earlyReferences = new ConcurrentHashMap<>(16);
@ -248,7 +247,7 @@ class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor,
private final BeanOverrideRegistrar overrideRegistrar;
private WrapEarlyBeanPostProcessor(BeanOverrideRegistrar registrar) {
WrapEarlyBeanPostProcessor(BeanOverrideRegistrar registrar) {
this.overrideRegistrar = registrar;
}

View File

@ -55,7 +55,16 @@ class BeanOverrideContextCustomizer implements ContextCustomizer {
this.detectedClasses = detectedClasses;
}
static void registerInfrastructure(BeanDefinitionRegistry registry, Set<Class<?>> detectedClasses) {
@Override
public void customizeContext(ConfigurableApplicationContext context, MergedContextConfiguration mergedConfig) {
if (!(context instanceof BeanDefinitionRegistry registry)) {
throw new IllegalStateException("Cannot process bean overrides with an ApplicationContext " +
"that doesn't implement BeanDefinitionRegistry: " + context.getClass());
}
registerInfrastructure(registry, this.detectedClasses);
}
private void registerInfrastructure(BeanDefinitionRegistry registry, Set<Class<?>> detectedClasses) {
addInfrastructureBeanDefinition(registry, BeanOverrideRegistrar.class, REGISTRAR_BEAN_NAME,
constructorArgs -> constructorArgs.addIndexedArgumentValue(0, detectedClasses));
RuntimeBeanReference registrarReference = new RuntimeBeanReference(REGISTRAR_BEAN_NAME);
@ -67,7 +76,7 @@ class BeanOverrideContextCustomizer implements ContextCustomizer {
constructorArgs -> constructorArgs.addIndexedArgumentValue(0, registrarReference));
}
private static void addInfrastructureBeanDefinition(BeanDefinitionRegistry registry,
private void addInfrastructureBeanDefinition(BeanDefinitionRegistry registry,
Class<?> clazz, String beanName, Consumer<ConstructorArgumentValues> constructorArgumentsConsumer) {
if (!registry.containsBeanDefinition(beanName)) {
@ -80,24 +89,15 @@ class BeanOverrideContextCustomizer implements ContextCustomizer {
}
@Override
public void customizeContext(ConfigurableApplicationContext context, MergedContextConfiguration mergedConfig) {
if (!(context instanceof BeanDefinitionRegistry registry)) {
throw new IllegalStateException("Cannot process bean overrides with an ApplicationContext " +
"that doesn't implement BeanDefinitionRegistry: " + context.getClass());
}
registerInfrastructure(registry, this.detectedClasses);
}
@Override
public boolean equals(Object obj) {
if (obj == this) {
public boolean equals(Object other) {
if (other == this) {
return true;
}
if (obj == null || obj.getClass() != getClass()) {
if (other == null || other.getClass() != getClass()) {
return false;
}
BeanOverrideContextCustomizer other = (BeanOverrideContextCustomizer) obj;
return this.detectedClasses.equals(other.detectedClasses);
BeanOverrideContextCustomizer that = (BeanOverrideContextCustomizer) other;
return this.detectedClasses.equals(that.detectedClasses);
}
@Override

View File

@ -45,11 +45,39 @@ import org.springframework.util.StringUtils;
*/
class TestBeanOverrideProcessor implements BeanOverrideProcessor {
@Override
public TestBeanOverrideMetadata createMetadata(Annotation overrideAnnotation, Class<?> testClass, Field field) {
if (!(overrideAnnotation instanceof TestBean testBeanAnnotation)) {
throw new IllegalStateException("Invalid annotation passed to %s: expected @TestBean on field %s.%s"
.formatted(getClass().getSimpleName(), field.getDeclaringClass().getName(), field.getName()));
}
Method overrideMethod;
String methodName = testBeanAnnotation.methodName();
if (!methodName.isBlank()) {
// If the user specified an explicit method name, search for that.
overrideMethod = findTestBeanFactoryMethod(testClass, field.getType(), methodName);
}
else {
// Otherwise, search for candidate factory methods using the convention
// suffix and the field name or explicit bean name (if any).
List<String> candidateMethodNames = new ArrayList<>();
candidateMethodNames.add(field.getName() + TestBean.CONVENTION_SUFFIX);
String beanName = testBeanAnnotation.name();
if (StringUtils.hasText(beanName)) {
candidateMethodNames.add(beanName + TestBean.CONVENTION_SUFFIX);
}
overrideMethod = findTestBeanFactoryMethod(testClass, field.getType(), candidateMethodNames);
}
String beanName = (StringUtils.hasText(testBeanAnnotation.name()) ? testBeanAnnotation.name() : null);
return new TestBeanOverrideMetadata(field, ResolvableType.forField(field, testClass), beanName, overrideMethod);
}
/**
* Find a test bean factory {@link Method} for the given {@link Class}.
* <p>Delegates to {@link #findTestBeanFactoryMethod(Class, Class, List)}.
*/
static Method findTestBeanFactoryMethod(Class<?> clazz, Class<?> methodReturnType, String... methodNames) {
Method findTestBeanFactoryMethod(Class<?> clazz, Class<?> methodReturnType, String... methodNames) {
return findTestBeanFactoryMethod(clazz, methodReturnType, List.of(methodNames));
}
@ -76,7 +104,7 @@ class TestBeanOverrideProcessor implements BeanOverrideProcessor {
* @throws IllegalStateException if a matching factory method cannot
* be found or multiple methods match
*/
static Method findTestBeanFactoryMethod(Class<?> clazz, Class<?> methodReturnType, List<String> methodNames) {
Method findTestBeanFactoryMethod(Class<?> clazz, Class<?> methodReturnType, List<String> methodNames) {
Assert.notEmpty(methodNames, "At least one candidate method name is required");
Set<String> supportedNames = new LinkedHashSet<>(methodNames);
MethodFilter methodFilter = method -> (Modifier.isStatic(method.getModifiers()) &&
@ -88,47 +116,17 @@ class TestBeanOverrideProcessor implements BeanOverrideProcessor {
Assert.state(!methods.isEmpty(), () -> """
Failed to find a static test bean factory method in %s with return type %s \
whose name matches one of the supported candidates %s""".formatted(
clazz.getName(), methodReturnType.getName(), supportedNames));
clazz.getName(), methodReturnType.getName(), supportedNames));
long uniqueMethodNameCount = methods.stream().map(Method::getName).distinct().count();
Assert.state(uniqueMethodNameCount == 1, () -> """
Found %d competing static test bean factory methods in %s with return type %s \
whose name matches one of the supported candidates %s""".formatted(
uniqueMethodNameCount, clazz.getName(), methodReturnType.getName(), supportedNames));
uniqueMethodNameCount, clazz.getName(), methodReturnType.getName(), supportedNames));
return methods.iterator().next();
}
@Override
public TestBeanOverrideMetadata createMetadata(Annotation overrideAnnotation, Class<?> testClass, Field field) {
if (!(overrideAnnotation instanceof TestBean testBeanAnnotation)) {
throw new IllegalStateException("Invalid annotation passed to %s: expected @TestBean on field %s.%s"
.formatted(getClass().getSimpleName(), field.getDeclaringClass().getName(), field.getName()));
}
Method overrideMethod;
String methodName = testBeanAnnotation.methodName();
if (!methodName.isBlank()) {
// If the user specified an explicit method name, search for that.
overrideMethod = findTestBeanFactoryMethod(testClass, field.getType(), methodName);
}
else {
// Otherwise, search for candidate factory methods using the convention
// suffix and the field name or explicit bean name (if any).
List<String> candidateMethodNames = new ArrayList<>();
candidateMethodNames.add(field.getName() + TestBean.CONVENTION_SUFFIX);
String beanName = testBeanAnnotation.name();
if (StringUtils.hasText(beanName)) {
candidateMethodNames.add(beanName + TestBean.CONVENTION_SUFFIX);
}
overrideMethod = findTestBeanFactoryMethod(testClass, field.getType(), candidateMethodNames);
}
String beanName = (StringUtils.hasText(testBeanAnnotation.name()) ? testBeanAnnotation.name() : null);
return new TestBeanOverrideMetadata(field, ResolvableType.forField(field, testClass), beanName, overrideMethod);
}
private static Set<Method> findMethods(Class<?> clazz, MethodFilter methodFilter) {
Set<Method> methods = MethodIntrospector.selectMethods(clazz, methodFilter);
if (methods.isEmpty() && TestContextAnnotationUtils.searchEnclosingClass(clazz)) {

View File

@ -69,6 +69,14 @@ class MockitoBeanOverrideMetadata extends MockitoOverrideMetadata {
this.serializable = serializable;
}
private static Set<Class<?>> asClassSet(@Nullable Class<?>[] classes) {
Set<Class<?>> classSet = new LinkedHashSet<>();
if (classes != null) {
classSet.addAll(Arrays.asList(classes));
}
return Collections.unmodifiableSet(classSet);
}
/**
* Return the extra interfaces.
@ -99,12 +107,21 @@ class MockitoBeanOverrideMetadata extends MockitoOverrideMetadata {
return createMock(beanName);
}
private Set<Class<?>> asClassSet(@Nullable Class<?>[] classes) {
Set<Class<?>> classSet = new LinkedHashSet<>();
if (classes != null) {
classSet.addAll(Arrays.asList(classes));
@SuppressWarnings("unchecked")
<T> T createMock(String name) {
MockSettings settings = MockReset.withSettings(getReset());
if (StringUtils.hasLength(name)) {
settings.name(name);
}
return Collections.unmodifiableSet(classSet);
if (!this.extraInterfaces.isEmpty()) {
settings.extraInterfaces(ClassUtils.toClassArray(this.extraInterfaces));
}
settings.defaultAnswer(this.answer);
if (this.serializable) {
settings.serializable();
}
Class<?> targetType = getBeanType().resolve();
return (T) mock(targetType, settings);
}
@Override
@ -140,21 +157,4 @@ class MockitoBeanOverrideMetadata extends MockitoOverrideMetadata {
.toString();
}
@SuppressWarnings("unchecked")
<T> T createMock(String name) {
MockSettings settings = MockReset.withSettings(getReset());
if (StringUtils.hasLength(name)) {
settings.name(name);
}
if (!this.extraInterfaces.isEmpty()) {
settings.extraInterfaces(ClassUtils.toClassArray(this.extraInterfaces));
}
settings.defaultAnswer(this.answer);
if (this.serializable) {
settings.serializable();
}
Class<?> targetType = getBeanType().resolve();
return (T) mock(targetType, settings);
}
}

View File

@ -37,6 +37,7 @@ import org.springframework.context.annotation.Configuration;
import org.springframework.context.support.SimpleThreadScope;
import org.springframework.core.Ordered;
import org.springframework.core.ResolvableType;
import org.springframework.test.context.MergedContextConfiguration;
import org.springframework.test.context.bean.override.example.ExampleBeanOverrideAnnotation;
import org.springframework.test.context.bean.override.example.ExampleService;
import org.springframework.test.context.bean.override.example.FailingExampleService;
@ -47,6 +48,7 @@ import org.springframework.util.Assert;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link BeanOverrideBeanFactoryPostProcessor} combined with a
@ -199,9 +201,9 @@ class BeanOverrideBeanFactoryPostProcessorTests {
});
}
private AnnotationConfigApplicationContext createContext(Class<?>... classes) {
private AnnotationConfigApplicationContext createContext(Class<?> testClass) {
AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext();
BeanOverrideContextCustomizer.registerInfrastructure(context, Set.of(classes));
new BeanOverrideContextCustomizer(Set.of(testClass)).customizeContext(context, mock(MergedContextConfiguration.class));
return context;
}

View File

@ -24,7 +24,7 @@ package org.springframework.test.context.bean.override.convention;
*/
interface TestBeanFactory {
public static String createTestMessage() {
static String createTestMessage() {
return "test";
}

View File

@ -30,7 +30,6 @@ import org.springframework.util.ReflectionUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.springframework.test.context.bean.override.convention.TestBeanOverrideProcessor.findTestBeanFactoryMethod;
/**
* Tests for {@link TestBeanOverrideProcessor}.
@ -41,12 +40,15 @@ import static org.springframework.test.context.bean.override.convention.TestBean
*/
class TestBeanOverrideProcessorTests {
private final TestBeanOverrideProcessor processor = new TestBeanOverrideProcessor();
@Test
void findTestBeanFactoryMethodFindsFromCandidateNames() {
Class<?> clazz = MethodConventionTestCase.class;
Class<?> returnType = ExampleService.class;
Method method = findTestBeanFactoryMethod(clazz, returnType, "example1", "example2", "example3");
Method method = this.processor.findTestBeanFactoryMethod(
clazz, returnType, "example1", "example2", "example3");
assertThat(method.getName()).isEqualTo("example2");
}
@ -56,7 +58,7 @@ class TestBeanOverrideProcessorTests {
Class<?> clazz = SubTestCase.class;
Class<?> returnType = String.class;
Method method = findTestBeanFactoryMethod(clazz, returnType, "factory");
Method method = this.processor.findTestBeanFactoryMethod(clazz, returnType, "factory");
assertThat(method).isEqualTo(ReflectionUtils.findMethod(clazz, "factory"));
}
@ -67,7 +69,7 @@ class TestBeanOverrideProcessorTests {
Class<?> returnType = ExampleService.class;
assertThatIllegalStateException()
.isThrownBy(() -> findTestBeanFactoryMethod(clazz, returnType, "example1", "example3"))
.isThrownBy(() -> this.processor.findTestBeanFactoryMethod(clazz, returnType, "example1", "example3"))
.withMessage("""
Failed to find a static test bean factory method in %s with return type %s \
whose name matches one of the supported candidates %s""",
@ -80,7 +82,7 @@ class TestBeanOverrideProcessorTests {
Class<?> returnType = ExampleService.class;
assertThatIllegalStateException()
.isThrownBy(() -> findTestBeanFactoryMethod(clazz, returnType, "example2", "example4"))
.isThrownBy(() -> this.processor.findTestBeanFactoryMethod(clazz, returnType, "example2", "example4"))
.withMessage("""
Found %d competing static test bean factory methods in %s with return type %s \
whose name matches one of the supported candidates %s""".formatted(
@ -90,7 +92,7 @@ class TestBeanOverrideProcessorTests {
@Test
void findTestBeanFactoryMethodNoNameProvided() {
assertThatIllegalArgumentException()
.isThrownBy(() -> findTestBeanFactoryMethod(MethodConventionTestCase.class, ExampleService.class))
.isThrownBy(() -> this.processor.findTestBeanFactoryMethod(MethodConventionTestCase.class, ExampleService.class))
.withMessage("At least one candidate method name is required");
}
@ -102,9 +104,8 @@ class TestBeanOverrideProcessorTests {
TestBean overrideAnnotation = field.getAnnotation(TestBean.class);
assertThat(overrideAnnotation).isNotNull();
TestBeanOverrideProcessor processor = new TestBeanOverrideProcessor();
assertThatIllegalStateException()
.isThrownBy(() -> processor.createMetadata(overrideAnnotation, clazz, field))
.isThrownBy(() -> this.processor.createMetadata(overrideAnnotation, clazz, field))
.withMessage("""
Failed to find a static test bean factory method in %s with return type %s \
whose name matches one of the supported candidates %s""",
@ -118,8 +119,7 @@ class TestBeanOverrideProcessorTests {
TestBean overrideAnnotation = field.getAnnotation(TestBean.class);
assertThat(overrideAnnotation).isNotNull();
TestBeanOverrideProcessor processor = new TestBeanOverrideProcessor();
assertThat(processor.createMetadata(overrideAnnotation, clazz, field))
assertThat(this.processor.createMetadata(overrideAnnotation, clazz, field))
.isInstanceOf(TestBeanOverrideMetadata.class);
}
@ -131,8 +131,7 @@ class TestBeanOverrideProcessorTests {
TestBean overrideAnnotation = field.getAnnotation(TestBean.class);
assertThat(overrideAnnotation).isNotNull();
TestBeanOverrideProcessor processor = new TestBeanOverrideProcessor();
assertThatIllegalStateException().isThrownBy(() -> processor.createMetadata(
assertThatIllegalStateException().isThrownBy(() -> this.processor.createMetadata(
overrideAnnotation, clazz, field))
.withMessage("""
Failed to find a static test bean factory method in %s with return type %s \
@ -146,8 +145,7 @@ class TestBeanOverrideProcessorTests {
Field field = clazz.getField("field");
NonNull badAnnotation = AnnotationUtils.synthesizeAnnotation(NonNull.class);
TestBeanOverrideProcessor processor = new TestBeanOverrideProcessor();
assertThatIllegalStateException().isThrownBy(() -> processor.createMetadata(badAnnotation, clazz, field))
assertThatIllegalStateException().isThrownBy(() -> this.processor.createMetadata(badAnnotation, clazz, field))
.withMessage("Invalid annotation passed to TestBeanOverrideProcessor: expected @TestBean" +
" on field %s.%s", field.getDeclaringClass().getName(), field.getName());
}