Record trace with response status of 500 following unhandled exception

Previously, if the filter chain threw an unhandled exception,
WebRequestTraceFilter would record a trace with a response status of
200. This occurred because response.getStatus() would return 200 as
the container had not yet caught the exception and mapped it to an
error response.

This commit updates WebRequestTraceFilter to align its behaviour with
MetricsFilter. It now assumes that the response status will be a 500
and only updates that to the status of the response if the call to the
filter chain returns successfully.

To avoid making a breaking change to the signature of the protected
enhanceTrace method, an HttpServletResponseWrapper is used to include
the correct status in the trace.

Closes gh-5331
This commit is contained in:
Andy Wilkinson 2016-04-18 16:10:26 +01:00
parent 2e54078083
commit 9210029109
2 changed files with 55 additions and 1 deletions

View File

@ -29,6 +29,7 @@ import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import javax.servlet.http.HttpSession;
import org.apache.commons.logging.Log;
@ -37,6 +38,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.boot.actuate.trace.TraceProperties.Include;
import org.springframework.boot.autoconfigure.web.ErrorAttributes;
import org.springframework.core.Ordered;
import org.springframework.http.HttpStatus;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.filter.OncePerRequestFilter;
@ -108,11 +110,14 @@ public class WebRequestTraceFilter extends OncePerRequestFilter implements Order
throws ServletException, IOException {
Map<String, Object> trace = getTrace(request);
logTrace(request, trace);
int status = HttpStatus.INTERNAL_SERVER_ERROR.value();
try {
filterChain.doFilter(request, response);
status = response.getStatus();
}
finally {
enhanceTrace(trace, response);
enhanceTrace(trace, status == response.getStatus() ? response
: new CustomStatusResponseWrapper(response, status));
this.repository.add(trace);
}
}
@ -214,4 +219,21 @@ public class WebRequestTraceFilter extends OncePerRequestFilter implements Order
this.errorAttributes = errorAttributes;
}
private static final class CustomStatusResponseWrapper
extends HttpServletResponseWrapper {
private final int status;
private CustomStatusResponseWrapper(HttpServletResponse response, int status) {
super(response);
this.status = status;
}
@Override
public int getStatus() {
return this.status;
}
}
}

View File

@ -37,8 +37,12 @@ import org.springframework.boot.autoconfigure.web.DefaultErrorAttributes;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -183,4 +187,32 @@ public class WebRequestTraceFilterTests {
assertEquals("Foo", map.get("message").toString());
}
@Test
@SuppressWarnings("unchecked")
public void filterHas500ResponseStatusWhenExceptionIsThrown()
throws ServletException, IOException {
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/foo");
MockHttpServletResponse response = new MockHttpServletResponse();
try {
this.filter.doFilterInternal(request, response, new FilterChain() {
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
throw new RuntimeException();
}
});
fail("Exception was swallowed");
}
catch (RuntimeException ex) {
Map<String, Object> headers = (Map<String, Object>) this.repository.findAll()
.iterator().next().getInfo().get("headers");
Map<String, Object> responseHeaders = (Map<String, Object>) headers
.get("response");
assertThat((String) responseHeaders.get("status"), is(equalTo("500")));
}
}
}