Reject null return value from MethodReplacer for primitive return type
This commit throws an exception instead of silently converting a null return value from a MethodReplacer to a primitive 0/false value. See gh-32412
This commit is contained in:
parent
f285971cb3
commit
3e48031601
|
|
@ -18,6 +18,7 @@ package org.springframework.beans.factory.support;
|
|||
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.Objects;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
|
@ -275,13 +276,24 @@ public class CglibSubclassingInstantiationStrategy extends SimpleInstantiationSt
|
|||
this.owner = owner;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public Object intercept(Object obj, Method method, Object[] args, MethodProxy mp) throws Throwable {
|
||||
ReplaceOverride ro = (ReplaceOverride) getBeanDefinition().getMethodOverrides().getOverride(method);
|
||||
Assert.state(ro != null, "ReplaceOverride not found");
|
||||
// TODO could cache if a singleton for minor performance optimization
|
||||
MethodReplacer mr = this.owner.getBean(ro.getMethodReplacerBeanName(), MethodReplacer.class);
|
||||
return mr.reimplement(obj, method, args);
|
||||
return processReturnType(method, mr.reimplement(obj, method, args));
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private <T> T processReturnType(Method method, @Nullable T returnValue) {
|
||||
Class<?> returnType = method.getReturnType();
|
||||
if (returnType != void.class && returnType.isPrimitive()) {
|
||||
return Objects.requireNonNull(returnValue, () -> "Null return value from replacer does not match primitive return type for: " + method);
|
||||
}
|
||||
|
||||
return returnValue;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,120 @@
|
|||
package org.springframework.beans.factory.support;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import org.assertj.core.api.ThrowableAssert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.lang.Nullable;
|
||||
|
||||
class CglibSubclassingInstantiationStrategyTests {
|
||||
|
||||
private final CglibSubclassingInstantiationStrategy strategy = new CglibSubclassingInstantiationStrategy();
|
||||
|
||||
@Nullable
|
||||
public static Object valueToReturnFromReplacer;
|
||||
|
||||
@Test
|
||||
void methodOverride() {
|
||||
StaticListableBeanFactory beanFactory = new StaticListableBeanFactory(Map.of(
|
||||
"myBean", new MyBean(),
|
||||
"replacer", new MyReplacer()
|
||||
));
|
||||
|
||||
RootBeanDefinition bd = new RootBeanDefinition(MyBean.class);
|
||||
MethodOverrides methodOverrides = new MethodOverrides();
|
||||
Stream.of("getBoolean", "getShort", "getInt", "getLong", "getFloat", "getDouble", "getByte")
|
||||
.forEach(methodToOverride -> addOverride(methodOverrides, methodToOverride));
|
||||
bd.setMethodOverrides(methodOverrides);
|
||||
|
||||
MyBean bean = (MyBean) strategy.instantiate(bd, "myBean", beanFactory);
|
||||
|
||||
valueToReturnFromReplacer = null;
|
||||
assertCorrectExceptionThrownBy(bean::getBoolean);
|
||||
valueToReturnFromReplacer = true;
|
||||
assertThat(bean.getBoolean()).isTrue();
|
||||
|
||||
valueToReturnFromReplacer = null;
|
||||
assertCorrectExceptionThrownBy(bean::getShort);
|
||||
valueToReturnFromReplacer = 123;
|
||||
assertThat(bean.getShort()).isEqualTo((short) 123);
|
||||
|
||||
valueToReturnFromReplacer = null;
|
||||
assertCorrectExceptionThrownBy(bean::getInt);
|
||||
valueToReturnFromReplacer = 123;
|
||||
assertThat(bean.getInt()).isEqualTo(123);
|
||||
|
||||
valueToReturnFromReplacer = null;
|
||||
assertCorrectExceptionThrownBy(bean::getLong);
|
||||
valueToReturnFromReplacer = 123;
|
||||
assertThat(bean.getLong()).isEqualTo(123L);
|
||||
|
||||
valueToReturnFromReplacer = null;
|
||||
assertCorrectExceptionThrownBy(bean::getFloat);
|
||||
valueToReturnFromReplacer = 123;
|
||||
assertThat(bean.getFloat()).isEqualTo(123f);
|
||||
|
||||
valueToReturnFromReplacer = null;
|
||||
assertCorrectExceptionThrownBy(bean::getDouble);
|
||||
valueToReturnFromReplacer = 123;
|
||||
assertThat(bean.getDouble()).isEqualTo(123d);
|
||||
|
||||
valueToReturnFromReplacer = null;
|
||||
assertCorrectExceptionThrownBy(bean::getByte);
|
||||
valueToReturnFromReplacer = 123;
|
||||
assertThat(bean.getByte()).isEqualTo((byte) 123);
|
||||
}
|
||||
|
||||
private void assertCorrectExceptionThrownBy(ThrowableAssert.ThrowingCallable runnable) {
|
||||
assertThatThrownBy(runnable)
|
||||
.isInstanceOf(NullPointerException.class)
|
||||
.hasMessageMatching("Null return value from replacer does not match primitive return type for: "
|
||||
+ "\\w+ org\\.springframework\\.beans\\.factory\\.support\\.CglibSubclassingInstantiationStrategyTests\\$MyBean\\.\\w+\\(\\)");
|
||||
}
|
||||
|
||||
private void addOverride(MethodOverrides methodOverrides, String methodToOverride) {
|
||||
methodOverrides.addOverride(new ReplaceOverride(methodToOverride, "replacer"));
|
||||
}
|
||||
|
||||
static class MyBean {
|
||||
boolean getBoolean() {
|
||||
return true;
|
||||
}
|
||||
|
||||
short getShort() {
|
||||
return 123;
|
||||
}
|
||||
|
||||
int getInt() {
|
||||
return 123;
|
||||
}
|
||||
|
||||
long getLong() {
|
||||
return 123;
|
||||
}
|
||||
|
||||
float getFloat() {
|
||||
return 123;
|
||||
}
|
||||
|
||||
double getDouble() {
|
||||
return 123;
|
||||
}
|
||||
|
||||
byte getByte() {
|
||||
return 123;
|
||||
}
|
||||
}
|
||||
|
||||
static class MyReplacer implements MethodReplacer {
|
||||
|
||||
@Override
|
||||
public Object reimplement(Object obj, Method method, Object[] args) {
|
||||
return CglibSubclassingInstantiationStrategyTests.valueToReturnFromReplacer;
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue