Update JettyWebSocketSession

Ensure the JettyWebSocket session can return the Principal and accepted
WebSocket sub-protocol even after the session is closed.

Issue: SPR-11621
This commit is contained in:
Rossen Stoyanchev 2014-04-01 13:13:55 -04:00
parent a805f12374
commit 73ecbc047c
6 changed files with 257 additions and 50 deletions

View File

@ -50,7 +50,9 @@ public class JettyWebSocketSession extends AbstractWebSocketSession<Session> {
private List<WebSocketExtension> extensions;
private final Principal user;
private Principal user;
private String acceptedProtocol;
/**
@ -105,7 +107,7 @@ public class JettyWebSocketSession extends AbstractWebSocketSession<Session> {
return this.user;
}
checkNativeSessionInitialized();
return getNativeSession().getUpgradeRequest().getUserPrincipal();
return (isOpen() ? getNativeSession().getUpgradeRequest().getUserPrincipal() : null);
}
@Override
@ -123,7 +125,7 @@ public class JettyWebSocketSession extends AbstractWebSocketSession<Session> {
@Override
public String getAcceptedProtocol() {
checkNativeSessionInitialized();
return getNativeSession().getUpgradeResponse().getAcceptedSubProtocol();
return this.acceptedProtocol;
}
@Override
@ -168,6 +170,15 @@ public class JettyWebSocketSession extends AbstractWebSocketSession<Session> {
return ((getNativeSession() != null) && getNativeSession().isOpen());
}
@Override
public void initializeNativeSession(Session session) {
super.initializeNativeSession(session);
if (this.user == null) {
this.user = session.getUpgradeRequest().getUserPrincipal();
}
this.acceptedProtocol = session.getUpgradeResponse().getAcceptedSubProtocol();
}
@Override
protected void sendTextMessage(TextMessage message) throws IOException {
getNativeSession().getRemote().sendString(message.getPayload());

View File

@ -122,7 +122,7 @@ public class StandardWebSocketSession extends AbstractWebSocketSession<Session>
return this.user;
}
checkNativeSessionInitialized();
return getNativeSession().getUserPrincipal();
return (isOpen() ? getNativeSession().getUserPrincipal() : null);
}
@Override

View File

@ -0,0 +1,132 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.adapter.jetty;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.UpgradeResponse;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
import org.springframework.web.socket.handler.TestPrincipal;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
/**
* Unit tests for {@link org.springframework.web.socket.adapter.jetty.JettyWebSocketSession}.
*
* @author Rossen Stoyanchev
*/
public class JettyWebSocketSessionTests {
private Map<String,Object> attributes;
@Before
public void setup() {
this.attributes = new HashMap<>();
}
@Test
public void getPrincipalWithConstructorArg() {
TestPrincipal user = new TestPrincipal("joe");
JettyWebSocketSession session = new JettyWebSocketSession(attributes, user);
assertSame(user, session.getPrincipal());
}
@Test
public void getPrincipalFromNativeSession() {
TestPrincipal user = new TestPrincipal("joe");
UpgradeRequest request = Mockito.mock(UpgradeRequest.class);
when(request.getUserPrincipal()).thenReturn(user);
UpgradeResponse response = Mockito.mock(UpgradeResponse.class);
when(response.getAcceptedSubProtocol()).thenReturn(null);
Session nativeSession = Mockito.mock(Session.class);
when(nativeSession.getUpgradeRequest()).thenReturn(request);
when(nativeSession.getUpgradeResponse()).thenReturn(response);
JettyWebSocketSession session = new JettyWebSocketSession(attributes);
session.initializeNativeSession(nativeSession);
reset(nativeSession);
assertSame(user, session.getPrincipal());
verifyNoMoreInteractions(nativeSession);
}
@Test
public void getPrincipalNotAvailable() {
UpgradeRequest request = Mockito.mock(UpgradeRequest.class);
when(request.getUserPrincipal()).thenReturn(null);
UpgradeResponse response = Mockito.mock(UpgradeResponse.class);
when(response.getAcceptedSubProtocol()).thenReturn(null);
Session nativeSession = Mockito.mock(Session.class);
when(nativeSession.getUpgradeRequest()).thenReturn(request);
when(nativeSession.getUpgradeResponse()).thenReturn(response);
JettyWebSocketSession session = new JettyWebSocketSession(attributes);
session.initializeNativeSession(nativeSession);
reset(nativeSession);
assertNull(session.getPrincipal());
verify(nativeSession).isOpen();
verifyNoMoreInteractions(nativeSession);
}
@Test
public void getAcceptedProtocol() {
String protocol = "foo";
UpgradeRequest request = Mockito.mock(UpgradeRequest.class);
when(request.getUserPrincipal()).thenReturn(null);
UpgradeResponse response = Mockito.mock(UpgradeResponse.class);
when(response.getAcceptedSubProtocol()).thenReturn(protocol);
Session nativeSession = Mockito.mock(Session.class);
when(nativeSession.getUpgradeRequest()).thenReturn(request);
when(nativeSession.getUpgradeResponse()).thenReturn(response);
JettyWebSocketSession session = new JettyWebSocketSession(attributes);
session.initializeNativeSession(nativeSession);
reset(nativeSession);
assertSame(protocol, session.getAcceptedProtocol());
verifyNoMoreInteractions(nativeSession);
}
}

View File

@ -0,0 +1,110 @@
/* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.adapter.standard;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.handler.TestPrincipal;
import javax.websocket.Session;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
/**
* Unit tests for {@link org.springframework.web.socket.adapter.standard.StandardWebSocketSession}.
*
* @author Rossen Stoyanchev
*/
public class StandardWebSocketSessionTests {
private HttpHeaders headers;
private Map<String,Object> attributes;
@Before
public void setup() {
this.headers = new HttpHeaders();
this.attributes = new HashMap<>();
}
@Test
public void getPrincipalWithConstructorArg() {
TestPrincipal user = new TestPrincipal("joe");
StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null, user);
assertSame(user, session.getPrincipal());
}
@Test
public void getPrincipalWithNativeSession() {
TestPrincipal user = new TestPrincipal("joe");
Session nativeSession = Mockito.mock(Session.class);
when(nativeSession.getUserPrincipal()).thenReturn(user);
StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null);
session.initializeNativeSession(nativeSession);
assertSame(user, session.getPrincipal());
}
@Test
public void getPrincipalNone() {
Session nativeSession = Mockito.mock(Session.class);
when(nativeSession.getUserPrincipal()).thenReturn(null);
StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null);
session.initializeNativeSession(nativeSession);
reset(nativeSession);
assertNull(session.getPrincipal());
verify(nativeSession).isOpen();
verifyNoMoreInteractions(nativeSession);
}
@Test
public void getAcceptedProtocol() {
String protocol = "foo";
Session nativeSession = Mockito.mock(Session.class);
when(nativeSession.getNegotiatedSubprotocol()).thenReturn(protocol);
StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null);
session.initializeNativeSession(nativeSession);
reset(nativeSession);
assertEquals(protocol, session.getAcceptedProtocol());
verifyNoMoreInteractions(nativeSession);
}
}

