diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/ReactorNettyRequestUpgradeStrategy.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/ReactorNettyRequestUpgradeStrategy.java index f7d6575b10..f29c4f6356 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/ReactorNettyRequestUpgradeStrategy.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/ReactorNettyRequestUpgradeStrategy.java @@ -103,15 +103,15 @@ public class ReactorNettyRequestUpgradeStrategy implements RequestUpgradeStrateg HttpServerResponse reactorResponse = getNativeResponse(response); HandshakeInfo handshakeInfo = handshakeInfoFactory.get(); NettyDataBufferFactory bufferFactory = (NettyDataBufferFactory) response.bufferFactory(); - - return reactorResponse.sendWebsocket(subProtocol, this.maxFramePayloadLength, this.handlePing, - (in, out) -> { - ReactorNettyWebSocketSession session = - new ReactorNettyWebSocketSession( - in, out, handshakeInfo, bufferFactory, this.maxFramePayloadLength); - URI uri = exchange.getRequest().getURI(); - return handler.handle(session).checkpoint(uri + " [ReactorNettyRequestUpgradeStrategy]"); - }); + return response.setComplete() + .then(Mono.defer(() -> reactorResponse.sendWebsocket(subProtocol, this.maxFramePayloadLength, this.handlePing, + (in, out) -> { + ReactorNettyWebSocketSession session = + new ReactorNettyWebSocketSession( + in, out, handshakeInfo, bufferFactory, this.maxFramePayloadLength); + URI uri = exchange.getRequest().getURI(); + return handler.handle(session).checkpoint(uri + " [ReactorNettyRequestUpgradeStrategy]"); + }))); } private static HttpServerResponse getNativeResponse(ServerHttpResponse response) { diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java index 30c06ed4f4..8307c8eed0 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,10 +43,11 @@ import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; import org.springframework.web.server.ServerWebExchange; /** -* A {@link RequestUpgradeStrategy} for use with Undertow. - * + * A {@link RequestUpgradeStrategy} for use with Undertow. + * * @author Violeta Georgieva * @author Rossen Stoyanchev + * @author Brian Clozel * @since 5.0 */ public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy { @@ -63,16 +64,12 @@ public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy { HandshakeInfo handshakeInfo = handshakeInfoFactory.get(); DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory(); - - try { - DefaultCallback callback = new DefaultCallback(handshakeInfo, handler, bufferFactory); - new WebSocketProtocolHandshakeHandler(handshakes, callback).handleRequest(httpExchange); - } - catch (Exception ex) { - return Mono.error(ex); - } - - return Mono.empty(); + return exchange.getResponse().setComplete() + .then(Mono.fromCallable(() -> { + DefaultCallback callback = new DefaultCallback(handshakeInfo, handler, bufferFactory); + new WebSocketProtocolHandshakeHandler(handshakes, callback).handleRequest(httpExchange); + return null; + })); } private static HttpServerExchange getNativeRequest(ServerHttpRequest request) { diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/socket/WebSocketIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/socket/WebSocketIntegrationTests.java index 0e970b4965..4ac544d83e 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/socket/WebSocketIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/socket/WebSocketIntegrationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,9 +33,11 @@ import reactor.core.publisher.ReplayProcessor; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.http.HttpHeaders; +import org.springframework.http.ResponseCookie; import org.springframework.web.reactive.HandlerMapping; import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping; import org.springframework.web.reactive.socket.client.WebSocketClient; +import org.springframework.web.server.WebFilter; import org.springframework.web.testfixture.http.server.reactive.bootstrap.HttpServer; import static org.assertj.core.api.Assertions.assertThat; @@ -45,6 +47,7 @@ import static org.assertj.core.api.Assertions.assertThat; * * @author Rossen Stoyanchev * @author Sam Brannen + * @author Brian Clozel */ class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { @@ -91,6 +94,7 @@ class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { public List getSubProtocols() { return Collections.singletonList(protocol); } + @Override public Mono handle(WebSocketSession session) { infoRef.set(session.getHandshakeInfo()); @@ -138,12 +142,31 @@ class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { .doOnNext(s -> logger.debug("inbound " + s)) .then() .doFinally(signalType -> - logger.debug("Completed with: " + signalType) + logger.debug("Completed with: " + signalType) ); }) .block(TIMEOUT); } + @ParameterizedWebSocketTest + void cookie(WebSocketClient client, HttpServer server, Class serverConfigClass) throws Exception { + startServer(client, server, serverConfigClass); + + MonoProcessor output = MonoProcessor.create(); + AtomicReference cookie = new AtomicReference<>(); + this.client.execute(getUrl("/cookie"), + session -> { + cookie.set(session.getHandshakeInfo().getHeaders().getFirst("Set-Cookie")); + return session.receive() + .map(WebSocketMessage::getPayloadAsText) + .subscribeWith(output) + .then(); + }) + .block(TIMEOUT); + assertThat(output.block(TIMEOUT)).isEqualTo("cookie"); + assertThat(cookie.get()).isEqualTo("project=spring"); + } + @Configuration static class WebConfig { @@ -155,8 +178,19 @@ class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { map.put("/sub-protocol", new SubProtocolWebSocketHandler()); map.put("/custom-header", new CustomHeaderHandler()); map.put("/close", new SessionClosingHandler()); + map.put("/cookie", new CookieHandler()); return new SimpleUrlHandlerMapping(map); } + + @Bean + public WebFilter cookieWebFilter() { + return (exchange, chain) -> { + if (exchange.getRequest().getPath().value().startsWith("/cookie")) { + exchange.getResponse().addCookie(ResponseCookie.from("project", "spring").build()); + } + return chain.filter(exchange); + }; + } } @@ -209,4 +243,13 @@ class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { } } + private static class CookieHandler implements WebSocketHandler { + + @Override + public Mono handle(WebSocketSession session) { + WebSocketMessage message = session.textMessage("cookie"); + return session.send(Mono.just(message)); + } + } + }