Support @MockBean/@SpyBean with @Primary

Update `MockitoPostProcessor` so that `@MockBean` and `@SpyBean`
work consistently when combined with `@Primary`.

See gh-11077

Co-authored-by: Andreas Neiser <andreas.neiser@gmail.com>
This commit is contained in:
Phillip Webb 2018-09-05 16:30:05 -07:00
parent 82b27c60a4
commit c777614d8f
1 changed files with 37 additions and 25 deletions

View File

@ -19,6 +19,7 @@ package org.springframework.boot.test.mock.mockito;
import java.beans.PropertyDescriptor; import java.beans.PropertyDescriptor;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
@ -184,6 +185,8 @@ public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAda
String beanName = getBeanName(beanFactory, registry, definition, beanDefinition); String beanName = getBeanName(beanFactory, registry, definition, beanDefinition);
String transformedBeanName = BeanFactoryUtils.transformedBeanName(beanName); String transformedBeanName = BeanFactoryUtils.transformedBeanName(beanName);
if (registry.containsBeanDefinition(transformedBeanName)) { if (registry.containsBeanDefinition(transformedBeanName)) {
BeanDefinition existing = registry.getBeanDefinition(transformedBeanName);
copyBeanDefinitionDetails(existing, beanDefinition);
registry.removeBeanDefinition(transformedBeanName); registry.removeBeanDefinition(transformedBeanName);
} }
registry.registerBeanDefinition(transformedBeanName, beanDefinition); registry.registerBeanDefinition(transformedBeanName, beanDefinition);
@ -196,6 +199,10 @@ public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAda
} }
} }
private void copyBeanDefinitionDetails(BeanDefinition from, RootBeanDefinition to) {
to.setPrimary(from.isPrimary());
}
private RootBeanDefinition createBeanDefinition(MockDefinition mockDefinition) { private RootBeanDefinition createBeanDefinition(MockDefinition mockDefinition) {
RootBeanDefinition definition = new RootBeanDefinition( RootBeanDefinition definition = new RootBeanDefinition(
mockDefinition.getTypeToMock().resolve()); mockDefinition.getTypeToMock().resolve());
@ -212,13 +219,19 @@ public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAda
if (StringUtils.hasLength(mockDefinition.getName())) { if (StringUtils.hasLength(mockDefinition.getName())) {
return mockDefinition.getName(); return mockDefinition.getName();
} }
Set<String> existingBeans = findCandidateBeans(beanFactory, mockDefinition); Set<String> existingBeans = findCandidateBeans(beanFactory,
mockDefinition.getTypeToMock(), mockDefinition.getQualifier());
if (existingBeans.isEmpty()) { if (existingBeans.isEmpty()) {
return this.beanNameGenerator.generateBeanName(beanDefinition, registry); return this.beanNameGenerator.generateBeanName(beanDefinition, registry);
} }
if (existingBeans.size() == 1) { if (existingBeans.size() == 1) {
return existingBeans.iterator().next(); return existingBeans.iterator().next();
} }
String primaryCandidate = determinePrimaryCandidate(registry, existingBeans,
mockDefinition.getTypeToMock());
if (primaryCandidate != null) {
return primaryCandidate;
}
throw new IllegalStateException( throw new IllegalStateException(
"Unable to register mock bean " + mockDefinition.getTypeToMock() "Unable to register mock bean " + mockDefinition.getTypeToMock()
+ " expected a single matching bean to replace but found " + " expected a single matching bean to replace but found "
@ -226,22 +239,21 @@ public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAda
} }
private void registerSpy(ConfigurableListableBeanFactory beanFactory, private void registerSpy(ConfigurableListableBeanFactory beanFactory,
BeanDefinitionRegistry registry, SpyDefinition definition, Field field) { BeanDefinitionRegistry registry, SpyDefinition spyDefinition, Field field) {
String[] existingBeans = getExistingBeans(beanFactory, definition.getTypeToSpy()); Set<String> existingBeans = findCandidateBeans(beanFactory,
spyDefinition.getTypeToSpy(), spyDefinition.getQualifier());
if (ObjectUtils.isEmpty(existingBeans)) { if (ObjectUtils.isEmpty(existingBeans)) {
createSpy(registry, definition, field); createSpy(registry, spyDefinition, field);
} }
else { else {
registerSpies(registry, definition, field, existingBeans); registerSpies(registry, spyDefinition, field, existingBeans);
} }
} }
private Set<String> findCandidateBeans(ConfigurableListableBeanFactory beanFactory, private Set<String> findCandidateBeans(ConfigurableListableBeanFactory beanFactory,
MockDefinition mockDefinition) { ResolvableType type, QualifierDefinition qualifier) {
QualifierDefinition qualifier = mockDefinition.getQualifier();
Set<String> candidates = new TreeSet<>(); Set<String> candidates = new TreeSet<>();
for (String candidate : getExistingBeans(beanFactory, for (String candidate : getExistingBeans(beanFactory, type)) {
mockDefinition.getTypeToMock())) {
if (qualifier == null || qualifier.matches(beanFactory, candidate)) { if (qualifier == null || qualifier.matches(beanFactory, candidate)) {
candidates.add(candidate); candidates.add(candidate);
} }
@ -249,7 +261,7 @@ public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAda
return candidates; return candidates;
} }
private String[] getExistingBeans(ConfigurableListableBeanFactory beanFactory, private Set<String> getExistingBeans(ConfigurableListableBeanFactory beanFactory,
ResolvableType type) { ResolvableType type) {
Set<String> beans = new LinkedHashSet<>( Set<String> beans = new LinkedHashSet<>(
Arrays.asList(beanFactory.getBeanNamesForType(type))); Arrays.asList(beanFactory.getBeanNamesForType(type)));
@ -263,7 +275,7 @@ public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAda
} }
} }
beans.removeIf(this::isScopedTarget); beans.removeIf(this::isScopedTarget);
return StringUtils.toStringArray(beans); return beans;
} }
private boolean isScopedTarget(String beanName) { private boolean isScopedTarget(String beanName) {
@ -275,49 +287,49 @@ public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAda
} }
} }
private void createSpy(BeanDefinitionRegistry registry, SpyDefinition definition, private void createSpy(BeanDefinitionRegistry registry, SpyDefinition spyDefinition,
Field field) { Field field) {
RootBeanDefinition beanDefinition = new RootBeanDefinition( RootBeanDefinition beanDefinition = new RootBeanDefinition(
definition.getTypeToSpy().resolve()); spyDefinition.getTypeToSpy().resolve());
String beanName = this.beanNameGenerator.generateBeanName(beanDefinition, String beanName = this.beanNameGenerator.generateBeanName(beanDefinition,
registry); registry);
registry.registerBeanDefinition(beanName, beanDefinition); registry.registerBeanDefinition(beanName, beanDefinition);
registerSpy(definition, field, beanName); registerSpy(spyDefinition, field, beanName);
} }
private void registerSpies(BeanDefinitionRegistry registry, SpyDefinition definition, private void registerSpies(BeanDefinitionRegistry registry,
Field field, String[] existingBeans) { SpyDefinition spyDefinition, Field field, Collection<String> existingBeans) {
try { try {
registerSpy(definition, field, String beanName = determineBeanName(existingBeans, spyDefinition, registry);
determineBeanName(existingBeans, definition, registry)); registerSpy(spyDefinition, field, beanName);
} }
catch (RuntimeException ex) { catch (RuntimeException ex) {
throw new IllegalStateException( throw new IllegalStateException(
"Unable to register spy bean " + definition.getTypeToSpy(), ex); "Unable to register spy bean " + spyDefinition.getTypeToSpy(), ex);
} }
} }
private String determineBeanName(String[] existingBeans, SpyDefinition definition, private String determineBeanName(Collection<String> existingBeans,
BeanDefinitionRegistry registry) { SpyDefinition definition, BeanDefinitionRegistry registry) {
if (StringUtils.hasText(definition.getName())) { if (StringUtils.hasText(definition.getName())) {
return definition.getName(); return definition.getName();
} }
if (existingBeans.length == 1) { if (existingBeans.size() == 1) {
return existingBeans[0]; return existingBeans.iterator().next();
} }
return determinePrimaryCandidate(registry, existingBeans, return determinePrimaryCandidate(registry, existingBeans,
definition.getTypeToSpy()); definition.getTypeToSpy());
} }
private String determinePrimaryCandidate(BeanDefinitionRegistry registry, private String determinePrimaryCandidate(BeanDefinitionRegistry registry,
String[] candidateBeanNames, ResolvableType type) { Collection<String> candidateBeanNames, ResolvableType type) {
String primaryBeanName = null; String primaryBeanName = null;
for (String candidateBeanName : candidateBeanNames) { for (String candidateBeanName : candidateBeanNames) {
BeanDefinition beanDefinition = registry.getBeanDefinition(candidateBeanName); BeanDefinition beanDefinition = registry.getBeanDefinition(candidateBeanName);
if (beanDefinition.isPrimary()) { if (beanDefinition.isPrimary()) {
if (primaryBeanName != null) { if (primaryBeanName != null) {
throw new NoUniqueBeanDefinitionException(type.resolve(), throw new NoUniqueBeanDefinitionException(type.resolve(),
candidateBeanNames.length, candidateBeanNames.size(),
"more than one 'primary' bean found among candidates: " "more than one 'primary' bean found among candidates: "
+ Arrays.asList(candidateBeanNames)); + Arrays.asList(candidateBeanNames));
} }