Mark response as complete before WebSocket upgrade

Prior to this commit, some WebSocket `RequestUpgradeStrategy` reactive
implementations would prevent the application from writing HTTP headers
and cookies to the response.

For Reactor Netty and Undertow, handling the upgrade and starting the
WebSocket communication marks the response status and headers as sent
and the application cannot update HTTP response headers after that.

This commit ensures that the `RequestUpgradeStrategy` implementations
mark the responses as "complete", so that headers are written before we
delegate to the server implementation.

Fixes gh-24475
This commit is contained in:
Brian Clozel 2020-02-13 21:34:37 +01:00
parent 4cbc61abfc
commit 13f23dc32b
3 changed files with 64 additions and 24 deletions

View File

@ -103,15 +103,15 @@ public class ReactorNettyRequestUpgradeStrategy implements RequestUpgradeStrateg
HttpServerResponse reactorResponse = getNativeResponse(response);
HandshakeInfo handshakeInfo = handshakeInfoFactory.get();
NettyDataBufferFactory bufferFactory = (NettyDataBufferFactory) response.bufferFactory();
return reactorResponse.sendWebsocket(subProtocol, this.maxFramePayloadLength, this.handlePing,
(in, out) -> {
ReactorNettyWebSocketSession session =
new ReactorNettyWebSocketSession(
in, out, handshakeInfo, bufferFactory, this.maxFramePayloadLength);
URI uri = exchange.getRequest().getURI();
return handler.handle(session).checkpoint(uri + " [ReactorNettyRequestUpgradeStrategy]");
});
return response.setComplete()
.then(Mono.defer(() -> reactorResponse.sendWebsocket(subProtocol, this.maxFramePayloadLength, this.handlePing,
(in, out) -> {
ReactorNettyWebSocketSession session =
new ReactorNettyWebSocketSession(
in, out, handshakeInfo, bufferFactory, this.maxFramePayloadLength);
URI uri = exchange.getRequest().getURI();
return handler.handle(session).checkpoint(uri + " [ReactorNettyRequestUpgradeStrategy]");
})));
}
private static HttpServerResponse getNativeResponse(ServerHttpResponse response) {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -43,10 +43,11 @@ import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
import org.springframework.web.server.ServerWebExchange;
/**
* A {@link RequestUpgradeStrategy} for use with Undertow.
*
* A {@link RequestUpgradeStrategy} for use with Undertow.
*
* @author Violeta Georgieva
* @author Rossen Stoyanchev
* @author Brian Clozel
* @since 5.0
*/
public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy {
@ -63,16 +64,12 @@ public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy {
HandshakeInfo handshakeInfo = handshakeInfoFactory.get();
DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory();
try {
DefaultCallback callback = new DefaultCallback(handshakeInfo, handler, bufferFactory);
new WebSocketProtocolHandshakeHandler(handshakes, callback).handleRequest(httpExchange);
}
catch (Exception ex) {
return Mono.error(ex);
}
return Mono.empty();
return exchange.getResponse().setComplete()
.then(Mono.fromCallable(() -> {
DefaultCallback callback = new DefaultCallback(handshakeInfo, handler, bufferFactory);
new WebSocketProtocolHandshakeHandler(handshakes, callback).handleRequest(httpExchange);
return null;
}));
}
private static HttpServerExchange getNativeRequest(ServerHttpRequest request) {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -33,9 +33,11 @@ import reactor.core.publisher.ReplayProcessor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpHeaders;
import org.springframework.http.ResponseCookie;
import org.springframework.web.reactive.HandlerMapping;
import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping;
import org.springframework.web.reactive.socket.client.WebSocketClient;
import org.springframework.web.server.WebFilter;
import org.springframework.web.testfixture.http.server.reactive.bootstrap.HttpServer;
import static org.assertj.core.api.Assertions.assertThat;
@ -45,6 +47,7 @@ import static org.assertj.core.api.Assertions.assertThat;
*
* @author Rossen Stoyanchev
* @author Sam Brannen
* @author Brian Clozel
*/
class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@ -91,6 +94,7 @@ class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
public List<String> getSubProtocols() {
return Collections.singletonList(protocol);
}
@Override
public Mono<Void> handle(WebSocketSession session) {
infoRef.set(session.getHandshakeInfo());
@ -138,12 +142,31 @@ class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
.doOnNext(s -> logger.debug("inbound " + s))
.then()
.doFinally(signalType ->
logger.debug("Completed with: " + signalType)
logger.debug("Completed with: " + signalType)
);
})
.block(TIMEOUT);
}
@ParameterizedWebSocketTest
void cookie(WebSocketClient client, HttpServer server, Class<?> serverConfigClass) throws Exception {
startServer(client, server, serverConfigClass);
MonoProcessor<Object> output = MonoProcessor.create();
AtomicReference<String> cookie = new AtomicReference<>();
this.client.execute(getUrl("/cookie"),
session -> {
cookie.set(session.getHandshakeInfo().getHeaders().getFirst("Set-Cookie"));
return session.receive()
.map(WebSocketMessage::getPayloadAsText)
.subscribeWith(output)
.then();
})
.block(TIMEOUT);
assertThat(output.block(TIMEOUT)).isEqualTo("cookie");
assertThat(cookie.get()).isEqualTo("project=spring");
}
@Configuration
static class WebConfig {
@ -155,8 +178,19 @@ class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
map.put("/sub-protocol", new SubProtocolWebSocketHandler());
map.put("/custom-header", new CustomHeaderHandler());
map.put("/close", new SessionClosingHandler());
map.put("/cookie", new CookieHandler());
return new SimpleUrlHandlerMapping(map);
}
@Bean
public WebFilter cookieWebFilter() {
return (exchange, chain) -> {
if (exchange.getRequest().getPath().value().startsWith("/cookie")) {
exchange.getResponse().addCookie(ResponseCookie.from("project", "spring").build());
}
return chain.filter(exchange);
};
}
}
@ -209,4 +243,13 @@ class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
}
}
private static class CookieHandler implements WebSocketHandler {
@Override
public Mono<Void> handle(WebSocketSession session) {
WebSocketMessage message = session.textMessage("cookie");
return session.send(Mono.just(message));
}
}
}