View File

@ -23,7 +23,6 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import javax.websocket.Session;
import org.junit.Before;
import org.junit.Test;
@ -48,7 +47,6 @@ import org.springframework.messaging.support.MessageBuilder;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;
import org.springframework.web.socket.handler.TestWebSocketSession;
import org.springframework.web.socket.sockjs.transport.SockJsSession;
@ -290,23 +288,6 @@ public class StompSubProtocolHandlerTests {
assertTrue(actual.getPayload().startsWith("ERROR"));
}
// SPR-11621
@Test
public void availableSessionFieldsAfterSessionEnded() throws IOException {
Session nativeSession = Mockito.mock(Session.class);
when(nativeSession.getId()).thenReturn("1");
when(nativeSession.getUserPrincipal()).thenReturn(new org.springframework.web.socket.handler.TestPrincipal("test"));
when(nativeSession.getNegotiatedSubprotocol()).thenReturn("v12.sToMp");
StandardWebSocketSession standardWebsocketSession = new StandardWebSocketSession(null, null, null, null);
standardWebsocketSession.initializeNativeSession(nativeSession);
this.protocolHandler.afterSessionStarted(standardWebsocketSession, this.channel);
standardWebsocketSession.close(CloseStatus.NORMAL);
this.protocolHandler.afterSessionEnded(standardWebsocketSession, CloseStatus.NORMAL, this.channel);
assertEquals("v12.sToMp", standardWebsocketSession.getAcceptedProtocol());
assertNotNull(standardWebsocketSession.getPrincipal());
}
private static class UniqueUser extends TestPrincipal implements DestinationUserNameProvider {

View File

@ -24,16 +24,9 @@ import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
import org.springframework.web.socket.handler.TestPrincipal;
import org.springframework.web.socket.handler.TestWebSocketSession;
import javax.websocket.Session;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.mockito.Mockito.*;
/**
@ -59,9 +52,6 @@ public class SubProtocolWebSocketHandlerTests {
@Mock
SubscribableChannel outClientChannel;
@Mock
private Session nativeSession;
@Before
public void setup() {
@ -150,21 +140,4 @@ public class SubProtocolWebSocketHandlerTests {
this.webSocketHandler.afterConnectionEstablished(session);
}
// SPR-11621
@Test
public void availableSessionFieldsafterConnectionClosed() throws Exception {
this.webSocketHandler.setDefaultProtocolHandler(stompHandler);
when(nativeSession.getId()).thenReturn("1");
when(nativeSession.getUserPrincipal()).thenReturn(new TestPrincipal("test"));
when(nativeSession.getNegotiatedSubprotocol()).thenReturn("v12.sToMp");
StandardWebSocketSession standardWebsocketSession = new StandardWebSocketSession(null, null, null, null);
standardWebsocketSession.initializeNativeSession(this.nativeSession);
this.webSocketHandler.afterConnectionEstablished(standardWebsocketSession);
standardWebsocketSession.close(CloseStatus.NORMAL);
this.webSocketHandler.afterConnectionClosed(standardWebsocketSession, CloseStatus.NORMAL);
assertEquals("v12.sToMp", standardWebsocketSession.getAcceptedProtocol());
assertNotNull(standardWebsocketSession.getPrincipal());
}
}