Support RFC 8441 upgrades over HTTP/2 CONNECT

See gh-34362

Signed-off-by: Jared Wiltshire <jazdw@users.noreply.github.com>
This commit is contained in:
Jared Wiltshire 2025-02-03 16:45:47 -07:00 committed by rstoyanchev
parent d59991fcc9
commit 49f9b40fba
3 changed files with 32 additions and 26 deletions

View File

@ -205,23 +205,25 @@ public class HandshakeWebSocketService implements WebSocketService, Lifecycle {
HttpMethod method = request.getMethod(); HttpMethod method = request.getMethod();
HttpHeaders headers = request.getHeaders(); HttpHeaders headers = request.getHeaders();
if (HttpMethod.GET != method && CONNECT_METHOD != method) { if (HttpMethod.GET != method && !CONNECT_METHOD.equals(method)) {
return Mono.error(new MethodNotAllowedException( return Mono.error(new MethodNotAllowedException(
request.getMethod(), Set.of(HttpMethod.GET, CONNECT_METHOD))); request.getMethod(), Set.of(HttpMethod.GET, CONNECT_METHOD)));
} }
if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { if (HttpMethod.GET == method) {
return handleBadRequest(exchange, "Invalid 'Upgrade' header: " + headers); if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
} return handleBadRequest(exchange, "Invalid 'Upgrade' header: " + headers);
}
List<String> connectionValue = headers.getConnection(); List<String> connectionValue = headers.getConnection();
if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) { if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) {
return handleBadRequest(exchange, "Invalid 'Connection' header: " + headers); return handleBadRequest(exchange, "Invalid 'Connection' header: " + headers);
} }
String key = headers.getFirst(SEC_WEBSOCKET_KEY); String key = headers.getFirst(SEC_WEBSOCKET_KEY);
if (key == null) { if (key == null) {
return handleBadRequest(exchange, "Missing \"Sec-WebSocket-Key\" header"); return handleBadRequest(exchange, "Missing \"Sec-WebSocket-Key\" header");
}
} }
String protocol = selectProtocol(headers, handler); String protocol = selectProtocol(headers, handler);

View File

@ -151,7 +151,7 @@ public class WebSocketHttpHeaders extends HttpHeaders {
} }
/** /**
* Returns the value of the {@code Sec-WebSocket-Key} header. * Returns the value of the {@code Sec-WebSocket-Protocol} header.
* @return the value of the header * @return the value of the header
*/ */
public List<String> getSecWebSocketProtocol() { public List<String> getSecWebSocketProtocol() {

View File

@ -215,7 +215,7 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Life
} }
try { try {
HttpMethod httpMethod = request.getMethod(); HttpMethod httpMethod = request.getMethod();
if (HttpMethod.GET != httpMethod && CONNECT_METHOD != httpMethod) { if (HttpMethod.GET != httpMethod && !CONNECT_METHOD.equals(httpMethod)) {
response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED); response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED);
response.getHeaders().setAllow(Set.of(HttpMethod.GET, CONNECT_METHOD)); response.getHeaders().setAllow(Set.of(HttpMethod.GET, CONNECT_METHOD));
if (logger.isErrorEnabled()) { if (logger.isErrorEnabled()) {
@ -223,13 +223,15 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Life
} }
return false; return false;
} }
if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { if (HttpMethod.GET == httpMethod) {
handleInvalidUpgradeHeader(request, response); if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
return false; handleInvalidUpgradeHeader(request, response);
} return false;
if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) { }
handleInvalidConnectHeader(request, response); if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) {
return false; handleInvalidConnectHeader(request, response);
return false;
}
} }
if (!isWebSocketVersionSupported(headers)) { if (!isWebSocketVersionSupported(headers)) {
handleWebSocketVersionNotSupported(request, response); handleWebSocketVersionNotSupported(request, response);
@ -239,13 +241,15 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Life
response.setStatusCode(HttpStatus.FORBIDDEN); response.setStatusCode(HttpStatus.FORBIDDEN);
return false; return false;
} }
String wsKey = headers.getSecWebSocketKey(); if (HttpMethod.GET == httpMethod) {
if (wsKey == null) { String wsKey = headers.getSecWebSocketKey();
if (logger.isErrorEnabled()) { if (wsKey == null) {
logger.error("Missing \"Sec-WebSocket-Key\" header"); if (logger.isErrorEnabled()) {
logger.error("Missing \"Sec-WebSocket-Key\" header");
}
response.setStatusCode(HttpStatus.BAD_REQUEST);
return false;
} }
response.setStatusCode(HttpStatus.BAD_REQUEST);
return false;
} }
} }
catch (IOException ex) { catch (IOException ex) {