Minor refactoring + polish

- RxNettyWebSocketSession filters out WebSocketCloseFrame again
- add before/afterHandshake helper methods in WebSocketClientSupport
- log request headers on server and response headers on client
- polish 400 request handling in HandshakeWebSocketService
This commit is contained in:
Rossen Stoyanchev 2016-12-22 16:14:33 -05:00
parent d64d9ab370
commit 3719f75d3b
5 changed files with 74 additions and 62 deletions

View File

@ -17,8 +17,10 @@
package org.springframework.web.reactive.socket.adapter; package org.springframework.web.reactive.socket.adapter;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrameAggregator; import io.netty.handler.codec.http.websocketx.WebSocketFrameAggregator;
import io.reactivex.netty.protocol.http.ws.WebSocketConnection; import io.reactivex.netty.protocol.http.ws.WebSocketConnection;
@ -45,8 +47,8 @@ import org.springframework.web.reactive.socket.WebSocketSession;
public class RxNettyWebSocketSession extends NettyWebSocketSessionSupport<WebSocketConnection> { public class RxNettyWebSocketSession extends NettyWebSocketSessionSupport<WebSocketConnection> {
/** /**
* The name of the {@link WebSocketFrameAggregator} inserted by * The {@code ChannelHandler} name to use when inserting a
* {@link #aggregateFrames(Channel, String)}. * {@link WebSocketFrameAggregator} in the channel pipeline.
*/ */
public static final String FRAME_AGGREGATOR_NAME = "websocket-frame-aggregator"; public static final String FRAME_AGGREGATOR_NAME = "websocket-frame-aggregator";
@ -70,18 +72,21 @@ public class RxNettyWebSocketSession extends NettyWebSocketSessionSupport<WebSoc
logger.trace("WebSocketFrameAggregator already registered."); logger.trace("WebSocketFrameAggregator already registered.");
return this; return this;
} }
ChannelHandlerContext context = pipeline.context(frameDecoderName); ChannelHandlerContext frameDecoder = pipeline.context(frameDecoderName);
Assert.notNull(context, "WebSocketFrameDecoder not found: " + frameDecoderName); Assert.notNull(frameDecoder, "WebSocketFrameDecoder not found: " + frameDecoderName);
WebSocketFrameAggregator aggregator = new WebSocketFrameAggregator(DEFAULT_FRAME_MAX_SIZE); ChannelHandler frameAggregator = new WebSocketFrameAggregator(DEFAULT_FRAME_MAX_SIZE);
pipeline.addAfter(context.name(), FRAME_AGGREGATOR_NAME, aggregator); pipeline.addAfter(frameDecoder.name(), FRAME_AGGREGATOR_NAME, frameAggregator);
return this; return this;
} }
@Override @Override
public Flux<WebSocketMessage> receive() { public Flux<WebSocketMessage> receive() {
Observable<WebSocketMessage> observable = getDelegate().getInput().map(super::toMessage); Observable<WebSocketMessage> messages = getDelegate()
return Flux.from(RxReactiveStreams.toPublisher(observable)); .getInput()
.filter(frame -> !(frame instanceof CloseWebSocketFrame))
.map(super::toMessage);
return Flux.from(RxReactiveStreams.toPublisher(messages));
} }
@Override @Override

View File

