From d6895aa09846a0e1e1904d17f00ea53c73fc104f Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Mon, 12 Dec 2016 17:54:24 -0500 Subject: [PATCH] Consistently extend WebSocketHandlerAdapterSupport The WebSocketHandler adapters for all runtimes now extend WebSocketHandlerAdapterSupport, which now also exposes a shared DataBufferFactory property initialized from the response. Issue: SPR-14527 --- .../adapter/JettyWebSocketHandlerAdapter.java | 25 ++-- .../ReactorNettyWebSocketHandlerAdapter.java | 13 +- .../RxNettyWebSocketHandlerAdapter.java | 12 +- .../TomcatWebSocketHandlerAdapter.java | 131 +++++++++--------- .../UndertowWebSocketHandlerAdapter.java | 52 +++---- .../adapter/UndertowWebSocketSession.java | 5 +- .../WebSocketHandlerAdapterSupport.java | 26 +++- .../upgrade/JettyRequestUpgradeStrategy.java | 92 ++++++------ .../upgrade/TomcatRequestUpgradeStrategy.java | 51 +++---- .../UndertowRequestUpgradeStrategy.java | 26 ++-- 10 files changed, 205 insertions(+), 228 deletions(-) diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java index 4d5c08816f4..a7e9c17c469 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java @@ -32,9 +32,8 @@ import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import org.springframework.core.io.buffer.DataBuffer; -import org.springframework.core.io.buffer.DataBufferFactory; -import org.springframework.core.io.buffer.DefaultDataBufferFactory; -import org.springframework.util.Assert; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.web.reactive.socket.CloseStatus; import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketMessage; @@ -48,21 +47,17 @@ import org.springframework.web.reactive.socket.WebSocketMessage.Type; * @since 5.0 */ @WebSocket -public class JettyWebSocketHandlerAdapter { +public class JettyWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport { private static final ByteBuffer EMPTY_PAYLOAD = ByteBuffer.wrap(new byte[0]); - - private final WebSocketHandler delegate; - private JettyWebSocketSession session; - private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(false); + public JettyWebSocketHandlerAdapter(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler delegate) { - public JettyWebSocketHandlerAdapter(WebSocketHandler delegate) { - Assert.notNull("WebSocketHandler is required"); - this.delegate = delegate; + super(request, response, delegate); } @@ -71,7 +66,7 @@ public class JettyWebSocketHandlerAdapter { this.session = new JettyWebSocketSession(session); HandlerResultSubscriber subscriber = new HandlerResultSubscriber(); - this.delegate.handle(this.session).subscribe(subscriber); + getDelegate().handle(this.session).subscribe(subscriber); } @OnWebSocketMessage @@ -105,15 +100,15 @@ public class JettyWebSocketHandlerAdapter { private WebSocketMessage toMessage(Type type, T message) { if (Type.TEXT.equals(type)) { byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8); - DataBuffer buffer = this.bufferFactory.wrap(bytes); + DataBuffer buffer = getBufferFactory().wrap(bytes); return WebSocketMessage.create(Type.TEXT, buffer); } else if (Type.BINARY.equals(type)) { - DataBuffer buffer = this.bufferFactory.wrap((ByteBuffer) message); + DataBuffer buffer = getBufferFactory().wrap((ByteBuffer) message); return WebSocketMessage.create(Type.BINARY, buffer); } else if (Type.PONG.equals(type)) { - DataBuffer buffer = this.bufferFactory.wrap((ByteBuffer) message); + DataBuffer buffer = getBufferFactory().wrap((ByteBuffer) message); return WebSocketMessage.create(Type.PONG, buffer); } else { diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/ReactorNettyWebSocketHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/ReactorNettyWebSocketHandlerAdapter.java index faaf90c2b5f..56b6e779941 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/ReactorNettyWebSocketHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/ReactorNettyWebSocketHandlerAdapter.java @@ -21,10 +21,8 @@ import org.reactivestreams.Publisher; import reactor.ipc.netty.http.HttpInbound; import reactor.ipc.netty.http.HttpOutbound; -import org.springframework.core.io.buffer.NettyDataBufferFactory; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpResponse; -import org.springframework.util.Assert; import org.springframework.web.reactive.socket.WebSocketHandler; /** @@ -38,22 +36,13 @@ public class ReactorNettyWebSocketHandlerAdapter extends WebSocketHandlerAdapter implements BiFunction> { - private final NettyDataBufferFactory bufferFactory; - - public ReactorNettyWebSocketHandlerAdapter(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler handler) { - super(request, handler); - Assert.notNull("'response' is required"); - this.bufferFactory = (NettyDataBufferFactory) response.bufferFactory(); + super(request, response, handler); } - public NettyDataBufferFactory getBufferFactory() { - return this.bufferFactory; - } - @Override public Publisher apply(HttpInbound inbound, HttpOutbound outbound) { ReactorNettyWebSocketSession session = diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/RxNettyWebSocketHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/RxNettyWebSocketHandlerAdapter.java index 772de5c71e9..d9896ad43ca 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/RxNettyWebSocketHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/RxNettyWebSocketHandlerAdapter.java @@ -20,10 +20,8 @@ import reactor.core.publisher.Mono; import rx.Observable; import rx.RxReactiveStreams; -import org.springframework.core.io.buffer.NettyDataBufferFactory; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpResponse; -import org.springframework.util.Assert; import org.springframework.web.reactive.socket.WebSocketHandler; /** @@ -36,22 +34,14 @@ import org.springframework.web.reactive.socket.WebSocketHandler; public class RxNettyWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport implements io.reactivex.netty.protocol.http.ws.server.WebSocketHandler { - private final NettyDataBufferFactory bufferFactory; - public RxNettyWebSocketHandlerAdapter(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler handler) { - super(request, handler); - Assert.notNull("'response' is required"); - this.bufferFactory = (NettyDataBufferFactory) response.bufferFactory(); + super(request, response, handler); } - public NettyDataBufferFactory getBufferFactory() { - return this.bufferFactory; - } - @Override public Observable handle(WebSocketConnection conn) { RxNettyWebSocketSession session = new RxNettyWebSocketSession(conn, getUri(), getBufferFactory()); diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketHandlerAdapter.java index f1913a798d3..74b65f7b84b 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/TomcatWebSocketHandlerAdapter.java @@ -28,9 +28,8 @@ import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import org.springframework.core.io.buffer.DataBuffer; -import org.springframework.core.io.buffer.DataBufferFactory; -import org.springframework.core.io.buffer.DefaultDataBufferFactory; -import org.springframework.util.Assert; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.web.reactive.socket.CloseStatus; import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketMessage; @@ -43,76 +42,84 @@ import org.springframework.web.reactive.socket.WebSocketMessage.Type; * @author Violeta Georgieva * @since 5.0 */ -public class TomcatWebSocketHandlerAdapter extends Endpoint { - - private final WebSocketHandler delegate; +public class TomcatWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport { private TomcatWebSocketSession session; - private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(false); + public TomcatWebSocketHandlerAdapter(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler delegate) { - public TomcatWebSocketHandlerAdapter(WebSocketHandler delegate) { - Assert.notNull("WebSocketHandler is required"); - this.delegate = delegate; + super(request, response, delegate); } - @Override - public void onOpen(Session session, EndpointConfig config) { - this.session = new TomcatWebSocketSession(session); - - session.addMessageHandler(String.class, message -> { - WebSocketMessage webSocketMessage = toMessage(message); - this.session.handleMessage(webSocketMessage.getType(), webSocketMessage); - }); - session.addMessageHandler(ByteBuffer.class, message -> { - WebSocketMessage webSocketMessage = toMessage(message); - this.session.handleMessage(webSocketMessage.getType(), webSocketMessage); - }); - session.addMessageHandler(PongMessage.class, message -> { - WebSocketMessage webSocketMessage = toMessage(message); - this.session.handleMessage(webSocketMessage.getType(), webSocketMessage); - }); - - HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(); - this.delegate.handle(this.session).subscribe(resultSubscriber); + public Endpoint getEndpoint() { + return new StandardEndpoint(); } - private WebSocketMessage toMessage(T message) { - if (message instanceof String) { - byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8); - return WebSocketMessage.create(Type.TEXT, this.bufferFactory.wrap(bytes)); - } - else if (message instanceof ByteBuffer) { - DataBuffer buffer = this.bufferFactory.wrap((ByteBuffer) message); - return WebSocketMessage.create(Type.BINARY, buffer); - } - else if (message instanceof PongMessage) { - DataBuffer buffer = this.bufferFactory.wrap(((PongMessage) message).getApplicationData()); - return WebSocketMessage.create(Type.PONG, buffer); - } - else { - throw new IllegalArgumentException("Unexpected message type: " + message); - } + private TomcatWebSocketSession getSession() { + return this.session; } - @Override - public void onClose(Session session, CloseReason reason) { - if (this.session != null) { - int code = reason.getCloseCode().getCode(); - this.session.handleClose(new CloseStatus(code, reason.getReasonPhrase())); + + private class StandardEndpoint extends Endpoint { + + @Override + public void onOpen(Session session, EndpointConfig config) { + TomcatWebSocketHandlerAdapter.this.session = new TomcatWebSocketSession(session); + + session.addMessageHandler(String.class, message -> { + WebSocketMessage webSocketMessage = toMessage(message); + getSession().handleMessage(webSocketMessage.getType(), webSocketMessage); + }); + session.addMessageHandler(ByteBuffer.class, message -> { + WebSocketMessage webSocketMessage = toMessage(message); + getSession().handleMessage(webSocketMessage.getType(), webSocketMessage); + }); + session.addMessageHandler(PongMessage.class, message -> { + WebSocketMessage webSocketMessage = toMessage(message); + getSession().handleMessage(webSocketMessage.getType(), webSocketMessage); + }); + + HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(); + getDelegate().handle(TomcatWebSocketHandlerAdapter.this.session).subscribe(resultSubscriber); + } + + private WebSocketMessage toMessage(T message) { + if (message instanceof String) { + byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8); + return WebSocketMessage.create(Type.TEXT, getBufferFactory().wrap(bytes)); + } + else if (message instanceof ByteBuffer) { + DataBuffer buffer = getBufferFactory().wrap((ByteBuffer) message); + return WebSocketMessage.create(Type.BINARY, buffer); + } + else if (message instanceof PongMessage) { + DataBuffer buffer = getBufferFactory().wrap(((PongMessage) message).getApplicationData()); + return WebSocketMessage.create(Type.PONG, buffer); + } + else { + throw new IllegalArgumentException("Unexpected message type: " + message); + } + } + + @Override + public void onClose(Session session, CloseReason reason) { + if (getSession() != null) { + int code = reason.getCloseCode().getCode(); + getSession().handleClose(new CloseStatus(code, reason.getReasonPhrase())); + } + } + + @Override + public void onError(Session session, Throwable exception) { + if (getSession() != null) { + getSession().handleError(exception); + } } } - @Override - public void onError(Session session, Throwable exception) { - if (this.session != null) { - this.session.handleError(exception); - } - } - - private final class HandlerResultSubscriber implements Subscriber { @Override @@ -127,15 +134,15 @@ public class TomcatWebSocketHandlerAdapter extends Endpoint { @Override public void onError(Throwable ex) { - if (session != null) { - session.close(new CloseStatus(CloseStatus.SERVER_ERROR.getCode(), ex.getMessage())); + if (getSession() != null) { + getSession().close(new CloseStatus(CloseStatus.SERVER_ERROR.getCode(), ex.getMessage())); } } @Override public void onComplete() { - if (session != null) { - session.close(); + if (getSession() != null) { + getSession().close(); } } } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketHandlerAdapter.java index a0a17c3716c..17169806fb5 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketHandlerAdapter.java @@ -16,22 +16,9 @@ package org.springframework.web.reactive.socket.adapter; -import java.net.URISyntaxException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - -import org.springframework.core.io.buffer.DataBuffer; -import org.springframework.core.io.buffer.DataBufferFactory; -import org.springframework.core.io.buffer.DefaultDataBufferFactory; -import org.springframework.util.Assert; -import org.springframework.web.reactive.socket.CloseStatus; -import org.springframework.web.reactive.socket.WebSocketHandler; -import org.springframework.web.reactive.socket.WebSocketMessage; -import org.springframework.web.reactive.socket.WebSocketMessage.Type; - import io.undertow.websockets.WebSocketConnectionCallback; import io.undertow.websockets.core.AbstractReceiveListener; import io.undertow.websockets.core.BufferedBinaryMessage; @@ -39,6 +26,16 @@ import io.undertow.websockets.core.BufferedTextMessage; import io.undertow.websockets.core.CloseMessage; import io.undertow.websockets.core.WebSocketChannel; import io.undertow.websockets.spi.WebSocketHttpExchange; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.web.reactive.socket.CloseStatus; +import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.WebSocketMessage; +import org.springframework.web.reactive.socket.WebSocketMessage.Type; /** * Undertow {@code WebSocketHandler} implementation adapting and @@ -47,36 +44,27 @@ import io.undertow.websockets.spi.WebSocketHttpExchange; * @author Violeta Georgieva * @since 5.0 */ -public class UndertowWebSocketHandlerAdapter implements WebSocketConnectionCallback { - - private final WebSocketHandler delegate; +public class UndertowWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport + implements WebSocketConnectionCallback { private UndertowWebSocketSession session; - private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(false); + public UndertowWebSocketHandlerAdapter(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler delegate) { - public UndertowWebSocketHandlerAdapter(WebSocketHandler delegate) { - Assert.notNull("WebSocketHandler is required"); - this.delegate = delegate; + super(request, response, delegate); } @Override public void onConnect(WebSocketHttpExchange exchange, WebSocketChannel channel) { - try { - this.session = new UndertowWebSocketSession(channel); - } - catch (URISyntaxException e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } - + this.session = new UndertowWebSocketSession(channel, getUri()); channel.getReceiveSetter().set(new UndertowReceiveListener()); channel.resumeReceives(); HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(); - this.delegate.handle(this.session).subscribe(resultSubscriber); + getDelegate().handle(this.session).subscribe(resultSubscriber); } @@ -114,14 +102,14 @@ public class UndertowWebSocketHandlerAdapter implements WebSocketConnectionCallb private WebSocketMessage toMessage(Type type, T message) { if (Type.TEXT.equals(type)) { byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8); - return WebSocketMessage.create(Type.TEXT, bufferFactory.wrap(bytes)); + return WebSocketMessage.create(Type.TEXT, getBufferFactory().wrap(bytes)); } else if (Type.BINARY.equals(type)) { - DataBuffer buffer = bufferFactory.allocateBuffer().write((ByteBuffer[]) message); + DataBuffer buffer = getBufferFactory().allocateBuffer().write((ByteBuffer[]) message); return WebSocketMessage.create(Type.BINARY, buffer); } else if (Type.PONG.equals(type)) { - DataBuffer buffer = bufferFactory.allocateBuffer().write((ByteBuffer[]) message); + DataBuffer buffer = getBufferFactory().allocateBuffer().write((ByteBuffer[]) message); return WebSocketMessage.create(Type.PONG, buffer); } else { diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketSession.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketSession.java index 6d0b5577532..298ccad6961 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketSession.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/UndertowWebSocketSession.java @@ -18,7 +18,6 @@ package org.springframework.web.reactive.socket.adapter; import java.io.IOException; import java.net.URI; -import java.net.URISyntaxException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; @@ -43,8 +42,8 @@ import org.springframework.web.reactive.socket.WebSocketSession; public class UndertowWebSocketSession extends AbstractListenerWebSocketSession { - public UndertowWebSocketSession(WebSocketChannel channel) throws URISyntaxException { - super(channel, ObjectUtils.getIdentityHexString(channel), new URI(channel.getUrl())); + public UndertowWebSocketSession(WebSocketChannel channel, URI url) { + super(channel, ObjectUtils.getIdentityHexString(channel), url); } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/WebSocketHandlerAdapterSupport.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/WebSocketHandlerAdapterSupport.java index 97218500122..4a36ee2e0f9 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/WebSocketHandlerAdapterSupport.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/WebSocketHandlerAdapterSupport.java @@ -17,12 +17,15 @@ package org.springframework.web.reactive.socket.adapter; import java.net.URI; +import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.util.Assert; import org.springframework.web.reactive.socket.WebSocketHandler; /** - * Base class for {@link WebSocketHandler} implementations. + * Base class for {@link WebSocketHandler} adapters to underlying WebSocket + * handler APIs. * * @author Rossen Stoyanchev * @since 5.0 @@ -33,21 +36,32 @@ public abstract class WebSocketHandlerAdapterSupport { private final WebSocketHandler delegate; + private final DataBufferFactory bufferFactory; - protected WebSocketHandlerAdapterSupport(ServerHttpRequest request, WebSocketHandler handler) { - Assert.notNull("'request' is required"); - Assert.notNull("'handler' handler is required"); + + protected WebSocketHandlerAdapterSupport(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler handler) { + + Assert.notNull("ServerHttpRequest is required"); + Assert.notNull("ServerHttpResponse is required"); + Assert.notNull("WebSocketHandler handler is required"); this.uri = request.getURI(); + this.bufferFactory = response.bufferFactory(); this.delegate = handler; } - public URI getUri() { + protected URI getUri() { return this.uri; } - public WebSocketHandler getDelegate() { + protected WebSocketHandler getDelegate() { return this.delegate; } + @SuppressWarnings("unchecked") + protected T getBufferFactory() { + return (T) this.bufferFactory; + } + } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java index 82295c6e0a5..f898083b586 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java @@ -17,16 +17,14 @@ package org.springframework.web.reactive.socket.server.upgrade; import java.io.IOException; - import javax.servlet.ServletContext; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.eclipse.jetty.util.DecoratedObjectFactory; import org.eclipse.jetty.websocket.server.WebSocketServerFactory; -import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest; -import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse; -import org.eclipse.jetty.websocket.servlet.WebSocketCreator; +import reactor.core.publisher.Mono; + import org.springframework.context.Lifecycle; import org.springframework.core.NamedThreadLocal; import org.springframework.http.server.reactive.ServerHttpRequest; @@ -39,8 +37,6 @@ import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdap import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; - /** * A {@link RequestUpgradeStrategy} for use with Jetty. * @@ -52,43 +48,13 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life private static final ThreadLocal wsContainerHolder = new NamedThreadLocal<>("Jetty WebSocketHandler Adapter"); + private WebSocketServerFactory factory; private ServletContext servletContext; private volatile boolean running = false; - @Override - public Mono upgrade(ServerWebExchange exchange, WebSocketHandler webSocketHandler) { - - JettyWebSocketHandlerAdapter adapter = - new JettyWebSocketHandlerAdapter(webSocketHandler); - - HttpServletRequest servletRequest = getHttpServletRequest(exchange.getRequest()); - HttpServletResponse servletResponse = getHttpServletResponse(exchange.getResponse()); - - if (this.servletContext == null) { - this.servletContext = servletRequest.getServletContext(); - servletContext.setAttribute(DecoratedObjectFactory.ATTR, new DecoratedObjectFactory()); - } - - try { - start(); - - Assert.isTrue(this.factory.isUpgradeRequest(servletRequest, servletResponse), "Not a WebSocket handshake"); - - wsContainerHolder.set(adapter); - this.factory.acceptWebSocket(servletRequest, servletResponse); - } - catch (IOException ex) { - return Mono.error(ex); - } - finally { - wsContainerHolder.remove(); - } - - return Mono.empty(); - } @Override public void start() { @@ -96,16 +62,10 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life this.running = true; try { this.factory = new WebSocketServerFactory(this.servletContext); - this.factory.setCreator(new WebSocketCreator() { - - @Override - public Object createWebSocket(ServletUpgradeRequest req, - ServletUpgradeResponse resp) { - JettyWebSocketHandlerAdapter adapter = wsContainerHolder.get(); - Assert.state(adapter != null, "Expected JettyWebSocketHandlerAdapter"); - return adapter; - } - + this.factory.setCreator((request, response) -> { + JettyWebSocketHandlerAdapter adapter = wsContainerHolder.get(); + Assert.state(adapter != null, "Expected JettyWebSocketHandlerAdapter"); + return adapter; }); this.factory.start(); } @@ -133,12 +93,46 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life return this.running; } - private final HttpServletRequest getHttpServletRequest(ServerHttpRequest request) { + @Override + public Mono upgrade(ServerWebExchange exchange, WebSocketHandler handler) { + + ServerHttpRequest request = exchange.getRequest(); + ServerHttpResponse response = exchange.getResponse(); + JettyWebSocketHandlerAdapter adapter = new JettyWebSocketHandlerAdapter(request, response, handler); + + HttpServletRequest servletRequest = getHttpServletRequest(request); + HttpServletResponse servletResponse = getHttpServletResponse(response); + + if (this.servletContext == null) { + this.servletContext = servletRequest.getServletContext(); + this.servletContext.setAttribute(DecoratedObjectFactory.ATTR, new DecoratedObjectFactory()); + } + + try { + start(); + + Assert.isTrue(this.factory.isUpgradeRequest( + servletRequest, servletResponse), "Not a WebSocket handshake"); + + wsContainerHolder.set(adapter); + this.factory.acceptWebSocket(servletRequest, servletResponse); + } + catch (IOException ex) { + return Mono.error(ex); + } + finally { + wsContainerHolder.remove(); + } + + return Mono.empty(); + } + + private HttpServletRequest getHttpServletRequest(ServerHttpRequest request) { Assert.isTrue(request instanceof ServletServerHttpRequest); return ((ServletServerHttpRequest) request).getServletRequest(); } - private final HttpServletResponse getHttpServletResponse(ServerHttpResponse response) { + private HttpServletResponse getHttpServletResponse(ServerHttpResponse response) { Assert.isTrue(response instanceof ServletServerHttpResponse); return ((ServletServerHttpResponse) response).getServletResponse(); } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java index 899973892c8..dca97965519 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java @@ -24,6 +24,8 @@ import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import javax.websocket.Endpoint; +import javax.websocket.server.ServerEndpointConfig; import org.apache.tomcat.websocket.server.WsServerContainer; import org.springframework.http.server.reactive.ServerHttpRequest; @@ -50,45 +52,46 @@ public class TomcatRequestUpgradeStrategy implements RequestUpgradeStrategy { @Override - public Mono upgrade(ServerWebExchange exchange, WebSocketHandler webSocketHandler){ + public Mono upgrade(ServerWebExchange exchange, WebSocketHandler handler){ - TomcatWebSocketHandlerAdapter endpoint = - new TomcatWebSocketHandlerAdapter(webSocketHandler); + ServerHttpRequest request = exchange.getRequest(); + ServerHttpResponse response = exchange.getResponse(); + Endpoint endpoint = new TomcatWebSocketHandlerAdapter(request, response, handler).getEndpoint(); - HttpServletRequest servletRequest = getHttpServletRequest(exchange.getRequest()); - HttpServletResponse servletResponse = getHttpServletResponse(exchange.getResponse()); + HttpServletRequest servletRequest = getHttpServletRequest(request); + HttpServletResponse servletResponse = getHttpServletResponse(response); - Map pathParams = Collections. emptyMap(); - - ServerEndpointRegistration sec = - new ServerEndpointRegistration(servletRequest.getRequestURI(), endpoint); + String requestURI = servletRequest.getRequestURI(); + ServerEndpointConfig config = new ServerEndpointRegistration(requestURI, endpoint); try { - getContainer(servletRequest).doUpgrade(servletRequest, servletResponse, - sec, pathParams); + WsServerContainer container = getContainer(servletRequest); + container.doUpgrade(servletRequest, servletResponse, config, Collections.emptyMap()); } - catch (ServletException | IOException e) { - return Mono.error(e); + catch (ServletException | IOException ex) { + return Mono.error(ex); } return Mono.empty(); } - private WsServerContainer getContainer(HttpServletRequest request) { - ServletContext servletContext = request.getServletContext(); - Object container = servletContext.getAttribute(SERVER_CONTAINER_ATTR); - Assert.notNull(container, "No '" + SERVER_CONTAINER_ATTR + "' ServletContext attribute. " + - "Are you running in a Servlet container that supports JSR-356?"); - Assert.isTrue(container instanceof WsServerContainer); - return (WsServerContainer) container; - } - - private final HttpServletRequest getHttpServletRequest(ServerHttpRequest request) { + private HttpServletRequest getHttpServletRequest(ServerHttpRequest request) { Assert.isTrue(request instanceof ServletServerHttpRequest); return ((ServletServerHttpRequest) request).getServletRequest(); } - private final HttpServletResponse getHttpServletResponse(ServerHttpResponse response) { + private HttpServletResponse getHttpServletResponse(ServerHttpResponse response) { Assert.isTrue(response instanceof ServletServerHttpResponse); return ((ServletServerHttpResponse) response).getServletResponse(); } + + private WsServerContainer getContainer(HttpServletRequest request) { + ServletContext servletContext = request.getServletContext(); + Object container = servletContext.getAttribute(SERVER_CONTAINER_ATTR); + Assert.notNull(container, + "No 'javax.websocket.server.ServerContainer' ServletContext attribute. " + + "Are you running in a Servlet container that supports JSR-356?"); + Assert.isTrue(container instanceof WsServerContainer); + return (WsServerContainer) container; + } + } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java index 7cc22ddda7e..ec13e2b6854 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java @@ -17,6 +17,7 @@ package org.springframework.web.reactive.socket.server.upgrade; import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.http.server.reactive.UndertowServerHttpRequest; import org.springframework.util.Assert; import org.springframework.web.reactive.socket.WebSocketHandler; @@ -25,6 +26,7 @@ import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; import org.springframework.web.server.ServerWebExchange; import io.undertow.server.HttpServerExchange; +import io.undertow.websockets.WebSocketConnectionCallback; import io.undertow.websockets.WebSocketProtocolHandshakeHandler; import reactor.core.publisher.Mono; @@ -37,27 +39,23 @@ import reactor.core.publisher.Mono; public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy { @Override - public Mono upgrade(ServerWebExchange exchange, - WebSocketHandler webSocketHandler) { + public Mono upgrade(ServerWebExchange exchange, WebSocketHandler handler) { - UndertowWebSocketHandlerAdapter callback = - new UndertowWebSocketHandlerAdapter(webSocketHandler); + ServerHttpRequest request = exchange.getRequest(); + ServerHttpResponse response = exchange.getResponse(); + WebSocketConnectionCallback callback = new UndertowWebSocketHandlerAdapter(request, response, handler); + + Assert.isTrue(request instanceof UndertowServerHttpRequest); + HttpServerExchange httpExchange = ((UndertowServerHttpRequest) request).getUndertowExchange(); - WebSocketProtocolHandshakeHandler handler = - new WebSocketProtocolHandshakeHandler(callback); try { - handler.handleRequest(getUndertowExchange(exchange.getRequest())); + new WebSocketProtocolHandshakeHandler(callback).handleRequest(httpExchange); } - catch (Exception e) { - return Mono.error(e); + catch (Exception ex) { + return Mono.error(ex); } return Mono.empty(); } - private final HttpServerExchange getUndertowExchange(ServerHttpRequest request) { - Assert.isTrue(request instanceof UndertowServerHttpRequest); - return ((UndertowServerHttpRequest) request).getUndertowExchange(); - } - }