Modify return type of subProtocolWebSocketHandler bean

The @Bean method now returns WebSocketHandler allowing it to be
decorated via WebSocketHandlerDecorator.
This commit is contained in:
Rossen Stoyanchev 2013-09-03 15:24:58 -04:00
parent 542b5b2029
commit 30d2f783a7
5 changed files with 50 additions and 22 deletions

View File

@ -22,11 +22,13 @@ import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandl
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.DefaultHandshakeHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.config.SockJsServiceRegistration;
import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
import org.springframework.web.socket.support.WebSocketHandlerDecorator;
/**
@ -39,7 +41,7 @@ public abstract class AbstractStompEndpointRegistration<M> implements StompEndpo
private final String[] paths;
private final SubProtocolWebSocketHandler wsHandler;
private final WebSocketHandler wsHandler;
private HandshakeHandler handshakeHandler;
@ -48,7 +50,7 @@ public abstract class AbstractStompEndpointRegistration<M> implements StompEndpo
private final TaskScheduler sockJsTaskScheduler;
public AbstractStompEndpointRegistration(String[] paths, SubProtocolWebSocketHandler webSocketHandler,
public AbstractStompEndpointRegistration(String[] paths, WebSocketHandler webSocketHandler,
TaskScheduler sockJsTaskScheduler) {
Assert.notEmpty(paths, "No paths specified");
@ -115,7 +117,7 @@ public abstract class AbstractStompEndpointRegistration<M> implements StompEndpo
if (handler instanceof DefaultHandshakeHandler) {
DefaultHandshakeHandler defaultHandshakeHandler = (DefaultHandshakeHandler) handler;
if (ObjectUtils.isEmpty(defaultHandshakeHandler.getSupportedProtocols())) {
Set<String> protocols = this.wsHandler.getSupportedProtocols();
Set<String> protocols = findSubProtocolWebSocketHandler(this.wsHandler).getSupportedProtocols();
defaultHandshakeHandler.setSupportedProtocols(protocols.toArray(new String[protocols.size()]));
}
}
@ -123,12 +125,19 @@ public abstract class AbstractStompEndpointRegistration<M> implements StompEndpo
return handler;
}
private static SubProtocolWebSocketHandler findSubProtocolWebSocketHandler(WebSocketHandler webSocketHandler) {
WebSocketHandler actual = (webSocketHandler instanceof WebSocketHandlerDecorator) ?
((WebSocketHandlerDecorator) webSocketHandler).getLastHandler() : webSocketHandler;
Assert.isInstanceOf(SubProtocolWebSocketHandler.class, actual,
"No SubProtocolWebSocketHandler found: " + webSocketHandler);
return (SubProtocolWebSocketHandler) actual;
}
protected abstract void addSockJsServiceMapping(M mappings, SockJsService sockJsService,
SubProtocolWebSocketHandler wsHandler, String pathPattern);
WebSocketHandler wsHandler, String pathPattern);
protected abstract void addWebSocketHandlerMapping(M mappings,
SubProtocolWebSocketHandler wsHandler, HandshakeHandler handshakeHandler, String path);
WebSocketHandler wsHandler, HandshakeHandler handshakeHandler, String path);
private class StompSockJsServiceRegistration extends SockJsServiceRegistration {

View File

@ -16,11 +16,11 @@
package org.springframework.messaging.simp.config;
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.SockJsHttpRequestHandler;
@ -38,8 +38,8 @@ public class ServletStompEndpointRegistration
extends AbstractStompEndpointRegistration<MultiValueMap<HttpRequestHandler, String>> {
public ServletStompEndpointRegistration(String[] paths, SubProtocolWebSocketHandler wsHandler,
TaskScheduler sockJsTaskScheduler) {
public ServletStompEndpointRegistration(String[] paths,
WebSocketHandler wsHandler, TaskScheduler sockJsTaskScheduler) {
super(paths, wsHandler, sockJsTaskScheduler);
}
@ -51,7 +51,7 @@ public class ServletStompEndpointRegistration
@Override
protected void addSockJsServiceMapping(MultiValueMap<HttpRequestHandler, String> mappings,
SockJsService sockJsService, SubProtocolWebSocketHandler wsHandler, String pathPattern) {
SockJsService sockJsService, WebSocketHandler wsHandler, String pathPattern) {
SockJsHttpRequestHandler httpHandler = new SockJsHttpRequestHandler(sockJsService, wsHandler);
mappings.add(httpHandler, pathPattern);
@ -59,7 +59,7 @@ public class ServletStompEndpointRegistration
@Override
protected void addWebSocketHandlerMapping(MultiValueMap<HttpRequestHandler, String> mappings,
SubProtocolWebSocketHandler wsHandler, HandshakeHandler handshakeHandler, String path) {
WebSocketHandler wsHandler, HandshakeHandler handshakeHandler, String path) {
WebSocketHttpRequestHandler handler = new WebSocketHttpRequestHandler(wsHandler, handshakeHandler);
mappings.add(handler, path);

View File

@ -30,6 +30,8 @@ import org.springframework.util.MultiValueMap;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.servlet.handler.AbstractHandlerMapping;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.support.WebSocketHandlerDecorator;
/**
@ -40,7 +42,9 @@ import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
*/
public class ServletStompEndpointRegistry implements StompEndpointRegistry {
private final SubProtocolWebSocketHandler wsHandler;
private final WebSocketHandler webSocketHandler;
private final SubProtocolWebSocketHandler subProtocolWebSocketHandler;
private final StompProtocolHandler stompHandler;
@ -49,23 +53,36 @@ public class ServletStompEndpointRegistry implements StompEndpointRegistry {
private final TaskScheduler sockJsScheduler;
public ServletStompEndpointRegistry(SubProtocolWebSocketHandler webSocketHandler,
public ServletStompEndpointRegistry(WebSocketHandler webSocketHandler,
MutableUserQueueSuffixResolver userQueueSuffixResolver, TaskScheduler defaultSockJsTaskScheduler) {
Assert.notNull(webSocketHandler);
Assert.notNull(userQueueSuffixResolver);
this.wsHandler = webSocketHandler;
this.webSocketHandler = webSocketHandler;
this.subProtocolWebSocketHandler = findSubProtocolWebSocketHandler(webSocketHandler);
this.stompHandler = new StompProtocolHandler();
this.stompHandler.setUserQueueSuffixResolver(userQueueSuffixResolver);
this.sockJsScheduler = defaultSockJsTaskScheduler;
}
private static SubProtocolWebSocketHandler findSubProtocolWebSocketHandler(WebSocketHandler webSocketHandler) {
WebSocketHandler actual = (webSocketHandler instanceof WebSocketHandlerDecorator) ?
((WebSocketHandlerDecorator) webSocketHandler).getLastHandler() : webSocketHandler;
Assert.isInstanceOf(SubProtocolWebSocketHandler.class, actual,
"No SubProtocolWebSocketHandler found: " + webSocketHandler);
return (SubProtocolWebSocketHandler) actual;
}
@Override
public StompEndpointRegistration addEndpoint(String... paths) {
this.wsHandler.addProtocolHandler(this.stompHandler);
ServletStompEndpointRegistration r = new ServletStompEndpointRegistration(paths, this.wsHandler, this.sockJsScheduler);
this.subProtocolWebSocketHandler.addProtocolHandler(this.stompHandler);
ServletStompEndpointRegistration r = new ServletStompEndpointRegistration(
paths, this.webSocketHandler, this.sockJsScheduler);
this.registrations.add(r);
return r;
}

View File

@ -34,6 +34,7 @@ import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.handler.AbstractHandlerMapping;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.config.SockJsServiceRegistration;
@ -65,7 +66,7 @@ public abstract class WebSocketMessageBrokerConfigurationSupport {
}
@Bean
public SubProtocolWebSocketHandler subProtocolWebSocketHandler() {
public WebSocketHandler subProtocolWebSocketHandler() {
SubProtocolWebSocketHandler wsHandler = new SubProtocolWebSocketHandler(webSocketRequestChannel());
webSocketResponseChannel().subscribe(wsHandler);
return wsHandler;

View File

@ -25,6 +25,7 @@ import org.mockito.Mockito;
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler;
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.DefaultHandshakeHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.sockjs.SockJsService;
@ -122,13 +123,13 @@ public class AbstractStompEndpointRegistrationTests {
@Override
protected void addSockJsServiceMapping(List<Mapping> mappings, SockJsService sockJsService,
SubProtocolWebSocketHandler wsHandler, String pathPattern) {
WebSocketHandler wsHandler, String pathPattern) {
mappings.add(new Mapping(wsHandler, pathPattern, sockJsService));
}
@Override
protected void addWebSocketHandlerMapping(List<Mapping> mappings, SubProtocolWebSocketHandler wsHandler,
protected void addWebSocketHandlerMapping(List<Mapping> mappings, WebSocketHandler wsHandler,
HandshakeHandler handshakeHandler, String path) {
mappings.add(new Mapping(wsHandler, path, handshakeHandler));
@ -137,7 +138,7 @@ public class AbstractStompEndpointRegistrationTests {
private static class Mapping {
private final SubProtocolWebSocketHandler webSocketHandler;
private final WebSocketHandler webSocketHandler;
private final String path;
@ -145,14 +146,14 @@ public class AbstractStompEndpointRegistrationTests {
private final DefaultSockJsService sockJsService;
public Mapping(SubProtocolWebSocketHandler handler, String path, SockJsService sockJsService) {
public Mapping(WebSocketHandler handler, String path, SockJsService sockJsService) {
this.webSocketHandler = handler;
this.path = path;
this.handshakeHandler = null;
this.sockJsService = (DefaultSockJsService) sockJsService;
}
public Mapping(SubProtocolWebSocketHandler h, String path, HandshakeHandler hh) {
public Mapping(WebSocketHandler h, String path, HandshakeHandler hh) {
this.webSocketHandler = h;
this.path = path;
this.handshakeHandler = hh;