Return AbstractSubscribableChannel from @Bean methods

Declare SubscribableChannel @Beans in
WebSocketMessageBrokerConfigurationSupport as
AbstractSubscribableChannel to avoid the need for casting when
registering interceptors.

Issue: SPR-11065
This commit is contained in:
Rossen Stoyanchev 2013-11-07 20:47:03 -05:00
parent 0340cc5f03
commit 6a18daea33
3 changed files with 76 additions and 90 deletions

View File

@ -16,24 +16,15 @@
package org.springframework.messaging.simp.config;
import java.util.ArrayList;
import java.util.List;
import org.springframework.context.annotation.Bean;
import org.springframework.messaging.Message;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler;
import org.springframework.messaging.simp.SimpMessageSendingOperations;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.handler.*;
import org.springframework.messaging.simp.handler.SimpAnnotationMethodMessageHandler;
import org.springframework.messaging.support.channel.AbstractSubscribableChannel;
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel;
import org.springframework.messaging.support.converter.ByteArrayMessageConverter;
import org.springframework.messaging.support.converter.CompositeMessageConverter;
import org.springframework.messaging.support.converter.DefaultContentTypeResolver;
import org.springframework.messaging.support.converter.MappingJackson2MessageConverter;
import org.springframework.messaging.support.converter.MessageConverter;
import org.springframework.messaging.support.converter.StringMessageConverter;
import org.springframework.messaging.support.converter.*;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.ClassUtils;
@ -44,6 +35,9 @@ import org.springframework.web.servlet.handler.AbstractHandlerMapping;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.config.SockJsServiceRegistration;
import java.util.ArrayList;
import java.util.List;
/**
* Configuration support for broker-backed messaging over WebSocket using a higher-level
@ -118,12 +112,12 @@ public abstract class WebSocketMessageBrokerConfigurationSupport {
}
@Bean
public SubscribableChannel webSocketRequestChannel() {
public AbstractSubscribableChannel webSocketRequestChannel() {
return new ExecutorSubscribableChannel(webSocketChannelExecutor());
}
@Bean
public SubscribableChannel webSocketResponseChannel() {
public AbstractSubscribableChannel webSocketResponseChannel() {
return new ExecutorSubscribableChannel(webSocketChannelExecutor());
}
@ -209,7 +203,7 @@ public abstract class WebSocketMessageBrokerConfigurationSupport {
}
@Bean
public SubscribableChannel brokerChannel() {
public AbstractSubscribableChannel brokerChannel() {
return new ExecutorSubscribableChannel(); // synchronous
}

View File

@ -16,19 +16,13 @@
package org.springframework.messaging.simp.config;
import java.util.List;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.SendTo;
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler;
@ -43,6 +37,8 @@ import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.stomp.StompTextMessageBuilder;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.channel.AbstractSubscribableChannel;
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel;
import org.springframework.messaging.support.converter.CompositeMessageConverter;
import org.springframework.messaging.support.converter.DefaultContentTypeResolver;
import org.springframework.stereotype.Controller;
@ -52,9 +48,11 @@ import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.support.TestWebSocketSession;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.*;
import static org.mockito.Matchers.*;
import static org.mockito.Mockito.*;
/**
@ -95,27 +93,20 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test
public void webSocketRequestChannel() {
SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketRequestChannel", SubscribableChannel.class);
TestChannel channel = this.cxtSimpleBroker.getBean("webSocketRequestChannel", TestChannel.class);
List<MessageHandler> handlers = channel.handlers;
ArgumentCaptor<MessageHandler> captor = ArgumentCaptor.forClass(MessageHandler.class);
verify(channel, times(3)).subscribe(captor.capture());
List<MessageHandler> values = captor.getAllValues();
assertEquals(3, values.size());
assertTrue(values.contains(cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class)));
assertTrue(values.contains(cxtSimpleBroker.getBean(UserDestinationMessageHandler.class)));
assertTrue(values.contains(cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class)));
assertEquals(3, handlers.size());
assertTrue(handlers.contains(cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class)));
assertTrue(handlers.contains(cxtSimpleBroker.getBean(UserDestinationMessageHandler.class)));
assertTrue(handlers.contains(cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class)));
}
@Test
public void webSocketRequestChannelWithStompBroker() {
SubscribableChannel channel = this.cxtStompBroker.getBean("webSocketRequestChannel", SubscribableChannel.class);
TestChannel channel = this.cxtStompBroker.getBean("webSocketRequestChannel", TestChannel.class);
List<MessageHandler> values = channel.handlers;
ArgumentCaptor<MessageHandler> captor = ArgumentCaptor.forClass(MessageHandler.class);
verify(channel, times(3)).subscribe(captor.capture());
List<MessageHandler> values = captor.getAllValues();
assertEquals(3, values.size());
assertTrue(values.contains(cxtStompBroker.getBean(SimpAnnotationMethodMessageHandler.class)));
assertTrue(values.contains(cxtStompBroker.getBean(UserDestinationMessageHandler.class)));
@ -125,16 +116,13 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test
public void webSocketRequestChannelSendMessage() throws Exception {
SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketRequestChannel", SubscribableChannel.class);
TestChannel channel = this.cxtSimpleBroker.getBean("webSocketRequestChannel", TestChannel.class);
SubProtocolWebSocketHandler webSocketHandler = this.cxtSimpleBroker.getBean(SubProtocolWebSocketHandler.class);
TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.SEND).headers("destination:/foo").build();
webSocketHandler.handleMessage(new TestWebSocketSession(), textMessage);
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
verify(channel).send(captor.capture());
Message message = captor.getValue();
Message<?> message = channel.messages.get(0);
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
@ -143,15 +131,17 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test
public void webSocketResponseChannel() {
SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", SubscribableChannel.class);
verify(channel).subscribe(any(SubProtocolWebSocketHandler.class));
verifyNoMoreInteractions(channel);
TestChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", TestChannel.class);
List<MessageHandler> values = channel.handlers;
assertEquals(1, values.size());
assertTrue(values.get(0) instanceof SubProtocolWebSocketHandler);
}
@Test
public void webSocketResponseChannelUsedByAnnotatedMethod() {
SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", SubscribableChannel.class);
TestChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", TestChannel.class);
SimpAnnotationMethodMessageHandler messageHandler = this.cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
@ -160,12 +150,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
headers.setDestination("/foo");
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
when(channel.send(any(Message.class))).thenReturn(true);
messageHandler.handleMessage(message);
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
verify(channel).send(captor.capture());
message = captor.getValue();
message = channel.messages.get(0);
headers = StompHeaderAccessor.wrap(message);
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
@ -175,7 +162,7 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test
public void webSocketResponseChannelUsedBySimpleBroker() {
SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", SubscribableChannel.class);
TestChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", TestChannel.class);
SimpleBrokerMessageHandler broker = this.cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
@ -193,12 +180,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
message = MessageBuilder.withPayload("bar".getBytes()).setHeaders(headers).build();
// message
when(channel.send(any(Message.class))).thenReturn(true);
broker.handleMessage(message);
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
verify(channel).send(captor.capture());
message = captor.getValue();
message = channel.messages.get(0);
headers = StompHeaderAccessor.wrap(message);
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
@ -208,45 +192,36 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test
public void brokerChannel() {
SubscribableChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", SubscribableChannel.class);
TestChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", TestChannel.class);
List<MessageHandler> handlers = channel.handlers;
ArgumentCaptor<MessageHandler> captor = ArgumentCaptor.forClass(MessageHandler.class);
verify(channel, times(2)).subscribe(captor.capture());
List<MessageHandler> values = captor.getAllValues();
assertEquals(2, values.size());
assertTrue(values.contains(cxtSimpleBroker.getBean(UserDestinationMessageHandler.class)));
assertTrue(values.contains(cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class)));
assertEquals(2, handlers.size());
assertTrue(handlers.contains(cxtSimpleBroker.getBean(UserDestinationMessageHandler.class)));
assertTrue(handlers.contains(cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class)));
}
@Test
public void brokerChannelWithStompBroker() {
SubscribableChannel channel = this.cxtStompBroker.getBean("brokerChannel", SubscribableChannel.class);
TestChannel channel = this.cxtStompBroker.getBean("brokerChannel", TestChannel.class);
List<MessageHandler> handlers = channel.handlers;
ArgumentCaptor<MessageHandler> captor = ArgumentCaptor.forClass(MessageHandler.class);
verify(channel, times(2)).subscribe(captor.capture());
List<MessageHandler> values = captor.getAllValues();
assertEquals(2, values.size());
assertTrue(values.contains(cxtStompBroker.getBean(UserDestinationMessageHandler.class)));
assertTrue(values.contains(cxtStompBroker.getBean(StompBrokerRelayMessageHandler.class)));
assertEquals(2, handlers.size());
assertTrue(handlers.contains(cxtStompBroker.getBean(UserDestinationMessageHandler.class)));
assertTrue(handlers.contains(cxtStompBroker.getBean(StompBrokerRelayMessageHandler.class)));
}
@Test
public void brokerChannelUsedByAnnotatedMethod() {
SubscribableChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", SubscribableChannel.class);
TestChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", TestChannel.class);
SimpAnnotationMethodMessageHandler messageHandler = this.cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.setDestination("/foo");
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
when(channel.send(any(Message.class))).thenReturn(true);
messageHandler.handleMessage(message);
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
verify(channel).send(captor.capture());
message = captor.getValue();
message = channel.messages.get(0);
headers = StompHeaderAccessor.wrap(message);
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
@ -256,7 +231,7 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Test
public void brokerChannelUsedByUserDestinationMessageHandler() {
SubscribableChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", SubscribableChannel.class);
TestChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", TestChannel.class);
UserDestinationMessageHandler messageHandler = this.cxtSimpleBroker.getBean(UserDestinationMessageHandler.class);
this.cxtSimpleBroker.getBean(UserSessionRegistry.class).registerSessionId("joe", "s1");
@ -265,12 +240,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
headers.setDestination("/user/joe/foo");
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
when(channel.send(any(Message.class))).thenReturn(true);
messageHandler.handleMessage(message);
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
verify(channel).send(captor.capture());
message = captor.getValue();
message = channel.messages.get(0);
headers = StompHeaderAccessor.wrap(message);
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
@ -340,19 +312,39 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Override
@Bean
public SubscribableChannel webSocketRequestChannel() {
return Mockito.mock(SubscribableChannel.class);
public AbstractSubscribableChannel webSocketRequestChannel() {
return new TestChannel();
}
@Override
@Bean
public SubscribableChannel webSocketResponseChannel() {
return Mockito.mock(SubscribableChannel.class);
public AbstractSubscribableChannel webSocketResponseChannel() {
return new TestChannel();
}
@Override
public SubscribableChannel brokerChannel() {
return Mockito.mock(SubscribableChannel.class);
public AbstractSubscribableChannel brokerChannel() {
return new TestChannel();
}
}
private static class TestChannel extends ExecutorSubscribableChannel {
private final List<MessageHandler> handlers = new ArrayList<>();
private final List<Message<?>> messages = new ArrayList<>();
@Override
public boolean subscribeInternal(MessageHandler handler) {
this.handlers.add(handler);
return super.subscribeInternal(handler);
}
@Override
public boolean sendInternal(Message<?> message, long timeout) {
this.messages.add(message);
return true;
}
}

View File

@ -34,7 +34,6 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.handler.annotation.MessageExceptionHandler;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.simp.config.DelegatingWebSocketMessageBrokerConfiguration;
@ -42,6 +41,7 @@ import org.springframework.messaging.simp.config.MessageBrokerConfigurer;
import org.springframework.messaging.simp.config.StompEndpointRegistry;
import org.springframework.messaging.simp.config.WebSocketMessageBrokerConfigurer;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.support.channel.AbstractSubscribableChannel;
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel;
import org.springframework.stereotype.Controller;
import org.springframework.web.socket.AbstractWebSocketIntegrationTests;
@ -227,13 +227,13 @@ public class SimpAnnotationMethodIntegrationTests extends AbstractWebSocketInteg
@Override
@Bean
public SubscribableChannel webSocketRequestChannel() {
public AbstractSubscribableChannel webSocketRequestChannel() {
return new ExecutorSubscribableChannel(); // synchronous
}
@Override
@Bean
public SubscribableChannel webSocketResponseChannel() {
public AbstractSubscribableChannel webSocketResponseChannel() {
return new ExecutorSubscribableChannel(); // synchronous
}
}