diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java index ca28b8e0580..875e5bd3127 100644 --- a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java @@ -27,6 +27,7 @@ import jakarta.servlet.http.HttpServletResponse; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatusCode; +import org.springframework.http.MediaType; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -118,12 +119,13 @@ public class ServletServerHttpResponse implements ServerHttpResponse { } }); // HttpServletResponse exposes some headers as properties: we should include those if not already present - if (this.servletResponse.getContentType() == null && this.headers.getContentType() != null) { - this.servletResponse.setContentType(this.headers.getContentType().toString()); + MediaType contentTypeHeader = this.headers.getContentType(); + if (this.servletResponse.getContentType() == null && contentTypeHeader != null) { + this.servletResponse.setContentType(contentTypeHeader.toString()); } - if (this.servletResponse.getCharacterEncoding() == null && this.headers.getContentType() != null && - this.headers.getContentType().getCharset() != null) { - this.servletResponse.setCharacterEncoding(this.headers.getContentType().getCharset().name()); + if (this.servletResponse.getCharacterEncoding() == null && contentTypeHeader != null && + contentTypeHeader.getCharset() != null) { + this.servletResponse.setCharacterEncoding(contentTypeHeader.getCharset().name()); } long contentLength = getHeaders().getContentLength(); if (contentLength != -1) { @@ -169,13 +171,15 @@ public class ServletServerHttpResponse implements ServerHttpResponse { } @Override + @Nullable public List get(Object key) { Assert.isInstanceOf(String.class, key, "Key must be a String-based header name"); String headerName = (String) key; if (headerName.equalsIgnoreCase(CONTENT_TYPE)) { // Content-Type is written as an override so don't merge - return Collections.singletonList(getFirst(headerName)); + String value = getFirst(headerName); + return (value != null ? Collections.singletonList(value) : null); } Collection values1 = servletResponse.getHeaders(headerName); diff --git a/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpResponseTests.java b/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpResponseTests.java index efdf1bf691f..34a40016ac0 100644 --- a/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpResponseTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpResponseTests.java @@ -75,6 +75,19 @@ class ServletServerHttpResponseTests { assertThat(mockResponse.getCharacterEncoding()).as("Invalid Content-Type").isEqualTo("UTF-8"); } + @Test + void getHeadersWithNoContentType() { + this.response = new ServletServerHttpResponse(this.mockResponse); + assertThat(this.response.getHeaders().get(HttpHeaders.CONTENT_TYPE)).isNull(); + } + + @Test + void getHeadersWithContentType() { + this.mockResponse.setContentType(MediaType.TEXT_PLAIN_VALUE); + this.response = new ServletServerHttpResponse(this.mockResponse); + assertThat(this.response.getHeaders().get(HttpHeaders.CONTENT_TYPE)).containsExactly(MediaType.TEXT_PLAIN_VALUE); + } + @Test void preExistingHeadersFromHttpServletResponse() { String headerName = "Access-Control-Allow-Origin";