fix: Preserve NULL Parameters in case of AOP Proxy

Signed-off-by: Shah Nisarg Pankaj <nisargshah693@gmail.com>
This commit is contained in:
Nisarg Shah 2025-06-08 14:12:01 +00:00 committed by Shah Nisarg Pankaj
parent de7d50d39f
commit 254bf589a3
4 changed files with 97 additions and 21 deletions

View File

@ -381,7 +381,7 @@ public abstract class AopUtils {
Continuation<?> continuation = (Continuation<?>) args[args.length -1]; Continuation<?> continuation = (Continuation<?>) args[args.length -1];
Assert.state(continuation != null, "No Continuation available"); Assert.state(continuation != null, "No Continuation available");
CoroutineContext context = continuation.getContext().minusKey(Job.Key); CoroutineContext context = continuation.getContext().minusKey(Job.Key);
return CoroutinesUtils.invokeSuspendingFunction(context, method, target, args); return CoroutinesUtils.invokeSuspendingFunctionPreserveNulls(context, method, target, args);
} }
} }

View File

@ -43,6 +43,17 @@ class AopUtilsKotlinTests {
} }
} }
@Test
fun `Invoking suspending function with null argument should not return default value`() {
val method = ReflectionUtils.findMethod(WithoutInterface::class.java, "handleWithDefaultParam",
String::class. java, Continuation::class.java)!!
val continuation = Continuation<Any>(CoroutineName("test")) { }
val result = AopUtils.invokeJoinpointUsingReflection(WithoutInterface(), method, arrayOf(null, continuation))
assertThat(result).isInstanceOfSatisfying(Mono::class.java) {
assertThat(it.block()).isEqualTo(null)
}
}
@Test @Test
fun `Invoking suspending function on bridged method should return Mono`() { fun `Invoking suspending function on bridged method should return Mono`() {
val value = "foo" val value = "foo"
@ -65,6 +76,11 @@ class AopUtilsKotlinTests {
delay(1) delay(1)
return value return value
} }
suspend fun handleWithDefaultParam(value: String? = "defaultVal") : String? {
delay(1)
return value
}
} }
interface ProxyInterface<T> { interface ProxyInterface<T> {

View File

@ -112,6 +112,35 @@ public abstract class CoroutinesUtils {
@SuppressWarnings({"DataFlowIssue", "NullAway"}) @SuppressWarnings({"DataFlowIssue", "NullAway"})
public static Publisher<?> invokeSuspendingFunction( public static Publisher<?> invokeSuspendingFunction(
CoroutineContext context, Method method, @Nullable Object target, @Nullable Object... args) { CoroutineContext context, Method method, @Nullable Object target, @Nullable Object... args) {
return invokeSuspendingFunctionCore(context, method, target, args, false);
}
/**
* Invoke a suspending function and convert it to {@link Mono} or
* {@link Flux}.
* @param context the coroutine context to use
* @param method the suspending function to invoke
* @param target the target to invoke {@code method} on
* @param args the function arguments. If the {@code Continuation} argument is specified as the last argument
* (typically {@code null}), it is ignored.
* @return the method invocation result as reactive stream
* @throws IllegalArgumentException if {@code method} is not a suspending function
* @since 6.0
* This function preservers the null parameter passed in argument
*/
@SuppressWarnings({"DataFlowIssue", "NullAway"})
public static Publisher<?> invokeSuspendingFunctionPreserveNulls(
CoroutineContext context, Method method, @Nullable Object target, @Nullable Object... args) {
return invokeSuspendingFunctionCore(context, method, target, args, true);
}
private static Publisher<?> invokeSuspendingFunctionCore(
CoroutineContext context,
Method method,
@Nullable Object target,
@Nullable Object[] args,
boolean preserveNulls)
{
Assert.isTrue(KotlinDetector.isSuspendingFunction(method), "Method must be a suspending function"); Assert.isTrue(KotlinDetector.isSuspendingFunction(method), "Method must be a suspending function");
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method); KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
@ -120,26 +149,7 @@ public abstract class CoroutinesUtils {
KCallablesJvm.setAccessible(function, true); KCallablesJvm.setAccessible(function, true);
} }
Mono<Object> mono = MonoKt.mono(context, (scope, continuation) -> { Mono<Object> mono = MonoKt.mono(context, (scope, continuation) -> {
Map<KParameter, Object> argMap = CollectionUtils.newHashMap(args.length + 1); Map<KParameter, Object> argMap = buildArgMap(function, target, args, preserveNulls);
int index = 0;
for (KParameter parameter : function.getParameters()) {
switch (parameter.getKind()) {
case INSTANCE -> argMap.put(parameter, target);
case VALUE, EXTENSION_RECEIVER -> {
Object arg = args[index];
if (!(parameter.isOptional() && arg == null)) {
KType type = parameter.getType();
if (!(type.isMarkedNullable() && arg == null) &&
type.getClassifier() instanceof KClass<?> kClass &&
KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(kClass))) {
arg = box(kClass, arg);
}
argMap.put(parameter, arg);
}
index++;
}
}
}
return KCallables.callSuspendBy(function, argMap, continuation); return KCallables.callSuspendBy(function, argMap, continuation);
}) })
.filter(result -> result != Unit.INSTANCE) .filter(result -> result != Unit.INSTANCE)
@ -158,6 +168,40 @@ public abstract class CoroutinesUtils {
return mono; return mono;
} }
private static Map<KParameter, Object> buildArgMap(
KFunction<?> function,
@Nullable Object target,
@Nullable Object[] args,
boolean preserveNulls) {
Map<KParameter, Object> argMap = CollectionUtils.newHashMap(args.length + 1);
int index = 0;
for (KParameter parameter : function.getParameters()) {
switch (parameter.getKind()) {
case INSTANCE -> argMap.put(parameter, target);
case VALUE, EXTENSION_RECEIVER -> {
Object arg = args[index];
if (!(parameter.isOptional() && arg == null)) {
KType type = parameter.getType();
if (!(type.isMarkedNullable() && arg == null) &&
type.getClassifier() instanceof KClass<?> kClass &&
KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(kClass))) {
arg = box(kClass, arg);
}
argMap.put(parameter, arg);
} else if(preserveNulls) {
argMap.put(parameter, arg);
}
index++;
}
}
}
return argMap;
}
private static Object box(KClass<?> kClass, @Nullable Object arg) { private static Object box(KClass<?> kClass, @Nullable Object arg) {
KFunction<?> constructor = Objects.requireNonNull(KClasses.getPrimaryConstructor(kClass)); KFunction<?> constructor = Objects.requireNonNull(KClasses.getPrimaryConstructor(kClass));
KType type = constructor.getParameters().get(0).getType(); KType type = constructor.getParameters().get(0).getType();

View File

@ -93,6 +93,16 @@ class CoroutinesUtilsTests {
} }
} }
@Test
fun invokeSuspendingFunctionWithNullableParameterPreservesNull() {
val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingFunctionWithOptionalParameterAndDefaultValue", String::class.java, Continuation::class.java)
val mono = CoroutinesUtils.invokeSuspendingFunctionPreserveNulls(Dispatchers.Unconfined, method, this, null, null) as Mono
runBlocking {
Assertions.assertThat(mono.awaitSingleOrNull()).isNull()
}
}
@Test @Test
fun invokePrivateSuspendingFunction() { fun invokePrivateSuspendingFunction() {
val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("privateSuspendingFunction", String::class.java, Continuation::class.java) val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("privateSuspendingFunction", String::class.java, Continuation::class.java)
@ -300,6 +310,12 @@ class CoroutinesUtilsTests {
return value return value
} }
suspend fun suspendingFunctionWithOptionalParameterAndDefaultValue(value: String? = "foo"): String? {
delay(1)
return value
}
suspend fun suspendingFunctionWithMono(): Mono<String> { suspend fun suspendingFunctionWithMono(): Mono<String> {
delay(1) delay(1)
return Mono.just("foo") return Mono.just("foo")