diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSession.java index d6f188a70d..5e2328264b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSession.java @@ -54,7 +54,9 @@ public class StandardWebSocketSession extends AbstractWebSocketSession private final InetSocketAddress remoteAddress; - private final Principal user; + private Principal user; + + private String acceptedProtocol; private List extensions; @@ -136,8 +138,7 @@ public class StandardWebSocketSession extends AbstractWebSocketSession @Override public String getAcceptedProtocol() { checkNativeSessionInitialized(); - String protocol = getNativeSession().getNegotiatedSubprotocol(); - return StringUtils.isEmpty(protocol)? null : protocol; + return this.acceptedProtocol; } @Override @@ -182,6 +183,15 @@ public class StandardWebSocketSession extends AbstractWebSocketSession return (getNativeSession() != null && getNativeSession().isOpen()); } + @Override + public void initializeNativeSession(Session session) { + super.initializeNativeSession(session); + if(this.user == null) { + this.user = session.getUserPrincipal(); + } + this.acceptedProtocol = session.getNegotiatedSubprotocol(); + } + @Override protected void sendTextMessage(TextMessage message) throws IOException { getNativeSession().getBasicRemote().sendText(message.getPayload(), message.isLast()); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index 5aaf1127ab..2836c94e63 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -16,12 +16,14 @@ package org.springframework.web.socket.messaging; +import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; +import javax.websocket.Session; import org.junit.Before; import org.junit.Test; @@ -46,6 +48,7 @@ import org.springframework.messaging.support.MessageBuilder; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketMessage; +import org.springframework.web.socket.adapter.standard.StandardWebSocketSession; import org.springframework.web.socket.handler.TestWebSocketSession; import org.springframework.web.socket.sockjs.transport.SockJsSession; @@ -287,6 +290,23 @@ public class StompSubProtocolHandlerTests { assertTrue(actual.getPayload().startsWith("ERROR")); } + // SPR-11621 + + @Test + public void availableSessionFieldsAfterSessionEnded() throws IOException { + Session nativeSession = Mockito.mock(Session.class); + when(nativeSession.getId()).thenReturn("1"); + when(nativeSession.getUserPrincipal()).thenReturn(new org.springframework.web.socket.handler.TestPrincipal("test")); + when(nativeSession.getNegotiatedSubprotocol()).thenReturn("v12.sToMp"); + StandardWebSocketSession standardWebsocketSession = new StandardWebSocketSession(null, null, null, null); + standardWebsocketSession.initializeNativeSession(nativeSession); + this.protocolHandler.afterSessionStarted(standardWebsocketSession, this.channel); + standardWebsocketSession.close(CloseStatus.NORMAL); + this.protocolHandler.afterSessionEnded(standardWebsocketSession, CloseStatus.NORMAL, this.channel); + assertEquals("v12.sToMp", standardWebsocketSession.getAcceptedProtocol()); + assertNotNull(standardWebsocketSession.getPrincipal()); + } + private static class UniqueUser extends TestPrincipal implements DestinationUserNameProvider { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java index b021199e18..a374e17b30 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java @@ -21,14 +21,19 @@ import java.util.Arrays; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.SubscribableChannel; -import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.adapter.standard.StandardWebSocketSession; import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator; +import org.springframework.web.socket.handler.TestPrincipal; import org.springframework.web.socket.handler.TestWebSocketSession; +import javax.websocket.Session; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.mockito.Mockito.*; /** @@ -54,6 +59,9 @@ public class SubProtocolWebSocketHandlerTests { @Mock SubscribableChannel outClientChannel; + @Mock + private Session nativeSession; + @Before public void setup() { @@ -142,4 +150,21 @@ public class SubProtocolWebSocketHandlerTests { this.webSocketHandler.afterConnectionEstablished(session); } + // SPR-11621 + + @Test + public void availableSessionFieldsafterConnectionClosed() throws Exception { + this.webSocketHandler.setDefaultProtocolHandler(stompHandler); + when(nativeSession.getId()).thenReturn("1"); + when(nativeSession.getUserPrincipal()).thenReturn(new TestPrincipal("test")); + when(nativeSession.getNegotiatedSubprotocol()).thenReturn("v12.sToMp"); + StandardWebSocketSession standardWebsocketSession = new StandardWebSocketSession(null, null, null, null); + standardWebsocketSession.initializeNativeSession(this.nativeSession); + this.webSocketHandler.afterConnectionEstablished(standardWebsocketSession); + standardWebsocketSession.close(CloseStatus.NORMAL); + this.webSocketHandler.afterConnectionClosed(standardWebsocketSession, CloseStatus.NORMAL); + assertEquals("v12.sToMp", standardWebsocketSession.getAcceptedProtocol()); + assertNotNull(standardWebsocketSession.getPrincipal()); + } + }