diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/endpoint/ServerEndpointExporter.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/endpoint/ServerEndpointExporter.java index 798387a4d2..f9df84aa99 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/endpoint/ServerEndpointExporter.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/endpoint/ServerEndpointExporter.java @@ -35,7 +35,6 @@ import org.springframework.beans.factory.config.BeanPostProcessor; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.util.Assert; -import org.springframework.util.ClassUtils; import org.springframework.util.ReflectionUtils; /** @@ -58,9 +57,6 @@ import org.springframework.util.ReflectionUtils; */ public class ServerEndpointExporter implements InitializingBean, BeanPostProcessor, ApplicationContextAware { - private static final boolean isServletApiPresent = - ClassUtils.isPresent("javax.servlet.ServletContext", ServerEndpointExporter.class.getClassLoader()); - private static Log logger = LogFactory.getLog(ServerEndpointExporter.class); @@ -103,20 +99,25 @@ public class ServerEndpointExporter implements InitializingBean, BeanPostProcess } protected ServerContainer getServerContainer() { - if (isServletApiPresent) { - try { - Method getter = ReflectionUtils.findMethod(this.applicationContext.getClass(), "getServletContext"); - Object servletContext = getter.invoke(this.applicationContext); - Method attrMethod = ReflectionUtils.findMethod(servletContext.getClass(), "getAttribute", String.class); - return (ServerContainer) attrMethod.invoke(servletContext, "javax.websocket.server.ServerContainer"); - } - catch (Exception ex) { - throw new IllegalStateException( - "Failed to get javax.websocket.server.ServerContainer via ServletContext attribute", ex); - } + Class servletContextClass; + try { + servletContextClass = Class.forName("javax.servlet.ServletContext"); + } + catch (Throwable e) { + return null; + } + + try { + Method getter = ReflectionUtils.findMethod(this.applicationContext.getClass(), "getServletContext"); + Object servletContext = getter.invoke(this.applicationContext); + Method attrMethod = ReflectionUtils.findMethod(servletContextClass, "getAttribute", String.class); + return (ServerContainer) attrMethod.invoke(servletContext, "javax.websocket.server.ServerContainer"); + } + catch (Exception ex) { + throw new IllegalStateException( + "Failed to get javax.websocket.server.ServerContainer via ServletContext attribute", ex); } - return null; } @Override