This commit is contained in:
Rossen Stoyanchev 2016-05-31 09:29:24 -04:00
parent 7e95cd8b4e
commit 431a50823f
2 changed files with 15 additions and 27 deletions

View File

@ -74,7 +74,6 @@ public class StreamingResponseBodyReturnValueHandler implements HandlerMethodRet
returnValue = responseEntity.getBody(); returnValue = responseEntity.getBody();
if (returnValue == null) { if (returnValue == null) {
mavContainer.setRequestHandled(true); mavContainer.setRequestHandled(true);
// Ensure headers are flushed
outputMessage.flush(); outputMessage.flush();
return; return;
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2015 the original author or authors. * Copyright 2002-2016 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,12 +15,9 @@
*/ */
package org.springframework.web.servlet.mvc.method.annotation; package org.springframework.web.servlet.mvc.method.annotation;
import static org.junit.Assert.*;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.util.Collections;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@ -29,7 +26,6 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.core.MethodParameter; import org.springframework.core.MethodParameter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.mock.web.test.MockHttpServletResponse; import org.springframework.mock.web.test.MockHttpServletResponse;
@ -40,6 +36,10 @@ import org.springframework.web.context.request.async.StandardServletAsyncWebRequ
import org.springframework.web.context.request.async.WebAsyncUtils; import org.springframework.web.context.request.async.WebAsyncUtils;
import org.springframework.web.method.support.ModelAndViewContainer; import org.springframework.web.method.support.ModelAndViewContainer;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
/** /**
* Unit tests for * Unit tests for
@ -59,8 +59,6 @@ public class StreamingResponseBodyReturnValueHandlerTests {
private MockHttpServletResponse response; private MockHttpServletResponse response;
private HttpHeaders headers = new HttpHeaders();
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
@ -72,13 +70,12 @@ public class StreamingResponseBodyReturnValueHandlerTests {
this.response = new MockHttpServletResponse(); this.response = new MockHttpServletResponse();
this.webRequest = new ServletWebRequest(this.request, this.response); this.webRequest = new ServletWebRequest(this.request, this.response);
this.headers.add("foo", "bar");
AsyncWebRequest asyncWebRequest = new StandardServletAsyncWebRequest(this.request, this.response); AsyncWebRequest asyncWebRequest = new StandardServletAsyncWebRequest(this.request, this.response);
WebAsyncUtils.getAsyncManager(this.webRequest).setAsyncWebRequest(asyncWebRequest); WebAsyncUtils.getAsyncManager(this.webRequest).setAsyncWebRequest(asyncWebRequest);
this.request.setAsyncSupported(true); this.request.setAsyncSupported(true);
} }
@Test @Test
public void supportsReturnType() throws Exception { public void supportsReturnType() throws Exception {
assertTrue(this.handler.supportsReturnType(returnType(TestController.class, "handle"))); assertTrue(this.handler.supportsReturnType(returnType(TestController.class, "handle")));
@ -93,13 +90,9 @@ public class StreamingResponseBodyReturnValueHandlerTests {
CountDownLatch latch = new CountDownLatch(1); CountDownLatch latch = new CountDownLatch(1);
MethodParameter returnType = returnType(TestController.class, "handle"); MethodParameter returnType = returnType(TestController.class, "handle");
StreamingResponseBody streamingBody = new StreamingResponseBody() { StreamingResponseBody streamingBody = outputStream -> {
outputStream.write("foo".getBytes(Charset.forName("UTF-8")));
@Override latch.countDown();
public void writeTo(OutputStream outputStream) throws IOException {
outputStream.write("foo".getBytes(Charset.forName("UTF-8")));
latch.countDown();
}
}; };
this.handler.handleReturnValue(streamingBody, returnType, this.mavContainer, this.webRequest); this.handler.handleReturnValue(streamingBody, returnType, this.mavContainer, this.webRequest);
@ -116,13 +109,9 @@ public class StreamingResponseBodyReturnValueHandlerTests {
MethodParameter returnType = returnType(TestController.class, "handleResponseEntity"); MethodParameter returnType = returnType(TestController.class, "handleResponseEntity");
ResponseEntity<StreamingResponseBody> emitter = ResponseEntity.ok().header("foo", "bar") ResponseEntity<StreamingResponseBody> emitter = ResponseEntity.ok().header("foo", "bar")
.body(new StreamingResponseBody() { .body(outputStream -> {
outputStream.write("foo".getBytes(Charset.forName("UTF-8")));
@Override latch.countDown();
public void writeTo(OutputStream outputStream) throws IOException {
outputStream.write("foo".getBytes(Charset.forName("UTF-8")));
latch.countDown();
}
}); });
this.handler.handleReturnValue(emitter, returnType, this.mavContainer, this.webRequest); this.handler.handleReturnValue(emitter, returnType, this.mavContainer, this.webRequest);
@ -147,11 +136,11 @@ public class StreamingResponseBodyReturnValueHandlerTests {
@Test @Test
public void responseEntityWithHeadersAndNoContent() throws Exception { public void responseEntityWithHeadersAndNoContent() throws Exception {
ResponseEntity<?> emitter = ResponseEntity.noContent().header("foo", "bar").build();
MethodParameter returnType = returnType(TestController.class, "handleResponseEntity"); MethodParameter returnType = returnType(TestController.class, "handleResponseEntity");
ResponseEntity<?> emitter = ResponseEntity.noContent().headers(headers).build();
this.handler.handleReturnValue(emitter, returnType, this.mavContainer, this.webRequest); this.handler.handleReturnValue(emitter, returnType, this.mavContainer, this.webRequest);
assertEquals(this.response.getHeaders("foo"), this.headers.get("foo")); assertEquals(Collections.singletonList("bar"), this.response.getHeaders("foo"));
} }
private MethodParameter returnType(Class<?> clazz, String methodName) throws NoSuchMethodException { private MethodParameter returnType(Class<?> clazz, String methodName) throws NoSuchMethodException {