Add missing HandshakeInterceptor for STOMP endpoints

Issue: SPR-11845
This commit is contained in:
Rossen Stoyanchev 2014-07-15 13:28:50 -04:00
parent 6d6cc0ecec
commit 4dd5c274a0
3 changed files with 48 additions and 11 deletions

View File

@ -17,6 +17,7 @@
package org.springframework.web.socket.config.annotation;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
/**
* A contract for configuring a STOMP over WebSocket endpoint.
@ -36,4 +37,9 @@ public interface StompWebSocketEndpointRegistration {
*/
StompWebSocketEndpointRegistration setHandshakeHandler(HandshakeHandler handshakeHandler);
/**
* Configure the HandshakeInterceptor's to use.
*/
StompWebSocketEndpointRegistration addInterceptors(HandshakeInterceptor... interceptors);
}

View File

@ -23,11 +23,14 @@ import org.springframework.util.MultiValueMap;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
import java.util.Arrays;
/**
* An abstract base class class for configuring STOMP over WebSocket/SockJS endpoints.
*
@ -44,6 +47,8 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
private HandshakeHandler handshakeHandler;
private HandshakeInterceptor[] interceptors;
private StompSockJsServiceRegistration registration;
@ -58,9 +63,6 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
this.sockJsTaskScheduler = sockJsTaskScheduler;
}
/**
* Provide a custom or pre-configured {@link HandshakeHandler}.
*/
@Override
public StompWebSocketEndpointRegistration setHandshakeHandler(HandshakeHandler handshakeHandler) {
Assert.notNull(handshakeHandler, "'handshakeHandler' must not be null");
@ -68,12 +70,22 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
return this;
}
/**
* Enable SockJS fallback options.
*/
@Override
public StompWebSocketEndpointRegistration addInterceptors(HandshakeInterceptor... interceptors) {
this.interceptors = interceptors;
return this;
}
protected HandshakeInterceptor[] getInterceptors() {
return this.interceptors;
}
@Override
public SockJsServiceRegistration withSockJS() {
this.registration = new StompSockJsServiceRegistration(this.sockJsTaskScheduler);
if (this.interceptors != null) {
this.registration.setInterceptors(this.interceptors);
}
if (this.handshakeHandler != null) {
WebSocketTransportHandler transportHandler = new WebSocketTransportHandler(this.handshakeHandler);
this.registration.setTransportHandlerOverrides(transportHandler);
@ -93,9 +105,16 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
}
else {
for (String path : this.paths) {
WebSocketHttpRequestHandler handler = (this.handshakeHandler != null) ?
new WebSocketHttpRequestHandler(this.webSocketHandler, this.handshakeHandler) :
new WebSocketHttpRequestHandler(this.webSocketHandler);
WebSocketHttpRequestHandler handler;
if (this.handshakeHandler != null) {
handler = new WebSocketHttpRequestHandler(this.webSocketHandler, this.handshakeHandler);
}
else {
handler = new WebSocketHttpRequestHandler(this.webSocketHandler);
}
if (this.interceptors != null) {
handler.setHandshakeInterceptors(Arrays.asList(this.interceptors));
}
mappings.add(handler, path);
}
}

View File

@ -29,7 +29,9 @@ import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.MultiValueMap;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
import org.springframework.web.socket.sockjs.transport.TransportHandler;
@ -38,6 +40,8 @@ import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsServ
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
import static org.junit.Assert.*;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
/**
@ -73,12 +77,15 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
}
@Test
public void customHandshakeHandler() {
public void handshakeHandlerAndInterceptors() {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
registration.setHandshakeHandler(handshakeHandler);
registration.addInterceptors(interceptor);
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
@ -89,15 +96,19 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey();
assertNotNull(requestHandler.getWebSocketHandler());
assertSame(handshakeHandler, requestHandler.getHandshakeHandler());
assertEquals(Arrays.asList(interceptor), requestHandler.getHandshakeInterceptors());
}
@Test
public void customHandshakeHandlerPassedToSockJsService() {
public void handshakeHandlerAndInterceptorsWithSockJsService() {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
registration.setHandshakeHandler(handshakeHandler);
registration.addInterceptors(interceptor);
registration.withSockJS();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
@ -115,6 +126,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
Map<TransportType, TransportHandler> handlers = sockJsService.getTransportHandlers();
WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET);
assertSame(handshakeHandler, transportHandler.getHandshakeHandler());
assertEquals(Arrays.asList(interceptor), sockJsService.getHandshakeInterceptors());
}
}