Minor refactoring + polish

- RxNettyWebSocketSession filters out WebSocketCloseFrame again
- add before/afterHandshake helper methods in WebSocketClientSupport
- log request headers on server and response headers on client
- polish 400 request handling in HandshakeWebSocketService
This commit is contained in:
Rossen Stoyanchev 2016-12-22 16:14:33 -05:00
parent d64d9ab370
commit 3719f75d3b
5 changed files with 74 additions and 62 deletions

View File

@ -17,8 +17,10 @@
package org.springframework.web.reactive.socket.adapter;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrameAggregator;
import io.reactivex.netty.protocol.http.ws.WebSocketConnection;
@ -45,8 +47,8 @@ import org.springframework.web.reactive.socket.WebSocketSession;
public class RxNettyWebSocketSession extends NettyWebSocketSessionSupport<WebSocketConnection> {
/**
* The name of the {@link WebSocketFrameAggregator} inserted by
* {@link #aggregateFrames(Channel, String)}.
* The {@code ChannelHandler} name to use when inserting a
* {@link WebSocketFrameAggregator} in the channel pipeline.
*/
public static final String FRAME_AGGREGATOR_NAME = "websocket-frame-aggregator";
@ -70,18 +72,21 @@ public class RxNettyWebSocketSession extends NettyWebSocketSessionSupport<WebSoc
logger.trace("WebSocketFrameAggregator already registered.");
return this;
}
ChannelHandlerContext context = pipeline.context(frameDecoderName);
Assert.notNull(context, "WebSocketFrameDecoder not found: " + frameDecoderName);
WebSocketFrameAggregator aggregator = new WebSocketFrameAggregator(DEFAULT_FRAME_MAX_SIZE);
pipeline.addAfter(context.name(), FRAME_AGGREGATOR_NAME, aggregator);
ChannelHandlerContext frameDecoder = pipeline.context(frameDecoderName);
Assert.notNull(frameDecoder, "WebSocketFrameDecoder not found: " + frameDecoderName);
ChannelHandler frameAggregator = new WebSocketFrameAggregator(DEFAULT_FRAME_MAX_SIZE);
pipeline.addAfter(frameDecoder.name(), FRAME_AGGREGATOR_NAME, frameAggregator);
return this;
}
@Override
public Flux<WebSocketMessage> receive() {
Observable<WebSocketMessage> observable = getDelegate().getInput().map(super::toMessage);
return Flux.from(RxReactiveStreams.toPublisher(observable));
Observable<WebSocketMessage> messages = getDelegate()
.getInput()
.filter(frame -> !(frame instanceof CloseWebSocketFrame))
.map(super::toMessage);
return Flux.from(RxReactiveStreams.toPublisher(messages));
}
@Override

View File

