Add MessageHolder

MessageHolder holds the currently processed message in a ThreadLocal,
which allows PubSubMessageBuilder to automatically add a session id
to messages to be sent.
This commit is contained in:
Rossen Stoyanchev 2013-06-19 15:30:23 -04:00
parent 5cfc59d76d
commit 44db0f815a
7 changed files with 150 additions and 77 deletions

View File

@ -28,10 +28,21 @@ import org.springframework.messaging.SubscribableChannel;
@SuppressWarnings("rawtypes")
public interface PubSubChannelRegistry<M extends Message, H extends MessageHandler<M>> {
/**
* A channel for messaging arriving from clients.
*/
SubscribableChannel<M, H> getClientInputChannel();
/**
* A channel for sending direct messages to a client. The client must be have
* previously subscribed to the destination of the message.
*/
SubscribableChannel<M, H> getClientOutputChannel();
/**
* A channel for broadcasting messages through a message broker.
*/
SubscribableChannel<M, H> getMessageBrokerChannel();
}

View File

@ -43,6 +43,7 @@ import org.springframework.web.messaging.annotation.SubscribeEvent;
import org.springframework.web.messaging.annotation.UnsubscribeEvent;
import org.springframework.web.messaging.converter.MessageConverter;
import org.springframework.web.messaging.service.AbstractPubSubMessageHandler;
import org.springframework.web.messaging.support.MessageHolder;
import org.springframework.web.messaging.support.PubSubHeaderAccesssor;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.method.HandlerMethodSelector;
@ -197,6 +198,8 @@ public class AnnotationPubSubMessageHandler<M extends Message> extends AbstractP
invocableHandlerMethod.setMessageMethodArgumentResolvers(this.argumentResolvers);
try {
MessageHolder.setMessage(message);
Object value = invocableHandlerMethod.invoke(message);
MethodParameter returnType = handlerMethod.getReturnType();
@ -205,12 +208,14 @@ public class AnnotationPubSubMessageHandler<M extends Message> extends AbstractP
}
this.returnValueHandlers.handleReturnValue(value, returnType, message);
}
catch (Throwable e) {
// TODO: send error message, or add @ExceptionHandler-like capability
e.printStackTrace();
}
finally {
MessageHolder.reset();
}
}
protected HandlerMethod getHandlerMethod(String destination, Map<MappingInfo, HandlerMethod> handlerMethods) {

View File

@ -20,8 +20,6 @@ import org.springframework.core.MethodParameter;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.util.Assert;
import org.springframework.web.messaging.support.PubSubHeaderAccesssor;
import org.springframework.web.messaging.support.SessionMessageChannel;
/**
@ -46,9 +44,7 @@ public class MessageChannelArgumentResolver<M extends Message> implements Argume
@Override
public Object resolveArgument(MethodParameter parameter, M message) throws Exception {
Assert.notNull(this.messageBrokerChannel, "messageBrokerChannel is required");
final String sessionId = PubSubHeaderAccesssor.wrap(message).getSessionId();
return new SessionMessageChannel<M>(this.messageBrokerChannel, sessionId);
return this.messageBrokerChannel;
}
}

View File

@ -16,6 +16,8 @@
package org.springframework.web.messaging.service.method;
import java.util.Map;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
@ -68,33 +70,30 @@ public class MessageReturnValueHandler<M extends Message> implements ReturnValue
return;
}
returnMessage = updateReturnMessage(returnMessage, message);
returnMessage = processReturnMessage(returnMessage, message);
this.clientChannel.send(returnMessage);
}
protected M updateReturnMessage(M returnMessage, M message) {
protected M processReturnMessage(M returnMessage, M message) {
PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message);
String sessionId = headers.getSessionId();
String subscriptionId = headers.getSubscriptionId();
Assert.notNull(subscriptionId, "No subscription id: " + message);
Assert.notNull(headers.getSubscriptionId(), "No subscription id: " + message);
PubSubHeaderAccesssor returnHeaders = PubSubHeaderAccesssor.wrap(returnMessage);
returnHeaders.setSessionId(sessionId);
returnHeaders.setSubscriptionId(subscriptionId);
returnHeaders.setSessionId(headers.getSessionId());
returnHeaders.setSubscriptionId(headers.getSubscriptionId());
if (returnHeaders.getDestination() == null) {
returnHeaders.setDestination(headers.getDestination());
}
return createMessage(returnHeaders, returnMessage.getPayload());
return createMessage(returnMessage.getPayload(), returnHeaders.toHeaders());
}
@SuppressWarnings("unchecked")
private M createMessage(PubSubHeaderAccesssor returnHeaders, Object payload) {
return (M) MessageBuilder.withPayload(payload).copyHeaders(returnHeaders.toHeaders()).build();
private M createMessage(Object payload, Map<String, Object> headers) {
return (M) MessageBuilder.withPayload(payload).copyHeaders(headers).build();
}
}

View File

@ -0,0 +1,45 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.support;
import org.springframework.core.NamedThreadLocal;
import org.springframework.messaging.Message;
/**
* @author Rossen Stoyanchev
* @since 4.0
*/
public class MessageHolder {
private static final NamedThreadLocal<Message<?>> messageHolder =
new NamedThreadLocal<Message<?>>("Current message");
public static void setMessage(Message<?> message) {
messageHolder.set(message);
}
public static Message<?> getMessage() {
return messageHolder.get();
}
public static void reset() {
messageHolder.remove();
}
}

View File

@ -0,0 +1,77 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.support;
import org.springframework.http.MediaType;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import reactor.util.Assert;
/**
* @author Rossen Stoyanchev
* @since 4.0
*/
public class PubSubMessageBuilder<T> {
private final PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.create();
private final T payload;
private PubSubMessageBuilder(T payload) {
Assert.notNull(payload, "<T> is required");
this.payload = payload;
}
public static <T> PubSubMessageBuilder<T> withPayload(T payload) {
return new PubSubMessageBuilder<T>(payload);
}
public PubSubMessageBuilder<T> destination(String destination) {
Assert.notNull(destination, "destination is required");
this.headers.setDestination(destination);
return this;
}
public PubSubMessageBuilder<T> contentType(MediaType contentType) {
Assert.notNull(contentType, "contentType is required");
this.headers.setContentType(contentType);
return this;
}
public PubSubMessageBuilder<T> contentType(String contentType) {
Assert.notNull(contentType, "contentType is required");
this.headers.setContentType(MediaType.parseMediaType(contentType));
return this;
}
public Message<T> build() {
Message<?> message = MessageHolder.getMessage();
if (message != null) {
String sessionId = PubSubHeaderAccesssor.wrap(message).getSessionId();
this.headers.setSessionId(sessionId);
}
return MessageBuilder.withPayload(this.payload).copyHeaders(this.headers.toHeaders()).build();
}
}

View File

@ -1,60 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.support;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.support.MessageBuilder;
import reactor.util.Assert;
/**
* @author Rossen Stoyanchev
* @since 4.0
*/
@SuppressWarnings("rawtypes")
public class SessionMessageChannel<M extends Message> implements MessageChannel<M> {
private MessageChannel<M> delegate;
private final String sessionId;
public SessionMessageChannel(MessageChannel<M> delegate, String sessionId) {
Assert.notNull(delegate, "delegate is required");
Assert.notNull(sessionId, "sessionId is required");
this.sessionId = sessionId;
this.delegate = delegate;
}
@Override
public boolean send(M message) {
return send(message, -1);
}
@Override
public boolean send(M message, long timeout) {
PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message);
headers.setSessionId(this.sessionId);
Object payload = message.getPayload();
@SuppressWarnings("unchecked")
M messageToSend = (M) MessageBuilder.withPayload(payload).copyHeaders(headers.toHeaders()).build();
this.delegate.send(messageToSend);
return true;
}
}