Better support for FactoryBeans in BeanOverrideBeanFactoryPostProcessor

This commit makes sure to account for FactoryBean names when registering
a bean override. In the case of ReplaceDefinition mode, if there is a
factory bean name, it is used to check singleton status and as the name
in the registrar.

Closes gh-32971
This commit is contained in:
Simon Baslé 2024-06-06 15:49:31 +02:00
parent 416eff1b04
commit d38e4d869f
2 changed files with 103 additions and 4 deletions

View File

@ -34,6 +34,7 @@ import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.config.DependencyDescriptor;
import org.springframework.beans.factory.config.SmartInstantiationAwareBeanPostProcessor;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.DefaultBeanNameGenerator;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.core.Ordered;
import org.springframework.core.PriorityOrdered;
@ -128,6 +129,7 @@ class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor,
RootBeanDefinition beanDefinition = createBeanDefinition(overrideMetadata);
String beanName = overrideMetadata.getBeanName();
String beanNameIncludingFactory;
BeanDefinition existingBeanDefinition = null;
if (beanName == null) {
Set<String> candidateNames = getExistingBeanNamesByType(beanFactory, overrideMetadata, true);
@ -139,7 +141,8 @@ class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor,
" (as required by annotated field '" + field.getDeclaringClass().getSimpleName() +
"." + field.getName() + "')" + (candidateCount > 0 ? ": " + candidateNames : ""));
}
beanName = candidateNames.iterator().next();
beanNameIncludingFactory = candidateNames.iterator().next();
beanName = BeanFactoryUtils.transformedBeanName(beanNameIncludingFactory);
existingBeanDefinition = beanFactory.getBeanDefinition(beanName);
}
else {
@ -151,6 +154,7 @@ class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor,
throw new IllegalStateException("Unable to override bean '" + beanName + "': there is no " +
"bean definition to replace with that name of type " + overrideMetadata.getBeanType());
}
beanNameIncludingFactory = beanName;
}
if (existingBeanDefinition != null) {
@ -160,7 +164,7 @@ class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor,
registry.registerBeanDefinition(beanName, beanDefinition);
Object override = overrideMetadata.createOverride(beanName, existingBeanDefinition, null);
if (beanFactory.isSingleton(beanName)) {
if (beanFactory.isSingleton(beanNameIncludingFactory)) {
// Now we have an instance (the override) that we can register.
// At this stage we don't expect a singleton instance to be present,
// and this call will throw if there is such an instance already.
@ -168,7 +172,7 @@ class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor,
}
overrideMetadata.track(override, beanFactory);
this.overrideRegistrar.registerNameForMetadata(overrideMetadata, beanName);
this.overrideRegistrar.registerNameForMetadata(overrideMetadata, beanNameIncludingFactory);
}
/**
@ -190,7 +194,7 @@ class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor,
" (as required by annotated field '" + field.getDeclaringClass().getSimpleName() +
"." + field.getName() + "')" + (candidateCount > 0 ? ": " + candidateNames : ""));
}
beanName = candidateNames.iterator().next();
beanName = BeanFactoryUtils.transformedBeanName(candidateNames.iterator().next());
}
else {
Set<String> candidates = getExistingBeanNamesByType(beanFactory, metadata, false);

View File

@ -0,0 +1,95 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.test.context.bean.override.mockito;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
/**
* Test {@link MockitoBean @MockitoBean} for a factory bean.
*
* @author Phillip Webb
*/
@ExtendWith(SpringExtension.class)
class MockitoBeanForBeanFactoryIntegrationTests {
// spring-boot/gh-7439
@MockitoBean
private TestFactoryBean testFactoryBean;
@Autowired
private ApplicationContext applicationContext;
@Test
@SuppressWarnings({ "unchecked", "rawtypes" })
void testName() {
TestBean testBean = mock(TestBean.class);
given(testBean.hello()).willReturn("amock");
given(this.testFactoryBean.getObjectType()).willReturn((Class) TestBean.class);
given(this.testFactoryBean.getObject()).willReturn(testBean);
TestBean bean = this.applicationContext.getBean(TestBean.class);
assertThat(bean.hello()).isEqualTo("amock");
}
@Configuration(proxyBeanMethods = false)
static class Config {
@Bean
TestFactoryBean testFactoryBean() {
return new TestFactoryBean();
}
}
static class TestFactoryBean implements FactoryBean<TestBean> {
@Override
public TestBean getObject() {
return () -> "normal";
}
@Override
public Class<?> getObjectType() {
return TestBean.class;
}
@Override
public boolean isSingleton() {
return false;
}
}
interface TestBean {
String hello();
}
}