Only reset response buffer for error handling

Prior to this commit, `DispatcherServlet` would completely reset the
response (status, headers and body) before handling errors within Spring
MVC. This can cause unintended consequences when Servlet Filters added
response headers before the error happened. Such response headers might
be still required in case of error handling.

This commit changes the complete reset of the response to only resetting
the response buffer, if possible.

Closes gh-31154
See gh-31104
This commit is contained in:
Brian Clozel 2023-09-08 18:47:21 +02:00
parent 88ee8fc92f
commit 0f945873a3
2 changed files with 61 additions and 57 deletions

View File

@ -1338,9 +1338,10 @@ public class DispatcherServlet extends FrameworkServlet {
// Success and error responses may use different content types
request.removeAttribute(HandlerMapping.PRODUCIBLE_MEDIA_TYPES_ATTRIBUTE);
// Reset the response if the response is not committed already
// Reset the response body buffer if the response is not committed already,
// leaving the response headers in place.
try {
response.reset();
response.resetBuffer();
}
catch (IllegalStateException illegalStateException) {
// the response is already committed, leave it to exception handlers anyway

View File

@ -120,7 +120,7 @@ public class DispatcherServletTests {
@Test
public void configuredDispatcherServlets() {
void configuredDispatcherServlets() {
assertThat((simpleDispatcherServlet.getNamespace())).as("Correct namespace")
.isEqualTo("simple" + FrameworkServlet.DEFAULT_NAMESPACE_SUFFIX);
assertThat((FrameworkServlet.SERVLET_CONTEXT_PREFIX + "simple")).as("Correct attribute")
@ -139,7 +139,7 @@ public class DispatcherServletTests {
}
@Test
public void invalidRequest() throws Exception {
void invalidRequest() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/invalid.do");
MockHttpServletResponse response = new MockHttpServletResponse();
simpleDispatcherServlet.service(request, response);
@ -148,7 +148,7 @@ public class DispatcherServletTests {
}
@Test
public void requestHandledEvent() throws Exception {
void requestHandledEvent() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/locale.do");
MockHttpServletResponse response = new MockHttpServletResponse();
complexDispatcherServlet.service(request, response);
@ -159,7 +159,7 @@ public class DispatcherServletTests {
}
@Test
public void publishEventsOff() throws Exception {
void publishEventsOff() throws Exception {
complexDispatcherServlet.setPublishEvents(false);
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/locale.do");
MockHttpServletResponse response = new MockHttpServletResponse();
@ -171,7 +171,7 @@ public class DispatcherServletTests {
}
@Test
public void parameterizableViewController() throws Exception {
void parameterizableViewController() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/view.do");
request.addUserRole("role1");
MockHttpServletResponse response = new MockHttpServletResponse();
@ -180,7 +180,7 @@ public class DispatcherServletTests {
}
@Test
public void handlerInterceptorSuppressesView() throws Exception {
void handlerInterceptorSuppressesView() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/view.do");
request.addUserRole("role1");
request.addParameter("noView", "true");
@ -190,7 +190,7 @@ public class DispatcherServletTests {
}
@Test
public void localeRequest() throws Exception {
void localeRequest() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/locale.do");
request.addPreferredLocale(Locale.CANADA);
MockHttpServletResponse response = new MockHttpServletResponse();
@ -200,7 +200,7 @@ public class DispatcherServletTests {
}
@Test
public void unknownRequest() throws Exception {
void unknownRequest() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/unknown.do");
MockHttpServletResponse response = new MockHttpServletResponse();
complexDispatcherServlet.service(request, response);
@ -209,7 +209,7 @@ public class DispatcherServletTests {
}
@Test
public void anotherLocaleRequest() throws Exception {
void anotherLocaleRequest() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/locale.do;abc=def");
request.addPreferredLocale(Locale.CANADA);
request.addUserRole("role1");
@ -230,7 +230,7 @@ public class DispatcherServletTests {
}
@Test
public void existingMultipartRequest() throws Exception {
void existingMultipartRequest() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/locale.do;abc=def");
request.addPreferredLocale(Locale.CANADA);
request.addUserRole("role1");
@ -246,7 +246,7 @@ public class DispatcherServletTests {
}
@Test
public void existingMultipartRequestButWrapped() throws Exception {
void existingMultipartRequestButWrapped() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/locale.do;abc=def");
request.addPreferredLocale(Locale.CANADA);
request.addUserRole("role1");
@ -262,7 +262,7 @@ public class DispatcherServletTests {
}
@Test
public void multipartResolutionFailed() throws Exception {
void multipartResolutionFailed() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/locale.do;abc=def");
request.addPreferredLocale(Locale.CANADA);
request.addUserRole("role1");
@ -276,7 +276,7 @@ public class DispatcherServletTests {
}
@Test
public void handlerInterceptorAbort() throws Exception {
void handlerInterceptorAbort() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/locale.do");
request.addParameter("abort", "true");
request.addPreferredLocale(Locale.CANADA);
@ -293,7 +293,7 @@ public class DispatcherServletTests {
}
@Test
public void modelAndViewDefiningException() throws Exception {
void modelAndViewDefiningException() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/locale.do");
request.addPreferredLocale(Locale.CANADA);
request.addUserRole("role1");
@ -305,7 +305,7 @@ public class DispatcherServletTests {
}
@Test
public void simpleMappingExceptionResolverWithSpecificHandler1() throws Exception {
void simpleMappingExceptionResolverWithSpecificHandler1() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/locale.do");
request.addPreferredLocale(Locale.CANADA);
request.addUserRole("role1");
@ -319,7 +319,7 @@ public class DispatcherServletTests {
}
@Test
public void simpleMappingExceptionResolverWithSpecificHandler2() throws Exception {
void simpleMappingExceptionResolverWithSpecificHandler2() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/locale.do");
request.addPreferredLocale(Locale.CANADA);
request.addUserRole("role1");
@ -332,7 +332,7 @@ public class DispatcherServletTests {
}
@Test
public void simpleMappingExceptionResolverWithAllHandlers1() throws Exception {
void simpleMappingExceptionResolverWithAllHandlers1() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/loc.do");
request.addPreferredLocale(Locale.CANADA);
request.addUserRole("role1");
@ -346,7 +346,7 @@ public class DispatcherServletTests {
}
@Test
public void simpleMappingExceptionResolverWithAllHandlers2() throws Exception {
void simpleMappingExceptionResolverWithAllHandlers2() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/loc.do");
request.addPreferredLocale(Locale.CANADA);
request.addUserRole("role1");
@ -359,7 +359,7 @@ public class DispatcherServletTests {
}
@Test
public void simpleMappingExceptionResolverWithDefaultErrorView() throws Exception {
void simpleMappingExceptionResolverWithDefaultErrorView() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/locale.do");
request.addPreferredLocale(Locale.CANADA);
request.addUserRole("role1");
@ -372,7 +372,7 @@ public class DispatcherServletTests {
}
@Test
public void localeChangeInterceptor1() throws Exception {
void localeChangeInterceptor1() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/locale.do");
request.addPreferredLocale(Locale.GERMAN);
request.addUserRole("role2");
@ -385,7 +385,7 @@ public class DispatcherServletTests {
}
@Test
public void localeChangeInterceptor2() throws Exception {
void localeChangeInterceptor2() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/locale.do");
request.addPreferredLocale(Locale.GERMAN);
request.addUserRole("role2");
@ -397,7 +397,7 @@ public class DispatcherServletTests {
}
@Test
public void themeChangeInterceptor1() throws Exception {
void themeChangeInterceptor1() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/locale.do");
request.addPreferredLocale(Locale.CANADA);
request.addUserRole("role1");
@ -410,7 +410,7 @@ public class DispatcherServletTests {
}
@Test
public void themeChangeInterceptor2() throws Exception {
void themeChangeInterceptor2() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/locale.do");
request.addPreferredLocale(Locale.CANADA);
request.addUserRole("role1");
@ -422,7 +422,7 @@ public class DispatcherServletTests {
}
@Test
public void notAuthorized() throws Exception {
void notAuthorized() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/locale.do");
request.addPreferredLocale(Locale.CANADA);
MockHttpServletResponse response = new MockHttpServletResponse();
@ -431,7 +431,7 @@ public class DispatcherServletTests {
}
@Test
public void headMethodWithExplicitHandling() throws Exception {
void headMethodWithExplicitHandling() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "HEAD", "/head.do");
MockHttpServletResponse response = new MockHttpServletResponse();
complexDispatcherServlet.service(request, response);
@ -444,7 +444,7 @@ public class DispatcherServletTests {
}
@Test
public void headMethodWithNoBodyResponse() throws Exception {
void headMethodWithNoBodyResponse() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "HEAD", "/body.do");
MockHttpServletResponse response = new MockHttpServletResponse();
complexDispatcherServlet.service(request, response);
@ -457,7 +457,7 @@ public class DispatcherServletTests {
}
@Test
public void notDetectAllHandlerMappings() throws ServletException, IOException {
void notDetectAllHandlerMappings() throws ServletException, IOException {
DispatcherServlet complexDispatcherServlet = new DispatcherServlet();
complexDispatcherServlet.setContextClass(ComplexWebApplicationContext.class);
complexDispatcherServlet.setNamespace("test");
@ -471,7 +471,7 @@ public class DispatcherServletTests {
}
@Test
public void handlerNotMappedWithAutodetect() throws ServletException, IOException {
void handlerNotMappedWithAutodetect() throws ServletException, IOException {
DispatcherServlet complexDispatcherServlet = new DispatcherServlet();
// no parent
complexDispatcherServlet.setContextClass(ComplexWebApplicationContext.class);
@ -485,7 +485,7 @@ public class DispatcherServletTests {
}
@Test
public void detectHandlerMappingFromParent() throws ServletException, IOException {
void detectHandlerMappingFromParent() throws ServletException, IOException {
// create a parent context that includes a mapping
StaticWebApplicationContext parent = new StaticWebApplicationContext();
parent.setServletContext(getServletContext());
@ -515,7 +515,7 @@ public class DispatcherServletTests {
}
@Test
public void detectAllHandlerAdapters() throws ServletException, IOException {
void detectAllHandlerAdapters() throws ServletException, IOException {
DispatcherServlet complexDispatcherServlet = new DispatcherServlet();
complexDispatcherServlet.setContextClass(ComplexWebApplicationContext.class);
complexDispatcherServlet.setNamespace("test");
@ -535,7 +535,7 @@ public class DispatcherServletTests {
}
@Test
public void notDetectAllHandlerAdapters() throws ServletException, IOException {
void notDetectAllHandlerAdapters() throws ServletException, IOException {
DispatcherServlet complexDispatcherServlet = new DispatcherServlet();
complexDispatcherServlet.setContextClass(ComplexWebApplicationContext.class);
complexDispatcherServlet.setNamespace("test");
@ -560,7 +560,7 @@ public class DispatcherServletTests {
}
@Test
public void notDetectAllHandlerExceptionResolvers() throws ServletException, IOException {
void notDetectAllHandlerExceptionResolvers() throws ServletException, IOException {
DispatcherServlet complexDispatcherServlet = new DispatcherServlet();
complexDispatcherServlet.setContextClass(ComplexWebApplicationContext.class);
complexDispatcherServlet.setNamespace("test");
@ -575,7 +575,7 @@ public class DispatcherServletTests {
}
@Test
public void notDetectAllViewResolvers() throws ServletException, IOException {
void notDetectAllViewResolvers() throws ServletException, IOException {
DispatcherServlet complexDispatcherServlet = new DispatcherServlet();
complexDispatcherServlet.setContextClass(ComplexWebApplicationContext.class);
complexDispatcherServlet.setNamespace("test");
@ -590,7 +590,7 @@ public class DispatcherServletTests {
}
@Test
public void throwExceptionIfNoHandlerFound() throws ServletException, IOException {
void throwExceptionIfNoHandlerFound() throws ServletException, IOException {
DispatcherServlet complexDispatcherServlet = new DispatcherServlet();
complexDispatcherServlet.setContextClass(SimpleWebApplicationContext.class);
complexDispatcherServlet.setNamespace("test");
@ -606,7 +606,7 @@ public class DispatcherServletTests {
// SPR-12984
@Test
public void noHandlerFoundExceptionMessage() {
void noHandlerFoundExceptionMessage() {
HttpHeaders headers = new HttpHeaders();
headers.add("foo", "bar");
NoHandlerFoundException ex = new NoHandlerFoundException("GET", "/foo", headers);
@ -615,7 +615,7 @@ public class DispatcherServletTests {
}
@Test
public void cleanupAfterIncludeWithRemove() throws ServletException, IOException {
void cleanupAfterIncludeWithRemove() throws ServletException, IOException {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/main.do");
MockHttpServletResponse response = new MockHttpServletResponse();
@ -635,7 +635,7 @@ public class DispatcherServletTests {
}
@Test
public void cleanupAfterIncludeWithRestore() throws ServletException, IOException {
void cleanupAfterIncludeWithRestore() throws ServletException, IOException {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/main.do");
MockHttpServletResponse response = new MockHttpServletResponse();
@ -655,7 +655,7 @@ public class DispatcherServletTests {
}
@Test
public void noCleanupAfterInclude() throws ServletException, IOException {
void noCleanupAfterInclude() throws ServletException, IOException {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/main.do");
MockHttpServletResponse response = new MockHttpServletResponse();
@ -676,7 +676,7 @@ public class DispatcherServletTests {
}
@Test
public void servletHandlerAdapter() throws Exception {
void servletHandlerAdapter() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "GET", "/servlet.do");
MockHttpServletResponse response = new MockHttpServletResponse();
complexDispatcherServlet.service(request, response);
@ -690,7 +690,7 @@ public class DispatcherServletTests {
}
@Test
public void withNoView() throws Exception {
void withNoView() throws Exception {
MockServletContext servletContext = new MockServletContext();
MockHttpServletRequest request = new MockHttpServletRequest(servletContext, "GET", "/noview.do");
MockHttpServletResponse response = new MockHttpServletResponse();
@ -700,7 +700,7 @@ public class DispatcherServletTests {
}
@Test
public void withNoViewNested() throws Exception {
void withNoViewNested() throws Exception {
MockServletContext servletContext = new MockServletContext();
MockHttpServletRequest request = new MockHttpServletRequest(servletContext, "GET", "/noview/simple.do");
MockHttpServletResponse response = new MockHttpServletResponse();
@ -710,7 +710,7 @@ public class DispatcherServletTests {
}
@Test
public void withNoViewAndSamePath() throws Exception {
void withNoViewAndSamePath() throws Exception {
InternalResourceViewResolver vr = (InternalResourceViewResolver) complexDispatcherServlet
.getWebApplicationContext().getBean("viewResolver2");
vr.setSuffix("");
@ -724,7 +724,7 @@ public class DispatcherServletTests {
}
@Test // gh-26318
public void parsedRequestPathIsRestoredOnForward() throws Exception {
void parsedRequestPathIsRestoredOnForward() throws Exception {
AnnotationConfigWebApplicationContext context = new AnnotationConfigWebApplicationContext();
context.register(PathPatternParserConfig.class);
DispatcherServlet servlet = new DispatcherServlet(context);
@ -745,7 +745,7 @@ public class DispatcherServletTests {
}
@Test
public void dispatcherServletRefresh() throws ServletException {
void dispatcherServletRefresh() throws ServletException {
MockServletContext servletContext = new MockServletContext("org/springframework/web/context");
DispatcherServlet servlet = new DispatcherServlet();
@ -776,7 +776,7 @@ public class DispatcherServletTests {
}
@Test
public void dispatcherServletContextRefresh() throws ServletException {
void dispatcherServletContextRefresh() throws ServletException {
MockServletContext servletContext = new MockServletContext("org/springframework/web/context");
DispatcherServlet servlet = new DispatcherServlet();
@ -807,7 +807,7 @@ public class DispatcherServletTests {
}
@Test
public void environmentOperations() {
void environmentOperations() {
DispatcherServlet servlet = new DispatcherServlet();
ConfigurableEnvironment defaultEnv = servlet.getEnvironment();
assertThat(defaultEnv).isNotNull();
@ -828,7 +828,7 @@ public class DispatcherServletTests {
}
@Test
public void allowedOptionsIncludesPatchMethod() throws Exception {
void allowedOptionsIncludesPatchMethod() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "OPTIONS", "/foo");
MockHttpServletResponse response = spy(new MockHttpServletResponse());
DispatcherServlet servlet = new DispatcherServlet();
@ -839,7 +839,7 @@ public class DispatcherServletTests {
}
@Test
public void contextInitializers() throws Exception {
void contextInitializers() throws Exception {
DispatcherServlet servlet = new DispatcherServlet();
servlet.setContextClass(SimpleWebApplicationContext.class);
servlet.setContextInitializers(new TestWebContextInitializer(), new OtherWebContextInitializer());
@ -849,7 +849,7 @@ public class DispatcherServletTests {
}
@Test
public void contextInitializerClasses() throws Exception {
void contextInitializerClasses() throws Exception {
DispatcherServlet servlet = new DispatcherServlet();
servlet.setContextClass(SimpleWebApplicationContext.class);
servlet.setContextInitializerClasses(
@ -860,7 +860,7 @@ public class DispatcherServletTests {
}
@Test
public void globalInitializerClasses() throws Exception {
void globalInitializerClasses() throws Exception {
DispatcherServlet servlet = new DispatcherServlet();
servlet.setContextClass(SimpleWebApplicationContext.class);
getServletContext().setInitParameter(ContextLoader.GLOBAL_INITIALIZER_CLASSES_PARAM,
@ -871,7 +871,7 @@ public class DispatcherServletTests {
}
@Test
public void mixedInitializerClasses() throws Exception {
void mixedInitializerClasses() throws Exception {
DispatcherServlet servlet = new DispatcherServlet();
servlet.setContextClass(SimpleWebApplicationContext.class);
getServletContext().setInitParameter(ContextLoader.GLOBAL_INITIALIZER_CLASSES_PARAM,
@ -883,7 +883,7 @@ public class DispatcherServletTests {
}
@Test
public void webDavMethod() throws Exception {
void webDavMethod() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(getServletContext(), "PROPFIND", "/body.do");
MockHttpServletResponse response = new MockHttpServletResponse();
complexDispatcherServlet.service(request, response);
@ -891,7 +891,7 @@ public class DispatcherServletTests {
}
@Test
public void shouldResetResponseIfNotCommitted() throws Exception {
void shouldResetResponseBufferIfNotCommitted() throws Exception {
StaticWebApplicationContext context = new StaticWebApplicationContext();
context.setServletContext(getServletContext());
context.registerSingleton("/error", ErrorController.class);
@ -903,11 +903,12 @@ public class DispatcherServletTests {
assertThatThrownBy(() -> servlet.service(request, response)).isInstanceOf(ServletException.class)
.hasCauseInstanceOf(IllegalArgumentException.class);
assertThat(response.getContentAsByteArray()).isEmpty();
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getStatus()).isEqualTo(400);
assertThat(response.getHeader("Test-Header")).isEqualTo("spring");
}
@Test
public void shouldAttemptToResetResponseIfCommitted() throws Exception {
void shouldAttemptToResetResponseBufferIfCommitted() throws Exception {
StaticWebApplicationContext context = new StaticWebApplicationContext();
context.setServletContext(getServletContext());
context.registerSingleton("/error", ErrorController.class);
@ -921,6 +922,7 @@ public class DispatcherServletTests {
.hasCauseInstanceOf(IllegalArgumentException.class);
assertThat(response.getContentAsByteArray()).isNotEmpty();
assertThat(response.getStatus()).isEqualTo(400);
assertThat(response.getHeader("Test-Header")).isEqualTo("spring");
}
@ -974,6 +976,7 @@ public class DispatcherServletTests {
@Override
public ModelAndView handleRequest(HttpServletRequest request, HttpServletResponse response) throws Exception {
response.setStatus(400);
response.setHeader("Test-Header", "spring");
if (request.getAttribute("commit") != null) {
response.flushBuffer();
}