Make WebSocket fields available after it is closed

Update some native WebSocket session getters to return basic
information after it is closed. It is required for example in
SubProtocolWebSocketHandler#afterConnectionEstablished() or
StompSubProtocolHandler#afterSessionStarted().

Issue: SPR-11621
This commit is contained in:
Sebastien Deleuze 2014-03-31 16:35:40 +02:00 committed by Rossen Stoyanchev
parent ea762b2c74
commit a805f12374
3 changed files with 60 additions and 5 deletions

View File

@ -54,7 +54,9 @@ public class StandardWebSocketSession extends AbstractWebSocketSession<Session>
private final InetSocketAddress remoteAddress; private final InetSocketAddress remoteAddress;
private final Principal user; private Principal user;
private String acceptedProtocol;
private List<WebSocketExtension> extensions; private List<WebSocketExtension> extensions;
@ -136,8 +138,7 @@ public class StandardWebSocketSession extends AbstractWebSocketSession<Session>
@Override @Override
public String getAcceptedProtocol() { public String getAcceptedProtocol() {
checkNativeSessionInitialized(); checkNativeSessionInitialized();
String protocol = getNativeSession().getNegotiatedSubprotocol(); return this.acceptedProtocol;
return StringUtils.isEmpty(protocol)? null : protocol;
} }
@Override @Override
@ -182,6 +183,15 @@ public class StandardWebSocketSession extends AbstractWebSocketSession<Session>
return (getNativeSession() != null && getNativeSession().isOpen()); return (getNativeSession() != null && getNativeSession().isOpen());
} }
@Override
public void initializeNativeSession(Session session) {
super.initializeNativeSession(session);
if(this.user == null) {
this.user = session.getUserPrincipal();
}
this.acceptedProtocol = session.getNegotiatedSubprotocol();
}
@Override @Override
protected void sendTextMessage(TextMessage message) throws IOException { protected void sendTextMessage(TextMessage message) throws IOException {
getNativeSession().getBasicRemote().sendText(message.getPayload(), message.isLast()); getNativeSession().getBasicRemote().sendText(message.getPayload(), message.isLast());

View File

@ -16,12 +16,14 @@
package org.springframework.web.socket.messaging; package org.springframework.web.socket.messaging;
import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import javax.websocket.Session;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -46,6 +48,7 @@ import org.springframework.messaging.support.MessageBuilder;
import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;
import org.springframework.web.socket.handler.TestWebSocketSession; import org.springframework.web.socket.handler.TestWebSocketSession;
import org.springframework.web.socket.sockjs.transport.SockJsSession; import org.springframework.web.socket.sockjs.transport.SockJsSession;
@ -287,6 +290,23 @@ public class StompSubProtocolHandlerTests {
assertTrue(actual.getPayload().startsWith("ERROR")); assertTrue(actual.getPayload().startsWith("ERROR"));
} }
// SPR-11621
@Test
public void availableSessionFieldsAfterSessionEnded() throws IOException {
Session nativeSession = Mockito.mock(Session.class);
when(nativeSession.getId()).thenReturn("1");
when(nativeSession.getUserPrincipal()).thenReturn(new org.springframework.web.socket.handler.TestPrincipal("test"));
when(nativeSession.getNegotiatedSubprotocol()).thenReturn("v12.sToMp");
StandardWebSocketSession standardWebsocketSession = new StandardWebSocketSession(null, null, null, null);
standardWebsocketSession.initializeNativeSession(nativeSession);
this.protocolHandler.afterSessionStarted(standardWebsocketSession, this.channel);
standardWebsocketSession.close(CloseStatus.NORMAL);
this.protocolHandler.afterSessionEnded(standardWebsocketSession, CloseStatus.NORMAL, this.channel);
assertEquals("v12.sToMp", standardWebsocketSession.getAcceptedProtocol());
assertNotNull(standardWebsocketSession.getPrincipal());
}
private static class UniqueUser extends TestPrincipal implements DestinationUserNameProvider { private static class UniqueUser extends TestPrincipal implements DestinationUserNameProvider {

View File

@ -21,14 +21,19 @@ import java.util.Arrays;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.SubscribableChannel;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator; import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
import org.springframework.web.socket.handler.TestPrincipal;
import org.springframework.web.socket.handler.TestWebSocketSession; import org.springframework.web.socket.handler.TestWebSocketSession;
import javax.websocket.Session;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
/** /**
@ -54,6 +59,9 @@ public class SubProtocolWebSocketHandlerTests {
@Mock @Mock
SubscribableChannel outClientChannel; SubscribableChannel outClientChannel;
@Mock
private Session nativeSession;
@Before @Before
public void setup() { public void setup() {
@ -142,4 +150,21 @@ public class SubProtocolWebSocketHandlerTests {
this.webSocketHandler.afterConnectionEstablished(session); this.webSocketHandler.afterConnectionEstablished(session);
} }
// SPR-11621
@Test
public void availableSessionFieldsafterConnectionClosed() throws Exception {
this.webSocketHandler.setDefaultProtocolHandler(stompHandler);
when(nativeSession.getId()).thenReturn("1");
when(nativeSession.getUserPrincipal()).thenReturn(new TestPrincipal("test"));
when(nativeSession.getNegotiatedSubprotocol()).thenReturn("v12.sToMp");
StandardWebSocketSession standardWebsocketSession = new StandardWebSocketSession(null, null, null, null);
standardWebsocketSession.initializeNativeSession(this.nativeSession);
this.webSocketHandler.afterConnectionEstablished(standardWebsocketSession);
standardWebsocketSession.close(CloseStatus.NORMAL);
this.webSocketHandler.afterConnectionClosed(standardWebsocketSession, CloseStatus.NORMAL);
assertEquals("v12.sToMp", standardWebsocketSession.getAcceptedProtocol());
assertNotNull(standardWebsocketSession.getPrincipal());
}
} }