Allow hook to associate user with WebSocket session

This change adds a protected method to DefaultHandshakeHandler to
determine the user for the WebSocket session. By default it's
implemeted to obtain it from the request.

Issue: SPR-11228
This commit is contained in:
Rossen Stoyanchev 2014-01-13 16:36:41 -05:00
parent 6265bc1df7
commit a5c3143512
8 changed files with 72 additions and 19 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -50,7 +50,7 @@ public class JettyWebSocketSession extends AbstractWebSocketSession<Session> {
private List<WebSocketExtension> extensions;
private final Principal principal;
private final Principal user;
/**
@ -68,13 +68,13 @@ public class JettyWebSocketSession extends AbstractWebSocketSession<Session> {
*
* @param handshakeAttributes attributes from the HTTP handshake to make available
* through the WebSocket session
* @param principal the user associated with the session; can be left
* @param user the user associated with the session; can be left
* {@code null} in which case, we'll fallback on the user available via
* {@link org.eclipse.jetty.websocket.api.Session#getUpgradeRequest()}
*/
public JettyWebSocketSession(Map<String, Object> handshakeAttributes, Principal principal) {
public JettyWebSocketSession(Map<String, Object> handshakeAttributes, Principal user) {
super(handshakeAttributes);
this.principal = principal;
this.user = user;
}
@ -103,8 +103,8 @@ public class JettyWebSocketSession extends AbstractWebSocketSession<Session> {
@Override
public Principal getPrincipal() {
if (this.principal != null) {
return this.principal;
if (this.user != null) {
return this.user;
}
checkNativeSessionInitialized();
return getNativeSession().getUpgradeRequest().getUserPrincipal();

View File

@ -54,11 +54,14 @@ public class StandardWebSocketSession extends AbstractWebSocketSession<Session>
private final InetSocketAddress remoteAddress;
private final Principal user;
private List<WebSocketExtension> extensions;
/**
* Class constructor.
*
* @param headers the headers of the handshake request
* @param handshakeAttributes attributes from the HTTP handshake to make available
* through the WebSocket session
@ -67,11 +70,30 @@ public class StandardWebSocketSession extends AbstractWebSocketSession<Session>
*/
public StandardWebSocketSession(HttpHeaders headers, Map<String, Object> handshakeAttributes,
InetSocketAddress localAddress, InetSocketAddress remoteAddress) {
this(headers, handshakeAttributes, localAddress, remoteAddress, null);
}
/**
* Class constructor that associates a user with the WebSocket session.
*
* @param headers the headers of the handshake request
* @param handshakeAttributes attributes from the HTTP handshake to make available
* through the WebSocket session
* @param localAddress the address on which the request was received
* @param remoteAddress the address of the remote client
* @param user the user associated with the session; can be left
* {@code null} in which case, we'll fallback on the user available via
*/
public StandardWebSocketSession(HttpHeaders headers, Map<String, Object> handshakeAttributes,
InetSocketAddress localAddress, InetSocketAddress remoteAddress, Principal user) {
super(handshakeAttributes);
headers = (headers != null) ? headers : new HttpHeaders();
this.handshakeHeaders = HttpHeaders.readOnlyHttpHeaders(headers);
this.localAddress = localAddress;
this.remoteAddress = remoteAddress;
this.user = user;
}
@Override
@ -93,6 +115,9 @@ public class StandardWebSocketSession extends AbstractWebSocketSession<Session>
@Override
public Principal getPrincipal() {
if (this.user != null) {
return this.user;
}
checkNativeSessionInitialized();
return getNativeSession().getUserPrincipal();
}

View File

@ -16,6 +16,7 @@
package org.springframework.web.socket.server;
import java.security.Principal;
import java.util.List;
import java.util.Map;
@ -49,6 +50,7 @@ public interface RequestUpgradeStrategy {
* @param response the current response
* @param selectedProtocol the selected sub-protocol, if any
* @param selectedExtensions the selected WebSocket protocol extensions
* @param user the user to associate with the WebSocket session
* @param wsHandler the handler for WebSocket messages
* @param attributes handshake request specific attributes to be set on the WebSocket
* session via {@link org.springframework.web.socket.server.HandshakeInterceptor}
@ -60,7 +62,7 @@ public interface RequestUpgradeStrategy {
* handshake request.
*/
void upgrade(ServerHttpRequest request, ServerHttpResponse response,
String selectedProtocol, List<WebSocketExtension> selectedExtensions,
String selectedProtocol, List<WebSocketExtension> selectedExtensions, Principal user,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException;
}

View File

@ -17,6 +17,7 @@
package org.springframework.web.socket.server.jetty;
import java.io.IOException;
import java.security.Principal;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@ -130,7 +131,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy {
@Override
public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
String selectedProtocol, List<WebSocketExtension> selectedExtensions,
String selectedProtocol, List<WebSocketExtension> selectedExtensions, Principal user,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException {
Assert.isInstanceOf(ServletServerHttpRequest.class, request);
@ -141,7 +142,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy {
Assert.isTrue(this.factory.isUpgradeRequest(servletRequest, servletResponse), "Not a WebSocket handshake");
JettyWebSocketSession session = new JettyWebSocketSession(attributes, request.getPrincipal());
JettyWebSocketSession session = new JettyWebSocketSession(attributes, user);
JettyWebSocketHandlerAdapter handlerAdapter = new JettyWebSocketHandlerAdapter(wsHandler, session);
WebSocketHandlerContainer container =

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,6 +17,7 @@
package org.springframework.web.socket.server.standard;
import java.net.InetSocketAddress;
import java.security.Principal;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@ -86,15 +87,15 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS
@Override
public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
String selectedProtocol, List<WebSocketExtension> selectedExtensions,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException {
String selectedProtocol, List<WebSocketExtension> selectedExtensions, Principal user,
WebSocketHandler wsHandler, Map<String, Object> attrs) throws HandshakeFailureException {
HttpHeaders headers = request.getHeaders();
InetSocketAddress localAddr = request.getLocalAddress();
InetSocketAddress remoteAddr = request.getRemoteAddress();
StandardWebSocketSession session = new StandardWebSocketSession(headers, attributes, localAddr, remoteAddr);
StandardWebSocketSession session = new StandardWebSocketSession(headers, attrs, localAddr, remoteAddr, user);
StandardWebSocketHandlerAdapter endpoint = new StandardWebSocketHandlerAdapter(wsHandler, session);
List<Extension> extensions = new ArrayList<Extension>();

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,6 +17,7 @@
package org.springframework.web.socket.server.support;
import java.io.IOException;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@ -199,11 +200,13 @@ public class DefaultHandshakeHandler implements HandshakeHandler {
List<WebSocketExtension> supported = this.requestUpgradeStrategy.getSupportedExtensions(request);
List<WebSocketExtension> extensions = filterRequestedExtensions(request, requested, supported);
Principal user = determineUser(request, wsHandler, attributes);
if (logger.isDebugEnabled()) {
logger.debug("Upgrading request, sub-protocol=" + subProtocol + ", extensions=" + extensions);
}
this.requestUpgradeStrategy.upgrade(request, response, subProtocol, extensions, wsHandler, attributes);
this.requestUpgradeStrategy.upgrade(request, response, subProtocol, extensions, user, wsHandler, attributes);
return true;
}
@ -326,4 +329,25 @@ public class DefaultHandshakeHandler implements HandshakeHandler {
return requested;
}
/**
* A method that can be used to associate a user with the WebSocket session
* in the process of being established. The default implementation calls
* {@link org.springframework.http.server.ServerHttpRequest#getPrincipal()}
* <p>
* Sub-classes can provide custom logic for associating a user with a session,
* for example for assigning a name to anonymous users (i.e. not fully
* authenticated).
*
* @param request the handshake request
* @param wsHandler the WebSocket handler that will handle messages
* @param attributes handshake attributes to pass to the WebSocket session
*
* @return the user for the WebSocket session or {@code null}
*/
protected Principal determineUser(ServerHttpRequest request, WebSocketHandler wsHandler,
Map<String, Object> attributes) {
return request.getPrincipal();
}
}

View File

@ -77,7 +77,7 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
verify(this.upgradeStrategy).upgrade(this.request, this.response,
"STOMP", Collections.<WebSocketExtension>emptyList(), handler, attributes);
"STOMP", Collections.<WebSocketExtension>emptyList(), null, handler, attributes);
}
@Test
@ -99,7 +99,7 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
verify(this.upgradeStrategy).upgrade(this.request, this.response,
"v11.stomp", Collections.<WebSocketExtension>emptyList(), handler, attributes);
"v11.stomp", Collections.<WebSocketExtension>emptyList(), null, handler, attributes);
}
@Test
@ -121,7 +121,7 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
verify(this.upgradeStrategy).upgrade(this.request, this.response,
null, Collections.<WebSocketExtension>emptyList(), handler, attributes);
null, Collections.<WebSocketExtension>emptyList(), null, handler, attributes);
}