@ -16,7 +16,6 @@
package org.springframework.web.reactive.socket.client; package org.springframework.web.reactive.socket.client;
import java.net.URI; import java.net.URI;
import java.util.Optional;
import java.util.function.Consumer; import java.util.function.Consumer;
import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufAllocator;
@ -61,7 +60,7 @@ public class ReactorNettyWebSocketClient extends WebSocketClientSupport implemen
@Override @Override
public Mono<Void> execute(URI url, HttpHeaders headers, WebSocketHandler handler) { public Mono<Void> execute(URI url, HttpHeaders headers, WebSocketHandler handler) {
String[] protocols = getSubProtocols(headers, handler); String[] protocols = beforeHandshake(url, headers, handler);
// TODO: https://github.com/reactor/reactor-netty/issues/20 // TODO: https://github.com/reactor/reactor-netty/issues/20
return this.httpClient return this.httpClient
@ -71,9 +70,7 @@ public class ReactorNettyWebSocketClient extends WebSocketClientSupport implemen
}) })
.then(response -> { .then(response -> {
HttpHeaders responseHeaders = getResponseHeaders(response); HttpHeaders responseHeaders = getResponseHeaders(response);
String protocol = responseHeaders.getFirst(SEC_WEBSOCKET_PROTOCOL); HandshakeInfo info = afterHandshake(url, response.status().code(), responseHeaders);
HandshakeInfo info = new HandshakeInfo(url, responseHeaders, Mono.empty(),
Optional.ofNullable(protocol));
ByteBufAllocator allocator = response.channel().alloc(); ByteBufAllocator allocator = response.channel().alloc();
NettyDataBufferFactory factory = new NettyDataBufferFactory(allocator); NettyDataBufferFactory factory = new NettyDataBufferFactory(allocator);

View File

@ -22,7 +22,6 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.function.Function; import java.util.function.Function;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngine;
@ -44,7 +43,6 @@ import org.springframework.http.HttpHeaders;
import org.springframework.util.ObjectUtils; import org.springframework.util.ObjectUtils;
import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketSession;
import org.springframework.web.reactive.socket.adapter.RxNettyWebSocketSession; import org.springframework.web.reactive.socket.adapter.RxNettyWebSocketSession;
/** /**
@ -105,7 +103,10 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We
} }
private Observable<Void> connectInternal(URI url, HttpHeaders headers, WebSocketHandler handler) { private Observable<Void> connectInternal(URI url, HttpHeaders headers, WebSocketHandler handler) {
return createRequest(url, headers, handler)
String[] protocols = beforeHandshake(url, headers, handler);
return createRequest(url, headers, protocols)
.flatMap(response -> { .flatMap(response -> {
Observable<WebSocketConnection> conn = response.getWebSocketConnection(); Observable<WebSocketConnection> conn = response.getWebSocketConnection();
return Observable.zip(Observable.just(response), conn, Tuples::of); return Observable.zip(Observable.just(response), conn, Tuples::of);
@ -113,8 +114,7 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We
.flatMap(tuple -> { .flatMap(tuple -> {
WebSocketResponse<ByteBuf> response = tuple.getT1(); WebSocketResponse<ByteBuf> response = tuple.getT1();
HttpHeaders responseHeaders = getResponseHeaders(response); HttpHeaders responseHeaders = getResponseHeaders(response);
Optional<String> protocol = Optional.ofNullable(response.getAcceptedSubProtocol()); HandshakeInfo info = afterHandshake(url, response.getStatus().code(), responseHeaders);
HandshakeInfo info = new HandshakeInfo(url, responseHeaders, Mono.empty(), protocol);
ByteBufAllocator allocator = response.unsafeNettyChannel().alloc(); ByteBufAllocator allocator = response.unsafeNettyChannel().alloc();
NettyDataBufferFactory factory = new NettyDataBufferFactory(allocator); NettyDataBufferFactory factory = new NettyDataBufferFactory(allocator);
@ -128,7 +128,7 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We
}); });
} }
private WebSocketRequest<ByteBuf> createRequest(URI url, HttpHeaders headers, WebSocketHandler handler) { private WebSocketRequest<ByteBuf> createRequest(URI url, HttpHeaders headers, String[] protocols) {
String query = url.getRawQuery(); String query = url.getRawQuery();
String requestUrl = url.getRawPath() + (query != null ? "?" + query : ""); String requestUrl = url.getRawPath() + (query != null ? "?" + query : "");
@ -138,7 +138,6 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We
.setHeaders(toObjectValueMap(headers)) .setHeaders(toObjectValueMap(headers))
.requestWebSocketUpgrade(); .requestWebSocketUpgrade();
String[] protocols = getSubProtocols(headers, handler);
if (!ObjectUtils.isEmpty(protocols)) { if (!ObjectUtils.isEmpty(protocols)) {
request = request.requestSubProtocols(protocols); request = request.requestSubProtocols(protocols);
} }

View File

@ -15,8 +15,16 @@
*/ */
package org.springframework.web.reactive.socket.client; package org.springframework.web.reactive.socket.client;
import java.net.URI;
import java.util.Optional;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.util.StringUtils; import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketHandler;
/** /**
@ -30,11 +38,23 @@ public class WebSocketClientSupport {
protected static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"; protected static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol";
protected String[] getSubProtocols(HttpHeaders headers, WebSocketHandler handler) { protected final Log logger = LogFactory.getLog(getClass());
String value = headers.getFirst(SEC_WEBSOCKET_PROTOCOL);
return (value != null ?
StringUtils.commaDelimitedListToStringArray(value) : protected String[] beforeHandshake(URI url, HttpHeaders headers, WebSocketHandler handler) {
handler.getSubProtocols()); if (logger.isDebugEnabled()) {
logger.debug("Executing handshake to " + url);
}
return handler.getSubProtocols();
}
protected HandshakeInfo afterHandshake(URI url, int statusCode, HttpHeaders headers) {
Assert.isTrue(statusCode == 101);
if (logger.isDebugEnabled()) {
logger.debug("Handshake response: " + url + ", " + headers);
}
String protocol = headers.getFirst(SEC_WEBSOCKET_PROTOCOL);
return new HandshakeInfo(url, headers, Mono.empty(), Optional.ofNullable(protocol));
} }
} }

View File

@ -25,10 +25,9 @@ import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import org.springframework.context.Lifecycle; import org.springframework.context.Lifecycle;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils;
@ -38,6 +37,7 @@ import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
import org.springframework.web.reactive.socket.server.WebSocketService; import org.springframework.web.reactive.socket.server.WebSocketService;
import org.springframework.web.server.MethodNotAllowedException; import org.springframework.web.server.MethodNotAllowedException;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.ServerWebInputException;
/** /**
* {@code WebSocketService} implementation that handles a WebSocket HTTP * {@code WebSocketService} implementation that handles a WebSocket HTTP
@ -179,53 +179,44 @@ public class HandshakeWebSocketService implements WebSocketService, Lifecycle {
public Mono<Void> handleRequest(ServerWebExchange exchange, WebSocketHandler handler) { public Mono<Void> handleRequest(ServerWebExchange exchange, WebSocketHandler handler) {
ServerHttpRequest request = exchange.getRequest(); ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse(); HttpMethod method = request.getMethod();
HttpHeaders headers = request.getHeaders();
if (logger.isTraceEnabled()) { if (logger.isDebugEnabled()) {
logger.trace("Processing " + request.getMethod() + " " + request.getURI()); logger.debug("Handling " + request.getURI() + " with headers: " + headers);
} }
if (HttpMethod.GET != request.getMethod()) { if (HttpMethod.GET != method) {
return Mono.error(new MethodNotAllowedException( return Mono.error(new MethodNotAllowedException(method.name(), Collections.singleton("GET")));
request.getMethod().name(), Collections.singleton("GET")));
} }
if (!isWebSocketUpgrade(request)) { if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
response.setStatusCode(HttpStatus.BAD_REQUEST); return handleBadRequest("Invalid 'Upgrade' header: " + headers);
return response.setComplete();
} }
Optional<String> subProtocol = selectSubProtocol(request, handler); List<String> connectionValue = headers.getConnection();
return getUpgradeStrategy().upgrade(exchange, handler, subProtocol);
}
private boolean isWebSocketUpgrade(ServerHttpRequest request) {
if (!"WebSocket".equalsIgnoreCase(request.getHeaders().getUpgrade())) {
if (logger.isErrorEnabled()) {
logger.error("Invalid 'Upgrade' header: " + request.getHeaders());
}
return false;
}
List<String> connectionValue = request.getHeaders().getConnection();
if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) { if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) {
if (logger.isErrorEnabled()) { return handleBadRequest("Invalid 'Connection' header: " + headers);
logger.error("Invalid 'Connection' header: " + request.getHeaders());
}
return false;
}
String key = request.getHeaders().getFirst(SEC_WEBSOCKET_KEY);
if (key == null) {
if (logger.isErrorEnabled()) {
logger.error("Missing \"Sec-WebSocket-Key\" header");
}
return false;
}
return true;
} }
private Optional<String> selectSubProtocol(ServerHttpRequest request, WebSocketHandler handler) { String key = headers.getFirst(SEC_WEBSOCKET_KEY);
String protocolHeader = request.getHeaders().getFirst(SEC_WEBSOCKET_PROTOCOL); if (key == null) {
return handleBadRequest("Missing \"Sec-WebSocket-Key\" header");
}
Optional<String> protocol = selectProtocol(headers, handler);
return this.upgradeStrategy.upgrade(exchange, handler, protocol);
}
private Mono<Void> handleBadRequest(String reason) {
if (logger.isDebugEnabled()) {
logger.debug(reason);
}
return Mono.error(new ServerWebInputException(reason));
}
private Optional<String> selectProtocol(HttpHeaders headers, WebSocketHandler handler) {
String protocolHeader = headers.getFirst(SEC_WEBSOCKET_PROTOCOL);
if (protocolHeader == null) { if (protocolHeader == null) {
return Optional.empty(); return Optional.empty();
} }