Return AbstractSubscribableChannel from @Bean methods

Declare SubscribableChannel @Beans in
WebSocketMessageBrokerConfigurationSupport as
AbstractSubscribableChannel to avoid the need for casting when
registering interceptors.

Issue: SPR-11065
This commit is contained in:
Rossen Stoyanchev 2013-11-07 20:47:03 -05:00
parent 0340cc5f03
commit 6a18daea33
3 changed files with 76 additions and 90 deletions

View File

@ -16,24 +16,15 @@
package org.springframework.messaging.simp.config; package org.springframework.messaging.simp.config;
import java.util.ArrayList;
import java.util.List;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler; import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler;
import org.springframework.messaging.simp.SimpMessageSendingOperations; import org.springframework.messaging.simp.SimpMessageSendingOperations;
import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.handler.*; import org.springframework.messaging.simp.handler.*;
import org.springframework.messaging.simp.handler.SimpAnnotationMethodMessageHandler; import org.springframework.messaging.support.channel.AbstractSubscribableChannel;
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel; import org.springframework.messaging.support.channel.ExecutorSubscribableChannel;
import org.springframework.messaging.support.converter.ByteArrayMessageConverter; import org.springframework.messaging.support.converter.*;
import org.springframework.messaging.support.converter.CompositeMessageConverter;
import org.springframework.messaging.support.converter.DefaultContentTypeResolver;
import org.springframework.messaging.support.converter.MappingJackson2MessageConverter;
import org.springframework.messaging.support.converter.MessageConverter;
import org.springframework.messaging.support.converter.StringMessageConverter;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
@ -44,6 +35,9 @@ import org.springframework.web.servlet.handler.AbstractHandlerMapping;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.config.SockJsServiceRegistration; import org.springframework.web.socket.server.config.SockJsServiceRegistration;
import java.util.ArrayList;
import java.util.List;
/** /**
* Configuration support for broker-backed messaging over WebSocket using a higher-level * Configuration support for broker-backed messaging over WebSocket using a higher-level
@ -118,12 +112,12 @@ public abstract class WebSocketMessageBrokerConfigurationSupport {
} }
@Bean @Bean
public SubscribableChannel webSocketRequestChannel() { public AbstractSubscribableChannel webSocketRequestChannel() {
return new ExecutorSubscribableChannel(webSocketChannelExecutor()); return new ExecutorSubscribableChannel(webSocketChannelExecutor());
} }
@Bean @Bean
public SubscribableChannel webSocketResponseChannel() { public AbstractSubscribableChannel webSocketResponseChannel() {
return new ExecutorSubscribableChannel(webSocketChannelExecutor()); return new ExecutorSubscribableChannel(webSocketChannelExecutor());
} }
@ -209,7 +203,7 @@ public abstract class WebSocketMessageBrokerConfigurationSupport {
} }
@Bean @Bean
public SubscribableChannel brokerChannel() { public AbstractSubscribableChannel brokerChannel() {
return new ExecutorSubscribableChannel(); // synchronous return new ExecutorSubscribableChannel(); // synchronous
} }

View File

@ -16,19 +16,13 @@
package org.springframework.messaging.simp.config; package org.springframework.messaging.simp.config;
import java.util.List;
import java.util.Map;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandler; import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.handler.annotation.SendTo;
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler; import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler;
@ -43,6 +37,8 @@ import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.stomp.StompTextMessageBuilder; import org.springframework.messaging.simp.stomp.StompTextMessageBuilder;
import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.channel.AbstractSubscribableChannel;
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel;
import org.springframework.messaging.support.converter.CompositeMessageConverter; import org.springframework.messaging.support.converter.CompositeMessageConverter;
import org.springframework.messaging.support.converter.DefaultContentTypeResolver; import org.springframework.messaging.support.converter.DefaultContentTypeResolver;
import org.springframework.stereotype.Controller; import org.springframework.stereotype.Controller;
@ -52,9 +48,11 @@ import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.support.TestWebSocketSession; import org.springframework.web.socket.support.TestWebSocketSession;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.*; import static org.junit.Assert.*;
import static org.mockito.Matchers.*;
import static org.mockito.Mockito.*;
/** /**
@ -95,27 +93,20 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test @Test
public void webSocketRequestChannel() { public void webSocketRequestChannel() {
SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketRequestChannel", SubscribableChannel.class); TestChannel channel = this.cxtSimpleBroker.getBean("webSocketRequestChannel", TestChannel.class);
List<MessageHandler> handlers = channel.handlers;
ArgumentCaptor<MessageHandler> captor = ArgumentCaptor.forClass(MessageHandler.class); assertEquals(3, handlers.size());
verify(channel, times(3)).subscribe(captor.capture()); assertTrue(handlers.contains(cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class)));
assertTrue(handlers.contains(cxtSimpleBroker.getBean(UserDestinationMessageHandler.class)));
List<MessageHandler> values = captor.getAllValues(); assertTrue(handlers.contains(cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class)));
assertEquals(3, values.size());
assertTrue(values.contains(cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class)));
assertTrue(values.contains(cxtSimpleBroker.getBean(UserDestinationMessageHandler.class)));
assertTrue(values.contains(cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class)));
} }
@Test @Test
public void webSocketRequestChannelWithStompBroker() { public void webSocketRequestChannelWithStompBroker() {
SubscribableChannel channel = this.cxtStompBroker.getBean("webSocketRequestChannel", SubscribableChannel.class); TestChannel channel = this.cxtStompBroker.getBean("webSocketRequestChannel", TestChannel.class);
List<MessageHandler> values = channel.handlers;
ArgumentCaptor<MessageHandler> captor = ArgumentCaptor.forClass(MessageHandler.class);
verify(channel, times(3)).subscribe(captor.capture());
List<MessageHandler> values = captor.getAllValues();
assertEquals(3, values.size()); assertEquals(3, values.size());
assertTrue(values.contains(cxtStompBroker.getBean(SimpAnnotationMethodMessageHandler.class))); assertTrue(values.contains(cxtStompBroker.getBean(SimpAnnotationMethodMessageHandler.class)));
assertTrue(values.contains(cxtStompBroker.getBean(UserDestinationMessageHandler.class))); assertTrue(values.contains(cxtStompBroker.getBean(UserDestinationMessageHandler.class)));
@ -125,16 +116,13 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test @Test
public void webSocketRequestChannelSendMessage() throws Exception { public void webSocketRequestChannelSendMessage() throws Exception {
SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketRequestChannel", SubscribableChannel.class); TestChannel channel = this.cxtSimpleBroker.getBean("webSocketRequestChannel", TestChannel.class);
SubProtocolWebSocketHandler webSocketHandler = this.cxtSimpleBroker.getBean(SubProtocolWebSocketHandler.class); SubProtocolWebSocketHandler webSocketHandler = this.cxtSimpleBroker.getBean(SubProtocolWebSocketHandler.class);
TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.SEND).headers("destination:/foo").build(); TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.SEND).headers("destination:/foo").build();
webSocketHandler.handleMessage(new TestWebSocketSession(), textMessage); webSocketHandler.handleMessage(new TestWebSocketSession(), textMessage);
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class); Message<?> message = channel.messages.get(0);
verify(channel).send(captor.capture());
Message message = captor.getValue();
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType()); assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
@ -143,15 +131,17 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test @Test
public void webSocketResponseChannel() { public void webSocketResponseChannel() {
SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", SubscribableChannel.class); TestChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", TestChannel.class);
verify(channel).subscribe(any(SubProtocolWebSocketHandler.class)); List<MessageHandler> values = channel.handlers;
verifyNoMoreInteractions(channel);
assertEquals(1, values.size());
assertTrue(values.get(0) instanceof SubProtocolWebSocketHandler);
} }
@Test @Test
public void webSocketResponseChannelUsedByAnnotatedMethod() { public void webSocketResponseChannelUsedByAnnotatedMethod() {
SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", SubscribableChannel.class); TestChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", TestChannel.class);
SimpAnnotationMethodMessageHandler messageHandler = this.cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class); SimpAnnotationMethodMessageHandler messageHandler = this.cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
@ -160,12 +150,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
headers.setDestination("/foo"); headers.setDestination("/foo");
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
when(channel.send(any(Message.class))).thenReturn(true);
messageHandler.handleMessage(message); messageHandler.handleMessage(message);
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class); message = channel.messages.get(0);
verify(channel).send(captor.capture());
message = captor.getValue();
headers = StompHeaderAccessor.wrap(message); headers = StompHeaderAccessor.wrap(message);
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType()); assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
@ -175,7 +162,7 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test @Test
public void webSocketResponseChannelUsedBySimpleBroker() { public void webSocketResponseChannelUsedBySimpleBroker() {
SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", SubscribableChannel.class); TestChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", TestChannel.class);
SimpleBrokerMessageHandler broker = this.cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class); SimpleBrokerMessageHandler broker = this.cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
@ -193,12 +180,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
message = MessageBuilder.withPayload("bar".getBytes()).setHeaders(headers).build(); message = MessageBuilder.withPayload("bar".getBytes()).setHeaders(headers).build();
// message // message
when(channel.send(any(Message.class))).thenReturn(true);
broker.handleMessage(message); broker.handleMessage(message);
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class); message = channel.messages.get(0);
verify(channel).send(captor.capture());
message = captor.getValue();
headers = StompHeaderAccessor.wrap(message); headers = StompHeaderAccessor.wrap(message);
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType()); assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
@ -208,45 +192,36 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test @Test
public void brokerChannel() { public void brokerChannel() {
SubscribableChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", SubscribableChannel.class); TestChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", TestChannel.class);
List<MessageHandler> handlers = channel.handlers;
ArgumentCaptor<MessageHandler> captor = ArgumentCaptor.forClass(MessageHandler.class); assertEquals(2, handlers.size());
verify(channel, times(2)).subscribe(captor.capture()); assertTrue(handlers.contains(cxtSimpleBroker.getBean(UserDestinationMessageHandler.class)));
assertTrue(handlers.contains(cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class)));
List<MessageHandler> values = captor.getAllValues();
assertEquals(2, values.size());
assertTrue(values.contains(cxtSimpleBroker.getBean(UserDestinationMessageHandler.class)));
assertTrue(values.contains(cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class)));
} }
@Test @Test
public void brokerChannelWithStompBroker() { public void brokerChannelWithStompBroker() {
SubscribableChannel channel = this.cxtStompBroker.getBean("brokerChannel", SubscribableChannel.class); TestChannel channel = this.cxtStompBroker.getBean("brokerChannel", TestChannel.class);
List<MessageHandler> handlers = channel.handlers;
ArgumentCaptor<MessageHandler> captor = ArgumentCaptor.forClass(MessageHandler.class); assertEquals(2, handlers.size());
verify(channel, times(2)).subscribe(captor.capture()); assertTrue(handlers.contains(cxtStompBroker.getBean(UserDestinationMessageHandler.class)));
assertTrue(handlers.contains(cxtStompBroker.getBean(StompBrokerRelayMessageHandler.class)));
List<MessageHandler> values = captor.getAllValues();
assertEquals(2, values.size());
assertTrue(values.contains(cxtStompBroker.getBean(UserDestinationMessageHandler.class)));
assertTrue(values.contains(cxtStompBroker.getBean(StompBrokerRelayMessageHandler.class)));
} }
@Test @Test
public void brokerChannelUsedByAnnotatedMethod() { public void brokerChannelUsedByAnnotatedMethod() {
SubscribableChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", SubscribableChannel.class); TestChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", TestChannel.class);
SimpAnnotationMethodMessageHandler messageHandler = this.cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class); SimpAnnotationMethodMessageHandler messageHandler = this.cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.setDestination("/foo"); headers.setDestination("/foo");
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
when(channel.send(any(Message.class))).thenReturn(true);
messageHandler.handleMessage(message); messageHandler.handleMessage(message);
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class); message = channel.messages.get(0);
verify(channel).send(captor.capture());
message = captor.getValue();
headers = StompHeaderAccessor.wrap(message); headers = StompHeaderAccessor.wrap(message);
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType()); assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
@ -256,7 +231,7 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test @Test
public void brokerChannelUsedByUserDestinationMessageHandler() { public void brokerChannelUsedByUserDestinationMessageHandler() {
SubscribableChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", SubscribableChannel.class); TestChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", TestChannel.class);
UserDestinationMessageHandler messageHandler = this.cxtSimpleBroker.getBean(UserDestinationMessageHandler.class); UserDestinationMessageHandler messageHandler = this.cxtSimpleBroker.getBean(UserDestinationMessageHandler.class);
this.cxtSimpleBroker.getBean(UserSessionRegistry.class).registerSessionId("joe", "s1"); this.cxtSimpleBroker.getBean(UserSessionRegistry.class).registerSessionId("joe", "s1");
@ -265,12 +240,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
headers.setDestination("/user/joe/foo"); headers.setDestination("/user/joe/foo");
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
when(channel.send(any(Message.class))).thenReturn(true);
messageHandler.handleMessage(message); messageHandler.handleMessage(message);
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class); message = channel.messages.get(0);
verify(channel).send(captor.capture());
message = captor.getValue();
headers = StompHeaderAccessor.wrap(message); headers = StompHeaderAccessor.wrap(message);
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType()); assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
@ -340,19 +312,39 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Override @Override
@Bean @Bean
public SubscribableChannel webSocketRequestChannel() { public AbstractSubscribableChannel webSocketRequestChannel() {
return Mockito.mock(SubscribableChannel.class); return new TestChannel();
} }
@Override @Override
@Bean @Bean
public SubscribableChannel webSocketResponseChannel() { public AbstractSubscribableChannel webSocketResponseChannel() {
return Mockito.mock(SubscribableChannel.class); return new TestChannel();
} }
@Override @Override
public SubscribableChannel brokerChannel() { public AbstractSubscribableChannel brokerChannel() {
return Mockito.mock(SubscribableChannel.class); return new TestChannel();
}
}
private static class TestChannel extends ExecutorSubscribableChannel {
private final List<MessageHandler> handlers = new ArrayList<>();
private final List<Message<?>> messages = new ArrayList<>();
@Override
public boolean subscribeInternal(MessageHandler handler) {
this.handlers.add(handler);
return super.subscribeInternal(handler);
}
@Override
public boolean sendInternal(Message<?> message, long timeout) {
this.messages.add(message);
return true;
} }
} }

View File

@ -34,7 +34,6 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.handler.annotation.MessageExceptionHandler; import org.springframework.messaging.handler.annotation.MessageExceptionHandler;
import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.simp.config.DelegatingWebSocketMessageBrokerConfiguration; import org.springframework.messaging.simp.config.DelegatingWebSocketMessageBrokerConfiguration;
@ -42,6 +41,7 @@ import org.springframework.messaging.simp.config.MessageBrokerConfigurer;
import org.springframework.messaging.simp.config.StompEndpointRegistry; import org.springframework.messaging.simp.config.StompEndpointRegistry;
import org.springframework.messaging.simp.config.WebSocketMessageBrokerConfigurer; import org.springframework.messaging.simp.config.WebSocketMessageBrokerConfigurer;
import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.support.channel.AbstractSubscribableChannel;
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel; import org.springframework.messaging.support.channel.ExecutorSubscribableChannel;
import org.springframework.stereotype.Controller; import org.springframework.stereotype.Controller;
import org.springframework.web.socket.AbstractWebSocketIntegrationTests; import org.springframework.web.socket.AbstractWebSocketIntegrationTests;
@ -227,13 +227,13 @@ public class SimpAnnotationMethodIntegrationTests extends AbstractWebSocketInteg
@Override @Override
@Bean @Bean
public SubscribableChannel webSocketRequestChannel() { public AbstractSubscribableChannel webSocketRequestChannel() {
return new ExecutorSubscribableChannel(); // synchronous return new ExecutorSubscribableChannel(); // synchronous
} }
@Override @Override
@Bean @Bean
public SubscribableChannel webSocketResponseChannel() { public AbstractSubscribableChannel webSocketResponseChannel() {
return new ExecutorSubscribableChannel(); // synchronous return new ExecutorSubscribableChannel(); // synchronous
} }
} }