From 3056301015bc96c32d29d68b50d1fc09615512cd Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Tue, 14 Oct 2014 09:01:09 -0400 Subject: [PATCH] Improve HttpSessionHandshakeInterceptor Use explicit flag whether to copy all attributes. --- .../HttpSessionHandshakeInterceptor.java | 94 ++++++++++++++----- .../HttpSessionHandshakeInterceptorTests.java | 41 +++++--- 2 files changed, 99 insertions(+), 36 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java index 53aa9699ca9..a76bdbcda23 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.server.support; import java.util.Collection; +import java.util.Collections; import java.util.Enumeration; import java.util.Map; import javax.servlet.http.HttpSession; @@ -30,8 +31,11 @@ import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.server.HandshakeInterceptor; /** - * An interceptor to copy HTTP session attributes into the map of "handshake attributes" - * made available through {@link WebSocketSession#getAttributes()}. + * An interceptor to copy information from the HTTP session to the "handshake + * attributes" map to made available via{@link WebSocketSession#getAttributes()}. + * + *

Copies a subset or all HTTP session attributes and/or the HTTP session id + * under the key {@link #HTTP_SESSION_ID_ATTR_NAME}. * * @author Rossen Stoyanchev * @since 4.0 @@ -44,32 +48,67 @@ public class HttpSessionHandshakeInterceptor implements HandshakeInterceptor { */ public static final String HTTP_SESSION_ID_ATTR_NAME = "HTTP.SESSION.ID"; + private final Collection attributeNames; - private boolean copyHttpSessionId; + private boolean copyAllAttributes; + + private boolean copyHttpSessionId = true; /** - * A constructor for copying all available HTTP session attributes. + * Default constructor for copying all HTTP session attributes and the HTTP + * session id. + * @see #setCopyAllAttributes + * @see #setCopyHttpSessionId */ public HttpSessionHandshakeInterceptor() { - this(null); + this.attributeNames = Collections.emptyList(); + this.copyAllAttributes = true; } /** - * A constructor for copying a subset of HTTP session attributes. - * @param attributeNames the HTTP session attributes to copy + * Constructor for copying specific HTTP session attributes and the HTTP + * session id. + * @param attributeNames session attributes to copy + * @see #setCopyAllAttributes + * @see #setCopyHttpSessionId */ public HttpSessionHandshakeInterceptor(Collection attributeNames) { - this.attributeNames = attributeNames; + this.attributeNames = Collections.unmodifiableCollection(attributeNames); + this.copyAllAttributes = false; + } + + + /** + * Return the configured attribute names to copy (read-only). + */ + public Collection getAttributeNames() { + return this.attributeNames; } /** - * When set to "true", the HTTP session id is copied to the WebSocket - * handshake attributes, and is subsequently available via - * {@link org.springframework.web.socket.WebSocketSession#getAttributes()} + * Whether to copy all attributes from the HTTP session. If set to "true" any + * explicitly configured attribute names are ignored. + *

By default this is set to either "true" or "false" depending on which + * constructor was used (default or with attribute names respectively). + * @param copyAllAttributes whether to copy all attributes + */ + public void setCopyAllAttributes(boolean copyAllAttributes) { + this.copyAllAttributes = copyAllAttributes; + } + + /** + * Whether to copy all HTTP session attributes. + */ + public boolean isCopyAllAttributes() { + return this.copyAllAttributes; + } + + /** + * Whether the HTTP session id should be copied to the handshake attributes * under the key {@link #HTTP_SESSION_ID_ATTR_NAME}. - *

By default this is "false". + *

By default this is "true". * @param copyHttpSessionId whether to copy the HTTP session id. */ public void setCopyHttpSessionId(boolean copyHttpSessionId) { @@ -88,25 +127,30 @@ public class HttpSessionHandshakeInterceptor implements HandshakeInterceptor { public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws Exception { - if (request instanceof ServletServerHttpRequest) { - ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request; - HttpSession session = servletRequest.getServletRequest().getSession(false); - if (session != null) { - Enumeration names = session.getAttributeNames(); - while (names.hasMoreElements()) { - String name = names.nextElement(); - if (CollectionUtils.isEmpty(this.attributeNames) || this.attributeNames.contains(name)) { - attributes.put(name, session.getAttribute(name)); - } - } - if (isCopyHttpSessionId()) { - attributes.put(HTTP_SESSION_ID_ATTR_NAME, session.getId()); + HttpSession session = getSession(request); + if (session != null) { + if (isCopyHttpSessionId()) { + attributes.put(HTTP_SESSION_ID_ATTR_NAME, session.getId()); + } + Enumeration names = session.getAttributeNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + if (isCopyAllAttributes() || getAttributeNames().contains(name)) { + attributes.put(name, session.getAttribute(name)); } } } return true; } + private HttpSession getSession(ServerHttpRequest request) { + if (request instanceof ServletServerHttpRequest) { + ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request; + return servletRequest.getServletRequest().getSession(false); + } + return null; + } + @Override public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception ex) { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java index 8ab3162c729..c52a8f879cc 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java @@ -24,6 +24,7 @@ import java.util.Set; import org.junit.Test; import org.mockito.Mockito; import org.springframework.mock.web.test.MockHttpSession; +import org.springframework.mock.web.test.MockServletContext; import org.springframework.web.socket.AbstractHttpRequestTests; import org.springframework.web.socket.WebSocketHandler; @@ -38,28 +39,29 @@ public class HttpSessionHandshakeInterceptorTests extends AbstractHttpRequestTes @Test - public void copyAllAttributes() throws Exception { - + public void defaultConstructor() throws Exception { Map attributes = new HashMap(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + this.servletRequest.setSession(new MockHttpSession(null, "123")); this.servletRequest.getSession().setAttribute("foo", "bar"); this.servletRequest.getSession().setAttribute("bar", "baz"); HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes); - assertEquals(2, attributes.size()); + assertEquals(3, attributes.size()); assertEquals("bar", attributes.get("foo")); assertEquals("baz", attributes.get("bar")); + assertEquals("123", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME)); } @Test - public void copySelectedAttributes() throws Exception { - + public void constructorWithAttributeNames() throws Exception { Map attributes = new HashMap(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + this.servletRequest.setSession(new MockHttpSession(null, "123")); this.servletRequest.getSession().setAttribute("foo", "bar"); this.servletRequest.getSession().setAttribute("bar", "baz"); @@ -67,29 +69,46 @@ public class HttpSessionHandshakeInterceptorTests extends AbstractHttpRequestTes HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(names); interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes); + assertEquals(2, attributes.size()); + assertEquals("bar", attributes.get("foo")); + assertEquals("123", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME)); + } + + @Test + public void doNotCopyHttpSessionId() throws Exception { + Map attributes = new HashMap(); + WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + + this.servletRequest.setSession(new MockHttpSession(null, "123")); + this.servletRequest.getSession().setAttribute("foo", "bar"); + + HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); + interceptor.setCopyHttpSessionId(false); + interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes); + assertEquals(1, attributes.size()); assertEquals("bar", attributes.get("foo")); } - @Test - public void copyHttpSessionId() throws Exception { + @Test + public void doNotCopyAttributes() throws Exception { Map attributes = new HashMap(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); - this.servletRequest.setSession(new MockHttpSession(null, "foo")); + this.servletRequest.setSession(new MockHttpSession(null, "123")); + this.servletRequest.getSession().setAttribute("foo", "bar"); HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); - interceptor.setCopyHttpSessionId(true); + interceptor.setCopyAllAttributes(false); interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes); assertEquals(1, attributes.size()); - assertEquals("foo", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME)); + assertEquals("123", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME)); } @Test public void doNotCauseSessionCreation() throws Exception { - Map attributes = new HashMap(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);