Minor refactoring + polish
- RxNettyWebSocketSession filters out WebSocketCloseFrame again - add before/afterHandshake helper methods in WebSocketClientSupport - log request headers on server and response headers on client - polish 400 request handling in HandshakeWebSocketService
This commit is contained in:
parent
d64d9ab370
commit
3719f75d3b
|
|
@ -17,8 +17,10 @@
|
|||
package org.springframework.web.reactive.socket.adapter;
|
||||
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPipeline;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketFrameAggregator;
|
||||
import io.reactivex.netty.protocol.http.ws.WebSocketConnection;
|
||||
|
|
@ -45,8 +47,8 @@ import org.springframework.web.reactive.socket.WebSocketSession;
|
|||
public class RxNettyWebSocketSession extends NettyWebSocketSessionSupport<WebSocketConnection> {
|
||||
|
||||
/**
|
||||
* The name of the {@link WebSocketFrameAggregator} inserted by
|
||||
* {@link #aggregateFrames(Channel, String)}.
|
||||
* The {@code ChannelHandler} name to use when inserting a
|
||||
* {@link WebSocketFrameAggregator} in the channel pipeline.
|
||||
*/
|
||||
public static final String FRAME_AGGREGATOR_NAME = "websocket-frame-aggregator";
|
||||
|
||||
|
|
@ -70,18 +72,21 @@ public class RxNettyWebSocketSession extends NettyWebSocketSessionSupport<WebSoc
|
|||
logger.trace("WebSocketFrameAggregator already registered.");
|
||||
return this;
|
||||
}
|
||||
ChannelHandlerContext context = pipeline.context(frameDecoderName);
|
||||
Assert.notNull(context, "WebSocketFrameDecoder not found: " + frameDecoderName);
|
||||
WebSocketFrameAggregator aggregator = new WebSocketFrameAggregator(DEFAULT_FRAME_MAX_SIZE);
|
||||
pipeline.addAfter(context.name(), FRAME_AGGREGATOR_NAME, aggregator);
|
||||
ChannelHandlerContext frameDecoder = pipeline.context(frameDecoderName);
|
||||
Assert.notNull(frameDecoder, "WebSocketFrameDecoder not found: " + frameDecoderName);
|
||||
ChannelHandler frameAggregator = new WebSocketFrameAggregator(DEFAULT_FRAME_MAX_SIZE);
|
||||
pipeline.addAfter(frameDecoder.name(), FRAME_AGGREGATOR_NAME, frameAggregator);
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Flux<WebSocketMessage> receive() {
|
||||
Observable<WebSocketMessage> observable = getDelegate().getInput().map(super::toMessage);
|
||||
return Flux.from(RxReactiveStreams.toPublisher(observable));
|
||||
Observable<WebSocketMessage> messages = getDelegate()
|
||||
.getInput()
|
||||
.filter(frame -> !(frame instanceof CloseWebSocketFrame))
|
||||
.map(super::toMessage);
|
||||
return Flux.from(RxReactiveStreams.toPublisher(messages));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@
|
|||
package org.springframework.web.reactive.socket.client;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.Optional;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
|
|
@ -61,7 +60,7 @@ public class ReactorNettyWebSocketClient extends WebSocketClientSupport implemen
|
|||
@Override
|
||||
public Mono<Void> execute(URI url, HttpHeaders headers, WebSocketHandler handler) {
|
||||
|
||||
String[] protocols = getSubProtocols(headers, handler);
|
||||
String[] protocols = beforeHandshake(url, headers, handler);
|
||||
// TODO: https://github.com/reactor/reactor-netty/issues/20
|
||||
|
||||
return this.httpClient
|
||||
|
|
@ -71,9 +70,7 @@ public class ReactorNettyWebSocketClient extends WebSocketClientSupport implemen
|
|||
})
|
||||
.then(response -> {
|
||||
HttpHeaders responseHeaders = getResponseHeaders(response);
|
||||
String protocol = responseHeaders.getFirst(SEC_WEBSOCKET_PROTOCOL);
|
||||
HandshakeInfo info = new HandshakeInfo(url, responseHeaders, Mono.empty(),
|
||||
Optional.ofNullable(protocol));
|
||||
HandshakeInfo info = afterHandshake(url, response.status().code(), responseHeaders);
|
||||
|
||||
ByteBufAllocator allocator = response.channel().alloc();
|
||||
NettyDataBufferFactory factory = new NettyDataBufferFactory(allocator);
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ import java.util.Collections;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.function.Function;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLEngine;
|
||||
|
|
@ -44,7 +43,6 @@ import org.springframework.http.HttpHeaders;
|
|||
import org.springframework.util.ObjectUtils;
|
||||
import org.springframework.web.reactive.socket.HandshakeInfo;
|
||||
import org.springframework.web.reactive.socket.WebSocketHandler;
|
||||
import org.springframework.web.reactive.socket.WebSocketSession;
|
||||
import org.springframework.web.reactive.socket.adapter.RxNettyWebSocketSession;
|
||||
|
||||
/**
|
||||
|
|
@ -105,7 +103,10 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We
|
|||
}
|
||||
|
||||
private Observable<Void> connectInternal(URI url, HttpHeaders headers, WebSocketHandler handler) {
|
||||
return createRequest(url, headers, handler)
|
||||
|
||||
String[] protocols = beforeHandshake(url, headers, handler);
|
||||
|
||||
return createRequest(url, headers, protocols)
|
||||
.flatMap(response -> {
|
||||
Observable<WebSocketConnection> conn = response.getWebSocketConnection();
|
||||
return Observable.zip(Observable.just(response), conn, Tuples::of);
|
||||
|
|
@ -113,8 +114,7 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We
|
|||
.flatMap(tuple -> {
|
||||
WebSocketResponse<ByteBuf> response = tuple.getT1();
|
||||
HttpHeaders responseHeaders = getResponseHeaders(response);
|
||||
Optional<String> protocol = Optional.ofNullable(response.getAcceptedSubProtocol());
|
||||
HandshakeInfo info = new HandshakeInfo(url, responseHeaders, Mono.empty(), protocol);
|
||||
HandshakeInfo info = afterHandshake(url, response.getStatus().code(), responseHeaders);
|
||||
|
||||
ByteBufAllocator allocator = response.unsafeNettyChannel().alloc();
|
||||
NettyDataBufferFactory factory = new NettyDataBufferFactory(allocator);
|
||||
|
|
@ -128,7 +128,7 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We
|
|||
});
|
||||
}
|
||||
|
||||
private WebSocketRequest<ByteBuf> createRequest(URI url, HttpHeaders headers, WebSocketHandler handler) {
|
||||
private WebSocketRequest<ByteBuf> createRequest(URI url, HttpHeaders headers, String[] protocols) {
|
||||
|
||||
String query = url.getRawQuery();
|
||||
String requestUrl = url.getRawPath() + (query != null ? "?" + query : "");
|
||||
|
|
@ -138,7 +138,6 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We
|
|||
.setHeaders(toObjectValueMap(headers))
|
||||
.requestWebSocketUpgrade();
|
||||
|
||||
String[] protocols = getSubProtocols(headers, handler);
|
||||
if (!ObjectUtils.isEmpty(protocols)) {
|
||||
request = request.requestSubProtocols(protocols);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,8 +15,16 @@
|
|||
*/
|
||||
package org.springframework.web.reactive.socket.client;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.Optional;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.util.StringUtils;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.web.reactive.socket.HandshakeInfo;
|
||||
import org.springframework.web.reactive.socket.WebSocketHandler;
|
||||
|
||||
/**
|
||||
|
|
@ -30,11 +38,23 @@ public class WebSocketClientSupport {
|
|||
protected static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol";
|
||||
|
||||
|
||||
protected String[] getSubProtocols(HttpHeaders headers, WebSocketHandler handler) {
|
||||
String value = headers.getFirst(SEC_WEBSOCKET_PROTOCOL);
|
||||
return (value != null ?
|
||||
StringUtils.commaDelimitedListToStringArray(value) :
|
||||
handler.getSubProtocols());
|
||||
protected final Log logger = LogFactory.getLog(getClass());
|
||||
|
||||
|
||||
protected String[] beforeHandshake(URI url, HttpHeaders headers, WebSocketHandler handler) {
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("Executing handshake to " + url);
|
||||
}
|
||||
return handler.getSubProtocols();
|
||||
}
|
||||
|
||||
protected HandshakeInfo afterHandshake(URI url, int statusCode, HttpHeaders headers) {
|
||||
Assert.isTrue(statusCode == 101);
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("Handshake response: " + url + ", " + headers);
|
||||
}
|
||||
String protocol = headers.getFirst(SEC_WEBSOCKET_PROTOCOL);
|
||||
return new HandshakeInfo(url, headers, Mono.empty(), Optional.ofNullable(protocol));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -25,10 +25,9 @@ import org.apache.commons.logging.LogFactory;
|
|||
import reactor.core.publisher.Mono;
|
||||
|
||||
import org.springframework.context.Lifecycle;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.server.reactive.ServerHttpRequest;
|
||||
import org.springframework.http.server.reactive.ServerHttpResponse;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.ClassUtils;
|
||||
import org.springframework.util.ReflectionUtils;
|
||||
|
|
@ -38,6 +37,7 @@ import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
|
|||
import org.springframework.web.reactive.socket.server.WebSocketService;
|
||||
import org.springframework.web.server.MethodNotAllowedException;
|
||||
import org.springframework.web.server.ServerWebExchange;
|
||||
import org.springframework.web.server.ServerWebInputException;
|
||||
|
||||
/**
|
||||
* {@code WebSocketService} implementation that handles a WebSocket HTTP
|
||||
|
|
@ -179,53 +179,44 @@ public class HandshakeWebSocketService implements WebSocketService, Lifecycle {
|
|||
public Mono<Void> handleRequest(ServerWebExchange exchange, WebSocketHandler handler) {
|
||||
|
||||
ServerHttpRequest request = exchange.getRequest();
|
||||
ServerHttpResponse response = exchange.getResponse();
|
||||
HttpMethod method = request.getMethod();
|
||||
HttpHeaders headers = request.getHeaders();
|
||||
|
||||
if (logger.isTraceEnabled()) {
|
||||
logger.trace("Processing " + request.getMethod() + " " + request.getURI());
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("Handling " + request.getURI() + " with headers: " + headers);
|
||||
}
|
||||
|
||||
if (HttpMethod.GET != request.getMethod()) {
|
||||
return Mono.error(new MethodNotAllowedException(
|
||||
request.getMethod().name(), Collections.singleton("GET")));
|
||||
if (HttpMethod.GET != method) {
|
||||
return Mono.error(new MethodNotAllowedException(method.name(), Collections.singleton("GET")));
|
||||
}
|
||||
|
||||
if (!isWebSocketUpgrade(request)) {
|
||||
response.setStatusCode(HttpStatus.BAD_REQUEST);
|
||||
return response.setComplete();
|
||||
if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
|
||||
return handleBadRequest("Invalid 'Upgrade' header: " + headers);
|
||||
}
|
||||
|
||||
Optional<String> subProtocol = selectSubProtocol(request, handler);
|
||||
|
||||
return getUpgradeStrategy().upgrade(exchange, handler, subProtocol);
|
||||
}
|
||||
|
||||
private boolean isWebSocketUpgrade(ServerHttpRequest request) {
|
||||
if (!"WebSocket".equalsIgnoreCase(request.getHeaders().getUpgrade())) {
|
||||
if (logger.isErrorEnabled()) {
|
||||
logger.error("Invalid 'Upgrade' header: " + request.getHeaders());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
List<String> connectionValue = request.getHeaders().getConnection();
|
||||
List<String> connectionValue = headers.getConnection();
|
||||
if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) {
|
||||
if (logger.isErrorEnabled()) {
|
||||
logger.error("Invalid 'Connection' header: " + request.getHeaders());
|
||||
}
|
||||
return false;
|
||||
return handleBadRequest("Invalid 'Connection' header: " + headers);
|
||||
}
|
||||
String key = request.getHeaders().getFirst(SEC_WEBSOCKET_KEY);
|
||||
|
||||
String key = headers.getFirst(SEC_WEBSOCKET_KEY);
|
||||
if (key == null) {
|
||||
if (logger.isErrorEnabled()) {
|
||||
logger.error("Missing \"Sec-WebSocket-Key\" header");
|
||||
}
|
||||
return false;
|
||||
return handleBadRequest("Missing \"Sec-WebSocket-Key\" header");
|
||||
}
|
||||
return true;
|
||||
|
||||
Optional<String> protocol = selectProtocol(headers, handler);
|
||||
return this.upgradeStrategy.upgrade(exchange, handler, protocol);
|
||||
}
|
||||
|
||||
private Optional<String> selectSubProtocol(ServerHttpRequest request, WebSocketHandler handler) {
|
||||
String protocolHeader = request.getHeaders().getFirst(SEC_WEBSOCKET_PROTOCOL);
|
||||
private Mono<Void> handleBadRequest(String reason) {
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug(reason);
|
||||
}
|
||||
return Mono.error(new ServerWebInputException(reason));
|
||||
}
|
||||
|
||||
private Optional<String> selectProtocol(HttpHeaders headers, WebSocketHandler handler) {
|
||||
String protocolHeader = headers.getFirst(SEC_WEBSOCKET_PROTOCOL);
|
||||
if (protocolHeader == null) {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue