diff --git a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletResponse.java b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletResponse.java index a48ae62901f..62137f0e66f 100644 --- a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletResponse.java +++ b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletResponse.java @@ -742,6 +742,13 @@ public class MockHttpServletResponse implements HttpServletResponse { super.flush(); setCommitted(true); } + + @Override + public void close() { + super.flush(); + super.close(); + setCommitted(true); + } } } diff --git a/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletResponseTests.java b/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletResponseTests.java index 0b5ca7eaaf8..078e689eb82 100644 --- a/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletResponseTests.java +++ b/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletResponseTests.java @@ -227,6 +227,16 @@ public class MockHttpServletResponseTests { assertEquals(1, response.getContentAsByteArray().length); } + @Test // SPR-16683 + public void servletWriterCommittedOnWriterClose() throws IOException { + assertFalse(response.isCommitted()); + response.getWriter().write("X"); + assertFalse(response.isCommitted()); + response.getWriter().close(); + assertTrue(response.isCommitted()); + assertEquals(1, response.getContentAsByteArray().length); + } + @Test public void servletWriterAutoFlushedForString() throws IOException { response.getWriter().write("X"); diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletResponse.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletResponse.java index 727faafd0c7..48001545c0e 100644 --- a/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletResponse.java +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletResponse.java @@ -40,6 +40,7 @@ import javax.servlet.http.HttpServletResponse; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.LinkedCaseInsensitiveMap; import org.springframework.util.StringUtils; @@ -72,6 +73,7 @@ public class MockHttpServletResponse implements HttpServletResponse { private boolean writerAccessAllowed = true; + @Nullable private String characterEncoding = WebUtils.DEFAULT_CHARACTER_ENCODING; private boolean charset = false; @@ -80,10 +82,12 @@ public class MockHttpServletResponse implements HttpServletResponse { private final ServletOutputStream outputStream = new ResponseServletOutputStream(this.content); + @Nullable private PrintWriter writer; private long contentLength = 0; + @Nullable private String contentType; private int bufferSize = 4096; @@ -103,8 +107,10 @@ public class MockHttpServletResponse implements HttpServletResponse { private int status = HttpServletResponse.SC_OK; + @Nullable private String errorMessage; + @Nullable private String forwardedUrl; private final List includedUrls = new ArrayList<>(); @@ -170,6 +176,7 @@ public class MockHttpServletResponse implements HttpServletResponse { } @Override + @Nullable public String getCharacterEncoding() { return this.characterEncoding; } @@ -221,7 +228,7 @@ public class MockHttpServletResponse implements HttpServletResponse { } @Override - public void setContentType(String contentType) { + public void setContentType(@Nullable String contentType) { this.contentType = contentType; if (contentType != null) { try { @@ -244,6 +251,7 @@ public class MockHttpServletResponse implements HttpServletResponse { } @Override + @Nullable public String getContentType() { return this.contentType; } @@ -352,6 +360,7 @@ public class MockHttpServletResponse implements HttpServletResponse { return this.cookies.toArray(new Cookie[0]); } + @Nullable public Cookie getCookie(String name) { Assert.notNull(name, "Cookie name must not be null"); for (Cookie cookie : this.cookies) { @@ -387,6 +396,7 @@ public class MockHttpServletResponse implements HttpServletResponse { * @return the associated header value, or {@code null} if none */ @Override + @Nullable public String getHeader(String name) { HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name); return (header != null ? header.getStringValue() : null); @@ -417,6 +427,7 @@ public class MockHttpServletResponse implements HttpServletResponse { * @param name the name of the header * @return the associated header value, or {@code null} if none */ + @Nullable public Object getHeaderValue(String name) { HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name); return (header != null ? header.getValue() : null); @@ -495,6 +506,7 @@ public class MockHttpServletResponse implements HttpServletResponse { setCommitted(true); } + @Nullable public String getRedirectedUrl() { return getHeader(HttpHeaders.LOCATION); } @@ -625,6 +637,7 @@ public class MockHttpServletResponse implements HttpServletResponse { return this.status; } + @Nullable public String getErrorMessage() { return this.errorMessage; } @@ -634,21 +647,23 @@ public class MockHttpServletResponse implements HttpServletResponse { // Methods for MockRequestDispatcher //--------------------------------------------------------------------- - public void setForwardedUrl(String forwardedUrl) { + public void setForwardedUrl(@Nullable String forwardedUrl) { this.forwardedUrl = forwardedUrl; } + @Nullable public String getForwardedUrl() { return this.forwardedUrl; } - public void setIncludedUrl(String includedUrl) { + public void setIncludedUrl(@Nullable String includedUrl) { this.includedUrls.clear(); if (includedUrl != null) { this.includedUrls.add(includedUrl); } } + @Nullable public String getIncludedUrl() { int count = this.includedUrls.size(); Assert.state(count <= 1, @@ -702,7 +717,7 @@ public class MockHttpServletResponse implements HttpServletResponse { } @Override - public void write(char buf[], int off, int len) { + public void write(char[] buf, int off, int len) { super.write(buf, off, len); super.flush(); setCommittedIfBufferSizeExceeded(); @@ -727,6 +742,13 @@ public class MockHttpServletResponse implements HttpServletResponse { super.flush(); setCommitted(true); } + + @Override + public void close() { + super.flush(); + super.close(); + setCommitted(true); + } } }