diff --git a/build.gradle b/build.gradle index 68821626078..8a52c65c8bd 100644 --- a/build.gradle +++ b/build.gradle @@ -484,6 +484,7 @@ project("spring-websocket") { optional("org.eclipse.jetty.websocket:websocket-client:9.0.3.v20130506") optional("com.fasterxml.jackson.core:jackson-databind:2.2.0") // required for SockJS support currently optional("reactor:reactor-core:1.0.0.BUILD-SNAPSHOT") + optional("reactor:reactor-tcp:1.0.0.BUILD-SNAPSHOT") } repositories { diff --git a/spring-websocket/src/main/java/org/springframework/web/stomp/StompHeaders.java b/spring-websocket/src/main/java/org/springframework/web/stomp/StompHeaders.java index 38f07c5f9e0..206bc2f147d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/stomp/StompHeaders.java +++ b/spring-websocket/src/main/java/org/springframework/web/stomp/StompHeaders.java @@ -39,6 +39,8 @@ public class StompHeaders implements MultiValueMap, Serializable private static final long serialVersionUID = 1L; + // TODO: separate client from server headers so they can't be mixed + // Client private static final String ACCEPT_VERSION = "accept-version"; @@ -56,6 +58,8 @@ public class StompHeaders implements MultiValueMap, Serializable private static final String VERSION = "version"; + private static final String MESSAGE = "message"; + // Client and Server private static final String ACK = "ack"; @@ -163,6 +167,14 @@ public class StompHeaders implements MultiValueMap, Serializable set(SUBSCRIPTION, id); } + public String getMessage() { + return getFirst(MESSAGE); + } + + public void setMessage(String id) { + set(MESSAGE, id); + } + // MultiValueMap methods diff --git a/spring-websocket/src/main/java/org/springframework/web/stomp/StompMessage.java b/spring-websocket/src/main/java/org/springframework/web/stomp/StompMessage.java index c23067744ba..421dde2a173 100644 --- a/spring-websocket/src/main/java/org/springframework/web/stomp/StompMessage.java +++ b/spring-websocket/src/main/java/org/springframework/web/stomp/StompMessage.java @@ -34,6 +34,8 @@ public class StompMessage { private final byte[] payload; + private String sessionId; + public StompMessage(StompCommand command, StompHeaders headers, byte[] payload) { this.command = command; @@ -60,6 +62,14 @@ public class StompMessage { return this.payload; } + public void setSessionId(String sessionId) { + this.sessionId = sessionId; + } + + public String getStompSessionId() { + return this.sessionId; + } + @Override public String toString() { return "StompMessage [headers=" + this.headers + ", payload=" + new String(this.payload) + "]"; diff --git a/spring-websocket/src/main/java/org/springframework/web/stomp/StompSession.java b/spring-websocket/src/main/java/org/springframework/web/stomp/StompSession.java index ee2b635889c..b334a0e60b9 100644 --- a/spring-websocket/src/main/java/org/springframework/web/stomp/StompSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/stomp/StompSession.java @@ -27,8 +27,10 @@ public interface StompSession { String getId(); + /** + * If the message is a STOMP ERROR message, the session will also be closed. + * + */ void sendMessage(StompMessage message) throws IOException; - void close() throws Exception; - } diff --git a/spring-websocket/src/main/java/org/springframework/web/stomp/adapter/StompMessageProcessor.java b/spring-websocket/src/main/java/org/springframework/web/stomp/adapter/StompMessageProcessor.java index 45799a87b58..8d2af8e888b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/stomp/adapter/StompMessageProcessor.java +++ b/spring-websocket/src/main/java/org/springframework/web/stomp/adapter/StompMessageProcessor.java @@ -16,8 +16,6 @@ package org.springframework.web.stomp.adapter; -import java.io.IOException; - import org.springframework.web.stomp.StompMessage; import org.springframework.web.stomp.StompSession; @@ -28,6 +26,8 @@ import org.springframework.web.stomp.StompSession; */ public interface StompMessageProcessor { - void processMessage(StompSession stompSession, StompMessage message) throws IOException; + void processMessage(StompSession stompSession, StompMessage message); + + void processConnectionClosed(StompSession stompSession); } diff --git a/spring-websocket/src/main/java/org/springframework/web/stomp/adapter/StompWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/stomp/adapter/StompWebSocketHandler.java index 019d9c03c48..837ae9bf578 100644 --- a/spring-websocket/src/main/java/org/springframework/web/stomp/adapter/StompWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/stomp/adapter/StompWebSocketHandler.java @@ -23,6 +23,8 @@ import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; +import org.springframework.web.stomp.StompCommand; +import org.springframework.web.stomp.StompHeaders; import org.springframework.web.stomp.StompMessage; import org.springframework.web.stomp.StompSession; import org.springframework.web.stomp.support.StompMessageConverter; @@ -54,18 +56,43 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { } @Override - protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { + protected void handleTextMessage(WebSocketSession session, TextMessage message) { StompSession stompSession = this.sessions.get(session.getId()); Assert.notNull(stompSession, "No STOMP session for WebSocket session id=" + session.getId()); - StompMessage stompMessage = this.messageConverter.toStompMessage(message.getPayload()); - this.messageProcessor.processMessage(stompSession, stompMessage); + try { + StompMessage stompMessage = this.messageConverter.toStompMessage(message.getPayload()); + stompMessage.setSessionId(stompSession.getId()); + + // TODO: validate size limits + // http://stomp.github.io/stomp-specification-1.2.html#Size_Limits + + this.messageProcessor.processMessage(stompSession, stompMessage); + + // TODO: send RECEIPT message if incoming message has "receipt" header + // http://stomp.github.io/stomp-specification-1.2.html#Header_receipt + + } + catch (Throwable error) { + StompHeaders headers = new StompHeaders(); + headers.setMessage(error.getMessage()); + StompMessage errorMessage = new StompMessage(StompCommand.ERROR, headers); + try { + stompSession.sendMessage(errorMessage); + } + catch (Throwable t) { + // ignore + } + } } @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { - this.sessions.remove(session.getId()); + StompSession stompSession = this.sessions.remove(session.getId()); + if (stompSession != null) { + this.messageProcessor.processConnectionClosed(stompSession); + } } } diff --git a/spring-websocket/src/main/java/org/springframework/web/stomp/adapter/WebSocketStompSession.java b/spring-websocket/src/main/java/org/springframework/web/stomp/adapter/WebSocketStompSession.java index 3e34752bc60..fa71c7b7722 100644 --- a/spring-websocket/src/main/java/org/springframework/web/stomp/adapter/WebSocketStompSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/stomp/adapter/WebSocketStompSession.java @@ -19,8 +19,10 @@ package org.springframework.web.stomp.adapter; import java.io.IOException; import org.springframework.util.Assert; +import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.stomp.StompCommand; import org.springframework.web.stomp.StompMessage; import org.springframework.web.stomp.StompSession; import org.springframework.web.stomp.support.StompMessageConverter; @@ -53,19 +55,19 @@ public class WebSocketStompSession implements StompSession { @Override public void sendMessage(StompMessage message) throws IOException { + Assert.notNull(this.webSocketSession, "Cannot send message without active session"); - byte[] bytes = this.messageConverter.fromStompMessage(message); - this.webSocketSession.sendMessage(new TextMessage(new String(bytes, StompMessage.CHARSET))); - } - public void sessionClosed() { - this.webSocketSession = null; - } - - @Override - public void close() throws Exception { - this.webSocketSession.close(); - this.webSocketSession = null; + try { + byte[] bytes = this.messageConverter.fromStompMessage(message); + this.webSocketSession.sendMessage(new TextMessage(new String(bytes, StompMessage.CHARSET))); + } + finally { + if (StompCommand.ERROR.equals(message.getCommand())) { + this.webSocketSession.close(CloseStatus.PROTOCOL_ERROR); + this.webSocketSession = null; + } + } } } \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/stomp/server/ReactorServerStompMessageProcessor.java b/spring-websocket/src/main/java/org/springframework/web/stomp/server/ReactorServerStompMessageProcessor.java index 027bf8e1ee0..3e9dd5c8a63 100644 --- a/spring-websocket/src/main/java/org/springframework/web/stomp/server/ReactorServerStompMessageProcessor.java +++ b/spring-websocket/src/main/java/org/springframework/web/stomp/server/ReactorServerStompMessageProcessor.java @@ -24,7 +24,7 @@ import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.web.stomp.StompCommand; import org.springframework.web.stomp.StompException; import org.springframework.web.stomp.StompHeaders; @@ -37,7 +37,6 @@ import reactor.core.Reactor; import reactor.fn.Consumer; import reactor.fn.Event; import reactor.fn.Registration; -import reactor.fn.Tuple; /** * @author Gary Russell @@ -59,35 +58,66 @@ public class ReactorServerStompMessageProcessor implements StompMessageProcessor this.reactor = reactor; } - public void processMessage(StompSession session, StompMessage message) throws IOException { + public void processMessage(StompSession session, StompMessage message) { - StompCommand command = message.getCommand(); - Assert.notNull(command, "STOMP command not found"); - - if (StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command)) { - connect(session, message); + try { + StompCommand command = message.getCommand(); + if (StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command)) { + connect(session, message); + } + else if (StompCommand.SUBSCRIBE.equals(command)) { + subscribe(session, message); + } + else if (StompCommand.UNSUBSCRIBE.equals(command)) { + unsubscribe(session, message); + } + else if (StompCommand.SEND.equals(command)) { + send(session, message); + } + else if (StompCommand.DISCONNECT.equals(command)) { + disconnect(session, message); + } + else if (StompCommand.ACK.equals(command) || StompCommand.NACK.equals(command)) { + // TODO + logger.warn("Ignoring " + command + ". It is not supported yet."); + } + else if (StompCommand.BEGIN.equals(command) || StompCommand.COMMIT.equals(command) || StompCommand.ABORT.equals(command)) { + // TODO + logger.warn("Ignoring " + command + ". It is not supported yet."); + } + else { + sendErrorMessage(session, "Invalid STOMP command " + command); + } } - else if (StompCommand.SUBSCRIBE.equals(command)) { - subscribe(session, message); - } - else if (StompCommand.UNSUBSCRIBE.equals(command)) { - unsubscribe(session, message); - } - else if (StompCommand.SEND.equals(command)) { - send(session, message); - } - else if (StompCommand.DISCONNECT.equals(command)) { - disconnect(session); - } - else { - throw new IllegalStateException("Unexpected command: " + command); + catch (Throwable t) { + handleError(session, t); } } - protected void connect(StompSession session, StompMessage connectMessage) throws IOException { + private void handleError(final StompSession session, Throwable t) { + logger.error("Terminating STOMP session due to failure to send message: ", t); + sendErrorMessage(session, t.getMessage()); + if (removeSubscriptions(session.getId())) { + // TODO: send error event and including exception info + } + } + + private void sendErrorMessage(StompSession session, String errorText) { + StompHeaders headers = new StompHeaders(); + headers.setMessage(errorText); + StompMessage errorMessage = new StompMessage(StompCommand.ERROR, headers); + try { + session.sendMessage(errorMessage); + } + catch (Throwable t) { + // ignore + } + } + + protected void connect(StompSession session, StompMessage stompMessage) throws IOException { StompHeaders headers = new StompHeaders(); - Set acceptVersions = connectMessage.getHeaders().getAcceptVersion(); + Set acceptVersions = stompMessage.getHeaders().getAcceptVersion(); if (acceptVersions.contains("1.2")) { headers.setVersion("1.2"); } @@ -105,16 +135,19 @@ public class ReactorServerStompMessageProcessor implements StompMessageProcessor // TODO: security - this.reactor.notify(StompCommand.CONNECT, Fn.event(session.getId())); - session.sendMessage(new StompMessage(StompCommand.CONNECTED, headers)); + + this.reactor.notify(StompCommand.CONNECT, Fn.event(stompMessage)); } - protected void subscribe(final StompSession session, StompMessage message) { + protected void subscribe(final StompSession session, StompMessage stompMessage) { - final String subscription = message.getHeaders().getId(); + final String subscription = stompMessage.getHeaders().getId(); String replyToKey = StompCommand.SUBSCRIBE + ":" + session.getId() + ":" + subscription; + // TODO: extract and remember "ack" mode + // http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE_ack_Header + if (logger.isTraceEnabled()) { logger.trace("Adding subscription with replyToKey=" + replyToKey); } @@ -126,17 +159,19 @@ public class ReactorServerStompMessageProcessor implements StompMessageProcessor try { session.sendMessage(event.getData()); } - catch (IOException e) { - // TODO: stomp error, close session, websocket close status - ReactorServerStompMessageProcessor.this.removeSubscriptions(session.getId()); - e.printStackTrace(); + catch (Throwable t) { + handleError(session, t); } } }); addSubscription(session.getId(), registration); - this.reactor.notify(StompCommand.SUBSCRIBE, Fn.event(Tuple.of(session.getId(), message), replyToKey)); + this.reactor.notify(StompCommand.SUBSCRIBE, Fn.event(stompMessage, replyToKey)); + + // TODO: need a way to communicate back if subscription was successfully created or + // not in which case an ERROR should be sent back and close the connection + // http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE } private void addSubscription(String sessionId, Registration registration) { @@ -148,28 +183,39 @@ public class ReactorServerStompMessageProcessor implements StompMessageProcessor list.add(registration); } - protected void unsubscribe(StompSession session, StompMessage message) { - this.reactor.notify(StompCommand.UNSUBSCRIBE, Fn.event(Tuple.of(session.getId(), message))); + protected void unsubscribe(StompSession session, StompMessage stompMessage) { + this.reactor.notify(StompCommand.UNSUBSCRIBE, Fn.event(stompMessage)); } - protected void send(StompSession session, StompMessage message) { - this.reactor.notify(StompCommand.SEND, Fn.event(Tuple.of(session.getId(), message))); + protected void send(StompSession session, StompMessage stompMessage) { + this.reactor.notify(StompCommand.SEND, Fn.event(stompMessage)); } - protected void disconnect(StompSession session) { + protected void disconnect(StompSession session, StompMessage stompMessage) { String sessionId = session.getId(); removeSubscriptions(sessionId); - this.reactor.notify(StompCommand.DISCONNECT, Fn.event(sessionId)); + this.reactor.notify(StompCommand.DISCONNECT, Fn.event(stompMessage)); } - private void removeSubscriptions(String sessionId) { + private boolean removeSubscriptions(String sessionId) { List> registrations = this.subscriptionsBySession.remove(sessionId); + if (CollectionUtils.isEmpty(registrations)) { + return false; + } if (logger.isTraceEnabled()) { logger.trace("Cancelling " + registrations.size() + " subscriptions for session=" + sessionId); } for (Registration registration : registrations) { registration.cancel(); } + return true; + } + + @Override + public void processConnectionClosed(StompSession session) { + if (removeSubscriptions(session.getId())) { + // TODO: this implies abnormal closure from the underlying transport (no DISCONNECT) .. send an error event + } } } diff --git a/spring-websocket/src/main/java/org/springframework/web/stomp/server/SimpleStompReactorService.java b/spring-websocket/src/main/java/org/springframework/web/stomp/server/SimpleStompReactorService.java index 9d22507769e..294962d1436 100644 --- a/spring-websocket/src/main/java/org/springframework/web/stomp/server/SimpleStompReactorService.java +++ b/spring-websocket/src/main/java/org/springframework/web/stomp/server/SimpleStompReactorService.java @@ -32,7 +32,6 @@ import reactor.core.Reactor; import reactor.fn.Consumer; import reactor.fn.Event; import reactor.fn.Registration; -import reactor.fn.Tuple2; /** @@ -75,14 +74,12 @@ public class SimpleStompReactorService { } - private final class SubscribeConsumer implements Consumer>> { + private final class SubscribeConsumer implements Consumer> { @Override - public void accept(Event> event) { + public void accept(Event event) { - String sessionId = event.getData().getT1(); - StompMessage message = event.getData().getT2(); - final Object replyToKey = event.getReplyTo(); + StompMessage message = event.getData(); if (logger.isDebugEnabled()) { logger.debug("Subscribe " + message); @@ -97,19 +94,19 @@ public class SimpleStompReactorService { StompHeaders headers = new StompHeaders(); headers.setDestination(inMessage.getHeaders().getDestination()); StompMessage outMessage = new StompMessage(StompCommand.MESSAGE, headers, inMessage.getPayload()); - SimpleStompReactorService.this.reactor.notify(replyToKey, Fn.event(outMessage)); + SimpleStompReactorService.this.reactor.notify(event.getReplyTo(), Fn.event(outMessage)); } }); - addSubscription(sessionId, registration); + addSubscription(message.getStompSessionId(), registration); } } - private final class SendConsumer implements Consumer>> { + private final class SendConsumer implements Consumer> { @Override - public void accept(Event> event) { - StompMessage message = event.getData().getT2(); + public void accept(Event event) { + StompMessage message = event.getData(); logger.debug("Message received: " + message); String destination = message.getHeaders().getDestination();