Improve access to WebSocketSession fields

Ensure the Standard- and the JettyWebSocketSession can return the
Principal and accepted WebSocket sub-protocol even after the
session is closed.

Issue: SPR-11621
This commit is contained in:
Rossen Stoyanchev 2014-03-31 16:35:40 +02:00
parent e21c47d4ce
commit 7b014eaa55
9 changed files with 280 additions and 12 deletions

View File

@ -50,7 +50,9 @@ public class JettyWebSocketSession extends AbstractWebSocketSession<Session> {
private List<WebSocketExtension> extensions;
private final Principal user;
private Principal user;
private String acceptedProtocol;
/**
@ -105,7 +107,7 @@ public class JettyWebSocketSession extends AbstractWebSocketSession<Session> {
return this.user;
}
checkNativeSessionInitialized();
return getNativeSession().getUpgradeRequest().getUserPrincipal();
return (isOpen() ? getNativeSession().getUpgradeRequest().getUserPrincipal() : null);
}
@Override
@ -123,7 +125,7 @@ public class JettyWebSocketSession extends AbstractWebSocketSession<Session> {
@Override
public String getAcceptedProtocol() {
checkNativeSessionInitialized();
return getNativeSession().getUpgradeResponse().getAcceptedSubProtocol();
return this.acceptedProtocol;
}
@Override
@ -168,6 +170,15 @@ public class JettyWebSocketSession extends AbstractWebSocketSession<Session> {
return ((getNativeSession() != null) && getNativeSession().isOpen());
}
@Override
public void initializeNativeSession(Session session) {
super.initializeNativeSession(session);
if (this.user == null) {
this.user = session.getUpgradeRequest().getUserPrincipal();
}
this.acceptedProtocol = session.getUpgradeResponse().getAcceptedSubProtocol();
}
@Override
protected void sendTextMessage(TextMessage message) throws IOException {
getNativeSession().getRemote().sendString(message.getPayload());

View File

@ -54,7 +54,9 @@ public class StandardWebSocketSession extends AbstractWebSocketSession<Session>
private final InetSocketAddress remoteAddress;
private final Principal user;
private Principal user;
private String acceptedProtocol;
private List<WebSocketExtension> extensions;
@ -120,7 +122,7 @@ public class StandardWebSocketSession extends AbstractWebSocketSession<Session>
return this.user;
}
checkNativeSessionInitialized();
return getNativeSession().getUserPrincipal();
return (isOpen() ? getNativeSession().getUserPrincipal() : null);
}
@Override
@ -136,8 +138,7 @@ public class StandardWebSocketSession extends AbstractWebSocketSession<Session>
@Override
public String getAcceptedProtocol() {
checkNativeSessionInitialized();
String protocol = getNativeSession().getNegotiatedSubprotocol();
return StringUtils.isEmpty(protocol)? null : protocol;
return this.acceptedProtocol;
}
@Override
@ -182,6 +183,15 @@ public class StandardWebSocketSession extends AbstractWebSocketSession<Session>
return (getNativeSession() != null && getNativeSession().isOpen());
}
@Override
public void initializeNativeSession(Session session) {
super.initializeNativeSession(session);
if(this.user == null) {
this.user = session.getUserPrincipal();
}
this.acceptedProtocol = session.getNegotiatedSubprotocol();
}
@Override
protected void sendTextMessage(TextMessage message) throws IOException {
getNativeSession().getBasicRemote().sendText(message.getPayload(), message.isLast());

View File

@ -14,11 +14,14 @@
* limitations under the License.
*/
package org.springframework.web.socket.adapter;
package org.springframework.web.socket.adapter.jetty;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.UpgradeResponse;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.adapter.jetty.JettyWebSocketHandlerAdapter;
@ -45,6 +48,9 @@ public class JettyWebSocketHandlerAdapterTests {
@Before
public void setup() {
this.session = mock(Session.class);
when(this.session.getUpgradeRequest()).thenReturn(Mockito.mock(UpgradeRequest.class));
when(this.session.getUpgradeResponse()).thenReturn(Mockito.mock(UpgradeResponse.class));
this.webSocketHandler = mock(WebSocketHandler.class);
this.webSocketSession = new JettyWebSocketSession(null, null);
this.adapter = new JettyWebSocketHandlerAdapter(this.webSocketHandler, this.webSocketSession);

View File

@ -0,0 +1,132 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.adapter.jetty;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.UpgradeResponse;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
import org.springframework.web.socket.handler.TestPrincipal;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
/**
* Unit tests for {@link org.springframework.web.socket.adapter.jetty.JettyWebSocketSession}.
*
* @author Rossen Stoyanchev
*/
public class JettyWebSocketSessionTests {
private Map<String,Object> attributes;
@Before
public void setup() {
this.attributes = new HashMap<>();
}
@Test
public void getPrincipalWithConstructorArg() {
TestPrincipal user = new TestPrincipal("joe");
JettyWebSocketSession session = new JettyWebSocketSession(attributes, user);
assertSame(user, session.getPrincipal());
}
@Test
public void getPrincipalFromNativeSession() {
TestPrincipal user = new TestPrincipal("joe");
UpgradeRequest request = Mockito.mock(UpgradeRequest.class);
when(request.getUserPrincipal()).thenReturn(user);
UpgradeResponse response = Mockito.mock(UpgradeResponse.class);
when(response.getAcceptedSubProtocol()).thenReturn(null);
Session nativeSession = Mockito.mock(Session.class);
when(nativeSession.getUpgradeRequest()).thenReturn(request);
when(nativeSession.getUpgradeResponse()).thenReturn(response);
JettyWebSocketSession session = new JettyWebSocketSession(attributes);
session.initializeNativeSession(nativeSession);
reset(nativeSession);
assertSame(user, session.getPrincipal());
verifyNoMoreInteractions(nativeSession);
}
@Test
public void getPrincipalNotAvailable() {
UpgradeRequest request = Mockito.mock(UpgradeRequest.class);
when(request.getUserPrincipal()).thenReturn(null);
UpgradeResponse response = Mockito.mock(UpgradeResponse.class);
when(response.getAcceptedSubProtocol()).thenReturn(null);
Session nativeSession = Mockito.mock(Session.class);
when(nativeSession.getUpgradeRequest()).thenReturn(request);
when(nativeSession.getUpgradeResponse()).thenReturn(response);
JettyWebSocketSession session = new JettyWebSocketSession(attributes);
session.initializeNativeSession(nativeSession);
reset(nativeSession);
assertNull(session.getPrincipal());
verify(nativeSession).isOpen();
verifyNoMoreInteractions(nativeSession);
}
@Test
public void getAcceptedProtocol() {
String protocol = "foo";
UpgradeRequest request = Mockito.mock(UpgradeRequest.class);
when(request.getUserPrincipal()).thenReturn(null);
UpgradeResponse response = Mockito.mock(UpgradeResponse.class);
when(response.getAcceptedSubProtocol()).thenReturn(protocol);
Session nativeSession = Mockito.mock(Session.class);
when(nativeSession.getUpgradeRequest()).thenReturn(request);
when(nativeSession.getUpgradeResponse()).thenReturn(response);
JettyWebSocketSession session = new JettyWebSocketSession(attributes);
session.initializeNativeSession(nativeSession);
reset(nativeSession);
assertSame(protocol, session.getAcceptedProtocol());
verifyNoMoreInteractions(nativeSession);
}
}

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package org.springframework.web.socket.adapter;
package org.springframework.web.socket.adapter.standard;
import java.nio.ByteBuffer;
import javax.websocket.DecodeException;

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package org.springframework.web.socket.adapter;
package org.springframework.web.socket.adapter.standard;
import javax.websocket.CloseReason;
import javax.websocket.CloseReason.CloseCodes;

View File

@ -0,0 +1,110 @@
/* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.adapter.standard;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.handler.TestPrincipal;
import javax.websocket.Session;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
/**
* Unit tests for {@link org.springframework.web.socket.adapter.standard.StandardWebSocketSession}.
*
* @author Rossen Stoyanchev
*/
public class StandardWebSocketSessionTests {
private HttpHeaders headers;
private Map<String,Object> attributes;
@Before
public void setup() {
this.headers = new HttpHeaders();
this.attributes = new HashMap<>();
}
@Test
public void getPrincipalWithConstructorArg() {
TestPrincipal user = new TestPrincipal("joe");
StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null, user);
assertSame(user, session.getPrincipal());
}
@Test
public void getPrincipalWithNativeSession() {
TestPrincipal user = new TestPrincipal("joe");
Session nativeSession = Mockito.mock(Session.class);
when(nativeSession.getUserPrincipal()).thenReturn(user);
StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null);
session.initializeNativeSession(nativeSession);
assertSame(user, session.getPrincipal());
}
@Test
public void getPrincipalNone() {
Session nativeSession = Mockito.mock(Session.class);
when(nativeSession.getUserPrincipal()).thenReturn(null);
StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null);
session.initializeNativeSession(nativeSession);
reset(nativeSession);
assertNull(session.getPrincipal());
verify(nativeSession).isOpen();
verifyNoMoreInteractions(nativeSession);
}
@Test
public void getAcceptedProtocol() {
String protocol = "foo";
Session nativeSession = Mockito.mock(Session.class);
when(nativeSession.getNegotiatedSubprotocol()).thenReturn(protocol);
StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null);
session.initializeNativeSession(nativeSession);
reset(nativeSession);
assertEquals(protocol, session.getAcceptedProtocol());
verifyNoMoreInteractions(nativeSession);
}
}

View File

@ -16,6 +16,7 @@
package org.springframework.web.socket.messaging;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;

View File

@ -21,11 +21,9 @@ import java.util.Arrays;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
import org.springframework.web.socket.handler.TestWebSocketSession;