From 13f23dc32ba0f17cc6d26e068c09e0b2578c6ffc Mon Sep 17 00:00:00 2001 From: Brian Clozel Date: Thu, 13 Feb 2020 21:34:37 +0100 Subject: [PATCH] Mark response as complete before WebSocket upgrade Prior to this commit, some WebSocket `RequestUpgradeStrategy` reactive implementations would prevent the application from writing HTTP headers and cookies to the response. For Reactor Netty and Undertow, handling the upgrade and starting the WebSocket communication marks the response status and headers as sent and the application cannot update HTTP response headers after that. This commit ensures that the `RequestUpgradeStrategy` implementations mark the responses as "complete", so that headers are written before we delegate to the server implementation. Fixes gh-24475 --- .../ReactorNettyRequestUpgradeStrategy.java | 18 +++---- .../UndertowRequestUpgradeStrategy.java | 23 ++++----- .../socket/WebSocketIntegrationTests.java | 47 ++++++++++++++++++- 3 files changed, 64 insertions(+), 24 deletions(-) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/ReactorNettyRequestUpgradeStrategy.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/ReactorNettyRequestUpgradeStrategy.java index f7d6575b10..f29c4f6356 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/ReactorNettyRequestUpgradeStrategy.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/ReactorNettyRequestUpgradeStrategy.java @@ -103,15 +103,15 @@ public class ReactorNettyRequestUpgradeStrategy implements RequestUpgradeStrateg HttpServerResponse reactorResponse = getNativeResponse(response); HandshakeInfo handshakeInfo = handshakeInfoFactory.get(); NettyDataBufferFactory bufferFactory = (NettyDataBufferFactory) response.bufferFactory(); - - return reactorResponse.sendWebsocket(subProtocol, this.maxFramePayloadLength, this.handlePing, - (in, out) -> { - ReactorNettyWebSocketSession session = - new ReactorNettyWebSocketSession( - in, out, handshakeInfo, bufferFactory, this.maxFramePayloadLength); - URI uri = exchange.getRequest().getURI(); - return handler.handle(session).checkpoint(uri + " [ReactorNettyRequestUpgradeStrategy]"); - }); + return response.setComplete() + .then(Mono.defer(() -> reactorResponse.sendWebsocket(subProtocol, this.maxFramePayloadLength, this.handlePing, + (in, out) -> { + ReactorNettyWebSocketSession session = + new ReactorNettyWebSocketSession( + in, out, handshakeInfo, bufferFactory, this.maxFramePayloadLength); + URI uri = exchange.getRequest().getURI(); + return handler.handle(session).checkpoint(uri + " [ReactorNettyRequestUpgradeStrategy]"); + }))); } private static HttpServerResponse getNativeResponse(ServerHttpResponse response) { diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java index 30c06ed4f4..8307c8eed0 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,10 +43,11 @@ import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; import org.springframework.web.server.ServerWebExchange; /** -* A {@link RequestUpgradeStrategy} for use with Undertow. - * + * A {@link RequestUpgradeStrategy} for use with Undertow. + * * @author Violeta Georgieva * @author Rossen Stoyanchev + * @author Brian Clozel * @since 5.0 */ public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy { @@ -63,16 +64,12 @@ public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy { HandshakeInfo handshakeInfo = handshakeInfoFactory.get(); DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory(); - - try { - DefaultCallback callback = new DefaultCallback(handshakeInfo, handler, bufferFactory); - new WebSocketProtocolHandshakeHandler(handshakes, callback).handleRequest(httpExchange); - } - catch (Exception ex) { - return Mono.error(ex); - } - - return Mono.empty(); + return exchange.getResponse().setComplete() + .then(Mono.fromCallable(() -> { + DefaultCallback callback = new DefaultCallback(handshakeInfo, handler, bufferFactory); + new WebSocketProtocolHandshakeHandler(handshakes, callback).handleRequest(httpExchange); + return null; + })); } private static HttpServerExchange getNativeRequest(ServerHttpRequest request) { diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/socket/WebSocketIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/socket/WebSocketIntegrationTests.java index 0e970b4965..4ac544d83e 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/socket/WebSocketIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/socket/WebSocketIntegrationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,9 +33,11 @@ import reactor.core.publisher.ReplayProcessor; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.http.HttpHeaders; +import org.springframework.http.ResponseCookie; import org.springframework.web.reactive.HandlerMapping; import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping; import org.springframework.web.reactive.socket.client.WebSocketClient; +import org.springframework.web.server.WebFilter; import org.springframework.web.testfixture.http.server.reactive.bootstrap.HttpServer; import static org.assertj.core.api.Assertions.assertThat; @@ -45,6 +47,7 @@ import static org.assertj.core.api.Assertions.assertThat; * * @author Rossen Stoyanchev * @author Sam Brannen + * @author Brian Clozel */ class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { @@ -91,6 +94,7 @@ class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { public List getSubProtocols() { return Collections.singletonList(protocol); } + @Override public Mono handle(WebSocketSession session) { infoRef.set(session.getHandshakeInfo()); @@ -138,12 +142,31 @@ class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { .doOnNext(s -> logger.debug("inbound " + s)) .then() .doFinally(signalType -> - logger.debug("Completed with: " + signalType) + logger.debug("Completed with: " + signalType) ); }) .block(TIMEOUT); } + @ParameterizedWebSocketTest + void cookie(WebSocketClient client, HttpServer server, Class serverConfigClass) throws Exception { + startServer(client, server, serverConfigClass); + + MonoProcessor output = MonoProcessor.create(); + AtomicReference cookie = new AtomicReference<>(); + this.client.execute(getUrl("/cookie"), + session -> { + cookie.set(session.getHandshakeInfo().getHeaders().getFirst("Set-Cookie")); + return session.receive() + .map(WebSocketMessage::getPayloadAsText) + .subscribeWith(output) + .then(); + }) + .block(TIMEOUT); + assertThat(output.block(TIMEOUT)).isEqualTo("cookie"); + assertThat(cookie.get()).isEqualTo("project=spring"); + } + @Configuration static class WebConfig { @@ -155,8 +178,19 @@ class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { map.put("/sub-protocol", new SubProtocolWebSocketHandler()); map.put("/custom-header", new CustomHeaderHandler()); map.put("/close", new SessionClosingHandler()); + map.put("/cookie", new CookieHandler()); return new SimpleUrlHandlerMapping(map); } + + @Bean + public WebFilter cookieWebFilter() { + return (exchange, chain) -> { + if (exchange.getRequest().getPath().value().startsWith("/cookie")) { + exchange.getResponse().addCookie(ResponseCookie.from("project", "spring").build()); + } + return chain.filter(exchange); + }; + } } @@ -209,4 +243,13 @@ class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { } } + private static class CookieHandler implements WebSocketHandler { + + @Override + public Mono handle(WebSocketSession session) { + WebSocketMessage message = session.textMessage("cookie"); + return session.send(Mono.just(message)); + } + } + }