Improve CORS handling in AbstractSockJsService
After this change, AbstractSockJsService does not add CORS headers if the response already contains an "Access-Control-Allow-Origin" header. Essentially it backs off assuming CORS headers are handled centrally e.g. through a Filter. In order to support this, the ServletServerHttpResponse now returns an instance of HttpHeaders that also provides access to headers already present in the HttpServletResponse. Issue: SPR-11443
This commit is contained in:
parent
cf3b2b1a4d
commit
49d7bda722
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2013 the original author or authors.
|
||||
* Copyright 2002-2014 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -158,8 +158,8 @@ public class HttpHeaders implements MultiValueMap<String, String>, Serializable
|
|||
List<MediaType> result = (value != null) ? MediaType.parseMediaTypes(value) : Collections.<MediaType>emptyList();
|
||||
|
||||
// Some containers parse 'Accept' into multiple values
|
||||
if ((result.size() == 1) && (headers.get(ACCEPT).size() > 1)) {
|
||||
value = StringUtils.collectionToCommaDelimitedString(headers.get(ACCEPT));
|
||||
if ((result.size() == 1) && (get(ACCEPT).size() > 1)) {
|
||||
value = StringUtils.collectionToCommaDelimitedString(get(ACCEPT));
|
||||
result = MediaType.parseMediaTypes(value);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@ package org.springframework.http.server;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -26,18 +28,25 @@ import javax.servlet.http.HttpServletResponse;
|
|||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.ClassUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* {@link ServerHttpResponse} implementation that is based on a {@link HttpServletResponse}.
|
||||
*
|
||||
* @author Arjen Poutsma
|
||||
* @author Rossen Stoyanchev
|
||||
* @since 3.0
|
||||
*/
|
||||
public class ServletServerHttpResponse implements ServerHttpResponse {
|
||||
|
||||
private static final boolean servlet3Present =
|
||||
ClassUtils.isPresent("javax.servlet.AsyncContext", ServletServerHttpResponse.class.getClassLoader());
|
||||
|
||||
|
||||
private final HttpServletResponse servletResponse;
|
||||
|
||||
private final HttpHeaders headers = new HttpHeaders();
|
||||
private final HttpHeaders headers;
|
||||
|
||||
private boolean headersWritten = false;
|
||||
|
||||
|
@ -49,6 +58,7 @@ public class ServletServerHttpResponse implements ServerHttpResponse {
|
|||
public ServletServerHttpResponse(HttpServletResponse servletResponse) {
|
||||
Assert.notNull(servletResponse, "'servletResponse' must not be null");
|
||||
this.servletResponse = servletResponse;
|
||||
this.headers = (servlet3Present ? new ServletResponseHttpHeaders() : new HttpHeaders());
|
||||
}
|
||||
|
||||
|
||||
|
@ -105,4 +115,56 @@ public class ServletServerHttpResponse implements ServerHttpResponse {
|
|||
this.headersWritten = true;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extends HttpHeaders with the ability to look up headers already present in
|
||||
* the underlying HttpServletResponse.
|
||||
*
|
||||
* The intent is merely to expose what is available through the HttpServletResponse
|
||||
* i.e. the ability to look up specific header values by name. All other
|
||||
* map-related operations (e.g. iteration, removal, etc) apply only to values
|
||||
* added directly through HttpHeaders methods.
|
||||
*
|
||||
* @since 4.0.3
|
||||
*/
|
||||
private class ServletResponseHttpHeaders extends HttpHeaders {
|
||||
|
||||
private static final long serialVersionUID = 3410708522401046302L;
|
||||
|
||||
@Override
|
||||
public String getFirst(String headerName) {
|
||||
String value = servletResponse.getHeader(headerName);
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
else {
|
||||
return super.getFirst(headerName);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> get(Object key) {
|
||||
|
||||
Assert.isInstanceOf(String.class, key, "key must be a String-based header name");
|
||||
Collection<String> values1 = servletResponse.getHeaders((String) key);
|
||||
boolean isEmpty1 = CollectionUtils.isEmpty(values1);
|
||||
|
||||
List<String> values2 = super.get(key);
|
||||
boolean isEmpty2 = CollectionUtils.isEmpty(values2);
|
||||
|
||||
if (isEmpty1 && isEmpty2) {
|
||||
return null;
|
||||
}
|
||||
|
||||
List<String> values = new ArrayList<String>();
|
||||
if (!isEmpty1) {
|
||||
values.addAll(values1);
|
||||
}
|
||||
if (!isEmpty2) {
|
||||
values.addAll(values2);
|
||||
}
|
||||
return values;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2012 the original author or authors.
|
||||
* Copyright 2002-2014 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,6 +17,7 @@
|
|||
package org.springframework.http.server;
|
||||
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.Before;
|
||||
|
@ -71,6 +72,19 @@ public class ServletServerHttpResponseTests {
|
|||
assertEquals("Invalid Content-Type", "UTF-8", mockResponse.getCharacterEncoding());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getHeadersFromHttpServletResponse() {
|
||||
|
||||
String headerName = "Access-Control-Allow-Origin";
|
||||
String headerValue = "localhost:8080";
|
||||
|
||||
this.mockResponse.addHeader(headerName, headerValue);
|
||||
this.response = new ServletServerHttpResponse(this.mockResponse);
|
||||
|
||||
assertEquals(headerValue, this.response.getHeaders().getFirst(headerName));
|
||||
assertEquals(Arrays.asList(headerValue), this.response.getHeaders().get(headerName));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getBody() throws Exception {
|
||||
byte[] content = "Hello World".getBytes("UTF-8");
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2013 the original author or authors.
|
||||
* Copyright 2002-2014 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,6 +28,7 @@ import java.util.concurrent.TimeUnit;
|
|||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.InvalidMediaTypeException;
|
||||
|
@ -352,22 +353,32 @@ public abstract class AbstractSockJsService implements SockJsService {
|
|||
|
||||
|
||||
protected void addCorsHeaders(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) {
|
||||
String origin = request.getHeaders().getFirst("origin");
|
||||
|
||||
HttpHeaders requestHeaders = request.getHeaders();
|
||||
HttpHeaders responseHeaders = response.getHeaders();
|
||||
|
||||
// Perhaps a CORS Filter has already added this?
|
||||
if (!CollectionUtils.isEmpty(responseHeaders.get("Access-Control-Allow-Origin"))) {
|
||||
logger.debug("Skip adding CORS headers, response already contains \"Access-Control-Allow-Origin\"");
|
||||
return;
|
||||
}
|
||||
|
||||
String origin = requestHeaders.getFirst("origin");
|
||||
origin = ((origin == null) || origin.equals("null")) ? "*" : origin;
|
||||
|
||||
response.getHeaders().add("Access-Control-Allow-Origin", origin);
|
||||
response.getHeaders().add("Access-Control-Allow-Credentials", "true");
|
||||
responseHeaders.add("Access-Control-Allow-Origin", origin);
|
||||
responseHeaders.add("Access-Control-Allow-Credentials", "true");
|
||||
|
||||
List<String> accessControllerHeaders = request.getHeaders().get("Access-Control-Request-Headers");
|
||||
List<String> accessControllerHeaders = requestHeaders.get("Access-Control-Request-Headers");
|
||||
if (accessControllerHeaders != null) {
|
||||
for (String header : accessControllerHeaders) {
|
||||
response.getHeaders().add("Access-Control-Allow-Headers", header);
|
||||
responseHeaders.add("Access-Control-Allow-Headers", header);
|
||||
}
|
||||
}
|
||||
|
||||
if (!ObjectUtils.isEmpty(httpMethods)) {
|
||||
response.getHeaders().add("Access-Control-Allow-Methods", StringUtils.arrayToDelimitedString(httpMethods, ", "));
|
||||
response.getHeaders().add("Access-Control-Max-Age", String.valueOf(ONE_YEAR));
|
||||
responseHeaders.add("Access-Control-Allow-Methods", StringUtils.arrayToDelimitedString(httpMethods, ", "));
|
||||
responseHeaders.add("Access-Control-Max-Age", String.valueOf(ONE_YEAR));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -54,25 +54,25 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
|
|||
public void validateRequest() throws Exception {
|
||||
|
||||
this.service.setWebSocketEnabled(false);
|
||||
handleRequest("GET", "/echo/server/session/websocket", HttpStatus.NOT_FOUND);
|
||||
resetResponseAndHandleRequest("GET", "/echo/server/session/websocket", HttpStatus.NOT_FOUND);
|
||||
|
||||
this.service.setWebSocketEnabled(true);
|
||||
handleRequest("GET", "/echo/server/session/websocket", HttpStatus.OK);
|
||||
resetResponseAndHandleRequest("GET", "/echo/server/session/websocket", HttpStatus.OK);
|
||||
|
||||
handleRequest("GET", "/echo//", HttpStatus.NOT_FOUND);
|
||||
handleRequest("GET", "/echo///", HttpStatus.NOT_FOUND);
|
||||
handleRequest("GET", "/echo/other", HttpStatus.NOT_FOUND);
|
||||
handleRequest("GET", "/echo//service/websocket", HttpStatus.NOT_FOUND);
|
||||
handleRequest("GET", "/echo/server//websocket", HttpStatus.NOT_FOUND);
|
||||
handleRequest("GET", "/echo/server/session/", HttpStatus.NOT_FOUND);
|
||||
handleRequest("GET", "/echo/s.erver/session/websocket", HttpStatus.NOT_FOUND);
|
||||
handleRequest("GET", "/echo/server/s.ession/websocket", HttpStatus.NOT_FOUND);
|
||||
resetResponseAndHandleRequest("GET", "/echo//", HttpStatus.NOT_FOUND);
|
||||
resetResponseAndHandleRequest("GET", "/echo///", HttpStatus.NOT_FOUND);
|
||||
resetResponseAndHandleRequest("GET", "/echo/other", HttpStatus.NOT_FOUND);
|
||||
resetResponseAndHandleRequest("GET", "/echo//service/websocket", HttpStatus.NOT_FOUND);
|
||||
resetResponseAndHandleRequest("GET", "/echo/server//websocket", HttpStatus.NOT_FOUND);
|
||||
resetResponseAndHandleRequest("GET", "/echo/server/session/", HttpStatus.NOT_FOUND);
|
||||
resetResponseAndHandleRequest("GET", "/echo/s.erver/session/websocket", HttpStatus.NOT_FOUND);
|
||||
resetResponseAndHandleRequest("GET", "/echo/server/s.ession/websocket", HttpStatus.NOT_FOUND);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void handleInfoGet() throws Exception {
|
||||
|
||||
handleRequest("GET", "/echo/info", HttpStatus.OK);
|
||||
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
|
||||
|
||||
assertEquals("application/json;charset=UTF-8", this.servletResponse.getContentType());
|
||||
assertEquals("*", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
|
||||
|
@ -86,19 +86,32 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
|
|||
|
||||
this.service.setSessionCookieNeeded(false);
|
||||
this.service.setWebSocketEnabled(false);
|
||||
handleRequest("GET", "/echo/info", HttpStatus.OK);
|
||||
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
|
||||
|
||||
body = this.servletResponse.getContentAsString();
|
||||
assertEquals(",\"origins\":[\"*:*\"],\"cookie_needed\":false,\"websocket\":false}",
|
||||
body.substring(body.indexOf(',')));
|
||||
}
|
||||
|
||||
// SPR-11443
|
||||
|
||||
@Test
|
||||
public void handleInfoGetCorsFilter() throws Exception {
|
||||
|
||||
// Simulate scenario where Filter would have already set CORS headers
|
||||
this.servletResponse.setHeader("Access-Control-Allow-Origin", "foobar:123");
|
||||
|
||||
handleRequest("GET", "/echo/info", HttpStatus.OK);
|
||||
|
||||
assertEquals("foobar:123", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void handleInfoOptions() throws Exception {
|
||||
|
||||
this.servletRequest.addHeader("Access-Control-Request-Headers", "Last-Modified");
|
||||
|
||||
handleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
|
||||
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
|
||||
this.response.flush();
|
||||
|
||||
assertEquals("*", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
|
||||
|
@ -111,7 +124,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
|
|||
@Test
|
||||
public void handleIframeRequest() throws Exception {
|
||||
|
||||
handleRequest("GET", "/echo/iframe.html", HttpStatus.OK);
|
||||
resetResponseAndHandleRequest("GET", "/echo/iframe.html", HttpStatus.OK);
|
||||
|
||||
assertEquals("text/html;charset=UTF-8", this.servletResponse.getContentType());
|
||||
assertTrue(this.servletResponse.getContentAsString().startsWith("<!DOCTYPE html>\n"));
|
||||
|
@ -125,39 +138,42 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
|
|||
|
||||
this.servletRequest.addHeader("If-None-Match", "\"0da1ed070012f304e47b83c81c48ad620\"");
|
||||
|
||||
handleRequest("GET", "/echo/iframe.html", HttpStatus.NOT_MODIFIED);
|
||||
resetResponseAndHandleRequest("GET", "/echo/iframe.html", HttpStatus.NOT_MODIFIED);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void handleRawWebSocketRequest() throws Exception {
|
||||
|
||||
handleRequest("GET", "/echo", HttpStatus.OK);
|
||||
resetResponseAndHandleRequest("GET", "/echo", HttpStatus.OK);
|
||||
assertEquals("Welcome to SockJS!\n", this.servletResponse.getContentAsString());
|
||||
|
||||
handleRequest("GET", "/echo/websocket", HttpStatus.OK);
|
||||
resetResponseAndHandleRequest("GET", "/echo/websocket", HttpStatus.OK);
|
||||
assertNull("Raw WebSocket should not open a SockJS session", this.service.sessionId);
|
||||
assertSame(this.handler, this.service.handler);
|
||||
}
|
||||
|
||||
|
||||
private void handleRequest(String httpMethod, String uri, HttpStatus httpStatus) throws IOException {
|
||||
resetResponse();
|
||||
setRequest(httpMethod, uri);
|
||||
String sockJsPath = uri.substring("/echo".length());
|
||||
this.service.handleRequest(this.request, this.response, sockJsPath, this.handler);
|
||||
|
||||
assertEquals(httpStatus.value(), this.servletResponse.getStatus());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void handleEmptyContentType() throws Exception {
|
||||
|
||||
servletRequest.setContentType("");
|
||||
handleRequest("GET", "/echo/info", HttpStatus.OK);
|
||||
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
|
||||
|
||||
assertEquals("Invalid/empty content should have been ignored", 200, this.servletResponse.getStatus());
|
||||
}
|
||||
|
||||
private void resetResponseAndHandleRequest(String httpMethod, String uri, HttpStatus httpStatus) throws IOException {
|
||||
resetResponse();
|
||||
handleRequest(httpMethod, uri, httpStatus);
|
||||
}
|
||||
|
||||
private void handleRequest(String httpMethod, String uri, HttpStatus httpStatus) throws IOException {
|
||||
setRequest(httpMethod, uri);
|
||||
String sockJsPath = uri.substring("/echo".length());
|
||||
this.service.handleRequest(this.request, this.response, sockJsPath, this.handler);
|
||||
|
||||
assertEquals(httpStatus.value(), this.servletResponse.getStatus());
|
||||
}
|
||||
|
||||
|
||||
private static class TestSockJsService extends AbstractSockJsService {
|
||||
|
||||
|
|
Loading…
Reference in New Issue