Improve HttpSessionHandshakeInterceptor
Use explicit flag whether to copy all attributes.
This commit is contained in:
parent
de11cd8791
commit
3056301015
|
|
@ -17,6 +17,7 @@
|
|||
package org.springframework.web.socket.server.support;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.Enumeration;
|
||||
import java.util.Map;
|
||||
import javax.servlet.http.HttpSession;
|
||||
|
|
@ -30,8 +31,11 @@ import org.springframework.web.socket.WebSocketSession;
|
|||
import org.springframework.web.socket.server.HandshakeInterceptor;
|
||||
|
||||
/**
|
||||
* An interceptor to copy HTTP session attributes into the map of "handshake attributes"
|
||||
* made available through {@link WebSocketSession#getAttributes()}.
|
||||
* An interceptor to copy information from the HTTP session to the "handshake
|
||||
* attributes" map to made available via{@link WebSocketSession#getAttributes()}.
|
||||
*
|
||||
* <p>Copies a subset or all HTTP session attributes and/or the HTTP session id
|
||||
* under the key {@link #HTTP_SESSION_ID_ATTR_NAME}.
|
||||
*
|
||||
* @author Rossen Stoyanchev
|
||||
* @since 4.0
|
||||
|
|
@ -44,32 +48,67 @@ public class HttpSessionHandshakeInterceptor implements HandshakeInterceptor {
|
|||
*/
|
||||
public static final String HTTP_SESSION_ID_ATTR_NAME = "HTTP.SESSION.ID";
|
||||
|
||||
|
||||
private final Collection<String> attributeNames;
|
||||
|
||||
private boolean copyHttpSessionId;
|
||||
private boolean copyAllAttributes;
|
||||
|
||||
private boolean copyHttpSessionId = true;
|
||||
|
||||
|
||||
/**
|
||||
* A constructor for copying all available HTTP session attributes.
|
||||
* Default constructor for copying all HTTP session attributes and the HTTP
|
||||
* session id.
|
||||
* @see #setCopyAllAttributes
|
||||
* @see #setCopyHttpSessionId
|
||||
*/
|
||||
public HttpSessionHandshakeInterceptor() {
|
||||
this(null);
|
||||
this.attributeNames = Collections.emptyList();
|
||||
this.copyAllAttributes = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* A constructor for copying a subset of HTTP session attributes.
|
||||
* @param attributeNames the HTTP session attributes to copy
|
||||
* Constructor for copying specific HTTP session attributes and the HTTP
|
||||
* session id.
|
||||
* @param attributeNames session attributes to copy
|
||||
* @see #setCopyAllAttributes
|
||||
* @see #setCopyHttpSessionId
|
||||
*/
|
||||
public HttpSessionHandshakeInterceptor(Collection<String> attributeNames) {
|
||||
this.attributeNames = attributeNames;
|
||||
this.attributeNames = Collections.unmodifiableCollection(attributeNames);
|
||||
this.copyAllAttributes = false;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Return the configured attribute names to copy (read-only).
|
||||
*/
|
||||
public Collection<String> getAttributeNames() {
|
||||
return this.attributeNames;
|
||||
}
|
||||
|
||||
/**
|
||||
* When set to "true", the HTTP session id is copied to the WebSocket
|
||||
* handshake attributes, and is subsequently available via
|
||||
* {@link org.springframework.web.socket.WebSocketSession#getAttributes()}
|
||||
* Whether to copy all attributes from the HTTP session. If set to "true" any
|
||||
* explicitly configured attribute names are ignored.
|
||||
* <p>By default this is set to either "true" or "false" depending on which
|
||||
* constructor was used (default or with attribute names respectively).
|
||||
* @param copyAllAttributes whether to copy all attributes
|
||||
*/
|
||||
public void setCopyAllAttributes(boolean copyAllAttributes) {
|
||||
this.copyAllAttributes = copyAllAttributes;
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether to copy all HTTP session attributes.
|
||||
*/
|
||||
public boolean isCopyAllAttributes() {
|
||||
return this.copyAllAttributes;
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether the HTTP session id should be copied to the handshake attributes
|
||||
* under the key {@link #HTTP_SESSION_ID_ATTR_NAME}.
|
||||
* <p>By default this is "false".
|
||||
* <p>By default this is "true".
|
||||
* @param copyHttpSessionId whether to copy the HTTP session id.
|
||||
*/
|
||||
public void setCopyHttpSessionId(boolean copyHttpSessionId) {
|
||||
|
|
@ -88,25 +127,30 @@ public class HttpSessionHandshakeInterceptor implements HandshakeInterceptor {
|
|||
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
|
||||
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
|
||||
|
||||
if (request instanceof ServletServerHttpRequest) {
|
||||
ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request;
|
||||
HttpSession session = servletRequest.getServletRequest().getSession(false);
|
||||
if (session != null) {
|
||||
Enumeration<String> names = session.getAttributeNames();
|
||||
while (names.hasMoreElements()) {
|
||||
String name = names.nextElement();
|
||||
if (CollectionUtils.isEmpty(this.attributeNames) || this.attributeNames.contains(name)) {
|
||||
attributes.put(name, session.getAttribute(name));
|
||||
}
|
||||
}
|
||||
if (isCopyHttpSessionId()) {
|
||||
attributes.put(HTTP_SESSION_ID_ATTR_NAME, session.getId());
|
||||
HttpSession session = getSession(request);
|
||||
if (session != null) {
|
||||
if (isCopyHttpSessionId()) {
|
||||
attributes.put(HTTP_SESSION_ID_ATTR_NAME, session.getId());
|
||||
}
|
||||
Enumeration<String> names = session.getAttributeNames();
|
||||
while (names.hasMoreElements()) {
|
||||
String name = names.nextElement();
|
||||
if (isCopyAllAttributes() || getAttributeNames().contains(name)) {
|
||||
attributes.put(name, session.getAttribute(name));
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private HttpSession getSession(ServerHttpRequest request) {
|
||||
if (request instanceof ServletServerHttpRequest) {
|
||||
ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request;
|
||||
return servletRequest.getServletRequest().getSession(false);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
|
||||
WebSocketHandler wsHandler, Exception ex) {
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import java.util.Set;
|
|||
import org.junit.Test;
|
||||
import org.mockito.Mockito;
|
||||
import org.springframework.mock.web.test.MockHttpSession;
|
||||
import org.springframework.mock.web.test.MockServletContext;
|
||||
import org.springframework.web.socket.AbstractHttpRequestTests;
|
||||
import org.springframework.web.socket.WebSocketHandler;
|
||||
|
||||
|
|
@ -38,28 +39,29 @@ public class HttpSessionHandshakeInterceptorTests extends AbstractHttpRequestTes
|
|||
|
||||
|
||||
@Test
|
||||
public void copyAllAttributes() throws Exception {
|
||||
|
||||
public void defaultConstructor() throws Exception {
|
||||
Map<String, Object> attributes = new HashMap<String, Object>();
|
||||
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
|
||||
|
||||
this.servletRequest.setSession(new MockHttpSession(null, "123"));
|
||||
this.servletRequest.getSession().setAttribute("foo", "bar");
|
||||
this.servletRequest.getSession().setAttribute("bar", "baz");
|
||||
|
||||
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
|
||||
interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes);
|
||||
|
||||
assertEquals(2, attributes.size());
|
||||
assertEquals(3, attributes.size());
|
||||
assertEquals("bar", attributes.get("foo"));
|
||||
assertEquals("baz", attributes.get("bar"));
|
||||
assertEquals("123", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void copySelectedAttributes() throws Exception {
|
||||
|
||||
public void constructorWithAttributeNames() throws Exception {
|
||||
Map<String, Object> attributes = new HashMap<String, Object>();
|
||||
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
|
||||
|
||||
this.servletRequest.setSession(new MockHttpSession(null, "123"));
|
||||
this.servletRequest.getSession().setAttribute("foo", "bar");
|
||||
this.servletRequest.getSession().setAttribute("bar", "baz");
|
||||
|
||||
|
|
@ -67,29 +69,46 @@ public class HttpSessionHandshakeInterceptorTests extends AbstractHttpRequestTes
|
|||
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(names);
|
||||
interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes);
|
||||
|
||||
assertEquals(2, attributes.size());
|
||||
assertEquals("bar", attributes.get("foo"));
|
||||
assertEquals("123", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doNotCopyHttpSessionId() throws Exception {
|
||||
Map<String, Object> attributes = new HashMap<String, Object>();
|
||||
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
|
||||
|
||||
this.servletRequest.setSession(new MockHttpSession(null, "123"));
|
||||
this.servletRequest.getSession().setAttribute("foo", "bar");
|
||||
|
||||
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
|
||||
interceptor.setCopyHttpSessionId(false);
|
||||
interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes);
|
||||
|
||||
assertEquals(1, attributes.size());
|
||||
assertEquals("bar", attributes.get("foo"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void copyHttpSessionId() throws Exception {
|
||||
|
||||
@Test
|
||||
public void doNotCopyAttributes() throws Exception {
|
||||
Map<String, Object> attributes = new HashMap<String, Object>();
|
||||
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
|
||||
|
||||
this.servletRequest.setSession(new MockHttpSession(null, "foo"));
|
||||
this.servletRequest.setSession(new MockHttpSession(null, "123"));
|
||||
this.servletRequest.getSession().setAttribute("foo", "bar");
|
||||
|
||||
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
|
||||
interceptor.setCopyHttpSessionId(true);
|
||||
interceptor.setCopyAllAttributes(false);
|
||||
interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes);
|
||||
|
||||
assertEquals(1, attributes.size());
|
||||
assertEquals("foo", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME));
|
||||
assertEquals("123", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doNotCauseSessionCreation() throws Exception {
|
||||
|
||||
Map<String, Object> attributes = new HashMap<String, Object>();
|
||||
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue