Improve HttpSessionHandshakeInterceptor

Use explicit flag whether to copy all attributes.
This commit is contained in:
Rossen Stoyanchev 2014-10-14 09:01:09 -04:00
parent de11cd8791
commit 3056301015
2 changed files with 99 additions and 36 deletions

View File

@ -17,6 +17,7 @@
package org.springframework.web.socket.server.support;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Map;
import javax.servlet.http.HttpSession;
@ -30,8 +31,11 @@ import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.server.HandshakeInterceptor;
/**
* An interceptor to copy HTTP session attributes into the map of "handshake attributes"
* made available through {@link WebSocketSession#getAttributes()}.
* An interceptor to copy information from the HTTP session to the "handshake
* attributes" map to made available via{@link WebSocketSession#getAttributes()}.
*
* <p>Copies a subset or all HTTP session attributes and/or the HTTP session id
* under the key {@link #HTTP_SESSION_ID_ATTR_NAME}.
*
* @author Rossen Stoyanchev
* @since 4.0
@ -44,32 +48,67 @@ public class HttpSessionHandshakeInterceptor implements HandshakeInterceptor {
*/
public static final String HTTP_SESSION_ID_ATTR_NAME = "HTTP.SESSION.ID";
private final Collection<String> attributeNames;
private boolean copyHttpSessionId;
private boolean copyAllAttributes;
private boolean copyHttpSessionId = true;
/**
* A constructor for copying all available HTTP session attributes.
* Default constructor for copying all HTTP session attributes and the HTTP
* session id.
* @see #setCopyAllAttributes
* @see #setCopyHttpSessionId
*/
public HttpSessionHandshakeInterceptor() {
this(null);
this.attributeNames = Collections.emptyList();
this.copyAllAttributes = true;
}
/**
* A constructor for copying a subset of HTTP session attributes.
* @param attributeNames the HTTP session attributes to copy
* Constructor for copying specific HTTP session attributes and the HTTP
* session id.
* @param attributeNames session attributes to copy
* @see #setCopyAllAttributes
* @see #setCopyHttpSessionId
*/
public HttpSessionHandshakeInterceptor(Collection<String> attributeNames) {
this.attributeNames = attributeNames;
this.attributeNames = Collections.unmodifiableCollection(attributeNames);
this.copyAllAttributes = false;
}
/**
* Return the configured attribute names to copy (read-only).
*/
public Collection<String> getAttributeNames() {
return this.attributeNames;
}
/**
* When set to "true", the HTTP session id is copied to the WebSocket
* handshake attributes, and is subsequently available via
* {@link org.springframework.web.socket.WebSocketSession#getAttributes()}
* Whether to copy all attributes from the HTTP session. If set to "true" any
* explicitly configured attribute names are ignored.
* <p>By default this is set to either "true" or "false" depending on which
* constructor was used (default or with attribute names respectively).
* @param copyAllAttributes whether to copy all attributes
*/
public void setCopyAllAttributes(boolean copyAllAttributes) {
this.copyAllAttributes = copyAllAttributes;
}
/**
* Whether to copy all HTTP session attributes.
*/
public boolean isCopyAllAttributes() {
return this.copyAllAttributes;
}
/**
* Whether the HTTP session id should be copied to the handshake attributes
* under the key {@link #HTTP_SESSION_ID_ATTR_NAME}.
* <p>By default this is "false".
* <p>By default this is "true".
* @param copyHttpSessionId whether to copy the HTTP session id.
*/
public void setCopyHttpSessionId(boolean copyHttpSessionId) {
@ -88,25 +127,30 @@ public class HttpSessionHandshakeInterceptor implements HandshakeInterceptor {
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
if (request instanceof ServletServerHttpRequest) {
ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request;
HttpSession session = servletRequest.getServletRequest().getSession(false);
if (session != null) {
Enumeration<String> names = session.getAttributeNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
if (CollectionUtils.isEmpty(this.attributeNames) || this.attributeNames.contains(name)) {
attributes.put(name, session.getAttribute(name));
}
}
if (isCopyHttpSessionId()) {
attributes.put(HTTP_SESSION_ID_ATTR_NAME, session.getId());
HttpSession session = getSession(request);
if (session != null) {
if (isCopyHttpSessionId()) {
attributes.put(HTTP_SESSION_ID_ATTR_NAME, session.getId());
}
Enumeration<String> names = session.getAttributeNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
if (isCopyAllAttributes() || getAttributeNames().contains(name)) {
attributes.put(name, session.getAttribute(name));
}
}
}
return true;
}
private HttpSession getSession(ServerHttpRequest request) {
if (request instanceof ServletServerHttpRequest) {
ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request;
return servletRequest.getServletRequest().getSession(false);
}
return null;
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Exception ex) {

View File

@ -24,6 +24,7 @@ import java.util.Set;
import org.junit.Test;
import org.mockito.Mockito;
import org.springframework.mock.web.test.MockHttpSession;
import org.springframework.mock.web.test.MockServletContext;
import org.springframework.web.socket.AbstractHttpRequestTests;
import org.springframework.web.socket.WebSocketHandler;
@ -38,28 +39,29 @@ public class HttpSessionHandshakeInterceptorTests extends AbstractHttpRequestTes
@Test
public void copyAllAttributes() throws Exception {
public void defaultConstructor() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.setSession(new MockHttpSession(null, "123"));
this.servletRequest.getSession().setAttribute("foo", "bar");
this.servletRequest.getSession().setAttribute("bar", "baz");
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes);
assertEquals(2, attributes.size());
assertEquals(3, attributes.size());
assertEquals("bar", attributes.get("foo"));
assertEquals("baz", attributes.get("bar"));
assertEquals("123", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME));
}
@Test
public void copySelectedAttributes() throws Exception {
public void constructorWithAttributeNames() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.setSession(new MockHttpSession(null, "123"));
this.servletRequest.getSession().setAttribute("foo", "bar");
this.servletRequest.getSession().setAttribute("bar", "baz");
@ -67,29 +69,46 @@ public class HttpSessionHandshakeInterceptorTests extends AbstractHttpRequestTes
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(names);
interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes);
assertEquals(2, attributes.size());
assertEquals("bar", attributes.get("foo"));
assertEquals("123", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME));
}
@Test
public void doNotCopyHttpSessionId() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.setSession(new MockHttpSession(null, "123"));
this.servletRequest.getSession().setAttribute("foo", "bar");
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
interceptor.setCopyHttpSessionId(false);
interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes);
assertEquals(1, attributes.size());
assertEquals("bar", attributes.get("foo"));
}
@Test
public void copyHttpSessionId() throws Exception {
@Test
public void doNotCopyAttributes() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.setSession(new MockHttpSession(null, "foo"));
this.servletRequest.setSession(new MockHttpSession(null, "123"));
this.servletRequest.getSession().setAttribute("foo", "bar");
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
interceptor.setCopyHttpSessionId(true);
interceptor.setCopyAllAttributes(false);
interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes);
assertEquals(1, attributes.size());
assertEquals("foo", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME));
assertEquals("123", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME));
}
@Test
public void doNotCauseSessionCreation() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);