Fixed type resolution for uninitialized factory-method declaration

Issue: SPR-11112
This commit is contained in:
Juergen Hoeller 2013-12-09 18:53:27 +01:00
parent a3b022aa48
commit 5dcd28761c
4 changed files with 119 additions and 13 deletions

View File

@ -657,10 +657,10 @@ public abstract class AbstractAutowireCapableBeanFactory extends AbstractBeanFac
// If all factory methods have the same return type, return that type. // If all factory methods have the same return type, return that type.
// Can't clearly figure out exact method due to type converting / autowiring! // Can't clearly figure out exact method due to type converting / autowiring!
Class<?> commonType = null;
boolean cache = false; boolean cache = false;
int minNrOfArgs = mbd.getConstructorArgumentValues().getArgumentCount(); int minNrOfArgs = mbd.getConstructorArgumentValues().getArgumentCount();
Method[] candidates = ReflectionUtils.getUniqueDeclaredMethods(factoryClass); Method[] candidates = ReflectionUtils.getUniqueDeclaredMethods(factoryClass);
Set<Class<?>> returnTypes = new HashSet<Class<?>>(1);
for (Method factoryMethod : candidates) { for (Method factoryMethod : candidates) {
if (Modifier.isStatic(factoryMethod.getModifiers()) == isStatic && if (Modifier.isStatic(factoryMethod.getModifiers()) == isStatic &&
factoryMethod.getName().equals(mbd.getFactoryMethodName()) && factoryMethod.getName().equals(mbd.getFactoryMethodName()) &&
@ -694,7 +694,7 @@ public abstract class AbstractAutowireCapableBeanFactory extends AbstractBeanFac
factoryMethod, args, getBeanClassLoader()); factoryMethod, args, getBeanClassLoader());
if (returnType != null) { if (returnType != null) {
cache = true; cache = true;
returnTypes.add(returnType); commonType = ClassUtils.determineCommonAncestor(returnType, commonType);
} }
} }
catch (Throwable ex) { catch (Throwable ex) {
@ -704,18 +704,17 @@ public abstract class AbstractAutowireCapableBeanFactory extends AbstractBeanFac
} }
} }
else { else {
returnTypes.add(factoryMethod.getReturnType()); commonType = ClassUtils.determineCommonAncestor(factoryMethod.getReturnType(), commonType);
} }
} }
} }
if (returnTypes.size() == 1) { if (commonType != null) {
// Clear return type found: all factory methods return same type. // Clear return type found: all factory methods return same type.
Class<?> result = returnTypes.iterator().next();
if (cache) { if (cache) {
mbd.resolvedFactoryMethodReturnType = result; mbd.resolvedFactoryMethodReturnType = commonType;
} }
return result; return commonType;
} }
else { else {
// Ambiguous return types found: return null to indicate "not determinable". // Ambiguous return types found: return null to indicate "not determinable".

View File

@ -54,11 +54,11 @@ public class FactoryMethods {
return new FactoryMethods(tb, name, num); return new FactoryMethods(tb, name, num);
} }
static FactoryMethods newInstance(TestBean tb, int num, Integer something) { static ExtendedFactoryMethods newInstance(TestBean tb, int num, Integer something) {
if (something != null) { if (something != null) {
throw new IllegalStateException("Should never be called with non-null value"); throw new IllegalStateException("Should never be called with non-null value");
} }
return new FactoryMethods(tb, null, num); return new ExtendedFactoryMethods(tb, null, num);
} }
@SuppressWarnings("unused") @SuppressWarnings("unused")
@ -119,4 +119,12 @@ public class FactoryMethods {
this.name = name; this.name = name;
} }
public static class ExtendedFactoryMethods extends FactoryMethods {
ExtendedFactoryMethods(TestBean tb, String name, int num) {
super(tb, name, num);
}
}
} }

View File

