diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/UndertowRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/UndertowRequestUpgradeStrategy.java index 679313dcfbd..86854c4a025 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/UndertowRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/UndertowRequestUpgradeStrategy.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.server.standard; import java.lang.reflect.Constructor; +import java.lang.reflect.Method; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -48,57 +49,84 @@ import io.undertow.websockets.jsr.handshake.HandshakeUtil; import io.undertow.websockets.jsr.handshake.JsrHybi07Handshake; import io.undertow.websockets.jsr.handshake.JsrHybi08Handshake; import io.undertow.websockets.jsr.handshake.JsrHybi13Handshake; +import io.undertow.websockets.spi.WebSocketHttpExchange; import org.xnio.StreamConnection; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; import org.springframework.web.socket.server.HandshakeFailureException; /** * A {@link org.springframework.web.socket.server.RequestUpgradeStrategy} for use * with WildFly and its underlying Undertow web server. * - *

Compatible with Undertow 1.0, 1.1, 1.2 - as included in WildFly 8.x and 9.0. + *

Compatible with Undertow 1.0 to 1.3 - as included in WildFly 8.x, 9 and 10. * * @author Rossen Stoyanchev + * @author Brian Clozel + * @author Juergen Hoeller * @since 4.0.1 */ public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy { private static final Constructor exchangeConstructor; + private static final boolean exchangeConstructorWithPeerConnections; + private static final Constructor endpointConstructor; - private static final boolean undertow10Present; + private static final boolean endpointConstructorWithEndpointFactory; - private static final boolean undertow11Present; + private static final Method getBufferPoolMethod; + + private static final Method createChannelMethod; static { - Class exchangeType = ServletWebSocketHttpExchange.class; - Class[] exchangeParamTypes = new Class[] {HttpServletRequest.class, HttpServletResponse.class, Set.class}; - if (ClassUtils.hasConstructor(exchangeType, exchangeParamTypes)) { - exchangeConstructor = ClassUtils.getConstructorIfAvailable(exchangeType, exchangeParamTypes); - undertow10Present = false; - } - else { - exchangeParamTypes = new Class[] {HttpServletRequest.class, HttpServletResponse.class}; - exchangeConstructor = ClassUtils.getConstructorIfAvailable(exchangeType, exchangeParamTypes); - undertow10Present = true; - } + try { + Class exchangeType = ServletWebSocketHttpExchange.class; + Class[] exchangeParamTypes = + new Class[] {HttpServletRequest.class, HttpServletResponse.class, Set.class}; + Constructor exchangeCtor = + ClassUtils.getConstructorIfAvailable(exchangeType, exchangeParamTypes); + if (exchangeCtor != null) { + // Undertow 1.1+ + exchangeConstructor = exchangeCtor; + exchangeConstructorWithPeerConnections = true; + } + else { + // Undertow 1.0 + exchangeParamTypes = new Class[] {HttpServletRequest.class, HttpServletResponse.class}; + exchangeConstructor = exchangeType.getConstructor(exchangeParamTypes); + exchangeConstructorWithPeerConnections = false; + } - Class endpointType = ConfiguredServerEndpoint.class; - Class[] endpointParamTypes = new Class[] {ServerEndpointConfig.class, InstanceFactory.class, - PathTemplate.class, EncodingFactory.class, AnnotatedEndpointFactory.class}; - if (ClassUtils.hasConstructor(endpointType, endpointParamTypes)) { - endpointConstructor = ClassUtils.getConstructorIfAvailable(endpointType, endpointParamTypes); - undertow11Present = true; + Class endpointType = ConfiguredServerEndpoint.class; + Class[] endpointParamTypes = new Class[] {ServerEndpointConfig.class, InstanceFactory.class, + PathTemplate.class, EncodingFactory.class, AnnotatedEndpointFactory.class}; + Constructor endpointCtor = + ClassUtils.getConstructorIfAvailable(endpointType, endpointParamTypes); + if (endpointCtor != null) { + // Undertow 1.1+ + endpointConstructor = endpointCtor; + endpointConstructorWithEndpointFactory = true; + } + else { + // Undertow 1.0 + endpointParamTypes = new Class[] {ServerEndpointConfig.class, InstanceFactory.class, + PathTemplate.class, EncodingFactory.class}; + endpointConstructor = endpointType.getConstructor(endpointParamTypes); + endpointConstructorWithEndpointFactory = false; + } + + // Adapting between different Pool API types in Undertow 1.0-1.2 vs 1.3 + getBufferPoolMethod = WebSocketHttpExchange.class.getMethod("getBufferPool"); + createChannelMethod = Handshake.class.getMethod("createChannel", + WebSocketHttpExchange.class, StreamConnection.class, getBufferPoolMethod.getReturnType()); } - else { - endpointParamTypes = new Class[] {ServerEndpointConfig.class, InstanceFactory.class, - PathTemplate.class, EncodingFactory.class}; - endpointConstructor = ClassUtils.getConstructorIfAvailable(endpointType, endpointParamTypes); - undertow11Present = false; + catch (Throwable ex) { + throw new IllegalStateException("Incompatible Undertow API version", ex); } } @@ -113,11 +141,11 @@ public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrat public UndertowRequestUpgradeStrategy() { - if (undertow10Present) { - this.peerConnections = null; + if (exchangeConstructorWithPeerConnections) { + this.peerConnections = Collections.newSetFromMap(new ConcurrentHashMap()); } else { - this.peerConnections = Collections.newSetFromMap(new ConcurrentHashMap()); + this.peerConnections = null; } } @@ -149,7 +177,9 @@ public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrat exchange.upgradeChannel(new HttpUpgradeListener() { @Override public void handleUpgrade(StreamConnection connection, HttpServerExchange serverExchange) { - WebSocketChannel channel = handshake.createChannel(exchange, connection, exchange.getBufferPool()); + Object bufferPool = ReflectionUtils.invokeMethod(getBufferPoolMethod, exchange); + WebSocketChannel channel = (WebSocketChannel) ReflectionUtils.invokeMethod( + createChannelMethod, handshake, exchange, connection, bufferPool); if (peerConnections != null) { peerConnections.add(channel); } @@ -202,7 +232,7 @@ public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrat Collections., List>>emptyMap(), Collections., List>>emptyMap()); try { - return (undertow11Present ? + return (endpointConstructorWithEndpointFactory ? endpointConstructor.newInstance(endpointRegistration, new EndpointInstanceFactory(endpoint), null, encodingFactory, null) : endpointConstructor.newInstance(endpointRegistration,