@ -16,7 +16,6 @@
package org.springframework.web.reactive.socket.client;
import java.net.URI;
import java.util.Optional;
import java.util.function.Consumer;
import io.netty.buffer.ByteBufAllocator;
@ -61,7 +60,7 @@ public class ReactorNettyWebSocketClient extends WebSocketClientSupport implemen
@Override
public Mono<Void> execute(URI url, HttpHeaders headers, WebSocketHandler handler) {
String[] protocols = getSubProtocols(headers, handler);
String[] protocols = beforeHandshake(url, headers, handler);
// TODO: https://github.com/reactor/reactor-netty/issues/20
return this.httpClient
@ -71,9 +70,7 @@ public class ReactorNettyWebSocketClient extends WebSocketClientSupport implemen
})
.then(response -> {
HttpHeaders responseHeaders = getResponseHeaders(response);
String protocol = responseHeaders.getFirst(SEC_WEBSOCKET_PROTOCOL);
HandshakeInfo info = new HandshakeInfo(url, responseHeaders, Mono.empty(),
Optional.ofNullable(protocol));
HandshakeInfo info = afterHandshake(url, response.status().code(), responseHeaders);
ByteBufAllocator allocator = response.channel().alloc();
NettyDataBufferFactory factory = new NettyDataBufferFactory(allocator);

View File

@ -22,7 +22,6 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
@ -44,7 +43,6 @@ import org.springframework.http.HttpHeaders;
import org.springframework.util.ObjectUtils;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketSession;
import org.springframework.web.reactive.socket.adapter.RxNettyWebSocketSession;
/**
@ -105,7 +103,10 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We
}
private Observable<Void> connectInternal(URI url, HttpHeaders headers, WebSocketHandler handler) {
return createRequest(url, headers, handler)
String[] protocols = beforeHandshake(url, headers, handler);
return createRequest(url, headers, protocols)
.flatMap(response -> {
Observable<WebSocketConnection> conn = response.getWebSocketConnection();
return Observable.zip(Observable.just(response), conn, Tuples::of);
@ -113,8 +114,7 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We
.flatMap(tuple -> {
WebSocketResponse<ByteBuf> response = tuple.getT1();
HttpHeaders responseHeaders = getResponseHeaders(response);
Optional<String> protocol = Optional.ofNullable(response.getAcceptedSubProtocol());
HandshakeInfo info = new HandshakeInfo(url, responseHeaders, Mono.empty(), protocol);
HandshakeInfo info = afterHandshake(url, response.getStatus().code(), responseHeaders);
ByteBufAllocator allocator = response.unsafeNettyChannel().alloc();
NettyDataBufferFactory factory = new NettyDataBufferFactory(allocator);
@ -128,7 +128,7 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We
});
}
private WebSocketRequest<ByteBuf> createRequest(URI url, HttpHeaders headers, WebSocketHandler handler) {
private WebSocketRequest<ByteBuf> createRequest(URI url, HttpHeaders headers, String[] protocols) {
String query = url.getRawQuery();
String requestUrl = url.getRawPath() + (query != null ? "?" + query : "");
@ -138,7 +138,6 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We
.setHeaders(toObjectValueMap(headers))
.requestWebSocketUpgrade();
String[] protocols = getSubProtocols(headers, handler);
if (!ObjectUtils.isEmpty(protocols)) {
request = request.requestSubProtocols(protocols);
}

View File

@ -15,8 +15,16 @@
*/
package org.springframework.web.reactive.socket.client;
import java.net.URI;
import java.util.Optional;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpHeaders;
import org.springframework.util.StringUtils;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
/**
@ -30,11 +38,23 @@ public class WebSocketClientSupport {
protected static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol";
protected String[] getSubProtocols(HttpHeaders headers, WebSocketHandler handler) {
String value = headers.getFirst(SEC_WEBSOCKET_PROTOCOL);
return (value != null ?
StringUtils.commaDelimitedListToStringArray(value) :
handler.getSubProtocols());
protected final Log logger = LogFactory.getLog(getClass());
protected String[] beforeHandshake(URI url, HttpHeaders headers, WebSocketHandler handler) {
if (logger.isDebugEnabled()) {
logger.debug("Executing handshake to " + url);
}
return handler.getSubProtocols();
}
protected HandshakeInfo afterHandshake(URI url, int statusCode, HttpHeaders headers) {
Assert.isTrue(statusCode == 101);
if (logger.isDebugEnabled()) {
logger.debug("Handshake response: " + url + ", " + headers);
}
String protocol = headers.getFirst(SEC_WEBSOCKET_PROTOCOL);
return new HandshakeInfo(url, headers, Mono.empty(), Optional.ofNullable(protocol));
}
}

View File

@ -25,10 +25,9 @@ import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Mono;
import org.springframework.context.Lifecycle;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
@ -38,6 +37,7 @@ import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
import org.springframework.web.reactive.socket.server.WebSocketService;
import org.springframework.web.server.MethodNotAllowedException;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.ServerWebInputException;
/**
* {@code WebSocketService} implementation that handles a WebSocket HTTP
@ -179,53 +179,44 @@ public class HandshakeWebSocketService implements WebSocketService, Lifecycle {
public Mono<Void> handleRequest(ServerWebExchange exchange, WebSocketHandler handler) {
ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse();
HttpMethod method = request.getMethod();
HttpHeaders headers = request.getHeaders();
if (logger.isTraceEnabled()) {
logger.trace("Processing " + request.getMethod() + " " + request.getURI());
if (logger.isDebugEnabled()) {
logger.debug("Handling " + request.getURI() + " with headers: " + headers);
}
if (HttpMethod.GET != request.getMethod()) {
return Mono.error(new MethodNotAllowedException(
request.getMethod().name(), Collections.singleton("GET")));
if (HttpMethod.GET != method) {
return Mono.error(new MethodNotAllowedException(method.name(), Collections.singleton("GET")));
}
if (!isWebSocketUpgrade(request)) {
response.setStatusCode(HttpStatus.BAD_REQUEST);
return response.setComplete();
if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
return handleBadRequest("Invalid 'Upgrade' header: " + headers);
}
Optional<String> subProtocol = selectSubProtocol(request, handler);
return getUpgradeStrategy().upgrade(exchange, handler, subProtocol);
}
private boolean isWebSocketUpgrade(ServerHttpRequest request) {
if (!"WebSocket".equalsIgnoreCase(request.getHeaders().getUpgrade())) {
if (logger.isErrorEnabled()) {
logger.error("Invalid 'Upgrade' header: " + request.getHeaders());
}
return false;
}
List<String> connectionValue = request.getHeaders().getConnection();
List<String> connectionValue = headers.getConnection();
if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) {
if (logger.isErrorEnabled()) {
logger.error("Invalid 'Connection' header: " + request.getHeaders());
}
return false;
return handleBadRequest("Invalid 'Connection' header: " + headers);
}
String key = request.getHeaders().getFirst(SEC_WEBSOCKET_KEY);
String key = headers.getFirst(SEC_WEBSOCKET_KEY);
if (key == null) {
if (logger.isErrorEnabled()) {
logger.error("Missing \"Sec-WebSocket-Key\" header");
}
return false;
return handleBadRequest("Missing \"Sec-WebSocket-Key\" header");
}
return true;
Optional<String> protocol = selectProtocol(headers, handler);
return this.upgradeStrategy.upgrade(exchange, handler, protocol);
}
private Optional<String> selectSubProtocol(ServerHttpRequest request, WebSocketHandler handler) {
String protocolHeader = request.getHeaders().getFirst(SEC_WEBSOCKET_PROTOCOL);
private Mono<Void> handleBadRequest(String reason) {
if (logger.isDebugEnabled()) {
logger.debug(reason);
}
return Mono.error(new ServerWebInputException(reason));
}
private Optional<String> selectProtocol(HttpHeaders headers, WebSocketHandler handler) {
String protocolHeader = headers.getFirst(SEC_WEBSOCKET_PROTOCOL);
if (protocolHeader == null) {
return Optional.empty();
}