Do not flush response buffer in ErrorPageFilter if request is async
Previously, the ErrorPageFilter would always flush the response buffer, irrespective of the request being asynchronous. This could lead to a response being committed prematurely, preventing, for example, headers being set by subsequent processing. This commit updates ErrorPageFilter so that in the success case (status < 400) the response buffer is only flushed if the request is not async (determined by calling request.isAsyncStarted()). If an exception's been thrown or the status is >= 400 the response buffer is always flushed. Fixes #1316
This commit is contained in:
parent
6f8d4778ad
commit
bacbff1fbf
|
@ -56,7 +56,7 @@ import org.springframework.web.filter.OncePerRequestFilter;
|
||||||
@Component
|
@Component
|
||||||
@Order(Ordered.HIGHEST_PRECEDENCE)
|
@Order(Ordered.HIGHEST_PRECEDENCE)
|
||||||
class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer implements
|
class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer implements
|
||||||
Filter, NonEmbeddedServletContainerFactory {
|
Filter, NonEmbeddedServletContainerFactory {
|
||||||
|
|
||||||
private static Log logger = LogFactory.getLog(ErrorPageFilter.class);
|
private static Log logger = LogFactory.getLog(ErrorPageFilter.class);
|
||||||
|
|
||||||
|
@ -109,18 +109,21 @@ class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer imple
|
||||||
int status = wrapped.getStatus();
|
int status = wrapped.getStatus();
|
||||||
if (status >= 400) {
|
if (status >= 400) {
|
||||||
handleErrorStatus(request, response, status, wrapped.getMessage());
|
handleErrorStatus(request, response, status, wrapped.getMessage());
|
||||||
|
response.flushBuffer();
|
||||||
|
}
|
||||||
|
else if (!request.isAsyncStarted()) {
|
||||||
|
response.flushBuffer();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
catch (Throwable ex) {
|
catch (Throwable ex) {
|
||||||
handleException(request, response, wrapped, ex);
|
handleException(request, response, wrapped, ex);
|
||||||
|
response.flushBuffer();
|
||||||
}
|
}
|
||||||
response.flushBuffer();
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void handleErrorStatus(HttpServletRequest request,
|
private void handleErrorStatus(HttpServletRequest request,
|
||||||
HttpServletResponse response, int status, String message)
|
HttpServletResponse response, int status, String message)
|
||||||
throws ServletException, IOException {
|
throws ServletException, IOException {
|
||||||
String errorPath = getErrorPath(this.statuses, status);
|
String errorPath = getErrorPath(this.statuses, status);
|
||||||
if (errorPath == null) {
|
if (errorPath == null) {
|
||||||
response.sendError(status, message);
|
response.sendError(status, message);
|
||||||
|
@ -132,7 +135,7 @@ class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer imple
|
||||||
|
|
||||||
private void handleException(HttpServletRequest request,
|
private void handleException(HttpServletRequest request,
|
||||||
HttpServletResponse response, ErrorWrapperResponse wrapped, Throwable ex)
|
HttpServletResponse response, ErrorWrapperResponse wrapped, Throwable ex)
|
||||||
throws IOException, ServletException {
|
throws IOException, ServletException {
|
||||||
Class<?> type = ex.getClass();
|
Class<?> type = ex.getClass();
|
||||||
String errorPath = getErrorPath(type);
|
String errorPath = getErrorPath(type);
|
||||||
if (errorPath == null) {
|
if (errorPath == null) {
|
||||||
|
|
|
@ -34,6 +34,7 @@ import org.springframework.mock.web.MockHttpServletRequest;
|
||||||
import org.springframework.mock.web.MockHttpServletResponse;
|
import org.springframework.mock.web.MockHttpServletResponse;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.Assert.assertNotNull;
|
||||||
import static org.junit.Assert.assertThat;
|
import static org.junit.Assert.assertThat;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
@ -59,6 +60,7 @@ public class ErrorPageFilterTests {
|
||||||
assertThat(this.chain.getRequest(), equalTo((ServletRequest) this.request));
|
assertThat(this.chain.getRequest(), equalTo((ServletRequest) this.request));
|
||||||
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse(),
|
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse(),
|
||||||
equalTo((ServletResponse) this.response));
|
equalTo((ServletResponse) this.response));
|
||||||
|
assertTrue(this.response.isCommitted());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -79,6 +81,7 @@ public class ErrorPageFilterTests {
|
||||||
equalTo((ServletResponse) this.response));
|
equalTo((ServletResponse) this.response));
|
||||||
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus(),
|
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus(),
|
||||||
equalTo(400));
|
equalTo(400));
|
||||||
|
assertTrue(this.response.isCommitted());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -97,6 +100,7 @@ public class ErrorPageFilterTests {
|
||||||
equalTo((ServletResponse) this.response));
|
equalTo((ServletResponse) this.response));
|
||||||
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus(),
|
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getStatus(),
|
||||||
equalTo(400));
|
equalTo(400));
|
||||||
|
assertTrue(this.response.isCommitted());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -199,6 +203,62 @@ public class ErrorPageFilterTests {
|
||||||
equalTo((Object) "BAD"));
|
equalTo((Object) "BAD"));
|
||||||
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE),
|
assertThat(this.request.getAttribute(RequestDispatcher.ERROR_EXCEPTION_TYPE),
|
||||||
equalTo((Object) IllegalStateException.class.getName()));
|
equalTo((Object) IllegalStateException.class.getName()));
|
||||||
|
assertTrue(this.response.isCommitted());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void responseIsNotCommitedWhenRequestIsAsync() throws Exception {
|
||||||
|
this.request.setAsyncStarted(true);
|
||||||
|
|
||||||
|
this.filter.doFilter(this.request, this.response, this.chain);
|
||||||
|
|
||||||
|
assertThat(this.chain.getRequest(), equalTo((ServletRequest) this.request));
|
||||||
|
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse(),
|
||||||
|
equalTo((ServletResponse) this.response));
|
||||||
|
assertFalse(this.response.isCommitted());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void responseIsCommitedWhenRequestIsAsyncAndExceptionIsThrown()
|
||||||
|
throws Exception {
|
||||||
|
this.filter.addErrorPages(new ErrorPage("/error"));
|
||||||
|
this.request.setAsyncStarted(true);
|
||||||
|
this.chain = new MockFilterChain() {
|
||||||
|
@Override
|
||||||
|
public void doFilter(ServletRequest request, ServletResponse response)
|
||||||
|
throws IOException, ServletException {
|
||||||
|
super.doFilter(request, response);
|
||||||
|
throw new RuntimeException("BAD");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
this.filter.doFilter(this.request, this.response, this.chain);
|
||||||
|
|
||||||
|
assertThat(this.chain.getRequest(), equalTo((ServletRequest) this.request));
|
||||||
|
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse(),
|
||||||
|
equalTo((ServletResponse) this.response));
|
||||||
|
assertTrue(this.response.isCommitted());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void responseIsCommitedWhenRequestIsAsyncAndStatusIs400Plus() throws Exception {
|
||||||
|
this.filter.addErrorPages(new ErrorPage("/error"));
|
||||||
|
this.request.setAsyncStarted(true);
|
||||||
|
this.chain = new MockFilterChain() {
|
||||||
|
@Override
|
||||||
|
public void doFilter(ServletRequest request, ServletResponse response)
|
||||||
|
throws IOException, ServletException {
|
||||||
|
super.doFilter(request, response);
|
||||||
|
((HttpServletResponse) response).sendError(400, "BAD");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
this.filter.doFilter(this.request, this.response, this.chain);
|
||||||
|
|
||||||
|
assertThat(this.chain.getRequest(), equalTo((ServletRequest) this.request));
|
||||||
|
assertThat(((HttpServletResponseWrapper) this.chain.getResponse()).getResponse(),
|
||||||
|
equalTo((ServletResponse) this.response));
|
||||||
|
assertTrue(this.response.isCommitted());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue