From f477c1653d2d8198bede242f4ea4d4a599727bcc Mon Sep 17 00:00:00 2001 From: rstoyanchev Date: Mon, 3 Feb 2025 15:28:27 +0000 Subject: [PATCH] Allow WebSocket over HTTP CONNECT Closes gh-34044 --- .../server/support/HandshakeWebSocketService.java | 10 +++++++--- .../server/support/AbstractHandshakeHandler.java | 13 +++++++++---- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java index 603d8e6183..c54f38d9bc 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import java.security.Principal; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -66,6 +67,9 @@ import org.springframework.web.server.ServerWebInputException; */ public class HandshakeWebSocketService implements WebSocketService, Lifecycle { + // For WebSocket upgrades in HTTP/2 (see RFC 8441) + private static final HttpMethod CONNECT_METHOD = HttpMethod.valueOf("CONNECT"); + private static final String SEC_WEBSOCKET_KEY = "Sec-WebSocket-Key"; private static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"; @@ -201,9 +205,9 @@ public class HandshakeWebSocketService implements WebSocketService, Lifecycle { HttpMethod method = request.getMethod(); HttpHeaders headers = request.getHeaders(); - if (HttpMethod.GET != method) { + if (HttpMethod.GET != method && CONNECT_METHOD != method) { return Mono.error(new MethodNotAllowedException( - request.getMethod(), Collections.singleton(HttpMethod.GET))); + request.getMethod(), Set.of(HttpMethod.GET, CONNECT_METHOD))); } if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java index a32692c8af..acde43c3cc 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Set; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -77,6 +78,9 @@ import org.springframework.web.socket.server.standard.WebSphereRequestUpgradeStr */ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Lifecycle { + // For WebSocket upgrades in HTTP/2 (see RFC 8441) + private static final HttpMethod CONNECT_METHOD = HttpMethod.valueOf("CONNECT"); + private static final boolean tomcatWsPresent; private static final boolean jettyWsPresent; @@ -210,11 +214,12 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Life logger.trace("Processing request " + request.getURI() + " with headers=" + headers); } try { - if (HttpMethod.GET != request.getMethod()) { + HttpMethod httpMethod = request.getMethod(); + if (HttpMethod.GET != httpMethod && CONNECT_METHOD != httpMethod) { response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED); - response.getHeaders().setAllow(Collections.singleton(HttpMethod.GET)); + response.getHeaders().setAllow(Set.of(HttpMethod.GET, CONNECT_METHOD)); if (logger.isErrorEnabled()) { - logger.error("Handshake failed due to unexpected HTTP method: " + request.getMethod()); + logger.error("Handshake failed due to unexpected HTTP method: " + httpMethod); } return false; }