diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/ContextWebSocketHandler.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/ContextWebSocketHandler.java new file mode 100644 index 0000000000..816dc147bd --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/ContextWebSocketHandler.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.reactive.socket.adapter; + +import java.util.List; + +import reactor.core.publisher.Mono; +import reactor.util.context.ContextView; + +import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.WebSocketSession; + +/** + * {@link WebSocketHandler} decorator that enriches the context of the target handler. + * + * @author Rossen Stoyanchev + * @since 5.3.3 + */ +public final class ContextWebSocketHandler implements WebSocketHandler { + + private final WebSocketHandler delegate; + + private final ContextView contextView; + + + private ContextWebSocketHandler(WebSocketHandler delegate, ContextView contextView) { + this.delegate = delegate; + this.contextView = contextView; + } + + + @Override + public List getSubProtocols() { + return this.delegate.getSubProtocols(); + } + + @Override + public Mono handle(WebSocketSession session) { + return this.delegate.handle(session).contextWrite(this.contextView); + } + + + /** + * Return the given handler, decorated to insert the given context, or the + * same handler instance when the context is empty. + */ + public static WebSocketHandler decorate(WebSocketHandler handler, ContextView contextView) { + return (!contextView.isEmpty() ? new ContextWebSocketHandler(handler, contextView) : handler); + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/JettyWebSocketClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/JettyWebSocketClient.java index 0dbe7b0ba1..1074837cbe 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/JettyWebSocketClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/JettyWebSocketClient.java @@ -16,6 +16,7 @@ package org.springframework.web.reactive.socket.client; +import java.io.IOException; import java.net.URI; import org.apache.commons.logging.Log; @@ -33,6 +34,7 @@ import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.http.HttpHeaders; import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler; import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdapter; import org.springframework.web.reactive.socket.adapter.JettyWebSocketSession; @@ -137,18 +139,23 @@ public class JettyWebSocketClient implements WebSocketClient, Lifecycle { private Mono executeInternal(URI url, HttpHeaders headers, WebSocketHandler handler) { Sinks.Empty completionSink = Sinks.empty(); - return Mono.fromCallable( - () -> { - if (logger.isDebugEnabled()) { - logger.debug("Connecting to " + url); - } - Object jettyHandler = createHandler(url, handler, completionSink); - ClientUpgradeRequest request = new ClientUpgradeRequest(); - request.setSubProtocols(handler.getSubProtocols()); - UpgradeListener upgradeListener = new DefaultUpgradeListener(headers); - return this.jettyClient.connect(jettyHandler, url, request, upgradeListener); - }) - .then(completionSink.asMono()); + return Mono.deferContextual(contextView -> { + if (logger.isDebugEnabled()) { + logger.debug("Connecting to " + url); + } + Object jettyHandler = createHandler( + url, ContextWebSocketHandler.decorate(handler, contextView), completionSink); + ClientUpgradeRequest request = new ClientUpgradeRequest(); + request.setSubProtocols(handler.getSubProtocols()); + UpgradeListener upgradeListener = new DefaultUpgradeListener(headers); + try { + this.jettyClient.connect(jettyHandler, url, request, upgradeListener); + return completionSink.asMono(); + } + catch (IOException ex) { + return Mono.error(ex); + } + }); } private Object createHandler(URI url, WebSocketHandler handler, Sinks.Empty completion) { diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/StandardWebSocketClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/StandardWebSocketClient.java index 0ce6922789..4bf98585a0 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/StandardWebSocketClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/StandardWebSocketClient.java @@ -39,6 +39,7 @@ import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.http.HttpHeaders; import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler; import org.springframework.web.reactive.socket.adapter.StandardWebSocketHandlerAdapter; import org.springframework.web.reactive.socket.adapter.StandardWebSocketSession; @@ -95,20 +96,26 @@ public class StandardWebSocketClient implements WebSocketClient { } private Mono executeInternal(URI url, HttpHeaders requestHeaders, WebSocketHandler handler) { - Sinks.Empty completionSink = Sinks.empty(); - return Mono.fromCallable( - () -> { + Sinks.Empty completion = Sinks.empty(); + return Mono.deferContextual( + contextView -> { if (logger.isDebugEnabled()) { logger.debug("Connecting to " + url); } List protocols = handler.getSubProtocols(); DefaultConfigurator configurator = new DefaultConfigurator(requestHeaders); - Endpoint endpoint = createEndpoint(url, handler, completionSink, configurator); + Endpoint endpoint = createEndpoint( + url, ContextWebSocketHandler.decorate(handler, contextView), completion, configurator); ClientEndpointConfig config = createEndpointConfig(configurator, protocols); - return this.webSocketContainer.connectToServer(endpoint, config, url); + try { + this.webSocketContainer.connectToServer(endpoint, config, url); + return completion.asMono(); + } + catch (Exception ex) { + return Mono.error(ex); + } }) - .subscribeOn(Schedulers.boundedElastic()) // connectToServer is blocking - .then(completionSink.asMono()); + .subscribeOn(Schedulers.boundedElastic()); // connectToServer is blocking } private StandardWebSocketHandlerAdapter createEndpoint(URI url, WebSocketHandler handler, diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/UndertowWebSocketClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/UndertowWebSocketClient.java index 03563efdc8..272e04f345 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/UndertowWebSocketClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/UndertowWebSocketClient.java @@ -42,6 +42,7 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler; import org.springframework.web.reactive.socket.adapter.UndertowWebSocketHandlerAdapter; import org.springframework.web.reactive.socket.adapter.UndertowWebSocketSession; @@ -154,9 +155,9 @@ public class UndertowWebSocketClient implements WebSocketClient { } private Mono executeInternal(URI url, HttpHeaders headers, WebSocketHandler handler) { - Sinks.Empty completionSink = Sinks.empty(); - return Mono.fromCallable( - () -> { + Sinks.Empty completion = Sinks.empty(); + return Mono.deferContextual( + contextView -> { if (logger.isDebugEnabled()) { logger.debug("Connecting to " + url); } @@ -164,21 +165,22 @@ public class UndertowWebSocketClient implements WebSocketClient { ConnectionBuilder builder = createConnectionBuilder(url); DefaultNegotiation negotiation = new DefaultNegotiation(protocols, headers, builder); builder.setClientNegotiation(negotiation); - return builder.connect().addNotifier( + builder.connect().addNotifier( new IoFuture.HandlingNotifier() { @Override public void handleDone(WebSocketChannel channel, Object attachment) { - handleChannel(url, handler, completionSink, negotiation, channel); + handleChannel(url, ContextWebSocketHandler.decorate(handler, contextView), + completion, negotiation, channel); } @Override public void handleFailed(IOException ex, Object attachment) { // Ignore result: can't overflow, ok if not first or no one listens - completionSink.tryEmitError( + completion.tryEmitError( new IllegalStateException("Failed to connect to " + url, ex)); } }, null); - }) - .then(completionSink.asMono()); + return completion.asMono(); + }); } /** diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java index f4a4882e7b..7d5b008100 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java @@ -16,6 +16,7 @@ package org.springframework.web.reactive.socket.server.upgrade; +import java.io.IOException; import java.util.function.Supplier; import javax.servlet.ServletContext; @@ -39,6 +40,7 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler; import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdapter; import org.springframework.web.reactive.socket.adapter.JettyWebSocketSession; import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; @@ -152,9 +154,6 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life HandshakeInfo handshakeInfo = handshakeInfoFactory.get(); DataBufferFactory factory = response.bufferFactory(); - JettyWebSocketHandlerAdapter adapter = new JettyWebSocketHandlerAdapter( - handler, session -> new JettyWebSocketSession(session, handshakeInfo, factory)); - startLazily(servletRequest); Assert.state(this.factory != null, "No WebSocketServerFactory available"); @@ -163,15 +162,22 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life // Trigger WebFlux preCommit actions and upgrade return exchange.getResponse().setComplete() - .then(Mono.fromCallable(() -> { + .then(Mono.deferContextual(contextView -> { + JettyWebSocketHandlerAdapter adapter = new JettyWebSocketHandlerAdapter( + ContextWebSocketHandler.decorate(handler, contextView), + session -> new JettyWebSocketSession(session, handshakeInfo, factory)); + try { adapterHolder.set(new WebSocketHandlerContainer(adapter, subProtocol)); this.factory.acceptWebSocket(servletRequest, servletResponse); } + catch (IOException ex) { + return Mono.error(ex); + } finally { adapterHolder.remove(); } - return null; + return Mono.empty(); })); } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java index ad56c0370f..a9d84b1644 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java @@ -38,6 +38,7 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler; import org.springframework.web.reactive.socket.adapter.StandardWebSocketHandlerAdapter; import org.springframework.web.reactive.socket.adapter.TomcatWebSocketSession; import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; @@ -137,20 +138,26 @@ public class TomcatRequestUpgradeStrategy implements RequestUpgradeStrategy { HandshakeInfo handshakeInfo = handshakeInfoFactory.get(); DataBufferFactory bufferFactory = response.bufferFactory(); - Endpoint endpoint = new StandardWebSocketHandlerAdapter( - handler, session -> new TomcatWebSocketSession(session, handshakeInfo, bufferFactory)); - - String requestURI = servletRequest.getRequestURI(); - DefaultServerEndpointConfig config = new DefaultServerEndpointConfig(requestURI, endpoint); - config.setSubprotocols(subProtocol != null ? - Collections.singletonList(subProtocol) : Collections.emptyList()); - // Trigger WebFlux preCommit actions and upgrade return exchange.getResponse().setComplete() - .then(Mono.fromCallable(() -> { + .then(Mono.deferContextual(contextView -> { + Endpoint endpoint = new StandardWebSocketHandlerAdapter( + ContextWebSocketHandler.decorate(handler, contextView), + session -> new TomcatWebSocketSession(session, handshakeInfo, bufferFactory)); + + String requestURI = servletRequest.getRequestURI(); + DefaultServerEndpointConfig config = new DefaultServerEndpointConfig(requestURI, endpoint); + config.setSubprotocols(subProtocol != null ? + Collections.singletonList(subProtocol) : Collections.emptyList()); + WsServerContainer container = getContainer(servletRequest); - container.doUpgrade(servletRequest, servletResponse, config, Collections.emptyMap()); - return null; + try { + container.doUpgrade(servletRequest, servletResponse, config, Collections.emptyMap()); + } + catch (Exception ex) { + return Mono.error(ex); + } + return Mono.empty(); })); } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java index d57fb1b948..fa60bcdb71 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java @@ -37,6 +37,7 @@ import org.springframework.http.server.reactive.ServerHttpRequestDecorator; import org.springframework.lang.Nullable; import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler; import org.springframework.web.reactive.socket.adapter.UndertowWebSocketHandlerAdapter; import org.springframework.web.reactive.socket.adapter.UndertowWebSocketSession; import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; @@ -67,10 +68,18 @@ public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy { // Trigger WebFlux preCommit actions and upgrade return exchange.getResponse().setComplete() - .then(Mono.fromCallable(() -> { - DefaultCallback callback = new DefaultCallback(handshakeInfo, handler, bufferFactory); - new WebSocketProtocolHandshakeHandler(handshakes, callback).handleRequest(httpExchange); - return null; + .then(Mono.deferContextual(contextView -> { + DefaultCallback callback = new DefaultCallback( + handshakeInfo, + ContextWebSocketHandler.decorate(handler, contextView), + bufferFactory); + try { + new WebSocketProtocolHandshakeHandler(handshakes, callback).handleRequest(httpExchange); + } + catch (Exception ex) { + return Mono.error(ex); + } + return Mono.empty(); })); } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/socket/AbstractWebSocketIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/socket/AbstractWebSocketIntegrationTests.java index 8d0e78cf2a..494107f6c9 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/socket/AbstractWebSocketIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/socket/AbstractWebSocketIntegrationTests.java @@ -43,6 +43,7 @@ import org.springframework.context.annotation.AnnotationConfigApplicationContext import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.web.filter.reactive.ServerWebExchangeContextFilter; import org.springframework.web.reactive.DispatcherHandler; import org.springframework.web.reactive.socket.client.JettyWebSocketClient; import org.springframework.web.reactive.socket.client.ReactorNettyWebSocketClient; @@ -57,6 +58,7 @@ import org.springframework.web.reactive.socket.server.upgrade.JettyRequestUpgrad import org.springframework.web.reactive.socket.server.upgrade.ReactorNettyRequestUpgradeStrategy; import org.springframework.web.reactive.socket.server.upgrade.TomcatRequestUpgradeStrategy; import org.springframework.web.reactive.socket.server.upgrade.UndertowRequestUpgradeStrategy; +import org.springframework.web.server.WebFilter; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; import org.springframework.web.testfixture.http.server.reactive.bootstrap.HttpServer; import org.springframework.web.testfixture.http.server.reactive.bootstrap.JettyHttpServer; @@ -165,6 +167,11 @@ abstract class AbstractWebSocketIntegrationTests { @Configuration static class DispatcherConfig { + @Bean + public WebFilter contextFilter() { + return new ServerWebExchangeContextFilter(); + } + @Bean public DispatcherHandler webHandler() { return new DispatcherHandler(); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/socket/WebSocketIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/socket/WebSocketIntegrationTests.java index 072cfddd23..2b0380aea9 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/socket/WebSocketIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/socket/WebSocketIntegrationTests.java @@ -33,6 +33,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.http.HttpHeaders; import org.springframework.http.ResponseCookie; +import org.springframework.web.filter.reactive.ServerWebExchangeContextFilter; import org.springframework.web.reactive.HandlerMapping; import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping; import org.springframework.web.reactive.socket.client.WebSocketClient; @@ -216,8 +217,11 @@ class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { @Override public Mono handle(WebSocketSession session) { - // Use retain() for Reactor Netty - return session.send(session.receive().doOnNext(WebSocketMessage::retain)); + return Mono.deferContextual(contextView -> { + String key = ServerWebExchangeContextFilter.EXCHANGE_CONTEXT_ATTRIBUTE; + assertThat(contextView.getOrEmpty(key).orElse(null)).isNotNull(); + return session.send(session.receive().doOnNext(WebSocketMessage::retain)); + }); } }