From b07d02f1bf08a8e50deb3274007b4094171ed609 Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Mon, 22 Mar 2010 10:23:39 +0000 Subject: [PATCH] SPR-7018 - Support for HttpEntity in @MVC --- .../AnnotationMethodHandlerAdapter.java | 33 +++++++++- .../ServletAnnotationControllerTests.java | 36 ++++++++++ .../support/HandlerMethodInvoker.java | 66 +++++++++++++++---- 3 files changed, 121 insertions(+), 14 deletions(-) diff --git a/org.springframework.web.servlet/src/main/java/org/springframework/web/servlet/mvc/annotation/AnnotationMethodHandlerAdapter.java b/org.springframework.web.servlet/src/main/java/org/springframework/web/servlet/mvc/annotation/AnnotationMethodHandlerAdapter.java index 79e36bb4fdd..ea05dbf6497 100644 --- a/org.springframework.web.servlet/src/main/java/org/springframework/web/servlet/mvc/annotation/AnnotationMethodHandlerAdapter.java +++ b/org.springframework.web.servlet/src/main/java/org/springframework/web/servlet/mvc/annotation/AnnotationMethodHandlerAdapter.java @@ -55,6 +55,8 @@ import org.springframework.core.LocalVariableTableParameterNameDiscoverer; import org.springframework.core.Ordered; import org.springframework.core.ParameterNameDiscoverer; import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpInputMessage; import org.springframework.http.HttpOutputMessage; import org.springframework.http.HttpStatus; @@ -797,6 +799,10 @@ public class AnnotationMethodHandlerAdapter extends WebContentGenerator handleResponseBody(returnValue, webRequest); return null; } + if (returnValue instanceof HttpEntity) { + handleHttpEntityResponse((HttpEntity) returnValue, webRequest); + return null; + } if (returnValue instanceof ModelAndView) { ModelAndView mav = (ModelAndView) returnValue; @@ -839,7 +845,6 @@ public class AnnotationMethodHandlerAdapter extends WebContentGenerator } } - @SuppressWarnings("unchecked") private void handleResponseBody(Object returnValue, ServletWebRequest webRequest) throws ServletException, IOException { if (returnValue == null) { @@ -847,12 +852,35 @@ public class AnnotationMethodHandlerAdapter extends WebContentGenerator } HttpInputMessage inputMessage = new ServletServerHttpRequest(webRequest.getRequest()); + HttpOutputMessage outputMessage = new ServletServerHttpResponse(webRequest.getResponse()); + + writeWithMessageConverters(returnValue, inputMessage, outputMessage); + } + + private void handleHttpEntityResponse(HttpEntity responseEntity, ServletWebRequest webRequest) + throws ServletException, IOException { + if (responseEntity == null) { + return; + } + HttpInputMessage inputMessage = new ServletServerHttpRequest(webRequest.getRequest()); + HttpOutputMessage outputMessage = new ServletServerHttpResponse(webRequest.getResponse()); + + HttpHeaders entityHeaders = responseEntity.getHeaders(); + if (!entityHeaders.isEmpty()) { + outputMessage.getHeaders().putAll(entityHeaders); + } + writeWithMessageConverters(responseEntity.getBody(), inputMessage, outputMessage); + } + + @SuppressWarnings("unchecked") + private void writeWithMessageConverters(Object returnValue, + HttpInputMessage inputMessage, HttpOutputMessage outputMessage) + throws IOException, HttpMediaTypeNotAcceptableException { List acceptedMediaTypes = inputMessage.getHeaders().getAccept(); if (acceptedMediaTypes.isEmpty()) { acceptedMediaTypes = Collections.singletonList(MediaType.ALL); } MediaType.sortBySpecificity(acceptedMediaTypes); - HttpOutputMessage outputMessage = new ServletServerHttpResponse(webRequest.getResponse()); Class returnValueType = returnValue.getClass(); List allSupportedMediaTypes = new ArrayList(); if (getMessageConverters() != null) { @@ -879,6 +907,7 @@ public class AnnotationMethodHandlerAdapter extends WebContentGenerator } throw new HttpMediaTypeNotAcceptableException(allSupportedMediaTypes); } + } diff --git a/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/annotation/ServletAnnotationControllerTests.java b/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/annotation/ServletAnnotationControllerTests.java index adfb8dbf568..38547d30546 100644 --- a/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/annotation/ServletAnnotationControllerTests.java +++ b/org.springframework.web.servlet/src/test/java/org/springframework/web/servlet/mvc/annotation/ServletAnnotationControllerTests.java @@ -18,6 +18,7 @@ package org.springframework.web.servlet.mvc.annotation; import java.io.IOException; import java.io.Serializable; +import java.io.UnsupportedEncodingException; import java.io.Writer; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; @@ -64,6 +65,7 @@ import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.propertyeditors.CustomDateEditor; import org.springframework.context.annotation.AnnotationConfigUtils; import org.springframework.core.MethodParameter; +import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpInputMessage; import org.springframework.http.HttpOutputMessage; @@ -1155,6 +1157,24 @@ public class ServletAnnotationControllerTests { assertEquals("Invalid response status code", HttpServletResponse.SC_BAD_REQUEST, response.getStatus()); } + @Test + public void httpEntity() throws ServletException, IOException { + initServlet(HttpEntityController.class); + + MockHttpServletRequest request = new MockHttpServletRequest("PUT", "/handle"); + String requestBody = "Hello World"; + request.setContent(requestBody.getBytes("UTF-8")); + request.addHeader("Content-Type", "text/plain; charset=utf-8"); + request.addHeader("Accept", "text/*, */*"); + request.addHeader("MyRequestHeader", "MyValue"); + MockHttpServletResponse response = new MockHttpServletResponse(); + servlet.service(request, response); + assertEquals(200, response.getStatus()); + assertEquals(requestBody, response.getContentAsString()); + assertEquals("MyValue", response.getHeader("MyResponseHeader")); + } + + /* * See SPR-6877 */ @@ -2502,5 +2522,21 @@ public class ServletAnnotationControllerTests { } } + @Controller + public static class HttpEntityController { + + @RequestMapping("/handle") + public HttpEntity handle(HttpEntity requestEntity) throws UnsupportedEncodingException { + assertNotNull(requestEntity); + assertEquals("MyValue", requestEntity.getHeaders().getFirst("MyRequestHeader")); + String requestBody = new String(requestEntity.getBody(), "UTF-8"); + assertEquals("Hello World", requestBody); + + HttpHeaders responseHeaders = new HttpHeaders(); + responseHeaders.set("MyResponseHeader", "MyValue"); + return new HttpEntity(requestBody, responseHeaders); + } + } + } diff --git a/org.springframework.web/src/main/java/org/springframework/web/bind/annotation/support/HandlerMethodInvoker.java b/org.springframework.web/src/main/java/org/springframework/web/bind/annotation/support/HandlerMethodInvoker.java index 1d043498fc8..7475d18036a 100644 --- a/org.springframework.web/src/main/java/org/springframework/web/bind/annotation/support/HandlerMethodInvoker.java +++ b/org.springframework.web/src/main/java/org/springframework/web/bind/annotation/support/HandlerMethodInvoker.java @@ -17,8 +17,12 @@ package org.springframework.web.bind.annotation.support; import java.lang.annotation.Annotation; +import java.lang.reflect.Array; +import java.lang.reflect.GenericArrayType; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -40,12 +44,14 @@ import org.springframework.core.GenericTypeResolver; import org.springframework.core.MethodParameter; import org.springframework.core.ParameterNameDiscoverer; import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpInputMessage; import org.springframework.http.MediaType; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.ui.ExtendedModelMap; import org.springframework.ui.Model; +import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -192,7 +198,7 @@ public class HandlerMethodInvoker { boolean required = false; String defaultValue = null; boolean validate = false; - int found = 0; + int annotationsFound = 0; Annotation[] paramAnns = methodParam.getParameterAnnotations(); for (Annotation paramAnn : paramAnns) { @@ -201,35 +207,35 @@ public class HandlerMethodInvoker { paramName = requestParam.value(); required = requestParam.required(); defaultValue = parseDefaultValueAttribute(requestParam.defaultValue()); - found++; + annotationsFound++; } else if (RequestHeader.class.isInstance(paramAnn)) { RequestHeader requestHeader = (RequestHeader) paramAnn; headerName = requestHeader.value(); required = requestHeader.required(); defaultValue = parseDefaultValueAttribute(requestHeader.defaultValue()); - found++; + annotationsFound++; } else if (RequestBody.class.isInstance(paramAnn)) { requestBodyFound = true; - found++; + annotationsFound++; } else if (CookieValue.class.isInstance(paramAnn)) { CookieValue cookieValue = (CookieValue) paramAnn; cookieName = cookieValue.value(); required = cookieValue.required(); defaultValue = parseDefaultValueAttribute(cookieValue.defaultValue()); - found++; + annotationsFound++; } else if (PathVariable.class.isInstance(paramAnn)) { PathVariable pathVar = (PathVariable) paramAnn; pathVarName = pathVar.value(); - found++; + annotationsFound++; } else if (ModelAttribute.class.isInstance(paramAnn)) { ModelAttribute attr = (ModelAttribute) paramAnn; attrName = attr.value(); - found++; + annotationsFound++; } else if (Value.class.isInstance(paramAnn)) { defaultValue = ((Value) paramAnn).value(); @@ -239,12 +245,12 @@ public class HandlerMethodInvoker { } } - if (found > 1) { + if (annotationsFound > 1) { throw new IllegalStateException("Handler parameter annotations are exclusive choices - " + "do not specify more than one such annotation on the same parameter: " + handlerMethod); } - if (found == 0) { + if (annotationsFound == 0) { Object argValue = resolveCommonArgument(methodParam, webRequest); if (argValue != WebArgumentResolver.UNRESOLVED) { args[i] = argValue; @@ -260,6 +266,9 @@ public class HandlerMethodInvoker { else if (SessionStatus.class.isAssignableFrom(paramType)) { args[i] = this.sessionStatus; } + else if (HttpEntity.class.isAssignableFrom(paramType)) { + args[i] = resolveHttpEntityRequest(methodParam, webRequest, handler); + } else if (Errors.class.isAssignableFrom(paramType)) { throw new IllegalStateException("Errors/BindingResult argument declared " + "without preceding model attribute. Check your handler method signature!"); @@ -527,12 +536,22 @@ public class HandlerMethodInvoker { /** * Resolves the given {@link RequestBody @RequestBody} annotation. */ - @SuppressWarnings("unchecked") protected Object resolveRequestBody(MethodParameter methodParam, NativeWebRequest webRequest, Object handler) throws Exception { + return readWithMessageConverters(methodParam, createHttpInputMessage(webRequest), methodParam.getParameterType()); + } + @SuppressWarnings("unchecked") + private HttpEntity resolveHttpEntityRequest(MethodParameter methodParam, NativeWebRequest webRequest, Object handler) + throws Exception { HttpInputMessage inputMessage = createHttpInputMessage(webRequest); - Class paramType = methodParam.getParameterType(); + Class paramType = getHttpEntityType(methodParam); + Object body = readWithMessageConverters(methodParam, inputMessage, paramType); + return new HttpEntity(body, inputMessage.getHeaders()); + } + + private Object readWithMessageConverters(MethodParameter methodParam, HttpInputMessage inputMessage, Class paramType) + throws Exception{ MediaType contentType = inputMessage.getHeaders().getContentType(); if (contentType == null) { StringBuilder builder = new StringBuilder(ClassUtils.getShortName(methodParam.getParameterType())); @@ -542,8 +561,9 @@ public class HandlerMethodInvoker { builder.append(paramName); } throw new HttpMediaTypeNotSupportedException( - "Cannot extract @RequestBody parameter (" + builder.toString() + "): no Content-Type found"); + "Cannot extract parameter (" + builder.toString() + "): no Content-Type found"); } + List allSupportedMediaTypes = new ArrayList(); if (this.messageConverters != null) { for (HttpMessageConverter messageConverter : this.messageConverters) { @@ -560,6 +580,28 @@ public class HandlerMethodInvoker { throw new HttpMediaTypeNotSupportedException(contentType, allSupportedMediaTypes); } + private Class getHttpEntityType(MethodParameter methodParam) { + Assert.isAssignable(HttpEntity.class, methodParam.getParameterType()); + ParameterizedType type = (ParameterizedType) methodParam.getGenericParameterType(); + if (type.getActualTypeArguments().length == 1) { + Type typeArgument = type.getActualTypeArguments()[0]; + if (typeArgument instanceof Class) { + return (Class) typeArgument; + } + else if (typeArgument instanceof GenericArrayType) { + Type componentType = ((GenericArrayType) typeArgument).getGenericComponentType(); + if (componentType instanceof Class) { + // Surely, there should be a nicer way to do this + Object array = Array.newInstance((Class) componentType, 0); + return array.getClass(); + } + } + } + throw new IllegalArgumentException( + "HttpEntity parameter (" + methodParam.getParameterName() + ") is not parameterized"); + + } + private Object resolveCookieValue(String cookieName, boolean required, String defaultValue, MethodParameter methodParam, NativeWebRequest webRequest, Object handlerForInitBinderCall) throws Exception {