diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java index 30e376fda38..696a19bf7c6 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java @@ -16,24 +16,15 @@ package org.springframework.messaging.simp.config; -import java.util.ArrayList; -import java.util.List; - import org.springframework.context.annotation.Bean; import org.springframework.messaging.Message; -import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler; import org.springframework.messaging.simp.SimpMessageSendingOperations; import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.messaging.simp.handler.*; -import org.springframework.messaging.simp.handler.SimpAnnotationMethodMessageHandler; +import org.springframework.messaging.support.channel.AbstractSubscribableChannel; import org.springframework.messaging.support.channel.ExecutorSubscribableChannel; -import org.springframework.messaging.support.converter.ByteArrayMessageConverter; -import org.springframework.messaging.support.converter.CompositeMessageConverter; -import org.springframework.messaging.support.converter.DefaultContentTypeResolver; -import org.springframework.messaging.support.converter.MappingJackson2MessageConverter; -import org.springframework.messaging.support.converter.MessageConverter; -import org.springframework.messaging.support.converter.StringMessageConverter; +import org.springframework.messaging.support.converter.*; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.util.ClassUtils; @@ -44,6 +35,9 @@ import org.springframework.web.servlet.handler.AbstractHandlerMapping; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.config.SockJsServiceRegistration; +import java.util.ArrayList; +import java.util.List; + /** * Configuration support for broker-backed messaging over WebSocket using a higher-level @@ -118,12 +112,12 @@ public abstract class WebSocketMessageBrokerConfigurationSupport { } @Bean - public SubscribableChannel webSocketRequestChannel() { + public AbstractSubscribableChannel webSocketRequestChannel() { return new ExecutorSubscribableChannel(webSocketChannelExecutor()); } @Bean - public SubscribableChannel webSocketResponseChannel() { + public AbstractSubscribableChannel webSocketResponseChannel() { return new ExecutorSubscribableChannel(webSocketChannelExecutor()); } @@ -209,7 +203,7 @@ public abstract class WebSocketMessageBrokerConfigurationSupport { } @Bean - public SubscribableChannel brokerChannel() { + public AbstractSubscribableChannel brokerChannel() { return new ExecutorSubscribableChannel(); // synchronous } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupportTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupportTests.java index bf01420b390..8103cc36d95 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupportTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupportTests.java @@ -16,19 +16,13 @@ package org.springframework.messaging.simp.config; -import java.util.List; -import java.util.Map; - import org.junit.Before; import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHandler; -import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler; @@ -43,6 +37,8 @@ import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.stomp.StompTextMessageBuilder; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.support.channel.AbstractSubscribableChannel; +import org.springframework.messaging.support.channel.ExecutorSubscribableChannel; import org.springframework.messaging.support.converter.CompositeMessageConverter; import org.springframework.messaging.support.converter.DefaultContentTypeResolver; import org.springframework.stereotype.Controller; @@ -52,9 +48,11 @@ import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.support.TestWebSocketSession; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + import static org.junit.Assert.*; -import static org.mockito.Matchers.*; -import static org.mockito.Mockito.*; /** @@ -95,27 +93,20 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @Test public void webSocketRequestChannel() { - SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketRequestChannel", SubscribableChannel.class); + TestChannel channel = this.cxtSimpleBroker.getBean("webSocketRequestChannel", TestChannel.class); + List handlers = channel.handlers; - ArgumentCaptor captor = ArgumentCaptor.forClass(MessageHandler.class); - verify(channel, times(3)).subscribe(captor.capture()); - - List values = captor.getAllValues(); - assertEquals(3, values.size()); - - assertTrue(values.contains(cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class))); - assertTrue(values.contains(cxtSimpleBroker.getBean(UserDestinationMessageHandler.class))); - assertTrue(values.contains(cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class))); + assertEquals(3, handlers.size()); + assertTrue(handlers.contains(cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class))); + assertTrue(handlers.contains(cxtSimpleBroker.getBean(UserDestinationMessageHandler.class))); + assertTrue(handlers.contains(cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class))); } @Test public void webSocketRequestChannelWithStompBroker() { - SubscribableChannel channel = this.cxtStompBroker.getBean("webSocketRequestChannel", SubscribableChannel.class); + TestChannel channel = this.cxtStompBroker.getBean("webSocketRequestChannel", TestChannel.class); + List values = channel.handlers; - ArgumentCaptor captor = ArgumentCaptor.forClass(MessageHandler.class); - verify(channel, times(3)).subscribe(captor.capture()); - - List values = captor.getAllValues(); assertEquals(3, values.size()); assertTrue(values.contains(cxtStompBroker.getBean(SimpAnnotationMethodMessageHandler.class))); assertTrue(values.contains(cxtStompBroker.getBean(UserDestinationMessageHandler.class))); @@ -125,16 +116,13 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @Test public void webSocketRequestChannelSendMessage() throws Exception { - SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketRequestChannel", SubscribableChannel.class); + TestChannel channel = this.cxtSimpleBroker.getBean("webSocketRequestChannel", TestChannel.class); SubProtocolWebSocketHandler webSocketHandler = this.cxtSimpleBroker.getBean(SubProtocolWebSocketHandler.class); TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.SEND).headers("destination:/foo").build(); webSocketHandler.handleMessage(new TestWebSocketSession(), textMessage); - ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); - verify(channel).send(captor.capture()); - - Message message = captor.getValue(); + Message message = channel.messages.get(0); StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); assertEquals(SimpMessageType.MESSAGE, headers.getMessageType()); @@ -143,15 +131,17 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @Test public void webSocketResponseChannel() { - SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", SubscribableChannel.class); - verify(channel).subscribe(any(SubProtocolWebSocketHandler.class)); - verifyNoMoreInteractions(channel); + TestChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", TestChannel.class); + List values = channel.handlers; + + assertEquals(1, values.size()); + assertTrue(values.get(0) instanceof SubProtocolWebSocketHandler); } @Test public void webSocketResponseChannelUsedByAnnotatedMethod() { - SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", SubscribableChannel.class); + TestChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", TestChannel.class); SimpAnnotationMethodMessageHandler messageHandler = this.cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE); @@ -160,12 +150,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests { headers.setDestination("/foo"); Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); - when(channel.send(any(Message.class))).thenReturn(true); messageHandler.handleMessage(message); - ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); - verify(channel).send(captor.capture()); - message = captor.getValue(); + message = channel.messages.get(0); headers = StompHeaderAccessor.wrap(message); assertEquals(SimpMessageType.MESSAGE, headers.getMessageType()); @@ -175,7 +162,7 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @Test public void webSocketResponseChannelUsedBySimpleBroker() { - SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", SubscribableChannel.class); + TestChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", TestChannel.class); SimpleBrokerMessageHandler broker = this.cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE); @@ -193,12 +180,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests { message = MessageBuilder.withPayload("bar".getBytes()).setHeaders(headers).build(); // message - when(channel.send(any(Message.class))).thenReturn(true); broker.handleMessage(message); - ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); - verify(channel).send(captor.capture()); - message = captor.getValue(); + message = channel.messages.get(0); headers = StompHeaderAccessor.wrap(message); assertEquals(SimpMessageType.MESSAGE, headers.getMessageType()); @@ -208,45 +192,36 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @Test public void brokerChannel() { - SubscribableChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", SubscribableChannel.class); + TestChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", TestChannel.class); + List handlers = channel.handlers; - ArgumentCaptor captor = ArgumentCaptor.forClass(MessageHandler.class); - verify(channel, times(2)).subscribe(captor.capture()); - - List values = captor.getAllValues(); - assertEquals(2, values.size()); - assertTrue(values.contains(cxtSimpleBroker.getBean(UserDestinationMessageHandler.class))); - assertTrue(values.contains(cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class))); + assertEquals(2, handlers.size()); + assertTrue(handlers.contains(cxtSimpleBroker.getBean(UserDestinationMessageHandler.class))); + assertTrue(handlers.contains(cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class))); } @Test public void brokerChannelWithStompBroker() { - SubscribableChannel channel = this.cxtStompBroker.getBean("brokerChannel", SubscribableChannel.class); + TestChannel channel = this.cxtStompBroker.getBean("brokerChannel", TestChannel.class); + List handlers = channel.handlers; - ArgumentCaptor captor = ArgumentCaptor.forClass(MessageHandler.class); - verify(channel, times(2)).subscribe(captor.capture()); - - List values = captor.getAllValues(); - assertEquals(2, values.size()); - assertTrue(values.contains(cxtStompBroker.getBean(UserDestinationMessageHandler.class))); - assertTrue(values.contains(cxtStompBroker.getBean(StompBrokerRelayMessageHandler.class))); + assertEquals(2, handlers.size()); + assertTrue(handlers.contains(cxtStompBroker.getBean(UserDestinationMessageHandler.class))); + assertTrue(handlers.contains(cxtStompBroker.getBean(StompBrokerRelayMessageHandler.class))); } @Test public void brokerChannelUsedByAnnotatedMethod() { - SubscribableChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", SubscribableChannel.class); + TestChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", TestChannel.class); SimpAnnotationMethodMessageHandler messageHandler = this.cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); headers.setDestination("/foo"); Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); - when(channel.send(any(Message.class))).thenReturn(true); messageHandler.handleMessage(message); - ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); - verify(channel).send(captor.capture()); - message = captor.getValue(); + message = channel.messages.get(0); headers = StompHeaderAccessor.wrap(message); assertEquals(SimpMessageType.MESSAGE, headers.getMessageType()); @@ -256,7 +231,7 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @Test public void brokerChannelUsedByUserDestinationMessageHandler() { - SubscribableChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", SubscribableChannel.class); + TestChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", TestChannel.class); UserDestinationMessageHandler messageHandler = this.cxtSimpleBroker.getBean(UserDestinationMessageHandler.class); this.cxtSimpleBroker.getBean(UserSessionRegistry.class).registerSessionId("joe", "s1"); @@ -265,12 +240,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests { headers.setDestination("/user/joe/foo"); Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); - when(channel.send(any(Message.class))).thenReturn(true); messageHandler.handleMessage(message); - ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); - verify(channel).send(captor.capture()); - message = captor.getValue(); + message = channel.messages.get(0); headers = StompHeaderAccessor.wrap(message); assertEquals(SimpMessageType.MESSAGE, headers.getMessageType()); @@ -340,19 +312,39 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @Override @Bean - public SubscribableChannel webSocketRequestChannel() { - return Mockito.mock(SubscribableChannel.class); + public AbstractSubscribableChannel webSocketRequestChannel() { + return new TestChannel(); } @Override @Bean - public SubscribableChannel webSocketResponseChannel() { - return Mockito.mock(SubscribableChannel.class); + public AbstractSubscribableChannel webSocketResponseChannel() { + return new TestChannel(); } @Override - public SubscribableChannel brokerChannel() { - return Mockito.mock(SubscribableChannel.class); + public AbstractSubscribableChannel brokerChannel() { + return new TestChannel(); + } + } + + private static class TestChannel extends ExecutorSubscribableChannel { + + private final List handlers = new ArrayList<>(); + + private final List> messages = new ArrayList<>(); + + + @Override + public boolean subscribeInternal(MessageHandler handler) { + this.handlers.add(handler); + return super.subscribeInternal(handler); + } + + @Override + public boolean sendInternal(Message message, long timeout) { + this.messages.add(message); + return true; } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpAnnotationMethodIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpAnnotationMethodIntegrationTests.java index b993a64e5f6..96b95ae3e6b 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpAnnotationMethodIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpAnnotationMethodIntegrationTests.java @@ -34,7 +34,6 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.Configuration; -import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.handler.annotation.MessageExceptionHandler; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.simp.config.DelegatingWebSocketMessageBrokerConfiguration; @@ -42,6 +41,7 @@ import org.springframework.messaging.simp.config.MessageBrokerConfigurer; import org.springframework.messaging.simp.config.StompEndpointRegistry; import org.springframework.messaging.simp.config.WebSocketMessageBrokerConfigurer; import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.support.channel.AbstractSubscribableChannel; import org.springframework.messaging.support.channel.ExecutorSubscribableChannel; import org.springframework.stereotype.Controller; import org.springframework.web.socket.AbstractWebSocketIntegrationTests; @@ -227,13 +227,13 @@ public class SimpAnnotationMethodIntegrationTests extends AbstractWebSocketInteg @Override @Bean - public SubscribableChannel webSocketRequestChannel() { + public AbstractSubscribableChannel webSocketRequestChannel() { return new ExecutorSubscribableChannel(); // synchronous } @Override @Bean - public SubscribableChannel webSocketResponseChannel() { + public AbstractSubscribableChannel webSocketResponseChannel() { return new ExecutorSubscribableChannel(); // synchronous } }