Enrich WebSocketHandler context

Closes gh-26210
This commit is contained in:
Rossen Stoyanchev 2020-12-15 21:33:50 +00:00
parent 83c19cd60e
commit a11d1c8510
9 changed files with 162 additions and 49 deletions

View File

@ -0,0 +1,64 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.reactive.socket.adapter;
import java.util.List;
import reactor.core.publisher.Mono;
import reactor.util.context.ContextView;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketSession;
/**
* {@link WebSocketHandler} decorator that enriches the context of the target handler.
*
* @author Rossen Stoyanchev
* @since 5.3.3
*/
public final class ContextWebSocketHandler implements WebSocketHandler {
private final WebSocketHandler delegate;
private final ContextView contextView;
private ContextWebSocketHandler(WebSocketHandler delegate, ContextView contextView) {
this.delegate = delegate;
this.contextView = contextView;
}
@Override
public List<String> getSubProtocols() {
return this.delegate.getSubProtocols();
}
@Override
public Mono<Void> handle(WebSocketSession session) {
return this.delegate.handle(session).contextWrite(this.contextView);
}
/**
* Return the given handler, decorated to insert the given context, or the
* same handler instance when the context is empty.
*/
public static WebSocketHandler decorate(WebSocketHandler handler, ContextView contextView) {
return (!contextView.isEmpty() ? new ContextWebSocketHandler(handler, contextView) : handler);
}
}

View File