@ -1143,6 +1143,39 @@ public abstract class ClassUtils {
return Proxy.getProxyClass(classLoader, interfaces); return Proxy.getProxyClass(classLoader, interfaces);
} }
/**
* Determine the common ancestor of the given classes, if any.
* @param clazz1 the class to introspect
* @param clazz2 the other class to introspect
* @return the common ancestor (i.e. common superclass, one interface
* extending the other), or {@code null} if none found. If any of the
* given classes is {@code null}, the other class will be returned.
* @since 3.2.6
*/
public static Class<?> determineCommonAncestor(Class<?> clazz1, Class<?> clazz2) {
if (clazz1 == null) {
return clazz2;
}
if (clazz2 == null) {
return clazz1;
}
if (clazz1.isAssignableFrom(clazz2)) {
return clazz1;
}
if (clazz2.isAssignableFrom(clazz1)) {
return clazz2;
}
Class<?> ancestor = clazz1;
do {
ancestor = ancestor.getSuperclass();
if (ancestor == null || Object.class.equals(ancestor)) {
return null;
}
}
while (!ancestor.isAssignableFrom(clazz2));
return ancestor;
}
/** /**
* Check whether the given class is visible in the given ClassLoader. * Check whether the given class is visible in the given ClassLoader.
* @param clazz the class to check (typically an interface) * @param clazz the class to check (typically an interface)

View File

@ -20,41 +20,48 @@ import java.io.Serializable;
import java.lang.reflect.InvocationTargetException; import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.lang.reflect.Proxy; import java.lang.reflect.Proxy;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Set;
import junit.framework.TestCase;
import org.springframework.tests.sample.objects.DerivedTestObject; import org.springframework.tests.sample.objects.DerivedTestObject;
import org.springframework.tests.sample.objects.ITestInterface; import org.springframework.tests.sample.objects.ITestInterface;
import org.springframework.tests.sample.objects.ITestObject; import org.springframework.tests.sample.objects.ITestObject;
import org.springframework.tests.sample.objects.TestObject; import org.springframework.tests.sample.objects.TestObject;
import static org.junit.Assert.*;
import org.junit.Before;
import org.junit.Test;
/** /**
* @author Colin Sampaleanu * @author Colin Sampaleanu
* @author Juergen Hoeller * @author Juergen Hoeller
* @author Rob Harrop * @author Rob Harrop
* @author Rick Evans * @author Rick Evans
*/ */
public class ClassUtilsTests extends TestCase { public class ClassUtilsTests {
private ClassLoader classLoader = getClass().getClassLoader(); private ClassLoader classLoader = getClass().getClassLoader();
@Override @Before
public void setUp() { public void setUp() {
InnerClass.noArgCalled = false; InnerClass.noArgCalled = false;
InnerClass.argCalled = false; InnerClass.argCalled = false;
InnerClass.overloadedCalled = false; InnerClass.overloadedCalled = false;
} }
@Test
public void testIsPresent() throws Exception { public void testIsPresent() throws Exception {
assertTrue(ClassUtils.isPresent("java.lang.String", classLoader)); assertTrue(ClassUtils.isPresent("java.lang.String", classLoader));
assertFalse(ClassUtils.isPresent("java.lang.MySpecialString", classLoader)); assertFalse(ClassUtils.isPresent("java.lang.MySpecialString", classLoader));
} }
@Test
public void testForName() throws ClassNotFoundException { public void testForName() throws ClassNotFoundException {
assertEquals(String.class, ClassUtils.forName("java.lang.String", classLoader)); assertEquals(String.class, ClassUtils.forName("java.lang.String", classLoader));
assertEquals(String[].class, ClassUtils.forName("java.lang.String[]", classLoader)); assertEquals(String[].class, ClassUtils.forName("java.lang.String[]", classLoader));
@ -69,6 +76,7 @@ public class ClassUtilsTests extends TestCase {
assertEquals(short[][][].class, ClassUtils.forName("[[[S", classLoader)); assertEquals(short[][][].class, ClassUtils.forName("[[[S", classLoader));
} }
@Test
public void testForNameWithPrimitiveClasses() throws ClassNotFoundException { public void testForNameWithPrimitiveClasses() throws ClassNotFoundException {
assertEquals(boolean.class, ClassUtils.forName("boolean", classLoader)); assertEquals(boolean.class, ClassUtils.forName("boolean", classLoader));
assertEquals(byte.class, ClassUtils.forName("byte", classLoader)); assertEquals(byte.class, ClassUtils.forName("byte", classLoader));
@ -81,6 +89,7 @@ public class ClassUtilsTests extends TestCase {
assertEquals(void.class, ClassUtils.forName("void", classLoader)); assertEquals(void.class, ClassUtils.forName("void", classLoader));
} }
@Test
public void testForNameWithPrimitiveArrays() throws ClassNotFoundException { public void testForNameWithPrimitiveArrays() throws ClassNotFoundException {
assertEquals(boolean[].class, ClassUtils.forName("boolean[]", classLoader)); assertEquals(boolean[].class, ClassUtils.forName("boolean[]", classLoader));
assertEquals(byte[].class, ClassUtils.forName("byte[]", classLoader)); assertEquals(byte[].class, ClassUtils.forName("byte[]", classLoader));
@ -92,6 +101,7 @@ public class ClassUtilsTests extends TestCase {
assertEquals(double[].class, ClassUtils.forName("double[]", classLoader)); assertEquals(double[].class, ClassUtils.forName("double[]", classLoader));
} }
@Test
public void testForNameWithPrimitiveArraysInternalName() throws ClassNotFoundException { public void testForNameWithPrimitiveArraysInternalName() throws ClassNotFoundException {
assertEquals(boolean[].class, ClassUtils.forName(boolean[].class.getName(), classLoader)); assertEquals(boolean[].class, ClassUtils.forName(boolean[].class.getName(), classLoader));
assertEquals(byte[].class, ClassUtils.forName(byte[].class.getName(), classLoader)); assertEquals(byte[].class, ClassUtils.forName(byte[].class.getName(), classLoader));
@ -103,76 +113,91 @@ public class ClassUtilsTests extends TestCase {
assertEquals(double[].class, ClassUtils.forName(double[].class.getName(), classLoader)); assertEquals(double[].class, ClassUtils.forName(double[].class.getName(), classLoader));
} }
@Test
public void testGetShortName() { public void testGetShortName() {
String className = ClassUtils.getShortName(getClass()); String className = ClassUtils.getShortName(getClass());
assertEquals("Class name did not match", "ClassUtilsTests", className); assertEquals("Class name did not match", "ClassUtilsTests", className);
} }
@Test
public void testGetShortNameForObjectArrayClass() { public void testGetShortNameForObjectArrayClass() {
String className = ClassUtils.getShortName(Object[].class); String className = ClassUtils.getShortName(Object[].class);
assertEquals("Class name did not match", "Object[]", className); assertEquals("Class name did not match", "Object[]", className);
} }
@Test
public void testGetShortNameForMultiDimensionalObjectArrayClass() { public void testGetShortNameForMultiDimensionalObjectArrayClass() {
String className = ClassUtils.getShortName(Object[][].class); String className = ClassUtils.getShortName(Object[][].class);
assertEquals("Class name did not match", "Object[][]", className); assertEquals("Class name did not match", "Object[][]", className);
} }
@Test
public void testGetShortNameForPrimitiveArrayClass() { public void testGetShortNameForPrimitiveArrayClass() {
String className = ClassUtils.getShortName(byte[].class); String className = ClassUtils.getShortName(byte[].class);
assertEquals("Class name did not match", "byte[]", className); assertEquals("Class name did not match", "byte[]", className);
} }
@Test
public void testGetShortNameForMultiDimensionalPrimitiveArrayClass() { public void testGetShortNameForMultiDimensionalPrimitiveArrayClass() {
String className = ClassUtils.getShortName(byte[][][].class); String className = ClassUtils.getShortName(byte[][][].class);
assertEquals("Class name did not match", "byte[][][]", className); assertEquals("Class name did not match", "byte[][][]", className);
} }
@Test
public void testGetShortNameForInnerClass() { public void testGetShortNameForInnerClass() {
String className = ClassUtils.getShortName(InnerClass.class); String className = ClassUtils.getShortName(InnerClass.class);
assertEquals("Class name did not match", "ClassUtilsTests.InnerClass", className); assertEquals("Class name did not match", "ClassUtilsTests.InnerClass", className);
} }
@Test
public void testGetShortNameAsProperty() { public void testGetShortNameAsProperty() {
String shortName = ClassUtils.getShortNameAsProperty(this.getClass()); String shortName = ClassUtils.getShortNameAsProperty(this.getClass());
assertEquals("Class name did not match", "classUtilsTests", shortName); assertEquals("Class name did not match", "classUtilsTests", shortName);
} }
@Test
public void testGetClassFileName() { public void testGetClassFileName() {
assertEquals("String.class", ClassUtils.getClassFileName(String.class)); assertEquals("String.class", ClassUtils.getClassFileName(String.class));
assertEquals("ClassUtilsTests.class", ClassUtils.getClassFileName(getClass())); assertEquals("ClassUtilsTests.class", ClassUtils.getClassFileName(getClass()));
} }
@Test
public void testGetPackageName() { public void testGetPackageName() {
assertEquals("java.lang", ClassUtils.getPackageName(String.class)); assertEquals("java.lang", ClassUtils.getPackageName(String.class));
assertEquals(getClass().getPackage().getName(), ClassUtils.getPackageName(getClass())); assertEquals(getClass().getPackage().getName(), ClassUtils.getPackageName(getClass()));
} }
@Test
public void testGetQualifiedName() { public void testGetQualifiedName() {
String className = ClassUtils.getQualifiedName(getClass()); String className = ClassUtils.getQualifiedName(getClass());
assertEquals("Class name did not match", "org.springframework.util.ClassUtilsTests", className); assertEquals("Class name did not match", "org.springframework.util.ClassUtilsTests", className);
} }
@Test
public void testGetQualifiedNameForObjectArrayClass() { public void testGetQualifiedNameForObjectArrayClass() {
String className = ClassUtils.getQualifiedName(Object[].class); String className = ClassUtils.getQualifiedName(Object[].class);
assertEquals("Class name did not match", "java.lang.Object[]", className); assertEquals("Class name did not match", "java.lang.Object[]", className);
} }
@Test
public void testGetQualifiedNameForMultiDimensionalObjectArrayClass() { public void testGetQualifiedNameForMultiDimensionalObjectArrayClass() {
String className = ClassUtils.getQualifiedName(Object[][].class); String className = ClassUtils.getQualifiedName(Object[][].class);
assertEquals("Class name did not match", "java.lang.Object[][]", className); assertEquals("Class name did not match", "java.lang.Object[][]", className);
} }
@Test
public void testGetQualifiedNameForPrimitiveArrayClass() { public void testGetQualifiedNameForPrimitiveArrayClass() {
String className = ClassUtils.getQualifiedName(byte[].class); String className = ClassUtils.getQualifiedName(byte[].class);
assertEquals("Class name did not match", "byte[]", className); assertEquals("Class name did not match", "byte[]", className);
} }
@Test
public void testGetQualifiedNameForMultiDimensionalPrimitiveArrayClass() { public void testGetQualifiedNameForMultiDimensionalPrimitiveArrayClass() {
String className = ClassUtils.getQualifiedName(byte[][].class); String className = ClassUtils.getQualifiedName(byte[][].class);
assertEquals("Class name did not match", "byte[][]", className); assertEquals("Class name did not match", "byte[][]", className);
} }
@Test
public void testHasMethod() throws Exception { public void testHasMethod() throws Exception {
assertTrue(ClassUtils.hasMethod(Collection.class, "size")); assertTrue(ClassUtils.hasMethod(Collection.class, "size"));
assertTrue(ClassUtils.hasMethod(Collection.class, "remove", Object.class)); assertTrue(ClassUtils.hasMethod(Collection.class, "remove", Object.class));
@ -180,6 +205,7 @@ public class ClassUtilsTests extends TestCase {
assertFalse(ClassUtils.hasMethod(Collection.class, "someOtherMethod")); assertFalse(ClassUtils.hasMethod(Collection.class, "someOtherMethod"));
} }
@Test
public void testGetMethodIfAvailable() throws Exception { public void testGetMethodIfAvailable() throws Exception {
Method method = ClassUtils.getMethodIfAvailable(Collection.class, "size"); Method method = ClassUtils.getMethodIfAvailable(Collection.class, "size");
assertNotNull(method); assertNotNull(method);
@ -193,6 +219,7 @@ public class ClassUtilsTests extends TestCase {
assertNull(ClassUtils.getMethodIfAvailable(Collection.class, "someOtherMethod")); assertNull(ClassUtils.getMethodIfAvailable(Collection.class, "someOtherMethod"));
} }
@Test
public void testGetMethodCountForName() { public void testGetMethodCountForName() {
assertEquals("Verifying number of overloaded 'print' methods for OverloadedMethodsClass.", 2, assertEquals("Verifying number of overloaded 'print' methods for OverloadedMethodsClass.", 2,
ClassUtils.getMethodCountForName(OverloadedMethodsClass.class, "print")); ClassUtils.getMethodCountForName(OverloadedMethodsClass.class, "print"));
@ -200,6 +227,7 @@ public class ClassUtilsTests extends TestCase {
ClassUtils.getMethodCountForName(SubOverloadedMethodsClass.class, "print")); ClassUtils.getMethodCountForName(SubOverloadedMethodsClass.class, "print"));
} }
@Test
public void testCountOverloadedMethods() { public void testCountOverloadedMethods() {
assertFalse(ClassUtils.hasAtLeastOneMethodWithName(TestObject.class, "foobar")); assertFalse(ClassUtils.hasAtLeastOneMethodWithName(TestObject.class, "foobar"));
// no args // no args
@ -208,6 +236,7 @@ public class ClassUtilsTests extends TestCase {
assertTrue(ClassUtils.hasAtLeastOneMethodWithName(TestObject.class, "setAge")); assertTrue(ClassUtils.hasAtLeastOneMethodWithName(TestObject.class, "setAge"));
} }
@Test
public void testNoArgsStaticMethod() throws IllegalAccessException, InvocationTargetException { public void testNoArgsStaticMethod() throws IllegalAccessException, InvocationTargetException {
Method method = ClassUtils.getStaticMethod(InnerClass.class, "staticMethod", (Class[]) null); Method method = ClassUtils.getStaticMethod(InnerClass.class, "staticMethod", (Class[]) null);
method.invoke(null, (Object[]) null); method.invoke(null, (Object[]) null);
@ -215,6 +244,7 @@ public class ClassUtilsTests extends TestCase {
InnerClass.noArgCalled); InnerClass.noArgCalled);
} }
@Test
public void testArgsStaticMethod() throws IllegalAccessException, InvocationTargetException { public void testArgsStaticMethod() throws IllegalAccessException, InvocationTargetException {
Method method = ClassUtils.getStaticMethod(InnerClass.class, "argStaticMethod", Method method = ClassUtils.getStaticMethod(InnerClass.class, "argStaticMethod",
new Class[] {String.class}); new Class[] {String.class});
@ -222,6 +252,7 @@ public class ClassUtilsTests extends TestCase {
assertTrue("argument method was not invoked.", InnerClass.argCalled); assertTrue("argument method was not invoked.", InnerClass.argCalled);
} }
@Test
public void testOverloadedStaticMethod() throws IllegalAccessException, InvocationTargetException { public void testOverloadedStaticMethod() throws IllegalAccessException, InvocationTargetException {
Method method = ClassUtils.getStaticMethod(InnerClass.class, "staticMethod", Method method = ClassUtils.getStaticMethod(InnerClass.class, "staticMethod",
new Class[] {String.class}); new Class[] {String.class});
@ -230,6 +261,7 @@ public class ClassUtilsTests extends TestCase {
InnerClass.overloadedCalled); InnerClass.overloadedCalled);
} }
@Test
public void testIsAssignable() { public void testIsAssignable() {
assertTrue(ClassUtils.isAssignable(Object.class, Object.class)); assertTrue(ClassUtils.isAssignable(Object.class, Object.class));
assertTrue(ClassUtils.isAssignable(String.class, String.class)); assertTrue(ClassUtils.isAssignable(String.class, String.class));
@ -245,11 +277,13 @@ public class ClassUtilsTests extends TestCase {
assertFalse(ClassUtils.isAssignable(double.class, Integer.class)); assertFalse(ClassUtils.isAssignable(double.class, Integer.class));
} }
@Test
public void testClassPackageAsResourcePath() { public void testClassPackageAsResourcePath() {
String result = ClassUtils.classPackageAsResourcePath(Proxy.class); String result = ClassUtils.classPackageAsResourcePath(Proxy.class);
assertTrue(result.equals("java/lang/reflect")); assertTrue(result.equals("java/lang/reflect"));
} }
@Test
public void testAddResourcePathToPackagePath() { public void testAddResourcePathToPackagePath() {
String result = "java/lang/reflect/xyzabc.xml"; String result = "java/lang/reflect/xyzabc.xml";
assertEquals(result, ClassUtils.addResourcePathToPackagePath(Proxy.class, "xyzabc.xml")); assertEquals(result, ClassUtils.addResourcePathToPackagePath(Proxy.class, "xyzabc.xml"));
@ -259,6 +293,7 @@ public class ClassUtilsTests extends TestCase {
ClassUtils.addResourcePathToPackagePath(Proxy.class, "a/b/c/d.xml")); ClassUtils.addResourcePathToPackagePath(Proxy.class, "a/b/c/d.xml"));
} }
@Test
public void testGetAllInterfaces() { public void testGetAllInterfaces() {
DerivedTestObject testBean = new DerivedTestObject(); DerivedTestObject testBean = new DerivedTestObject();
List ifcs = Arrays.asList(ClassUtils.getAllInterfaces(testBean)); List ifcs = Arrays.asList(ClassUtils.getAllInterfaces(testBean));
@ -268,6 +303,7 @@ public class ClassUtilsTests extends TestCase {
assertTrue("Contains IOther", ifcs.contains(ITestInterface.class)); assertTrue("Contains IOther", ifcs.contains(ITestInterface.class));
} }
@Test
public void testClassNamesToString() { public void testClassNamesToString() {
List ifcs = new LinkedList(); List ifcs = new LinkedList();
ifcs.add(Serializable.class); ifcs.add(Serializable.class);
@ -288,6 +324,36 @@ public class ClassUtilsTests extends TestCase {
assertEquals("[]", ClassUtils.classNamesToString(Collections.EMPTY_LIST)); assertEquals("[]", ClassUtils.classNamesToString(Collections.EMPTY_LIST));
} }
@Test
public void testDetermineCommonAncestor() {
assertEquals(Number.class, ClassUtils.determineCommonAncestor(Integer.class, Number.class));
assertEquals(Number.class, ClassUtils.determineCommonAncestor(Number.class, Integer.class));
assertEquals(Number.class, ClassUtils.determineCommonAncestor(Number.class, null));
assertEquals(Integer.class, ClassUtils.determineCommonAncestor(null, Integer.class));
assertEquals(Integer.class, ClassUtils.determineCommonAncestor(Integer.class, Integer.class));
assertEquals(Number.class, ClassUtils.determineCommonAncestor(Integer.class, Float.class));
assertEquals(Number.class, ClassUtils.determineCommonAncestor(Float.class, Integer.class));
assertNull(ClassUtils.determineCommonAncestor(Integer.class, String.class));
assertNull(ClassUtils.determineCommonAncestor(String.class, Integer.class));
assertEquals(Collection.class, ClassUtils.determineCommonAncestor(List.class, Collection.class));
assertEquals(Collection.class, ClassUtils.determineCommonAncestor(Collection.class, List.class));
assertEquals(Collection.class, ClassUtils.determineCommonAncestor(Collection.class, null));
assertEquals(List.class, ClassUtils.determineCommonAncestor(null, List.class));
assertEquals(List.class, ClassUtils.determineCommonAncestor(List.class, List.class));
assertNull(ClassUtils.determineCommonAncestor(List.class, Set.class));
assertNull(ClassUtils.determineCommonAncestor(Set.class, List.class));
assertNull(ClassUtils.determineCommonAncestor(List.class, Runnable.class));
assertNull(ClassUtils.determineCommonAncestor(Runnable.class, List.class));
assertEquals(List.class, ClassUtils.determineCommonAncestor(List.class, ArrayList.class));
assertEquals(List.class, ClassUtils.determineCommonAncestor(ArrayList.class, List.class));
assertNull(ClassUtils.determineCommonAncestor(List.class, String.class));
assertNull(ClassUtils.determineCommonAncestor(String.class, List.class));
}
public static class InnerClass { public static class InnerClass {