Improve CORS handling in AbstractSockJsService

After this change, AbstractSockJsService does not add CORS headers if
the response already contains an "Access-Control-Allow-Origin" header.
Essentially it backs off assuming CORS headers are handled centrally
e.g. through a Filter.

In order to support this, the ServletServerHttpResponse now returns an
instance of HttpHeaders that also provides access to headers already
present in the HttpServletResponse.

Issue: SPR-11443
This commit is contained in:
Rossen Stoyanchev 2014-03-05 21:02:03 -05:00
parent cf3b2b1a4d
commit 49d7bda722
5 changed files with 144 additions and 41 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -158,8 +158,8 @@ public class HttpHeaders implements MultiValueMap<String, String>, Serializable
List<MediaType> result = (value != null) ? MediaType.parseMediaTypes(value) : Collections.<MediaType>emptyList();
// Some containers parse 'Accept' into multiple values
if ((result.size() == 1) && (headers.get(ACCEPT).size() > 1)) {
value = StringUtils.collectionToCommaDelimitedString(headers.get(ACCEPT));
if ((result.size() == 1) && (get(ACCEPT).size() > 1)) {
value = StringUtils.collectionToCommaDelimitedString(get(ACCEPT));
result = MediaType.parseMediaTypes(value);
}

View File

@ -18,6 +18,8 @@ package org.springframework.http.server;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
@ -26,18 +28,25 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
/**
* {@link ServerHttpResponse} implementation that is based on a {@link HttpServletResponse}.
*
* @author Arjen Poutsma
* @author Rossen Stoyanchev
* @since 3.0
*/
public class ServletServerHttpResponse implements ServerHttpResponse {
private static final boolean servlet3Present =
ClassUtils.isPresent("javax.servlet.AsyncContext", ServletServerHttpResponse.class.getClassLoader());
private final HttpServletResponse servletResponse;
private final HttpHeaders headers = new HttpHeaders();
private final HttpHeaders headers;
private boolean headersWritten = false;
@ -49,6 +58,7 @@ public class ServletServerHttpResponse implements ServerHttpResponse {
public ServletServerHttpResponse(HttpServletResponse servletResponse) {
Assert.notNull(servletResponse, "'servletResponse' must not be null");
this.servletResponse = servletResponse;
this.headers = (servlet3Present ? new ServletResponseHttpHeaders() : new HttpHeaders());
}
@ -105,4 +115,56 @@ public class ServletServerHttpResponse implements ServerHttpResponse {
this.headersWritten = true;
}
}
/**
* Extends HttpHeaders with the ability to look up headers already present in
* the underlying HttpServletResponse.
*
* The intent is merely to expose what is available through the HttpServletResponse
* i.e. the ability to look up specific header values by name. All other
* map-related operations (e.g. iteration, removal, etc) apply only to values
* added directly through HttpHeaders methods.
*
* @since 4.0.3
*/
private class ServletResponseHttpHeaders extends HttpHeaders {
private static final long serialVersionUID = 3410708522401046302L;
@Override
public String getFirst(String headerName) {
String value = servletResponse.getHeader(headerName);
if (value != null) {
return value;
}
else {
return super.getFirst(headerName);
}
}
@Override
public List<String> get(Object key) {
Assert.isInstanceOf(String.class, key, "key must be a String-based header name");
Collection<String> values1 = servletResponse.getHeaders((String) key);
boolean isEmpty1 = CollectionUtils.isEmpty(values1);
List<String> values2 = super.get(key);
boolean isEmpty2 = CollectionUtils.isEmpty(values2);
if (isEmpty1 && isEmpty2) {
return null;
}
List<String> values = new ArrayList<String>();
if (!isEmpty1) {
values.addAll(values1);
}
if (!isEmpty2) {
values.addAll(values2);
}
return values;
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2012 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,6 +17,7 @@
package org.springframework.http.server;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.List;
import org.junit.Before;
@ -71,6 +72,19 @@ public class ServletServerHttpResponseTests {
assertEquals("Invalid Content-Type", "UTF-8", mockResponse.getCharacterEncoding());
}
@Test
public void getHeadersFromHttpServletResponse() {
String headerName = "Access-Control-Allow-Origin";
String headerValue = "localhost:8080";
this.mockResponse.addHeader(headerName, headerValue);
this.response = new ServletServerHttpResponse(this.mockResponse);
assertEquals(headerValue, this.response.getHeaders().getFirst(headerName));
assertEquals(Arrays.asList(headerValue), this.response.getHeaders().get(headerName));
}
@Test
public void getBody() throws Exception {
byte[] content = "Hello World".getBytes("UTF-8");

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,6 +28,7 @@ import java.util.concurrent.TimeUnit;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.InvalidMediaTypeException;
@ -352,22 +353,32 @@ public abstract class AbstractSockJsService implements SockJsService {
protected void addCorsHeaders(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) {
String origin = request.getHeaders().getFirst("origin");
HttpHeaders requestHeaders = request.getHeaders();
HttpHeaders responseHeaders = response.getHeaders();
// Perhaps a CORS Filter has already added this?
if (!CollectionUtils.isEmpty(responseHeaders.get("Access-Control-Allow-Origin"))) {
logger.debug("Skip adding CORS headers, response already contains \"Access-Control-Allow-Origin\"");
return;
}
String origin = requestHeaders.getFirst("origin");
origin = ((origin == null) || origin.equals("null")) ? "*" : origin;
response.getHeaders().add("Access-Control-Allow-Origin", origin);
response.getHeaders().add("Access-Control-Allow-Credentials", "true");
responseHeaders.add("Access-Control-Allow-Origin", origin);
responseHeaders.add("Access-Control-Allow-Credentials", "true");
List<String> accessControllerHeaders = request.getHeaders().get("Access-Control-Request-Headers");
List<String> accessControllerHeaders = requestHeaders.get("Access-Control-Request-Headers");
if (accessControllerHeaders != null) {
for (String header : accessControllerHeaders) {
response.getHeaders().add("Access-Control-Allow-Headers", header);
responseHeaders.add("Access-Control-Allow-Headers", header);
}
}
if (!ObjectUtils.isEmpty(httpMethods)) {
response.getHeaders().add("Access-Control-Allow-Methods", StringUtils.arrayToDelimitedString(httpMethods, ", "));
response.getHeaders().add("Access-Control-Max-Age", String.valueOf(ONE_YEAR));
responseHeaders.add("Access-Control-Allow-Methods", StringUtils.arrayToDelimitedString(httpMethods, ", "));
responseHeaders.add("Access-Control-Max-Age", String.valueOf(ONE_YEAR));
}
}

View File

@ -54,25 +54,25 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
public void validateRequest() throws Exception {
this.service.setWebSocketEnabled(false);
handleRequest("GET", "/echo/server/session/websocket", HttpStatus.NOT_FOUND);
resetResponseAndHandleRequest("GET", "/echo/server/session/websocket", HttpStatus.NOT_FOUND);
this.service.setWebSocketEnabled(true);
handleRequest("GET", "/echo/server/session/websocket", HttpStatus.OK);
resetResponseAndHandleRequest("GET", "/echo/server/session/websocket", HttpStatus.OK);
handleRequest("GET", "/echo//", HttpStatus.NOT_FOUND);
handleRequest("GET", "/echo///", HttpStatus.NOT_FOUND);
handleRequest("GET", "/echo/other", HttpStatus.NOT_FOUND);
handleRequest("GET", "/echo//service/websocket", HttpStatus.NOT_FOUND);
handleRequest("GET", "/echo/server//websocket", HttpStatus.NOT_FOUND);
handleRequest("GET", "/echo/server/session/", HttpStatus.NOT_FOUND);
handleRequest("GET", "/echo/s.erver/session/websocket", HttpStatus.NOT_FOUND);
handleRequest("GET", "/echo/server/s.ession/websocket", HttpStatus.NOT_FOUND);
resetResponseAndHandleRequest("GET", "/echo//", HttpStatus.NOT_FOUND);
resetResponseAndHandleRequest("GET", "/echo///", HttpStatus.NOT_FOUND);
resetResponseAndHandleRequest("GET", "/echo/other", HttpStatus.NOT_FOUND);
resetResponseAndHandleRequest("GET", "/echo//service/websocket", HttpStatus.NOT_FOUND);
resetResponseAndHandleRequest("GET", "/echo/server//websocket", HttpStatus.NOT_FOUND);
resetResponseAndHandleRequest("GET", "/echo/server/session/", HttpStatus.NOT_FOUND);
resetResponseAndHandleRequest("GET", "/echo/s.erver/session/websocket", HttpStatus.NOT_FOUND);
resetResponseAndHandleRequest("GET", "/echo/server/s.ession/websocket", HttpStatus.NOT_FOUND);
}
@Test
public void handleInfoGet() throws Exception {
handleRequest("GET", "/echo/info", HttpStatus.OK);
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
assertEquals("application/json;charset=UTF-8", this.servletResponse.getContentType());
assertEquals("*", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
@ -86,19 +86,32 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
this.service.setSessionCookieNeeded(false);
this.service.setWebSocketEnabled(false);
handleRequest("GET", "/echo/info", HttpStatus.OK);
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
body = this.servletResponse.getContentAsString();
assertEquals(",\"origins\":[\"*:*\"],\"cookie_needed\":false,\"websocket\":false}",
body.substring(body.indexOf(',')));
}
// SPR-11443
@Test
public void handleInfoGetCorsFilter() throws Exception {
// Simulate scenario where Filter would have already set CORS headers
this.servletResponse.setHeader("Access-Control-Allow-Origin", "foobar:123");
handleRequest("GET", "/echo/info", HttpStatus.OK);
assertEquals("foobar:123", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
}
@Test
public void handleInfoOptions() throws Exception {
this.servletRequest.addHeader("Access-Control-Request-Headers", "Last-Modified");
handleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
this.response.flush();
assertEquals("*", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
@ -111,7 +124,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
@Test
public void handleIframeRequest() throws Exception {
handleRequest("GET", "/echo/iframe.html", HttpStatus.OK);
resetResponseAndHandleRequest("GET", "/echo/iframe.html", HttpStatus.OK);
assertEquals("text/html;charset=UTF-8", this.servletResponse.getContentType());
assertTrue(this.servletResponse.getContentAsString().startsWith("<!DOCTYPE html>\n"));
@ -125,39 +138,42 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
this.servletRequest.addHeader("If-None-Match", "\"0da1ed070012f304e47b83c81c48ad620\"");
handleRequest("GET", "/echo/iframe.html", HttpStatus.NOT_MODIFIED);
resetResponseAndHandleRequest("GET", "/echo/iframe.html", HttpStatus.NOT_MODIFIED);
}
@Test
public void handleRawWebSocketRequest() throws Exception {
handleRequest("GET", "/echo", HttpStatus.OK);
resetResponseAndHandleRequest("GET", "/echo", HttpStatus.OK);
assertEquals("Welcome to SockJS!\n", this.servletResponse.getContentAsString());
handleRequest("GET", "/echo/websocket", HttpStatus.OK);
resetResponseAndHandleRequest("GET", "/echo/websocket", HttpStatus.OK);
assertNull("Raw WebSocket should not open a SockJS session", this.service.sessionId);
assertSame(this.handler, this.service.handler);
}
private void handleRequest(String httpMethod, String uri, HttpStatus httpStatus) throws IOException {
resetResponse();
setRequest(httpMethod, uri);
String sockJsPath = uri.substring("/echo".length());
this.service.handleRequest(this.request, this.response, sockJsPath, this.handler);
assertEquals(httpStatus.value(), this.servletResponse.getStatus());
}
@Test
public void handleEmptyContentType() throws Exception {
servletRequest.setContentType("");
handleRequest("GET", "/echo/info", HttpStatus.OK);
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
assertEquals("Invalid/empty content should have been ignored", 200, this.servletResponse.getStatus());
}
private void resetResponseAndHandleRequest(String httpMethod, String uri, HttpStatus httpStatus) throws IOException {
resetResponse();
handleRequest(httpMethod, uri, httpStatus);
}
private void handleRequest(String httpMethod, String uri, HttpStatus httpStatus) throws IOException {
setRequest(httpMethod, uri);
String sockJsPath = uri.substring("/echo".length());
this.service.handleRequest(this.request, this.response, sockJsPath, this.handler);
assertEquals(httpStatus.value(), this.servletResponse.getStatus());
}
private static class TestSockJsService extends AbstractSockJsService {