@ -16,6 +16,7 @@
package org.springframework.web.reactive.socket.client;
import java.io.IOException;
import java.net.URI;
import org.apache.commons.logging.Log;
@ -33,6 +34,7 @@ import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler;
import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.JettyWebSocketSession;
@ -137,18 +139,23 @@ public class JettyWebSocketClient implements WebSocketClient, Lifecycle {
private Mono<Void> executeInternal(URI url, HttpHeaders headers, WebSocketHandler handler) {
Sinks.Empty<Void> completionSink = Sinks.empty();
return Mono.fromCallable(
() -> {
if (logger.isDebugEnabled()) {
logger.debug("Connecting to " + url);
}
Object jettyHandler = createHandler(url, handler, completionSink);
ClientUpgradeRequest request = new ClientUpgradeRequest();
request.setSubProtocols(handler.getSubProtocols());
UpgradeListener upgradeListener = new DefaultUpgradeListener(headers);
return this.jettyClient.connect(jettyHandler, url, request, upgradeListener);
})
.then(completionSink.asMono());
return Mono.deferContextual(contextView -> {
if (logger.isDebugEnabled()) {
logger.debug("Connecting to " + url);
}
Object jettyHandler = createHandler(
url, ContextWebSocketHandler.decorate(handler, contextView), completionSink);
ClientUpgradeRequest request = new ClientUpgradeRequest();
request.setSubProtocols(handler.getSubProtocols());
UpgradeListener upgradeListener = new DefaultUpgradeListener(headers);
try {
this.jettyClient.connect(jettyHandler, url, request, upgradeListener);
return completionSink.asMono();
}
catch (IOException ex) {
return Mono.error(ex);
}
});
}
private Object createHandler(URI url, WebSocketHandler handler, Sinks.Empty<Void> completion) {

View File

@ -39,6 +39,7 @@ import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler;
import org.springframework.web.reactive.socket.adapter.StandardWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.StandardWebSocketSession;
@ -95,20 +96,26 @@ public class StandardWebSocketClient implements WebSocketClient {
}
private Mono<Void> executeInternal(URI url, HttpHeaders requestHeaders, WebSocketHandler handler) {
Sinks.Empty<Void> completionSink = Sinks.empty();
return Mono.fromCallable(
() -> {
Sinks.Empty<Void> completion = Sinks.empty();
return Mono.deferContextual(
contextView -> {
if (logger.isDebugEnabled()) {
logger.debug("Connecting to " + url);
}
List<String> protocols = handler.getSubProtocols();
DefaultConfigurator configurator = new DefaultConfigurator(requestHeaders);
Endpoint endpoint = createEndpoint(url, handler, completionSink, configurator);
Endpoint endpoint = createEndpoint(
url, ContextWebSocketHandler.decorate(handler, contextView), completion, configurator);
ClientEndpointConfig config = createEndpointConfig(configurator, protocols);
return this.webSocketContainer.connectToServer(endpoint, config, url);
try {
this.webSocketContainer.connectToServer(endpoint, config, url);
return completion.asMono();
}
catch (Exception ex) {
return Mono.error(ex);
}
})
.subscribeOn(Schedulers.boundedElastic()) // connectToServer is blocking
.then(completionSink.asMono());
.subscribeOn(Schedulers.boundedElastic()); // connectToServer is blocking
}
private StandardWebSocketHandlerAdapter createEndpoint(URI url, WebSocketHandler handler,

View File

@ -42,6 +42,7 @@ import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler;
import org.springframework.web.reactive.socket.adapter.UndertowWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.UndertowWebSocketSession;
@ -154,9 +155,9 @@ public class UndertowWebSocketClient implements WebSocketClient {
}
private Mono<Void> executeInternal(URI url, HttpHeaders headers, WebSocketHandler handler) {
Sinks.Empty<Void> completionSink = Sinks.empty();
return Mono.fromCallable(
() -> {
Sinks.Empty<Void> completion = Sinks.empty();
return Mono.deferContextual(
contextView -> {
if (logger.isDebugEnabled()) {
logger.debug("Connecting to " + url);
}
@ -164,21 +165,22 @@ public class UndertowWebSocketClient implements WebSocketClient {
ConnectionBuilder builder = createConnectionBuilder(url);
DefaultNegotiation negotiation = new DefaultNegotiation(protocols, headers, builder);
builder.setClientNegotiation(negotiation);
return builder.connect().addNotifier(
builder.connect().addNotifier(
new IoFuture.HandlingNotifier<WebSocketChannel, Object>() {
@Override
public void handleDone(WebSocketChannel channel, Object attachment) {
handleChannel(url, handler, completionSink, negotiation, channel);
handleChannel(url, ContextWebSocketHandler.decorate(handler, contextView),
completion, negotiation, channel);
}
@Override
public void handleFailed(IOException ex, Object attachment) {
// Ignore result: can't overflow, ok if not first or no one listens
completionSink.tryEmitError(
completion.tryEmitError(
new IllegalStateException("Failed to connect to " + url, ex));
}
}, null);
})
.then(completionSink.asMono());
return completion.asMono();
});
}
/**

View File

@ -16,6 +16,7 @@
package org.springframework.web.reactive.socket.server.upgrade;
import java.io.IOException;
import java.util.function.Supplier;
import javax.servlet.ServletContext;
@ -39,6 +40,7 @@ import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler;
import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.JettyWebSocketSession;
import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
@ -152,9 +154,6 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
HandshakeInfo handshakeInfo = handshakeInfoFactory.get();
DataBufferFactory factory = response.bufferFactory();
JettyWebSocketHandlerAdapter adapter = new JettyWebSocketHandlerAdapter(
handler, session -> new JettyWebSocketSession(session, handshakeInfo, factory));
startLazily(servletRequest);
Assert.state(this.factory != null, "No WebSocketServerFactory available");
@ -163,15 +162,22 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
// Trigger WebFlux preCommit actions and upgrade
return exchange.getResponse().setComplete()
.then(Mono.fromCallable(() -> {
.then(Mono.deferContextual(contextView -> {
JettyWebSocketHandlerAdapter adapter = new JettyWebSocketHandlerAdapter(
ContextWebSocketHandler.decorate(handler, contextView),
session -> new JettyWebSocketSession(session, handshakeInfo, factory));
try {
adapterHolder.set(new WebSocketHandlerContainer(adapter, subProtocol));
this.factory.acceptWebSocket(servletRequest, servletResponse);
}
catch (IOException ex) {
return Mono.error(ex);
}
finally {
adapterHolder.remove();
}
return null;
return Mono.empty();
}));
}

View File

@ -38,6 +38,7 @@ import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler;
import org.springframework.web.reactive.socket.adapter.StandardWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.TomcatWebSocketSession;
import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
@ -137,20 +138,26 @@ public class TomcatRequestUpgradeStrategy implements RequestUpgradeStrategy {
HandshakeInfo handshakeInfo = handshakeInfoFactory.get();
DataBufferFactory bufferFactory = response.bufferFactory();
Endpoint endpoint = new StandardWebSocketHandlerAdapter(
handler, session -> new TomcatWebSocketSession(session, handshakeInfo, bufferFactory));
String requestURI = servletRequest.getRequestURI();
DefaultServerEndpointConfig config = new DefaultServerEndpointConfig(requestURI, endpoint);
config.setSubprotocols(subProtocol != null ?
Collections.singletonList(subProtocol) : Collections.emptyList());
// Trigger WebFlux preCommit actions and upgrade
return exchange.getResponse().setComplete()
.then(Mono.fromCallable(() -> {
.then(Mono.deferContextual(contextView -> {
Endpoint endpoint = new StandardWebSocketHandlerAdapter(
ContextWebSocketHandler.decorate(handler, contextView),
session -> new TomcatWebSocketSession(session, handshakeInfo, bufferFactory));
String requestURI = servletRequest.getRequestURI();
DefaultServerEndpointConfig config = new DefaultServerEndpointConfig(requestURI, endpoint);
config.setSubprotocols(subProtocol != null ?
Collections.singletonList(subProtocol) : Collections.emptyList());
WsServerContainer container = getContainer(servletRequest);
container.doUpgrade(servletRequest, servletResponse, config, Collections.emptyMap());
return null;
try {
container.doUpgrade(servletRequest, servletResponse, config, Collections.emptyMap());
}
catch (Exception ex) {
return Mono.error(ex);
}
return Mono.empty();
}));
}

View File

@ -37,6 +37,7 @@ import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.lang.Nullable;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler;
import org.springframework.web.reactive.socket.adapter.UndertowWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.UndertowWebSocketSession;
import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
@ -67,10 +68,18 @@ public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy {
// Trigger WebFlux preCommit actions and upgrade
return exchange.getResponse().setComplete()
.then(Mono.fromCallable(() -> {
DefaultCallback callback = new DefaultCallback(handshakeInfo, handler, bufferFactory);
new WebSocketProtocolHandshakeHandler(handshakes, callback).handleRequest(httpExchange);
return null;
.then(Mono.deferContextual(contextView -> {
DefaultCallback callback = new DefaultCallback(
handshakeInfo,
ContextWebSocketHandler.decorate(handler, contextView),
bufferFactory);
try {
new WebSocketProtocolHandshakeHandler(handshakes, callback).handleRequest(httpExchange);
}
catch (Exception ex) {
return Mono.error(ex);
}
return Mono.empty();
}));
}

View File

@ -43,6 +43,7 @@ import org.springframework.context.annotation.AnnotationConfigApplicationContext
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.web.filter.reactive.ServerWebExchangeContextFilter;
import org.springframework.web.reactive.DispatcherHandler;
import org.springframework.web.reactive.socket.client.JettyWebSocketClient;
import org.springframework.web.reactive.socket.client.ReactorNettyWebSocketClient;
@ -57,6 +58,7 @@ import org.springframework.web.reactive.socket.server.upgrade.JettyRequestUpgrad
import org.springframework.web.reactive.socket.server.upgrade.ReactorNettyRequestUpgradeStrategy;
import org.springframework.web.reactive.socket.server.upgrade.TomcatRequestUpgradeStrategy;
import org.springframework.web.reactive.socket.server.upgrade.UndertowRequestUpgradeStrategy;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
import org.springframework.web.testfixture.http.server.reactive.bootstrap.HttpServer;
import org.springframework.web.testfixture.http.server.reactive.bootstrap.JettyHttpServer;
@ -165,6 +167,11 @@ abstract class AbstractWebSocketIntegrationTests {
@Configuration
static class DispatcherConfig {
@Bean
public WebFilter contextFilter() {
return new ServerWebExchangeContextFilter();
}
@Bean
public DispatcherHandler webHandler() {
return new DispatcherHandler();

View File

@ -33,6 +33,7 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpHeaders;
import org.springframework.http.ResponseCookie;
import org.springframework.web.filter.reactive.ServerWebExchangeContextFilter;
import org.springframework.web.reactive.HandlerMapping;
import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping;
import org.springframework.web.reactive.socket.client.WebSocketClient;
@ -216,8 +217,11 @@ class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@Override
public Mono<Void> handle(WebSocketSession session) {
// Use retain() for Reactor Netty
return session.send(session.receive().doOnNext(WebSocketMessage::retain));
return Mono.deferContextual(contextView -> {
String key = ServerWebExchangeContextFilter.EXCHANGE_CONTEXT_ATTRIBUTE;
assertThat(contextView.getOrEmpty(key).orElse(null)).isNotNull();
return session.send(session.receive().doOnNext(WebSocketMessage::retain));
});
}
}