Enrich WebSocketHandler context

Closes gh-26210
This commit is contained in:
Rossen Stoyanchev 2020-12-15 21:33:50 +00:00
parent 83c19cd60e
commit a11d1c8510
9 changed files with 162 additions and 49 deletions

View File

@ -0,0 +1,64 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.reactive.socket.adapter;
import java.util.List;
import reactor.core.publisher.Mono;
import reactor.util.context.ContextView;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketSession;
/**
* {@link WebSocketHandler} decorator that enriches the context of the target handler.
*
* @author Rossen Stoyanchev
* @since 5.3.3
*/
public final class ContextWebSocketHandler implements WebSocketHandler {
private final WebSocketHandler delegate;
private final ContextView contextView;
private ContextWebSocketHandler(WebSocketHandler delegate, ContextView contextView) {
this.delegate = delegate;
this.contextView = contextView;
}
@Override
public List<String> getSubProtocols() {
return this.delegate.getSubProtocols();
}
@Override
public Mono<Void> handle(WebSocketSession session) {
return this.delegate.handle(session).contextWrite(this.contextView);
}
/**
* Return the given handler, decorated to insert the given context, or the
* same handler instance when the context is empty.
*/
public static WebSocketHandler decorate(WebSocketHandler handler, ContextView contextView) {
return (!contextView.isEmpty() ? new ContextWebSocketHandler(handler, contextView) : handler);
}
}

View File

