Honor @⁠Primary for test Bean Overrides such as @⁠MockitoBean

Spring Boot has honored @⁠Primary for selecting which candidate bean
@⁠MockBean and @⁠SpyBean should mock or spy since Spring Boot 1.4.3;
however, the support for @⁠Primary was not ported from Spring Boot to
Spring Framework's new Bean Overrides feature in the TestContext
framework.

To address that, this commit introduces support for @⁠Primary for
selecting bean overrides -- for example, for annotations such as
@⁠TestBean, @⁠MockitoBean, and @⁠MockitoSpyBean.

See https://github.com/spring-projects/spring-boot/issues/7621
Closes gh-33819
This commit is contained in:
Sam Brannen 2024-10-30 15:44:06 +01:00
parent 9166688b6f
commit 08e0baac94
4 changed files with 106 additions and 22 deletions

View File

@ -26,6 +26,7 @@ import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
@ -219,29 +220,43 @@ class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor,
*/
private void wrapBean(ConfigurableListableBeanFactory beanFactory, BeanOverrideHandler handler) {
String beanName = handler.getBeanName();
ResolvableType beanType = handler.getBeanType();
if (beanName == null) {
// We are wrapping an existing bean by-type.
Set<String> candidateNames = getExistingBeanNamesByType(beanFactory, handler, true);
int candidateCount = candidateNames.size();
if (candidateCount != 1) {
Field field = handler.getField();
throw new IllegalStateException("""
Unable to select a bean to override by wrapping: found %d bean instances of type %s \
(as required by annotated field '%s.%s')%s"""
.formatted(candidateCount, handler.getBeanType(),
field.getDeclaringClass().getSimpleName(), field.getName(),
(candidateCount > 0 ? ": " + candidateNames : "")));
if (candidateCount == 1) {
beanName = candidateNames.iterator().next();
}
beanName = BeanFactoryUtils.transformedBeanName(candidateNames.iterator().next());
else {
String primaryCandidate = determinePrimaryCandidate(beanFactory, candidateNames, beanType.toClass());
if (primaryCandidate != null) {
beanName = primaryCandidate;
}
else {
Field field = handler.getField();
throw new IllegalStateException("""
Unable to select a bean to override by wrapping: found %d bean instances of type %s \
(as required by annotated field '%s.%s')%s"""
.formatted(candidateCount, beanType, field.getDeclaringClass().getSimpleName(),
field.getName(), (candidateCount > 0 ? ": " + candidateNames : "")));
}
}
beanName = BeanFactoryUtils.transformedBeanName(beanName);
}
else {
// We are wrapping an existing bean by-name.
Set<String> candidates = getExistingBeanNamesByType(beanFactory, handler, false);
if (!candidates.contains(beanName)) {
throw new IllegalStateException("""
Unable to override bean by wrapping: there is no existing bean \
with name [%s] and type [%s]."""
.formatted(beanName, handler.getBeanType()));
.formatted(beanName, beanType));
}
}
validateBeanDefinition(beanFactory, beanName);
this.beanOverrideRegistry.registerBeanOverrideHandler(handler, beanName);
}
@ -250,6 +265,9 @@ class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor,
private String getBeanNameForType(ConfigurableListableBeanFactory beanFactory, BeanOverrideHandler handler,
boolean requireExistingBean) {
Field field = handler.getField();
ResolvableType beanType = handler.getBeanType();
Set<String> candidateNames = getExistingBeanNamesByType(beanFactory, handler, true);
int candidateCount = candidateNames.size();
if (candidateCount == 1) {
@ -257,19 +275,22 @@ class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor,
}
else if (candidateCount == 0) {
if (requireExistingBean) {
Field field = handler.getField();
throw new IllegalStateException(
"Unable to override bean: no beans of type %s (as required by annotated field '%s.%s')"
.formatted(handler.getBeanType(), field.getDeclaringClass().getSimpleName(), field.getName()));
.formatted(beanType, field.getDeclaringClass().getSimpleName(), field.getName()));
}
return null;
}
Field field = handler.getField();
String primaryCandidate = determinePrimaryCandidate(beanFactory, candidateNames, beanType.toClass());
if (primaryCandidate != null) {
return primaryCandidate;
}
throw new IllegalStateException("""
Unable to select a bean to override: found %s beans of type %s \
(as required by annotated field '%s.%s'): %s"""
.formatted(candidateCount, handler.getBeanType(), field.getDeclaringClass().getSimpleName(),
.formatted(candidateCount, beanType, field.getDeclaringClass().getSimpleName(),
field.getName(), candidateNames));
}
@ -310,6 +331,30 @@ class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor,
return beanNames;
}
@Nullable
private static String determinePrimaryCandidate(
ConfigurableListableBeanFactory beanFactory, Set<String> candidateBeanNames, Class<?> beanType) {
if (candidateBeanNames.isEmpty()) {
return null;
}
String primaryBeanName = null;
for (String candidateBeanName : candidateBeanNames) {
if (beanFactory.containsBeanDefinition(candidateBeanName)) {
BeanDefinition beanDefinition = beanFactory.getBeanDefinition(candidateBeanName);
if (beanDefinition.isPrimary()) {
if (primaryBeanName != null) {
throw new NoUniqueBeanDefinitionException(beanType, candidateBeanNames.size(),
"more than one 'primary' bean found among candidates: " + candidateBeanNames);
}
primaryBeanName = candidateBeanName;
}
}
}
return primaryBeanName;
}
/**
* Create a pseudo-{@link BeanDefinition} for the supplied {@link BeanOverrideHandler},
* whose {@linkplain RootBeanDefinition#getTargetType() target type} and

View File

@ -26,6 +26,7 @@ import org.junit.jupiter.api.Test;
import org.springframework.beans.BeanWrapper;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
@ -42,6 +43,7 @@ import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.util.Assert;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.mockito.Mockito.mock;
@ -167,6 +169,41 @@ class BeanOverrideBeanFactoryPostProcessorTests {
assertThat(context.getBean("counter")).isSameAs(42);
}
@Test // gh-33819
void replaceBeanByTypeWithMultipleCandidatesAndOnePrimary() {
AnnotationConfigApplicationContext context = createContext(TestBeanByTypeTestCase.class);
context.registerBean("description1", String.class, () -> "one");
RootBeanDefinition beanDefinition2 = new RootBeanDefinition(String.class);
beanDefinition2.getConstructorArgumentValues().addIndexedArgumentValue(0, "two");
beanDefinition2.setPrimary(true);
context.registerBeanDefinition("description2", beanDefinition2);
context.refresh();
assertThat(context.getBean("description1", String.class)).isEqualTo("one");
assertThat(context.getBean("description2", String.class)).isEqualTo("overridden");
assertThat(context.getBean(String.class)).isEqualTo("overridden");
}
@Test // gh-33819
void replaceBeanByTypeWithMultipleCandidatesAndMultiplePrimaryBeansFails() {
AnnotationConfigApplicationContext context = createContext(TestBeanByTypeTestCase.class);
RootBeanDefinition beanDefinition1 = new RootBeanDefinition(String.class);
beanDefinition1.getConstructorArgumentValues().addIndexedArgumentValue(0, "one");
beanDefinition1.setPrimary(true);
context.registerBeanDefinition("description1", beanDefinition1);
RootBeanDefinition beanDefinition2 = new RootBeanDefinition(String.class);
beanDefinition2.getConstructorArgumentValues().addIndexedArgumentValue(0, "two");
beanDefinition2.setPrimary(true);
context.registerBeanDefinition("description2", beanDefinition2);
assertThatExceptionOfType(NoUniqueBeanDefinitionException.class)
.isThrownBy(context::refresh)
.withMessage("No qualifying bean of type 'java.lang.String' available: " +
"more than one 'primary' bean found among candidates: [description1, description2]");
}
@Test
void createOrReplaceBeanByNameWithMatchingBeanDefinition() {
AnnotationConfigApplicationContext context = createContext(CaseByNameWithReplaceOrCreateStrategy.class);
@ -428,6 +465,16 @@ class BeanOverrideBeanFactoryPostProcessorTests {
}
}
static class TestBeanByTypeTestCase {
@TestBean
String description;
static String description() {
return "overridden";
}
}
static class TestFactoryBean implements FactoryBean<Object> {
@Override

View File

@ -16,7 +16,6 @@
package org.springframework.test.context.bean.override.mockito.integration;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.MockingDetails;
@ -27,7 +26,6 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.Primary;
import org.springframework.test.context.aot.DisabledInAotMode;
import org.springframework.test.context.bean.override.example.ExampleGenericServiceCaller;
import org.springframework.test.context.bean.override.example.IntegerExampleGenericService;
import org.springframework.test.context.bean.override.example.StringExampleGenericService;
@ -49,8 +47,6 @@ import static org.mockito.Mockito.mockingDetails;
* @see MockitoBeanWithMultipleExistingBeansAndExplicitBeanNameIntegrationTests
* @see MockitoBeanWithMultipleExistingBeansAndExplicitQualifierIntegrationTests
*/
@Disabled("Disabled until @Primary is supported for BeanOverrideStrategy.REPLACE_OR_CREATE")
@DisabledInAotMode
@ExtendWith(SpringExtension.class)
class MockitoBeanWithMultipleExistingBeansAndOnePrimaryIntegrationTests {

View File

@ -16,7 +16,6 @@
package org.springframework.test.context.bean.override.mockito.integration;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.MockingDetails;
@ -27,7 +26,6 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.Primary;
import org.springframework.test.context.aot.DisabledInAotMode;
import org.springframework.test.context.bean.override.example.ExampleGenericServiceCaller;
import org.springframework.test.context.bean.override.example.IntegerExampleGenericService;
import org.springframework.test.context.bean.override.example.StringExampleGenericService;
@ -48,8 +46,6 @@ import static org.mockito.Mockito.mockingDetails;
* @see MockitoSpyBeanWithMultipleExistingBeansAndExplicitBeanNameIntegrationTests
* @see MockitoSpyBeanWithMultipleExistingBeansAndExplicitQualifierIntegrationTests
*/
@Disabled("Disabled until @Primary is supported for @MockitoSpyBean")
@DisabledInAotMode
@ExtendWith(SpringExtension.class)
class MockitoSpyBeanWithMultipleExistingBeansAndOnePrimaryIntegrationTests {