Use tcp-reactor in StompRelayMessageHandler

This commit is contained in:
Rossen Stoyanchev 2013-06-13 21:48:51 -04:00
parent 78d1063e37
commit 641aaf4b6a
3 changed files with 94 additions and 130 deletions

View File

@ -482,13 +482,15 @@ project("spring-websocket") {
}
optional("org.eclipse.jetty.websocket:websocket-server:9.0.3.v20130506")
optional("org.eclipse.jetty.websocket:websocket-client:9.0.3.v20130506")
optional("com.fasterxml.jackson.core:jackson-databind:2.2.0") // required for SockJS support currently
optional("reactor:reactor-core:1.0.0.BUILD-SNAPSHOT")
optional("com.fasterxml.jackson.core:jackson-databind:2.2.0") // currently needed for SockJS support
optional("reactor:reactor-core:1.0.0.BUILD-SNAPSHOT") // STOMP message processing
optional("reactor:reactor-tcp:1.0.0.BUILD-SNAPSHOT") // STOMP relay to message broker
}
repositories {
maven { url "https://repository.apache.org/content/repositories/snapshots" } // tomcat-websocket-* snapshots
maven { url "https://maven.java.net/content/repositories/releases" } // javax.websocket, tyrus
mavenLocal() // temporary workaround for locally installed (latest) reactor
maven { url 'http://repo.springsource.org/snapshot' } // reactor
}
}

View File