@ -16,6 +16,7 @@
package org.springframework.web.reactive.socket.client; package org.springframework.web.reactive.socket.client;
import java.io.IOException;
import java.net.URI; import java.net.URI;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
@ -33,6 +34,7 @@ import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler;
import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdapter; import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.JettyWebSocketSession; import org.springframework.web.reactive.socket.adapter.JettyWebSocketSession;
@ -137,18 +139,23 @@ public class JettyWebSocketClient implements WebSocketClient, Lifecycle {
private Mono<Void> executeInternal(URI url, HttpHeaders headers, WebSocketHandler handler) { private Mono<Void> executeInternal(URI url, HttpHeaders headers, WebSocketHandler handler) {
Sinks.Empty<Void> completionSink = Sinks.empty(); Sinks.Empty<Void> completionSink = Sinks.empty();
return Mono.fromCallable( return Mono.deferContextual(contextView -> {
() -> { if (logger.isDebugEnabled()) {
if (logger.isDebugEnabled()) { logger.debug("Connecting to " + url);
logger.debug("Connecting to " + url); }
} Object jettyHandler = createHandler(
Object jettyHandler = createHandler(url, handler, completionSink); url, ContextWebSocketHandler.decorate(handler, contextView), completionSink);
ClientUpgradeRequest request = new ClientUpgradeRequest(); ClientUpgradeRequest request = new ClientUpgradeRequest();
request.setSubProtocols(handler.getSubProtocols()); request.setSubProtocols(handler.getSubProtocols());
UpgradeListener upgradeListener = new DefaultUpgradeListener(headers); UpgradeListener upgradeListener = new DefaultUpgradeListener(headers);
return this.jettyClient.connect(jettyHandler, url, request, upgradeListener); try {
}) this.jettyClient.connect(jettyHandler, url, request, upgradeListener);
.then(completionSink.asMono()); return completionSink.asMono();
}
catch (IOException ex) {
return Mono.error(ex);
}
});
} }
private Object createHandler(URI url, WebSocketHandler handler, Sinks.Empty<Void> completion) { private Object createHandler(URI url, WebSocketHandler handler, Sinks.Empty<Void> completion) {

View File

@ -39,6 +39,7 @@ import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler;
import org.springframework.web.reactive.socket.adapter.StandardWebSocketHandlerAdapter; import org.springframework.web.reactive.socket.adapter.StandardWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.StandardWebSocketSession; import org.springframework.web.reactive.socket.adapter.StandardWebSocketSession;
@ -95,20 +96,26 @@ public class StandardWebSocketClient implements WebSocketClient {
} }
private Mono<Void> executeInternal(URI url, HttpHeaders requestHeaders, WebSocketHandler handler) { private Mono<Void> executeInternal(URI url, HttpHeaders requestHeaders, WebSocketHandler handler) {
Sinks.Empty<Void> completionSink = Sinks.empty(); Sinks.Empty<Void> completion = Sinks.empty();
return Mono.fromCallable( return Mono.deferContextual(
() -> { contextView -> {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Connecting to " + url); logger.debug("Connecting to " + url);
} }
List<String> protocols = handler.getSubProtocols(); List<String> protocols = handler.getSubProtocols();
DefaultConfigurator configurator = new DefaultConfigurator(requestHeaders); DefaultConfigurator configurator = new DefaultConfigurator(requestHeaders);
Endpoint endpoint = createEndpoint(url, handler, completionSink, configurator); Endpoint endpoint = createEndpoint(
url, ContextWebSocketHandler.decorate(handler, contextView), completion, configurator);
ClientEndpointConfig config = createEndpointConfig(configurator, protocols); ClientEndpointConfig config = createEndpointConfig(configurator, protocols);
return this.webSocketContainer.connectToServer(endpoint, config, url); try {
this.webSocketContainer.connectToServer(endpoint, config, url);
return completion.asMono();
}
catch (Exception ex) {
return Mono.error(ex);
}
}) })
.subscribeOn(Schedulers.boundedElastic()) // connectToServer is blocking .subscribeOn(Schedulers.boundedElastic()); // connectToServer is blocking
.then(completionSink.asMono());
} }
private StandardWebSocketHandlerAdapter createEndpoint(URI url, WebSocketHandler handler, private StandardWebSocketHandlerAdapter createEndpoint(URI url, WebSocketHandler handler,

View File

@ -42,6 +42,7 @@ import org.springframework.lang.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler;
import org.springframework.web.reactive.socket.adapter.UndertowWebSocketHandlerAdapter; import org.springframework.web.reactive.socket.adapter.UndertowWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.UndertowWebSocketSession; import org.springframework.web.reactive.socket.adapter.UndertowWebSocketSession;
@ -154,9 +155,9 @@ public class UndertowWebSocketClient implements WebSocketClient {
} }
private Mono<Void> executeInternal(URI url, HttpHeaders headers, WebSocketHandler handler) { private Mono<Void> executeInternal(URI url, HttpHeaders headers, WebSocketHandler handler) {
Sinks.Empty<Void> completionSink = Sinks.empty(); Sinks.Empty<Void> completion = Sinks.empty();
return Mono.fromCallable( return Mono.deferContextual(
() -> { contextView -> {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Connecting to " + url); logger.debug("Connecting to " + url);
} }
@ -164,21 +165,22 @@ public class UndertowWebSocketClient implements WebSocketClient {
ConnectionBuilder builder = createConnectionBuilder(url); ConnectionBuilder builder = createConnectionBuilder(url);
DefaultNegotiation negotiation = new DefaultNegotiation(protocols, headers, builder); DefaultNegotiation negotiation = new DefaultNegotiation(protocols, headers, builder);
builder.setClientNegotiation(negotiation); builder.setClientNegotiation(negotiation);
return builder.connect().addNotifier( builder.connect().addNotifier(
new IoFuture.HandlingNotifier<WebSocketChannel, Object>() { new IoFuture.HandlingNotifier<WebSocketChannel, Object>() {
@Override @Override
public void handleDone(WebSocketChannel channel, Object attachment) { public void handleDone(WebSocketChannel channel, Object attachment) {
handleChannel(url, handler, completionSink, negotiation, channel); handleChannel(url, ContextWebSocketHandler.decorate(handler, contextView),
completion, negotiation, channel);
} }
@Override @Override
public void handleFailed(IOException ex, Object attachment) { public void handleFailed(IOException ex, Object attachment) {
// Ignore result: can't overflow, ok if not first or no one listens // Ignore result: can't overflow, ok if not first or no one listens
completionSink.tryEmitError( completion.tryEmitError(
new IllegalStateException("Failed to connect to " + url, ex)); new IllegalStateException("Failed to connect to " + url, ex));
} }
}, null); }, null);
}) return completion.asMono();
.then(completionSink.asMono()); });
} }
/** /**

View File

@ -16,6 +16,7 @@
package org.springframework.web.reactive.socket.server.upgrade; package org.springframework.web.reactive.socket.server.upgrade;
import java.io.IOException;
import java.util.function.Supplier; import java.util.function.Supplier;
import javax.servlet.ServletContext; import javax.servlet.ServletContext;
@ -39,6 +40,7 @@ import org.springframework.lang.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler;
import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdapter; import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.JettyWebSocketSession; import org.springframework.web.reactive.socket.adapter.JettyWebSocketSession;
import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
@ -152,9 +154,6 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
HandshakeInfo handshakeInfo = handshakeInfoFactory.get(); HandshakeInfo handshakeInfo = handshakeInfoFactory.get();
DataBufferFactory factory = response.bufferFactory(); DataBufferFactory factory = response.bufferFactory();
JettyWebSocketHandlerAdapter adapter = new JettyWebSocketHandlerAdapter(
handler, session -> new JettyWebSocketSession(session, handshakeInfo, factory));
startLazily(servletRequest); startLazily(servletRequest);
Assert.state(this.factory != null, "No WebSocketServerFactory available"); Assert.state(this.factory != null, "No WebSocketServerFactory available");
@ -163,15 +162,22 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
// Trigger WebFlux preCommit actions and upgrade // Trigger WebFlux preCommit actions and upgrade
return exchange.getResponse().setComplete() return exchange.getResponse().setComplete()
.then(Mono.fromCallable(() -> { .then(Mono.deferContextual(contextView -> {
JettyWebSocketHandlerAdapter adapter = new JettyWebSocketHandlerAdapter(
ContextWebSocketHandler.decorate(handler, contextView),
session -> new JettyWebSocketSession(session, handshakeInfo, factory));
try { try {
adapterHolder.set(new WebSocketHandlerContainer(adapter, subProtocol)); adapterHolder.set(new WebSocketHandlerContainer(adapter, subProtocol));
this.factory.acceptWebSocket(servletRequest, servletResponse); this.factory.acceptWebSocket(servletRequest, servletResponse);
} }
catch (IOException ex) {
return Mono.error(ex);
}
finally { finally {
adapterHolder.remove(); adapterHolder.remove();
} }
return null; return Mono.empty();
})); }));
} }

View File

@ -38,6 +38,7 @@ import org.springframework.lang.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler;
import org.springframework.web.reactive.socket.adapter.StandardWebSocketHandlerAdapter; import org.springframework.web.reactive.socket.adapter.StandardWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.TomcatWebSocketSession; import org.springframework.web.reactive.socket.adapter.TomcatWebSocketSession;
import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
@ -137,20 +138,26 @@ public class TomcatRequestUpgradeStrategy implements RequestUpgradeStrategy {
HandshakeInfo handshakeInfo = handshakeInfoFactory.get(); HandshakeInfo handshakeInfo = handshakeInfoFactory.get();
DataBufferFactory bufferFactory = response.bufferFactory(); DataBufferFactory bufferFactory = response.bufferFactory();
Endpoint endpoint = new StandardWebSocketHandlerAdapter(
handler, session -> new TomcatWebSocketSession(session, handshakeInfo, bufferFactory));
String requestURI = servletRequest.getRequestURI();
DefaultServerEndpointConfig config = new DefaultServerEndpointConfig(requestURI, endpoint);
config.setSubprotocols(subProtocol != null ?
Collections.singletonList(subProtocol) : Collections.emptyList());
// Trigger WebFlux preCommit actions and upgrade // Trigger WebFlux preCommit actions and upgrade
return exchange.getResponse().setComplete() return exchange.getResponse().setComplete()
.then(Mono.fromCallable(() -> { .then(Mono.deferContextual(contextView -> {
Endpoint endpoint = new StandardWebSocketHandlerAdapter(
ContextWebSocketHandler.decorate(handler, contextView),
session -> new TomcatWebSocketSession(session, handshakeInfo, bufferFactory));
String requestURI = servletRequest.getRequestURI();
DefaultServerEndpointConfig config = new DefaultServerEndpointConfig(requestURI, endpoint);
config.setSubprotocols(subProtocol != null ?
Collections.singletonList(subProtocol) : Collections.emptyList());
WsServerContainer container = getContainer(servletRequest); WsServerContainer container = getContainer(servletRequest);
container.doUpgrade(servletRequest, servletResponse, config, Collections.emptyMap()); try {
return null; container.doUpgrade(servletRequest, servletResponse, config, Collections.emptyMap());
}
catch (Exception ex) {
return Mono.error(ex);
}
return Mono.empty();
})); }));
} }

View File

@ -37,6 +37,7 @@ import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.web.reactive.socket.HandshakeInfo; import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler;
import org.springframework.web.reactive.socket.adapter.UndertowWebSocketHandlerAdapter; import org.springframework.web.reactive.socket.adapter.UndertowWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.UndertowWebSocketSession; import org.springframework.web.reactive.socket.adapter.UndertowWebSocketSession;
import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
@ -67,10 +68,18 @@ public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy {
// Trigger WebFlux preCommit actions and upgrade // Trigger WebFlux preCommit actions and upgrade
return exchange.getResponse().setComplete() return exchange.getResponse().setComplete()
.then(Mono.fromCallable(() -> { .then(Mono.deferContextual(contextView -> {
DefaultCallback callback = new DefaultCallback(handshakeInfo, handler, bufferFactory); DefaultCallback callback = new DefaultCallback(
new WebSocketProtocolHandshakeHandler(handshakes, callback).handleRequest(httpExchange); handshakeInfo,
return null; ContextWebSocketHandler.decorate(handler, contextView),
bufferFactory);
try {
new WebSocketProtocolHandshakeHandler(handshakes, callback).handleRequest(httpExchange);
}
catch (Exception ex) {
return Mono.error(ex);
}
return Mono.empty();
})); }));
} }

View File

@ -43,6 +43,7 @@ import org.springframework.context.annotation.AnnotationConfigApplicationContext
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.web.filter.reactive.ServerWebExchangeContextFilter;
import org.springframework.web.reactive.DispatcherHandler; import org.springframework.web.reactive.DispatcherHandler;
import org.springframework.web.reactive.socket.client.JettyWebSocketClient; import org.springframework.web.reactive.socket.client.JettyWebSocketClient;
import org.springframework.web.reactive.socket.client.ReactorNettyWebSocketClient; import org.springframework.web.reactive.socket.client.ReactorNettyWebSocketClient;
@ -57,6 +58,7 @@ import org.springframework.web.reactive.socket.server.upgrade.JettyRequestUpgrad
import org.springframework.web.reactive.socket.server.upgrade.ReactorNettyRequestUpgradeStrategy; import org.springframework.web.reactive.socket.server.upgrade.ReactorNettyRequestUpgradeStrategy;
import org.springframework.web.reactive.socket.server.upgrade.TomcatRequestUpgradeStrategy; import org.springframework.web.reactive.socket.server.upgrade.TomcatRequestUpgradeStrategy;
import org.springframework.web.reactive.socket.server.upgrade.UndertowRequestUpgradeStrategy; import org.springframework.web.reactive.socket.server.upgrade.UndertowRequestUpgradeStrategy;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.adapter.WebHttpHandlerBuilder; import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
import org.springframework.web.testfixture.http.server.reactive.bootstrap.HttpServer; import org.springframework.web.testfixture.http.server.reactive.bootstrap.HttpServer;
import org.springframework.web.testfixture.http.server.reactive.bootstrap.JettyHttpServer; import org.springframework.web.testfixture.http.server.reactive.bootstrap.JettyHttpServer;
@ -165,6 +167,11 @@ abstract class AbstractWebSocketIntegrationTests {
@Configuration @Configuration
static class DispatcherConfig { static class DispatcherConfig {
@Bean
public WebFilter contextFilter() {
return new ServerWebExchangeContextFilter();
}
@Bean @Bean
public DispatcherHandler webHandler() { public DispatcherHandler webHandler() {
return new DispatcherHandler(); return new DispatcherHandler();

View File

@ -33,6 +33,7 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.ResponseCookie; import org.springframework.http.ResponseCookie;
import org.springframework.web.filter.reactive.ServerWebExchangeContextFilter;
import org.springframework.web.reactive.HandlerMapping; import org.springframework.web.reactive.HandlerMapping;
import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping; import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping;
import org.springframework.web.reactive.socket.client.WebSocketClient; import org.springframework.web.reactive.socket.client.WebSocketClient;
@ -216,8 +217,11 @@ class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@Override @Override
public Mono<Void> handle(WebSocketSession session) { public Mono<Void> handle(WebSocketSession session) {
// Use retain() for Reactor Netty return Mono.deferContextual(contextView -> {
return session.send(session.receive().doOnNext(WebSocketMessage::retain)); String key = ServerWebExchangeContextFilter.EXCHANGE_CONTEXT_ATTRIBUTE;
assertThat(contextView.getOrEmpty(key).orElse(null)).isNotNull();
return session.send(session.receive().doOnNext(WebSocketMessage::retain));
});
} }
} }