Handle Content-Length in ShallowEtagHeaderFilter more robustly

This commit ensures that setting the Content-Length through
setHeader("Content-Length", x") has the same effect as calling
setContentLength in the ShallowEtagHeaderFilter. It also filters out
Content-Type headers similarly to Content-Length.

Closes gh-32039
This commit is contained in:
Arjen Poutsma 2024-01-18 15:32:01 +01:00
parent b8b31ff8a1
commit 375e0e6827
2 changed files with 142 additions and 6 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,6 +21,10 @@ import java.io.InputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.WriteListener;
@ -55,6 +59,9 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
@Nullable
private Integer contentLength;
@Nullable
private String contentType;
/**
* Create a new ContentCachingResponseWrapper for the given servlet response.
@ -139,6 +146,122 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
this.contentLength = lenInt;
}
@Override
public void setContentType(String type) {
this.contentType = type;
}
@Override
@Nullable
public String getContentType() {
return this.contentType;
}
@Override
public boolean containsHeader(String name) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
return this.contentLength != null;
}
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
return this.contentType != null;
}
else {
return super.containsHeader(name);
}
}
@Override
public void setHeader(String name, String value) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
this.contentLength = Integer.valueOf(value);
}
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
this.contentType = value;
}
else {
super.setHeader(name, value);
}
}
@Override
public void addHeader(String name, String value) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
this.contentLength = Integer.valueOf(value);
}
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
this.contentType = value;
}
else {
super.addHeader(name, value);
}
}
@Override
public void setIntHeader(String name, int value) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
this.contentLength = Integer.valueOf(value);
}
else {
super.setIntHeader(name, value);
}
}
@Override
public void addIntHeader(String name, int value) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
this.contentLength = Integer.valueOf(value);
}
else {
super.addIntHeader(name, value);
}
}
@Override
@Nullable
public String getHeader(String name) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
return (this.contentLength != null) ? this.contentLength.toString() : null;
}
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
return this.contentType;
}
else {
return super.getHeader(name);
}
}
@Override
public Collection<String> getHeaders(String name) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
return this.contentLength != null ? Collections.singleton(this.contentLength.toString()) :
Collections.emptySet();
}
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
return this.contentType != null ? Collections.singleton(this.contentType) : Collections.emptySet();
}
else {
return super.getHeaders(name);
}
}
@Override
public Collection<String> getHeaderNames() {
Collection<String> headerNames = super.getHeaderNames();
if (this.contentLength != null || this.contentType != null) {
List<String> result = new ArrayList<>(headerNames);
if (this.contentLength != null) {
result.add(HttpHeaders.CONTENT_LENGTH);
}
if (this.contentType != null) {
result.add(HttpHeaders.CONTENT_TYPE);
}
return result;
}
else {
return headerNames;
}
}
@Override
public void setBufferSize(int size) {
if (size > this.content.size()) {
@ -197,11 +320,17 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
protected void copyBodyToResponse(boolean complete) throws IOException {
if (this.content.size() > 0) {
HttpServletResponse rawResponse = (HttpServletResponse) getResponse();
if ((complete || this.contentLength != null) && !rawResponse.isCommitted()) {
if (rawResponse.getHeader(HttpHeaders.TRANSFER_ENCODING) == null) {
rawResponse.setContentLength(complete ? this.content.size() : this.contentLength);
if (!rawResponse.isCommitted()) {
if (complete || this.contentLength != null) {
if (rawResponse.getHeader(HttpHeaders.TRANSFER_ENCODING) == null) {
rawResponse.setContentLength(complete ? this.content.size() : this.contentLength);
}
this.contentLength = null;
}
if (complete || this.contentType != null) {
rawResponse.setContentType(this.contentType);
this.contentType = null;
}
this.contentLength = null;
}
this.content.writeTo(rawResponse.getOutputStream());
this.content.reset();

View File

@ -23,6 +23,7 @@ import jakarta.servlet.FilterChain;
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.Test;
import org.springframework.http.MediaType;
import org.springframework.util.FileCopyUtils;
import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
import org.springframework.web.testfixture.servlet.MockHttpServletResponse;
@ -68,6 +69,7 @@ class ShallowEtagHeaderFilterTests {
FilterChain filterChain = (filterRequest, filterResponse) -> {
assertThat(filterRequest).as("Invalid request passed").isEqualTo(request);
((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK);
filterResponse.setContentType(MediaType.TEXT_PLAIN_VALUE);
FileCopyUtils.copy(responseBody, filterResponse.getOutputStream());
};
filter.doFilter(request, response, filterChain);
@ -75,6 +77,7 @@ class ShallowEtagHeaderFilterTests {
assertThat(response.getStatus()).as("Invalid status").isEqualTo(200);
assertThat(response.getHeader("ETag")).as("Invalid ETag").isEqualTo("\"0b10a8db164e0754105b7a99be72e3fe5\"");
assertThat(response.getContentLength()).as("Invalid Content-Length header").isGreaterThan(0);
assertThat(response.getContentType()).as("Invalid Content-Type header").isEqualTo(MediaType.TEXT_PLAIN_VALUE);
assertThat(response.getContentAsByteArray()).as("Invalid content").isEqualTo(responseBody);
}
@ -88,6 +91,7 @@ class ShallowEtagHeaderFilterTests {
FilterChain filterChain = (filterRequest, filterResponse) -> {
assertThat(filterRequest).as("Invalid request passed").isEqualTo(request);
((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK);
filterResponse.setContentType(MediaType.TEXT_PLAIN_VALUE);
FileCopyUtils.copy(responseBody, filterResponse.getOutputStream());
};
filter.doFilter(request, response, filterChain);
@ -95,6 +99,7 @@ class ShallowEtagHeaderFilterTests {
assertThat(response.getStatus()).as("Invalid status").isEqualTo(200);
assertThat(response.getHeader("ETag")).as("Invalid ETag").isEqualTo("W/\"0b10a8db164e0754105b7a99be72e3fe5\"");
assertThat(response.getContentLength()).as("Invalid Content-Length header").isGreaterThan(0);
assertThat(response.getContentType()).as("Invalid Content-Type header").isEqualTo(MediaType.TEXT_PLAIN_VALUE);
assertThat(response.getContentAsByteArray()).as("Invalid content").isEqualTo(responseBody);
}
@ -108,14 +113,16 @@ class ShallowEtagHeaderFilterTests {
FilterChain filterChain = (filterRequest, filterResponse) -> {
assertThat(filterRequest).as("Invalid request passed").isEqualTo(request);
byte[] responseBody = "Hello World".getBytes(StandardCharsets.UTF_8);
FileCopyUtils.copy(responseBody, filterResponse.getOutputStream());
filterResponse.setContentLength(responseBody.length);
filterResponse.setContentType(MediaType.TEXT_PLAIN_VALUE);
FileCopyUtils.copy(responseBody, filterResponse.getOutputStream());
};
filter.doFilter(request, response, filterChain);
assertThat(response.getStatus()).as("Invalid status").isEqualTo(304);
assertThat(response.getHeader("ETag")).as("Invalid ETag").isEqualTo("\"0b10a8db164e0754105b7a99be72e3fe5\"");
assertThat(response.containsHeader("Content-Length")).as("Response has Content-Length header").isFalse();
assertThat(response.containsHeader("Content-Type")).as("Response has Content-Type header").isFalse();
byte[] expecteds = new byte[0];
assertThat(response.getContentAsByteArray()).as("Invalid content").isEqualTo(expecteds);
}