Polish Kotlin nullable support

This commit polishes Kotlin nullable support by reusing
MethodParameter#isOptional() instead of adding a new
MethodParameter#isNullable() method, adds
Kotlin tests and introduces Spring Web Reactive
support.

Issue: SPR-14165
This commit is contained in:
Sebastien Deleuze 2016-11-21 23:59:41 +01:00
parent fada91e538
commit a143b57d4b
10 changed files with 609 additions and 63 deletions

View File

@ -5,7 +5,7 @@ buildscript {
dependencies {
classpath("org.springframework.build.gradle:propdeps-plugin:0.0.7")
classpath("org.asciidoctor:asciidoctor-gradle-plugin:1.5.3")
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:1.0.4"
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:1.0.5-2"
classpath("io.spring.gradle:docbook-reference-plugin:0.3.1")
}
}
@ -71,7 +71,7 @@ configure(allprojects) { project ->
ext.junitPlatformVersion = '1.0.0-M2'
ext.log4jVersion = '2.7'
ext.nettyVersion = "4.1.6.Final"
ext.kotlinVersion = "1.0.4"
ext.kotlinVersion = "1.0.5-2"
ext.okhttpVersion = "2.7.5"
ext.okhttp3Version = "3.4.2"
ext.poiVersion = "3.15"
@ -569,6 +569,8 @@ project("spring-oxm") {
project("spring-messaging") {
description = "Spring Messaging"
apply plugin: "kotlin"
dependencies {
compile(project(":spring-beans"))
compile(project(":spring-core"))
@ -604,6 +606,8 @@ project("spring-messaging") {
testCompile("io.netty:netty-all:${nettyVersion}")
testCompile("org.xmlunit:xmlunit-matchers:${xmlunitVersion}")
testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}")
testCompile("org.jetbrains.kotlin:kotlin-reflect:${kotlinVersion}")
testCompile("org.jetbrains.kotlin:kotlin-stdlib:${kotlinVersion}")
testRuntime("javax.activation:activation:${activationApiVersion}")
testRuntime("com.sun.xml.bind:jaxb-core:${jaxbVersion}")
testRuntime("com.sun.xml.bind:jaxb-impl:${jaxbVersion}")
@ -710,7 +714,9 @@ project("spring-context-indexer") {
project("spring-web") {
description = "Spring Web"
apply plugin: "groovy"
apply plugin: "kotlin"
dependencies {
compile(project(":spring-aop")) // for JaxWsPortProxyFactoryBean
@ -781,6 +787,8 @@ project("spring-web") {
testCompile("com.squareup.okhttp3:mockwebserver:${okhttp3Version}")
testCompile("org.xmlunit:xmlunit-matchers:${xmlunitVersion}")
testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}")
testCompile("org.jetbrains.kotlin:kotlin-reflect:${kotlinVersion}")
testCompile("org.jetbrains.kotlin:kotlin-stdlib:${kotlinVersion}")
testRuntime("com.sun.mail:javax.mail:${javamailVersion}")
testRuntime("com.sun.xml.bind:jaxb-core:${jaxbVersion}")
testRuntime("com.sun.xml.bind:jaxb-impl:${jaxbVersion}")
@ -797,6 +805,8 @@ project("spring-web") {
project("spring-web-reactive") {
description = "Spring Web Reactive"
apply plugin: "kotlin"
dependencies {
compile(project(":spring-core"))
compile(project(":spring-web"))
@ -828,6 +838,8 @@ project("spring-web-reactive") {
testCompile("com.fasterxml:aalto-xml:1.0.0")
testCompile("org.xmlunit:xmlunit-matchers:${xmlunitVersion}")
testCompile("org.slf4j:slf4j-jcl:${slf4jVersion}")
testCompile("org.jetbrains.kotlin:kotlin-reflect:${kotlinVersion}")
testCompile("org.jetbrains.kotlin:kotlin-stdlib:${kotlinVersion}")
testRuntime("javax.el:javax.el-api:${elApiVersion}")
testRuntime("org.glassfish:javax.el:3.0.1-b08")
testRuntime("com.sun.xml.bind:jaxb-core:${jaxbVersion}")

View File

@ -310,12 +310,12 @@ public class MethodParameter {
}
/**
* Return whether this method parameter is declared as optional
* in the form of Java 8's {@link java.util.Optional}.
* Return whether this method indicates a parameter which is not required
* (either in the form of Java 8's {@link java.util.Optional} or Kotlin nullable type).
* @since 4.3
*/
public boolean isOptional() {
return (getParameterType() == Optional.class);
return (getParameterType() == Optional.class || KotlinUtils.isNullable(this));
}
/**
@ -327,18 +327,7 @@ public class MethodParameter {
* @see #nested()
*/
public MethodParameter nestedIfOptional() {
return (isOptional() ? nested() : this);
}
/**
* Return whether this method parameter is declared as a "nullable" value, if supported by
* the underlying language. Currently the only supported language is Kotlin.
* @since 5.0
*/
public boolean isNullable() {
return KotlinUtils.isKotlinPresent() &&
KotlinUtils.isKotlinClass(getContainingClass()) &&
KotlinUtils.isNullable(this.parameterIndex, this.method, this.constructor);
return (getParameterType() == Optional.class ? nested() : this);
}
/**

View File

@ -22,7 +22,6 @@ import kotlin.reflect.KParameter;
import kotlin.reflect.jvm.ReflectJvmMapping;
import org.springframework.core.MethodParameter;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.util.List;
import java.util.stream.Collectors;
@ -31,44 +30,55 @@ import java.util.stream.Collectors;
* Miscellaneous Kotlin utility methods.
*
* @author Raman Gupta
* @author Sebastien Deleuze
* @since 5.0
*/
public class KotlinUtils {
public abstract class KotlinUtils {
private static final boolean kotlinPresent;
static {
kotlinPresent = ClassUtils.isPresent("kotlin.Unit", MethodParameter.class.getClassLoader());
}
private static final boolean kotlinPresent = ClassUtils.isPresent("kotlin.Unit", KotlinUtils.class.getClassLoader());
/**
* Return whether Kotlin is available on the classpath or not.
*/
public static boolean isKotlinPresent() {
return kotlinPresent;
}
/**
* Return whether the specified type is a Kotlin class or not.
*/
public static boolean isKotlinClass(Class<?> type) {
return type != null && type.getDeclaredAnnotation(Metadata.class) != null;
Assert.notNull(type, "Type must not be null");
return isKotlinPresent() && type.getDeclaredAnnotation(Metadata.class) != null;
}
public static boolean isNullable(int parameterIndex, Method method, Constructor<?> constructor) {
if(parameterIndex < 0) {
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
return function != null && function.getReturnType().isMarkedNullable();
} else {
KFunction<?> function = method != null ?
ReflectJvmMapping.getKotlinFunction(method) :
ReflectJvmMapping.getKotlinFunction(constructor);
if(function != null) {
@SuppressWarnings("unchecked")
List<KParameter> parameters = function.getParameters();
return parameters
.stream()
.filter(p -> KParameter.Kind.VALUE.equals(p.getKind()))
.collect(Collectors.toList())
.get(parameterIndex)
.getType()
.isMarkedNullable();
/**
* Check whether the specified {@link MethodParameter} represents a nullable Kotlin type or not.
*/
public static boolean isNullable(MethodParameter methodParameter) {
Method method = methodParameter.getMethod();
int parameterIndex = methodParameter.getParameterIndex();
if (isKotlinClass(methodParameter.getContainingClass())) {
if (parameterIndex < 0) {
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
return function != null && function.getReturnType().isMarkedNullable();
}
else {
KFunction<?> function = (method != null ? ReflectJvmMapping.getKotlinFunction(method) :
ReflectJvmMapping.getKotlinFunction(methodParameter.getConstructor()));
if (function != null) {
List<KParameter> parameters = function.getParameters();
return parameters
.stream()
.filter(p -> KParameter.Kind.VALUE.equals(p.getKind()))
.collect(Collectors.toList())
.get(parameterIndex)
.getType()
.isMarkedNullable();
}
}
return false;
}
return false;
}
}

View File

@ -11,20 +11,22 @@ import org.springframework.core.MethodParameter
import org.springframework.util.KotlinUtils.*
/**
* Unit tests for [KotlinUtils].
* Test fixture for [KotlinUtils].
*
* @author Raman Gupta
* @author Sebastien Deleuze
*/
class KotlinUtilsTests {
private lateinit var methodNullable: Method
private lateinit var methodNonNullable: Method
lateinit var nullableMethod: Method
lateinit var nonNullableMethod: Method
@Before
@Throws(NoSuchMethodException::class)
fun setUp() {
methodNullable = javaClass.getMethod("methodNullable", String::class.java, java.lang.Long.TYPE)
methodNonNullable = javaClass.getMethod("methodNonNullable", String::class.java, java.lang.Long.TYPE)
fun setup() {
nullableMethod = javaClass.getMethod("nullable", String::class.java)
nonNullableMethod = javaClass.getMethod("nonNullable", String::class.java)
}
@Test
@ -35,26 +37,26 @@ class KotlinUtilsTests {
@Test
fun `Are kotlin classes detected`() {
assertFalse(isKotlinClass(null))
assertFalse(isKotlinClass(MethodParameter::class.java))
assertTrue(isKotlinClass(javaClass))
}
@Test
fun `Obtains method return type nullability`() {
assertTrue(isNullable(-1, methodNullable, null))
assertFalse(isNullable(-1, methodNonNullable, null))
assertTrue(isNullable(MethodParameter(nullableMethod, -1)))
assertFalse(isNullable(MethodParameter(nonNullableMethod, -1)))
}
@Test
fun `Obtains method parameter nullability`() {
assertTrue(isNullable(0, methodNullable, null))
assertFalse(isNullable(1, methodNullable, null))
assertTrue(isNullable(MethodParameter(nullableMethod, 0)))
assertFalse(isNullable(MethodParameter(nonNullableMethod, 0)))
}
@Suppress("unused", "unused_parameter")
fun methodNullable(p1: String?, p2: Long): Int? = 42
fun nullable(p1: String?): Int? = 42
@Suppress("unused", "unused_parameter")
fun methodNonNullable(p1: String?, p2: Long): Int = 42
fun nonNullable(p1: String): Int = 42
}

View File

@ -98,7 +98,7 @@ public abstract class AbstractNamedValueMethodArgumentResolver implements Handle
if (namedValueInfo.defaultValue != null) {
arg = resolveStringValue(namedValueInfo.defaultValue);
}
else if (namedValueInfo.required && !nestedParameter.isOptional() && !nestedParameter.isNullable()) {
else if (namedValueInfo.required && !nestedParameter.isOptional()) {
handleMissingValue(namedValueInfo.name, nestedParameter, message);
}
arg = handleNullValue(namedValueInfo.name, arg, nestedParameter.getNestedParameterType());

View File

@ -0,0 +1,197 @@
/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.annotation.support
import java.util.Collections
import java.util.HashMap
import org.junit.Before
import org.junit.Test
import org.mockito.Mock
import org.mockito.MockitoAnnotations
import org.springframework.context.support.StaticApplicationContext
import org.springframework.messaging.Message
import org.springframework.messaging.MessageChannel
import org.springframework.messaging.SubscribableChannel
import org.springframework.messaging.converter.MessageConverter
import org.springframework.messaging.handler.annotation.Header
import org.springframework.messaging.handler.annotation.MessageMapping
import org.springframework.messaging.simp.SimpMessageHeaderAccessor
import org.springframework.messaging.simp.SimpMessageSendingOperations
import org.springframework.messaging.simp.SimpMessagingTemplate
import org.springframework.messaging.support.MessageBuilder
import org.springframework.stereotype.Controller
import org.junit.Assert.*
import org.springframework.messaging.MessageHandlingException
import org.springframework.messaging.handler.annotation.MessageExceptionHandler
/**
* Kotlin test fixture for [SimpAnnotationMethodMessageHandler].
*
* @author Sebastien Deleuze
*/
class SimpAnnotationMethodMessageHandlerKotlinTests {
lateinit var messageHandler: TestSimpAnnotationMethodMessageHandler
lateinit var testController: TestController
@Mock
lateinit var channel: SubscribableChannel
@Mock
lateinit var converter: MessageConverter
@Before
fun setup() {
MockitoAnnotations.initMocks(this)
val brokerTemplate = SimpMessagingTemplate(channel)
brokerTemplate.messageConverter = converter
messageHandler = TestSimpAnnotationMethodMessageHandler(brokerTemplate, channel, channel)
messageHandler.applicationContext = StaticApplicationContext()
messageHandler.afterPropertiesSet()
testController = TestController()
}
@Test
fun nullableHeaderWithHeader() {
val message = createMessage("/nullableHeader", Collections.singletonMap("foo", "bar"))
messageHandler.registerHandler(testController)
messageHandler.handleMessage(message)
assertNull(testController.exception)
assertEquals("bar", testController.header)
}
@Test
fun nullableHeaderWithoutHeader() {
val message = createMessage("/nullableHeader", Collections.emptyMap())
messageHandler.registerHandler(testController)
messageHandler.handleMessage(message)
assertNull(testController.exception)
assertNull(testController.header)
}
@Test
fun nonNullableHeaderWithHeader() {
val message = createMessage("/nonNullableHeader", Collections.singletonMap("foo", "bar"))
messageHandler.registerHandler(testController)
messageHandler.handleMessage(message)
assertEquals("bar", testController.header)
}
@Test
fun nonNullableHeaderWithoutHeader() {
val message = createMessage("/nonNullableHeader", Collections.emptyMap())
messageHandler.registerHandler(testController)
messageHandler.handleMessage(message)
assertNotNull(testController.exception)
assertTrue(testController.exception is MessageHandlingException)
}
@Test
fun nullableHeaderNotRequiredWithHeader() {
val message = createMessage("/nullableHeaderNotRequired", Collections.singletonMap("foo", "bar"))
messageHandler.registerHandler(testController)
messageHandler.handleMessage(message)
assertNull(testController.exception)
assertEquals("bar", testController.header)
}
@Test
fun nullableHeaderNotRequiredWithoutHeader() {
val message = createMessage("/nullableHeaderNotRequired", Collections.emptyMap())
messageHandler.registerHandler(testController)
messageHandler.handleMessage(message)
assertNull(testController.exception)
assertNull(testController.header)
}
@Test
fun nonNullableHeaderNotRequiredWithHeader() {
val message = createMessage("/nonNullableHeaderNotRequired", Collections.singletonMap("foo", "bar"))
messageHandler.registerHandler(testController)
messageHandler.handleMessage(message)
assertEquals("bar", testController.header)
}
@Test
fun nonNullableHeaderNotRequiredWithoutHeader() {
val message = createMessage("/nonNullableHeaderNotRequired", Collections.emptyMap())
messageHandler.registerHandler(testController)
messageHandler.handleMessage(message)
assertNotNull(testController.exception)
assertTrue(testController.exception is IllegalArgumentException)
}
private fun createMessage(destination: String, headers: Map<String, String?>): Message<ByteArray> {
val accessor = SimpMessageHeaderAccessor.create()
accessor.sessionId = "session1"
accessor.sessionAttributes = HashMap()
accessor.destination = destination
for (entry in headers.entries) accessor.setHeader(entry.key, entry.value)
return MessageBuilder.withPayload(ByteArray(0)).setHeaders(accessor).build()
}
class TestSimpAnnotationMethodMessageHandler(brokerTemplate: SimpMessageSendingOperations,
clientInboundChannel: SubscribableChannel,
clientOutboundChannel: MessageChannel) :
SimpAnnotationMethodMessageHandler(clientInboundChannel, clientOutboundChannel, brokerTemplate) {
fun registerHandler(handler: Any?) {
super.detectHandlerMethods(handler)
}
}
@Suppress("unused")
@Controller
@MessageMapping
class TestController {
var header: String? = null
var exception: Throwable? = null
@MessageMapping("/nullableHeader")
fun nullableHeader(@Header("foo") foo: String?) {
header = foo
}
@MessageMapping("/nonNullableHeader")
fun nonNullableHeader(@Header("foo") foo: String) {
header = foo
}
@MessageMapping("/nullableHeaderNotRequired")
fun nullableHeaderNotRequired(@Header("foo", required = false) foo: String?) {
header = foo
}
@MessageMapping("/nonNullableHeaderNotRequired")
fun nonNullableHeaderNotRequired(@Header("foo", required = false) foo: String) {
header = foo
}
@MessageExceptionHandler
fun handleIllegalArgumentException(exception: Throwable) {
this.exception = exception
}
}
}

View File

@ -0,0 +1,119 @@
package org.springframework.web.reactive.result.method.annotation
import org.junit.Before
import org.junit.Test
import org.springframework.core.MethodParameter
import org.springframework.core.annotation.SynthesizingMethodParameter
import org.springframework.format.support.DefaultFormattingConversionService
import org.springframework.http.HttpMethod
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest
import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse
import org.springframework.util.ReflectionUtils
import org.springframework.web.bind.annotation.RequestParam
import org.springframework.web.bind.support.ConfigurableWebBindingInitializer
import org.springframework.web.reactive.BindingContext
import org.springframework.web.server.ServerWebExchange
import org.springframework.web.server.ServerWebInputException
import org.springframework.web.server.adapter.DefaultServerWebExchange
import org.springframework.web.server.session.MockWebSessionManager
import reactor.test.StepVerifier
/**
* Kotlin test fixture for [RequestParamMethodArgumentResolver].
*
* @author Sebastien Deleuze
*/
class RequestParamMethodArgumentResolverKotlinTests {
lateinit var resolver: RequestParamMethodArgumentResolver
lateinit var exchange: ServerWebExchange
lateinit var bindingContext: BindingContext
lateinit var nullableParamRequired: MethodParameter
lateinit var nullableParamNotRequired: MethodParameter
lateinit var nonNullableParamRequired: MethodParameter
lateinit var nonNullableParamNotRequired: MethodParameter
@Before
fun setup() {
resolver = RequestParamMethodArgumentResolver(null, true)
val request = MockServerHttpRequest(HttpMethod.GET, "/")
val sessionManager = MockWebSessionManager()
exchange = DefaultServerWebExchange(request, MockServerHttpResponse(), sessionManager)
val initializer = ConfigurableWebBindingInitializer()
initializer.conversionService = DefaultFormattingConversionService()
bindingContext = BindingContext(initializer)
val method = ReflectionUtils.findMethod(javaClass, "handle", String::class.java,
String::class.java, String::class.java, String::class.java)
nullableParamRequired = SynthesizingMethodParameter(method, 0)
nullableParamNotRequired = SynthesizingMethodParameter(method, 1)
nonNullableParamRequired = SynthesizingMethodParameter(method, 2)
nonNullableParamNotRequired = SynthesizingMethodParameter(method, 3)
}
@Test
fun resolveNullableRequiredWithParameter() {
exchange.request.queryParams.set("name", "123")
var result = resolver.resolveArgument(nullableParamRequired, bindingContext, exchange)
StepVerifier.create(result).expectNext("123").expectComplete().verify()
}
@Test
fun resolveNullableRequiredWithoutParameter() {
var result = resolver.resolveArgument(nullableParamRequired, bindingContext, exchange)
StepVerifier.create(result).expectComplete().verify()
}
@Test
fun resolveNullableNotRequiredWithParameter() {
exchange.request.queryParams.set("name", "123")
var result = resolver.resolveArgument(nullableParamNotRequired, bindingContext, exchange)
StepVerifier.create(result).expectNext("123").expectComplete().verify()
}
@Test
fun resolveNullableNotRequiredWithoutParameter() {
var result = resolver.resolveArgument(nullableParamNotRequired, bindingContext, exchange)
StepVerifier.create(result).expectComplete().verify()
}
@Test
fun resolveNonNullableRequiredWithParameter() {
exchange.request.queryParams.set("name", "123")
var result = resolver.resolveArgument(nonNullableParamRequired, bindingContext, exchange)
StepVerifier.create(result).expectNext("123").expectComplete().verify()
}
@Test
fun resolveNonNullableRequiredWithoutParameter() {
var result = resolver.resolveArgument(nonNullableParamRequired, bindingContext, exchange)
StepVerifier.create(result).expectError(ServerWebInputException::class.java).verify()
}
@Test
fun resolveNonNullableNotRequiredWithParameter() {
exchange.request.queryParams.set("name", "123")
var result = resolver.resolveArgument(nonNullableParamNotRequired, bindingContext, exchange)
StepVerifier.create(result).expectNext("123").expectComplete().verify()
}
@Test
fun resolveNonNullableNotRequiredWithoutParameter() {
var result = resolver.resolveArgument(nonNullableParamNotRequired, bindingContext, exchange)
StepVerifier.create(result).expectComplete().verify()
}
@Suppress("unused_parameter")
fun handle(
@RequestParam("name") nullableParamRequired: String?,
@RequestParam("name", required = false) nullableParamNotRequired: String?,
@RequestParam("name") nonNullableParamRequired: String,
@RequestParam("name", required = false) nonNullableParamNotRequired: String) {
}
}

View File

@ -100,7 +100,7 @@ public abstract class AbstractNamedValueMethodArgumentResolver implements Handle
if (namedValueInfo.defaultValue != null) {
arg = resolveStringValue(namedValueInfo.defaultValue);
}
else if (namedValueInfo.required && !nestedParameter.isOptional() && !nestedParameter.isNullable()) {
else if (namedValueInfo.required && !nestedParameter.isOptional()) {
handleMissingValue(namedValueInfo.name, nestedParameter, webRequest);
}
arg = handleNullValue(namedValueInfo.name, arg, nestedParameter.getNestedParameterType());

View File

@ -0,0 +1,218 @@
package org.springframework.web.method.annotation
import org.junit.Assert.assertEquals
import org.junit.Assert.assertNull
import org.junit.Before
import org.junit.Test
import org.springframework.core.MethodParameter
import org.springframework.core.annotation.SynthesizingMethodParameter
import org.springframework.core.convert.support.DefaultConversionService
import org.springframework.http.HttpMethod
import org.springframework.http.MediaType
import org.springframework.mock.web.test.MockHttpServletRequest
import org.springframework.mock.web.test.MockHttpServletResponse
import org.springframework.mock.web.test.MockMultipartFile
import org.springframework.mock.web.test.MockMultipartHttpServletRequest
import org.springframework.util.ReflectionUtils
import org.springframework.web.bind.MissingServletRequestParameterException
import org.springframework.web.bind.annotation.RequestParam
import org.springframework.web.bind.support.ConfigurableWebBindingInitializer
import org.springframework.web.bind.support.DefaultDataBinderFactory
import org.springframework.web.bind.support.WebDataBinderFactory
import org.springframework.web.context.request.NativeWebRequest
import org.springframework.web.context.request.ServletWebRequest
import org.springframework.web.multipart.MultipartFile
import org.springframework.web.multipart.support.MissingServletRequestPartException
/**
* Kotlin test fixture for [RequestParamMethodArgumentResolver].
*
* @author Sebastien Deleuze
*/
class RequestParamMethodArgumentResolverKotlinTests {
lateinit var resolver: RequestParamMethodArgumentResolver
lateinit var webRequest: NativeWebRequest
lateinit var binderFactory: WebDataBinderFactory
lateinit var request: MockHttpServletRequest
lateinit var nullableParamRequired: MethodParameter
lateinit var nullableParamNotRequired: MethodParameter
lateinit var nonNullableParamRequired: MethodParameter
lateinit var nonNullableParamNotRequired: MethodParameter
lateinit var nullableMultipartParamRequired: MethodParameter
lateinit var nullableMultipartParamNotRequired: MethodParameter
lateinit var nonNullableMultipartParamRequired: MethodParameter
lateinit var nonNullableMultipartParamNotRequired: MethodParameter
@Before
fun setup() {
resolver = RequestParamMethodArgumentResolver(null, true)
request = MockHttpServletRequest()
val initializer = ConfigurableWebBindingInitializer()
initializer.conversionService = DefaultConversionService()
binderFactory = DefaultDataBinderFactory(initializer)
webRequest = ServletWebRequest(request, MockHttpServletResponse())
val method = ReflectionUtils.findMethod(javaClass, "handle", String::class.java,
String::class.java, String::class.java, String::class.java,
MultipartFile::class.java, MultipartFile::class.java,
MultipartFile::class.java, MultipartFile::class.java)
nullableParamRequired = SynthesizingMethodParameter(method, 0)
nullableParamNotRequired = SynthesizingMethodParameter(method, 1)
nonNullableParamRequired = SynthesizingMethodParameter(method, 2)
nonNullableParamNotRequired = SynthesizingMethodParameter(method, 3)
nullableMultipartParamRequired = SynthesizingMethodParameter(method, 4)
nullableMultipartParamNotRequired = SynthesizingMethodParameter(method, 5)
nonNullableMultipartParamRequired = SynthesizingMethodParameter(method, 6)
nonNullableMultipartParamNotRequired = SynthesizingMethodParameter(method, 7)
}
@Test
fun resolveNullableRequiredWithParameter() {
request.addParameter("name", "123")
var result = resolver.resolveArgument(nullableParamRequired, null, webRequest, binderFactory)
assertEquals("123", result)
}
@Test
fun resolveNullableRequiredWithoutParameter() {
var result = resolver.resolveArgument(nullableParamRequired, null, webRequest, binderFactory)
assertNull(result)
}
@Test
fun resolveNullableNotRequiredWithParameter() {
request.addParameter("name", "123")
var result = resolver.resolveArgument(nullableParamNotRequired, null, webRequest, binderFactory)
assertEquals("123", result)
}
@Test
fun resolveNullableNotRequiredWithoutParameter() {
var result = resolver.resolveArgument(nullableParamNotRequired, null, webRequest, binderFactory)
assertNull(result)
}
@Test
fun resolveNonNullableRequiredWithParameter() {
request.addParameter("name", "123")
var result = resolver.resolveArgument(nonNullableParamRequired, null, webRequest, binderFactory)
assertEquals("123", result)
}
@Test(expected = MissingServletRequestParameterException::class)
fun resolveNonNullableRequiredWithoutParameter() {
resolver.resolveArgument(nonNullableParamRequired, null, webRequest, binderFactory)
}
@Test
fun resolveNonNullableNotRequiredWithParameter() {
request.addParameter("name", "123")
var result = resolver.resolveArgument(nonNullableParamNotRequired, null, webRequest, binderFactory)
assertEquals("123", result)
}
@Test(expected = TypeCastException::class)
fun resolveNonNullableNotRequiredWithoutParameter() {
resolver.resolveArgument(nonNullableParamNotRequired, null, webRequest, binderFactory) as String
}
@Test
fun resolveNullableRequiredWithMultipartParameter() {
val request = MockMultipartHttpServletRequest()
val expected = MockMultipartFile("mfile", "Hello World".toByteArray())
request.addFile(expected)
webRequest = ServletWebRequest(request)
var result = resolver.resolveArgument(nullableMultipartParamRequired, null, webRequest, binderFactory)
assertEquals(expected, result)
}
@Test
fun resolveNullableRequiredWithoutMultipartParameter() {
request.method = HttpMethod.POST.name
request.contentType = MediaType.MULTIPART_FORM_DATA_VALUE
var result = resolver.resolveArgument(nullableMultipartParamRequired, null, webRequest, binderFactory)
assertNull(result)
}
@Test
fun resolveNullableNotRequiredWithMultipartParameter() {
val request = MockMultipartHttpServletRequest()
val expected = MockMultipartFile("mfile", "Hello World".toByteArray())
request.addFile(expected)
webRequest = ServletWebRequest(request)
var result = resolver.resolveArgument(nullableMultipartParamNotRequired, null, webRequest, binderFactory)
assertEquals(expected, result)
}
@Test
fun resolveNullableNotRequiredWithoutMultipartParameter() {
request.method = HttpMethod.POST.name
request.contentType = MediaType.MULTIPART_FORM_DATA_VALUE
var result = resolver.resolveArgument(nullableMultipartParamNotRequired, null, webRequest, binderFactory)
assertNull(result)
}
@Test
fun resolveNonNullableRequiredWithMultipartParameter() {
val request = MockMultipartHttpServletRequest()
val expected = MockMultipartFile("mfile", "Hello World".toByteArray())
request.addFile(expected)
webRequest = ServletWebRequest(request)
var result = resolver.resolveArgument(nonNullableMultipartParamRequired, null, webRequest, binderFactory)
assertEquals(expected, result)
}
@Test(expected = MissingServletRequestPartException::class)
fun resolveNonNullableRequiredWithoutMultipartParameter() {
request.method = HttpMethod.POST.name
request.contentType = MediaType.MULTIPART_FORM_DATA_VALUE
resolver.resolveArgument(nonNullableMultipartParamRequired, null, webRequest, binderFactory)
}
@Test
fun resolveNonNullableNotRequiredWithMultipartParameter() {
val request = MockMultipartHttpServletRequest()
val expected = MockMultipartFile("mfile", "Hello World".toByteArray())
request.addFile(expected)
webRequest = ServletWebRequest(request)
var result = resolver.resolveArgument(nonNullableMultipartParamNotRequired, null, webRequest, binderFactory)
assertEquals(expected, result)
}
@Test(expected = TypeCastException::class)
fun resolveNonNullableNotRequiredWithoutMultipartParameter() {
request.method = HttpMethod.POST.name
request.contentType = MediaType.MULTIPART_FORM_DATA_VALUE
resolver.resolveArgument(nonNullableMultipartParamNotRequired, null, webRequest, binderFactory) as MultipartFile
}
@Suppress("unused_parameter")
fun handle(
@RequestParam("name") nullableParamRequired: String?,
@RequestParam("name", required = false) nullableParamNotRequired: String?,
@RequestParam("name") nonNullableParamRequired: String,
@RequestParam("name", required = false) nonNullableParamNotRequired: String,
@RequestParam("mfile") nullableMultipartParamRequired: MultipartFile?,
@RequestParam("mfile", required = false) nullableMultipartParamNotRequired: MultipartFile?,
@RequestParam("mfile") nonNullableMultipartParamRequired: MultipartFile,
@RequestParam("mfile", required = false) nonNullableMultipartParamNotRequired: MultipartFile) {
}
}

View File

@ -113,8 +113,7 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
HttpServletRequest servletRequest = request.getNativeRequest(HttpServletRequest.class);
RequestPart requestPart = parameter.getParameterAnnotation(RequestPart.class);
boolean isRequired = ((requestPart == null || requestPart.required()) && !parameter.isOptional() &&
!parameter.isNullable());
boolean isRequired = ((requestPart == null || requestPart.required()) && !parameter.isOptional());
String name = getPartName(parameter, requestPart);
parameter = parameter.nestedIfOptional();
@ -157,7 +156,7 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM
throw new MissingServletRequestPartException(name);
}
}
if (parameter.isOptional()) {
if (parameter.getParameterType() == Optional.class) {
if (arg == null || (arg instanceof Collection && ((Collection) arg).isEmpty()) ||
(arg instanceof Object[] && ((Object[]) arg).length == 0)) {
arg = Optional.empty();