From 5d591edbf81b271fe2b5152f883bc14cb7c4f99c Mon Sep 17 00:00:00 2001 From: Jakub Kubrynski Date: Fri, 14 Feb 2014 22:53:37 -0800 Subject: [PATCH] Consider FactoryBean classes in OnBeanCondition Update OnBeanCondition to attempt to consider FactoryBean classes for bean type matches. To ensure early instantiation does not occur, the object type from the FactoryBean is deduced by resolving generics on the declaration. Fixes gh-355 --- .../condition/OnBeanCondition.java | 109 +++++++++++++++--- .../ConditionalOnMissingBeanTests.java | 93 +++++++++++---- 2 files changed, 167 insertions(+), 35 deletions(-) diff --git a/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/condition/OnBeanCondition.java b/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/condition/OnBeanCondition.java index 57ccd264d3e..435a5f659a0 100644 --- a/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/condition/OnBeanCondition.java +++ b/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/condition/OnBeanCondition.java @@ -20,15 +20,24 @@ import java.lang.annotation.Annotation; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Set; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.HierarchicalBeanFactory; +import org.springframework.beans.factory.ListableBeanFactory; +import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Condition; import org.springframework.context.annotation.ConditionContext; import org.springframework.context.annotation.ConfigurationCondition; +import org.springframework.core.ResolvableType; import org.springframework.core.type.AnnotatedTypeMetadata; import org.springframework.core.type.MethodMetadata; import org.springframework.util.Assert; @@ -43,6 +52,7 @@ import org.springframework.util.StringUtils; * * @author Phillip Webb * @author Dave Syer + * @author Jakub Kubrynski */ class OnBeanCondition extends SpringBootCondition implements ConfigurationCondition { @@ -100,8 +110,8 @@ class OnBeanCondition extends SpringBootCondition implements ConfigurationCondit boolean considerHierarchy = beans.getStrategy() == SearchStrategy.ALL; for (String type : beans.getTypes()) { - beanNames.addAll(Arrays.asList(getBeanNamesForType(beanFactory, type, - context.getClassLoader(), considerHierarchy))); + beanNames.addAll(getBeanNamesForType(beanFactory, type, + context.getClassLoader(), considerHierarchy)); } for (String annotation : beans.getAnnotations()) { @@ -126,25 +136,94 @@ class OnBeanCondition extends SpringBootCondition implements ConfigurationCondit return beanFactory.containsLocalBean(beanName); } - private String[] getBeanNamesForType(ConfigurableListableBeanFactory beanFactory, - String type, ClassLoader classLoader, boolean considerHierarchy) - throws LinkageError { - // eagerInit set to false to prevent early instantiation (some - // factory beans will not be able to determine their object type at this - // stage, so those are not eligible for matching this condition) + private Collection getBeanNamesForType( + ConfigurableListableBeanFactory beanFactory, String type, + ClassLoader classLoader, boolean considerHierarchy) throws LinkageError { try { - Class typeClass = ClassUtils.forName(type, classLoader); - if (considerHierarchy) { - return BeanFactoryUtils.beanNamesForTypeIncludingAncestors(beanFactory, - typeClass, false, false); - } - return beanFactory.getBeanNamesForType(typeClass, false, false); + Set result = new LinkedHashSet(); + collectBeanNamesForType(result, beanFactory, + ClassUtils.forName(type, classLoader), considerHierarchy); + return result; } catch (ClassNotFoundException ex) { - return NO_BEANS; + return Collections.emptySet(); } } + private void collectBeanNamesForType(Set result, + ListableBeanFactory beanFactory, Class type, boolean considerHierarchy) { + // eagerInit set to false to prevent early instantiation + result.addAll(Arrays.asList(beanFactory.getBeanNamesForType(type, true, false))); + if (beanFactory instanceof ConfigurableListableBeanFactory) { + collectBeanNamesForTypeFromFactoryBeans(result, + (ConfigurableListableBeanFactory) beanFactory, type); + } + if (considerHierarchy && beanFactory instanceof HierarchicalBeanFactory) { + BeanFactory parent = ((HierarchicalBeanFactory) beanFactory) + .getParentBeanFactory(); + if (parent instanceof ListableBeanFactory) { + collectBeanNamesForType(result, (ListableBeanFactory) parent, type, + considerHierarchy); + } + } + } + + /** + * Attempt to collect bean names for type by considering FactoryBean generics. Some + * factory beans will not be able to determine their object type at this stage, so + * those are not eligible for matching this condition. + */ + private void collectBeanNamesForTypeFromFactoryBeans(Set result, + ConfigurableListableBeanFactory beanFactory, Class type) { + String[] names = beanFactory.getBeanNamesForType(FactoryBean.class, true, false); + for (String name : names) { + name = BeanFactoryUtils.transformedBeanName(name); + BeanDefinition beanDefinition = beanFactory.getBeanDefinition(name); + Class generic = getFactoryBeanGeneric(beanFactory, beanDefinition); + if (generic != null && ClassUtils.isAssignable(type, generic)) { + result.add(name); + } + } + } + + private Class getFactoryBeanGeneric(ConfigurableListableBeanFactory beanFactory, + BeanDefinition definition) { + try { + if (StringUtils.hasLength(definition.getFactoryBeanName()) + && StringUtils.hasLength(definition.getFactoryMethodName())) { + return getConfigurationClassFactoryBeanGeneric(beanFactory, definition); + } + if (StringUtils.hasLength(definition.getBeanClassName())) { + return getDirectFactoryBeanGeneric(beanFactory, definition); + } + } + catch (Exception ex) { + } + return null; + } + + private Class getConfigurationClassFactoryBeanGeneric( + ConfigurableListableBeanFactory beanFactory, BeanDefinition definition) + throws Exception { + BeanDefinition factoryDefinition = beanFactory.getBeanDefinition(definition + .getFactoryBeanName()); + Class factoryClass = ClassUtils.forName(factoryDefinition.getBeanClassName(), + beanFactory.getBeanClassLoader()); + Method method = ReflectionUtils.findMethod(factoryClass, + definition.getFactoryMethodName()); + return ResolvableType.forMethodReturnType(method).as(FactoryBean.class) + .resolveGeneric(); + } + + private Class getDirectFactoryBeanGeneric( + ConfigurableListableBeanFactory beanFactory, BeanDefinition definition) + throws ClassNotFoundException, LinkageError { + Class factoryBeanClass = ClassUtils.forName(definition.getBeanClassName(), + beanFactory.getBeanClassLoader()); + return ResolvableType.forClass(factoryBeanClass).as(FactoryBean.class) + .resolveGeneric(); + } + private String[] getBeanNamesForAnnotation( ConfigurableListableBeanFactory beanFactory, String type, ClassLoader classLoader, boolean considerHierarchy) throws LinkageError { diff --git a/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/condition/ConditionalOnMissingBeanTests.java b/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/condition/ConditionalOnMissingBeanTests.java index e96531ba5ce..50ab687e5da 100644 --- a/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/condition/ConditionalOnMissingBeanTests.java +++ b/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/condition/ConditionalOnMissingBeanTests.java @@ -16,7 +16,6 @@ package org.springframework.boot.autoconfigure.condition; -import org.junit.Ignore; import org.junit.Test; import org.springframework.beans.factory.FactoryBean; import org.springframework.boot.autoconfigure.PropertyPlaceholderAutoConfiguration; @@ -38,6 +37,7 @@ import static org.junit.Assert.assertTrue; * * @author Dave Syer * @author Phillip Webb + * @author Jakub Kubrynski */ @SuppressWarnings("resource") public class ConditionalOnMissingBeanTests { @@ -102,7 +102,7 @@ public class ConditionalOnMissingBeanTests { @Test public void testAnnotationOnMissingBeanConditionWithEagerFactoryBean() { this.context.register(FooConfiguration.class, OnAnnotationConfiguration.class, - ConfigurationWithFactoryBean.class, + FactoryBeanXmlConfiguration.class, PropertyPlaceholderAutoConfiguration.class); this.context.refresh(); assertFalse(this.context.containsBean("bar")); @@ -111,22 +111,44 @@ public class ConditionalOnMissingBeanTests { } @Test - @Ignore("This will never work - you need to use XML for FactoryBeans, or else call getObject() inside the @Bean method") public void testOnMissingBeanConditionWithFactoryBean() { - this.context.register(ExampleBeanAndFactoryBeanConfiguration.class, + this.context.register(FactoryBeanConfiguration.class, + ConditionalOnFactoryBean.class, PropertyPlaceholderAutoConfiguration.class); this.context.refresh(); - // There should be only one - this.context.getBean(ExampleBean.class); + assertThat(this.context.getBean(ExampleBean.class).toString(), + equalTo("fromFactory")); + } + + @Test + public void testOnMissingBeanConditionWithConcreteFactoryBean() { + this.context.register(ConcreteFactoryBeanConfiguration.class, + ConditionalOnFactoryBean.class, + PropertyPlaceholderAutoConfiguration.class); + this.context.refresh(); + assertThat(this.context.getBean(ExampleBean.class).toString(), + equalTo("fromFactory")); + } + + @Test + public void testOnMissingBeanConditionWithUnhelpfulFactoryBean() { + this.context.register(UnhelpfulFactoryBeanConfiguration.class, + ConditionalOnFactoryBean.class, + PropertyPlaceholderAutoConfiguration.class); + this.context.refresh(); + // We could not tell that the FactoryBean would ultimately create an ExampleBean + assertThat(this.context.getBeansOfType(ExampleBean.class).values().size(), + equalTo(2)); } @Test public void testOnMissingBeanConditionWithFactoryBeanInXml() { - this.context.register(ConfigurationWithFactoryBean.class, + this.context.register(FactoryBeanXmlConfiguration.class, + ConditionalOnFactoryBean.class, PropertyPlaceholderAutoConfiguration.class); this.context.refresh(); - // There should be only one - this.context.getBean(ExampleBean.class); + assertThat(this.context.getBean(ExampleBean.class).toString(), + equalTo("fromFactory")); } @Configuration @@ -139,17 +161,41 @@ public class ConditionalOnMissingBeanTests { } @Configuration - protected static class ExampleBeanAndFactoryBeanConfiguration { - + protected static class FactoryBeanConfiguration { @Bean public FactoryBean exampleBeanFactoryBean() { return new ExampleFactoryBean("foo"); } + } + @Configuration + protected static class ConcreteFactoryBeanConfiguration { + @Bean + public ExampleFactoryBean exampleBeanFactoryBean() { + return new ExampleFactoryBean("foo"); + } + } + + @Configuration + protected static class UnhelpfulFactoryBeanConfiguration { + @Bean + @SuppressWarnings("rawtypes") + public FactoryBean exampleBeanFactoryBean() { + return new ExampleFactoryBean("foo"); + } + } + + @Configuration + @ImportResource("org/springframework/boot/autoconfigure/condition/factorybean.xml") + protected static class FactoryBeanXmlConfiguration { + } + + @Configuration + protected static class ConditionalOnFactoryBean { @Bean @ConditionalOnMissingBean(ExampleBean.class) public ExampleBean createExampleBean() { - return new ExampleBean(); + return new ExampleBean("direct"); } } @@ -162,11 +208,6 @@ public class ConditionalOnMissingBeanTests { } } - @Configuration - @ImportResource("org/springframework/boot/autoconfigure/condition/factorybean.xml") - protected static class ConfigurationWithFactoryBean { - } - @Configuration @EnableScheduling protected static class FooConfiguration { @@ -198,7 +239,7 @@ public class ConditionalOnMissingBeanTests { protected static class ExampleBeanConfiguration { @Bean public ExampleBean exampleBean() { - return new ExampleBean(); + return new ExampleBean("test"); } } @@ -208,12 +249,24 @@ public class ConditionalOnMissingBeanTests { @Bean @ConditionalOnMissingBean public ExampleBean exampleBean2() { - return new ExampleBean(); + return new ExampleBean("test"); } } public static class ExampleBean { + + private String value; + + public ExampleBean(String value) { + this.value = value; + } + + @Override + public String toString() { + return this.value; + } + } public static class ExampleFactoryBean implements FactoryBean { @@ -224,7 +277,7 @@ public class ConditionalOnMissingBeanTests { @Override public ExampleBean getObject() throws Exception { - return new ExampleBean(); + return new ExampleBean("fromFactory"); } @Override