diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSession.java index 469b0150ca3..a5e78b6ae01 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSession.java @@ -42,15 +42,13 @@ public abstract class AbstractWebSocketSession implements NativeWebSocketSess protected static final Log logger = LogFactory.getLog(NativeWebSocketSession.class); + private final Map attributes = new ConcurrentHashMap<>(); private T nativeSession; - private final Map attributes = new ConcurrentHashMap<>(); - /** * Create a new instance and associate the given attributes with it. - * * @param attributes attributes from the HTTP handshake to associate with the WebSocket * session; the provided attributes are copied, the original map is not used. */ @@ -83,7 +81,7 @@ public abstract class AbstractWebSocketSession implements NativeWebSocketSess } public void initializeNativeSession(T session) { - Assert.notNull(session, "session must not be null"); + Assert.notNull(session, "WebSocket session must not be null"); this.nativeSession = session; } @@ -125,6 +123,7 @@ public abstract class AbstractWebSocketSession implements NativeWebSocketSess protected abstract void sendPongMessage(PongMessage message) throws IOException; + @Override public final void close() throws IOException { close(CloseStatus.NORMAL); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java index bc6030f8e5b..4062db5af1d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java @@ -45,25 +45,45 @@ import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.AbstractWebSocketSession; /** - * A {@link WebSocketSession} for use with the Jetty 9 WebSocket API. + * A {@link WebSocketSession} for use with the Jetty 9.3/9.4 WebSocket API. * * @author Phillip Webb * @author Rossen Stoyanchev * @author Brian Clozel + * @author Juergen Hoeller * @since 4.0 */ public class JettyWebSocketSession extends AbstractWebSocketSession { // As of Jetty 9.4, UpgradeRequest and UpgradeResponse are interfaces instead of classes - private static final boolean isJetty94; + private static final boolean directInterfaceCalls; private static Method getUpgradeRequest; private static Method getUpgradeResponse; private static Method getRequestURI; private static Method getHeaders; + private static Method getUserPrincipal; private static Method getAcceptedSubProtocol; private static Method getExtensions; - private static Method getUserPrincipal; + + static { + directInterfaceCalls = UpgradeRequest.class.isInterface(); + if (!directInterfaceCalls) { + try { + getUpgradeRequest = Session.class.getMethod("getUpgradeRequest"); + getUpgradeResponse = Session.class.getMethod("getUpgradeResponse"); + getRequestURI = UpgradeRequest.class.getMethod("getRequestURI"); + getHeaders = UpgradeRequest.class.getMethod("getHeaders"); + getUserPrincipal = UpgradeRequest.class.getMethod("getUserPrincipal"); + getAcceptedSubProtocol = UpgradeResponse.class.getMethod("getAcceptedSubProtocol"); + getExtensions = UpgradeResponse.class.getMethod("getExtensions"); + } + catch (NoSuchMethodException ex) { + throw new IllegalStateException("Incompatible Jetty API", ex); + } + } + } + private String id; @@ -77,27 +97,9 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { private Principal user; - static { - isJetty94 = UpgradeRequest.class.isInterface(); - if (!isJetty94) { - try { - getUpgradeRequest = Session.class.getMethod("getUpgradeRequest"); - getUpgradeResponse = Session.class.getMethod("getUpgradeResponse"); - getRequestURI = UpgradeRequest.class.getMethod("getRequestURI"); - getHeaders = UpgradeRequest.class.getMethod("getHeaders"); - getAcceptedSubProtocol = UpgradeResponse.class.getMethod("getAcceptedSubProtocol"); - getExtensions = UpgradeResponse.class.getMethod("getExtensions"); - getUserPrincipal = UpgradeRequest.class.getMethod("getUserPrincipal"); - } - catch (NoSuchMethodException ex) { - throw new IllegalStateException("Incompatible Jetty API", ex); - } - } - } /** * Create a new {@link JettyWebSocketSession} instance. - * * @param attributes attributes from the HTTP handshake to associate with the WebSocket session */ public JettyWebSocketSession(Map attributes) { @@ -106,11 +108,10 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { /** * Create a new {@link JettyWebSocketSession} instance associated with the given user. - * * @param attributes attributes from the HTTP handshake to associate with the WebSocket * session; the provided attributes are copied, the original map is not used. - * @param user the user associated with the session; if {@code null} we'll fallback on the user - * available via {@link org.eclipse.jetty.websocket.api.Session#getUpgradeRequest()} + * @param user the user associated with the session; if {@code null} we'll fallback on the + * user available via {@link org.eclipse.jetty.websocket.api.Session#getUpgradeRequest()} */ public JettyWebSocketSession(Map attributes, Principal user) { super(attributes); @@ -191,23 +192,49 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { @Override public boolean isOpen() { - return ((getNativeSession() != null) && getNativeSession().isOpen()); + return (getNativeSession() != null && getNativeSession().isOpen()); } + @Override public void initializeNativeSession(Session session) { super.initializeNativeSession(session); - if (isJetty94) { - initializeJetty94Session(session); + if (directInterfaceCalls) { + initializeJettySessionDirectly(session); } else { - initializeJettySession(session); + initializeJettySessionReflectively(session); + } + } + + private void initializeJettySessionDirectly(Session session) { + this.id = ObjectUtils.getIdentityHexString(getNativeSession()); + this.uri = session.getUpgradeRequest().getRequestURI(); + + this.headers = new HttpHeaders(); + this.headers.putAll(session.getUpgradeRequest().getHeaders()); + this.headers = HttpHeaders.readOnlyHttpHeaders(headers); + + this.acceptedProtocol = session.getUpgradeResponse().getAcceptedSubProtocol(); + + List source = session.getUpgradeResponse().getExtensions(); + if (source != null) { + this.extensions = new ArrayList<>(source.size()); + for (ExtensionConfig ec : source) { + this.extensions.add(new WebSocketExtension(ec.getName(), ec.getParameters())); + } + } + else { + this.extensions = new ArrayList<>(0); + } + + if (this.user == null) { + this.user = session.getUpgradeRequest().getUserPrincipal(); } } @SuppressWarnings("unchecked") - private void initializeJettySession(Session session) { - + private void initializeJettySessionReflectively(Session session) { Object request = ReflectionUtils.invokeMethod(getUpgradeRequest, session); Object response = ReflectionUtils.invokeMethod(getUpgradeResponse, session); @@ -236,31 +263,6 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { } } - private void initializeJetty94Session(Session session) { - this.id = ObjectUtils.getIdentityHexString(getNativeSession()); - this.uri = session.getUpgradeRequest().getRequestURI(); - - this.headers = new HttpHeaders(); - this.headers.putAll(session.getUpgradeRequest().getHeaders()); - this.headers = HttpHeaders.readOnlyHttpHeaders(headers); - - this.acceptedProtocol = session.getUpgradeResponse().getAcceptedSubProtocol(); - - List source = session.getUpgradeResponse().getExtensions(); - if (source != null) { - this.extensions = new ArrayList<>(source.size()); - for (ExtensionConfig ec : source) { - this.extensions.add(new WebSocketExtension(ec.getName(), ec.getParameters())); - } - } - else { - this.extensions = new ArrayList<>(0); - } - - if (this.user == null) { - this.user = session.getUpgradeRequest().getUserPrincipal(); - } - } @Override protected void sendTextMessage(TextMessage message) throws IOException { @@ -287,7 +289,7 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { return getNativeSession().getRemote(); } catch (WebSocketException ex) { - throw new IOException("Unable to obtain RemoteEndpoint in session=" + getId(), ex); + throw new IOException("Unable to obtain RemoteEndpoint in session " + getId(), ex); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java index 9bf3260f4e9..937c5bdd7bf 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java @@ -21,7 +21,6 @@ import java.security.Principal; import java.util.ArrayList; import java.util.List; import java.util.Map; - import javax.servlet.ServletContext; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -60,75 +59,59 @@ import org.springframework.web.socket.server.RequestUpgradeStrategy; * @author Phillip Webb * @author Rossen Stoyanchev * @author Brian Clozel + * @author Juergen Hoeller * @since 4.0 */ -public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Lifecycle, ServletContextAware { +public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, ServletContextAware, Lifecycle { private static final ThreadLocal wsContainerHolder = new NamedThreadLocal<>("WebSocket Handler Container"); - // Actually 9.3.15+ - private static boolean isJetty94 = ClassUtils.hasConstructor(WebSocketServerFactory.class, ServletContext.class); + private final WebSocketServerFactoryAdapter factoryAdapter = + (ClassUtils.hasConstructor(WebSocketServerFactory.class, ServletContext.class) ? + new ModernJettyWebSocketServerFactoryAdapter() : new LegacyJettyWebSocketServerFactoryAdapter()); - private WebSocketServerFactoryAdapter factoryAdapter; + private ServletContext servletContext; + + private volatile boolean running = false; private volatile List supportedExtensions; - protected ServletContext servletContext; - - private volatile boolean running = false; /** * Default constructor that creates {@link WebSocketServerFactory} through * its default constructor thus using a default {@link WebSocketPolicy}. */ public JettyRequestUpgradeStrategy() { - this(WebSocketPolicy.newServerPolicy()); + this.factoryAdapter.setPolicy(WebSocketPolicy.newServerPolicy()); } /** - * A constructor accepting a {@link WebSocketPolicy} - * to be used when creating the {@link WebSocketServerFactory} instance. - * @since 4.3 + * A constructor accepting a {@link WebSocketPolicy} to be used when + * creating the {@link WebSocketServerFactory} instance. + * @param policy the policy to use + * @since 4.3.5 */ - public JettyRequestUpgradeStrategy(WebSocketPolicy webSocketPolicy) { - this.factoryAdapter = isJetty94 ? new Jetty94WebSocketServerFactoryAdapter() - : new JettyWebSocketServerFactoryAdapter(); - this.factoryAdapter.setWebSocketPolicy(webSocketPolicy); + public JettyRequestUpgradeStrategy(WebSocketPolicy policy) { + Assert.notNull(policy, "WebSocketPolicy must not be null"); + this.factoryAdapter.setPolicy(policy); } - @Override - public String[] getSupportedVersions() { - return new String[] {String.valueOf(HandshakeRFC6455.VERSION)}; + /** + * A constructor accepting a {@link WebSocketServerFactory}. + * @param factory the pre-configured factory to use + */ + public JettyRequestUpgradeStrategy(WebSocketServerFactory factory) { + Assert.notNull(factory, "WebSocketServerFactory must not be null"); + this.factoryAdapter.setFactory(factory); } - @Override - public List getSupportedExtensions(ServerHttpRequest request) { - if (this.supportedExtensions == null) { - this.supportedExtensions = getWebSocketExtensions(); - } - return this.supportedExtensions; - } - - private List getWebSocketExtensions() { - List result = new ArrayList<>(); - for (String name : this.factoryAdapter.getFactory().getExtensionFactory().getExtensionNames()) { - result.add(new WebSocketExtension(name)); - } - return result; - } @Override public void setServletContext(ServletContext servletContext) { this.servletContext = servletContext; } - @Override - public boolean isRunning() { - return this.running; - } - - @Override public void start() { if (!isRunning()) { @@ -136,7 +119,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life try { this.factoryAdapter.start(); } - catch (Exception ex) { + catch (Throwable ex) { throw new IllegalStateException("Unable to start Jetty WebSocketServerFactory", ex); } } @@ -149,12 +132,39 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life this.running = false; this.factoryAdapter.stop(); } - catch (Exception ex) { + catch (Throwable ex) { throw new IllegalStateException("Unable to stop Jetty WebSocketServerFactory", ex); } } } + @Override + public boolean isRunning() { + return this.running; + } + + + @Override + public String[] getSupportedVersions() { + return new String[] { String.valueOf(HandshakeRFC6455.VERSION) }; + } + + @Override + public List getSupportedExtensions(ServerHttpRequest request) { + if (this.supportedExtensions == null) { + this.supportedExtensions = buildWebSocketExtensions(); + } + return this.supportedExtensions; + } + + private List buildWebSocketExtensions() { + List result = new ArrayList<>(); + for (String name : this.factoryAdapter.getFactory().getExtensionFactory().getExtensionNames()) { + result.add(new WebSocketExtension(name)); + } + return result; + } + @Override public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String selectedProtocol, List selectedExtensions, Principal user, @@ -197,7 +207,9 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life private final List extensionConfigs; - public WebSocketHandlerContainer(JettyWebSocketHandlerAdapter handler, String protocol, List extensions) { + public WebSocketHandlerContainer( + JettyWebSocketHandlerAdapter handler, String protocol, List extensions) { + this.handler = handler; this.selectedProtocol = protocol; if (CollectionUtils.isEmpty(extensions)) { @@ -224,21 +236,29 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life } } + private static abstract class WebSocketServerFactoryAdapter { - protected WebSocketServerFactory factory; + private WebSocketPolicy policy; - protected WebSocketPolicy webSocketPolicy; + private WebSocketServerFactory factory; + + public void setPolicy(WebSocketPolicy policy) { + this.policy = policy; + } + + public void setFactory(WebSocketServerFactory factory) { + this.factory = factory; + } public WebSocketServerFactory getFactory() { - return factory; + return this.factory; } - public void setWebSocketPolicy(WebSocketPolicy webSocketPolicy) { - this.webSocketPolicy = webSocketPolicy; - } - - protected void configureFactory() { + public void start() throws Exception { + if (this.factory == null) { + this.factory = createFactory(this.policy); + } this.factory.setCreator(new WebSocketCreator() { @Override public Object createWebSocket(ServletUpgradeRequest request, ServletUpgradeResponse response) { @@ -249,43 +269,60 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life return container.getHandler(); } }); + startFactory(this.factory); } - abstract void start() throws Exception; + public void stop() throws Exception { + if (this.factory != null) { + stopFactory(this.factory); + } + } - abstract void stop() throws Exception; + protected abstract WebSocketServerFactory createFactory(WebSocketPolicy policy) throws Exception; + + protected abstract void startFactory(WebSocketServerFactory factory) throws Exception; + + protected abstract void stopFactory(WebSocketServerFactory factory) throws Exception; } - private class JettyWebSocketServerFactoryAdapter extends WebSocketServerFactoryAdapter { + + // Jetty 9.3.15+ + private class ModernJettyWebSocketServerFactoryAdapter extends WebSocketServerFactoryAdapter { @Override - void start() throws Exception { - this.factory = WebSocketServerFactory.class.getConstructor(WebSocketPolicy.class) - .newInstance(this.webSocketPolicy); - configureFactory(); - WebSocketServerFactory.class.getMethod("init", ServletContext.class) - .invoke(this.factory, servletContext); - } - - @Override - void stop() throws Exception { - WebSocketServerFactory.class.getMethod("cleanup").invoke(this.factory); - } - } - - private class Jetty94WebSocketServerFactoryAdapter extends WebSocketServerFactoryAdapter { - - @Override - void start() throws Exception { + protected WebSocketServerFactory createFactory(WebSocketPolicy policy) throws Exception { servletContext.setAttribute(DecoratedObjectFactory.ATTR, new DecoratedObjectFactory()); - this.factory = new WebSocketServerFactory(servletContext, this.webSocketPolicy); - configureFactory(); - this.factory.start(); + return new WebSocketServerFactory(servletContext, policy); } @Override - void stop() throws Exception { - this.factory.stop(); + protected void startFactory(WebSocketServerFactory factory) throws Exception { + factory.start(); + } + + @Override + protected void stopFactory(WebSocketServerFactory factory) throws Exception { + factory.stop(); + } + } + + + // Jetty <9.3.15 + private class LegacyJettyWebSocketServerFactoryAdapter extends WebSocketServerFactoryAdapter { + + @Override + protected WebSocketServerFactory createFactory(WebSocketPolicy policy) throws Exception { + return WebSocketServerFactory.class.getConstructor(WebSocketPolicy.class).newInstance(policy); + } + + @Override + protected void startFactory(WebSocketServerFactory factory) throws Exception { + WebSocketServerFactory.class.getMethod("init", ServletContext.class).invoke(factory, servletContext); + } + + @Override + protected void stopFactory(WebSocketServerFactory factory) throws Exception { + WebSocketServerFactory.class.getMethod("cleanup").invoke(factory); } }