diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java
index 1df1142accb..443f7e5a859 100644
--- a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java
+++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java
@@ -51,8 +51,6 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
// SiMP header names
- public static final String CONNECT_MESSAGE_HEADER = "simpConnectMessage";
-
public static final String DESTINATION_HEADER = "simpDestination";
public static final String MESSAGE_TYPE_HEADER = "simpMessageType";
@@ -65,6 +63,11 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
public static final String USER_HEADER = "simpUser";
+ public static final String CONNECT_MESSAGE_HEADER = "simpConnectMessage";
+
+ public static final String HEART_BEAT_HEADER = "simpHeartbeat";
+
+
/**
* For internal use.
*
The original destination used by a client when subscribing. Such a
@@ -262,4 +265,8 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
return (Principal) headers.get(USER_HEADER);
}
+ public static long[] getHeartbeat(Map headers) {
+ return (long[]) headers.get(HEART_BEAT_HEADER);
+ }
+
}
diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageType.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageType.java
index d0f9f8ca627..7def4d82ecd 100644
--- a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageType.java
+++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageType.java
@@ -29,14 +29,14 @@ public enum SimpMessageType {
CONNECT_ACK,
- HEARTBEAT,
-
MESSAGE,
SUBSCRIBE,
UNSUBSCRIBE,
+ HEARTBEAT,
+
DISCONNECT,
DISCONNECT_ACK,
diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java
index 6c4a94b7e6a..74ebafea5ad 100644
--- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java
+++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java
@@ -16,7 +16,11 @@
package org.springframework.messaging.simp.broker;
+import java.security.Principal;
import java.util.Collection;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ScheduledFuture;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
@@ -27,6 +31,7 @@ import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.support.MessageHeaderInitializer;
+import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.util.PathMatcher;
@@ -43,10 +48,18 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
private static final byte[] EMPTY_PAYLOAD = new byte[0];
+ private final Map sessions = new ConcurrentHashMap();
+
private SubscriptionRegistry subscriptionRegistry;
private PathMatcher pathMatcher;
+ private TaskScheduler taskScheduler;
+
+ private long[] heartbeatValue;
+
+ private ScheduledFuture> heartbeatFuture;
+
private MessageHeaderInitializer headerInitializer;
@@ -100,6 +113,49 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
initPathMatcherToUse();
}
+ /**
+ * Configure the {@link org.springframework.scheduling.TaskScheduler} to
+ * use for providing heartbeat support. Setting this property also sets the
+ * {@link #setHeartbeatValue heartbeatValue} to "10000, 10000".
+ * By default this is not set.
+ * @since 4.2
+ */
+ public void setTaskScheduler(TaskScheduler taskScheduler) {
+ Assert.notNull(taskScheduler);
+ this.taskScheduler = taskScheduler;
+ if (this.heartbeatValue == null) {
+ this.heartbeatValue = new long[] {10000, 10000};
+ }
+ }
+
+ /**
+ * Return the configured TaskScheduler.
+ */
+ public TaskScheduler getTaskScheduler() {
+ return this.taskScheduler;
+ }
+
+ /**
+ * Configure the value for the heart-beat settings. The first number
+ * represents how often the server will write or send a heartbeat.
+ * The second is how often the client should write. 0 means no heartbeats.
+ *
By default this is set to "0, 0" unless the {@link #setTaskScheduler
+ * taskScheduler} in which case the default becomes "10000,10000"
+ * (in milliseconds).
+ * @since 4.2
+ */
+ public void setHeartbeatValue(long[] heartbeat) {
+ Assert.notNull(heartbeat);
+ this.heartbeatValue = heartbeat;
+ }
+
+ /**
+ * The configured value for the heart-beat settings.
+ */
+ public long[] getHeartbeatValue() {
+ return this.heartbeatValue;
+ }
+
/**
* Configure a {@link MessageHeaderInitializer} to apply to the headers
* of all messages sent to the client outbound channel.
@@ -120,11 +176,37 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
@Override
public void startInternal() {
publishBrokerAvailableEvent();
+ if (getTaskScheduler() != null) {
+ long interval = initHeartbeatTaskDelay();
+ if (interval > 0) {
+ this.heartbeatFuture = this.taskScheduler.scheduleWithFixedDelay(new HeartbeatTask(), interval);
+ }
+ }
+ else {
+ Assert.isTrue(getHeartbeatValue() == null ||
+ (getHeartbeatValue()[0] == 0 && getHeartbeatValue()[1] == 0),
+ "Heartbeat values configured but no TaskScheduler is provided.");
+ }
+ }
+
+ private long initHeartbeatTaskDelay() {
+ if (getHeartbeatValue() == null) {
+ return 0;
+ }
+ else if (getHeartbeatValue()[0] > 0 && getHeartbeatValue()[1] > 0) {
+ return Math.min(getHeartbeatValue()[0], getHeartbeatValue()[1]);
+ }
+ else {
+ return (getHeartbeatValue()[0] > 0 ? getHeartbeatValue()[0] : getHeartbeatValue()[1]);
+ }
}
@Override
public void stopInternal() {
publishBrokerUnavailableEvent();
+ if (this.heartbeatFuture != null) {
+ this.heartbeatFuture.cancel(true);
+ }
}
@Override
@@ -133,6 +215,9 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers);
String destination = SimpMessageHeaderAccessor.getDestination(headers);
String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
+ Principal user = SimpMessageHeaderAccessor.getUser(headers);
+
+ updateSessionReadTime(sessionId);
if (!checkDestinationPrefix(destination)) {
return;
@@ -150,23 +235,21 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
}
else if (SimpMessageType.CONNECT.equals(messageType)) {
logMessage(message);
+ long[] clientHeartbeat = SimpMessageHeaderAccessor.getHeartbeat(headers);
+ long[] serverHeartbeat = getHeartbeatValue();
+ this.sessions.put(sessionId, new SessionInfo(sessionId, user, clientHeartbeat, serverHeartbeat));
SimpMessageHeaderAccessor connectAck = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
initHeaders(connectAck);
connectAck.setSessionId(sessionId);
connectAck.setUser(SimpMessageHeaderAccessor.getUser(headers));
connectAck.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, message);
+ connectAck.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, serverHeartbeat);
Message messageOut = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectAck.getMessageHeaders());
getClientOutboundChannel().send(messageOut);
}
else if (SimpMessageType.DISCONNECT.equals(messageType)) {
logMessage(message);
- this.subscriptionRegistry.unregisterAllSubscriptions(sessionId);
- SimpMessageHeaderAccessor disconnectAck = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK);
- initHeaders(disconnectAck);
- disconnectAck.setSessionId(sessionId);
- disconnectAck.setUser(SimpMessageHeaderAccessor.getUser(headers));
- Message messageOut = MessageBuilder.createMessage(EMPTY_PAYLOAD, disconnectAck.getMessageHeaders());
- getClientOutboundChannel().send(messageOut);
+ handleDisconnect(sessionId, user);
}
else if (SimpMessageType.SUBSCRIBE.equals(messageType)) {
logMessage(message);
@@ -178,6 +261,15 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
}
}
+ private void updateSessionReadTime(String sessionId) {
+ if (sessionId != null) {
+ SessionInfo info = this.sessions.get(sessionId);
+ if (info != null) {
+ info.setLastReadTime(System.currentTimeMillis());
+ }
+ }
+ }
+
private void logMessage(Message> message) {
if (logger.isDebugEnabled()) {
SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class);
@@ -192,11 +284,23 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
}
}
+ private void handleDisconnect(String sessionId, Principal user) {
+ this.sessions.remove(sessionId);
+ this.subscriptionRegistry.unregisterAllSubscriptions(sessionId);
+ SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK);
+ accessor.setSessionId(sessionId);
+ accessor.setUser(user);
+ initHeaders(accessor);
+ Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
+ getClientOutboundChannel().send(message);
+ }
+
protected void sendMessageToSubscribers(String destination, Message> message) {
MultiValueMap subscriptions = this.subscriptionRegistry.findSubscriptions(message);
if (!subscriptions.isEmpty() && logger.isDebugEnabled()) {
logger.debug("Broadcasting to " + subscriptions.size() + " sessions.");
}
+ long now = System.currentTimeMillis();
for (String sessionId : subscriptions.keySet()) {
for (String subscriptionId : subscriptions.get(sessionId)) {
SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
@@ -212,6 +316,12 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
catch (Throwable ex) {
logger.error("Failed to send " + message, ex);
}
+ finally {
+ SessionInfo info = this.sessions.get(sessionId);
+ if (info != null) {
+ info.setLastWriteTime(now);
+ }
+ }
}
}
}
@@ -221,4 +331,93 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
return "SimpleBroker[" + this.subscriptionRegistry + "]";
}
+
+ private static class SessionInfo {
+
+ /* STOMP spec: receiver SHOULD take into account an error margin */
+ private static final long HEARTBEAT_MULTIPLIER = 3;
+
+
+ private final String sessiondId;
+
+ private final Principal user;
+
+ private final long readInterval;
+
+ private final long writeInterval;
+
+ private volatile long lastReadTime;
+
+ private volatile long lastWriteTime;
+
+
+ public SessionInfo(String sessiondId, Principal user, long[] clientHeartbeat, long[] serverHeartbeat) {
+ this.sessiondId = sessiondId;
+ this.user = user;
+ if (clientHeartbeat != null && serverHeartbeat != null) {
+ this.readInterval = (clientHeartbeat[0] > 0 && serverHeartbeat[1] > 0 ?
+ Math.max(clientHeartbeat[0], serverHeartbeat[1]) * HEARTBEAT_MULTIPLIER : 0);
+ this.writeInterval = (clientHeartbeat[1] > 0 && serverHeartbeat[0] > 0 ?
+ Math.max(clientHeartbeat[1], serverHeartbeat[0]) : 0);
+ }
+ else {
+ this.readInterval = 0;
+ this.writeInterval = 0;
+ }
+ this.lastReadTime = this.lastWriteTime = System.currentTimeMillis();
+ }
+
+ public String getSessiondId() {
+ return this.sessiondId;
+ }
+
+ public Principal getUser() {
+ return this.user;
+ }
+
+ public long getReadInterval() {
+ return this.readInterval;
+ }
+
+ public long getWriteInterval() {
+ return this.writeInterval;
+ }
+
+ public long getLastReadTime() {
+ return this.lastReadTime;
+ }
+
+ public void setLastReadTime(long lastReadTime) {
+ this.lastReadTime = lastReadTime;
+ }
+
+ public long getLastWriteTime() {
+ return this.lastWriteTime;
+ }
+
+ public void setLastWriteTime(long lastWriteTime) {
+ this.lastWriteTime = lastWriteTime;
+ }
+ }
+
+ private class HeartbeatTask implements Runnable {
+
+ @Override
+ public void run() {
+ long now = System.currentTimeMillis();
+ for (SessionInfo info : sessions.values()) {
+ if (info.getReadInterval() > 0 && (now - info.getLastReadTime()) > info.getReadInterval()) {
+ handleDisconnect(info.getSessiondId(), info.getUser());
+ }
+ if (info.getWriteInterval() > 0 && (now - info.getLastWriteTime()) > info.getWriteInterval()) {
+ SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT);
+ accessor.setSessionId(info.getSessiondId());
+ accessor.setUser(info.getUser());
+ initHeaders(accessor);
+ MessageHeaders headers = accessor.getMessageHeaders();
+ getClientOutboundChannel().send(MessageBuilder.createMessage(EMPTY_PAYLOAD, headers));
+ }
+ }
+ }
+ }
}
diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/SimpleBrokerRegistration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/SimpleBrokerRegistration.java
index df0a5207617..ab240affc09 100644
--- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/SimpleBrokerRegistration.java
+++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/SimpleBrokerRegistration.java
@@ -19,6 +19,7 @@ package org.springframework.messaging.simp.config;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler;
+import org.springframework.scheduling.TaskScheduler;
/**
* Registration class for configuring a {@link SimpleBrokerMessageHandler}.
@@ -28,14 +29,54 @@ import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler;
*/
public class SimpleBrokerRegistration extends AbstractBrokerRegistration {
+ private TaskScheduler taskScheduler;
+
+ private long[] heartbeat;
+
+
public SimpleBrokerRegistration(SubscribableChannel inChannel, MessageChannel outChannel, String[] prefixes) {
super(inChannel, outChannel, prefixes);
}
+
+ /**
+ * Configure the {@link org.springframework.scheduling.TaskScheduler} to
+ * use for providing heartbeat support. Setting this property also sets the
+ * {@link #setHeartbeatValue heartbeatValue} to "10000, 10000".
+ * By default this is not set.
+ * @since 4.2
+ */
+ public SimpleBrokerRegistration setTaskScheduler(TaskScheduler taskScheduler) {
+ this.taskScheduler = taskScheduler;
+ return this;
+ }
+
+ /**
+ * Configure the value for the heartbeat settings. The first number
+ * represents how often the server will write or send a heartbeat.
+ * The second is how often the client should write. 0 means no heartbeats.
+ *
By default this is set to "0, 0" unless the {@link #setTaskScheduler
+ * taskScheduler} in which case the default becomes "10000,10000"
+ * (in milliseconds).
+ * @since 4.2
+ */
+ public SimpleBrokerRegistration setHeartbeatValue(long[] heartbeat) {
+ this.heartbeat = heartbeat;
+ return this;
+ }
+
+
@Override
protected SimpleBrokerMessageHandler getMessageHandler(SubscribableChannel brokerChannel) {
- return new SimpleBrokerMessageHandler(getClientInboundChannel(),
+ SimpleBrokerMessageHandler handler = new SimpleBrokerMessageHandler(getClientInboundChannel(),
getClientOutboundChannel(), brokerChannel, getDestinationPrefixes());
+ if (this.taskScheduler != null) {
+ handler.setTaskScheduler(this.taskScheduler);
+ }
+ if (this.heartbeat != null) {
+ handler.setHeartbeatValue(this.heartbeat);
+ }
+ return handler;
}
}
diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java
index 2d22bc97170..1fd99869e34 100644
--- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java
+++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java
@@ -389,7 +389,7 @@ public class DefaultStompSession implements ConnectionHandlingStompSession {
}
}
else if (StompCommand.CONNECTED.equals(command)) {
- initHeartbeats(stompHeaders);
+ initHeartbeatTasks(stompHeaders);
this.sessionFuture.set(this);
this.sessionHandler.afterConnected(this, stompHeaders);
}
@@ -420,20 +420,18 @@ public class DefaultStompSession implements ConnectionHandlingStompSession {
handler.handleFrame(stompHeaders, object);
}
- private void initHeartbeats(StompHeaders connectedHeaders) {
- long clientRead = this.connectHeaders.getHeartbeat()[0];
- long serverWrite = connectedHeaders.getHeartbeat()[1];
-
- if (clientRead > 0 && serverWrite > 0) {
- long interval = Math.max(clientRead, serverWrite);
+ private void initHeartbeatTasks(StompHeaders connectedHeaders) {
+ long[] connect = this.connectHeaders.getHeartbeat();
+ long[] connected = connectedHeaders.getHeartbeat();
+ if (connect == null || connected == null) {
+ return;
+ }
+ if (connect[0] > 0 && connected[1] > 0) {
+ long interval = Math.max(connect[0], connected[1]);
this.connection.onWriteInactivity(new WriteInactivityTask(), interval);
}
-
- long clientWrite = this.connectHeaders.getHeartbeat()[1];
- long serverRead = connectedHeaders.getHeartbeat()[0];
-
- if (clientWrite > 0 && serverRead > 0) {
- final long interval = Math.max(clientWrite, serverRead) * HEARTBEAT_MULTIPLIER;
+ if (connect[1] > 0 && connected[0] > 0) {
+ final long interval = Math.max(connect[1], connected[0]) * HEARTBEAT_MULTIPLIER;
this.connection.onReadInactivity(new ReadInactivityTask(), interval);
}
}
diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java
index 3a820e5a04a..354ff40bce0 100644
--- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java
+++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java
@@ -16,12 +16,15 @@
package org.springframework.messaging.simp.broker;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
+import static org.junit.Assert.*;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.*;
+import java.security.Principal;
import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.ScheduledFuture;
import org.junit.Before;
import org.junit.Test;
@@ -29,13 +32,16 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
+
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
+import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.TestPrincipal;
import org.springframework.messaging.support.MessageBuilder;
+import org.springframework.scheduling.TaskScheduler;
/**
* Unit tests for SimpleBrokerMessageHandler.
@@ -43,6 +49,7 @@ import org.springframework.messaging.support.MessageBuilder;
* @author Rossen Stoyanchev
* @since 4.0
*/
+@SuppressWarnings("unchecked")
public class SimpleBrokerMessageHandlerTests {
private SimpleBrokerMessageHandler messageHandler;
@@ -56,6 +63,9 @@ public class SimpleBrokerMessageHandlerTests {
@Mock
private SubscribableChannel brokerChannel;
+ @Mock
+ private TaskScheduler taskScheduler;
+
@Captor
ArgumentCaptor> messageCaptor;
@@ -133,11 +143,11 @@ public class SimpleBrokerMessageHandlerTests {
@Test
public void connect() {
- String sess1 = "sess1";
-
this.messageHandler.start();
- Message connectMessage = createConnectMessage(sess1);
+ String id = "sess1";
+ Message connectMessage = createConnectMessage(id, new TestPrincipal("joe"), null);
+ this.messageHandler.setTaskScheduler(this.taskScheduler);
this.messageHandler.handleMessage(connectMessage);
verify(this.clientOutboundChannel, times(1)).send(this.messageCaptor.capture());
@@ -145,10 +155,150 @@ public class SimpleBrokerMessageHandlerTests {
SimpMessageHeaderAccessor connectAckHeaders = SimpMessageHeaderAccessor.wrap(connectAckMessage);
assertEquals(connectMessage, connectAckHeaders.getHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER));
- assertEquals(sess1, connectAckHeaders.getSessionId());
+ assertEquals(id, connectAckHeaders.getSessionId());
assertEquals("joe", connectAckHeaders.getUser().getName());
+ assertArrayEquals(new long[] {10000, 10000},
+ SimpMessageHeaderAccessor.getHeartbeat(connectAckHeaders.getMessageHeaders()));
}
+ @Test
+ public void heartbeatValueWithAndWithoutTaskScheduler() throws Exception {
+
+ assertNull(this.messageHandler.getHeartbeatValue());
+
+ this.messageHandler.setTaskScheduler(this.taskScheduler);
+
+ assertNotNull(this.messageHandler.getHeartbeatValue());
+ assertArrayEquals(new long[] {10000, 10000}, this.messageHandler.getHeartbeatValue());
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void startWithHeartbeatValueWithoutTaskScheduler() throws Exception {
+ this.messageHandler.setHeartbeatValue(new long[] {10000, 10000});
+ this.messageHandler.start();
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void startAndStopWithHeartbeatValue() throws Exception {
+
+ ScheduledFuture future = mock(ScheduledFuture.class);
+ when(this.taskScheduler.scheduleWithFixedDelay(any(Runnable.class), eq(15000L))).thenReturn(future);
+
+ this.messageHandler.setTaskScheduler(this.taskScheduler);
+ this.messageHandler.setHeartbeatValue(new long[] {15000, 16000});
+ this.messageHandler.start();
+
+ verify(this.taskScheduler).scheduleWithFixedDelay(any(Runnable.class), eq(15000L));
+ verifyNoMoreInteractions(this.taskScheduler, future);
+
+ this.messageHandler.stop();
+
+ verify(future).cancel(true);
+ verifyNoMoreInteractions(future);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void startWithOneZeroHeartbeatValue() throws Exception {
+
+ this.messageHandler.setTaskScheduler(this.taskScheduler);
+ this.messageHandler.setHeartbeatValue(new long[] {0, 10000});
+ this.messageHandler.start();
+
+ verify(this.taskScheduler).scheduleWithFixedDelay(any(Runnable.class), eq(10000L));
+ }
+
+ @Test
+ public void readInactivity() throws Exception {
+
+ this.messageHandler.setHeartbeatValue(new long[] {0, 1});
+ this.messageHandler.setTaskScheduler(this.taskScheduler);
+ this.messageHandler.start();
+
+ ArgumentCaptor taskCaptor = ArgumentCaptor.forClass(Runnable.class);
+ verify(this.taskScheduler).scheduleWithFixedDelay(taskCaptor.capture(), eq(1L));
+ Runnable heartbeatTask = taskCaptor.getValue();
+ assertNotNull(heartbeatTask);
+
+ String id = "sess1";
+ TestPrincipal user = new TestPrincipal("joe");
+ Message connectMessage = createConnectMessage(id, user, new long[] {1, 0});
+ this.messageHandler.handleMessage(connectMessage);
+
+ Thread.sleep(10);
+ heartbeatTask.run();
+
+ verify(this.clientOutboundChannel, atLeast(2)).send(this.messageCaptor.capture());
+ List> messages = this.messageCaptor.getAllValues();
+ assertEquals(2, messages.size());
+
+ MessageHeaders headers = messages.get(0).getHeaders();
+ assertEquals(SimpMessageType.CONNECT_ACK, headers.get(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER));
+ headers = messages.get(1).getHeaders();
+ assertEquals(SimpMessageType.DISCONNECT_ACK, headers.get(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER));
+ assertEquals(id, headers.get(SimpMessageHeaderAccessor.SESSION_ID_HEADER));
+ assertEquals(user, headers.get(SimpMessageHeaderAccessor.USER_HEADER));
+ }
+
+ @Test
+ public void writeInactivity() throws Exception {
+
+ this.messageHandler.setHeartbeatValue(new long[] {1, 0});
+ this.messageHandler.setTaskScheduler(this.taskScheduler);
+ this.messageHandler.start();
+
+ ArgumentCaptor taskCaptor = ArgumentCaptor.forClass(Runnable.class);
+ verify(this.taskScheduler).scheduleWithFixedDelay(taskCaptor.capture(), eq(1L));
+ Runnable heartbeatTask = taskCaptor.getValue();
+ assertNotNull(heartbeatTask);
+
+ String id = "sess1";
+ TestPrincipal user = new TestPrincipal("joe");
+ Message connectMessage = createConnectMessage(id, user, new long[] {0, 1});
+ this.messageHandler.handleMessage(connectMessage);
+
+ Thread.sleep(10);
+ heartbeatTask.run();
+
+ verify(this.clientOutboundChannel, times(2)).send(this.messageCaptor.capture());
+ List> messages = this.messageCaptor.getAllValues();
+ assertEquals(2, messages.size());
+
+ MessageHeaders headers = messages.get(0).getHeaders();
+ assertEquals(SimpMessageType.CONNECT_ACK, headers.get(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER));
+ headers = messages.get(1).getHeaders();
+ assertEquals(SimpMessageType.HEARTBEAT, headers.get(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER));
+ assertEquals(id, headers.get(SimpMessageHeaderAccessor.SESSION_ID_HEADER));
+ assertEquals(user, headers.get(SimpMessageHeaderAccessor.USER_HEADER));
+ }
+
+ @Test
+ public void readWriteIntervalCalculation() throws Exception {
+
+ this.messageHandler.setHeartbeatValue(new long[] {1, 1});
+ this.messageHandler.setTaskScheduler(this.taskScheduler);
+ this.messageHandler.start();
+
+ ArgumentCaptor taskCaptor = ArgumentCaptor.forClass(Runnable.class);
+ verify(this.taskScheduler).scheduleWithFixedDelay(taskCaptor.capture(), eq(1L));
+ Runnable heartbeatTask = taskCaptor.getValue();
+ assertNotNull(heartbeatTask);
+
+ String id = "sess1";
+ TestPrincipal user = new TestPrincipal("joe");
+ Message connectMessage = createConnectMessage(id, user, new long[] {10000, 10000});
+ this.messageHandler.handleMessage(connectMessage);
+
+ Thread.sleep(10);
+ heartbeatTask.run();
+
+ verify(this.clientOutboundChannel, times(1)).send(this.messageCaptor.capture());
+ List> messages = this.messageCaptor.getAllValues();
+ assertEquals(1, messages.size());
+ assertEquals(SimpMessageType.CONNECT_ACK,
+ messages.get(0).getHeaders().get(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER));
+ }
private Message createSubscriptionMessage(String sessionId, String subcriptionId, String destination) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.SUBSCRIBE);
@@ -158,17 +308,18 @@ public class SimpleBrokerMessageHandlerTests {
return MessageBuilder.createMessage("", headers.getMessageHeaders());
}
- private Message createConnectMessage(String sessionId) {
- SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
- headers.setSessionId(sessionId);
- headers.setUser(new TestPrincipal("joe"));
- return MessageBuilder.createMessage("", headers.getMessageHeaders());
+ private Message createConnectMessage(String sessionId, Principal user, long[] heartbeat) {
+ SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
+ accessor.setSessionId(sessionId);
+ accessor.setUser(user);
+ accessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, heartbeat);
+ return MessageBuilder.createMessage("", accessor.getMessageHeaders());
}
private Message createMessage(String destination, String payload) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
headers.setDestination(destination);
- return MessageBuilder.createMessage("", headers.getMessageHeaders());
+ return MessageBuilder.createMessage(payload, headers.getMessageHeaders());
}
private boolean messageCaptured(String sessionId, String subcriptionId, String destination) {
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java
index db7d530083f..950de748e0e 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java
@@ -320,6 +320,14 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
String pathMatcherRef = messageBrokerElement.getAttribute("path-matcher");
brokerDef.getPropertyValues().add("pathMatcher", new RuntimeBeanReference(pathMatcherRef));
}
+ if (simpleBrokerElem.hasAttribute("scheduler")) {
+ String scheduler = simpleBrokerElem.getAttribute("scheduler");
+ brokerDef.getPropertyValues().add("taskScheduler", new RuntimeBeanReference(scheduler));
+ }
+ if (simpleBrokerElem.hasAttribute("heartbeat")) {
+ String heartbeatValue = simpleBrokerElem.getAttribute("heartbeat");
+ brokerDef.getPropertyValues().add("heartbeatValue", heartbeatValue);
+ }
}
else if (brokerRelayElem != null) {
String prefix = brokerRelayElem.getAttribute("prefix");
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
index 4f38ec66c12..a39d91bee25 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
@@ -34,6 +34,7 @@ import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
+import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.simp.SimpAttributes;
import org.springframework.messaging.simp.SimpAttributesContextHolder;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
@@ -233,17 +234,18 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
StompHeaderAccessor headerAccessor =
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
- if (logger.isTraceEnabled()) {
- logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload()));
- }
-
headerAccessor.setSessionId(session.getId());
headerAccessor.setSessionAttributes(session.getAttributes());
headerAccessor.setUser(session.getPrincipal());
+ headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
if (!detectImmutableMessageInterceptor(outputChannel)) {
headerAccessor.setImmutable();
}
+ if (logger.isTraceEnabled()) {
+ logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload()));
+ }
+
if (StompCommand.CONNECT.equals(headerAccessor.getCommand())) {
this.stats.incrementConnectCount();
}
@@ -401,13 +403,17 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
}
else if (accessor instanceof SimpMessageHeaderAccessor) {
stompAccessor = StompHeaderAccessor.wrap(message);
- if (SimpMessageType.CONNECT_ACK.equals(stompAccessor.getMessageType())) {
+ SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(message.getHeaders());
+ if (SimpMessageType.CONNECT_ACK.equals(messageType)) {
stompAccessor = convertConnectAcktoStompConnected(stompAccessor);
}
- else if (SimpMessageType.DISCONNECT_ACK.equals(stompAccessor.getMessageType())) {
+ else if (SimpMessageType.DISCONNECT_ACK.equals(messageType)) {
stompAccessor = StompHeaderAccessor.create(StompCommand.ERROR);
stompAccessor.setMessage("Session closed.");
}
+ else if (SimpMessageType.HEARTBEAT.equals(messageType)) {
+ stompAccessor = StompHeaderAccessor.createForHeartbeat();
+ }
else if (stompAccessor.getCommand() == null || StompCommand.SEND.equals(stompAccessor.getCommand())) {
stompAccessor.updateStompCommandAsServerMessage();
}
@@ -429,23 +435,21 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
Message> message = (Message>) connectAckHeaders.getHeader(name);
Assert.notNull(message, "Original STOMP CONNECT not found in " + connectAckHeaders);
StompHeaderAccessor connectHeaders = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
- String version;
+ StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED);
Set acceptVersions = connectHeaders.getAcceptVersion();
if (acceptVersions.contains("1.2")) {
- version = "1.2";
+ connectedHeaders.setVersion("1.2");
}
else if (acceptVersions.contains("1.1")) {
- version = "1.1";
+ connectedHeaders.setVersion("1.1");
}
- else if (acceptVersions.isEmpty()) {
- version = null;
- }
- else {
+ else if (!acceptVersions.isEmpty()) {
throw new IllegalArgumentException("Unsupported STOMP version '" + acceptVersions + "'");
}
- StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED);
- connectedHeaders.setVersion(version);
- connectedHeaders.setHeartbeat(0, 0); // not supported
+ long[] heartbeat = (long[]) connectAckHeaders.getHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER);
+ if (heartbeat != null) {
+ connectedHeaders.setHeartbeat(heartbeat[0], heartbeat[1]);
+ }
return connectedHeaders;
}
diff --git a/spring-websocket/src/main/resources/META-INF/spring.schemas b/spring-websocket/src/main/resources/META-INF/spring.schemas
index 5bcbadcb44b..4cd992cdf66 100644
--- a/spring-websocket/src/main/resources/META-INF/spring.schemas
+++ b/spring-websocket/src/main/resources/META-INF/spring.schemas
@@ -1,3 +1,3 @@
http\://www.springframework.org/schema/websocket/spring-websocket-4.0.xsd=org/springframework/web/socket/config/spring-websocket-4.0.xsd
http\://www.springframework.org/schema/websocket/spring-websocket-4.1.xsd=org/springframework/web/socket/config/spring-websocket-4.1.xsd
-http\://www.springframework.org/schema/websocket/spring-websocket.xsd=org/springframework/web/socket/config/spring-websocket-4.1.xsd
+http\://www.springframework.org/schema/websocket/spring-websocket.xsd=org/springframework/web/socket/config/spring-websocket-4.2.xsd
diff --git a/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.2.xsd b/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.2.xsd
new file mode 100644
index 00000000000..8c1945a9688
--- /dev/null
+++ b/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.2.xsd
@@ -0,0 +1,896 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java
index 55bde4374b3..20a574aa1b5 100644
--- a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java
+++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2014 the original author or authors.
+ * Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ package org.springframework.web.socket.config;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collection;
import java.util.List;
import org.hamcrest.Matchers;
@@ -180,8 +181,10 @@ public class MessageBrokerBeanDefinitionParserTests {
SimpleBrokerMessageHandler brokerMessageHandler = this.appContext.getBean(SimpleBrokerMessageHandler.class);
assertNotNull(brokerMessageHandler);
- assertEquals(Arrays.asList("/topic", "/queue"),
- new ArrayList(brokerMessageHandler.getDestinationPrefixes()));
+ Collection prefixes = brokerMessageHandler.getDestinationPrefixes();
+ assertEquals(Arrays.asList("/topic", "/queue"), new ArrayList(prefixes));
+ assertNotNull(brokerMessageHandler.getTaskScheduler());
+ assertArrayEquals(new long[] {15000, 15000}, brokerMessageHandler.getHeartbeatValue());
List> subscriberTypes =
Arrays.>asList(SimpAnnotationMethodMessageHandler.class,
diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java
index e4bdd1db52d..ac219f4fb31 100644
--- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java
+++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java
@@ -17,6 +17,7 @@
package org.springframework.web.socket.config.annotation;
import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
import java.util.ArrayList;
import java.util.List;
@@ -37,6 +38,7 @@ import org.springframework.messaging.handler.annotation.SendTo;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.annotation.SubscribeMapping;
import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler;
+import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.UserDestinationMessageHandler;
@@ -44,6 +46,7 @@ import org.springframework.messaging.support.AbstractSubscribableChannel;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.ExecutorSubscribableChannel;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
+import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.stereotype.Controller;
import org.springframework.web.servlet.HandlerMapping;
@@ -149,14 +152,18 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
}
@Test
- public void messageBrokerSockJsTaskScheduler() {
+ public void taskScheduler() {
ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class);
- ThreadPoolTaskScheduler taskScheduler =
- config.getBean("messageBrokerSockJsTaskScheduler", ThreadPoolTaskScheduler.class);
+ String name = "messageBrokerSockJsTaskScheduler";
+ ThreadPoolTaskScheduler taskScheduler = config.getBean(name, ThreadPoolTaskScheduler.class);
ScheduledThreadPoolExecutor executor = taskScheduler.getScheduledThreadPoolExecutor();
assertEquals(Runtime.getRuntime().availableProcessors(), executor.getCorePoolSize());
assertTrue(executor.getRemoveOnCancelPolicy());
+
+ SimpleBrokerMessageHandler handler = config.getBean(SimpleBrokerMessageHandler.class);
+ assertNotNull(handler.getTaskScheduler());
+ assertArrayEquals(new long[] {15000, 15000}, handler.getHeartbeatValue());
}
@Test
@@ -200,6 +207,7 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
}
+ @SuppressWarnings("unused")
@Controller
static class TestController {
@@ -215,6 +223,7 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
}
}
+ @SuppressWarnings("unused")
@Configuration
static class TestConfigurer extends AbstractWebSocketMessageBrokerConfigurer {
@@ -234,6 +243,13 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
registration.setSendTimeLimit(25 * 1000);
registration.setSendBufferSizeLimit(1024 * 1024);
}
+
+ @Override
+ public void configureMessageBroker(MessageBrokerRegistry registry) {
+ registry.enableSimpleBroker()
+ .setTaskScheduler(mock(TaskScheduler.class))
+ .setHeartbeatValue(new long[] {15000, 15000});
+ }
}
@Configuration
diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java
index c8ab1f0b3bf..5e9dd10d8d0 100644
--- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java
+++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java
@@ -16,21 +16,12 @@
package org.springframework.web.socket.messaging;
-import static org.hamcrest.Matchers.is;
-import static org.junit.Assert.assertArrayEquals;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertThat;
-import static org.junit.Assert.assertTrue;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.reset;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyNoMoreInteractions;
-import static org.mockito.Mockito.verifyZeroInteractions;
+import static org.hamcrest.Matchers.*;
+import static org.junit.Assert.*;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.*;
-import java.nio.ByteBuffer;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -42,6 +33,7 @@ import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
+
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.PayloadApplicationEvent;
@@ -53,7 +45,6 @@ import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.TestPrincipal;
import org.springframework.messaging.simp.stomp.StompCommand;
-import org.springframework.messaging.simp.stomp.StompDecoder;
import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DefaultUserSessionRegistry;
@@ -103,7 +94,7 @@ public class StompSubProtocolHandlerTests {
}
@Test
- public void handleMessageToClientConnected() {
+ public void handleMessageToClientWithConnectedFrame() {
UserSessionRegistry registry = new DefaultUserSessionRegistry();
this.protocolHandler.setUserSessionRegistry(registry);
@@ -120,7 +111,7 @@ public class StompSubProtocolHandlerTests {
}
@Test
- public void handleMessageToClientConnectedUniqueUserName() {
+ public void handleMessageToClientWithDestinationUserNameProvider() {
this.session.setPrincipal(new UniqueUser("joe"));
@@ -140,47 +131,197 @@ public class StompSubProtocolHandlerTests {
}
@Test
- public void handleMessageToClientConnectedWithHeartbeats() {
+ public void handleMessageToClientWithSimpConnectAck() {
- SockJsSession sockJsSession = Mockito.mock(SockJsSession.class);
+ StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT);
+ accessor.setHeartbeat(10000, 10000);
+ accessor.setAcceptVersion("1.0,1.1");
+ Message> connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
- StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
- headers.setHeartbeat(0,10);
- Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
- this.protocolHandler.handleMessageToClient(sockJsSession, message);
+ SimpMessageHeaderAccessor ackAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
+ ackAccessor.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, connectMessage);
+ ackAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, new long[] {15000, 15000});
+ Message ackMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, ackAccessor.getMessageHeaders());
+ this.protocolHandler.handleMessageToClient(this.session, ackMessage);
- verify(sockJsSession).disableHeartbeat();
+ assertEquals(1, this.session.getSentMessages().size());
+ TextMessage actual = (TextMessage) this.session.getSentMessages().get(0);
+ assertEquals("CONNECTED\n" + "version:1.1\n" + "heart-beat:15000,15000\n" +
+ "user-name:joe\n" + "\n" + "\u0000", actual.getPayload());
}
@Test
- public void handleMessageToClientConnectAck() {
+ public void handleMessageToClientWithSimpHeartbeat() {
- StompHeaderAccessor connectHeaders = StompHeaderAccessor.create(StompCommand.CONNECT);
- connectHeaders.setHeartbeat(10000, 10000);
- connectHeaders.setAcceptVersion("1.0,1.1");
- Message> connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectHeaders.getMessageHeaders());
-
- SimpMessageHeaderAccessor connectAckHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
- connectAckHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, connectMessage);
- Message connectAckMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectAckHeaders.getMessageHeaders());
-
- this.protocolHandler.handleMessageToClient(this.session, connectAckMessage);
-
- verifyNoMoreInteractions(this.channel);
-
- // Check CONNECTED reply
+ SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT);
+ accessor.setSessionId("s1");
+ accessor.setUser(new TestPrincipal("joe"));
+ Message ackMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
+ this.protocolHandler.handleMessageToClient(this.session, ackMessage);
assertEquals(1, this.session.getSentMessages().size());
- TextMessage textMessage = (TextMessage) this.session.getSentMessages().get(0);
+ TextMessage actual = (TextMessage) this.session.getSentMessages().get(0);
+ assertEquals("\n", actual.getPayload());
+ }
- List> messages = new StompDecoder().decode(ByteBuffer.wrap(textMessage.getPayload().getBytes()));
- assertEquals(1, messages.size());
- StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(messages.get(0));
+ @Test
+ public void handleMessageToClientWithHeartbeatSuppressingSockJsHeartbeat() throws IOException {
- assertEquals(StompCommand.CONNECTED, replyHeaders.getCommand());
- assertEquals("1.1", replyHeaders.getVersion());
- assertArrayEquals(new long[] {0, 0}, replyHeaders.getHeartbeat());
- assertEquals("joe", replyHeaders.getNativeHeader("user-name").get(0));
+ SockJsSession sockJsSession = Mockito.mock(SockJsSession.class);
+ StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED);
+ accessor.setHeartbeat(0, 10);
+ Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
+ this.protocolHandler.handleMessageToClient(sockJsSession, message);
+
+ verify(sockJsSession).getPrincipal();
+ verify(sockJsSession).disableHeartbeat();
+ verify(sockJsSession).sendMessage(any(WebSocketMessage.class));
+ verifyNoMoreInteractions(sockJsSession);
+
+ sockJsSession = Mockito.mock(SockJsSession.class);
+ accessor = StompHeaderAccessor.create(StompCommand.CONNECTED);
+ accessor.setHeartbeat(0, 0);
+ message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
+ this.protocolHandler.handleMessageToClient(sockJsSession, message);
+
+ verify(sockJsSession).getPrincipal();
+ verify(sockJsSession).sendMessage(any(WebSocketMessage.class));
+ verifyNoMoreInteractions(sockJsSession);
+ }
+
+ @Test
+ public void handleMessageToClientWithUserDestination() {
+
+ StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE);
+ headers.setMessageId("mess0");
+ headers.setSubscriptionId("sub0");
+ headers.setDestination("/queue/foo-user123");
+ headers.setNativeHeader(StompHeaderAccessor.ORIGINAL_DESTINATION, "/user/queue/foo");
+ Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
+ this.protocolHandler.handleMessageToClient(this.session, message);
+
+ assertEquals(1, this.session.getSentMessages().size());
+ WebSocketMessage> textMessage = this.session.getSentMessages().get(0);
+ assertTrue(((String) textMessage.getPayload()).contains("destination:/user/queue/foo\n"));
+ assertFalse(((String) textMessage.getPayload()).contains(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION));
+ }
+
+ // SPR-12475
+
+ @Test
+ public void handleMessageToClientWithBinaryWebSocketMessage() {
+
+ StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE);
+ headers.setMessageId("mess0");
+ headers.setSubscriptionId("sub0");
+ headers.setContentType(MimeTypeUtils.APPLICATION_OCTET_STREAM);
+ headers.setDestination("/queue/foo");
+
+ // Non-empty payload
+
+ byte[] payload = new byte[1];
+ Message message = MessageBuilder.createMessage(payload, headers.getMessageHeaders());
+ this.protocolHandler.handleMessageToClient(this.session, message);
+
+ assertEquals(1, this.session.getSentMessages().size());
+ WebSocketMessage> webSocketMessage = this.session.getSentMessages().get(0);
+ assertTrue(webSocketMessage instanceof BinaryMessage);
+
+ // Empty payload
+
+ payload = EMPTY_PAYLOAD;
+ message = MessageBuilder.createMessage(payload, headers.getMessageHeaders());
+ this.protocolHandler.handleMessageToClient(this.session, message);
+
+ assertEquals(2, this.session.getSentMessages().size());
+ webSocketMessage = this.session.getSentMessages().get(1);
+ assertTrue(webSocketMessage instanceof TextMessage);
+ }
+
+ @Test
+ public void handleMessageFromClient() {
+
+ TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).headers(
+ "login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build();
+
+ this.protocolHandler.afterSessionStarted(this.session, this.channel);
+ this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);
+
+ verify(this.channel).send(this.messageCaptor.capture());
+ Message> actual = this.messageCaptor.getValue();
+ assertNotNull(actual);
+
+ assertEquals("s1", SimpMessageHeaderAccessor.getSessionId(actual.getHeaders()));
+ assertNotNull(SimpMessageHeaderAccessor.getSessionAttributes(actual.getHeaders()));
+ assertNotNull(SimpMessageHeaderAccessor.getUser(actual.getHeaders()));
+ assertEquals("joe", SimpMessageHeaderAccessor.getUser(actual.getHeaders()).getName());
+ assertNotNull(SimpMessageHeaderAccessor.getHeartbeat(actual.getHeaders()));
+ assertArrayEquals(new long[] {10000, 10000}, SimpMessageHeaderAccessor.getHeartbeat(actual.getHeaders()));
+
+ StompHeaderAccessor stompAccessor = StompHeaderAccessor.wrap(actual);
+ assertEquals(StompCommand.CONNECT, stompAccessor.getCommand());
+ assertEquals("guest", stompAccessor.getLogin());
+ assertEquals("guest", stompAccessor.getPasscode());
+ assertArrayEquals(new long[] {10000, 10000}, stompAccessor.getHeartbeat());
+ assertEquals(new HashSet<>(Arrays.asList("1.1","1.0")), stompAccessor.getAcceptVersion());
+ assertEquals(0, this.session.getSentMessages().size());
+ }
+
+ @Test
+ public void handleMessageFromClientWithImmutableMessageInterceptor() {
+ AtomicReference mutable = new AtomicReference<>();
+ ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel();
+ channel.addInterceptor(new ChannelInterceptorAdapter() {
+ @Override
+ public Message> preSend(Message> message, MessageChannel channel) {
+ mutable.set(MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class).isMutable());
+ return message;
+ }
+ });
+ channel.addInterceptor(new ImmutableMessageChannelInterceptor());
+
+ StompSubProtocolHandler handler = new StompSubProtocolHandler();
+ handler.afterSessionStarted(this.session, channel);
+
+ TextMessage message = StompTextMessageBuilder.create(StompCommand.CONNECT).build();
+ handler.handleMessageFromClient(this.session, message, channel);
+ assertNotNull(mutable.get());
+ assertTrue(mutable.get());
+ }
+
+ @Test
+ public void handleMessageFromClientWithoutImmutableMessageInterceptor() {
+ AtomicReference mutable = new AtomicReference<>();
+ ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel();
+ channel.addInterceptor(new ChannelInterceptorAdapter() {
+ @Override
+ public Message> preSend(Message> message, MessageChannel channel) {
+ mutable.set(MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class).isMutable());
+ return message;
+ }
+ });
+
+ StompSubProtocolHandler handler = new StompSubProtocolHandler();
+ handler.afterSessionStarted(this.session, channel);
+
+ TextMessage message = StompTextMessageBuilder.create(StompCommand.CONNECT).build();
+ handler.handleMessageFromClient(this.session, message, channel);
+ assertNotNull(mutable.get());
+ assertFalse(mutable.get());
+ }
+
+ @Test
+ public void handleMessageFromClientWithInvalidStompCommand() {
+
+ TextMessage textMessage = new TextMessage("FOO\n\n\0");
+
+ this.protocolHandler.afterSessionStarted(this.session, this.channel);
+ this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);
+
+ verifyZeroInteractions(this.channel);
+ assertEquals(1, this.session.getSentMessages().size());
+ TextMessage actual = (TextMessage) this.session.getSentMessages().get(0);
+ assertTrue(actual.getPayload().startsWith("ERROR"));
}
@Test
@@ -262,137 +403,6 @@ public class StompSubProtocolHandlerTests {
assertEquals("joe", accessor.getUser().getName());
}
- @Test
- public void handleMessageToClientUserDestination() {
-
- StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE);
- headers.setMessageId("mess0");
- headers.setSubscriptionId("sub0");
- headers.setDestination("/queue/foo-user123");
- headers.setNativeHeader(StompHeaderAccessor.ORIGINAL_DESTINATION, "/user/queue/foo");
- Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
- this.protocolHandler.handleMessageToClient(this.session, message);
-
- assertEquals(1, this.session.getSentMessages().size());
- WebSocketMessage> textMessage = this.session.getSentMessages().get(0);
- assertTrue(((String) textMessage.getPayload()).contains("destination:/user/queue/foo\n"));
- assertFalse(((String) textMessage.getPayload()).contains(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION));
- }
-
- // SPR-12475
-
- @Test
- public void handleMessageToClientBinaryWebSocketMessage() {
-
- StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE);
- headers.setMessageId("mess0");
- headers.setSubscriptionId("sub0");
- headers.setContentType(MimeTypeUtils.APPLICATION_OCTET_STREAM);
- headers.setDestination("/queue/foo");
-
- // Non-empty payload
-
- byte[] payload = new byte[1];
- Message message = MessageBuilder.createMessage(payload, headers.getMessageHeaders());
- this.protocolHandler.handleMessageToClient(this.session, message);
-
- assertEquals(1, this.session.getSentMessages().size());
- WebSocketMessage> webSocketMessage = this.session.getSentMessages().get(0);
- assertTrue(webSocketMessage instanceof BinaryMessage);
-
- // Empty payload
-
- payload = EMPTY_PAYLOAD;
- message = MessageBuilder.createMessage(payload, headers.getMessageHeaders());
- this.protocolHandler.handleMessageToClient(this.session, message);
-
- assertEquals(2, this.session.getSentMessages().size());
- webSocketMessage = this.session.getSentMessages().get(1);
- assertTrue(webSocketMessage instanceof TextMessage);
- }
-
- @Test
- public void handleMessageFromClient() {
-
- TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).headers(
- "login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build();
-
- this.protocolHandler.afterSessionStarted(this.session, this.channel);
- this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);
-
- verify(this.channel).send(this.messageCaptor.capture());
- Message> actual = this.messageCaptor.getValue();
- assertNotNull(actual);
-
- StompHeaderAccessor headers = StompHeaderAccessor.wrap(actual);
- assertEquals(StompCommand.CONNECT, headers.getCommand());
- assertEquals("s1", headers.getSessionId());
- assertNotNull(headers.getSessionAttributes());
- assertEquals("joe", headers.getUser().getName());
- assertEquals("guest", headers.getLogin());
- assertEquals("guest", headers.getPasscode());
- assertArrayEquals(new long[] {10000, 10000}, headers.getHeartbeat());
- assertEquals(new HashSet<>(Arrays.asList("1.1","1.0")), headers.getAcceptVersion());
-
- assertEquals(0, this.session.getSentMessages().size());
- }
-
- @Test
- public void handleMessageFromClientWithImmutableMessageInterceptor() {
- AtomicReference mutable = new AtomicReference<>();
- ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel();
- channel.addInterceptor(new ChannelInterceptorAdapter() {
- @Override
- public Message> preSend(Message> message, MessageChannel channel) {
- mutable.set(MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class).isMutable());
- return message;
- }
- });
- channel.addInterceptor(new ImmutableMessageChannelInterceptor());
-
- StompSubProtocolHandler handler = new StompSubProtocolHandler();
- handler.afterSessionStarted(this.session, channel);
-
- TextMessage message = StompTextMessageBuilder.create(StompCommand.CONNECT).build();
- handler.handleMessageFromClient(this.session, message, channel);
- assertNotNull(mutable.get());
- assertTrue(mutable.get());
- }
-
- @Test
- public void handleMessageFromClientWithoutImmutableMessageInterceptor() {
- AtomicReference mutable = new AtomicReference<>();
- ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel();
- channel.addInterceptor(new ChannelInterceptorAdapter() {
- @Override
- public Message> preSend(Message> message, MessageChannel channel) {
- mutable.set(MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class).isMutable());
- return message;
- }
- });
-
- StompSubProtocolHandler handler = new StompSubProtocolHandler();
- handler.afterSessionStarted(this.session, channel);
-
- TextMessage message = StompTextMessageBuilder.create(StompCommand.CONNECT).build();
- handler.handleMessageFromClient(this.session, message, channel);
- assertNotNull(mutable.get());
- assertFalse(mutable.get());
- }
- @Test
- public void handleMessageFromClientInvalidStompCommand() {
-
- TextMessage textMessage = new TextMessage("FOO\n\n\0");
-
- this.protocolHandler.afterSessionStarted(this.session, this.channel);
- this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);
-
- verifyZeroInteractions(this.channel);
- assertEquals(1, this.session.getSentMessages().size());
- TextMessage actual = (TextMessage) this.session.getSentMessages().get(0);
- assertTrue(actual.getPayload().startsWith("ERROR"));
- }
-
@Test
public void webSocketScope() {
@@ -421,10 +431,10 @@ public class StompSubProtocolHandlerTests {
TextMessage textMessage = new TextMessage(new StompEncoder().encode(message));
this.protocolHandler.handleMessageFromClient(this.session, textMessage, testChannel);
- assertEquals(Collections.emptyList(), session.getSentMessages());
+ assertEquals(Collections.>emptyList(), session.getSentMessages());
this.protocolHandler.afterSessionEnded(this.session, CloseStatus.BAD_DATA, testChannel);
- assertEquals(Collections.emptyList(), session.getSentMessages());
+ assertEquals(Collections.>emptyList(), this.session.getSentMessages());
verify(runnable, times(1)).run();
}
diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/WebSocketStompClientIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/WebSocketStompClientIntegrationTests.java
index d589d43e8ca..9cddb8bad09 100644
--- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/WebSocketStompClientIntegrationTests.java
+++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/WebSocketStompClientIntegrationTests.java
@@ -175,6 +175,14 @@ public class WebSocketStompClientIntegrationTests {
received.add((String) payload);
}
});
+ try {
+ // Delay send since server processes concurrently
+ // Ideally order should be preserved or receipts supported (simple broker)
+ Thread.sleep(500);
+ }
+ catch (InterruptedException ex) {
+ logger.error(ex);
+ }
session.send(this.topic, this.payload);
}
diff --git a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-simple.xml b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-simple.xml
index 4405fe1864d..ffb1fdd7d6b 100644
--- a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-simple.xml
+++ b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-simple.xml
@@ -32,7 +32,7 @@
-
+
@@ -42,5 +42,6 @@
+