Reject null return value from MethodReplacer for primitive return type
This commit throws an exception instead of silently converting a null return value from a MethodReplacer to a primitive 0/false value. See gh-32412
This commit is contained in:
parent
f285971cb3
commit
3e48031601
|
|
@ -18,6 +18,7 @@ package org.springframework.beans.factory.support;
|
||||||
|
|
||||||
import java.lang.reflect.Constructor;
|
import java.lang.reflect.Constructor;
|
||||||
import java.lang.reflect.Method;
|
import java.lang.reflect.Method;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
|
@ -275,13 +276,24 @@ public class CglibSubclassingInstantiationStrategy extends SimpleInstantiationSt
|
||||||
this.owner = owner;
|
this.owner = owner;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Nullable
|
||||||
@Override
|
@Override
|
||||||
public Object intercept(Object obj, Method method, Object[] args, MethodProxy mp) throws Throwable {
|
public Object intercept(Object obj, Method method, Object[] args, MethodProxy mp) throws Throwable {
|
||||||
ReplaceOverride ro = (ReplaceOverride) getBeanDefinition().getMethodOverrides().getOverride(method);
|
ReplaceOverride ro = (ReplaceOverride) getBeanDefinition().getMethodOverrides().getOverride(method);
|
||||||
Assert.state(ro != null, "ReplaceOverride not found");
|
Assert.state(ro != null, "ReplaceOverride not found");
|
||||||
// TODO could cache if a singleton for minor performance optimization
|
// TODO could cache if a singleton for minor performance optimization
|
||||||
MethodReplacer mr = this.owner.getBean(ro.getMethodReplacerBeanName(), MethodReplacer.class);
|
MethodReplacer mr = this.owner.getBean(ro.getMethodReplacerBeanName(), MethodReplacer.class);
|
||||||
return mr.reimplement(obj, method, args);
|
return processReturnType(method, mr.reimplement(obj, method, args));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Nullable
|
||||||
|
private <T> T processReturnType(Method method, @Nullable T returnValue) {
|
||||||
|
Class<?> returnType = method.getReturnType();
|
||||||
|
if (returnType != void.class && returnType.isPrimitive()) {
|
||||||
|
return Objects.requireNonNull(returnValue, () -> "Null return value from replacer does not match primitive return type for: " + method);
|
||||||
|
}
|
||||||
|
|
||||||
|
return returnValue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,120 @@
|
||||||
|
package org.springframework.beans.factory.support;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||||
|
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
|
||||||
|
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
import org.assertj.core.api.ThrowableAssert;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.springframework.lang.Nullable;
|
||||||
|
|
||||||
|
class CglibSubclassingInstantiationStrategyTests {
|
||||||
|
|
||||||
|
private final CglibSubclassingInstantiationStrategy strategy = new CglibSubclassingInstantiationStrategy();
|
||||||
|
|
||||||
|
@Nullable
|
||||||
|
public static Object valueToReturnFromReplacer;
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void methodOverride() {
|
||||||
|
StaticListableBeanFactory beanFactory = new StaticListableBeanFactory(Map.of(
|
||||||
|
"myBean", new MyBean(),
|
||||||
|
"replacer", new MyReplacer()
|
||||||
|
));
|
||||||
|
|
||||||
|
RootBeanDefinition bd = new RootBeanDefinition(MyBean.class);
|
||||||
|
MethodOverrides methodOverrides = new MethodOverrides();
|
||||||
|
Stream.of("getBoolean", "getShort", "getInt", "getLong", "getFloat", "getDouble", "getByte")
|
||||||
|
.forEach(methodToOverride -> addOverride(methodOverrides, methodToOverride));
|
||||||
|
bd.setMethodOverrides(methodOverrides);
|
||||||
|
|
||||||
|
MyBean bean = (MyBean) strategy.instantiate(bd, "myBean", beanFactory);
|
||||||
|
|
||||||
|
valueToReturnFromReplacer = null;
|
||||||
|
assertCorrectExceptionThrownBy(bean::getBoolean);
|
||||||
|
valueToReturnFromReplacer = true;
|
||||||
|
assertThat(bean.getBoolean()).isTrue();
|
||||||
|
|
||||||
|
valueToReturnFromReplacer = null;
|
||||||
|
assertCorrectExceptionThrownBy(bean::getShort);
|
||||||
|
valueToReturnFromReplacer = 123;
|
||||||
|
assertThat(bean.getShort()).isEqualTo((short) 123);
|
||||||
|
|
||||||
|
valueToReturnFromReplacer = null;
|
||||||
|
assertCorrectExceptionThrownBy(bean::getInt);
|
||||||
|
valueToReturnFromReplacer = 123;
|
||||||
|
assertThat(bean.getInt()).isEqualTo(123);
|
||||||
|
|
||||||
|
valueToReturnFromReplacer = null;
|
||||||
|
assertCorrectExceptionThrownBy(bean::getLong);
|
||||||
|
valueToReturnFromReplacer = 123;
|
||||||
|
assertThat(bean.getLong()).isEqualTo(123L);
|
||||||
|
|
||||||
|
valueToReturnFromReplacer = null;
|
||||||
|
assertCorrectExceptionThrownBy(bean::getFloat);
|
||||||
|
valueToReturnFromReplacer = 123;
|
||||||
|
assertThat(bean.getFloat()).isEqualTo(123f);
|
||||||
|
|
||||||
|
valueToReturnFromReplacer = null;
|
||||||
|
assertCorrectExceptionThrownBy(bean::getDouble);
|
||||||
|
valueToReturnFromReplacer = 123;
|
||||||
|
assertThat(bean.getDouble()).isEqualTo(123d);
|
||||||
|
|
||||||
|
valueToReturnFromReplacer = null;
|
||||||
|
assertCorrectExceptionThrownBy(bean::getByte);
|
||||||
|
valueToReturnFromReplacer = 123;
|
||||||
|
assertThat(bean.getByte()).isEqualTo((byte) 123);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertCorrectExceptionThrownBy(ThrowableAssert.ThrowingCallable runnable) {
|
||||||
|
assertThatThrownBy(runnable)
|
||||||
|
.isInstanceOf(NullPointerException.class)
|
||||||
|
.hasMessageMatching("Null return value from replacer does not match primitive return type for: "
|
||||||
|
+ "\\w+ org\\.springframework\\.beans\\.factory\\.support\\.CglibSubclassingInstantiationStrategyTests\\$MyBean\\.\\w+\\(\\)");
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addOverride(MethodOverrides methodOverrides, String methodToOverride) {
|
||||||
|
methodOverrides.addOverride(new ReplaceOverride(methodToOverride, "replacer"));
|
||||||
|
}
|
||||||
|
|
||||||
|
static class MyBean {
|
||||||
|
boolean getBoolean() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
short getShort() {
|
||||||
|
return 123;
|
||||||
|
}
|
||||||
|
|
||||||
|
int getInt() {
|
||||||
|
return 123;
|
||||||
|
}
|
||||||
|
|
||||||
|
long getLong() {
|
||||||
|
return 123;
|
||||||
|
}
|
||||||
|
|
||||||
|
float getFloat() {
|
||||||
|
return 123;
|
||||||
|
}
|
||||||
|
|
||||||
|
double getDouble() {
|
||||||
|
return 123;
|
||||||
|
}
|
||||||
|
|
||||||
|
byte getByte() {
|
||||||
|
return 123;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static class MyReplacer implements MethodReplacer {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object reimplement(Object obj, Method method, Object[] args) {
|
||||||
|
return CglibSubclassingInstantiationStrategyTests.valueToReturnFromReplacer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue