diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java
index 53aa9699ca9..a76bdbcda23 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java
@@ -17,6 +17,7 @@
package org.springframework.web.socket.server.support;
import java.util.Collection;
+import java.util.Collections;
import java.util.Enumeration;
import java.util.Map;
import javax.servlet.http.HttpSession;
@@ -30,8 +31,11 @@ import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.server.HandshakeInterceptor;
/**
- * An interceptor to copy HTTP session attributes into the map of "handshake attributes"
- * made available through {@link WebSocketSession#getAttributes()}.
+ * An interceptor to copy information from the HTTP session to the "handshake
+ * attributes" map to made available via{@link WebSocketSession#getAttributes()}.
+ *
+ *
Copies a subset or all HTTP session attributes and/or the HTTP session id
+ * under the key {@link #HTTP_SESSION_ID_ATTR_NAME}.
*
* @author Rossen Stoyanchev
* @since 4.0
@@ -44,32 +48,67 @@ public class HttpSessionHandshakeInterceptor implements HandshakeInterceptor {
*/
public static final String HTTP_SESSION_ID_ATTR_NAME = "HTTP.SESSION.ID";
+
private final Collection attributeNames;
- private boolean copyHttpSessionId;
+ private boolean copyAllAttributes;
+
+ private boolean copyHttpSessionId = true;
/**
- * A constructor for copying all available HTTP session attributes.
+ * Default constructor for copying all HTTP session attributes and the HTTP
+ * session id.
+ * @see #setCopyAllAttributes
+ * @see #setCopyHttpSessionId
*/
public HttpSessionHandshakeInterceptor() {
- this(null);
+ this.attributeNames = Collections.emptyList();
+ this.copyAllAttributes = true;
}
/**
- * A constructor for copying a subset of HTTP session attributes.
- * @param attributeNames the HTTP session attributes to copy
+ * Constructor for copying specific HTTP session attributes and the HTTP
+ * session id.
+ * @param attributeNames session attributes to copy
+ * @see #setCopyAllAttributes
+ * @see #setCopyHttpSessionId
*/
public HttpSessionHandshakeInterceptor(Collection attributeNames) {
- this.attributeNames = attributeNames;
+ this.attributeNames = Collections.unmodifiableCollection(attributeNames);
+ this.copyAllAttributes = false;
+ }
+
+
+ /**
+ * Return the configured attribute names to copy (read-only).
+ */
+ public Collection getAttributeNames() {
+ return this.attributeNames;
}
/**
- * When set to "true", the HTTP session id is copied to the WebSocket
- * handshake attributes, and is subsequently available via
- * {@link org.springframework.web.socket.WebSocketSession#getAttributes()}
+ * Whether to copy all attributes from the HTTP session. If set to "true" any
+ * explicitly configured attribute names are ignored.
+ * By default this is set to either "true" or "false" depending on which
+ * constructor was used (default or with attribute names respectively).
+ * @param copyAllAttributes whether to copy all attributes
+ */
+ public void setCopyAllAttributes(boolean copyAllAttributes) {
+ this.copyAllAttributes = copyAllAttributes;
+ }
+
+ /**
+ * Whether to copy all HTTP session attributes.
+ */
+ public boolean isCopyAllAttributes() {
+ return this.copyAllAttributes;
+ }
+
+ /**
+ * Whether the HTTP session id should be copied to the handshake attributes
* under the key {@link #HTTP_SESSION_ID_ATTR_NAME}.
- *
By default this is "false".
+ *
By default this is "true".
* @param copyHttpSessionId whether to copy the HTTP session id.
*/
public void setCopyHttpSessionId(boolean copyHttpSessionId) {
@@ -88,25 +127,30 @@ public class HttpSessionHandshakeInterceptor implements HandshakeInterceptor {
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map attributes) throws Exception {
- if (request instanceof ServletServerHttpRequest) {
- ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request;
- HttpSession session = servletRequest.getServletRequest().getSession(false);
- if (session != null) {
- Enumeration names = session.getAttributeNames();
- while (names.hasMoreElements()) {
- String name = names.nextElement();
- if (CollectionUtils.isEmpty(this.attributeNames) || this.attributeNames.contains(name)) {
- attributes.put(name, session.getAttribute(name));
- }
- }
- if (isCopyHttpSessionId()) {
- attributes.put(HTTP_SESSION_ID_ATTR_NAME, session.getId());
+ HttpSession session = getSession(request);
+ if (session != null) {
+ if (isCopyHttpSessionId()) {
+ attributes.put(HTTP_SESSION_ID_ATTR_NAME, session.getId());
+ }
+ Enumeration names = session.getAttributeNames();
+ while (names.hasMoreElements()) {
+ String name = names.nextElement();
+ if (isCopyAllAttributes() || getAttributeNames().contains(name)) {
+ attributes.put(name, session.getAttribute(name));
}
}
}
return true;
}
+ private HttpSession getSession(ServerHttpRequest request) {
+ if (request instanceof ServletServerHttpRequest) {
+ ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request;
+ return servletRequest.getServletRequest().getSession(false);
+ }
+ return null;
+ }
+
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Exception ex) {
diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java
index 8ab3162c729..c52a8f879cc 100644
--- a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java
+++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java
@@ -24,6 +24,7 @@ import java.util.Set;
import org.junit.Test;
import org.mockito.Mockito;
import org.springframework.mock.web.test.MockHttpSession;
+import org.springframework.mock.web.test.MockServletContext;
import org.springframework.web.socket.AbstractHttpRequestTests;
import org.springframework.web.socket.WebSocketHandler;
@@ -38,28 +39,29 @@ public class HttpSessionHandshakeInterceptorTests extends AbstractHttpRequestTes
@Test
- public void copyAllAttributes() throws Exception {
-
+ public void defaultConstructor() throws Exception {
Map attributes = new HashMap();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
+ this.servletRequest.setSession(new MockHttpSession(null, "123"));
this.servletRequest.getSession().setAttribute("foo", "bar");
this.servletRequest.getSession().setAttribute("bar", "baz");
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes);
- assertEquals(2, attributes.size());
+ assertEquals(3, attributes.size());
assertEquals("bar", attributes.get("foo"));
assertEquals("baz", attributes.get("bar"));
+ assertEquals("123", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME));
}
@Test
- public void copySelectedAttributes() throws Exception {
-
+ public void constructorWithAttributeNames() throws Exception {
Map attributes = new HashMap();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
+ this.servletRequest.setSession(new MockHttpSession(null, "123"));
this.servletRequest.getSession().setAttribute("foo", "bar");
this.servletRequest.getSession().setAttribute("bar", "baz");
@@ -67,29 +69,46 @@ public class HttpSessionHandshakeInterceptorTests extends AbstractHttpRequestTes
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(names);
interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes);
+ assertEquals(2, attributes.size());
+ assertEquals("bar", attributes.get("foo"));
+ assertEquals("123", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME));
+ }
+
+ @Test
+ public void doNotCopyHttpSessionId() throws Exception {
+ Map attributes = new HashMap();
+ WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
+
+ this.servletRequest.setSession(new MockHttpSession(null, "123"));
+ this.servletRequest.getSession().setAttribute("foo", "bar");
+
+ HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
+ interceptor.setCopyHttpSessionId(false);
+ interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes);
+
assertEquals(1, attributes.size());
assertEquals("bar", attributes.get("foo"));
}
- @Test
- public void copyHttpSessionId() throws Exception {
+ @Test
+ public void doNotCopyAttributes() throws Exception {
Map attributes = new HashMap();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
- this.servletRequest.setSession(new MockHttpSession(null, "foo"));
+ this.servletRequest.setSession(new MockHttpSession(null, "123"));
+ this.servletRequest.getSession().setAttribute("foo", "bar");
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
- interceptor.setCopyHttpSessionId(true);
+ interceptor.setCopyAllAttributes(false);
interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes);
assertEquals(1, attributes.size());
- assertEquals("foo", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME));
+ assertEquals("123", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME));
}
@Test
public void doNotCauseSessionCreation() throws Exception {
-
Map attributes = new HashMap();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);