diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java index 8eb57594c3..b780cfb2de 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java @@ -50,7 +50,9 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { private List extensions; - private final Principal user; + private Principal user; + + private String acceptedProtocol; /** @@ -105,7 +107,7 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { return this.user; } checkNativeSessionInitialized(); - return getNativeSession().getUpgradeRequest().getUserPrincipal(); + return (isOpen() ? getNativeSession().getUpgradeRequest().getUserPrincipal() : null); } @Override @@ -123,7 +125,7 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { @Override public String getAcceptedProtocol() { checkNativeSessionInitialized(); - return getNativeSession().getUpgradeResponse().getAcceptedSubProtocol(); + return this.acceptedProtocol; } @Override @@ -168,6 +170,15 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { return ((getNativeSession() != null) && getNativeSession().isOpen()); } + @Override + public void initializeNativeSession(Session session) { + super.initializeNativeSession(session); + if (this.user == null) { + this.user = session.getUpgradeRequest().getUserPrincipal(); + } + this.acceptedProtocol = session.getUpgradeResponse().getAcceptedSubProtocol(); + } + @Override protected void sendTextMessage(TextMessage message) throws IOException { getNativeSession().getRemote().sendString(message.getPayload()); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSession.java index 5e2328264b..1ce37bc3d5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSession.java @@ -122,7 +122,7 @@ public class StandardWebSocketSession extends AbstractWebSocketSession return this.user; } checkNativeSessionInitialized(); - return getNativeSession().getUserPrincipal(); + return (isOpen() ? getNativeSession().getUserPrincipal() : null); } @Override diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSessionTests.java new file mode 100644 index 0000000000..74b4273217 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSessionTests.java @@ -0,0 +1,132 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.adapter.jetty; + +import org.eclipse.jetty.websocket.api.Session; +import org.eclipse.jetty.websocket.api.UpgradeRequest; +import org.eclipse.jetty.websocket.api.UpgradeResponse; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; +import org.springframework.web.socket.handler.TestPrincipal; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +/** + * Unit tests for {@link org.springframework.web.socket.adapter.jetty.JettyWebSocketSession}. + * + * @author Rossen Stoyanchev + */ +public class JettyWebSocketSessionTests { + + private Map attributes; + + + @Before + public void setup() { + this.attributes = new HashMap<>(); + } + + + @Test + public void getPrincipalWithConstructorArg() { + TestPrincipal user = new TestPrincipal("joe"); + JettyWebSocketSession session = new JettyWebSocketSession(attributes, user); + + assertSame(user, session.getPrincipal()); + } + + @Test + public void getPrincipalFromNativeSession() { + + TestPrincipal user = new TestPrincipal("joe"); + + UpgradeRequest request = Mockito.mock(UpgradeRequest.class); + when(request.getUserPrincipal()).thenReturn(user); + + UpgradeResponse response = Mockito.mock(UpgradeResponse.class); + when(response.getAcceptedSubProtocol()).thenReturn(null); + + Session nativeSession = Mockito.mock(Session.class); + when(nativeSession.getUpgradeRequest()).thenReturn(request); + when(nativeSession.getUpgradeResponse()).thenReturn(response); + + JettyWebSocketSession session = new JettyWebSocketSession(attributes); + session.initializeNativeSession(nativeSession); + + reset(nativeSession); + + assertSame(user, session.getPrincipal()); + verifyNoMoreInteractions(nativeSession); + } + + @Test + public void getPrincipalNotAvailable() { + + UpgradeRequest request = Mockito.mock(UpgradeRequest.class); + when(request.getUserPrincipal()).thenReturn(null); + + UpgradeResponse response = Mockito.mock(UpgradeResponse.class); + when(response.getAcceptedSubProtocol()).thenReturn(null); + + Session nativeSession = Mockito.mock(Session.class); + when(nativeSession.getUpgradeRequest()).thenReturn(request); + when(nativeSession.getUpgradeResponse()).thenReturn(response); + + JettyWebSocketSession session = new JettyWebSocketSession(attributes); + session.initializeNativeSession(nativeSession); + + reset(nativeSession); + + assertNull(session.getPrincipal()); + verify(nativeSession).isOpen(); + verifyNoMoreInteractions(nativeSession); + } + + @Test + public void getAcceptedProtocol() { + + String protocol = "foo"; + + UpgradeRequest request = Mockito.mock(UpgradeRequest.class); + when(request.getUserPrincipal()).thenReturn(null); + + UpgradeResponse response = Mockito.mock(UpgradeResponse.class); + when(response.getAcceptedSubProtocol()).thenReturn(protocol); + + Session nativeSession = Mockito.mock(Session.class); + when(nativeSession.getUpgradeRequest()).thenReturn(request); + when(nativeSession.getUpgradeResponse()).thenReturn(response); + + JettyWebSocketSession session = new JettyWebSocketSession(attributes); + session.initializeNativeSession(nativeSession); + + reset(nativeSession); + + assertSame(protocol, session.getAcceptedProtocol()); + verifyNoMoreInteractions(nativeSession); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSessionTests.java new file mode 100644 index 0000000000..ef2cf002be --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSessionTests.java @@ -0,0 +1,110 @@ +/* Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.adapter.standard; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; +import org.springframework.http.HttpHeaders; +import org.springframework.web.socket.handler.TestPrincipal; + +import javax.websocket.Session; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +/** + * Unit tests for {@link org.springframework.web.socket.adapter.standard.StandardWebSocketSession}. + * + * @author Rossen Stoyanchev + */ +public class StandardWebSocketSessionTests { + + private HttpHeaders headers; + + private Map attributes; + + + @Before + public void setup() { + this.headers = new HttpHeaders(); + this.attributes = new HashMap<>(); + } + + + @Test + public void getPrincipalWithConstructorArg() { + TestPrincipal user = new TestPrincipal("joe"); + StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null, user); + + assertSame(user, session.getPrincipal()); + } + + @Test + public void getPrincipalWithNativeSession() { + + TestPrincipal user = new TestPrincipal("joe"); + + Session nativeSession = Mockito.mock(Session.class); + when(nativeSession.getUserPrincipal()).thenReturn(user); + + StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null); + session.initializeNativeSession(nativeSession); + + assertSame(user, session.getPrincipal()); + } + + @Test + public void getPrincipalNone() { + + Session nativeSession = Mockito.mock(Session.class); + when(nativeSession.getUserPrincipal()).thenReturn(null); + + StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null); + session.initializeNativeSession(nativeSession); + + reset(nativeSession); + + assertNull(session.getPrincipal()); + verify(nativeSession).isOpen(); + verifyNoMoreInteractions(nativeSession); + } + + @Test + public void getAcceptedProtocol() { + + String protocol = "foo"; + + Session nativeSession = Mockito.mock(Session.class); + when(nativeSession.getNegotiatedSubprotocol()).thenReturn(protocol); + + StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null); + session.initializeNativeSession(nativeSession); + + reset(nativeSession); + + assertEquals(protocol, session.getAcceptedProtocol()); + verifyNoMoreInteractions(nativeSession); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index 2836c94e63..426a0599ab 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -23,7 +23,6 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; -import javax.websocket.Session; import org.junit.Before; import org.junit.Test; @@ -48,7 +47,6 @@ import org.springframework.messaging.support.MessageBuilder; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketMessage; -import org.springframework.web.socket.adapter.standard.StandardWebSocketSession; import org.springframework.web.socket.handler.TestWebSocketSession; import org.springframework.web.socket.sockjs.transport.SockJsSession; @@ -290,23 +288,6 @@ public class StompSubProtocolHandlerTests { assertTrue(actual.getPayload().startsWith("ERROR")); } - // SPR-11621 - - @Test - public void availableSessionFieldsAfterSessionEnded() throws IOException { - Session nativeSession = Mockito.mock(Session.class); - when(nativeSession.getId()).thenReturn("1"); - when(nativeSession.getUserPrincipal()).thenReturn(new org.springframework.web.socket.handler.TestPrincipal("test")); - when(nativeSession.getNegotiatedSubprotocol()).thenReturn("v12.sToMp"); - StandardWebSocketSession standardWebsocketSession = new StandardWebSocketSession(null, null, null, null); - standardWebsocketSession.initializeNativeSession(nativeSession); - this.protocolHandler.afterSessionStarted(standardWebsocketSession, this.channel); - standardWebsocketSession.close(CloseStatus.NORMAL); - this.protocolHandler.afterSessionEnded(standardWebsocketSession, CloseStatus.NORMAL, this.channel); - assertEquals("v12.sToMp", standardWebsocketSession.getAcceptedProtocol()); - assertNotNull(standardWebsocketSession.getPrincipal()); - } - private static class UniqueUser extends TestPrincipal implements DestinationUserNameProvider { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java index a374e17b30..31cf82ac76 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java @@ -24,16 +24,9 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.SubscribableChannel; -import org.springframework.web.socket.CloseStatus; -import org.springframework.web.socket.adapter.standard.StandardWebSocketSession; import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator; -import org.springframework.web.socket.handler.TestPrincipal; import org.springframework.web.socket.handler.TestWebSocketSession; -import javax.websocket.Session; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; import static org.mockito.Mockito.*; /** @@ -59,9 +52,6 @@ public class SubProtocolWebSocketHandlerTests { @Mock SubscribableChannel outClientChannel; - @Mock - private Session nativeSession; - @Before public void setup() { @@ -150,21 +140,4 @@ public class SubProtocolWebSocketHandlerTests { this.webSocketHandler.afterConnectionEstablished(session); } - // SPR-11621 - - @Test - public void availableSessionFieldsafterConnectionClosed() throws Exception { - this.webSocketHandler.setDefaultProtocolHandler(stompHandler); - when(nativeSession.getId()).thenReturn("1"); - when(nativeSession.getUserPrincipal()).thenReturn(new TestPrincipal("test")); - when(nativeSession.getNegotiatedSubprotocol()).thenReturn("v12.sToMp"); - StandardWebSocketSession standardWebsocketSession = new StandardWebSocketSession(null, null, null, null); - standardWebsocketSession.initializeNativeSession(this.nativeSession); - this.webSocketHandler.afterConnectionEstablished(standardWebsocketSession); - standardWebsocketSession.close(CloseStatus.NORMAL); - this.webSocketHandler.afterConnectionClosed(standardWebsocketSession, CloseStatus.NORMAL); - assertEquals("v12.sToMp", standardWebsocketSession.getAcceptedProtocol()); - assertNotNull(standardWebsocketSession.getPrincipal()); - } - }