parent
067994712d
commit
6fcc869338
|
|
@ -31,6 +31,7 @@ import javax.servlet.http.HttpServletResponse;
|
|||
|
||||
import org.springframework.http.HttpRequest;
|
||||
import org.springframework.http.server.ServletServerHttpRequest;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.web.util.UriComponents;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
|
|
@ -92,6 +93,10 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
|
|||
|
||||
private static class ForwardedHeaderRequestWrapper extends HttpServletRequestWrapper {
|
||||
|
||||
public static final Enumeration<String> EMPTY_HEADER_VALUES =
|
||||
Collections.enumeration(Collections.<String>emptyList());
|
||||
|
||||
|
||||
private final String scheme;
|
||||
|
||||
private final boolean secure;
|
||||
|
|
@ -100,9 +105,9 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
|
|||
|
||||
private final int port;
|
||||
|
||||
private final String portInUrl;
|
||||
private final StringBuffer requestUrl;
|
||||
|
||||
private final Map<String, List<String>> headers = new LinkedHashMap<String, List<String>>();
|
||||
private final Map<String, List<String>> headers;
|
||||
|
||||
|
||||
public ForwardedHeaderRequestWrapper(HttpServletRequest request) {
|
||||
|
|
@ -116,51 +121,35 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
|
|||
this.secure = "https".equals(scheme);
|
||||
this.host = uriComponents.getHost();
|
||||
this.port = (port == -1 ? (this.secure ? 443 : 80) : port);
|
||||
this.portInUrl = (port == -1 ? "" : ":" + port);
|
||||
this.requestUrl = initRequestUrl(this.scheme, this.host, port, request.getRequestURI());
|
||||
this.headers = initHeaders(request);
|
||||
}
|
||||
|
||||
private static StringBuffer initRequestUrl(String scheme, String host, int port, String path) {
|
||||
StringBuffer sb = new StringBuffer();
|
||||
sb.append(scheme).append("://").append(host);
|
||||
sb.append(port == -1 ? "" : ":" + port);
|
||||
sb.append(path);
|
||||
return sb;
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy the headers excluding any {@link #FORWARDED_HEADER_NAMES}.
|
||||
*/
|
||||
private static Map<String, List<String>> initHeaders(HttpServletRequest request) {
|
||||
Map<String, List<String>> headers = new LinkedHashMap<String, List<String>>();
|
||||
Enumeration<String> headerNames = request.getHeaderNames();
|
||||
while (headerNames.hasMoreElements()) {
|
||||
String name = headerNames.nextElement();
|
||||
this.headers.put(name, Collections.list(request.getHeaders(name)));
|
||||
headers.put(name, Collections.list(request.getHeaders(name)));
|
||||
}
|
||||
for (String name : FORWARDED_HEADER_NAMES) {
|
||||
this.headers.remove(name);
|
||||
headers.remove(name);
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String getHeader(String name) {
|
||||
Map.Entry<String, List<String>> header = getHeaderEntry(name);
|
||||
if (header == null || header.getValue() == null || header.getValue().isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
return header.getValue().get(0);
|
||||
}
|
||||
|
||||
protected Map.Entry<String, List<String>> getHeaderEntry(String name) {
|
||||
for (Map.Entry<String, List<String>> entry : this.headers.entrySet()) {
|
||||
if (entry.getKey().equalsIgnoreCase(name)) {
|
||||
return entry;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Enumeration<String> getHeaderNames() {
|
||||
return Collections.enumeration(this.headers.keySet());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Enumeration<String> getHeaders(String name) {
|
||||
Map.Entry<String, List<String>> header = getHeaderEntry(name);
|
||||
if (header == null || header.getValue() == null) {
|
||||
return Collections.enumeration(Collections.<String>emptyList());
|
||||
}
|
||||
return Collections.enumeration(header.getValue());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getScheme() {
|
||||
return this.scheme;
|
||||
|
|
@ -183,10 +172,26 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
|
|||
|
||||
@Override
|
||||
public StringBuffer getRequestURL() {
|
||||
StringBuffer sb = new StringBuffer();
|
||||
sb.append(this.scheme).append("://").append(this.host).append(this.portInUrl);
|
||||
sb.append(getRequestURI());
|
||||
return sb;
|
||||
return this.requestUrl;
|
||||
}
|
||||
|
||||
// Override header accessors in order to not expose forwarded headers
|
||||
|
||||
@Override
|
||||
public String getHeader(String name) {
|
||||
List<String> value = this.headers.get(name);
|
||||
return (CollectionUtils.isEmpty(value) ? null : value.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Enumeration<String> getHeaders(String name) {
|
||||
List<String> value = this.headers.get(name);
|
||||
return (CollectionUtils.isEmpty(value) ? EMPTY_HEADER_VALUES : Collections.enumeration(value));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Enumeration<String> getHeaderNames() {
|
||||
return Collections.enumeration(this.headers.keySet());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ public class ForwardedHeaderFilterTests {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void xForwardedHeaders() throws Exception {
|
||||
public void forwardedRequest() throws Exception {
|
||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||
request.setScheme("http");
|
||||
request.setServerName("localhost");
|
||||
|
|
@ -65,10 +65,9 @@ public class ForwardedHeaderFilterTests {
|
|||
request.addHeader("foo", "bar");
|
||||
|
||||
MockFilterChain chain = new MockFilterChain(new HttpServlet() {});
|
||||
|
||||
this.filter.doFilter(request, new MockHttpServletResponse(), chain);
|
||||
|
||||
HttpServletRequest actual = (HttpServletRequest) chain.getRequest();
|
||||
|
||||
assertEquals("https://84.198.58.199/mvc-showcase", actual.getRequestURL().toString());
|
||||
assertEquals("https", actual.getScheme());
|
||||
assertEquals("84.198.58.199", actual.getServerName());
|
||||
|
|
|
|||
Loading…
Reference in New Issue