From c89325b9ca1d3d30bbe8a32beed91dffa8d836bb Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Fri, 25 Jul 2014 17:08:11 -0400 Subject: [PATCH] Add WebLogicRequestUpgradeStrategy This change creates an AbstractTyrusRequestUpgradeStrategy shared between the WebLogic and GlassFish sub-classes. The version of Tyrus is lowered to 1.3.5 to match the version used in WebLogic (12.1.3) and that in turn requires a little extra effort in the base AbstractTyrusRequestUpgradeStrategy to make up for changes that have taken place from Tyrus 1.3.5 to 1.7. Issue: SPR-11293 --- build.gradle | 8 +- .../AbstractTyrusRequestUpgradeStrategy.java | 286 ++++++++++++++++++ .../GlassFishRequestUpgradeStrategy.java | 194 +++--------- .../WebLogicRequestUpgradeStrategy.java | 197 ++++++++++++ .../support/DefaultHandshakeHandler.java | 16 +- 5 files changed, 544 insertions(+), 157 deletions(-) create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/server/standard/AbstractTyrusRequestUpgradeStrategy.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/server/standard/WebLogicRequestUpgradeStrategy.java diff --git a/build.gradle b/build.gradle index 715f1c5c3eb..9fdc732111d 100644 --- a/build.gradle +++ b/build.gradle @@ -684,10 +684,10 @@ project("spring-websocket") { exclude group: "org.apache.tomcat", module: "tomcat-websocket-api" exclude group: "org.apache.tomcat", module: "tomcat-servlet-api" } - optional("org.glassfish.tyrus:tyrus-spi:1.7") - optional("org.glassfish.tyrus:tyrus-core:1.7") - optional("org.glassfish.tyrus:tyrus-server:1.7") - optional("org.glassfish.tyrus:tyrus-container-servlet:1.7") + optional("org.glassfish.tyrus:tyrus-spi:1.3.5") + optional("org.glassfish.tyrus:tyrus-core:1.3.5") + optional("org.glassfish.tyrus:tyrus-server:1.3.5") + optional("org.glassfish.tyrus:tyrus-container-servlet:1.3.5") optional("org.eclipse.jetty:jetty-webapp:${jettyVersion}") { exclude group: "javax.servlet", module: "javax.servlet" } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/AbstractTyrusRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/AbstractTyrusRequestUpgradeStrategy.java new file mode 100644 index 00000000000..56a51a19319 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/AbstractTyrusRequestUpgradeStrategy.java @@ -0,0 +1,286 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server.standard; + +import org.glassfish.tyrus.core.ComponentProviderService; +import org.glassfish.tyrus.core.RequestContext; +import org.glassfish.tyrus.core.TyrusEndpoint; +import org.glassfish.tyrus.core.TyrusEndpointWrapper; +import org.glassfish.tyrus.core.TyrusUpgradeResponse; +import org.glassfish.tyrus.core.TyrusWebSocketEngine; +import org.glassfish.tyrus.core.Version; +import org.glassfish.tyrus.core.WebSocketApplication; +import org.glassfish.tyrus.server.TyrusServerContainer; +import org.glassfish.tyrus.spi.WebSocketEngine.UpgradeInfo; +import org.springframework.beans.DirectFieldAccessor; +import org.springframework.http.HttpHeaders; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.util.ReflectionUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.server.HandshakeFailureException; + +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.websocket.DeploymentException; +import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.Extension; +import javax.websocket.WebSocketContainer; +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.Method; +import java.net.URI; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +/** + * An base class for WebSocket servers using Tyrus. + * + *

Works with Tyrus 1.3.5 (WebLogic 12.1.3) and Tyrus 1.7 (GlassFish 4.0.1). + * + * @author Rossen Stoyanchev + * @since 4.1 + * @see Project Tyrus + */ +public abstract class AbstractTyrusRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy { + + private static final Random random = new Random(); + + + private final ComponentProviderService componentProvider = ComponentProviderService.create(); + + + @Override + public String[] getSupportedVersions() { + return StringUtils.commaDelimitedListToStringArray(Version.getSupportedWireProtocolVersions()); + } + + protected List getInstalledExtensions(WebSocketContainer container) { + try { + return super.getInstalledExtensions(container); + } + catch (UnsupportedOperationException e) { + return new ArrayList(); + } + } + + protected abstract TyrusEndpointHelper getEndpointHelper(); + + + @Override + public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, + String subProtocol, List extensions, Endpoint endpoint) throws HandshakeFailureException { + + HttpServletRequest servletRequest = getHttpServletRequest(request); + HttpServletResponse servletResponse = getHttpServletResponse(response); + + TyrusServerContainer serverContainer = (TyrusServerContainer) getContainer(servletRequest); + TyrusWebSocketEngine engine = (TyrusWebSocketEngine) serverContainer.getWebSocketEngine(); + Object tyrusEndpoint = null; + + try { + // Shouldn't matter for processing but must be unique + String path = "/" + random.nextLong(); + tyrusEndpoint = createTyrusEndpoint(endpoint, path, subProtocol, extensions, serverContainer, engine); + getEndpointHelper().register(engine, tyrusEndpoint); + + HttpHeaders headers = request.getHeaders(); + RequestContext requestContext = createRequestContext(servletRequest, path, headers); + TyrusUpgradeResponse upgradeResponse = new TyrusUpgradeResponse(); + UpgradeInfo upgradeInfo = engine.upgrade(requestContext, upgradeResponse); + + switch (upgradeInfo.getStatus()) { + case SUCCESS: + if (logger.isTraceEnabled()) { + logger.trace("Successful upgrade: " + upgradeResponse.getHeaders()); + } + handleSuccess(servletRequest, servletResponse, upgradeInfo, upgradeResponse); + break; + case HANDSHAKE_FAILED: + // Should never happen + throw new HandshakeFailureException("Unexpected handshake failure: " + request.getURI()); + case NOT_APPLICABLE: + // Should never happen + throw new HandshakeFailureException("Unexpected handshake mapping failure: " + request.getURI()); + } + } + catch (Exception ex) { + throw new HandshakeFailureException("Error during handshake: " + request.getURI(), ex); + } + finally { + if (tyrusEndpoint != null) { + getEndpointHelper().unregister(engine, tyrusEndpoint); + } + } + } + + protected abstract void handleSuccess(HttpServletRequest request, HttpServletResponse response, + UpgradeInfo upgradeInfo, TyrusUpgradeResponse upgradeResponse) throws IOException, ServletException; + + private Object createTyrusEndpoint(Endpoint endpoint, String endpointPath, String protocol, + List extensions, WebSocketContainer container, TyrusWebSocketEngine engine) + throws DeploymentException { + + ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration(endpointPath, endpoint); + endpointConfig.setSubprotocols(Arrays.asList(protocol)); + endpointConfig.setExtensions(extensions); + return getEndpointHelper().createdEndpoint(endpointConfig, this.componentProvider, container, engine); + } + + private RequestContext createRequestContext(HttpServletRequest request, String endpointPath, HttpHeaders headers) { + RequestContext context = + RequestContext.Builder.create() + .requestURI(URI.create(endpointPath)) + .userPrincipal(request.getUserPrincipal()) + .secure(request.isSecure()) + // .remoteAddr(request.getRemoteAddr()) # Not available in 1.3.5 + .build(); + for (String header : headers.keySet()) { + context.getHeaders().put(header, headers.get(header)); + } + return context; + } + + + /** + * Helps with the creation, registration, and un-registration of endpoints. + */ + protected interface TyrusEndpointHelper { + + Object createdEndpoint(ServerEndpointRegistration registration, ComponentProviderService provider, + WebSocketContainer container, TyrusWebSocketEngine engine) throws DeploymentException; + + void register(TyrusWebSocketEngine engine, Object endpoint); + + void unregister(TyrusWebSocketEngine engine, Object endpoint); + + } + + protected static class Tyrus17EndpointHelper implements TyrusEndpointHelper { + + private static final Constructor constructor; + + private static final Method registerMethod; + + private static final Method unRegisterMethod; + + static { + try { + constructor = getEndpointConstructor(); + registerMethod = TyrusWebSocketEngine.class.getDeclaredMethod("register", TyrusEndpointWrapper.class); + unRegisterMethod = TyrusWebSocketEngine.class.getDeclaredMethod("unregister", TyrusEndpointWrapper.class); + ReflectionUtils.makeAccessible(registerMethod); + } + catch (Exception ex) { + throw new IllegalStateException("No compatible Tyrus version found", ex); + } + } + + private static Constructor getEndpointConstructor() { + for (Constructor current : TyrusEndpointWrapper.class.getConstructors()) { + Class[] types = current.getParameterTypes(); + if (types[0].equals(Endpoint.class) && types[1].equals(EndpointConfig.class)) { + return current; + } + } + throw new IllegalStateException("No compatible Tyrus version found"); + } + + + @Override + public Object createdEndpoint(ServerEndpointRegistration registration, ComponentProviderService provider, + WebSocketContainer container, TyrusWebSocketEngine engine) throws DeploymentException { + + DirectFieldAccessor accessor = new DirectFieldAccessor(engine); + Object sessionListener = accessor.getPropertyValue("sessionListener"); + Object clusterContext = accessor.getPropertyValue("clusterContext"); + try { + return constructor.newInstance(registration.getEndpoint(), registration, provider, container, + "/", registration.getConfigurator(), sessionListener, clusterContext, null); + } + catch (Exception ex) { + throw new HandshakeFailureException("Failed to register " + registration, ex); + } + } + + @Override + public void register(TyrusWebSocketEngine engine, Object endpoint) { + try { + registerMethod.invoke(engine, endpoint); + } + catch (Exception ex) { + throw new HandshakeFailureException("Failed to register " + endpoint, ex); + } + } + + @Override + public void unregister(TyrusWebSocketEngine engine, Object endpoint) { + try { + unRegisterMethod.invoke(engine, endpoint); + } + catch (Exception ex) { + throw new HandshakeFailureException("Failed to unregister " + endpoint, ex); + } + } + } + + protected static class Tyrus135EndpointHelper implements TyrusEndpointHelper { + + private static final Method registerMethod; + + static { + try { + registerMethod = TyrusWebSocketEngine.class.getDeclaredMethod("register", WebSocketApplication.class); + ReflectionUtils.makeAccessible(registerMethod); + } + catch (Exception ex) { + throw new IllegalStateException("No compatible Tyrus version found", ex); + } + } + + @Override + public Object createdEndpoint(ServerEndpointRegistration registration, ComponentProviderService provider, + WebSocketContainer container, TyrusWebSocketEngine engine) throws DeploymentException { + + TyrusEndpointWrapper endpointWrapper = new TyrusEndpointWrapper(registration.getEndpoint(), + registration, provider, container, "/", registration.getConfigurator()); + + return new TyrusEndpoint(endpointWrapper); + } + + @Override + public void register(TyrusWebSocketEngine engine, Object endpoint) { + try { + registerMethod.invoke(engine, endpoint); + } + catch (Exception ex) { + throw new HandshakeFailureException("Failed to register " + endpoint, ex); + } + } + + @Override + public void unregister(TyrusWebSocketEngine engine, Object endpoint) { + engine.unregister((TyrusEndpoint) endpoint); + } + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/GlassFishRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/GlassFishRequestUpgradeStrategy.java index f3bb35d96be..651692da159 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/GlassFishRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/GlassFishRequestUpgradeStrategy.java @@ -16,186 +16,84 @@ package org.springframework.web.socket.server.standard; -import java.lang.reflect.Constructor; -import java.lang.reflect.Method; -import java.net.URI; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Random; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.websocket.DeploymentException; -import javax.websocket.Endpoint; -import javax.websocket.Extension; -import javax.websocket.WebSocketContainer; - -import org.glassfish.tyrus.core.ComponentProviderService; -import org.glassfish.tyrus.core.RequestContext; -import org.glassfish.tyrus.core.TyrusEndpointWrapper; import org.glassfish.tyrus.core.TyrusUpgradeResponse; -import org.glassfish.tyrus.core.TyrusWebSocketEngine; import org.glassfish.tyrus.core.Utils; -import org.glassfish.tyrus.core.Version; -import org.glassfish.tyrus.core.cluster.ClusterContext; -import org.glassfish.tyrus.core.monitoring.EndpointEventListener; -import org.glassfish.tyrus.server.TyrusServerContainer; import org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler; import org.glassfish.tyrus.spi.WebSocketEngine.UpgradeInfo; - import org.glassfish.tyrus.spi.Writer; -import org.springframework.http.HttpHeaders; -import org.springframework.http.server.ServerHttpRequest; -import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.ReflectionUtils; -import org.springframework.util.StringUtils; -import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.server.HandshakeFailureException; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.util.List; +import java.util.Map; + /** - * A WebSocket request upgrade strategy for GlassFish 4.0.1 and beyond. + * A WebSocket {@code RequestUpgradeStrategy} for GlassFish 4.0.1 and beyond. * * @author Rossen Stoyanchev * @author Juergen Hoeller * @author Michael Irwin * @since 4.0 */ -public class GlassFishRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy { +public class GlassFishRequestUpgradeStrategy extends AbstractTyrusRequestUpgradeStrategy { - private static final Random random = new Random(); - - private static final Constructor tyrusServletWriterConstructor; - - private static final Method endpointRegistrationMethod; - - static { - try { - ClassLoader classLoader = GlassFishRequestUpgradeStrategy.class.getClassLoader(); - Class type = classLoader.loadClass("org.glassfish.tyrus.servlet.TyrusServletWriter"); - tyrusServletWriterConstructor = type.getDeclaredConstructor(TyrusHttpUpgradeHandler.class); - ReflectionUtils.makeAccessible(tyrusServletWriterConstructor); - - Class endpointType = TyrusEndpointWrapper.class; - endpointRegistrationMethod = TyrusWebSocketEngine.class.getDeclaredMethod("register", endpointType); - ReflectionUtils.makeAccessible(endpointRegistrationMethod); - } - catch (Exception ex) { - throw new IllegalStateException("No compatible Tyrus version found", ex); - } - } - - private final ComponentProviderService componentProviderService = ComponentProviderService.create(); + private static final TyrusEndpointHelper endpointHelper = new Tyrus17EndpointHelper(); + private static final GlassFishServletWriterHelper servletWriterHelper = new GlassFishServletWriterHelper(); @Override - public String[] getSupportedVersions() { - return StringUtils.commaDelimitedListToStringArray(Version.getSupportedWireProtocolVersions()); - } - - protected List getInstalledExtensions(WebSocketContainer container) { - try { - return super.getInstalledExtensions(container); - } - catch (UnsupportedOperationException e) { - return new ArrayList(); - } + protected TyrusEndpointHelper getEndpointHelper() { + return endpointHelper; } @Override - public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, - String subProtocol, List extensions, Endpoint endpoint) throws HandshakeFailureException { + protected void handleSuccess(HttpServletRequest request, HttpServletResponse response, + UpgradeInfo upgradeInfo, TyrusUpgradeResponse upgradeResponse) throws IOException, ServletException { - HttpServletRequest servletRequest = getHttpServletRequest(request); - HttpServletResponse servletResponse = getHttpServletResponse(response); + TyrusHttpUpgradeHandler handler = request.upgrade(TyrusHttpUpgradeHandler.class); + Writer servletWriter = servletWriterHelper.newInstance(handler); + handler.preInit(upgradeInfo, servletWriter, request.getUserPrincipal() != null); - TyrusServerContainer serverContainer = (TyrusServerContainer) getContainer(servletRequest); - TyrusWebSocketEngine engine = (TyrusWebSocketEngine) serverContainer.getWebSocketEngine(); - TyrusEndpointWrapper tyrusEndpoint = null; + response.setStatus(upgradeResponse.getStatus()); + for (Map.Entry> entry : upgradeResponse.getHeaders().entrySet()) { + response.addHeader(entry.getKey(), Utils.getHeaderFromList(entry.getValue())); + } + response.flushBuffer(); + } - try { - tyrusEndpoint = createTyrusEndpoint(endpoint, subProtocol, extensions, serverContainer); - endpointRegistrationMethod.invoke(engine, tyrusEndpoint); - String endpointPath = tyrusEndpoint.getEndpointPath(); - HttpHeaders headers = request.getHeaders(); + /** + * Helps to create and invoke {@code org.glassfish.tyrus.servlet.TyrusServletWriter}. + */ + private static class GlassFishServletWriterHelper { - RequestContext requestContext = createRequestContext(servletRequest, endpointPath, headers); - TyrusUpgradeResponse upgradeResponse = new TyrusUpgradeResponse(); - UpgradeInfo upgradeInfo = engine.upgrade(requestContext, upgradeResponse); + private static final Constructor constructor; - switch (upgradeInfo.getStatus()) { - case SUCCESS: - TyrusHttpUpgradeHandler handler = servletRequest.upgrade(TyrusHttpUpgradeHandler.class); - Writer servletWriter = createTyrusServletWriter(handler); - handler.preInit(upgradeInfo, servletWriter, servletRequest.getUserPrincipal() != null); - servletResponse.setStatus(upgradeResponse.getStatus()); - for (Map.Entry> entry : upgradeResponse.getHeaders().entrySet()) { - servletResponse.addHeader(entry.getKey(), Utils.getHeaderFromList(entry.getValue())); - } - servletResponse.flushBuffer(); - if (logger.isTraceEnabled()) { - logger.trace("Successful upgrade uri=" + servletRequest.getRequestURI() + - ", response headers=" + upgradeResponse.getHeaders()); - } - break; - case HANDSHAKE_FAILED: - // Should never happen - throw new HandshakeFailureException("Unexpected handshake failure: " + request.getURI()); - case NOT_APPLICABLE: - // Should never happen - throw new HandshakeFailureException("Unexpected handshake mapping failure: " + request.getURI()); + static { + try { + ClassLoader classLoader = GlassFishRequestUpgradeStrategy.class.getClassLoader(); + Class type = classLoader.loadClass("org.glassfish.tyrus.servlet.TyrusServletWriter"); + constructor = type.getDeclaredConstructor(TyrusHttpUpgradeHandler.class); + ReflectionUtils.makeAccessible(constructor); + } + catch (Exception ex) { + throw new IllegalStateException("No compatible Tyrus version found", ex); } } - catch (Exception ex) { - throw new HandshakeFailureException("Error during handshake: " + request.getURI(), ex); - } - finally { - if (tyrusEndpoint != null) { - engine.unregister(tyrusEndpoint); + + private Writer newInstance(TyrusHttpUpgradeHandler handler) { + try { + return (Writer) constructor.newInstance(handler); + } + catch (Exception ex) { + throw new HandshakeFailureException("Failed to instantiate TyrusServletWriter", ex); } - } - } - - private TyrusEndpointWrapper createTyrusEndpoint(Endpoint endpoint, String protocol, - List extensions, WebSocketContainer container) throws DeploymentException { - - // Shouldn't matter for processing but must be unique - String endpointPath = "/" + random.nextLong(); - - ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration(endpointPath, endpoint); - endpointConfig.setSubprotocols(Arrays.asList(protocol)); - endpointConfig.setExtensions(extensions); - - TyrusEndpointWrapper.SessionListener sessionListener = new TyrusEndpointWrapper.SessionListener() {}; - ClusterContext clusterContext = null; - EndpointEventListener eventListener = EndpointEventListener.NO_OP; - - return new TyrusEndpointWrapper(endpoint, endpointConfig, this.componentProviderService, - container, "/", endpointConfig.getConfigurator(), sessionListener, clusterContext, eventListener); - } - - private RequestContext createRequestContext(HttpServletRequest request, String endpointPath, HttpHeaders headers) { - RequestContext context = - RequestContext.Builder.create() - .requestURI(URI.create(endpointPath)) - .userPrincipal(request.getUserPrincipal()) - .secure(request.isSecure()) - .remoteAddr(request.getRemoteAddr()) - .build(); - for (String header : headers.keySet()) { - context.getHeaders().put(header, headers.get(header)); - } - return context; - } - - private Writer createTyrusServletWriter(TyrusHttpUpgradeHandler handler) { - try { - return (Writer) tyrusServletWriterConstructor.newInstance(handler); - } - catch (Exception ex) { - throw new HandshakeFailureException("Failed to instantiate TyrusServletWriter", ex); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/WebLogicRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/WebLogicRequestUpgradeStrategy.java new file mode 100644 index 00000000000..c7ff8038fc3 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/WebLogicRequestUpgradeStrategy.java @@ -0,0 +1,197 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server.standard; + +import org.glassfish.tyrus.core.TyrusUpgradeResponse; +import org.glassfish.tyrus.core.Utils; +import org.glassfish.tyrus.spi.Connection; +import org.glassfish.tyrus.spi.WebSocketEngine.UpgradeInfo; +import org.glassfish.tyrus.spi.Writer; +import org.springframework.beans.BeanWrapper; +import org.springframework.beans.BeanWrapperImpl; +import org.springframework.util.ReflectionUtils; +import org.springframework.web.socket.server.HandshakeFailureException; + +import javax.servlet.AsyncContext; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletRequestWrapper; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.websocket.CloseReason; +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.Method; +import java.util.List; +import java.util.Map; + +/** + * A WebSocket {@code RequestUpgradeStrategy} for WebLogic 12.1.3. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class WebLogicRequestUpgradeStrategy extends AbstractTyrusRequestUpgradeStrategy { + + private static final TyrusEndpointHelper endpointHelper = new Tyrus135EndpointHelper(); + + private static final TyrusMuxableWebSocketHelper webSocketHelper = new TyrusMuxableWebSocketHelper(); + + private static final WebLogicServletWriterHelper servletWriterHelper = new WebLogicServletWriterHelper(); + + + @Override + protected TyrusEndpointHelper getEndpointHelper() { + return endpointHelper; + } + + + @Override + protected void handleSuccess(HttpServletRequest request, HttpServletResponse response, + UpgradeInfo upgradeInfo, TyrusUpgradeResponse upgradeResponse) throws IOException, ServletException { + + response.setStatus(upgradeResponse.getStatus()); + for (Map.Entry> entry : upgradeResponse.getHeaders().entrySet()) { + response.addHeader(entry.getKey(), Utils.getHeaderFromList(entry.getValue())); + } + + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(-1L); + + Object nativeRequest = getNativeRequest(request); + BeanWrapper beanWrapper = new BeanWrapperImpl(nativeRequest); + Object httpSocket = beanWrapper.getPropertyValue("connection.connectionHandler.rawConnection"); + Object webSocket = webSocketHelper.newInstance(httpSocket); + webSocketHelper.upgrade(webSocket, httpSocket, request.getServletContext()); + + response.flushBuffer(); + + boolean isProtected = request.getUserPrincipal() != null; + Writer servletWriter = servletWriterHelper.newInstance(response, webSocket, isProtected); + Connection connection = upgradeInfo.createConnection(servletWriter, noOpCloseListener); + new BeanWrapperImpl(webSocket).setPropertyValue("connection", connection); + new BeanWrapperImpl(servletWriter).setPropertyValue("connection", connection); + webSocketHelper.registerForReadEvent(webSocket); + } + + private static Object getNativeRequest(ServletRequest request) { + while ((request instanceof ServletRequestWrapper)) { + request = ((ServletRequestWrapper) request).getRequest(); + } + return request; + } + + + private static final Connection.CloseListener noOpCloseListener = new Connection.CloseListener() { + + @Override + public void close(CloseReason reason) { + } + }; + + /** + * Helps to create and invoke {@code weblogic.servlet.internal.MuxableSocketHTTP}. + */ + private static class TyrusMuxableWebSocketHelper { + + public static final Class webSocketType; + + private static final Constructor webSocketConstructor; + + private static final Method webSocketUpgradeMethod; + + private static final Method webSocketReadEventMethod; + + static { + try { + ClassLoader classLoader = WebLogicRequestUpgradeStrategy.class.getClassLoader(); + Class socketType = classLoader.loadClass("weblogic.socket.MuxableSocket"); + Class httpSocketType = classLoader.loadClass("weblogic.servlet.internal.MuxableSocketHTTP"); + + webSocketType = classLoader.loadClass("weblogic.websocket.tyrus.TyrusMuxableWebSocket"); + webSocketConstructor = webSocketType.getDeclaredConstructor(httpSocketType); + webSocketUpgradeMethod = webSocketType.getMethod("upgrade", socketType, ServletContext.class); + webSocketReadEventMethod = webSocketType.getMethod("registerForReadEvent"); + } + catch (Exception ex) { + throw new IllegalStateException("No compatible WebSocket version found", ex); + } + } + + private Object newInstance(Object httpSocket) { + try { + return webSocketConstructor.newInstance(httpSocket); + } + catch (Exception ex) { + throw new HandshakeFailureException("Failed to create TyrusMuxableWebSocket", ex); + } + } + + private void upgrade(Object webSocket, Object httpSocket, ServletContext servletContext) { + try { + webSocketUpgradeMethod.invoke(webSocket, httpSocket, servletContext); + } + catch (Exception ex) { + throw new HandshakeFailureException("Failed to upgrade TyrusMuxableWebSocket", ex); + } + } + + private void registerForReadEvent(Object webSocket) { + try { + webSocketReadEventMethod.invoke(webSocket); + } + catch (Exception ex) { + throw new HandshakeFailureException("Failed to register WebSocket for read event.", ex); + } + } + } + + /** + * Helps to create and invoke {@code weblogic.websocket.tyrus.TyrusServletWriter}. + */ + private static class WebLogicServletWriterHelper { + + private static final Constructor constructor; + + static { + try { + ClassLoader loader = WebLogicRequestUpgradeStrategy.class.getClassLoader(); + Class type = loader.loadClass("weblogic.websocket.tyrus.TyrusServletWriter"); + Class listenerType = loader.loadClass("weblogic.websocket.tyrus.TyrusServletWriter$CloseListener"); + Class webSocketType = TyrusMuxableWebSocketHelper.webSocketType; + Class responseType = HttpServletResponse.class; + constructor = type.getDeclaredConstructor(webSocketType, responseType, listenerType, boolean.class); + ReflectionUtils.makeAccessible(constructor); + } + catch (Exception ex) { + throw new IllegalStateException("No compatible WebSocket version found", ex); + } + } + + + private Writer newInstance(HttpServletResponse response, Object webSocket, boolean isProtected) { + try { + return (Writer) constructor.newInstance(webSocket, response, null, isProtected); + } + catch (Exception ex) { + throw new HandshakeFailureException("Failed to create TyrusServletWriter", ex); + } + } + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java index 194d939c367..62457fe2342 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java @@ -62,9 +62,6 @@ public class DefaultHandshakeHandler implements HandshakeHandler { protected Log logger = LogFactory.getLog(getClass()); - private static final boolean glassFishWsPresent = ClassUtils.isPresent( - "org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", DefaultHandshakeHandler.class.getClassLoader()); - private static final boolean jettyWsPresent = ClassUtils.isPresent( "org.eclipse.jetty.websocket.server.WebSocketServerFactory", DefaultHandshakeHandler.class.getClassLoader()); @@ -74,6 +71,12 @@ public class DefaultHandshakeHandler implements HandshakeHandler { private static final boolean undertowWsPresent = ClassUtils.isPresent( "io.undertow.websockets.jsr.ServerWebSocketContainer", DefaultHandshakeHandler.class.getClassLoader()); + private static final boolean glassFishWsPresent = ClassUtils.isPresent( + "org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", DefaultHandshakeHandler.class.getClassLoader()); + + private static final boolean webLogicWsPresent = ClassUtils.isPresent( + "weblogic.websocket.tyrus.TyrusServletWriter", DefaultHandshakeHandler.class.getClassLoader()); + private final RequestUpgradeStrategy requestUpgradeStrategy; @@ -97,11 +100,14 @@ public class DefaultHandshakeHandler implements HandshakeHandler { else if (tomcatWsPresent) { className = "org.springframework.web.socket.server.standard.TomcatRequestUpgradeStrategy"; } + else if (undertowWsPresent) { + className = "org.springframework.web.socket.server.standard.UndertowRequestUpgradeStrategy"; + } else if (glassFishWsPresent) { className = "org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy"; } - else if (undertowWsPresent) { - className = "org.springframework.web.socket.server.standard.UndertowRequestUpgradeStrategy"; + else if (webLogicWsPresent) { + className = "org.springframework.web.socket.server.standard.WebLogicRequestUpgradeStrategy"; } else { throw new IllegalStateException("No suitable default RequestUpgradeStrategy found");