MockMvc re-uses response instance on async dispatch

MockMvc now properly detects the presence of an AsyncContext and
re-uses the response instance used to start it.

This commit also includes a minor fix in
ResponseBodyEmitterReturnValueHandler to ensure it does not disable
ETag related content buffering for reactive return values that do not
result in streaming (e.g. single value or collections).

Issue: SPR-16067
This commit is contained in:
Rossen Stoyanchev 2017-10-17 16:57:35 -04:00
parent 94c4a7f941
commit cd634633d8
3 changed files with 91 additions and 32 deletions

View File

@ -18,9 +18,13 @@ package org.springframework.test.web.servlet;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import javax.servlet.AsyncContext;
import javax.servlet.DispatcherType; import javax.servlet.DispatcherType;
import javax.servlet.Filter; import javax.servlet.Filter;
import javax.servlet.ServletContext; import javax.servlet.ServletContext;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import org.springframework.beans.Mergeable; import org.springframework.beans.Mergeable;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
@ -135,24 +139,35 @@ public final class MockMvc {
} }
MockHttpServletRequest request = requestBuilder.buildRequest(this.servletContext); MockHttpServletRequest request = requestBuilder.buildRequest(this.servletContext);
MockHttpServletResponse response = new MockHttpServletResponse();
AsyncContext asyncContext = request.getAsyncContext();
MockHttpServletResponse mockResponse;
HttpServletResponse servletResponse;
if (asyncContext != null) {
servletResponse = (HttpServletResponse) asyncContext.getResponse();
mockResponse = unwrapResponseIfNecessary(servletResponse);
}
else {
mockResponse = new MockHttpServletResponse();
servletResponse = mockResponse;
}
if (requestBuilder instanceof SmartRequestBuilder) { if (requestBuilder instanceof SmartRequestBuilder) {
request = ((SmartRequestBuilder) requestBuilder).postProcessRequest(request); request = ((SmartRequestBuilder) requestBuilder).postProcessRequest(request);
} }
final MvcResult mvcResult = new DefaultMvcResult(request, response); final MvcResult mvcResult = new DefaultMvcResult(request, mockResponse);
request.setAttribute(MVC_RESULT_ATTRIBUTE, mvcResult); request.setAttribute(MVC_RESULT_ATTRIBUTE, mvcResult);
RequestAttributes previousAttributes = RequestContextHolder.getRequestAttributes(); RequestAttributes previousAttributes = RequestContextHolder.getRequestAttributes();
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, servletResponse));
MockFilterChain filterChain = new MockFilterChain(this.servlet, this.filters); MockFilterChain filterChain = new MockFilterChain(this.servlet, this.filters);
filterChain.doFilter(request, response); filterChain.doFilter(request, servletResponse);
if (DispatcherType.ASYNC.equals(request.getDispatcherType()) && if (DispatcherType.ASYNC.equals(request.getDispatcherType()) &&
request.getAsyncContext() != null & !request.isAsyncStarted()) { asyncContext != null & !request.isAsyncStarted()) {
request.getAsyncContext().complete(); asyncContext.complete();
} }
applyDefaultResultActions(mvcResult); applyDefaultResultActions(mvcResult);
@ -176,6 +191,14 @@ public final class MockMvc {
}; };
} }
private MockHttpServletResponse unwrapResponseIfNecessary(ServletResponse servletResponse) {
while (servletResponse instanceof HttpServletResponseWrapper) {
servletResponse = ((HttpServletResponseWrapper) servletResponse).getResponse();
}
Assert.isInstanceOf(MockHttpServletResponse.class, servletResponse);
return (MockHttpServletResponse) servletResponse;
}
private void applyDefaultResultActions(MvcResult mvcResult) throws Exception { private void applyDefaultResultActions(MvcResult mvcResult) throws Exception {

View File

@ -18,6 +18,7 @@ package org.springframework.test.web.servlet.samples.standalone;
import java.io.IOException; import java.io.IOException;
import java.security.Principal; import java.security.Principal;
import java.util.concurrent.CompletableFuture;
import javax.servlet.Filter; import javax.servlet.Filter;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.ServletException; import javax.servlet.ServletException;
@ -29,22 +30,34 @@ import javax.validation.Valid;
import org.junit.Test; import org.junit.Test;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Controller; import org.springframework.stereotype.Controller;
import org.springframework.test.web.Person; import org.springframework.test.web.Person;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.validation.Errors; import org.springframework.validation.Errors;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.filter.ShallowEtagHeaderFilter;
import org.springframework.web.servlet.ModelAndView; import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.mvc.support.RedirectAttributes; import org.springframework.web.servlet.mvc.support.RedirectAttributes;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.asyncDispatch;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.flash;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.model;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
import static org.springframework.test.web.servlet.setup.MockMvcBuilders.standaloneSetup;
/** /**
* Tests with {@link Filter}'s. * Tests with {@link Filter}'s.
*
* @author Rob Winch * @author Rob Winch
*/ */
public class FilterTests { public class FilterTests {
@ -107,10 +120,29 @@ public class FilterTests {
.andExpect(model().attribute("principal", WrappingRequestResponseFilter.PRINCIPAL_NAME)); .andExpect(model().attribute("principal", WrappingRequestResponseFilter.PRINCIPAL_NAME));
} }
@Test // SPR-16067
public void filterWrapsRequestResponseWithAsyncDispatch() throws Exception {
MockMvc mockMvc = standaloneSetup(new PersonController())
.addFilters(new ShallowEtagHeaderFilter())
.build();
MvcResult mvcResult = mockMvc.perform(get("/persons/1").accept(MediaType.APPLICATION_JSON))
.andExpect(request().asyncStarted())
.andExpect(request().asyncResult(new Person("Lukas")))
.andReturn();
mockMvc.perform(asyncDispatch(mvcResult))
.andExpect(status().isOk())
.andExpect(header().longValue("Content-Length", 53))
.andExpect(header().string("ETag", "\"0e37becb4f0c90709cb2e1efcc61eaa00\""))
.andExpect(content().string("{\"name\":\"Lukas\",\"someDouble\":0.0,\"someBoolean\":false}"));
}
@Controller @Controller
private static class PersonController { private static class PersonController {
@RequestMapping(value="/persons", method=RequestMethod.POST)
@PostMapping(path="/persons")
public String save(@Valid Person person, Errors errors, RedirectAttributes redirectAttrs) { public String save(@Valid Person person, Errors errors, RedirectAttributes redirectAttrs) {
if (errors.hasErrors()) { if (errors.hasErrors()) {
return "person/add"; return "person/add";
@ -120,18 +152,25 @@ public class FilterTests {
return "redirect:/person/{id}"; return "redirect:/person/{id}";
} }
@RequestMapping(value="/user") @PostMapping("/user")
public ModelAndView user(Principal principal) { public ModelAndView user(Principal principal) {
return new ModelAndView("user/view", "principal", principal.getName()); return new ModelAndView("user/view", "principal", principal.getName());
} }
@RequestMapping(value="/forward") @GetMapping("/forward")
public String forward() { public String forward() {
return "forward:/persons"; return "forward:/persons";
} }
@GetMapping("persons/{id}")
@ResponseBody
public CompletableFuture<Person> getPerson() {
return CompletableFuture.completedFuture(new Person("Lukas"));
}
} }
private class ContinueFilter extends OncePerRequestFilter { private class ContinueFilter extends OncePerRequestFilter {
@Override @Override
protected void doFilterInternal(HttpServletRequest request, protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
@ -144,28 +183,25 @@ public class FilterTests {
public static final String PRINCIPAL_NAME = "WrapRequestResponseFilterPrincipal"; public static final String PRINCIPAL_NAME = "WrapRequestResponseFilterPrincipal";
@Override @Override
protected void doFilterInternal(HttpServletRequest request, protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { FilterChain filterChain) throws ServletException, IOException {
filterChain.doFilter(new HttpServletRequestWrapper(request) { filterChain.doFilter(new HttpServletRequestWrapper(request) {
@Override @Override
public Principal getUserPrincipal() { public Principal getUserPrincipal() {
return new Principal() { return () -> PRINCIPAL_NAME;
@Override
public String getName() {
return PRINCIPAL_NAME;
}
};
} }
}, new HttpServletResponseWrapper(response)); }, new HttpServletResponseWrapper(response));
} }
} }
private class RedirectFilter extends OncePerRequestFilter { private class RedirectFilter extends OncePerRequestFilter {
@Override @Override
protected void doFilterInternal(HttpServletRequest request, protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { FilterChain filterChain) throws ServletException, IOException {
response.sendRedirect("/login"); response.sendRedirect("/login");
} }

View File

@ -140,23 +140,23 @@ public class ResponseBodyEmitterReturnValueHandler implements HandlerMethodRetur
ServletRequest request = webRequest.getNativeRequest(ServletRequest.class); ServletRequest request = webRequest.getNativeRequest(ServletRequest.class);
Assert.state(request != null, "No ServletRequest"); Assert.state(request != null, "No ServletRequest");
ShallowEtagHeaderFilter.disableContentCaching(request);
ResponseBodyEmitter emitter; ResponseBodyEmitter emitter;
if (returnValue instanceof ResponseBodyEmitter) { if (returnValue instanceof ResponseBodyEmitter) {
emitter = (ResponseBodyEmitter) returnValue; emitter = (ResponseBodyEmitter) returnValue;
} }
else { else {
emitter = this.reactiveHandler.handleValue(returnValue, returnType, mavContainer, webRequest); emitter = this.reactiveHandler.handleValue(returnValue, returnType, mavContainer, webRequest);
}
if (emitter == null) { if (emitter == null) {
// Not streaming..
return; return;
} }
}
emitter.extendResponse(outputMessage); emitter.extendResponse(outputMessage);
// At this point we know we're streaming..
ShallowEtagHeaderFilter.disableContentCaching(request);
// Commit the response and wrap to ignore further header changes // Commit the response and wrap to ignore further header changes
outputMessage.getBody(); outputMessage.getBody();
outputMessage.flush(); outputMessage.flush();