Consistently extend WebSocketHandlerAdapterSupport

The WebSocketHandler adapters for all runtimes now extend
WebSocketHandlerAdapterSupport, which now also exposes
a shared DataBufferFactory property initialized from the response.

Issue: SPR-14527
This commit is contained in:
Rossen Stoyanchev 2016-12-12 17:54:24 -05:00
parent 5829e1c141
commit d6895aa098
10 changed files with 205 additions and 228 deletions

View File

@ -32,9 +32,8 @@ import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.util.Assert;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
@ -48,21 +47,17 @@ import org.springframework.web.reactive.socket.WebSocketMessage.Type;
* @since 5.0
*/
@WebSocket
public class JettyWebSocketHandlerAdapter {
public class JettyWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport {
private static final ByteBuffer EMPTY_PAYLOAD = ByteBuffer.wrap(new byte[0]);
private final WebSocketHandler delegate;
private JettyWebSocketSession session;
private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(false);
public JettyWebSocketHandlerAdapter(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler delegate) {
public JettyWebSocketHandlerAdapter(WebSocketHandler delegate) {
Assert.notNull("WebSocketHandler is required");
this.delegate = delegate;
super(request, response, delegate);
}
@ -71,7 +66,7 @@ public class JettyWebSocketHandlerAdapter {
this.session = new JettyWebSocketSession(session);
HandlerResultSubscriber subscriber = new HandlerResultSubscriber();
this.delegate.handle(this.session).subscribe(subscriber);
getDelegate().handle(this.session).subscribe(subscriber);
}
@OnWebSocketMessage
@ -105,15 +100,15 @@ public class JettyWebSocketHandlerAdapter {
private <T> WebSocketMessage toMessage(Type type, T message) {
if (Type.TEXT.equals(type)) {
byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8);
DataBuffer buffer = this.bufferFactory.wrap(bytes);
DataBuffer buffer = getBufferFactory().wrap(bytes);
return WebSocketMessage.create(Type.TEXT, buffer);
}
else if (Type.BINARY.equals(type)) {
DataBuffer buffer = this.bufferFactory.wrap((ByteBuffer) message);
DataBuffer buffer = getBufferFactory().wrap((ByteBuffer) message);
return WebSocketMessage.create(Type.BINARY, buffer);
}
else if (Type.PONG.equals(type)) {
DataBuffer buffer = this.bufferFactory.wrap((ByteBuffer) message);
DataBuffer buffer = getBufferFactory().wrap((ByteBuffer) message);
return WebSocketMessage.create(Type.PONG, buffer);
}
else {

View File

@ -21,10 +21,8 @@ import org.reactivestreams.Publisher;
import reactor.ipc.netty.http.HttpInbound;
import reactor.ipc.netty.http.HttpOutbound;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.WebSocketHandler;
/**
@ -38,22 +36,13 @@ public class ReactorNettyWebSocketHandlerAdapter extends WebSocketHandlerAdapter
implements BiFunction<HttpInbound, HttpOutbound, Publisher<Void>> {
private final NettyDataBufferFactory bufferFactory;
public ReactorNettyWebSocketHandlerAdapter(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler handler) {
super(request, handler);
Assert.notNull("'response' is required");
this.bufferFactory = (NettyDataBufferFactory) response.bufferFactory();
super(request, response, handler);
}
public NettyDataBufferFactory getBufferFactory() {
return this.bufferFactory;
}
@Override
public Publisher<Void> apply(HttpInbound inbound, HttpOutbound outbound) {
ReactorNettyWebSocketSession session =

View File

@ -20,10 +20,8 @@ import reactor.core.publisher.Mono;
import rx.Observable;
import rx.RxReactiveStreams;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.WebSocketHandler;
/**
@ -36,22 +34,14 @@ import org.springframework.web.reactive.socket.WebSocketHandler;
public class RxNettyWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport
implements io.reactivex.netty.protocol.http.ws.server.WebSocketHandler {
private final NettyDataBufferFactory bufferFactory;
public RxNettyWebSocketHandlerAdapter(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler handler) {
super(request, handler);
Assert.notNull("'response' is required");
this.bufferFactory = (NettyDataBufferFactory) response.bufferFactory();
super(request, response, handler);
}
public NettyDataBufferFactory getBufferFactory() {
return this.bufferFactory;
}
@Override
public Observable<Void> handle(WebSocketConnection conn) {
RxNettyWebSocketSession session = new RxNettyWebSocketSession(conn, getUri(), getBufferFactory());

View File

@ -28,9 +28,8 @@ import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.util.Assert;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
@ -43,76 +42,84 @@ import org.springframework.web.reactive.socket.WebSocketMessage.Type;
* @author Violeta Georgieva
* @since 5.0
*/
public class TomcatWebSocketHandlerAdapter extends Endpoint {
private final WebSocketHandler delegate;
public class TomcatWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport {
private TomcatWebSocketSession session;
private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(false);
public TomcatWebSocketHandlerAdapter(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler delegate) {
public TomcatWebSocketHandlerAdapter(WebSocketHandler delegate) {
Assert.notNull("WebSocketHandler is required");
this.delegate = delegate;
super(request, response, delegate);
}
@Override
public void onOpen(Session session, EndpointConfig config) {
this.session = new TomcatWebSocketSession(session);
session.addMessageHandler(String.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
this.session.handleMessage(webSocketMessage.getType(), webSocketMessage);
});
session.addMessageHandler(ByteBuffer.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
this.session.handleMessage(webSocketMessage.getType(), webSocketMessage);
});
session.addMessageHandler(PongMessage.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
this.session.handleMessage(webSocketMessage.getType(), webSocketMessage);
});
HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber();
this.delegate.handle(this.session).subscribe(resultSubscriber);
public Endpoint getEndpoint() {
return new StandardEndpoint();
}
private <T> WebSocketMessage toMessage(T message) {
if (message instanceof String) {
byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8);
return WebSocketMessage.create(Type.TEXT, this.bufferFactory.wrap(bytes));
}
else if (message instanceof ByteBuffer) {
DataBuffer buffer = this.bufferFactory.wrap((ByteBuffer) message);
return WebSocketMessage.create(Type.BINARY, buffer);
}
else if (message instanceof PongMessage) {
DataBuffer buffer = this.bufferFactory.wrap(((PongMessage) message).getApplicationData());
return WebSocketMessage.create(Type.PONG, buffer);
}
else {
throw new IllegalArgumentException("Unexpected message type: " + message);
}
private TomcatWebSocketSession getSession() {
return this.session;
}
@Override
public void onClose(Session session, CloseReason reason) {
if (this.session != null) {
int code = reason.getCloseCode().getCode();
this.session.handleClose(new CloseStatus(code, reason.getReasonPhrase()));
private class StandardEndpoint extends Endpoint {
@Override
public void onOpen(Session session, EndpointConfig config) {
TomcatWebSocketHandlerAdapter.this.session = new TomcatWebSocketSession(session);
session.addMessageHandler(String.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
getSession().handleMessage(webSocketMessage.getType(), webSocketMessage);
});
session.addMessageHandler(ByteBuffer.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
getSession().handleMessage(webSocketMessage.getType(), webSocketMessage);
});
session.addMessageHandler(PongMessage.class, message -> {
WebSocketMessage webSocketMessage = toMessage(message);
getSession().handleMessage(webSocketMessage.getType(), webSocketMessage);
});
HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber();
getDelegate().handle(TomcatWebSocketHandlerAdapter.this.session).subscribe(resultSubscriber);
}
private <T> WebSocketMessage toMessage(T message) {
if (message instanceof String) {
byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8);
return WebSocketMessage.create(Type.TEXT, getBufferFactory().wrap(bytes));
}
else if (message instanceof ByteBuffer) {
DataBuffer buffer = getBufferFactory().wrap((ByteBuffer) message);
return WebSocketMessage.create(Type.BINARY, buffer);
}
else if (message instanceof PongMessage) {
DataBuffer buffer = getBufferFactory().wrap(((PongMessage) message).getApplicationData());
return WebSocketMessage.create(Type.PONG, buffer);
}
else {
throw new IllegalArgumentException("Unexpected message type: " + message);
}
}
@Override
public void onClose(Session session, CloseReason reason) {
if (getSession() != null) {
int code = reason.getCloseCode().getCode();
getSession().handleClose(new CloseStatus(code, reason.getReasonPhrase()));
}
}
@Override
public void onError(Session session, Throwable exception) {
if (getSession() != null) {
getSession().handleError(exception);
}
}
}
@Override
public void onError(Session session, Throwable exception) {
if (this.session != null) {
this.session.handleError(exception);
}
}
private final class HandlerResultSubscriber implements Subscriber<Void> {
@Override
@ -127,15 +134,15 @@ public class TomcatWebSocketHandlerAdapter extends Endpoint {
@Override
public void onError(Throwable ex) {
if (session != null) {
session.close(new CloseStatus(CloseStatus.SERVER_ERROR.getCode(), ex.getMessage()));
if (getSession() != null) {
getSession().close(new CloseStatus(CloseStatus.SERVER_ERROR.getCode(), ex.getMessage()));
}
}
@Override
public void onComplete() {
if (session != null) {
session.close();
if (getSession() != null) {
getSession().close();
}
}
}

View File

@ -16,22 +16,9 @@
package org.springframework.web.reactive.socket.adapter;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketMessage.Type;
import io.undertow.websockets.WebSocketConnectionCallback;
import io.undertow.websockets.core.AbstractReceiveListener;
import io.undertow.websockets.core.BufferedBinaryMessage;
@ -39,6 +26,16 @@ import io.undertow.websockets.core.BufferedTextMessage;
import io.undertow.websockets.core.CloseMessage;
import io.undertow.websockets.core.WebSocketChannel;
import io.undertow.websockets.spi.WebSocketHttpExchange;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketMessage.Type;
/**
* Undertow {@code WebSocketHandler} implementation adapting and
@ -47,36 +44,27 @@ import io.undertow.websockets.spi.WebSocketHttpExchange;
* @author Violeta Georgieva
* @since 5.0
*/
public class UndertowWebSocketHandlerAdapter implements WebSocketConnectionCallback {
private final WebSocketHandler delegate;
public class UndertowWebSocketHandlerAdapter extends WebSocketHandlerAdapterSupport
implements WebSocketConnectionCallback {
private UndertowWebSocketSession session;
private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(false);
public UndertowWebSocketHandlerAdapter(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler delegate) {
public UndertowWebSocketHandlerAdapter(WebSocketHandler delegate) {
Assert.notNull("WebSocketHandler is required");
this.delegate = delegate;
super(request, response, delegate);
}
@Override
public void onConnect(WebSocketHttpExchange exchange, WebSocketChannel channel) {
try {
this.session = new UndertowWebSocketSession(channel);
}
catch (URISyntaxException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
this.session = new UndertowWebSocketSession(channel, getUri());
channel.getReceiveSetter().set(new UndertowReceiveListener());
channel.resumeReceives();
HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber();
this.delegate.handle(this.session).subscribe(resultSubscriber);
getDelegate().handle(this.session).subscribe(resultSubscriber);
}
@ -114,14 +102,14 @@ public class UndertowWebSocketHandlerAdapter implements WebSocketConnectionCallb
private <T> WebSocketMessage toMessage(Type type, T message) {
if (Type.TEXT.equals(type)) {
byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8);
return WebSocketMessage.create(Type.TEXT, bufferFactory.wrap(bytes));
return WebSocketMessage.create(Type.TEXT, getBufferFactory().wrap(bytes));
}
else if (Type.BINARY.equals(type)) {
DataBuffer buffer = bufferFactory.allocateBuffer().write((ByteBuffer[]) message);
DataBuffer buffer = getBufferFactory().allocateBuffer().write((ByteBuffer[]) message);
return WebSocketMessage.create(Type.BINARY, buffer);
}
else if (Type.PONG.equals(type)) {
DataBuffer buffer = bufferFactory.allocateBuffer().write((ByteBuffer[]) message);
DataBuffer buffer = getBufferFactory().allocateBuffer().write((ByteBuffer[]) message);
return WebSocketMessage.create(Type.PONG, buffer);
}
else {

View File

@ -18,7 +18,6 @@ package org.springframework.web.reactive.socket.adapter;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
@ -43,8 +42,8 @@ import org.springframework.web.reactive.socket.WebSocketSession;
public class UndertowWebSocketSession extends AbstractListenerWebSocketSession<WebSocketChannel> {
public UndertowWebSocketSession(WebSocketChannel channel) throws URISyntaxException {
super(channel, ObjectUtils.getIdentityHexString(channel), new URI(channel.getUrl()));
public UndertowWebSocketSession(WebSocketChannel channel, URI url) {
super(channel, ObjectUtils.getIdentityHexString(channel), url);
}

View File

@ -17,12 +17,15 @@ package org.springframework.web.reactive.socket.adapter;
import java.net.URI;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.WebSocketHandler;
/**
* Base class for {@link WebSocketHandler} implementations.
* Base class for {@link WebSocketHandler} adapters to underlying WebSocket
* handler APIs.
*
* @author Rossen Stoyanchev
* @since 5.0
@ -33,21 +36,32 @@ public abstract class WebSocketHandlerAdapterSupport {
private final WebSocketHandler delegate;
private final DataBufferFactory bufferFactory;
protected WebSocketHandlerAdapterSupport(ServerHttpRequest request, WebSocketHandler handler) {
Assert.notNull("'request' is required");
Assert.notNull("'handler' handler is required");
protected WebSocketHandlerAdapterSupport(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler handler) {
Assert.notNull("ServerHttpRequest is required");
Assert.notNull("ServerHttpResponse is required");
Assert.notNull("WebSocketHandler handler is required");
this.uri = request.getURI();
this.bufferFactory = response.bufferFactory();
this.delegate = handler;
}
public URI getUri() {
protected URI getUri() {
return this.uri;
}
public WebSocketHandler getDelegate() {
protected WebSocketHandler getDelegate() {
return this.delegate;
}
@SuppressWarnings("unchecked")
protected <T extends DataBufferFactory> T getBufferFactory() {
return (T) this.bufferFactory;
}
}

View File

@ -17,16 +17,14 @@
package org.springframework.web.reactive.socket.server.upgrade;
import java.io.IOException;
import javax.servlet.ServletContext;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.util.DecoratedObjectFactory;
import org.eclipse.jetty.websocket.server.WebSocketServerFactory;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
import org.eclipse.jetty.websocket.servlet.WebSocketCreator;
import reactor.core.publisher.Mono;
import org.springframework.context.Lifecycle;
import org.springframework.core.NamedThreadLocal;
import org.springframework.http.server.reactive.ServerHttpRequest;
@ -39,8 +37,6 @@ import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdap
import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
/**
* A {@link RequestUpgradeStrategy} for use with Jetty.
*
@ -52,43 +48,13 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
private static final ThreadLocal<JettyWebSocketHandlerAdapter> wsContainerHolder =
new NamedThreadLocal<>("Jetty WebSocketHandler Adapter");
private WebSocketServerFactory factory;
private ServletContext servletContext;
private volatile boolean running = false;
@Override
public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler webSocketHandler) {
JettyWebSocketHandlerAdapter adapter =
new JettyWebSocketHandlerAdapter(webSocketHandler);
HttpServletRequest servletRequest = getHttpServletRequest(exchange.getRequest());
HttpServletResponse servletResponse = getHttpServletResponse(exchange.getResponse());
if (this.servletContext == null) {
this.servletContext = servletRequest.getServletContext();
servletContext.setAttribute(DecoratedObjectFactory.ATTR, new DecoratedObjectFactory());
}
try {
start();
Assert.isTrue(this.factory.isUpgradeRequest(servletRequest, servletResponse), "Not a WebSocket handshake");
wsContainerHolder.set(adapter);
this.factory.acceptWebSocket(servletRequest, servletResponse);
}
catch (IOException ex) {
return Mono.error(ex);
}
finally {
wsContainerHolder.remove();
}
return Mono.empty();
}
@Override
public void start() {
@ -96,16 +62,10 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
this.running = true;
try {
this.factory = new WebSocketServerFactory(this.servletContext);
this.factory.setCreator(new WebSocketCreator() {
@Override
public Object createWebSocket(ServletUpgradeRequest req,
ServletUpgradeResponse resp) {
JettyWebSocketHandlerAdapter adapter = wsContainerHolder.get();
Assert.state(adapter != null, "Expected JettyWebSocketHandlerAdapter");
return adapter;
}
this.factory.setCreator((request, response) -> {
JettyWebSocketHandlerAdapter adapter = wsContainerHolder.get();
Assert.state(adapter != null, "Expected JettyWebSocketHandlerAdapter");
return adapter;
});
this.factory.start();
}
@ -133,12 +93,46 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
return this.running;
}
private final HttpServletRequest getHttpServletRequest(ServerHttpRequest request) {
@Override
public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler) {
ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse();
JettyWebSocketHandlerAdapter adapter = new JettyWebSocketHandlerAdapter(request, response, handler);
HttpServletRequest servletRequest = getHttpServletRequest(request);
HttpServletResponse servletResponse = getHttpServletResponse(response);
if (this.servletContext == null) {
this.servletContext = servletRequest.getServletContext();
this.servletContext.setAttribute(DecoratedObjectFactory.ATTR, new DecoratedObjectFactory());
}
try {
start();
Assert.isTrue(this.factory.isUpgradeRequest(
servletRequest, servletResponse), "Not a WebSocket handshake");
wsContainerHolder.set(adapter);
this.factory.acceptWebSocket(servletRequest, servletResponse);
}
catch (IOException ex) {
return Mono.error(ex);
}
finally {
wsContainerHolder.remove();
}
return Mono.empty();
}
private HttpServletRequest getHttpServletRequest(ServerHttpRequest request) {
Assert.isTrue(request instanceof ServletServerHttpRequest);
return ((ServletServerHttpRequest) request).getServletRequest();
}
private final HttpServletResponse getHttpServletResponse(ServerHttpResponse response) {
private HttpServletResponse getHttpServletResponse(ServerHttpResponse response) {
Assert.isTrue(response instanceof ServletServerHttpResponse);
return ((ServletServerHttpResponse) response).getServletResponse();
}

View File

@ -24,6 +24,8 @@ import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.websocket.Endpoint;
import javax.websocket.server.ServerEndpointConfig;
import org.apache.tomcat.websocket.server.WsServerContainer;
import org.springframework.http.server.reactive.ServerHttpRequest;
@ -50,45 +52,46 @@ public class TomcatRequestUpgradeStrategy implements RequestUpgradeStrategy {
@Override
public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler webSocketHandler){
public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler){
TomcatWebSocketHandlerAdapter endpoint =
new TomcatWebSocketHandlerAdapter(webSocketHandler);
ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse();
Endpoint endpoint = new TomcatWebSocketHandlerAdapter(request, response, handler).getEndpoint();
HttpServletRequest servletRequest = getHttpServletRequest(exchange.getRequest());
HttpServletResponse servletResponse = getHttpServletResponse(exchange.getResponse());
HttpServletRequest servletRequest = getHttpServletRequest(request);
HttpServletResponse servletResponse = getHttpServletResponse(response);
Map<String, String> pathParams = Collections.<String, String> emptyMap();
ServerEndpointRegistration sec =
new ServerEndpointRegistration(servletRequest.getRequestURI(), endpoint);
String requestURI = servletRequest.getRequestURI();
ServerEndpointConfig config = new ServerEndpointRegistration(requestURI, endpoint);
try {
getContainer(servletRequest).doUpgrade(servletRequest, servletResponse,
sec, pathParams);
WsServerContainer container = getContainer(servletRequest);
container.doUpgrade(servletRequest, servletResponse, config, Collections.emptyMap());
}
catch (ServletException | IOException e) {
return Mono.error(e);
catch (ServletException | IOException ex) {
return Mono.error(ex);
}
return Mono.empty();
}
private WsServerContainer getContainer(HttpServletRequest request) {
ServletContext servletContext = request.getServletContext();
Object container = servletContext.getAttribute(SERVER_CONTAINER_ATTR);
Assert.notNull(container, "No '" + SERVER_CONTAINER_ATTR + "' ServletContext attribute. " +
"Are you running in a Servlet container that supports JSR-356?");
Assert.isTrue(container instanceof WsServerContainer);
return (WsServerContainer) container;
}
private final HttpServletRequest getHttpServletRequest(ServerHttpRequest request) {
private HttpServletRequest getHttpServletRequest(ServerHttpRequest request) {
Assert.isTrue(request instanceof ServletServerHttpRequest);
return ((ServletServerHttpRequest) request).getServletRequest();
}
private final HttpServletResponse getHttpServletResponse(ServerHttpResponse response) {
private HttpServletResponse getHttpServletResponse(ServerHttpResponse response) {
Assert.isTrue(response instanceof ServletServerHttpResponse);
return ((ServletServerHttpResponse) response).getServletResponse();
}
private WsServerContainer getContainer(HttpServletRequest request) {
ServletContext servletContext = request.getServletContext();
Object container = servletContext.getAttribute(SERVER_CONTAINER_ATTR);
Assert.notNull(container,
"No 'javax.websocket.server.ServerContainer' ServletContext attribute. " +
"Are you running in a Servlet container that supports JSR-356?");
Assert.isTrue(container instanceof WsServerContainer);
return (WsServerContainer) container;
}
}

View File

@ -17,6 +17,7 @@
package org.springframework.web.reactive.socket.server.upgrade;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.UndertowServerHttpRequest;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.WebSocketHandler;
@ -25,6 +26,7 @@ import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
import org.springframework.web.server.ServerWebExchange;
import io.undertow.server.HttpServerExchange;
import io.undertow.websockets.WebSocketConnectionCallback;
import io.undertow.websockets.WebSocketProtocolHandshakeHandler;
import reactor.core.publisher.Mono;
@ -37,27 +39,23 @@ import reactor.core.publisher.Mono;
public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy {
@Override
public Mono<Void> upgrade(ServerWebExchange exchange,
WebSocketHandler webSocketHandler) {
public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler) {
UndertowWebSocketHandlerAdapter callback =
new UndertowWebSocketHandlerAdapter(webSocketHandler);
ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse();
WebSocketConnectionCallback callback = new UndertowWebSocketHandlerAdapter(request, response, handler);
Assert.isTrue(request instanceof UndertowServerHttpRequest);
HttpServerExchange httpExchange = ((UndertowServerHttpRequest) request).getUndertowExchange();
WebSocketProtocolHandshakeHandler handler =
new WebSocketProtocolHandshakeHandler(callback);
try {
handler.handleRequest(getUndertowExchange(exchange.getRequest()));
new WebSocketProtocolHandshakeHandler(callback).handleRequest(httpExchange);
}
catch (Exception e) {
return Mono.error(e);
catch (Exception ex) {
return Mono.error(ex);
}
return Mono.empty();
}
private final HttpServerExchange getUndertowExchange(ServerHttpRequest request) {
Assert.isTrue(request instanceof UndertowServerHttpRequest);
return ((UndertowServerHttpRequest) request).getUndertowExchange();
}
}