From a143b57d4b397f5778428be40cc7a5083c42ee89 Mon Sep 17 00:00:00 2001 From: Sebastien Deleuze Date: Mon, 21 Nov 2016 23:59:41 +0100 Subject: [PATCH] Polish Kotlin nullable support This commit polishes Kotlin nullable support by reusing MethodParameter#isOptional() instead of adding a new MethodParameter#isNullable() method, adds Kotlin tests and introduces Spring Web Reactive support. Issue: SPR-14165 --- build.gradle | 16 +- .../springframework/core/MethodParameter.java | 19 +- .../org/springframework/util/KotlinUtils.java | 64 ++--- .../springframework/util/KotlinUtilsTests.kt | 30 +-- ...tractNamedValueMethodArgumentResolver.java | 2 +- ...notationMethodMessageHandlerKotlinTests.kt | 197 ++++++++++++++++ ...tParamMethodArgumentResolverKotlinTests.kt | 119 ++++++++++ ...tractNamedValueMethodArgumentResolver.java | 2 +- ...tParamMethodArgumentResolverKotlinTests.kt | 218 ++++++++++++++++++ .../RequestPartMethodArgumentResolver.java | 5 +- 10 files changed, 609 insertions(+), 63 deletions(-) create mode 100644 spring-messaging/src/test/kotlin/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandlerKotlinTests.kt create mode 100644 spring-web-reactive/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/RequestParamMethodArgumentResolverKotlinTests.kt create mode 100644 spring-web/src/test/kotlin/org/springframework/web/method/annotation/RequestParamMethodArgumentResolverKotlinTests.kt diff --git a/build.gradle b/build.gradle index b8580f9a32a..c6751fb634e 100644 --- a/build.gradle +++ b/build.gradle @@ -5,7 +5,7 @@ buildscript { dependencies { classpath("org.springframework.build.gradle:propdeps-plugin:0.0.7") classpath("org.asciidoctor:asciidoctor-gradle-plugin:1.5.3") - classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:1.0.4" + classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:1.0.5-2" classpath("io.spring.gradle:docbook-reference-plugin:0.3.1") } } @@ -71,7 +71,7 @@ configure(allprojects) { project -> ext.junitPlatformVersion = '1.0.0-M2' ext.log4jVersion = '2.7' ext.nettyVersion = "4.1.6.Final" - ext.kotlinVersion = "1.0.4" + ext.kotlinVersion = "1.0.5-2" ext.okhttpVersion = "2.7.5" ext.okhttp3Version = "3.4.2" ext.poiVersion = "3.15" @@ -569,6 +569,8 @@ project("spring-oxm") { project("spring-messaging") { description = "Spring Messaging" + apply plugin: "kotlin" + dependencies { compile(project(":spring-beans")) compile(project(":spring-core")) @@ -604,6 +606,8 @@ project("spring-messaging") { testCompile("io.netty:netty-all:${nettyVersion}") testCompile("org.xmlunit:xmlunit-matchers:${xmlunitVersion}") testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}") + testCompile("org.jetbrains.kotlin:kotlin-reflect:${kotlinVersion}") + testCompile("org.jetbrains.kotlin:kotlin-stdlib:${kotlinVersion}") testRuntime("javax.activation:activation:${activationApiVersion}") testRuntime("com.sun.xml.bind:jaxb-core:${jaxbVersion}") testRuntime("com.sun.xml.bind:jaxb-impl:${jaxbVersion}") @@ -710,7 +714,9 @@ project("spring-context-indexer") { project("spring-web") { description = "Spring Web" + apply plugin: "groovy" + apply plugin: "kotlin" dependencies { compile(project(":spring-aop")) // for JaxWsPortProxyFactoryBean @@ -781,6 +787,8 @@ project("spring-web") { testCompile("com.squareup.okhttp3:mockwebserver:${okhttp3Version}") testCompile("org.xmlunit:xmlunit-matchers:${xmlunitVersion}") testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}") + testCompile("org.jetbrains.kotlin:kotlin-reflect:${kotlinVersion}") + testCompile("org.jetbrains.kotlin:kotlin-stdlib:${kotlinVersion}") testRuntime("com.sun.mail:javax.mail:${javamailVersion}") testRuntime("com.sun.xml.bind:jaxb-core:${jaxbVersion}") testRuntime("com.sun.xml.bind:jaxb-impl:${jaxbVersion}") @@ -797,6 +805,8 @@ project("spring-web") { project("spring-web-reactive") { description = "Spring Web Reactive" + apply plugin: "kotlin" + dependencies { compile(project(":spring-core")) compile(project(":spring-web")) @@ -828,6 +838,8 @@ project("spring-web-reactive") { testCompile("com.fasterxml:aalto-xml:1.0.0") testCompile("org.xmlunit:xmlunit-matchers:${xmlunitVersion}") testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}") + testCompile("org.jetbrains.kotlin:kotlin-reflect:${kotlinVersion}") + testCompile("org.jetbrains.kotlin:kotlin-stdlib:${kotlinVersion}") testRuntime("javax.el:javax.el-api:${elApiVersion}") testRuntime("org.glassfish:javax.el:3.0.1-b08") testRuntime("com.sun.xml.bind:jaxb-core:${jaxbVersion}") diff --git a/spring-core/src/main/java/org/springframework/core/MethodParameter.java b/spring-core/src/main/java/org/springframework/core/MethodParameter.java index d55381ad492..03de8085f44 100644 --- a/spring-core/src/main/java/org/springframework/core/MethodParameter.java +++ b/spring-core/src/main/java/org/springframework/core/MethodParameter.java @@ -310,12 +310,12 @@ public class MethodParameter { } /** - * Return whether this method parameter is declared as optional - * in the form of Java 8's {@link java.util.Optional}. + * Return whether this method indicates a parameter which is not required + * (either in the form of Java 8's {@link java.util.Optional} or Kotlin nullable type). * @since 4.3 */ public boolean isOptional() { - return (getParameterType() == Optional.class); + return (getParameterType() == Optional.class || KotlinUtils.isNullable(this)); } /** @@ -327,18 +327,7 @@ public class MethodParameter { * @see #nested() */ public MethodParameter nestedIfOptional() { - return (isOptional() ? nested() : this); - } - - /** - * Return whether this method parameter is declared as a "nullable" value, if supported by - * the underlying language. Currently the only supported language is Kotlin. - * @since 5.0 - */ - public boolean isNullable() { - return KotlinUtils.isKotlinPresent() && - KotlinUtils.isKotlinClass(getContainingClass()) && - KotlinUtils.isNullable(this.parameterIndex, this.method, this.constructor); + return (getParameterType() == Optional.class ? nested() : this); } /** diff --git a/spring-core/src/main/java/org/springframework/util/KotlinUtils.java b/spring-core/src/main/java/org/springframework/util/KotlinUtils.java index 1a08d7079c8..d47ab71a668 100644 --- a/spring-core/src/main/java/org/springframework/util/KotlinUtils.java +++ b/spring-core/src/main/java/org/springframework/util/KotlinUtils.java @@ -22,7 +22,6 @@ import kotlin.reflect.KParameter; import kotlin.reflect.jvm.ReflectJvmMapping; import org.springframework.core.MethodParameter; -import java.lang.reflect.Constructor; import java.lang.reflect.Method; import java.util.List; import java.util.stream.Collectors; @@ -31,44 +30,55 @@ import java.util.stream.Collectors; * Miscellaneous Kotlin utility methods. * * @author Raman Gupta + * @author Sebastien Deleuze * @since 5.0 */ -public class KotlinUtils { +public abstract class KotlinUtils { - private static final boolean kotlinPresent; - - static { - kotlinPresent = ClassUtils.isPresent("kotlin.Unit", MethodParameter.class.getClassLoader()); - } + private static final boolean kotlinPresent = ClassUtils.isPresent("kotlin.Unit", KotlinUtils.class.getClassLoader()); + /** + * Return whether Kotlin is available on the classpath or not. + */ public static boolean isKotlinPresent() { return kotlinPresent; } + /** + * Return whether the specified type is a Kotlin class or not. + */ public static boolean isKotlinClass(Class type) { - return type != null && type.getDeclaredAnnotation(Metadata.class) != null; + Assert.notNull(type, "Type must not be null"); + return isKotlinPresent() && type.getDeclaredAnnotation(Metadata.class) != null; } - public static boolean isNullable(int parameterIndex, Method method, Constructor constructor) { - if(parameterIndex < 0) { - KFunction function = ReflectJvmMapping.getKotlinFunction(method); - return function != null && function.getReturnType().isMarkedNullable(); - } else { - KFunction function = method != null ? - ReflectJvmMapping.getKotlinFunction(method) : - ReflectJvmMapping.getKotlinFunction(constructor); - if(function != null) { - @SuppressWarnings("unchecked") - List parameters = function.getParameters(); - return parameters - .stream() - .filter(p -> KParameter.Kind.VALUE.equals(p.getKind())) - .collect(Collectors.toList()) - .get(parameterIndex) - .getType() - .isMarkedNullable(); + /** + * Check whether the specified {@link MethodParameter} represents a nullable Kotlin type or not. + */ + public static boolean isNullable(MethodParameter methodParameter) { + Method method = methodParameter.getMethod(); + int parameterIndex = methodParameter.getParameterIndex(); + if (isKotlinClass(methodParameter.getContainingClass())) { + if (parameterIndex < 0) { + KFunction function = ReflectJvmMapping.getKotlinFunction(method); + return function != null && function.getReturnType().isMarkedNullable(); + } + else { + KFunction function = (method != null ? ReflectJvmMapping.getKotlinFunction(method) : + ReflectJvmMapping.getKotlinFunction(methodParameter.getConstructor())); + if (function != null) { + List parameters = function.getParameters(); + return parameters + .stream() + .filter(p -> KParameter.Kind.VALUE.equals(p.getKind())) + .collect(Collectors.toList()) + .get(parameterIndex) + .getType() + .isMarkedNullable(); + } } - return false; } + return false; } + } diff --git a/spring-core/src/test/kotlin/org/springframework/util/KotlinUtilsTests.kt b/spring-core/src/test/kotlin/org/springframework/util/KotlinUtilsTests.kt index e09700b99e8..10eb4786e1c 100644 --- a/spring-core/src/test/kotlin/org/springframework/util/KotlinUtilsTests.kt +++ b/spring-core/src/test/kotlin/org/springframework/util/KotlinUtilsTests.kt @@ -11,20 +11,22 @@ import org.springframework.core.MethodParameter import org.springframework.util.KotlinUtils.* /** - * Unit tests for [KotlinUtils]. - + * Test fixture for [KotlinUtils]. + * * @author Raman Gupta + * @author Sebastien Deleuze */ class KotlinUtilsTests { - private lateinit var methodNullable: Method - private lateinit var methodNonNullable: Method + lateinit var nullableMethod: Method + + lateinit var nonNullableMethod: Method @Before @Throws(NoSuchMethodException::class) - fun setUp() { - methodNullable = javaClass.getMethod("methodNullable", String::class.java, java.lang.Long.TYPE) - methodNonNullable = javaClass.getMethod("methodNonNullable", String::class.java, java.lang.Long.TYPE) + fun setup() { + nullableMethod = javaClass.getMethod("nullable", String::class.java) + nonNullableMethod = javaClass.getMethod("nonNullable", String::class.java) } @Test @@ -35,26 +37,26 @@ class KotlinUtilsTests { @Test fun `Are kotlin classes detected`() { - assertFalse(isKotlinClass(null)) assertFalse(isKotlinClass(MethodParameter::class.java)) assertTrue(isKotlinClass(javaClass)) } @Test fun `Obtains method return type nullability`() { - assertTrue(isNullable(-1, methodNullable, null)) - assertFalse(isNullable(-1, methodNonNullable, null)) + assertTrue(isNullable(MethodParameter(nullableMethod, -1))) + assertFalse(isNullable(MethodParameter(nonNullableMethod, -1))) } @Test fun `Obtains method parameter nullability`() { - assertTrue(isNullable(0, methodNullable, null)) - assertFalse(isNullable(1, methodNullable, null)) + assertTrue(isNullable(MethodParameter(nullableMethod, 0))) + assertFalse(isNullable(MethodParameter(nonNullableMethod, 0))) } @Suppress("unused", "unused_parameter") - fun methodNullable(p1: String?, p2: Long): Int? = 42 + fun nullable(p1: String?): Int? = 42 @Suppress("unused", "unused_parameter") - fun methodNonNullable(p1: String?, p2: Long): Int = 42 + fun nonNullable(p1: String): Int = 42 + } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/AbstractNamedValueMethodArgumentResolver.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/AbstractNamedValueMethodArgumentResolver.java index d7e730d644f..8f26c0c9e8c 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/AbstractNamedValueMethodArgumentResolver.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/AbstractNamedValueMethodArgumentResolver.java @@ -98,7 +98,7 @@ public abstract class AbstractNamedValueMethodArgumentResolver implements Handle if (namedValueInfo.defaultValue != null) { arg = resolveStringValue(namedValueInfo.defaultValue); } - else if (namedValueInfo.required && !nestedParameter.isOptional() && !nestedParameter.isNullable()) { + else if (namedValueInfo.required && !nestedParameter.isOptional()) { handleMissingValue(namedValueInfo.name, nestedParameter, message); } arg = handleNullValue(namedValueInfo.name, arg, nestedParameter.getNestedParameterType()); diff --git a/spring-messaging/src/test/kotlin/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandlerKotlinTests.kt b/spring-messaging/src/test/kotlin/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandlerKotlinTests.kt new file mode 100644 index 00000000000..ea43eae237c --- /dev/null +++ b/spring-messaging/src/test/kotlin/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandlerKotlinTests.kt @@ -0,0 +1,197 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.annotation.support + +import java.util.Collections +import java.util.HashMap + +import org.junit.Before +import org.junit.Test +import org.mockito.Mock +import org.mockito.MockitoAnnotations + +import org.springframework.context.support.StaticApplicationContext +import org.springframework.messaging.Message +import org.springframework.messaging.MessageChannel +import org.springframework.messaging.SubscribableChannel +import org.springframework.messaging.converter.MessageConverter +import org.springframework.messaging.handler.annotation.Header +import org.springframework.messaging.handler.annotation.MessageMapping +import org.springframework.messaging.simp.SimpMessageHeaderAccessor +import org.springframework.messaging.simp.SimpMessageSendingOperations +import org.springframework.messaging.simp.SimpMessagingTemplate +import org.springframework.messaging.support.MessageBuilder +import org.springframework.stereotype.Controller + +import org.junit.Assert.* +import org.springframework.messaging.MessageHandlingException +import org.springframework.messaging.handler.annotation.MessageExceptionHandler + +/** + * Kotlin test fixture for [SimpAnnotationMethodMessageHandler]. + * + * @author Sebastien Deleuze + */ +class SimpAnnotationMethodMessageHandlerKotlinTests { + + + lateinit var messageHandler: TestSimpAnnotationMethodMessageHandler + + lateinit var testController: TestController + + @Mock + lateinit var channel: SubscribableChannel + + @Mock + lateinit var converter: MessageConverter + + @Before + fun setup() { + MockitoAnnotations.initMocks(this) + val brokerTemplate = SimpMessagingTemplate(channel) + brokerTemplate.messageConverter = converter + messageHandler = TestSimpAnnotationMethodMessageHandler(brokerTemplate, channel, channel) + messageHandler.applicationContext = StaticApplicationContext() + messageHandler.afterPropertiesSet() + testController = TestController() + } + + @Test + fun nullableHeaderWithHeader() { + val message = createMessage("/nullableHeader", Collections.singletonMap("foo", "bar")) + messageHandler.registerHandler(testController) + messageHandler.handleMessage(message) + assertNull(testController.exception) + assertEquals("bar", testController.header) + } + + @Test + fun nullableHeaderWithoutHeader() { + val message = createMessage("/nullableHeader", Collections.emptyMap()) + messageHandler.registerHandler(testController) + messageHandler.handleMessage(message) + assertNull(testController.exception) + assertNull(testController.header) + } + + @Test + fun nonNullableHeaderWithHeader() { + val message = createMessage("/nonNullableHeader", Collections.singletonMap("foo", "bar")) + messageHandler.registerHandler(testController) + messageHandler.handleMessage(message) + assertEquals("bar", testController.header) + } + + @Test + fun nonNullableHeaderWithoutHeader() { + val message = createMessage("/nonNullableHeader", Collections.emptyMap()) + messageHandler.registerHandler(testController) + messageHandler.handleMessage(message) + assertNotNull(testController.exception) + assertTrue(testController.exception is MessageHandlingException) + } + + @Test + fun nullableHeaderNotRequiredWithHeader() { + val message = createMessage("/nullableHeaderNotRequired", Collections.singletonMap("foo", "bar")) + messageHandler.registerHandler(testController) + messageHandler.handleMessage(message) + assertNull(testController.exception) + assertEquals("bar", testController.header) + } + + @Test + fun nullableHeaderNotRequiredWithoutHeader() { + val message = createMessage("/nullableHeaderNotRequired", Collections.emptyMap()) + messageHandler.registerHandler(testController) + messageHandler.handleMessage(message) + assertNull(testController.exception) + assertNull(testController.header) + } + + @Test + fun nonNullableHeaderNotRequiredWithHeader() { + val message = createMessage("/nonNullableHeaderNotRequired", Collections.singletonMap("foo", "bar")) + messageHandler.registerHandler(testController) + messageHandler.handleMessage(message) + assertEquals("bar", testController.header) + } + + @Test + fun nonNullableHeaderNotRequiredWithoutHeader() { + val message = createMessage("/nonNullableHeaderNotRequired", Collections.emptyMap()) + messageHandler.registerHandler(testController) + messageHandler.handleMessage(message) + assertNotNull(testController.exception) + assertTrue(testController.exception is IllegalArgumentException) + } + + private fun createMessage(destination: String, headers: Map): Message { + val accessor = SimpMessageHeaderAccessor.create() + accessor.sessionId = "session1" + accessor.sessionAttributes = HashMap() + accessor.destination = destination + for (entry in headers.entries) accessor.setHeader(entry.key, entry.value) + + return MessageBuilder.withPayload(ByteArray(0)).setHeaders(accessor).build() + } + + class TestSimpAnnotationMethodMessageHandler(brokerTemplate: SimpMessageSendingOperations, + clientInboundChannel: SubscribableChannel, + clientOutboundChannel: MessageChannel) : + SimpAnnotationMethodMessageHandler(clientInboundChannel, clientOutboundChannel, brokerTemplate) { + + fun registerHandler(handler: Any?) { + super.detectHandlerMethods(handler) + } + } + + @Suppress("unused") + @Controller + @MessageMapping + class TestController { + + var header: String? = null + var exception: Throwable? = null + + @MessageMapping("/nullableHeader") + fun nullableHeader(@Header("foo") foo: String?) { + header = foo + } + + @MessageMapping("/nonNullableHeader") + fun nonNullableHeader(@Header("foo") foo: String) { + header = foo + } + + @MessageMapping("/nullableHeaderNotRequired") + fun nullableHeaderNotRequired(@Header("foo", required = false) foo: String?) { + header = foo + } + + @MessageMapping("/nonNullableHeaderNotRequired") + fun nonNullableHeaderNotRequired(@Header("foo", required = false) foo: String) { + header = foo + } + + @MessageExceptionHandler + fun handleIllegalArgumentException(exception: Throwable) { + this.exception = exception + } + } + +} diff --git a/spring-web-reactive/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/RequestParamMethodArgumentResolverKotlinTests.kt b/spring-web-reactive/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/RequestParamMethodArgumentResolverKotlinTests.kt new file mode 100644 index 00000000000..2d83546f12f --- /dev/null +++ b/spring-web-reactive/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/RequestParamMethodArgumentResolverKotlinTests.kt @@ -0,0 +1,119 @@ +package org.springframework.web.reactive.result.method.annotation + +import org.junit.Before +import org.junit.Test +import org.springframework.core.MethodParameter +import org.springframework.core.annotation.SynthesizingMethodParameter +import org.springframework.format.support.DefaultFormattingConversionService +import org.springframework.http.HttpMethod +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse +import org.springframework.util.ReflectionUtils +import org.springframework.web.bind.annotation.RequestParam +import org.springframework.web.bind.support.ConfigurableWebBindingInitializer +import org.springframework.web.reactive.BindingContext +import org.springframework.web.server.ServerWebExchange +import org.springframework.web.server.ServerWebInputException +import org.springframework.web.server.adapter.DefaultServerWebExchange +import org.springframework.web.server.session.MockWebSessionManager +import reactor.test.StepVerifier + +/** + * Kotlin test fixture for [RequestParamMethodArgumentResolver]. + * + * @author Sebastien Deleuze + */ +class RequestParamMethodArgumentResolverKotlinTests { + + lateinit var resolver: RequestParamMethodArgumentResolver + lateinit var exchange: ServerWebExchange + lateinit var bindingContext: BindingContext + + lateinit var nullableParamRequired: MethodParameter + lateinit var nullableParamNotRequired: MethodParameter + lateinit var nonNullableParamRequired: MethodParameter + lateinit var nonNullableParamNotRequired: MethodParameter + + + @Before + fun setup() { + resolver = RequestParamMethodArgumentResolver(null, true) + val request = MockServerHttpRequest(HttpMethod.GET, "/") + val sessionManager = MockWebSessionManager() + exchange = DefaultServerWebExchange(request, MockServerHttpResponse(), sessionManager) + val initializer = ConfigurableWebBindingInitializer() + initializer.conversionService = DefaultFormattingConversionService() + bindingContext = BindingContext(initializer) + + val method = ReflectionUtils.findMethod(javaClass, "handle", String::class.java, + String::class.java, String::class.java, String::class.java) + + nullableParamRequired = SynthesizingMethodParameter(method, 0) + nullableParamNotRequired = SynthesizingMethodParameter(method, 1) + nonNullableParamRequired = SynthesizingMethodParameter(method, 2) + nonNullableParamNotRequired = SynthesizingMethodParameter(method, 3) + } + + @Test + fun resolveNullableRequiredWithParameter() { + exchange.request.queryParams.set("name", "123") + var result = resolver.resolveArgument(nullableParamRequired, bindingContext, exchange) + StepVerifier.create(result).expectNext("123").expectComplete().verify() + } + + @Test + fun resolveNullableRequiredWithoutParameter() { + var result = resolver.resolveArgument(nullableParamRequired, bindingContext, exchange) + StepVerifier.create(result).expectComplete().verify() + } + + @Test + fun resolveNullableNotRequiredWithParameter() { + exchange.request.queryParams.set("name", "123") + var result = resolver.resolveArgument(nullableParamNotRequired, bindingContext, exchange) + StepVerifier.create(result).expectNext("123").expectComplete().verify() + } + + @Test + fun resolveNullableNotRequiredWithoutParameter() { + var result = resolver.resolveArgument(nullableParamNotRequired, bindingContext, exchange) + StepVerifier.create(result).expectComplete().verify() + } + + @Test + fun resolveNonNullableRequiredWithParameter() { + exchange.request.queryParams.set("name", "123") + var result = resolver.resolveArgument(nonNullableParamRequired, bindingContext, exchange) + StepVerifier.create(result).expectNext("123").expectComplete().verify() + } + + @Test + fun resolveNonNullableRequiredWithoutParameter() { + var result = resolver.resolveArgument(nonNullableParamRequired, bindingContext, exchange) + StepVerifier.create(result).expectError(ServerWebInputException::class.java).verify() + } + + @Test + fun resolveNonNullableNotRequiredWithParameter() { + exchange.request.queryParams.set("name", "123") + var result = resolver.resolveArgument(nonNullableParamNotRequired, bindingContext, exchange) + StepVerifier.create(result).expectNext("123").expectComplete().verify() + } + + @Test + fun resolveNonNullableNotRequiredWithoutParameter() { + var result = resolver.resolveArgument(nonNullableParamNotRequired, bindingContext, exchange) + StepVerifier.create(result).expectComplete().verify() + } + + + @Suppress("unused_parameter") + fun handle( + @RequestParam("name") nullableParamRequired: String?, + @RequestParam("name", required = false) nullableParamNotRequired: String?, + @RequestParam("name") nonNullableParamRequired: String, + @RequestParam("name", required = false) nonNullableParamNotRequired: String) { + } + +} + diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/AbstractNamedValueMethodArgumentResolver.java b/spring-web/src/main/java/org/springframework/web/method/annotation/AbstractNamedValueMethodArgumentResolver.java index 104ad867eb8..678fc300d06 100644 --- a/spring-web/src/main/java/org/springframework/web/method/annotation/AbstractNamedValueMethodArgumentResolver.java +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/AbstractNamedValueMethodArgumentResolver.java @@ -100,7 +100,7 @@ public abstract class AbstractNamedValueMethodArgumentResolver implements Handle if (namedValueInfo.defaultValue != null) { arg = resolveStringValue(namedValueInfo.defaultValue); } - else if (namedValueInfo.required && !nestedParameter.isOptional() && !nestedParameter.isNullable()) { + else if (namedValueInfo.required && !nestedParameter.isOptional()) { handleMissingValue(namedValueInfo.name, nestedParameter, webRequest); } arg = handleNullValue(namedValueInfo.name, arg, nestedParameter.getNestedParameterType()); diff --git a/spring-web/src/test/kotlin/org/springframework/web/method/annotation/RequestParamMethodArgumentResolverKotlinTests.kt b/spring-web/src/test/kotlin/org/springframework/web/method/annotation/RequestParamMethodArgumentResolverKotlinTests.kt new file mode 100644 index 00000000000..451307203a0 --- /dev/null +++ b/spring-web/src/test/kotlin/org/springframework/web/method/annotation/RequestParamMethodArgumentResolverKotlinTests.kt @@ -0,0 +1,218 @@ +package org.springframework.web.method.annotation + + +import org.junit.Assert.assertEquals +import org.junit.Assert.assertNull +import org.junit.Before +import org.junit.Test +import org.springframework.core.MethodParameter +import org.springframework.core.annotation.SynthesizingMethodParameter +import org.springframework.core.convert.support.DefaultConversionService +import org.springframework.http.HttpMethod +import org.springframework.http.MediaType +import org.springframework.mock.web.test.MockHttpServletRequest +import org.springframework.mock.web.test.MockHttpServletResponse +import org.springframework.mock.web.test.MockMultipartFile +import org.springframework.mock.web.test.MockMultipartHttpServletRequest +import org.springframework.util.ReflectionUtils +import org.springframework.web.bind.MissingServletRequestParameterException +import org.springframework.web.bind.annotation.RequestParam +import org.springframework.web.bind.support.ConfigurableWebBindingInitializer +import org.springframework.web.bind.support.DefaultDataBinderFactory +import org.springframework.web.bind.support.WebDataBinderFactory +import org.springframework.web.context.request.NativeWebRequest +import org.springframework.web.context.request.ServletWebRequest +import org.springframework.web.multipart.MultipartFile +import org.springframework.web.multipart.support.MissingServletRequestPartException + +/** + * Kotlin test fixture for [RequestParamMethodArgumentResolver]. + * + * @author Sebastien Deleuze + */ +class RequestParamMethodArgumentResolverKotlinTests { + + lateinit var resolver: RequestParamMethodArgumentResolver + lateinit var webRequest: NativeWebRequest + lateinit var binderFactory: WebDataBinderFactory + lateinit var request: MockHttpServletRequest + + lateinit var nullableParamRequired: MethodParameter + lateinit var nullableParamNotRequired: MethodParameter + lateinit var nonNullableParamRequired: MethodParameter + lateinit var nonNullableParamNotRequired: MethodParameter + + lateinit var nullableMultipartParamRequired: MethodParameter + lateinit var nullableMultipartParamNotRequired: MethodParameter + lateinit var nonNullableMultipartParamRequired: MethodParameter + lateinit var nonNullableMultipartParamNotRequired: MethodParameter + + + @Before + fun setup() { + resolver = RequestParamMethodArgumentResolver(null, true) + request = MockHttpServletRequest() + val initializer = ConfigurableWebBindingInitializer() + initializer.conversionService = DefaultConversionService() + binderFactory = DefaultDataBinderFactory(initializer) + webRequest = ServletWebRequest(request, MockHttpServletResponse()) + + val method = ReflectionUtils.findMethod(javaClass, "handle", String::class.java, + String::class.java, String::class.java, String::class.java, + MultipartFile::class.java, MultipartFile::class.java, + MultipartFile::class.java, MultipartFile::class.java) + + nullableParamRequired = SynthesizingMethodParameter(method, 0) + nullableParamNotRequired = SynthesizingMethodParameter(method, 1) + nonNullableParamRequired = SynthesizingMethodParameter(method, 2) + nonNullableParamNotRequired = SynthesizingMethodParameter(method, 3) + + nullableMultipartParamRequired = SynthesizingMethodParameter(method, 4) + nullableMultipartParamNotRequired = SynthesizingMethodParameter(method, 5) + nonNullableMultipartParamRequired = SynthesizingMethodParameter(method, 6) + nonNullableMultipartParamNotRequired = SynthesizingMethodParameter(method, 7) + } + + @Test + fun resolveNullableRequiredWithParameter() { + request.addParameter("name", "123") + var result = resolver.resolveArgument(nullableParamRequired, null, webRequest, binderFactory) + assertEquals("123", result) + } + + @Test + fun resolveNullableRequiredWithoutParameter() { + var result = resolver.resolveArgument(nullableParamRequired, null, webRequest, binderFactory) + assertNull(result) + } + + @Test + fun resolveNullableNotRequiredWithParameter() { + request.addParameter("name", "123") + var result = resolver.resolveArgument(nullableParamNotRequired, null, webRequest, binderFactory) + assertEquals("123", result) + } + + @Test + fun resolveNullableNotRequiredWithoutParameter() { + var result = resolver.resolveArgument(nullableParamNotRequired, null, webRequest, binderFactory) + assertNull(result) + } + + @Test + fun resolveNonNullableRequiredWithParameter() { + request.addParameter("name", "123") + var result = resolver.resolveArgument(nonNullableParamRequired, null, webRequest, binderFactory) + assertEquals("123", result) + } + + @Test(expected = MissingServletRequestParameterException::class) + fun resolveNonNullableRequiredWithoutParameter() { + resolver.resolveArgument(nonNullableParamRequired, null, webRequest, binderFactory) + } + + @Test + fun resolveNonNullableNotRequiredWithParameter() { + request.addParameter("name", "123") + var result = resolver.resolveArgument(nonNullableParamNotRequired, null, webRequest, binderFactory) + assertEquals("123", result) + } + + @Test(expected = TypeCastException::class) + fun resolveNonNullableNotRequiredWithoutParameter() { + resolver.resolveArgument(nonNullableParamNotRequired, null, webRequest, binderFactory) as String + } + + + @Test + fun resolveNullableRequiredWithMultipartParameter() { + val request = MockMultipartHttpServletRequest() + val expected = MockMultipartFile("mfile", "Hello World".toByteArray()) + request.addFile(expected) + webRequest = ServletWebRequest(request) + + var result = resolver.resolveArgument(nullableMultipartParamRequired, null, webRequest, binderFactory) + assertEquals(expected, result) + } + + @Test + fun resolveNullableRequiredWithoutMultipartParameter() { + request.method = HttpMethod.POST.name + request.contentType = MediaType.MULTIPART_FORM_DATA_VALUE + + var result = resolver.resolveArgument(nullableMultipartParamRequired, null, webRequest, binderFactory) + assertNull(result) + } + + @Test + fun resolveNullableNotRequiredWithMultipartParameter() { + val request = MockMultipartHttpServletRequest() + val expected = MockMultipartFile("mfile", "Hello World".toByteArray()) + request.addFile(expected) + webRequest = ServletWebRequest(request) + + var result = resolver.resolveArgument(nullableMultipartParamNotRequired, null, webRequest, binderFactory) + assertEquals(expected, result) + } + + @Test + fun resolveNullableNotRequiredWithoutMultipartParameter() { + request.method = HttpMethod.POST.name + request.contentType = MediaType.MULTIPART_FORM_DATA_VALUE + + var result = resolver.resolveArgument(nullableMultipartParamNotRequired, null, webRequest, binderFactory) + assertNull(result) + } + + @Test + fun resolveNonNullableRequiredWithMultipartParameter() { + val request = MockMultipartHttpServletRequest() + val expected = MockMultipartFile("mfile", "Hello World".toByteArray()) + request.addFile(expected) + webRequest = ServletWebRequest(request) + + var result = resolver.resolveArgument(nonNullableMultipartParamRequired, null, webRequest, binderFactory) + assertEquals(expected, result) + } + + @Test(expected = MissingServletRequestPartException::class) + fun resolveNonNullableRequiredWithoutMultipartParameter() { + request.method = HttpMethod.POST.name + request.contentType = MediaType.MULTIPART_FORM_DATA_VALUE + resolver.resolveArgument(nonNullableMultipartParamRequired, null, webRequest, binderFactory) + } + + @Test + fun resolveNonNullableNotRequiredWithMultipartParameter() { + val request = MockMultipartHttpServletRequest() + val expected = MockMultipartFile("mfile", "Hello World".toByteArray()) + request.addFile(expected) + webRequest = ServletWebRequest(request) + + var result = resolver.resolveArgument(nonNullableMultipartParamNotRequired, null, webRequest, binderFactory) + assertEquals(expected, result) + } + + @Test(expected = TypeCastException::class) + fun resolveNonNullableNotRequiredWithoutMultipartParameter() { + request.method = HttpMethod.POST.name + request.contentType = MediaType.MULTIPART_FORM_DATA_VALUE + resolver.resolveArgument(nonNullableMultipartParamNotRequired, null, webRequest, binderFactory) as MultipartFile + } + + + @Suppress("unused_parameter") + fun handle( + @RequestParam("name") nullableParamRequired: String?, + @RequestParam("name", required = false) nullableParamNotRequired: String?, + @RequestParam("name") nonNullableParamRequired: String, + @RequestParam("name", required = false) nonNullableParamNotRequired: String, + + @RequestParam("mfile") nullableMultipartParamRequired: MultipartFile?, + @RequestParam("mfile", required = false) nullableMultipartParamNotRequired: MultipartFile?, + @RequestParam("mfile") nonNullableMultipartParamRequired: MultipartFile, + @RequestParam("mfile", required = false) nonNullableMultipartParamNotRequired: MultipartFile) { + } + +} + diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolver.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolver.java index b198d002474..4cef39e45b5 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolver.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolver.java @@ -113,8 +113,7 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM HttpServletRequest servletRequest = request.getNativeRequest(HttpServletRequest.class); RequestPart requestPart = parameter.getParameterAnnotation(RequestPart.class); - boolean isRequired = ((requestPart == null || requestPart.required()) && !parameter.isOptional() && - !parameter.isNullable()); + boolean isRequired = ((requestPart == null || requestPart.required()) && !parameter.isOptional()); String name = getPartName(parameter, requestPart); parameter = parameter.nestedIfOptional(); @@ -157,7 +156,7 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM throw new MissingServletRequestPartException(name); } } - if (parameter.isOptional()) { + if (parameter.getParameterType() == Optional.class) { if (arg == null || (arg instanceof Collection && ((Collection) arg).isEmpty()) || (arg instanceof Object[] && ((Object[]) arg).length == 0)) { arg = Optional.empty();