diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/UndertowRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/UndertowRequestUpgradeStrategy.java index fb94633a68b..57fc9c38673 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/UndertowRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/UndertowRequestUpgradeStrategy.java @@ -16,9 +16,12 @@ package org.springframework.web.socket.server.standard; +import java.lang.reflect.Constructor; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.websocket.Decoder; @@ -34,9 +37,6 @@ import io.undertow.servlet.websockets.ServletWebSocketHttpExchange; import io.undertow.websockets.core.WebSocketChannel; import io.undertow.websockets.core.WebSocketVersion; import io.undertow.websockets.core.protocol.Handshake; -import io.undertow.websockets.core.protocol.version07.Hybi07Handshake; -import io.undertow.websockets.core.protocol.version08.Hybi08Handshake; -import io.undertow.websockets.core.protocol.version13.Hybi13Handshake; import io.undertow.websockets.jsr.ConfiguredServerEndpoint; import io.undertow.websockets.jsr.EncodingFactory; import io.undertow.websockets.jsr.EndpointSessionHandler; @@ -45,6 +45,7 @@ import io.undertow.websockets.jsr.handshake.HandshakeUtil; import io.undertow.websockets.jsr.handshake.JsrHybi07Handshake; import io.undertow.websockets.jsr.handshake.JsrHybi08Handshake; import io.undertow.websockets.jsr.handshake.JsrHybi13Handshake; +import org.springframework.util.ClassUtils; import org.xnio.StreamConnection; import org.springframework.http.server.ServerHttpRequest; @@ -61,16 +62,47 @@ import org.springframework.web.socket.server.HandshakeFailureException; */ public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy { - private final String[] supportedVersions = new String[] { + private static final Constructor exchangeConstructor; + + private static final boolean undertow10Present; + + static { + Class type = ServletWebSocketHttpExchange.class; + Class[] paramTypes = new Class[] {HttpServletRequest.class, HttpServletResponse.class, Set.class}; + if (ClassUtils.hasConstructor(type, paramTypes)) { + exchangeConstructor = ClassUtils.getConstructorIfAvailable(type, paramTypes); + undertow10Present = false; + } + else { + paramTypes = new Class[] {HttpServletRequest.class, HttpServletResponse.class}; + exchangeConstructor = ClassUtils.getConstructorIfAvailable(type, paramTypes); + undertow10Present = true; + } + } + + private static final String[] supportedVersions = new String[] { WebSocketVersion.V13.toHttpHeaderValue(), WebSocketVersion.V08.toHttpHeaderValue(), WebSocketVersion.V07.toHttpHeaderValue() }; + private Set peerConnections; + + + public UndertowRequestUpgradeStrategy() { + if (undertow10Present) { + this.peerConnections = null; + } + else { + this.peerConnections = Collections.newSetFromMap(new ConcurrentHashMap()); + } + } + + @Override public String[] getSupportedVersions() { - return this.supportedVersions; + return supportedVersions; } @Override @@ -80,7 +112,7 @@ public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrat HttpServletRequest servletRequest = getHttpServletRequest(request); HttpServletResponse servletResponse = getHttpServletResponse(response); - final ServletWebSocketHttpExchange exchange = new ServletWebSocketHttpExchange(servletRequest, servletResponse); + final ServletWebSocketHttpExchange exchange = createHttpExchange(servletRequest, servletResponse); exchange.putAttachment(HandshakeUtil.PATH_PARAMS, Collections.emptyMap()); ServerWebSocketContainer wsContainer = (ServerWebSocketContainer) getContainer(servletRequest); @@ -95,6 +127,9 @@ public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrat @Override public void handleUpgrade(StreamConnection connection, HttpServerExchange serverExchange) { WebSocketChannel channel = handshake.createChannel(exchange, connection, exchange.getBufferPool()); + if (peerConnections != null) { + peerConnections.add(channel); + } endpointSessionHandler.onConnect(exchange, channel); } }); @@ -102,6 +137,17 @@ public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrat handshake.handshake(exchange); } + private ServletWebSocketHttpExchange createHttpExchange(HttpServletRequest request, HttpServletResponse response) { + try { + return (this.peerConnections != null ? + exchangeConstructor.newInstance(request, response, this.peerConnections) : + exchangeConstructor.newInstance(request, response)); + } + catch (Exception ex) { + throw new HandshakeFailureException("Failed to instantiate ServletWebSocketHttpExchange", ex); + } + } + private Handshake getHandshakeToUse(ServletWebSocketHttpExchange exchange, ConfiguredServerEndpoint endpoint) { Handshake handshake = new JsrHybi13Handshake(endpoint); if (handshake.matches(exchange)) {