diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java index 8eb57594c38..b780cfb2de7 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java @@ -50,7 +50,9 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { private List extensions; - private final Principal user; + private Principal user; + + private String acceptedProtocol; /** @@ -105,7 +107,7 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { return this.user; } checkNativeSessionInitialized(); - return getNativeSession().getUpgradeRequest().getUserPrincipal(); + return (isOpen() ? getNativeSession().getUpgradeRequest().getUserPrincipal() : null); } @Override @@ -123,7 +125,7 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { @Override public String getAcceptedProtocol() { checkNativeSessionInitialized(); - return getNativeSession().getUpgradeResponse().getAcceptedSubProtocol(); + return this.acceptedProtocol; } @Override @@ -168,6 +170,15 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { return ((getNativeSession() != null) && getNativeSession().isOpen()); } + @Override + public void initializeNativeSession(Session session) { + super.initializeNativeSession(session); + if (this.user == null) { + this.user = session.getUpgradeRequest().getUserPrincipal(); + } + this.acceptedProtocol = session.getUpgradeResponse().getAcceptedSubProtocol(); + } + @Override protected void sendTextMessage(TextMessage message) throws IOException { getNativeSession().getRemote().sendString(message.getPayload()); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSession.java index d6f188a70d7..1ce37bc3d5e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSession.java @@ -54,7 +54,9 @@ public class StandardWebSocketSession extends AbstractWebSocketSession private final InetSocketAddress remoteAddress; - private final Principal user; + private Principal user; + + private String acceptedProtocol; private List extensions; @@ -120,7 +122,7 @@ public class StandardWebSocketSession extends AbstractWebSocketSession return this.user; } checkNativeSessionInitialized(); - return getNativeSession().getUserPrincipal(); + return (isOpen() ? getNativeSession().getUserPrincipal() : null); } @Override @@ -136,8 +138,7 @@ public class StandardWebSocketSession extends AbstractWebSocketSession @Override public String getAcceptedProtocol() { checkNativeSessionInitialized(); - String protocol = getNativeSession().getNegotiatedSubprotocol(); - return StringUtils.isEmpty(protocol)? null : protocol; + return this.acceptedProtocol; } @Override @@ -182,6 +183,15 @@ public class StandardWebSocketSession extends AbstractWebSocketSession return (getNativeSession() != null && getNativeSession().isOpen()); } + @Override + public void initializeNativeSession(Session session) { + super.initializeNativeSession(session); + if(this.user == null) { + this.user = session.getUserPrincipal(); + } + this.acceptedProtocol = session.getNegotiatedSubprotocol(); + } + @Override protected void sendTextMessage(TextMessage message) throws IOException { getNativeSession().getBasicRemote().sendText(message.getPayload(), message.isLast()); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapterTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketHandlerAdapterTests.java similarity index 86% rename from spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapterTests.java rename to spring-websocket/src/test/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketHandlerAdapterTests.java index 964fc967f88..d4d55e822ef 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapterTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketHandlerAdapterTests.java @@ -14,11 +14,14 @@ * limitations under the License. */ -package org.springframework.web.socket.adapter; +package org.springframework.web.socket.adapter.jetty; import org.eclipse.jetty.websocket.api.Session; +import org.eclipse.jetty.websocket.api.UpgradeRequest; +import org.eclipse.jetty.websocket.api.UpgradeResponse; import org.junit.Before; import org.junit.Test; +import org.mockito.Mockito; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.adapter.jetty.JettyWebSocketHandlerAdapter; @@ -45,6 +48,9 @@ public class JettyWebSocketHandlerAdapterTests { @Before public void setup() { this.session = mock(Session.class); + when(this.session.getUpgradeRequest()).thenReturn(Mockito.mock(UpgradeRequest.class)); + when(this.session.getUpgradeResponse()).thenReturn(Mockito.mock(UpgradeResponse.class)); + this.webSocketHandler = mock(WebSocketHandler.class); this.webSocketSession = new JettyWebSocketSession(null, null); this.adapter = new JettyWebSocketHandlerAdapter(this.webSocketHandler, this.webSocketSession); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSessionTests.java new file mode 100644 index 00000000000..74b4273217f --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSessionTests.java @@ -0,0 +1,132 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.adapter.jetty; + +import org.eclipse.jetty.websocket.api.Session; +import org.eclipse.jetty.websocket.api.UpgradeRequest; +import org.eclipse.jetty.websocket.api.UpgradeResponse; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; +import org.springframework.web.socket.handler.TestPrincipal; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +/** + * Unit tests for {@link org.springframework.web.socket.adapter.jetty.JettyWebSocketSession}. + * + * @author Rossen Stoyanchev + */ +public class JettyWebSocketSessionTests { + + private Map attributes; + + + @Before + public void setup() { + this.attributes = new HashMap<>(); + } + + + @Test + public void getPrincipalWithConstructorArg() { + TestPrincipal user = new TestPrincipal("joe"); + JettyWebSocketSession session = new JettyWebSocketSession(attributes, user); + + assertSame(user, session.getPrincipal()); + } + + @Test + public void getPrincipalFromNativeSession() { + + TestPrincipal user = new TestPrincipal("joe"); + + UpgradeRequest request = Mockito.mock(UpgradeRequest.class); + when(request.getUserPrincipal()).thenReturn(user); + + UpgradeResponse response = Mockito.mock(UpgradeResponse.class); + when(response.getAcceptedSubProtocol()).thenReturn(null); + + Session nativeSession = Mockito.mock(Session.class); + when(nativeSession.getUpgradeRequest()).thenReturn(request); + when(nativeSession.getUpgradeResponse()).thenReturn(response); + + JettyWebSocketSession session = new JettyWebSocketSession(attributes); + session.initializeNativeSession(nativeSession); + + reset(nativeSession); + + assertSame(user, session.getPrincipal()); + verifyNoMoreInteractions(nativeSession); + } + + @Test + public void getPrincipalNotAvailable() { + + UpgradeRequest request = Mockito.mock(UpgradeRequest.class); + when(request.getUserPrincipal()).thenReturn(null); + + UpgradeResponse response = Mockito.mock(UpgradeResponse.class); + when(response.getAcceptedSubProtocol()).thenReturn(null); + + Session nativeSession = Mockito.mock(Session.class); + when(nativeSession.getUpgradeRequest()).thenReturn(request); + when(nativeSession.getUpgradeResponse()).thenReturn(response); + + JettyWebSocketSession session = new JettyWebSocketSession(attributes); + session.initializeNativeSession(nativeSession); + + reset(nativeSession); + + assertNull(session.getPrincipal()); + verify(nativeSession).isOpen(); + verifyNoMoreInteractions(nativeSession); + } + + @Test + public void getAcceptedProtocol() { + + String protocol = "foo"; + + UpgradeRequest request = Mockito.mock(UpgradeRequest.class); + when(request.getUserPrincipal()).thenReturn(null); + + UpgradeResponse response = Mockito.mock(UpgradeResponse.class); + when(response.getAcceptedSubProtocol()).thenReturn(protocol); + + Session nativeSession = Mockito.mock(Session.class); + when(nativeSession.getUpgradeRequest()).thenReturn(request); + when(nativeSession.getUpgradeResponse()).thenReturn(response); + + JettyWebSocketSession session = new JettyWebSocketSession(attributes); + session.initializeNativeSession(nativeSession); + + reset(nativeSession); + + assertSame(protocol, session.getAcceptedProtocol()); + verifyNoMoreInteractions(nativeSession); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/ConvertingEncoderDecoderSupportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/ConvertingEncoderDecoderSupportTests.java similarity index 99% rename from spring-websocket/src/test/java/org/springframework/web/socket/adapter/ConvertingEncoderDecoderSupportTests.java rename to spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/ConvertingEncoderDecoderSupportTests.java index f08b2bb2648..24a4ddc151d 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/ConvertingEncoderDecoderSupportTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/ConvertingEncoderDecoderSupportTests.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.web.socket.adapter; +package org.springframework.web.socket.adapter.standard; import java.nio.ByteBuffer; import javax.websocket.DecodeException; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapterTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketHandlerAdapterTests.java similarity index 97% rename from spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapterTests.java rename to spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketHandlerAdapterTests.java index 19366b6c9c1..fb4ea6b3c9e 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapterTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketHandlerAdapterTests.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.web.socket.adapter; +package org.springframework.web.socket.adapter.standard; import javax.websocket.CloseReason; import javax.websocket.CloseReason.CloseCodes; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSessionTests.java new file mode 100644 index 00000000000..ef2cf002beb --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSessionTests.java @@ -0,0 +1,110 @@ +/* Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.adapter.standard; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; +import org.springframework.http.HttpHeaders; +import org.springframework.web.socket.handler.TestPrincipal; + +import javax.websocket.Session; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +/** + * Unit tests for {@link org.springframework.web.socket.adapter.standard.StandardWebSocketSession}. + * + * @author Rossen Stoyanchev + */ +public class StandardWebSocketSessionTests { + + private HttpHeaders headers; + + private Map attributes; + + + @Before + public void setup() { + this.headers = new HttpHeaders(); + this.attributes = new HashMap<>(); + } + + + @Test + public void getPrincipalWithConstructorArg() { + TestPrincipal user = new TestPrincipal("joe"); + StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null, user); + + assertSame(user, session.getPrincipal()); + } + + @Test + public void getPrincipalWithNativeSession() { + + TestPrincipal user = new TestPrincipal("joe"); + + Session nativeSession = Mockito.mock(Session.class); + when(nativeSession.getUserPrincipal()).thenReturn(user); + + StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null); + session.initializeNativeSession(nativeSession); + + assertSame(user, session.getPrincipal()); + } + + @Test + public void getPrincipalNone() { + + Session nativeSession = Mockito.mock(Session.class); + when(nativeSession.getUserPrincipal()).thenReturn(null); + + StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null); + session.initializeNativeSession(nativeSession); + + reset(nativeSession); + + assertNull(session.getPrincipal()); + verify(nativeSession).isOpen(); + verifyNoMoreInteractions(nativeSession); + } + + @Test + public void getAcceptedProtocol() { + + String protocol = "foo"; + + Session nativeSession = Mockito.mock(Session.class); + when(nativeSession.getNegotiatedSubprotocol()).thenReturn(protocol); + + StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null); + session.initializeNativeSession(nativeSession); + + reset(nativeSession); + + assertEquals(protocol, session.getAcceptedProtocol()); + verifyNoMoreInteractions(nativeSession); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index 5aaf1127abf..426a0599ab4 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -16,6 +16,7 @@ package org.springframework.web.socket.messaging; +import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java index b021199e189..31cf82ac766 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java @@ -21,11 +21,9 @@ import java.util.Arrays; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.SubscribableChannel; -import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator; import org.springframework.web.socket.handler.TestWebSocketSession;