diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/RxNettyWebSocketSession.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/RxNettyWebSocketSession.java index 0d436fa6cde..d352831af00 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/RxNettyWebSocketSession.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/adapter/RxNettyWebSocketSession.java @@ -17,8 +17,10 @@ package org.springframework.web.reactive.socket.adapter; import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrameAggregator; import io.reactivex.netty.protocol.http.ws.WebSocketConnection; @@ -45,8 +47,8 @@ import org.springframework.web.reactive.socket.WebSocketSession; public class RxNettyWebSocketSession extends NettyWebSocketSessionSupport { /** - * The name of the {@link WebSocketFrameAggregator} inserted by - * {@link #aggregateFrames(Channel, String)}. + * The {@code ChannelHandler} name to use when inserting a + * {@link WebSocketFrameAggregator} in the channel pipeline. */ public static final String FRAME_AGGREGATOR_NAME = "websocket-frame-aggregator"; @@ -70,18 +72,21 @@ public class RxNettyWebSocketSession extends NettyWebSocketSessionSupport receive() { - Observable observable = getDelegate().getInput().map(super::toMessage); - return Flux.from(RxReactiveStreams.toPublisher(observable)); + Observable messages = getDelegate() + .getInput() + .filter(frame -> !(frame instanceof CloseWebSocketFrame)) + .map(super::toMessage); + return Flux.from(RxReactiveStreams.toPublisher(messages)); } @Override diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/ReactorNettyWebSocketClient.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/ReactorNettyWebSocketClient.java index cd75064a559..8c28741af27 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/ReactorNettyWebSocketClient.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/ReactorNettyWebSocketClient.java @@ -16,7 +16,6 @@ package org.springframework.web.reactive.socket.client; import java.net.URI; -import java.util.Optional; import java.util.function.Consumer; import io.netty.buffer.ByteBufAllocator; @@ -61,7 +60,7 @@ public class ReactorNettyWebSocketClient extends WebSocketClientSupport implemen @Override public Mono execute(URI url, HttpHeaders headers, WebSocketHandler handler) { - String[] protocols = getSubProtocols(headers, handler); + String[] protocols = beforeHandshake(url, headers, handler); // TODO: https://github.com/reactor/reactor-netty/issues/20 return this.httpClient @@ -71,9 +70,7 @@ public class ReactorNettyWebSocketClient extends WebSocketClientSupport implemen }) .then(response -> { HttpHeaders responseHeaders = getResponseHeaders(response); - String protocol = responseHeaders.getFirst(SEC_WEBSOCKET_PROTOCOL); - HandshakeInfo info = new HandshakeInfo(url, responseHeaders, Mono.empty(), - Optional.ofNullable(protocol)); + HandshakeInfo info = afterHandshake(url, response.status().code(), responseHeaders); ByteBufAllocator allocator = response.channel().alloc(); NettyDataBufferFactory factory = new NettyDataBufferFactory(allocator); diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/RxNettyWebSocketClient.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/RxNettyWebSocketClient.java index a39e43cce00..139bb9688cf 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/RxNettyWebSocketClient.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/RxNettyWebSocketClient.java @@ -22,7 +22,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.function.Function; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; @@ -44,7 +43,6 @@ import org.springframework.http.HttpHeaders; import org.springframework.util.ObjectUtils; import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.WebSocketHandler; -import org.springframework.web.reactive.socket.WebSocketSession; import org.springframework.web.reactive.socket.adapter.RxNettyWebSocketSession; /** @@ -105,7 +103,10 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We } private Observable connectInternal(URI url, HttpHeaders headers, WebSocketHandler handler) { - return createRequest(url, headers, handler) + + String[] protocols = beforeHandshake(url, headers, handler); + + return createRequest(url, headers, protocols) .flatMap(response -> { Observable conn = response.getWebSocketConnection(); return Observable.zip(Observable.just(response), conn, Tuples::of); @@ -113,8 +114,7 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We .flatMap(tuple -> { WebSocketResponse response = tuple.getT1(); HttpHeaders responseHeaders = getResponseHeaders(response); - Optional protocol = Optional.ofNullable(response.getAcceptedSubProtocol()); - HandshakeInfo info = new HandshakeInfo(url, responseHeaders, Mono.empty(), protocol); + HandshakeInfo info = afterHandshake(url, response.getStatus().code(), responseHeaders); ByteBufAllocator allocator = response.unsafeNettyChannel().alloc(); NettyDataBufferFactory factory = new NettyDataBufferFactory(allocator); @@ -128,7 +128,7 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We }); } - private WebSocketRequest createRequest(URI url, HttpHeaders headers, WebSocketHandler handler) { + private WebSocketRequest createRequest(URI url, HttpHeaders headers, String[] protocols) { String query = url.getRawQuery(); String requestUrl = url.getRawPath() + (query != null ? "?" + query : ""); @@ -138,7 +138,6 @@ public class RxNettyWebSocketClient extends WebSocketClientSupport implements We .setHeaders(toObjectValueMap(headers)) .requestWebSocketUpgrade(); - String[] protocols = getSubProtocols(headers, handler); if (!ObjectUtils.isEmpty(protocols)) { request = request.requestSubProtocols(protocols); } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/WebSocketClientSupport.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/WebSocketClientSupport.java index dbdcf5e382f..e1b7108d7a5 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/WebSocketClientSupport.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/client/WebSocketClientSupport.java @@ -15,8 +15,16 @@ */ package org.springframework.web.reactive.socket.client; +import java.net.URI; +import java.util.Optional; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import reactor.core.publisher.Mono; + import org.springframework.http.HttpHeaders; -import org.springframework.util.StringUtils; +import org.springframework.util.Assert; +import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.WebSocketHandler; /** @@ -30,11 +38,23 @@ public class WebSocketClientSupport { protected static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"; - protected String[] getSubProtocols(HttpHeaders headers, WebSocketHandler handler) { - String value = headers.getFirst(SEC_WEBSOCKET_PROTOCOL); - return (value != null ? - StringUtils.commaDelimitedListToStringArray(value) : - handler.getSubProtocols()); + protected final Log logger = LogFactory.getLog(getClass()); + + + protected String[] beforeHandshake(URI url, HttpHeaders headers, WebSocketHandler handler) { + if (logger.isDebugEnabled()) { + logger.debug("Executing handshake to " + url); + } + return handler.getSubProtocols(); + } + + protected HandshakeInfo afterHandshake(URI url, int statusCode, HttpHeaders headers) { + Assert.isTrue(statusCode == 101); + if (logger.isDebugEnabled()) { + logger.debug("Handshake response: " + url + ", " + headers); + } + String protocol = headers.getFirst(SEC_WEBSOCKET_PROTOCOL); + return new HandshakeInfo(url, headers, Mono.empty(), Optional.ofNullable(protocol)); } } \ No newline at end of file diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java index 5aa4ab355d9..174d20fce37 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java @@ -25,10 +25,9 @@ import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Mono; import org.springframework.context.Lifecycle; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; -import org.springframework.http.HttpStatus; import org.springframework.http.server.reactive.ServerHttpRequest; -import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.ReflectionUtils; @@ -38,6 +37,7 @@ import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; import org.springframework.web.reactive.socket.server.WebSocketService; import org.springframework.web.server.MethodNotAllowedException; import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.ServerWebInputException; /** * {@code WebSocketService} implementation that handles a WebSocket HTTP @@ -179,53 +179,44 @@ public class HandshakeWebSocketService implements WebSocketService, Lifecycle { public Mono handleRequest(ServerWebExchange exchange, WebSocketHandler handler) { ServerHttpRequest request = exchange.getRequest(); - ServerHttpResponse response = exchange.getResponse(); + HttpMethod method = request.getMethod(); + HttpHeaders headers = request.getHeaders(); - if (logger.isTraceEnabled()) { - logger.trace("Processing " + request.getMethod() + " " + request.getURI()); + if (logger.isDebugEnabled()) { + logger.debug("Handling " + request.getURI() + " with headers: " + headers); } - if (HttpMethod.GET != request.getMethod()) { - return Mono.error(new MethodNotAllowedException( - request.getMethod().name(), Collections.singleton("GET"))); + if (HttpMethod.GET != method) { + return Mono.error(new MethodNotAllowedException(method.name(), Collections.singleton("GET"))); } - if (!isWebSocketUpgrade(request)) { - response.setStatusCode(HttpStatus.BAD_REQUEST); - return response.setComplete(); + if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { + return handleBadRequest("Invalid 'Upgrade' header: " + headers); } - Optional subProtocol = selectSubProtocol(request, handler); - - return getUpgradeStrategy().upgrade(exchange, handler, subProtocol); - } - - private boolean isWebSocketUpgrade(ServerHttpRequest request) { - if (!"WebSocket".equalsIgnoreCase(request.getHeaders().getUpgrade())) { - if (logger.isErrorEnabled()) { - logger.error("Invalid 'Upgrade' header: " + request.getHeaders()); - } - return false; - } - List connectionValue = request.getHeaders().getConnection(); + List connectionValue = headers.getConnection(); if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) { - if (logger.isErrorEnabled()) { - logger.error("Invalid 'Connection' header: " + request.getHeaders()); - } - return false; + return handleBadRequest("Invalid 'Connection' header: " + headers); } - String key = request.getHeaders().getFirst(SEC_WEBSOCKET_KEY); + + String key = headers.getFirst(SEC_WEBSOCKET_KEY); if (key == null) { - if (logger.isErrorEnabled()) { - logger.error("Missing \"Sec-WebSocket-Key\" header"); - } - return false; + return handleBadRequest("Missing \"Sec-WebSocket-Key\" header"); } - return true; + + Optional protocol = selectProtocol(headers, handler); + return this.upgradeStrategy.upgrade(exchange, handler, protocol); } - private Optional selectSubProtocol(ServerHttpRequest request, WebSocketHandler handler) { - String protocolHeader = request.getHeaders().getFirst(SEC_WEBSOCKET_PROTOCOL); + private Mono handleBadRequest(String reason) { + if (logger.isDebugEnabled()) { + logger.debug(reason); + } + return Mono.error(new ServerWebInputException(reason)); + } + + private Optional selectProtocol(HttpHeaders headers, WebSocketHandler handler) { + String protocolHeader = headers.getFirst(SEC_WEBSOCKET_PROTOCOL); if (protocolHeader == null) { return Optional.empty(); }