diff --git a/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/trace/WebRequestTraceFilter.java b/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/trace/WebRequestTraceFilter.java index 9ddf8ab7093..b9b01be9e10 100644 --- a/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/trace/WebRequestTraceFilter.java +++ b/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/trace/WebRequestTraceFilter.java @@ -29,6 +29,7 @@ import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; import javax.servlet.http.HttpSession; import org.apache.commons.logging.Log; @@ -37,6 +38,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.boot.actuate.trace.TraceProperties.Include; import org.springframework.boot.autoconfigure.web.ErrorAttributes; import org.springframework.core.Ordered; +import org.springframework.http.HttpStatus; import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.filter.OncePerRequestFilter; @@ -108,11 +110,14 @@ public class WebRequestTraceFilter extends OncePerRequestFilter implements Order throws ServletException, IOException { Map trace = getTrace(request); logTrace(request, trace); + int status = HttpStatus.INTERNAL_SERVER_ERROR.value(); try { filterChain.doFilter(request, response); + status = response.getStatus(); } finally { - enhanceTrace(trace, response); + enhanceTrace(trace, status == response.getStatus() ? response + : new CustomStatusResponseWrapper(response, status)); this.repository.add(trace); } } @@ -214,4 +219,21 @@ public class WebRequestTraceFilter extends OncePerRequestFilter implements Order this.errorAttributes = errorAttributes; } + private static final class CustomStatusResponseWrapper + extends HttpServletResponseWrapper { + + private final int status; + + private CustomStatusResponseWrapper(HttpServletResponse response, int status) { + super(response); + this.status = status; + } + + @Override + public int getStatus() { + return this.status; + } + + } + } diff --git a/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/trace/WebRequestTraceFilterTests.java b/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/trace/WebRequestTraceFilterTests.java index a69ee7c349a..365e2b67d20 100644 --- a/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/trace/WebRequestTraceFilterTests.java +++ b/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/trace/WebRequestTraceFilterTests.java @@ -37,8 +37,12 @@ import org.springframework.boot.autoconfigure.web.DefaultErrorAttributes; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -183,4 +187,32 @@ public class WebRequestTraceFilterTests { assertEquals("Foo", map.get("message").toString()); } + @Test + @SuppressWarnings("unchecked") + public void filterHas500ResponseStatusWhenExceptionIsThrown() + throws ServletException, IOException { + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/foo"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + try { + this.filter.doFilterInternal(request, response, new FilterChain() { + + @Override + public void doFilter(ServletRequest request, ServletResponse response) + throws IOException, ServletException { + throw new RuntimeException(); + } + + }); + fail("Exception was swallowed"); + } + catch (RuntimeException ex) { + Map headers = (Map) this.repository.findAll() + .iterator().next().getInfo().get("headers"); + Map responseHeaders = (Map) headers + .get("response"); + assertThat((String) responseHeaders.get("status"), is(equalTo("500"))); + } + } + }