diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpAttributes.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpAttributes.java new file mode 100644 index 00000000000..8a9a14e272b --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpAttributes.java @@ -0,0 +1,209 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.simp; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +import java.util.Map; + +/** + * A wrapper class for access to attributes associated with a SiMP session + * (e.g. WebSocket session). + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class SimpAttributes { + + private static Log logger = LogFactory.getLog(SimpAttributes.class); + + private static final String className = SimpAttributes.class.getName(); + + /** Key for the mutex session attribute */ + public static final String SESSION_MUTEX_NAME = className + ".MUTEX"; + + /** Key set after the session is completed */ + public static final String SESSION_COMPLETED_NAME = className + ".COMPLETED"; + + /** Prefix for the name of session attributes used to store destruction callbacks. */ + public static final String DESTRUCTION_CALLBACK_NAME_PREFIX = className + ".DESTRUCTION_CALLBACK."; + + + private final String sessionId; + + private final Map attributes; + + + /** + * Constructor wrapping the given session attributes map. + * + * @param sessionId the id of the associated session + * @param attributes the attributes + */ + public SimpAttributes(String sessionId, Map attributes) { + Assert.notNull(sessionId, "'sessionId' is required"); + Assert.notNull(attributes, "'attributes' is required"); + this.sessionId = sessionId; + this.attributes = attributes; + } + + + /** + * Extract the SiMP session attributes from the given message, wrap them in + * a {@link SimpAttributes} instance. + * + * @param message the message to extract session attributes from + */ + public static SimpAttributes fromMessage(Message message) { + Assert.notNull(message); + MessageHeaders headers = message.getHeaders(); + String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); + Map sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(headers); + if (sessionId == null || sessionAttributes == null) { + throw new IllegalStateException( + "Message does not contain SiMP session id or attributes: " + message); + } + return new SimpAttributes(sessionId, sessionAttributes); + } + + /** + * Return the value for the attribute of the given name, if any. + * + * @param name the name of the attribute + * @return the current attribute value, or {@code null} if not found + */ + public Object getAttribute(String name) { + return this.attributes.get(name); + } + + /** + * Set the value with the given name replacing an existing value (if any). + * + * @param name the name of the attribute + * @param value the value for the attribute + */ + public void setAttribute(String name, Object value) { + this.attributes.put(name, value); + } + + /** + * Remove the attribute of the given name, if it exists. + * + *

Also removes the registered destruction callback for the specified + * attribute, if any. However it does not execute the callback. + * It is assumed the removed object will continue to be used and destroyed + * independently at the appropriate time. + * + * @param name the name of the attribute + */ + public void removeAttribute(String name) { + this.attributes.remove(name); + removeDestructionCallback(name); + } + + /** + * Retrieve the names of all attributes. + * + * @return the attribute names as String array, never {@code null} + */ + public String[] getAttributeNames() { + return StringUtils.toStringArray(this.attributes.keySet()); + } + + /** + * Register a callback to execute on destruction of the specified attribute. + * The callback is executed when the session is closed. + * + * @param name the name of the attribute to register the callback for + * @param callback the destruction callback to be executed + */ + public void registerDestructionCallback(String name, Runnable callback) { + synchronized (getSessionMutex()) { + if (isSessionCompleted()) { + throw new IllegalStateException("Session id=" + getSessionId() + " already completed"); + } + this.attributes.put(DESTRUCTION_CALLBACK_NAME_PREFIX + name, callback); + } + } + + private void removeDestructionCallback(String name) { + synchronized (getSessionMutex()) { + this.attributes.remove(DESTRUCTION_CALLBACK_NAME_PREFIX + name); + } + } + + /** + * Return an id for the associated session. + * + * @return the session id as String (never {@code null}) + */ + public String getSessionId() { + return this.sessionId; + } + + /** + * Expose the object to synchronize on for the underlying session. + * + * @return the session mutex to use (never {@code null}) + */ + public Object getSessionMutex() { + Object mutex = this.attributes.get(SESSION_MUTEX_NAME); + if (mutex == null) { + mutex = this.attributes; + } + return mutex; + } + + /** + * Whether the {@link #sessionCompleted()} was already invoked. + */ + public boolean isSessionCompleted() { + return (this.attributes.get(SESSION_COMPLETED_NAME) != null); + } + + /** + * Invoked when the session is completed. Executed completion callbacks. + */ + public void sessionCompleted() { + synchronized (getSessionMutex()) { + if (!isSessionCompleted()) { + executeDestructionCallbacks(); + this.attributes.put(SESSION_COMPLETED_NAME, Boolean.TRUE); + } + } + } + + private void executeDestructionCallbacks() { + for (Map.Entry entry : this.attributes.entrySet()) { + if (entry.getKey().startsWith(DESTRUCTION_CALLBACK_NAME_PREFIX)) { + try { + ((Runnable) entry.getValue()).run(); + } + catch (Throwable t) { + if (logger.isErrorEnabled()) { + logger.error("Uncaught error in session attribute destruction callback", t); + } + } + } + } + } + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpAttributesContextHolder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpAttributesContextHolder.java new file mode 100644 index 00000000000..659f1143cab --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpAttributesContextHolder.java @@ -0,0 +1,92 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.simp; + +import org.springframework.core.NamedThreadLocal; +import org.springframework.messaging.Message; + + +/** + * Holder class to expose SiMP attributes associated with a session (e.g. WebSocket) + * in the form of a thread-bound {@link SimpAttributes} object. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public abstract class SimpAttributesContextHolder { + + private static final ThreadLocal attributesHolder = + new NamedThreadLocal("SiMP session attributes"); + + + /** + * Reset the SimpAttributes for the current thread. + */ + public static void resetAttributes() { + attributesHolder.remove(); + } + + /** + * Bind the given SimpAttributes to the current thread, + * + * @param attributes the RequestAttributes to expose + */ + public static void setAttributes(SimpAttributes attributes) { + if (attributes != null) { + attributesHolder.set(attributes); + } + else { + resetAttributes(); + } + } + + /** + * Extract the SiMP session attributes from the given message, wrap them in + * a {@link SimpAttributes} instance and bind it to the current thread, + * + * @param message the message to extract session attributes from + */ + public static void setAttributesFromMessage(Message message) { + setAttributes(SimpAttributes.fromMessage(message)); + } + + /** + * Return the SimpAttributes currently bound to the thread. + * + * @return the attributes or {@code null} if not bound + */ + public static SimpAttributes getAttributes() { + return attributesHolder.get(); + } + + /** + * Return the SimpAttributes currently bound to the thread or raise an + * {@link java.lang.IllegalStateException} if none are bound.. + * + * @return the attributes, never {@code null} + * @throws java.lang.IllegalStateException if attributes are not bound + */ + public static SimpAttributes currentAttributes() throws IllegalStateException { + SimpAttributes attributes = getAttributes(); + if (attributes == null) { + throw new IllegalStateException("No thread-bound SimpAttributes found. " + + "Your code is probably not processing a client message and executing in " + + "message-handling methods invoked by the SimpAnnotationMethodMessageHandler?"); + } + return attributes; + } + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpSessionScope.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpSessionScope.java new file mode 100644 index 00000000000..7c919317c0b --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpSessionScope.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp; + +import org.springframework.beans.factory.ObjectFactory; +import org.springframework.beans.factory.config.Scope; + +/** + * A {@link Scope} implementation exposing the attributes of a SiMP session + * (e.g. WebSocket session). + * + *

Relies on a thread-bound {@link SimpAttributes} instance exported by + * {@link org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler + * SimpAnnotationMethodMessageHandler}. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class SimpSessionScope implements Scope { + + + @Override + public Object get(String name, ObjectFactory objectFactory) { + SimpAttributes simpAttributes = SimpAttributesContextHolder.currentAttributes(); + Object value = simpAttributes.getAttribute(name); + if (value != null) { + return value; + } + synchronized (simpAttributes.getSessionMutex()) { + value = simpAttributes.getAttribute(name); + if (value == null) { + value = objectFactory.getObject(); + simpAttributes.setAttribute(name, value); + } + return value; + } + } + + @Override + public Object remove(String name) { + SimpAttributes simpAttributes = SimpAttributesContextHolder.currentAttributes(); + synchronized (simpAttributes.getSessionMutex()) { + Object value = simpAttributes.getAttribute(name); + if (value != null) { + simpAttributes.removeAttribute(name); + return value; + } else { + return null; + } + } + } + + @Override + public void registerDestructionCallback(String name, Runnable callback) { + SimpAttributesContextHolder.currentAttributes().registerDestructionCallback(name, callback); + } + + @Override + public Object resolveContextualObject(String key) { + return null; + } + + @Override + public String getConversationId() { + return SimpAttributesContextHolder.currentAttributes().getSessionId(); + } +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandler.java index c73c94dd1f4..f3194c917f2 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandler.java @@ -52,6 +52,8 @@ import org.springframework.messaging.handler.invocation.AbstractExceptionHandler import org.springframework.messaging.handler.invocation.AbstractMethodMessageHandler; import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; +import org.springframework.messaging.simp.SimpAttributes; +import org.springframework.messaging.simp.SimpAttributesContextHolder; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageMappingInfo; import org.springframework.messaging.simp.SimpMessageSendingOperations; @@ -390,7 +392,13 @@ public class SimpAnnotationMethodMessageHandler extends AbstractMethodMessageHan accessor.setHeader(DestinationVariableMethodArgumentResolver.DESTINATION_TEMPLATE_VARIABLES_HEADER, vars); } - super.handleMatch(mapping, handlerMethod, lookupDestination, message); + try { + SimpAttributesContextHolder.setAttributesFromMessage(message); + super.handleMatch(mapping, handlerMethod, lookupDestination, message); + } + finally { + SimpAttributesContextHolder.resetAttributes(); + } } @Override diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/SimpAttributesContextHolderTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/SimpAttributesContextHolderTests.java new file mode 100644 index 00000000000..689a5790895 --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/SimpAttributesContextHolderTests.java @@ -0,0 +1,132 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp; + +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.springframework.messaging.Message; +import org.springframework.messaging.support.GenericMessage; +import org.springframework.messaging.support.MessageBuilder; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import static org.hamcrest.MatcherAssert.*; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.sameInstance; +import static org.hamcrest.Matchers.startsWith; + +/** + * Unit tests for + * {@link org.springframework.messaging.simp.SimpAttributesContextHolder}. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class SimpAttributesContextHolderTests { + + private SimpAttributes simpAttributes; + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + + @Before + public void setUp() { + Map map = new ConcurrentHashMap<>(); + this.simpAttributes = new SimpAttributes("session1", map); + } + + @After + public void tearDown() { + SimpAttributesContextHolder.resetAttributes(); + } + + + @Test + public void resetAttributes() { + SimpAttributesContextHolder.setAttributes(this.simpAttributes); + assertThat(SimpAttributesContextHolder.getAttributes(), sameInstance(this.simpAttributes)); + + SimpAttributesContextHolder.resetAttributes(); + assertThat(SimpAttributesContextHolder.getAttributes(), nullValue()); + } + + @Test + public void getAttributes() { + assertThat(SimpAttributesContextHolder.getAttributes(), nullValue()); + + SimpAttributesContextHolder.setAttributes(this.simpAttributes); + assertThat(SimpAttributesContextHolder.getAttributes(), sameInstance(this.simpAttributes)); + } + + @Test + public void setAttributes() { + SimpAttributesContextHolder.setAttributes(this.simpAttributes); + assertThat(SimpAttributesContextHolder.getAttributes(), sameInstance(this.simpAttributes)); + + SimpAttributesContextHolder.setAttributes(null); + assertThat(SimpAttributesContextHolder.getAttributes(), nullValue()); + } + + @Test + public void setAttributesFromMessage() { + + String sessionId = "session1"; + ConcurrentHashMap map = new ConcurrentHashMap<>(); + + SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(); + headerAccessor.setSessionId(sessionId); + headerAccessor.setSessionAttributes(map); + Message message = MessageBuilder.createMessage("", headerAccessor.getMessageHeaders()); + + SimpAttributesContextHolder.setAttributesFromMessage(message); + + SimpAttributes attrs = SimpAttributesContextHolder.getAttributes(); + assertThat(attrs, notNullValue()); + assertThat(attrs.getSessionId(), is(sessionId)); + + attrs.setAttribute("name1", "value1"); + assertThat(map.get("name1"), is("value1")); + } + + @Test + public void setAttributesFromMessageWithMissingHeaders() { + this.thrown.expect(IllegalStateException.class); + this.thrown.expectMessage(startsWith("Message does not contain SiMP session id or attributes")); + SimpAttributesContextHolder.setAttributesFromMessage(new GenericMessage("")); + } + + @Test + public void currentAttributes() { + SimpAttributesContextHolder.setAttributes(this.simpAttributes); + assertThat(SimpAttributesContextHolder.currentAttributes(), sameInstance(this.simpAttributes)); + } + + @Test + public void currentAttributesNone() { + this.thrown.expect(IllegalStateException.class); + this.thrown.expectMessage(startsWith("No thread-bound SimpAttributes found")); + SimpAttributesContextHolder.currentAttributes(); + } + +} \ No newline at end of file diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/SimpAttributesTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/SimpAttributesTests.java new file mode 100644 index 00000000000..48d20c9c7fa --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/SimpAttributesTests.java @@ -0,0 +1,138 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.Mockito; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import static org.mockito.Mockito.*; +import static org.junit.Assert.assertThat; +import static org.hamcrest.Matchers.*; + +/** + * Unit tests for + * {@link org.springframework.messaging.simp.SimpAttributes}. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class SimpAttributesTests { + + private SimpAttributes simpAttributes; + + private Map map; + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + + @Before + public void setup() { + this.map = new ConcurrentHashMap<>(); + this.simpAttributes = new SimpAttributes("session1", this.map); + } + + + @Test + public void getAttribute() { + this.simpAttributes.setAttribute("name1", "value1"); + + assertThat(this.simpAttributes.getAttribute("name1"), is("value1")); + assertThat(this.simpAttributes.getAttribute("name2"), nullValue()); + } + + @Test + public void getAttributeNames() { + this.simpAttributes.setAttribute("name1", "value1"); + this.simpAttributes.setAttribute("name2", "value1"); + this.simpAttributes.setAttribute("name3", "value1"); + + assertThat(this.simpAttributes.getAttributeNames(), arrayContainingInAnyOrder("name1", "name2", "name3")); + } + + @Test + public void registerDestructionCallback() { + Runnable callback = Mockito.mock(Runnable.class); + this.simpAttributes.registerDestructionCallback("name1", callback); + + assertThat(this.simpAttributes.getAttribute( + SimpAttributes.DESTRUCTION_CALLBACK_NAME_PREFIX + "name1"), sameInstance(callback)); + } + + @Test + public void registerDestructionCallbackAfterSessionCompleted() { + this.simpAttributes.sessionCompleted(); + this.thrown.expect(IllegalStateException.class); + this.thrown.expectMessage(containsString("already completed")); + this.simpAttributes.registerDestructionCallback("name1", Mockito.mock(Runnable.class)); + } + + @Test + public void removeDestructionCallback() { + Runnable callback1 = Mockito.mock(Runnable.class); + Runnable callback2 = Mockito.mock(Runnable.class); + this.simpAttributes.registerDestructionCallback("name1", callback1); + this.simpAttributes.registerDestructionCallback("name2", callback2); + + assertThat(this.simpAttributes.getAttributeNames().length, is(2)); + } + + @Test + public void getSessionMutex() { + assertThat(this.simpAttributes.getSessionMutex(), sameInstance(this.map)); + } + + @Test + public void getSessionMutexExplicit() { + Object mutex = new Object(); + this.simpAttributes.setAttribute(SimpAttributes.SESSION_MUTEX_NAME, mutex); + + assertThat(this.simpAttributes.getSessionMutex(), sameInstance(mutex)); + } + + @Test + public void sessionCompleted() { + Runnable callback1 = Mockito.mock(Runnable.class); + Runnable callback2 = Mockito.mock(Runnable.class); + this.simpAttributes.registerDestructionCallback("name1", callback1); + this.simpAttributes.registerDestructionCallback("name2", callback2); + + this.simpAttributes.sessionCompleted(); + + verify(callback1, times(1)).run(); + verify(callback2, times(1)).run(); + } + + @Test + public void sessionCompletedIsIdempotent() { + Runnable callback1 = Mockito.mock(Runnable.class); + this.simpAttributes.registerDestructionCallback("name1", callback1); + + this.simpAttributes.sessionCompleted(); + this.simpAttributes.sessionCompleted(); + this.simpAttributes.sessionCompleted(); + + verify(callback1, times(1)).run(); + } + +} \ No newline at end of file diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/SimpSessionScopeTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/SimpSessionScopeTests.java new file mode 100644 index 00000000000..6644c1033bb --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/SimpSessionScopeTests.java @@ -0,0 +1,105 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.ObjectFactory; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import static org.junit.Assert.assertThat; +import static org.hamcrest.Matchers.*; +import static org.mockito.Mockito.*; + +/** + * Unit tests for {@link org.springframework.messaging.simp.SimpSessionScope}. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class SimpSessionScopeTests { + + private SimpSessionScope scope; + + private ObjectFactory objectFactory; + + private SimpAttributes simpAttributes; + + + @Before + public void setUp() { + this.scope = new SimpSessionScope(); + this.objectFactory = Mockito.mock(ObjectFactory.class); + this.simpAttributes = new SimpAttributes("session1", new ConcurrentHashMap<>()); + SimpAttributesContextHolder.setAttributes(this.simpAttributes); + } + + @After + public void tearDown() { + SimpAttributesContextHolder.resetAttributes(); + } + + @Test + public void get() { + this.simpAttributes.setAttribute("name", "value"); + Object actual = this.scope.get("name", this.objectFactory); + + assertThat(actual, is("value")); + } + + @Test + public void getWithObjectFactory() { + when(this.objectFactory.getObject()).thenReturn("value"); + Object actual = this.scope.get("name", this.objectFactory); + + assertThat(actual, is("value")); + assertThat(this.simpAttributes.getAttribute("name"), is("value")); + } + + @Test + public void remove() { + this.simpAttributes.setAttribute("name", "value"); + + Object removed = this.scope.remove("name"); + assertThat(removed, is("value")); + assertThat(this.simpAttributes.getAttribute("name"), nullValue()); + + removed = this.scope.remove("name"); + assertThat(removed, nullValue()); + } + + @Test + public void registerDestructionCallback() { + Runnable runnable = Mockito.mock(Runnable.class); + this.scope.registerDestructionCallback("name", runnable); + + this.simpAttributes.sessionCompleted(); + verify(runnable, times(1)).run(); + } + + @Test + public void getSessionId() { + assertThat(this.scope.getConversationId(), is("session1")); + } + + +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandlerTests.java index 559892273ca..7aca52c7bd4 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandlerTests.java @@ -18,6 +18,7 @@ package org.springframework.messaging.simp.annotation.support; import java.util.LinkedHashMap; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import org.junit.Before; import org.junit.Test; @@ -28,6 +29,8 @@ import org.springframework.messaging.MessageChannel; import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.handler.annotation.*; import org.springframework.messaging.handler.annotation.support.MethodArgumentNotValidException; +import org.springframework.messaging.simp.SimpAttributes; +import org.springframework.messaging.simp.SimpAttributesContextHolder; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageSendingOperations; import org.springframework.messaging.simp.SimpMessageType; @@ -39,6 +42,7 @@ import org.springframework.validation.Errors; import org.springframework.validation.Validator; import org.springframework.validation.annotation.Validated; +import static org.hamcrest.Matchers.is; import static org.junit.Assert.*; /** @@ -69,7 +73,7 @@ public class SimpAnnotationMethodMessageHandlerTests { this.messageHandler.afterPropertiesSet(); testController = new TestController(); - this.messageHandler.registerHandler(testController); + this.messageHandler.registerHandler(this.testController); } @@ -77,6 +81,8 @@ public class SimpAnnotationMethodMessageHandlerTests { @Test public void headerArgumentResolution() { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(); + headers.setSessionId("session1"); + headers.setSessionAttributes(new ConcurrentHashMap<>()); headers.setDestination("/pre/headers"); headers.setHeader("foo", "bar"); Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); @@ -90,6 +96,8 @@ public class SimpAnnotationMethodMessageHandlerTests { @Test public void messageMappingDestinationVariableResolution() { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(); + headers.setSessionId("session1"); + headers.setSessionAttributes(new ConcurrentHashMap<>()); headers.setDestination("/pre/message/bar/value"); Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); this.messageHandler.handleMessage(message); @@ -102,6 +110,8 @@ public class SimpAnnotationMethodMessageHandlerTests { @Test public void subscribeEventDestinationVariableResolution() { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.SUBSCRIBE); + headers.setSessionId("session1"); + headers.setSessionAttributes(new ConcurrentHashMap<>()); headers.setDestination("/pre/sub/bar/value"); Message message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build(); this.messageHandler.handleMessage(message); @@ -114,6 +124,8 @@ public class SimpAnnotationMethodMessageHandlerTests { @Test public void simpleBinding() { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(); + headers.setSessionId("session1"); + headers.setSessionAttributes(new ConcurrentHashMap<>()); headers.setDestination("/pre/binding/id/12"); Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); this.messageHandler.handleMessage(message); @@ -126,12 +138,28 @@ public class SimpAnnotationMethodMessageHandlerTests { @Test public void validationError() { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(); + headers.setSessionId("session1"); + headers.setSessionAttributes(new ConcurrentHashMap<>()); headers.setDestination("/pre/validation/payload"); Message message = MessageBuilder.withPayload(TEST_INVALID_VALUE.getBytes()).setHeaders(headers).build(); this.messageHandler.handleMessage(message); assertEquals("handleValidationException", this.testController.method); } + @Test + public void simpScope() { + ConcurrentHashMap map = new ConcurrentHashMap<>(); + map.put("name", "value"); + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(); + headers.setSessionId("session1"); + headers.setSessionAttributes(map); + headers.setDestination("/pre/scope"); + Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + this.messageHandler.handleMessage(message); + + assertEquals("scope", this.testController.method); + } + private static class TestSimpAnnotationMethodMessageHandler extends SimpAnnotationMethodMessageHandler { @@ -195,6 +223,13 @@ public class SimpAnnotationMethodMessageHandlerTests { public void handleValidationException() { this.method = "handleValidationException"; } + + @MessageMapping("/scope") + public void scope() { + SimpAttributes simpAttributes = SimpAttributesContextHolder.currentAttributes(); + assertThat(simpAttributes.getAttribute("name"), is("value")); + this.method = "scope"; + } } private static class StringTestValidator implements Validator { diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java index 747e7cf948d..3e5fa408c87 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import org.hamcrest.Matchers; import org.junit.Before; @@ -141,6 +142,7 @@ public class MessageBrokerConfigurationTests { StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE); headers.setSessionId("sess1"); + headers.setSessionAttributes(new ConcurrentHashMap<>()); headers.setSubscriptionId("subs1"); headers.setDestination("/foo"); Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); @@ -230,6 +232,8 @@ public class MessageBrokerConfigurationTests { this.simpleBrokerContext.getBean(SimpAnnotationMethodMessageHandler.class); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); + headers.setSessionId("sess1"); + headers.setSessionAttributes(new ConcurrentHashMap<>()); headers.setDestination("/foo"); Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java index c79df74ef94..1b78a8d2dfb 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java @@ -17,8 +17,12 @@ package org.springframework.web.socket.config; import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Map; +import org.springframework.beans.factory.config.CustomScopeConfigurer; +import org.springframework.messaging.simp.SimpSessionScope; import org.w3c.dom.Element; import org.springframework.beans.MutablePropertyValues; @@ -167,6 +171,11 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { registerUserDestinationMessageHandler(clientInChannel, clientOutChannel, brokerChannel, userDestinationResolver, parserCxt, source); + Map scopeMap = Collections.singletonMap("websocket", new SimpSessionScope()); + RootBeanDefinition scopeConfigurerDef = new RootBeanDefinition(CustomScopeConfigurer.class); + scopeConfigurerDef.getPropertyValues().add("scopes", scopeMap); + registerBeanDefByName("webSocketScopeConfigurer", scopeConfigurerDef, parserCxt, source); + parserCxt.popAndRegisterContainingComponent(); return null; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/EnableWebSocket.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/EnableWebSocket.java index 15722bdf2e1..b381a93db65 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/EnableWebSocket.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/EnableWebSocket.java @@ -58,6 +58,6 @@ import org.springframework.context.annotation.Import; @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) @Documented -@Import(DelegatingWebSocketConfiguration.class) +@Import({DelegatingWebSocketConfiguration.class, WebSocketScopeConfiguration.class}) public @interface EnableWebSocket { } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java index a0ab2dafc15..38520a47a2c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java @@ -16,7 +16,9 @@ package org.springframework.web.socket.config.annotation; +import org.springframework.beans.factory.config.CustomScopeConfigurer; import org.springframework.context.annotation.Bean; +import org.springframework.messaging.simp.SimpSessionScope; import org.springframework.messaging.simp.config.AbstractMessageBrokerConfiguration; import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; @@ -24,6 +26,8 @@ import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; +import java.util.Collections; + /** * Extends {@link AbstractMessageBrokerConfiguration} and adds configuration for * receiving and responding to STOMP messages from WebSocket clients. @@ -75,6 +79,8 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac protected void configureWebSocketTransport(WebSocketTransportRegistration registry) { } + protected abstract void registerStompEndpoints(StompEndpointRegistry registry); + /** * The default TaskScheduler to use if none is configured via * {@link SockJsServiceRegistration#setTaskScheduler(org.springframework.scheduling.TaskScheduler)}, i.e. @@ -100,6 +106,11 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac return scheduler; } - protected abstract void registerStompEndpoints(StompEndpointRegistry registry); + @Bean + public static CustomScopeConfigurer webSocketScopeConfigurer() { + CustomScopeConfigurer configurer = new CustomScopeConfigurer(); + configurer.setScopes(Collections.singletonMap("websocket", new SimpSessionScope())); + return configurer; + } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketScopeConfiguration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketScopeConfiguration.java new file mode 100644 index 00000000000..66b75824d3a --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketScopeConfiguration.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.config.annotation; + +import org.springframework.beans.factory.config.CustomScopeConfigurer; +import org.springframework.context.annotation.Bean; +import org.springframework.messaging.simp.SimpSessionScope; + +import java.util.Collections; + +/** + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class WebSocketScopeConfiguration { + + @Bean + public CustomScopeConfigurer webSocketScopeConfigurer() { + CustomScopeConfigurer configurer = new CustomScopeConfigurer(); + configurer.setScopes(Collections.singletonMap("websocket", new SimpSessionScope())); + return configurer; + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index bc363289114..c82ebf89282 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -33,6 +33,8 @@ import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.SimpAttributes; +import org.springframework.messaging.simp.SimpAttributesContextHolder; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.stomp.BufferingStompDecoder; @@ -221,7 +223,13 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE publishEvent(new SessionConnectEvent(this, message)); } - outputChannel.send(message); + try { + SimpAttributesContextHolder.setAttributesFromMessage(message); + outputChannel.send(message); + } + finally { + SimpAttributesContextHolder.resetAttributes(); + } } catch (Throwable ex) { logger.error("Terminating STOMP session due to failure to send message", ex); @@ -420,22 +428,33 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE this.userSessionRegistry.unregisterSessionId(userName, session.getId()); } - if (logger.isDebugEnabled()) { - logger.debug("WebSocket session ended, sending DISCONNECT message to broker"); + if (this.eventPublisher != null) { + publishEvent(new SessionDisconnectEvent(this, session.getId(), closeStatus)); } + Message message = createDisconnectMessage(session); + SimpAttributes simpAttributes = SimpAttributes.fromMessage(message); + try { + if (logger.isDebugEnabled()) { + logger.debug("WebSocket session ended, sending DISCONNECT message to broker"); + } + SimpAttributesContextHolder.setAttributes(simpAttributes); + outputChannel.send(message); + } + finally { + SimpAttributesContextHolder.resetAttributes(); + simpAttributes.sessionCompleted(); + } + } + + private Message createDisconnectMessage(WebSocketSession session) { StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.DISCONNECT); if (getHeaderInitializer() != null) { getHeaderInitializer().initHeaders(headerAccessor); } headerAccessor.setSessionId(session.getId()); - Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headerAccessor.getMessageHeaders()); - - if (this.eventPublisher != null) { - publishEvent(new SessionDisconnectEvent(this, session.getId(), closeStatus)); - } - - outputChannel.send(message); + headerAccessor.setSessionAttributes(session.getAttributes()); + return MessageBuilder.createMessage(EMPTY_PAYLOAD, headerAccessor.getMessageHeaders()); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java index 5625a6940df..416188a6afa 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java @@ -26,6 +26,7 @@ import org.junit.Test; import org.springframework.beans.DirectFieldAccessor; import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.beans.factory.config.CustomScopeConfigurer; import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; import org.springframework.core.io.ClassPathResource; import org.springframework.messaging.MessageHandler; @@ -175,6 +176,8 @@ public class MessageBrokerBeanDefinitionParserTests { catch (NoSuchBeanDefinitionException ex) { // expected } + + assertNotNull(this.appContext.getBean("webSocketScopeConfigurer", CustomScopeConfigurer.class)); } @Test diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index 77625f2515f..506f6bb225f 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -31,6 +31,8 @@ import org.springframework.context.ApplicationEvent; import org.springframework.context.ApplicationEventPublisher; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.SimpAttributes; +import org.springframework.messaging.simp.SimpAttributesContextHolder; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.TestPrincipal; @@ -48,6 +50,7 @@ import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.handler.TestWebSocketSession; import org.springframework.web.socket.sockjs.transport.SockJsSession; +import static org.hamcrest.Matchers.is; import static org.junit.Assert.*; import static org.mockito.Mockito.*; @@ -289,6 +292,41 @@ public class StompSubProtocolHandlerTests { assertTrue(actual.getPayload().startsWith("ERROR")); } + @Test + public void webSocketScope() { + + Runnable runnable = Mockito.mock(Runnable.class); + SimpAttributes simpAttributes = new SimpAttributes(this.session.getId(), this.session.getAttributes()); + simpAttributes.setAttribute("name", "value"); + simpAttributes.registerDestructionCallback("name", runnable); + + MessageChannel testChannel = new MessageChannel() { + @Override + public boolean send(Message message) { + SimpAttributes simpAttributes = SimpAttributesContextHolder.currentAttributes(); + assertThat(simpAttributes.getAttribute("name"), is("value")); + return true; + } + @Override + public boolean send(Message message, long timeout) { + return false; + } + }; + + this.protocolHandler.afterSessionStarted(this.session, this.channel); + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); + Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); + TextMessage textMessage = new TextMessage(new StompEncoder().encode(message)); + + this.protocolHandler.handleMessageFromClient(this.session, textMessage, testChannel); + assertEquals(Collections.emptyList(), session.getSentMessages()); + + this.protocolHandler.afterSessionEnded(this.session, CloseStatus.BAD_DATA, testChannel); + assertEquals(Collections.emptyList(), session.getSentMessages()); + verify(runnable, times(1)).run(); + } + private static class UniqueUser extends TestPrincipal implements DestinationUserNameProvider { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java index 7220456811b..f795dcc1f35 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java @@ -34,6 +34,8 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Scope; +import org.springframework.context.annotation.ScopedProxyMode; import org.springframework.messaging.handler.annotation.MessageExceptionHandler; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.simp.annotation.SendToUser; @@ -97,11 +99,11 @@ public class StompWebSocketIntegrationTests extends AbstractWebSocketIntegration @Test public void sendMessageToControllerAndReceiveReplyViaTopic() throws Exception { - TextMessage message1 = create(StompCommand.SUBSCRIBE).headers( - "id:subs1", "destination:/topic/increment").build(); + TextMessage message1 = create(StompCommand.SUBSCRIBE) + .headers("id:subs1", "destination:/topic/increment").build(); - TextMessage message2 = create(StompCommand.SEND).headers( - "destination:/app/increment").body("5").build(); + TextMessage message2 = create(StompCommand.SEND) + .headers("destination:/app/increment").body("5").build(); TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, message1, message2); WebSocketSession session = doHandshake(clientHandler, "/ws").get(); @@ -181,6 +183,37 @@ public class StompWebSocketIntegrationTests extends AbstractWebSocketIntegration } } + @Test + public void webSocketScope() throws Exception { + + TextMessage message1 = create(StompCommand.SUBSCRIBE) + .headers("id:subs1", "destination:/topic/scopedBeanValue").build(); + + TextMessage message2 = create(StompCommand.SEND) + .headers("destination:/app/scopedBeanValue").build(); + + TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, message1, message2); + WebSocketSession session = doHandshake(clientHandler, "/ws").get(); + + try { + assertTrue(clientHandler.latch.await(2, TimeUnit.SECONDS)); + + String payload = clientHandler.actual.get(0).getPayload(); + assertTrue(payload.startsWith("MESSAGE\n")); + assertTrue(payload.contains("destination:/topic/scopedBeanValue\n")); + assertTrue(payload.endsWith("\"55\"\0")); + } + finally { + session.close(); + } + } + + + @Target({ElementType.TYPE}) + @Retention(RetentionPolicy.RUNTIME) + @Controller + private @interface IntegrationTestController { + } @IntegrationTestController static class SimpleController { @@ -218,6 +251,42 @@ public class StompWebSocketIntegrationTests extends AbstractWebSocketIntegration } } + @IntegrationTestController + static class ScopedBeanController { + + private final ScopedBean scopedBean; + + @Autowired + public ScopedBeanController(ScopedBean scopedBean) { + this.scopedBean = scopedBean; + } + + @MessageMapping(value="/scopedBeanValue") + public String getValue() { + return this.scopedBean.getValue(); + } + } + + + static interface ScopedBean { + + String getValue(); + } + + static class ScopedBeanImpl implements ScopedBean { + + private final String value; + + public ScopedBeanImpl(String value) { + this.value = value; + } + + @Override + public String getValue() { + return this.value; + } + } + private static class TestClientWebSocketHandler extends TextWebSocketHandler { @@ -251,7 +320,8 @@ public class StompWebSocketIntegrationTests extends AbstractWebSocketIntegration } @Configuration - @ComponentScan(basePackageClasses=StompWebSocketIntegrationTests.class, + @ComponentScan( + basePackageClasses=StompWebSocketIntegrationTests.class, useDefaultFilters=false, includeFilters=@ComponentScan.Filter(IntegrationTestController.class)) static class TestMessageBrokerConfigurer extends AbstractWebSocketMessageBrokerConfigurer { @@ -269,6 +339,12 @@ public class StompWebSocketIntegrationTests extends AbstractWebSocketIntegration configurer.setApplicationDestinationPrefixes("/app"); configurer.enableSimpleBroker("/topic", "/queue"); } + + @Bean + @Scope(value="websocket", proxyMode=ScopedProxyMode.INTERFACES) + public ScopedBean scopedBean() { + return new ScopedBeanImpl("55"); + } } @Configuration @@ -287,10 +363,4 @@ public class StompWebSocketIntegrationTests extends AbstractWebSocketIntegration } } - @Target({ElementType.TYPE}) - @Retention(RetentionPolicy.RUNTIME) - @Controller - private @interface IntegrationTestController { - } - } diff --git a/src/asciidoc/index.adoc b/src/asciidoc/index.adoc index 08b88c0b701..3e6bb8f8eba 100644 --- a/src/asciidoc/index.adoc +++ b/src/asciidoc/index.adoc @@ -38140,7 +38140,7 @@ the message being handled through the `@SendToUser` annotation: [subs="verbatim,quotes"] ---- @Controller -public class MyController { +public class PortfolioController { @MessageMapping("/trade") @SendToUser("/queue/position-updates") @@ -38270,6 +38270,77 @@ implement their own reconnect logic. +[[websocket-stomp-websocket-scope]] +==== WebSocket Scope + +Each WebSocket session has a map of attributes. The map is attached as a header to +inbound client messages and may be accessed from a controller method, for example: + +[source,java,indent=0] +[subs="verbatim,quotes"] +---- +@Controller +public class MyController { + + @MessageMapping("/action") + public void handle(SimpMessageHeaderAccessor headerAccessor) { + Map attrs = headerAccessor.getSessionAttributes(); + // ... + } +} +---- + +It is also possible to declare a Spring-managed bean in the `"websocket"` scope. +WebSocket-scoped beans can be injected into controllers and any channel interceptors +registered on the "clientInboundChannel". Those are typically singletons and live +longer than any individual WebSocket session. Therefore you will need to use a +scope proxy mode for WebSocket-scoped beans: + +[source,java,indent=0] +[subs="verbatim,quotes"] +---- +@Component +@Scope(value="websocket", proxyMode = ScopedProxyMode.TARGET_CLASS) +public class MyBean { + + @PostConstruct + public void init() { + // Invoked after dependencies injected + } + + // ... + + @PreDestroy + public void destroy() { + // Invoked when the WebSocket session ends + } +} + +@Controller +public class MyController { + + private final MyBean myBean; + + @Autowired + public MyController(MyBean myBean) { + this.myBean = myBean; + } + + @MessageMapping("/action") + public void handle() { + // this.myBean from the current WebSocket session + } +} +---- + +As with any custom scope, Spring initializes a new MyBean instance the first +time it is accessed from the controller and stores the instance in the WebSocket +session attributes. The same instance is returned subsequently until the session +ends. WebSocket-scoped beans will have all Spring lifecycle methods invoked as +shown in the examples above. + + + [[websocket-stomp-configuration-performance]] ==== Configuration and Performance