From c777614d8f3a1072b9e3d03acf6adf2f766f9e48 Mon Sep 17 00:00:00 2001 From: Phillip Webb Date: Wed, 5 Sep 2018 16:30:05 -0700 Subject: [PATCH] Support @MockBean/@SpyBean with @Primary Update `MockitoPostProcessor` so that `@MockBean` and `@SpyBean` work consistently when combined with `@Primary`. See gh-11077 Co-authored-by: Andreas Neiser --- .../mock/mockito/MockitoPostProcessor.java | 62 +++++++++++-------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessor.java b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessor.java index 91f4b8cdac3..82112e6bc59 100644 --- a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessor.java +++ b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessor.java @@ -19,6 +19,7 @@ package org.springframework.boot.test.mock.mockito; import java.beans.PropertyDescriptor; import java.lang.reflect.Field; import java.util.Arrays; +import java.util.Collection; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.LinkedHashSet; @@ -184,6 +185,8 @@ public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAda String beanName = getBeanName(beanFactory, registry, definition, beanDefinition); String transformedBeanName = BeanFactoryUtils.transformedBeanName(beanName); if (registry.containsBeanDefinition(transformedBeanName)) { + BeanDefinition existing = registry.getBeanDefinition(transformedBeanName); + copyBeanDefinitionDetails(existing, beanDefinition); registry.removeBeanDefinition(transformedBeanName); } registry.registerBeanDefinition(transformedBeanName, beanDefinition); @@ -196,6 +199,10 @@ public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAda } } + private void copyBeanDefinitionDetails(BeanDefinition from, RootBeanDefinition to) { + to.setPrimary(from.isPrimary()); + } + private RootBeanDefinition createBeanDefinition(MockDefinition mockDefinition) { RootBeanDefinition definition = new RootBeanDefinition( mockDefinition.getTypeToMock().resolve()); @@ -212,13 +219,19 @@ public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAda if (StringUtils.hasLength(mockDefinition.getName())) { return mockDefinition.getName(); } - Set existingBeans = findCandidateBeans(beanFactory, mockDefinition); + Set existingBeans = findCandidateBeans(beanFactory, + mockDefinition.getTypeToMock(), mockDefinition.getQualifier()); if (existingBeans.isEmpty()) { return this.beanNameGenerator.generateBeanName(beanDefinition, registry); } if (existingBeans.size() == 1) { return existingBeans.iterator().next(); } + String primaryCandidate = determinePrimaryCandidate(registry, existingBeans, + mockDefinition.getTypeToMock()); + if (primaryCandidate != null) { + return primaryCandidate; + } throw new IllegalStateException( "Unable to register mock bean " + mockDefinition.getTypeToMock() + " expected a single matching bean to replace but found " @@ -226,22 +239,21 @@ public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAda } private void registerSpy(ConfigurableListableBeanFactory beanFactory, - BeanDefinitionRegistry registry, SpyDefinition definition, Field field) { - String[] existingBeans = getExistingBeans(beanFactory, definition.getTypeToSpy()); + BeanDefinitionRegistry registry, SpyDefinition spyDefinition, Field field) { + Set existingBeans = findCandidateBeans(beanFactory, + spyDefinition.getTypeToSpy(), spyDefinition.getQualifier()); if (ObjectUtils.isEmpty(existingBeans)) { - createSpy(registry, definition, field); + createSpy(registry, spyDefinition, field); } else { - registerSpies(registry, definition, field, existingBeans); + registerSpies(registry, spyDefinition, field, existingBeans); } } private Set findCandidateBeans(ConfigurableListableBeanFactory beanFactory, - MockDefinition mockDefinition) { - QualifierDefinition qualifier = mockDefinition.getQualifier(); + ResolvableType type, QualifierDefinition qualifier) { Set candidates = new TreeSet<>(); - for (String candidate : getExistingBeans(beanFactory, - mockDefinition.getTypeToMock())) { + for (String candidate : getExistingBeans(beanFactory, type)) { if (qualifier == null || qualifier.matches(beanFactory, candidate)) { candidates.add(candidate); } @@ -249,7 +261,7 @@ public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAda return candidates; } - private String[] getExistingBeans(ConfigurableListableBeanFactory beanFactory, + private Set getExistingBeans(ConfigurableListableBeanFactory beanFactory, ResolvableType type) { Set beans = new LinkedHashSet<>( Arrays.asList(beanFactory.getBeanNamesForType(type))); @@ -263,7 +275,7 @@ public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAda } } beans.removeIf(this::isScopedTarget); - return StringUtils.toStringArray(beans); + return beans; } private boolean isScopedTarget(String beanName) { @@ -275,49 +287,49 @@ public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAda } } - private void createSpy(BeanDefinitionRegistry registry, SpyDefinition definition, + private void createSpy(BeanDefinitionRegistry registry, SpyDefinition spyDefinition, Field field) { RootBeanDefinition beanDefinition = new RootBeanDefinition( - definition.getTypeToSpy().resolve()); + spyDefinition.getTypeToSpy().resolve()); String beanName = this.beanNameGenerator.generateBeanName(beanDefinition, registry); registry.registerBeanDefinition(beanName, beanDefinition); - registerSpy(definition, field, beanName); + registerSpy(spyDefinition, field, beanName); } - private void registerSpies(BeanDefinitionRegistry registry, SpyDefinition definition, - Field field, String[] existingBeans) { + private void registerSpies(BeanDefinitionRegistry registry, + SpyDefinition spyDefinition, Field field, Collection existingBeans) { try { - registerSpy(definition, field, - determineBeanName(existingBeans, definition, registry)); + String beanName = determineBeanName(existingBeans, spyDefinition, registry); + registerSpy(spyDefinition, field, beanName); } catch (RuntimeException ex) { throw new IllegalStateException( - "Unable to register spy bean " + definition.getTypeToSpy(), ex); + "Unable to register spy bean " + spyDefinition.getTypeToSpy(), ex); } } - private String determineBeanName(String[] existingBeans, SpyDefinition definition, - BeanDefinitionRegistry registry) { + private String determineBeanName(Collection existingBeans, + SpyDefinition definition, BeanDefinitionRegistry registry) { if (StringUtils.hasText(definition.getName())) { return definition.getName(); } - if (existingBeans.length == 1) { - return existingBeans[0]; + if (existingBeans.size() == 1) { + return existingBeans.iterator().next(); } return determinePrimaryCandidate(registry, existingBeans, definition.getTypeToSpy()); } private String determinePrimaryCandidate(BeanDefinitionRegistry registry, - String[] candidateBeanNames, ResolvableType type) { + Collection candidateBeanNames, ResolvableType type) { String primaryBeanName = null; for (String candidateBeanName : candidateBeanNames) { BeanDefinition beanDefinition = registry.getBeanDefinition(candidateBeanName); if (beanDefinition.isPrimary()) { if (primaryBeanName != null) { throw new NoUniqueBeanDefinitionException(type.resolve(), - candidateBeanNames.length, + candidateBeanNames.size(), "more than one 'primary' bean found among candidates: " + Arrays.asList(candidateBeanNames)); }