@ -25,6 +25,8 @@ import org.springframework.http.MediaType;
*/
public class ContentTypeNotSupportedException extends Exception {
private static final long serialVersionUID = -3597879520747071896L;
private final MediaType mediaType;
private final Class<?> sourceOrTargetType;

View File

@ -16,21 +16,12 @@
package org.springframework.web.messaging.stomp.support;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.nio.charset.Charset;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import javax.net.SocketFactory;
import org.springframework.core.task.TaskExecutor;
import org.springframework.http.MediaType;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
@ -45,6 +36,15 @@ import org.springframework.web.messaging.service.AbstractPubSubMessageHandler;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompHeaders;
import reactor.core.Environment;
import reactor.core.Promise;
import reactor.fn.Consumer;
import reactor.tcp.TcpClient;
import reactor.tcp.TcpConnection;
import reactor.tcp.encoding.DelimitedCodec;
import reactor.tcp.encoding.StandardCodecs;
import reactor.tcp.netty.NettyTcpClient;
/**
* @author Rossen Stoyanchev
@ -57,19 +57,22 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
private MessageConverter payloadConverter;
private final TaskExecutor taskExecutor;
private final TcpClient<String, String> tcpClient;
private Map<String, RelaySession> relaySessions = new ConcurrentHashMap<String, RelaySession>();
private final Map<String, TcpConnection<String, String>> connections =
new ConcurrentHashMap<String, TcpConnection<String, String>>();
/**
* @param executor
*/
public StompRelayPubSubMessageHandler(SubscribableChannel publishChannel, MessageChannel clientChannel,
TaskExecutor executor) {
public StompRelayPubSubMessageHandler(SubscribableChannel publishChannel, MessageChannel clientChannel) {
super(publishChannel, clientChannel);
this.taskExecutor = executor; // For now, a naive way to manage socket reading
this.tcpClient = new TcpClient.Spec<String, String>(NettyTcpClient.class)
.using(new Environment())
.codec(new DelimitedCodec<String, String>((byte) 0, StandardCodecs.STRING_CODEC))
.connect("127.0.0.1", 61613)
.get();
this.payloadConverter = new CompositeMessageConverter(null);
}
@ -84,34 +87,52 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
}
@Override
public void handleConnect(Message<?> message) {
public void handleConnect(final Message<?> message) {
String sessionId = (String) message.getHeaders().get(PubSubHeaders.SESSION_ID);
final String sessionId = (String) message.getHeaders().get(PubSubHeaders.SESSION_ID);
RelaySession session = new RelaySession();
this.relaySessions.put(sessionId, session);
Promise<TcpConnection<String, String>> promise = this.tcpClient.open();
try {
Socket socket = SocketFactory.getDefault().createSocket("127.0.0.1", 61613);
session.setSocket(socket);
promise.onSuccess(new Consumer<TcpConnection<String,String>>() {
@Override
public void accept(TcpConnection<String, String> connection) {
connections.put(sessionId, connection);
forwardMessage(message, StompCommand.CONNECT);
}
});
forwardMessage(message, StompCommand.CONNECT);
promise.consume(new Consumer<TcpConnection<String,String>>() {
@Override
public void accept(TcpConnection<String, String> connection) {
connection.in().consume(new Consumer<String>() {
@Override
public void accept(String stompFrame) {
if (stompFrame.isEmpty()) {
// TODO: why are we getting empty frames?
return;
}
Message<byte[]> message = stompMessageConverter.toMessage(stompFrame, sessionId);
getClientChannel().send(message);
}
});
}
});
// TODO: ATM no way to detect closed socket
// StompHeaders stompHeaders = StompHeaders.create(StompCommand.ERROR);
// stompHeaders.setMessage("Socket closed, STOMP session=" + sessionId);
// stompHeaders.setSessionId(sessionId);
// Message<byte[]> errorMessage = new GenericMessage<byte[]>(new byte[0], stompHeaders.toMessageHeaders());
// getClientChannel().send(errorMessage);
RelayReadTask readTask = new RelayReadTask(sessionId, session);
this.taskExecutor.execute(readTask);
}
catch (Throwable t) {
t.printStackTrace();
clearRelaySession(sessionId);
}
}
private void forwardMessage(Message<?> message, StompCommand command) {
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
String sessionId = stompHeaders.getSessionId();
RelaySession session = StompRelayPubSubMessageHandler.this.relaySessions.get(sessionId);
Assert.notNull(session, "RelaySession not found");
byte[] bytesToWrite;
try {
stompHeaders.setStompCommandIfNotSet(StompCommand.SEND);
@ -119,34 +140,48 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
MediaType contentType = stompHeaders.getContentType();
byte[] payload = this.payloadConverter.convertToPayload(message.getPayload(), contentType);
Message<byte[]> byteMessage = new GenericMessage<byte[]>(payload, stompHeaders.toMessageHeaders());
bytesToWrite = this.stompMessageConverter.fromMessage(byteMessage);
}
catch (Throwable ex) {
logger.error("Failed to forward message " + message, ex);
return;
}
TcpConnection<String, String> connection = getConnection(sessionId);
Assert.notNull(connection, "TCP connection to message broker not found, sessionId=" + sessionId);
try {
if (logger.isTraceEnabled()) {
logger.trace("Forwarding STOMP " + stompHeaders.getStompCommand() + " message");
}
byte[] bytes = this.stompMessageConverter.fromMessage(byteMessage);
session.getOutputStream().write(bytes);
session.getOutputStream().flush();
connection.out().accept(new String(bytesToWrite, Charset.forName("UTF-8")));
}
catch (Exception ex) {
logger.error("Couldn't forward message " + message, ex);
clearRelaySession(sessionId);
}
}
private void clearRelaySession(String stompSessionId) {
RelaySession relaySession = this.relaySessions.remove(stompSessionId);
if (relaySession != null) {
// TODO: raise failure event so client session can be closed
catch (Throwable ex) {
logger.error("Could not get TCP connection " + sessionId, ex);
try {
relaySession.getSocket().close();
if (connection != null) {
connection.close();
}
}
catch (IOException e) {
catch (Throwable t) {
// ignore
}
}
}
private TcpConnection<String, String> getConnection(String sessionId) {
TcpConnection<String, String> connection = this.connections.get(sessionId);
if (connection == null) {
try {
Thread.sleep(1000);
}
catch (InterruptedException e) {
return null;
}
}
connection = this.connections.get(sessionId);
return connection;
}
@Override
public void handlePublish(Message<?> message) {
forwardMessage(message, StompCommand.SEND);
@ -174,6 +209,8 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
forwardMessage(message, command);
}
// TODO:
/* @Override
public void handleClientConnectionClosed(String sessionId) {
if (logger.isDebugEnabled()) {
@ -183,81 +220,4 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
}
*/
private final static class RelaySession {
private Socket socket;
private InputStream inputStream;
private OutputStream outputStream;
public void setSocket(Socket socket) throws IOException {
this.socket = socket;
this.inputStream = new BufferedInputStream(socket.getInputStream());
this.outputStream = new BufferedOutputStream(socket.getOutputStream());
}
public Socket getSocket() {
return this.socket;
}
public InputStream getInputStream() {
return this.inputStream;
}
public OutputStream getOutputStream() {
return this.outputStream;
}
}
private final class RelayReadTask implements Runnable {
private final String sessionId;
private final RelaySession session;
private RelayReadTask(String sessionId, RelaySession session) {
this.sessionId = sessionId;
this.session = session;
}
@Override
public void run() {
try {
ByteArrayOutputStream out = new ByteArrayOutputStream();
while (!session.getSocket().isClosed()) {
int b = session.getInputStream().read();
if (b == -1) {
break;
}
else if (b == 0x00) {
byte[] bytes = out.toByteArray();
Message<byte[]> message = stompMessageConverter.toMessage(bytes, sessionId);
getClientChannel().send(message);
out.reset();
}
else {
out.write(b);
}
}
logger.debug("Socket closed, STOMP session=" + sessionId);
sendErrorMessage("Lost connection");
}
catch (IOException e) {
logger.error("Socket error: " + e.getMessage());
clearRelaySession(sessionId);
sendErrorMessage("Lost connection");
}
}
private void sendErrorMessage(String message) {
StompHeaders stompHeaders = StompHeaders.create(StompCommand.ERROR);
stompHeaders.setMessage(message);
stompHeaders.setSessionId(this.sessionId);
Message<byte[]> errorMessage = new GenericMessage<byte[]>(new byte[0], stompHeaders.toMessageHeaders());
getClientChannel().send(errorMessage);
}
}
}