diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/SockJsSessionFactory.java b/spring-websocket/src/main/java/org/springframework/sockjs/SockJsSessionFactory.java index f78c2ccf4b..b67dcefb2c 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/SockJsSessionFactory.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/SockJsSessionFactory.java @@ -24,6 +24,6 @@ package org.springframework.sockjs; */ public interface SockJsSessionFactory{ - S createSession(String sessionId); + S createSession(String sessionId, SockJsHandler sockJsHandler); } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/SockJsSessionSupport.java b/spring-websocket/src/main/java/org/springframework/sockjs/SockJsSessionSupport.java index 6e20a0b131..8c9747dc0a 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/SockJsSessionSupport.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/SockJsSessionSupport.java @@ -49,7 +49,7 @@ public abstract class SockJsSessionSupport implements SockJsSession { */ public SockJsSessionSupport(String sessionId, SockJsHandler sockJsHandler) { Assert.notNull(sessionId, "sessionId is required"); - Assert.notNull(sockJsHandler, "SockJsHandler is required"); + Assert.notNull(sockJsHandler, "sockJsHandler is required"); this.sessionId = sessionId; this.sockJsHandler = sockJsHandler; } @@ -58,10 +58,6 @@ public abstract class SockJsSessionSupport implements SockJsSession { return this.sessionId; } - public SockJsHandler getSockJsHandler() { - return this.sockJsHandler; - } - public boolean isNew() { return State.NEW.equals(this.state); } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractServerSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractServerSession.java index feb7599b80..caa436512b 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractServerSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractServerSession.java @@ -42,16 +42,11 @@ public abstract class AbstractServerSession extends SockJsSessionSupport { private ScheduledFuture heartbeatTask; - public AbstractServerSession(String sessionId, SockJsConfiguration sockJsConfig) { - super(sessionId, getSockJsHandler(sockJsConfig)); + public AbstractServerSession(String sessionId, SockJsConfiguration sockJsConfig, SockJsHandler sockJsHandler) { + super(sessionId, sockJsHandler); this.sockJsConfig = sockJsConfig; } - private static SockJsHandler getSockJsHandler(SockJsConfiguration sockJsConfig) { - Assert.notNull(sockJsConfig, "sockJsConfig is required"); - return sockJsConfig.getSockJsHandler(); - } - protected SockJsConfiguration getSockJsConfig() { return this.sockJsConfig; } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java index 2a81ca6f8c..e6823c0d4c 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java @@ -32,11 +32,13 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.scheduling.TaskScheduler; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; +import org.springframework.sockjs.SockJsHandler; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.DigestUtils; import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; +import org.springframework.web.util.UriUtils; /** @@ -45,7 +47,7 @@ import org.springframework.util.StringUtils; * @author Rossen Stoyanchev * @since 4.0 */ -public abstract class AbstractSockJsService implements SockJsConfiguration { +public abstract class AbstractSockJsService implements SockJsService, SockJsConfiguration { protected final Log logger = LogFactory.getLog(getClass()); @@ -169,10 +171,20 @@ public abstract class AbstractSockJsService implements SockJsConfiguration { this.heartbeatScheduler = heartbeatScheduler; } + /** + * The amount of time in milliseconds before a client is considered + * disconnected after not having a receiving connection, i.e. an active + * connection over which the server can send data to the client. + *

+ * The default value is 5000. + */ public void setDisconnectDelay(long disconnectDelay) { this.disconnectDelay = disconnectDelay; } + /** + * Return the amount of time in milliseconds before a client is considered disconnected. + */ public long getDisconnectDelay() { return this.disconnectDelay; } @@ -191,7 +203,7 @@ public abstract class AbstractSockJsService implements SockJsConfiguration { * Whether WebSocket transport is enabled. * @see #setWebSocketsEnabled(boolean) */ - public boolean isWebSocketsEnabled() { + public boolean isWebSocketEnabled() { return this.webSocketsEnabled; } @@ -205,8 +217,8 @@ public abstract class AbstractSockJsService implements SockJsConfiguration { * * @throws Exception */ - public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, String sockJsPath) - throws Exception { + public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, + String sockJsPath, SockJsHandler sockJsHandler) throws Exception { logger.debug(request.getMethod() + " [" + sockJsPath + "]"); @@ -217,6 +229,10 @@ public abstract class AbstractSockJsService implements SockJsConfiguration { // Ignore invalid Content-Type (TODO) } + String path = UriUtils.decode(request.getURI().getPath(), "URF-8"); + int index = path.indexOf(this.prefix); + sockJsPath = path.substring(index + this.prefix.length()); + try { if (sockJsPath.equals("") || sockJsPath.equals("/")) { response.getHeaders().setContentType(new MediaType("text", "plain", Charset.forName("UTF-8"))); @@ -232,7 +248,7 @@ public abstract class AbstractSockJsService implements SockJsConfiguration { return; } else if (sockJsPath.equals("/websocket")) { - handleRawWebSocket(request, response); + handleRawWebSocketRequest(request, response, sockJsHandler); return; } @@ -252,18 +268,19 @@ public abstract class AbstractSockJsService implements SockJsConfiguration { return; } - handleTransportRequest(request, response, sessionId, TransportType.fromValue(transport)); + handleTransportRequest(request, response, sessionId, TransportType.fromValue(transport), sockJsHandler); } finally { response.flush(); } } - protected abstract void handleRawWebSocket(ServerHttpRequest request, ServerHttpResponse response) - throws Exception; + protected abstract void handleRawWebSocketRequest(ServerHttpRequest request, ServerHttpResponse response, + SockJsHandler sockJsHandler) throws Exception; protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response, - String sessionId, TransportType transportType) throws Exception; + String sessionId, TransportType transportType, SockJsHandler sockJsHandler) throws Exception; + protected boolean validateRequest(String serverId, String sessionId, String transport) { @@ -278,7 +295,7 @@ public abstract class AbstractSockJsService implements SockJsConfiguration { return false; } - if (!isWebSocketsEnabled() && transport.equals(TransportType.WEBSOCKET.value())) { + if (!isWebSocketEnabled() && transport.equals(TransportType.WEBSOCKET.value())) { logger.debug("Websocket transport is disabled"); return false; } @@ -344,7 +361,7 @@ public abstract class AbstractSockJsService implements SockJsConfiguration { addCorsHeaders(request, response); addNoCacheHeaders(response); - String content = String.format(INFO_CONTENT, random.nextInt(), isJsessionIdCookieRequired(), isWebSocketsEnabled()); + String content = String.format(INFO_CONTENT, random.nextInt(), isJsessionIdCookieRequired(), isWebSocketEnabled()); response.getBody().write(content.getBytes()); } else if (HttpMethod.OPTIONS.equals(request.getMethod())) { diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandlerRegistrar.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/ConfigurableTransportHandler.java similarity index 55% rename from spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandlerRegistrar.java rename to spring-websocket/src/main/java/org/springframework/sockjs/server/ConfigurableTransportHandler.java index f89c2ea81b..8ad26ef6fa 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandlerRegistrar.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/ConfigurableTransportHandler.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.sockjs.server; +import java.util.Collection; + +import org.springframework.sockjs.SockJsHandler; +import org.springframework.websocket.WebSocketHandler; + /** * * @author Rossen Stoyanchev * @since 4.0 */ -public interface TransportHandlerRegistrar { +public interface ConfigurableTransportHandler extends TransportHandler { - void registerTransportHandlers(TransportHandlerRegistry registry, SockJsConfiguration config); + void setSockJsConfiguration(SockJsConfiguration sockJsConfig); + + /** + * Pre-register {@link SockJsHandler} instances so they can be adapted to + * {@link WebSocketHandler} and hence re-used at runtime. + */ + void registerSockJsHandlers(Collection sockJsHandlers); } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsConfiguration.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsConfiguration.java index a122a769e3..4bf1db79e8 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsConfiguration.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsConfiguration.java @@ -17,18 +17,14 @@ package org.springframework.sockjs.server; import org.springframework.scheduling.TaskScheduler; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; -import org.springframework.sockjs.SockJsHandler; - /** - * * * @author Rossen Stoyanchev * @since 4.0 */ public interface SockJsConfiguration { - /** * Streaming transports save responses on the client side and don't free * memory used by delivered messages. Such transports need to recycle the @@ -42,15 +38,6 @@ public interface SockJsConfiguration { */ public int getStreamBytesLimit(); - /** - * The amount of time in milliseconds before a client is considered - * disconnected after not having a receiving connection, i.e. an active - * connection over which the server can send data to the client. - *

- * The default value is 5000. - */ - public long getDisconnectDelay(); - /** * The amount of time in milliseconds when the server has not sent any * messages and after which the server should send a heartbeat frame to the @@ -67,11 +54,4 @@ public interface SockJsConfiguration { */ public TaskScheduler getHeartbeatScheduler(); - /** - * Provides access to the {@link SockJsHandler} that will handle the request. This - * method should be called once per SockJS session. It may return the same or a - * different instance every time it is called. - */ - SockJsHandler getSockJsHandler(); - } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/RequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsService.java similarity index 53% rename from spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/RequestUpgradeStrategy.java rename to spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsService.java index 2dd09023b7..dfb1e002d4 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/RequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsService.java @@ -14,29 +14,34 @@ * limitations under the License. */ -package org.springframework.websocket.server.endpoint.handshake; +package org.springframework.sockjs.server; -import javax.websocket.Endpoint; +import java.util.Collection; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.sockjs.SockJsHandler; +import org.springframework.websocket.WebSocketHandler; /** - * A strategy for performing the container-specific steps for upgrading an HTTP request - * as part of a WebSocket handshake. * * @author Rossen Stoyanchev * @since 4.0 */ -public interface RequestUpgradeStrategy { +public interface SockJsService { - String[] getSupportedVersions(); + String getPrefix(); /** - * Invoked after the handshake checks have been performed and succeeded. + * Pre-register {@link SockJsHandler} instances so they can be adapted to + * {@link WebSocketHandler} and hence re-used at runtime when + * {@link #handleRequest(ServerHttpRequest, ServerHttpResponse, String, SockJsHandler) handleRequest} + * is called. */ - void upgrade(ServerHttpRequest request, ServerHttpResponse response, String protocol, Endpoint endpoint) - throws Exception; + void registerSockJsHandlers(Collection sockJsHandlers); + + void handleRequest(ServerHttpRequest request, ServerHttpResponse response, String sockJsPath, + SockJsHandler handler) throws Exception; } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandler.java index 8f4087227c..ea6208b836 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the toriginal author or authors. + * Copyright 2002-2013 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.sockjs.server; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.sockjs.SockJsHandler; import org.springframework.sockjs.SockJsSessionSupport; @@ -29,7 +30,7 @@ public interface TransportHandler { TransportType getTransportType(); - void handleRequest(ServerHttpRequest request, ServerHttpResponse response, SockJsSessionSupport session) - throws Exception; + void handleRequest(ServerHttpRequest request, ServerHttpResponse response, + SockJsHandler sockJsHandler, SockJsSessionSupport session) throws Exception; } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandlerRegistry.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandlerRegistry.java deleted file mode 100644 index 0c4d232e9e..0000000000 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandlerRegistry.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.sockjs.server; - - -/** - * - * @author Rossen Stoyanchev - * @since 4.0 - */ -public interface TransportHandlerRegistry { - - void registerHandler(TransportHandler handler); - -} diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java index 4cd8af9f1b..af0789aa02 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java @@ -16,16 +16,13 @@ package org.springframework.sockjs.server.support; import java.util.Arrays; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import org.springframework.beans.BeansException; -import org.springframework.beans.factory.BeanFactory; -import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.beans.factory.InitializingBean; -import org.springframework.beans.factory.config.AutowireCapableBeanFactory; import org.springframework.http.Cookie; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; @@ -37,11 +34,21 @@ import org.springframework.sockjs.SockJsHandler; import org.springframework.sockjs.SockJsSessionFactory; import org.springframework.sockjs.SockJsSessionSupport; import org.springframework.sockjs.server.AbstractSockJsService; +import org.springframework.sockjs.server.ConfigurableTransportHandler; import org.springframework.sockjs.server.TransportHandler; -import org.springframework.sockjs.server.TransportHandlerRegistrar; -import org.springframework.sockjs.server.TransportHandlerRegistry; import org.springframework.sockjs.server.TransportType; +import org.springframework.sockjs.server.transport.EventSourceTransportHandler; +import org.springframework.sockjs.server.transport.HtmlFileTransportHandler; +import org.springframework.sockjs.server.transport.JsonpPollingTransportHandler; +import org.springframework.sockjs.server.transport.JsonpTransportHandler; +import org.springframework.sockjs.server.transport.WebSocketSockJsHandlerAdapter; +import org.springframework.sockjs.server.transport.WebSocketTransportHandler; +import org.springframework.sockjs.server.transport.XhrPollingTransportHandler; +import org.springframework.sockjs.server.transport.XhrStreamingTransportHandler; +import org.springframework.sockjs.server.transport.XhrTransportHandler; import org.springframework.util.Assert; +import org.springframework.websocket.WebSocketHandler; +import org.springframework.websocket.server.DefaultHandshakeHandler; import org.springframework.websocket.server.HandshakeHandler; @@ -51,37 +58,22 @@ import org.springframework.websocket.server.HandshakeHandler; * @author Rossen Stoyanchev * @since 4.0 */ -public class DefaultSockJsService extends AbstractSockJsService - implements TransportHandlerRegistry, BeanFactoryAware, InitializingBean { +public class DefaultSockJsService extends AbstractSockJsService implements InitializingBean { - private final Class sockJsHandlerClass; + private final Map transportHandlers = new HashMap(); - private final SockJsHandler sockJsHandler; + private final Map transportHandlerOverrides = new HashMap(); private TaskScheduler sessionTimeoutScheduler; private final Map sessions = new ConcurrentHashMap(); - private final Map transportHandlers = new HashMap(); - - private AutowireCapableBeanFactory beanFactory; + private final Map sockJsHandlers = new HashMap(); - public DefaultSockJsService(String prefix, Class sockJsHandlerClass) { - this(prefix, sockJsHandlerClass, null); - } - - public DefaultSockJsService(String prefix, SockJsHandler sockJsHandler) { - this(prefix, null, sockJsHandler); - } - - private DefaultSockJsService(String prefix, Class handlerClass, SockJsHandler handler) { + public DefaultSockJsService(String prefix) { super(prefix); - Assert.isTrue(((handlerClass != null) || (handler != null)), "A sockJsHandler class or instance is required"); - this.sockJsHandlerClass = handlerClass; - this.sockJsHandler = handler; this.sessionTimeoutScheduler = createScheduler("SockJs-sessionTimeout-"); - new DefaultTransportHandlerRegistrar().registerTransportHandlers(this, this); } /** @@ -98,43 +90,62 @@ public class DefaultSockJsService extends AbstractSockJsService this.sessionTimeoutScheduler = sessionTimeoutScheduler; } - @Override - public void registerHandler(TransportHandler transportHandler) { - Assert.notNull(transportHandler, "transportHandler is required"); - this.transportHandlers.put(transportHandler.getTransportType(), transportHandler); - } - - public void setTransportHandlerRegistrar(TransportHandlerRegistrar registrar) { - Assert.notNull(registrar, "registrar is required"); + public void setTransportHandlers(TransportHandler... handlers) { this.transportHandlers.clear(); - registrar.registerTransportHandlers(this, this); - } - - @Override - public void setBeanFactory(BeanFactory beanFactory) throws BeansException { - if (beanFactory instanceof AutowireCapableBeanFactory) { - this.beanFactory = (AutowireCapableBeanFactory) beanFactory; + for (TransportHandler handler : handlers) { + this.transportHandlers.put(handler.getTransportType(), handler); } } - @Override - public SockJsHandler getSockJsHandler() { - return (this.sockJsHandlerClass != null) ? - this.beanFactory.createBean(this.sockJsHandlerClass) : this.sockJsHandler; + public void setTransportHandlerOverrides(TransportHandler... handlers) { + this.transportHandlerOverrides.clear(); + for (TransportHandler handler : handlers) { + this.transportHandlerOverrides.put(handler.getTransportType(), handler); + } + } + + public void registerSockJsHandlers(Collection sockJsHandlers) { + for (SockJsHandler sockJsHandler : sockJsHandlers) { + if (!this.sockJsHandlers.containsKey(sockJsHandler)) { + this.sockJsHandlers.put(sockJsHandler, adaptSockJsHandler(sockJsHandler)); + } + } + configureTransportHandlers(); + } + + /** + * Adapt the {@link SockJsHandler} to the {@link WebSocketHandler} contract for + * raw WebSocket communication on SockJS path "/websocket". + */ + protected WebSocketSockJsHandlerAdapter adaptSockJsHandler(SockJsHandler sockJsHandler) { + return new WebSocketSockJsHandlerAdapter(this, sockJsHandler); } @Override public void afterPropertiesSet() throws Exception { - if (this.sockJsHandler != null) { - Assert.notNull(this.beanFactory, - "An AutowirecapableBeanFactory is required to initialize SockJS handler instances per request."); + if (this.transportHandlers.isEmpty()) { + if (isWebSocketEnabled() && (this.transportHandlerOverrides.get(TransportType.WEBSOCKET) == null)) { + this.transportHandlers.put(TransportType.WEBSOCKET, + new WebSocketTransportHandler(new DefaultHandshakeHandler())); + } + this.transportHandlers.put(TransportType.XHR, new XhrPollingTransportHandler()); + this.transportHandlers.put(TransportType.XHR_SEND, new XhrTransportHandler()); + this.transportHandlers.put(TransportType.JSONP, new JsonpPollingTransportHandler()); + this.transportHandlers.put(TransportType.JSONP_SEND, new JsonpTransportHandler()); + this.transportHandlers.put(TransportType.XHR_STREAMING, new XhrStreamingTransportHandler()); + this.transportHandlers.put(TransportType.EVENT_SOURCE, new EventSourceTransportHandler()); + this.transportHandlers.put(TransportType.HTML_FILE, new HtmlFileTransportHandler()); } - if (this.transportHandlers.get(TransportType.WEBSOCKET) == null) { - logger.warn("No WebSocket transport handler was registered"); + if (!this.transportHandlerOverrides.isEmpty()) { + for (TransportHandler transportHandler : this.transportHandlerOverrides.values()) { + this.transportHandlers.put(transportHandler.getTransportType(), transportHandler); + } } + configureTransportHandlers(); + this.sessionTimeoutScheduler.scheduleAtFixedRate(new Runnable() { public void run() { try { @@ -162,22 +173,45 @@ public class DefaultSockJsService extends AbstractSockJsService }, getDisconnectDelay()); } - @Override - protected void handleRawWebSocket(ServerHttpRequest request, ServerHttpResponse response) throws Exception { - TransportHandler transportHandler = this.transportHandlers.get(TransportType.WEBSOCKET); - if ((transportHandler != null) && transportHandler instanceof HandshakeHandler) { - HandshakeHandler handshakeHandler = (HandshakeHandler) transportHandler; - handshakeHandler.doHandshake(request, response); - } - else { - logger.debug("No handler found for raw WebSocket messages"); - response.setStatusCode(HttpStatus.NOT_FOUND); + + private void configureTransportHandlers() { + for (TransportHandler h : this.transportHandlers.values()) { + if (h instanceof ConfigurableTransportHandler) { + ((ConfigurableTransportHandler) h).setSockJsConfiguration(this); + if (!this.sockJsHandlers.isEmpty()) { + ((ConfigurableTransportHandler) h).registerSockJsHandlers(this.sockJsHandlers.keySet()); + if (h instanceof HandshakeHandler) { + ((HandshakeHandler) h).registerWebSocketHandlers(this.sockJsHandlers.values()); + } + } + } } } + @Override + protected void handleRawWebSocketRequest(ServerHttpRequest request, ServerHttpResponse response, + SockJsHandler sockJsHandler) throws Exception { + + if (isWebSocketEnabled()) { + TransportHandler transportHandler = this.transportHandlers.get(TransportType.WEBSOCKET); + if (transportHandler != null) { + if (transportHandler instanceof HandshakeHandler) { + WebSocketHandler webSocketHandler = this.sockJsHandlers.get(sockJsHandler); + if (webSocketHandler == null) { + webSocketHandler = adaptSockJsHandler(sockJsHandler); + } + ((HandshakeHandler) transportHandler).doHandshake(request, response, webSocketHandler); + return; + } + } + logger.warn("No handler for raw WebSocket messages"); + } + response.setStatusCode(HttpStatus.NOT_FOUND); + } + @Override protected void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response, - String sessionId, TransportType transportType) throws Exception { + String sessionId, TransportType transportType, SockJsHandler sockJsHandler) throws Exception { TransportHandler transportHandler = this.transportHandlers.get(transportType); @@ -204,7 +238,7 @@ public class DefaultSockJsService extends AbstractSockJsService return; } - SockJsSessionSupport session = getSockJsSession(sessionId, transportHandler); + SockJsSessionSupport session = getSockJsSession(sessionId, sockJsHandler, transportHandler); if (session != null) { if (transportType.setsNoCacheHeader()) { @@ -215,7 +249,7 @@ public class DefaultSockJsService extends AbstractSockJsService Cookie cookie = request.getCookies().getCookie("JSESSIONID"); String jsid = (cookie != null) ? cookie.getValue() : "dummy"; // TODO: bypass use of Cookie object (causes Jetty to set Expires header) - response.getHeaders().set("Set-Cookie", "JSESSIONID=" + jsid + ";path=/"); // TODO + response.getHeaders().set("Set-Cookie", "JSESSIONID=" + jsid + ";path=/"); } if (transportType.supportsCors()) { @@ -223,10 +257,11 @@ public class DefaultSockJsService extends AbstractSockJsService } } - transportHandler.handleRequest(request, response, session); + transportHandler.handleRequest(request, response, sockJsHandler, session); } - public SockJsSessionSupport getSockJsSession(String sessionId, TransportHandler transportHandler) { + public SockJsSessionSupport getSockJsSession(String sessionId, SockJsHandler sockJsHandler, + TransportHandler transportHandler) { SockJsSessionSupport session = this.sessions.get(sessionId); if (session != null) { @@ -242,7 +277,7 @@ public class DefaultSockJsService extends AbstractSockJsService return session; } logger.debug("Creating new session with session id \"" + sessionId + "\""); - session = (SockJsSessionSupport) sessionFactory.createSession(sessionId); + session = (SockJsSessionSupport) sessionFactory.createSession(sessionId, sockJsHandler); this.sessions.put(sessionId, session); return session; } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultTransportHandlerRegistrar.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultTransportHandlerRegistrar.java deleted file mode 100644 index fbbeab48a4..0000000000 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultTransportHandlerRegistrar.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.sockjs.server.support; - -import java.lang.reflect.Constructor; - -import org.springframework.beans.BeanUtils; -import org.springframework.sockjs.server.SockJsConfiguration; -import org.springframework.sockjs.server.TransportHandler; -import org.springframework.sockjs.server.TransportHandlerRegistrar; -import org.springframework.sockjs.server.TransportHandlerRegistry; -import org.springframework.sockjs.server.transport.EventSourceTransportHandler; -import org.springframework.sockjs.server.transport.HtmlFileTransportHandler; -import org.springframework.sockjs.server.transport.JsonpPollingTransportHandler; -import org.springframework.sockjs.server.transport.JsonpTransportHandler; -import org.springframework.sockjs.server.transport.XhrPollingTransportHandler; -import org.springframework.sockjs.server.transport.XhrStreamingTransportHandler; -import org.springframework.sockjs.server.transport.XhrTransportHandler; -import org.springframework.util.ClassUtils; - - -/** - * TODO - * - * @author Rossen Stoyanchev - * @since 4.0 - */ -public class DefaultTransportHandlerRegistrar implements TransportHandlerRegistrar { - - private static final boolean standardWebSocketApiPresent = ClassUtils.isPresent( - "javax.websocket.server.ServerEndpointConfig", DefaultTransportHandlerRegistrar.class.getClassLoader()); - - - public void registerTransportHandlers(TransportHandlerRegistry registry, SockJsConfiguration config) { - - if (standardWebSocketApiPresent) { - registry.registerHandler(createEndpointWebSocketTransportHandler(config)); - } - - registry.registerHandler(new XhrPollingTransportHandler(config)); - registry.registerHandler(new XhrTransportHandler()); - - registry.registerHandler(new JsonpPollingTransportHandler(config)); - registry.registerHandler(new JsonpTransportHandler()); - - registry.registerHandler(new XhrStreamingTransportHandler(config)); - registry.registerHandler(new EventSourceTransportHandler(config)); - registry.registerHandler(new HtmlFileTransportHandler(config)); - - } - - private TransportHandler createEndpointWebSocketTransportHandler(SockJsConfiguration config) { - try { - String className = "org.springframework.sockjs.server.transport.EndpointWebSocketTransportHandler"; - Class clazz = ClassUtils.forName(className, DefaultTransportHandlerRegistrar.class.getClassLoader()); - Constructor constructor = clazz.getConstructor(SockJsConfiguration.class); - return (TransportHandler) BeanUtils.instantiateClass(constructor, config); - } - catch (Throwable t) { - throw new IllegalStateException("Failed to instantiate EndpointWebSocketTransportHandler", t); - } - } - -} diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/support/SockJsHttpRequestHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/support/SockJsHttpRequestHandler.java new file mode 100644 index 0000000000..977f9eb8b9 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/support/SockJsHttpRequestHandler.java @@ -0,0 +1,105 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.sockjs.server.support; + +import java.io.IOException; +import java.util.Collections; + +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.http.server.AsyncServletServerHttpRequest; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.http.server.ServletServerHttpResponse; +import org.springframework.sockjs.SockJsHandler; +import org.springframework.sockjs.server.SockJsService; +import org.springframework.util.Assert; +import org.springframework.web.HttpRequestHandler; +import org.springframework.web.util.NestedServletException; +import org.springframework.web.util.UrlPathHelper; +import org.springframework.websocket.HandlerProvider; + + +/** + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class SockJsHttpRequestHandler implements HttpRequestHandler, BeanFactoryAware { + + private final SockJsService sockJsService; + + private final HandlerProvider handlerProvider; + + private final UrlPathHelper urlPathHelper = new UrlPathHelper(); + + + public SockJsHttpRequestHandler(SockJsService sockJsService, SockJsHandler sockJsHandler) { + Assert.notNull(sockJsService, "sockJsService is required"); + Assert.notNull(sockJsHandler, "sockJsHandler is required"); + this.sockJsService = sockJsService; + this.sockJsService.registerSockJsHandlers(Collections.singleton(sockJsHandler)); + this.handlerProvider = new HandlerProvider(sockJsHandler); + } + + public SockJsHttpRequestHandler(SockJsService sockJsService, Class sockJsHandlerClass) { + Assert.notNull(sockJsService, "sockJsService is required"); + Assert.notNull(sockJsHandlerClass, "sockJsHandlerClass is required"); + this.sockJsService = sockJsService; + this.handlerProvider = new HandlerProvider(sockJsHandlerClass); + } + + public String getMappingPattern() { + return this.sockJsService.getPrefix() + "/**"; + } + + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + this.handlerProvider.setBeanFactory(beanFactory); + } + + @Override + public void handleRequest(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String lookupPath = this.urlPathHelper.getLookupPathForRequest(request); + String prefix = this.sockJsService.getPrefix(); + + Assert.isTrue(lookupPath.startsWith(prefix), + "Request path does not match the prefix of the SockJsService " + prefix); + + String sockJsPath = lookupPath.substring(prefix.length()); + + ServerHttpRequest httpRequest = new AsyncServletServerHttpRequest(request, response); + ServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + + try { + SockJsHandler sockJsHandler = this.handlerProvider.getHandler(); + this.sockJsService.handleRequest(httpRequest, httpResponse, sockJsPath, sockJsHandler); + } + catch (Exception ex) { + // TODO + throw new NestedServletException("SockJS service failure", ex); + } + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/support/SockJsServiceHandlerMapping.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/support/SockJsServiceHandlerMapping.java deleted file mode 100644 index 656324fe5b..0000000000 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/support/SockJsServiceHandlerMapping.java +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.sockjs.server.support; - -import java.io.IOException; -import java.util.Arrays; -import java.util.List; - -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.springframework.http.server.AsyncServletServerHttpRequest; -import org.springframework.http.server.ServerHttpRequest; -import org.springframework.http.server.ServerHttpResponse; -import org.springframework.http.server.ServletServerHttpResponse; -import org.springframework.sockjs.server.AbstractSockJsService; -import org.springframework.web.HttpRequestHandler; -import org.springframework.web.servlet.handler.AbstractHandlerMapping; -import org.springframework.web.util.NestedServletException; - -/** - * A Spring MVC HandlerMapping for matching requests to a SockJS services based on the - * {@link AbstractSockJsService#getPrefix() prefix} property of each service. - * - * @author Rossen Stoyanchev - * @since 4.0 - */ -public class SockJsServiceHandlerMapping extends AbstractHandlerMapping { - - private static Log logger = LogFactory.getLog(SockJsServiceHandlerMapping.class); - - private final List sockJsServices; - - - public SockJsServiceHandlerMapping(AbstractSockJsService... sockJsServices) { - this.sockJsServices = Arrays.asList(sockJsServices); - } - - @Override - protected Object getHandlerInternal(HttpServletRequest request) throws Exception { - - String lookupPath = getUrlPathHelper().getLookupPathForRequest(request); - if (logger.isDebugEnabled()) { - logger.debug("Looking for SockJS service match to path " + lookupPath); - } - - for (AbstractSockJsService service : this.sockJsServices) { - if (lookupPath.startsWith(service.getPrefix())) { - if (logger.isDebugEnabled()) { - logger.debug("Matched to " + service); - } - String sockJsPath = lookupPath.substring(service.getPrefix().length()); - return new SockJsServiceHttpRequestHandler(service, sockJsPath); - } - } - - if (logger.isDebugEnabled()) { - logger.debug("Did not find a match"); - } - - return null; - } - - - /** - * {@link HttpRequestHandler} wrapping the invocation of the selected SockJS service. - */ - private static class SockJsServiceHttpRequestHandler implements HttpRequestHandler { - - private final String sockJsPath; - - private final AbstractSockJsService sockJsService; - - - public SockJsServiceHttpRequestHandler(AbstractSockJsService sockJsService, String sockJsPath) { - this.sockJsService = sockJsService; - this.sockJsPath = sockJsPath; - } - - @Override - public void handleRequest(HttpServletRequest request, HttpServletResponse response) - throws ServletException, IOException { - - ServerHttpRequest httpRequest = new AsyncServletServerHttpRequest(request, response); - ServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - - try { - this.sockJsService.handleRequest(httpRequest, httpResponse, this.sockJsPath); - } - catch (Exception ex) { - // TODO - throw new NestedServletException("SockJS service failure", ex); - } - } - } - -} diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpReceivingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpReceivingTransportHandler.java index 589e9770bd..300adbd981 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpReceivingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpReceivingTransportHandler.java @@ -25,6 +25,7 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.sockjs.SockJsHandler; import org.springframework.sockjs.SockJsSessionSupport; import org.springframework.sockjs.server.TransportHandler; @@ -51,8 +52,8 @@ public abstract class AbstractHttpReceivingTransportHandler implements Transport } @Override - public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, SockJsSessionSupport session) - throws Exception { + public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, + SockJsHandler sockJsHandler, SockJsSessionSupport session) throws Exception { if (session == null) { response.setStatusCode(HttpStatus.NOT_FOUND); diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpSendingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpSendingTransportHandler.java index f8dbd59366..aff6cea280 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpSendingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpSendingTransportHandler.java @@ -16,18 +16,20 @@ package org.springframework.sockjs.server.transport; import java.io.IOException; +import java.util.Collection; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.sockjs.SockJsHandler; import org.springframework.sockjs.SockJsSessionFactory; import org.springframework.sockjs.SockJsSessionSupport; +import org.springframework.sockjs.server.ConfigurableTransportHandler; import org.springframework.sockjs.server.SockJsConfiguration; import org.springframework.sockjs.server.SockJsFrame; import org.springframework.sockjs.server.SockJsFrame.FrameFormat; -import org.springframework.sockjs.server.TransportHandler; /** * TODO @@ -36,24 +38,30 @@ import org.springframework.sockjs.server.TransportHandler; * @since 4.0 */ public abstract class AbstractHttpSendingTransportHandler - implements TransportHandler, SockJsSessionFactory { + implements ConfigurableTransportHandler, SockJsSessionFactory { protected final Log logger = LogFactory.getLog(this.getClass()); - private final SockJsConfiguration sockJsConfig; + private SockJsConfiguration sockJsConfig; - public AbstractHttpSendingTransportHandler(SockJsConfiguration sockJsConfig) { + @Override + public void setSockJsConfiguration(SockJsConfiguration sockJsConfig) { this.sockJsConfig = sockJsConfig; } - protected SockJsConfiguration getSockJsConfig() { + @Override + public void registerSockJsHandlers(Collection sockJsHandlers) { + // ignore + } + + public SockJsConfiguration getSockJsConfig() { return this.sockJsConfig; } @Override public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, - SockJsSessionSupport session) throws Exception { + SockJsHandler sockJsHandler, SockJsSessionSupport session) throws Exception { // Set content type before writing response.getHeaders().setContentType(getContentType()); diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpServerSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpServerSession.java index 5f657d9470..e03fb4a4c9 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpServerSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpServerSession.java @@ -22,6 +22,7 @@ import java.util.concurrent.BlockingQueue; import org.springframework.http.server.AsyncServerHttpRequest; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.sockjs.SockJsHandler; import org.springframework.sockjs.server.AbstractServerSession; import org.springframework.sockjs.server.SockJsConfiguration; import org.springframework.sockjs.server.SockJsFrame; @@ -46,8 +47,8 @@ public abstract class AbstractHttpServerSession extends AbstractServerSession { private ServerHttpResponse response; - public AbstractHttpServerSession(String sessionId, SockJsConfiguration sockJsConfig) { - super(sessionId, sockJsConfig); + public AbstractHttpServerSession(String sessionId, SockJsConfiguration sockJsConfig, SockJsHandler sockJsHandler) { + super(sessionId, sockJsConfig, sockJsHandler); } public void setFrameFormat(FrameFormat frameFormat) { diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractSockJsWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractSockJsWebSocketHandler.java index 9f13cd10b1..2f723092fe 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractSockJsWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractSockJsWebSocketHandler.java @@ -22,8 +22,10 @@ import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.sockjs.SockJsHandler; import org.springframework.sockjs.SockJsSessionSupport; import org.springframework.sockjs.server.SockJsConfiguration; +import org.springframework.util.Assert; import org.springframework.websocket.WebSocketHandler; import org.springframework.websocket.WebSocketSession; @@ -39,18 +41,27 @@ public abstract class AbstractSockJsWebSocketHandler implements WebSocketHandler private final SockJsConfiguration sockJsConfig; + private final SockJsHandler sockJsHandler; + private final Map sessions = new ConcurrentHashMap(); - public AbstractSockJsWebSocketHandler(SockJsConfiguration sockJsConfig) { + public AbstractSockJsWebSocketHandler(SockJsConfiguration sockJsConfig, SockJsHandler sockJsHandler) { + Assert.notNull(sockJsConfig, "sockJsConfig is required"); + Assert.notNull(sockJsHandler, "sockJsHandler is required"); this.sockJsConfig = sockJsConfig; + this.sockJsHandler = sockJsHandler; } protected SockJsConfiguration getSockJsConfig() { return this.sockJsConfig; } + protected SockJsHandler getSockJsHandler() { + return this.sockJsHandler; + } + protected SockJsSessionSupport getSockJsSession(WebSocketSession wsSession) { return this.sessions.get(wsSession); } @@ -62,7 +73,6 @@ public abstract class AbstractSockJsWebSocketHandler implements WebSocketHandler } SockJsSessionSupport session = createSockJsSession(wsSession); this.sessions.put(wsSession, session); - session.connectionInitialized(); } protected abstract SockJsSessionSupport createSockJsSession(WebSocketSession wsSession) throws Exception; diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractStreamingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractStreamingTransportHandler.java index 89b076d93f..5aa8eb11be 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractStreamingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractStreamingTransportHandler.java @@ -19,7 +19,8 @@ import java.io.IOException; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; -import org.springframework.sockjs.server.SockJsConfiguration; +import org.springframework.sockjs.SockJsHandler; +import org.springframework.util.Assert; /** @@ -31,13 +32,10 @@ import org.springframework.sockjs.server.SockJsConfiguration; public abstract class AbstractStreamingTransportHandler extends AbstractHttpSendingTransportHandler { - public AbstractStreamingTransportHandler(SockJsConfiguration sockJsConfig) { - super(sockJsConfig); - } - @Override - public StreamingHttpServerSession createSession(String sessionId) { - return new StreamingHttpServerSession(sessionId, getSockJsConfig()); + public StreamingHttpServerSession createSession(String sessionId, SockJsHandler sockJsHandler) { + Assert.notNull(getSockJsConfig(), "This transport requires SockJsConfiguration"); + return new StreamingHttpServerSession(sessionId, getSockJsConfig(), sockJsHandler); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractWebSocketTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractWebSocketTransportHandler.java deleted file mode 100644 index ba0a224bf6..0000000000 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractWebSocketTransportHandler.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.sockjs.server.transport; - -import org.springframework.http.server.ServerHttpRequest; -import org.springframework.http.server.ServerHttpResponse; -import org.springframework.sockjs.SockJsSessionSupport; -import org.springframework.sockjs.server.SockJsConfiguration; -import org.springframework.sockjs.server.TransportHandler; -import org.springframework.sockjs.server.TransportType; -import org.springframework.websocket.WebSocketHandler; -import org.springframework.websocket.server.HandshakeHandler; - - -/** - * - * @author Rossen Stoyanchev - * @since 4.0 - */ -public abstract class AbstractWebSocketTransportHandler implements TransportHandler, HandshakeHandler { - - private final HandshakeHandler sockJsHandshakeHandler; - - private final HandshakeHandler handshakeHandler; - - - public AbstractWebSocketTransportHandler(SockJsConfiguration sockJsConfig) { - this.sockJsHandshakeHandler = createHandshakeHandler(new SockJsWebSocketHandler(sockJsConfig)); - this.handshakeHandler = createHandshakeHandler(new WebSocketSockJsHandlerAdapter(sockJsConfig)); - } - - protected abstract HandshakeHandler createHandshakeHandler(WebSocketHandler webSocketHandler); - - @Override - public TransportType getTransportType() { - return TransportType.WEBSOCKET; - } - - @Override - public void handleRequest(ServerHttpRequest request, ServerHttpResponse response, - SockJsSessionSupport session) throws Exception { - - this.sockJsHandshakeHandler.doHandshake(request, response); - } - - @Override - public boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response) throws Exception { - return this.handshakeHandler.doHandshake(request, response); - } - -} diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/EndpointWebSocketTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/EndpointWebSocketTransportHandler.java deleted file mode 100644 index 8d41943c9a..0000000000 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/EndpointWebSocketTransportHandler.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.sockjs.server.transport; - -import org.springframework.sockjs.server.SockJsConfiguration; -import org.springframework.websocket.WebSocketHandler; -import org.springframework.websocket.server.HandshakeHandler; -import org.springframework.websocket.server.endpoint.handshake.EndpointHandshakeHandler; - - -/** - * - * @author Rossen Stoyanchev - * @since 4.0 - */ -public class EndpointWebSocketTransportHandler extends AbstractWebSocketTransportHandler { - - - public EndpointWebSocketTransportHandler(SockJsConfiguration sockJsConfig) { - super(sockJsConfig); - } - - @Override - protected HandshakeHandler createHandshakeHandler(WebSocketHandler webSocketHandler) { - return new EndpointHandshakeHandler(webSocketHandler); - } - -} diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/EventSourceTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/EventSourceTransportHandler.java index 235b6f5de1..ae8efbc667 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/EventSourceTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/EventSourceTransportHandler.java @@ -21,10 +21,9 @@ import java.nio.charset.Charset; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; -import org.springframework.sockjs.server.SockJsConfiguration; -import org.springframework.sockjs.server.TransportType; import org.springframework.sockjs.server.SockJsFrame.DefaultFrameFormat; import org.springframework.sockjs.server.SockJsFrame.FrameFormat; +import org.springframework.sockjs.server.TransportType; /** @@ -36,10 +35,6 @@ import org.springframework.sockjs.server.SockJsFrame.FrameFormat; public class EventSourceTransportHandler extends AbstractStreamingTransportHandler { - public EventSourceTransportHandler(SockJsConfiguration sockJsConfig) { - super(sockJsConfig); - } - @Override public TransportType getTransportType() { return TransportType.EVENT_SOURCE; diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/HtmlFileTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/HtmlFileTransportHandler.java index 35e48b238e..d6568fa797 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/HtmlFileTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/HtmlFileTransportHandler.java @@ -22,10 +22,9 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; -import org.springframework.sockjs.server.SockJsConfiguration; -import org.springframework.sockjs.server.TransportType; import org.springframework.sockjs.server.SockJsFrame.DefaultFrameFormat; import org.springframework.sockjs.server.SockJsFrame.FrameFormat; +import org.springframework.sockjs.server.TransportType; import org.springframework.util.StringUtils; import org.springframework.web.util.JavaScriptUtils; @@ -67,10 +66,6 @@ public class HtmlFileTransportHandler extends AbstractStreamingTransportHandler } - public HtmlFileTransportHandler(SockJsConfiguration sockJsConfig) { - super(sockJsConfig); - } - @Override public TransportType getTransportType() { return TransportType.HTML_FILE; diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpPollingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpPollingTransportHandler.java index 28fe888714..a00c5de6b7 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpPollingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpPollingTransportHandler.java @@ -21,10 +21,11 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; -import org.springframework.sockjs.server.SockJsConfiguration; +import org.springframework.sockjs.SockJsHandler; import org.springframework.sockjs.server.SockJsFrame; -import org.springframework.sockjs.server.TransportType; import org.springframework.sockjs.server.SockJsFrame.FrameFormat; +import org.springframework.sockjs.server.TransportType; +import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.util.JavaScriptUtils; @@ -38,10 +39,6 @@ import org.springframework.web.util.JavaScriptUtils; public class JsonpPollingTransportHandler extends AbstractHttpSendingTransportHandler { - public JsonpPollingTransportHandler(SockJsConfiguration sockJsConfig) { - super(sockJsConfig); - } - @Override public TransportType getTransportType() { return TransportType.JSONP; @@ -53,8 +50,9 @@ public class JsonpPollingTransportHandler extends AbstractHttpSendingTransportHa } @Override - public PollingHttpServerSession createSession(String sessionId) { - return new PollingHttpServerSession(sessionId, getSockJsConfig()); + public PollingHttpServerSession createSession(String sessionId, SockJsHandler sockJsHandler) { + Assert.notNull(getSockJsConfig(), "This transport requires SockJsConfiguration"); + return new PollingHttpServerSession(sessionId, getSockJsConfig(), sockJsHandler); } @Override @@ -67,7 +65,7 @@ public class JsonpPollingTransportHandler extends AbstractHttpSendingTransportHa response.getBody().write("\"callback\" parameter required".getBytes("UTF-8")); return; } - super.handleRequest(request, response, session); + super.handleRequestInternal(request, response, session); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpTransportHandler.java index 029e5c4ebb..95e90dccc0 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpTransportHandler.java @@ -44,7 +44,7 @@ public class JsonpTransportHandler extends AbstractHttpReceivingTransportHandler } } - super.handleRequest(request, response, sockJsSession); + super.handleRequestInternal(request, response, sockJsSession); response.getBody().write("ok".getBytes("UTF-8")); } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/PollingHttpServerSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/PollingHttpServerSession.java index 34302f610f..d48fa31a2c 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/PollingHttpServerSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/PollingHttpServerSession.java @@ -17,14 +17,15 @@ package org.springframework.sockjs.server.transport; import java.io.IOException; +import org.springframework.sockjs.SockJsHandler; import org.springframework.sockjs.server.SockJsConfiguration; import org.springframework.sockjs.server.SockJsFrame; public class PollingHttpServerSession extends AbstractHttpServerSession { - public PollingHttpServerSession(String sessionId, SockJsConfiguration sockJsConfig) { - super(sessionId, sockJsConfig); + public PollingHttpServerSession(String sessionId, SockJsConfiguration sockJsConfig, SockJsHandler sockJsHandler) { + super(sessionId, sockJsConfig, sockJsHandler); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java index c19818e03d..f90fbdb13e 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java @@ -43,8 +43,8 @@ public class SockJsWebSocketHandler extends AbstractSockJsWebSocketHandler { private final ObjectMapper objectMapper = new ObjectMapper(); - public SockJsWebSocketHandler(SockJsConfiguration config) { - super(config); + public SockJsWebSocketHandler(SockJsConfiguration sockJsConfig, SockJsHandler sockJsHandler) { + super(sockJsConfig, sockJsHandler); } @Override @@ -78,8 +78,8 @@ public class SockJsWebSocketHandler extends AbstractSockJsWebSocketHandler { private WebSocketSession webSocketSession; - public WebSocketServerSession(WebSocketSession wsSession, SockJsConfiguration config) throws Exception { - super(String.valueOf(wsSession.hashCode()), config); + public WebSocketServerSession(WebSocketSession wsSession, SockJsConfiguration sockJsConfig) throws Exception { + super(String.valueOf(wsSession.hashCode()), sockJsConfig, getSockJsHandler()); this.webSocketSession = wsSession; this.webSocketSession.sendText(SockJsFrame.openFrame().getContent()); scheduleHeartbeat(); diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/StreamingHttpServerSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/StreamingHttpServerSession.java index 94cf47cd9d..71bdf2d9dd 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/StreamingHttpServerSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/StreamingHttpServerSession.java @@ -18,6 +18,7 @@ package org.springframework.sockjs.server.transport; import java.io.IOException; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.sockjs.SockJsHandler; import org.springframework.sockjs.server.SockJsConfiguration; import org.springframework.sockjs.server.SockJsFrame; @@ -27,8 +28,8 @@ public class StreamingHttpServerSession extends AbstractHttpServerSession { private int byteCount; - public StreamingHttpServerSession(String sessionId, SockJsConfiguration sockJsConfig) { - super(sessionId, sockJsConfig); + public StreamingHttpServerSession(String sessionId, SockJsConfiguration sockJsConfig, SockJsHandler sockJsHandler) { + super(sessionId, sockJsConfig, sockJsHandler); } protected void flushCache() throws IOException { diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketSockJsHandlerAdapter.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketSockJsHandlerAdapter.java index 82e94f6788..9ec2a9e386 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketSockJsHandlerAdapter.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketSockJsHandlerAdapter.java @@ -36,12 +36,12 @@ import org.springframework.websocket.WebSocketSession; public class WebSocketSockJsHandlerAdapter extends AbstractSockJsWebSocketHandler { - public WebSocketSockJsHandlerAdapter(SockJsConfiguration sockJsConfig) { - super(sockJsConfig); + public WebSocketSockJsHandlerAdapter(SockJsConfiguration sockJsConfig, SockJsHandler sockJsHandler) { + super(sockJsConfig, sockJsHandler); } @Override - protected SockJsSessionSupport createSockJsSession(WebSocketSession wsSession) { + protected SockJsSessionSupport createSockJsSession(WebSocketSession wsSession) throws Exception { return new WebSocketSessionAdapter(wsSession); } @@ -51,9 +51,10 @@ public class WebSocketSockJsHandlerAdapter extends AbstractSockJsWebSocketHandle private final WebSocketSession wsSession; - public WebSocketSessionAdapter(WebSocketSession wsSession) { - super(String.valueOf(wsSession.hashCode()), getSockJsConfig().getSockJsHandler()); + public WebSocketSessionAdapter(WebSocketSession wsSession) throws Exception { + super(String.valueOf(wsSession.hashCode()), getSockJsHandler()); this.wsSession = wsSession; + connectionInitialized(); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketTransportHandler.java new file mode 100644 index 0000000000..3fc0a7f9f3 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketTransportHandler.java @@ -0,0 +1,125 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.sockjs.server.transport; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.sockjs.SockJsHandler; +import org.springframework.sockjs.SockJsSessionSupport; +import org.springframework.sockjs.server.ConfigurableTransportHandler; +import org.springframework.sockjs.server.SockJsConfiguration; +import org.springframework.sockjs.server.TransportHandler; +import org.springframework.sockjs.server.TransportType; +import org.springframework.util.Assert; +import org.springframework.websocket.WebSocketHandler; +import org.springframework.websocket.server.HandshakeHandler; + + +/** + * A WebSocket {@link TransportHandler} that delegates to a {@link HandshakeHandler} + * passing a SockJS {@link WebSocketHandler}. Also implements {@link HandshakeHandler} + * directly in support for raw WebSocket communication at SockJS URL "/websocket". + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class WebSocketTransportHandler implements ConfigurableTransportHandler, HandshakeHandler { + + private final HandshakeHandler handshakeHandler; + + private SockJsConfiguration sockJsConfig; + + private final Map sockJsHandlers = new HashMap(); + + private final Collection rawWebSocketHandlers = new ArrayList(); + + + public WebSocketTransportHandler(HandshakeHandler handshakeHandler) { + Assert.notNull(handshakeHandler, "handshakeHandler is required"); + this.handshakeHandler = handshakeHandler; + } + + @Override + public TransportType getTransportType() { + return TransportType.WEBSOCKET; + } + + @Override + public void setSockJsConfiguration(SockJsConfiguration sockJsConfig) { + this.sockJsConfig = sockJsConfig; + } + + @Override + public void registerSockJsHandlers(Collection sockJsHandlers) { + this.sockJsHandlers.clear(); + for (SockJsHandler sockJsHandler : sockJsHandlers) { + this.sockJsHandlers.put(sockJsHandler, adaptSockJsHandler(sockJsHandler)); + } + this.handshakeHandler.registerWebSocketHandlers(getAllWebSocketHandlers()); + } + + /** + * Adapt the {@link SockJsHandler} to the {@link WebSocketHandler} contract for + * exchanging SockJS message over WebSocket. + */ + protected WebSocketHandler adaptSockJsHandler(SockJsHandler sockJsHandler) { + return new SockJsWebSocketHandler(this.sockJsConfig, sockJsHandler); + } + + private Collection getAllWebSocketHandlers() { + Set handlers = new HashSet(); + handlers.addAll(this.sockJsHandlers.values()); + handlers.addAll(this.rawWebSocketHandlers); + return handlers; + } + + @Override + public void handleRequest(ServerHttpRequest request, ServerHttpResponse response, + SockJsHandler sockJsHandler, SockJsSessionSupport session) throws Exception { + + WebSocketHandler webSocketHandler = this.sockJsHandlers.get(sockJsHandler); + if (webSocketHandler == null) { + webSocketHandler = adaptSockJsHandler(sockJsHandler); + } + + this.handshakeHandler.doHandshake(request, response, webSocketHandler); + } + + // HandshakeHandler methods + + @Override + public void registerWebSocketHandlers(Collection webSocketHandlers) { + this.rawWebSocketHandlers.clear(); + this.rawWebSocketHandlers.addAll(webSocketHandlers); + this.handshakeHandler.registerWebSocketHandlers(getAllWebSocketHandlers()); + } + + @Override + public boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler webSocketHandler) throws Exception { + + return this.handshakeHandler.doHandshake(request, response, webSocketHandler); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrPollingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrPollingTransportHandler.java index 78733099e0..7655e42948 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrPollingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrPollingTransportHandler.java @@ -19,10 +19,11 @@ import java.nio.charset.Charset; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; -import org.springframework.sockjs.server.SockJsConfiguration; -import org.springframework.sockjs.server.TransportType; +import org.springframework.sockjs.SockJsHandler; import org.springframework.sockjs.server.SockJsFrame.DefaultFrameFormat; import org.springframework.sockjs.server.SockJsFrame.FrameFormat; +import org.springframework.sockjs.server.TransportType; +import org.springframework.util.Assert; /** @@ -34,10 +35,6 @@ import org.springframework.sockjs.server.SockJsFrame.FrameFormat; public class XhrPollingTransportHandler extends AbstractHttpSendingTransportHandler { - public XhrPollingTransportHandler(SockJsConfiguration sockJsConfig) { - super(sockJsConfig); - } - @Override public TransportType getTransportType() { return TransportType.XHR; @@ -53,8 +50,9 @@ public class XhrPollingTransportHandler extends AbstractHttpSendingTransportHand return new DefaultFrameFormat("%s\n"); } - public PollingHttpServerSession createSession(String sessionId) { - return new PollingHttpServerSession(sessionId, getSockJsConfig()); + public PollingHttpServerSession createSession(String sessionId, SockJsHandler sockJsHandler) { + Assert.notNull(getSockJsConfig(), "This transport requires SockJsConfiguration"); + return new PollingHttpServerSession(sessionId, getSockJsConfig(), sockJsHandler); } } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrStreamingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrStreamingTransportHandler.java index 0f49b75de3..cf40744284 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrStreamingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrStreamingTransportHandler.java @@ -21,10 +21,9 @@ import java.nio.charset.Charset; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; -import org.springframework.sockjs.server.SockJsConfiguration; -import org.springframework.sockjs.server.TransportType; import org.springframework.sockjs.server.SockJsFrame.DefaultFrameFormat; import org.springframework.sockjs.server.SockJsFrame.FrameFormat; +import org.springframework.sockjs.server.TransportType; /** @@ -36,10 +35,6 @@ import org.springframework.sockjs.server.SockJsFrame.FrameFormat; public class XhrStreamingTransportHandler extends AbstractStreamingTransportHandler { - public XhrStreamingTransportHandler(SockJsConfiguration sockJsConfig) { - super(sockJsConfig); - } - @Override public TransportType getTransportType() { return TransportType.XHR_STREAMING; diff --git a/spring-websocket/src/main/java/org/springframework/websocket/HandlerProvider.java b/spring-websocket/src/main/java/org/springframework/websocket/HandlerProvider.java new file mode 100644 index 0000000000..f8d2d90bd2 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/websocket/HandlerProvider.java @@ -0,0 +1,93 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.websocket; + +import org.apache.commons.logging.Log; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.beans.factory.config.AutowireCapableBeanFactory; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; + + +/** + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class HandlerProvider implements BeanFactoryAware { + + private final T handlerBean; + + private final Class handlerClass; + + private AutowireCapableBeanFactory beanFactory; + + private Log logger; + + + public HandlerProvider(T handlerBean) { + Assert.notNull(handlerBean, "handlerBean is required"); + this.handlerBean = handlerBean; + this.handlerClass = null; + } + + public HandlerProvider(Class handlerClass) { + Assert.notNull(handlerClass, "handlerClass is required"); + this.handlerBean = null; + this.handlerClass = handlerClass; + } + + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + if (beanFactory instanceof AutowireCapableBeanFactory) { + this.beanFactory = (AutowireCapableBeanFactory) beanFactory; + } + } + + public void setLogger(Log logger) { + this.logger = logger; + } + + public boolean isSingleton() { + return (this.handlerBean != null); + } + + @SuppressWarnings("unchecked") + public Class getHandlerType() { + if (this.handlerClass != null) { + return this.handlerClass; + } + return (Class) ClassUtils.getUserClass(this.handlerBean.getClass()); + } + + public T getHandler() { + if (this.handlerBean != null) { + if (logger != null && logger.isTraceEnabled()) { + logger.trace("Returning handler singleton " + this.handlerBean); + } + return this.handlerBean; + } + Assert.isTrue(this.beanFactory != null, "BeanFactory is required to initialize handler instances."); + if (logger != null && logger.isTraceEnabled()) { + logger.trace("Creating handler of type " + this.handlerClass); + } + return this.beanFactory.createBean(this.handlerClass); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/websocket/client/AbstractEndpointConnectionManager.java b/spring-websocket/src/main/java/org/springframework/websocket/client/AbstractEndpointConnectionManager.java index 8e5d3f754a..9a24e04af5 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/client/AbstractEndpointConnectionManager.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/client/AbstractEndpointConnectionManager.java @@ -26,13 +26,9 @@ import javax.websocket.WebSocketContainer; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.springframework.beans.BeansException; -import org.springframework.context.ApplicationContext; -import org.springframework.context.ApplicationContextAware; import org.springframework.context.SmartLifecycle; import org.springframework.core.task.SimpleAsyncTaskExecutor; import org.springframework.core.task.TaskExecutor; -import org.springframework.util.Assert; import org.springframework.web.util.UriComponentsBuilder; @@ -41,14 +37,10 @@ import org.springframework.web.util.UriComponentsBuilder; * @author Rossen Stoyanchev * @since 4.0 */ -public abstract class AbstractEndpointConnectionManager implements ApplicationContextAware, SmartLifecycle { +public abstract class AbstractEndpointConnectionManager implements SmartLifecycle { protected final Log logger = LogFactory.getLog(getClass()); - private final Class endpointClass; - - private final Object endpointBean; - private final URI uri; private boolean autoStartup = false; @@ -59,29 +51,13 @@ public abstract class AbstractEndpointConnectionManager implements ApplicationCo private Session session; - private ApplicationContext applicationContext; - private TaskExecutor taskExecutor = new SimpleAsyncTaskExecutor("EndpointConnectionManager-"); private final Object lifecycleMonitor = new Object(); - public AbstractEndpointConnectionManager(Class endpointClass, String uriTemplate, Object... uriVariables) { - Assert.notNull(endpointClass, "endpointClass is required"); - this.endpointClass = endpointClass; - this.endpointBean = null; - this.uri = initUri(uriTemplate, uriVariables); - } - - public AbstractEndpointConnectionManager(Object endpointBean, String uriTemplate, Object... uriVariables) { - Assert.notNull(endpointBean, "endpointBean is required"); - this.endpointClass = null; - this.endpointBean = endpointBean; - this.uri = initUri(uriTemplate, uriVariables); - } - - private static URI initUri(String uri, Object... uriVariables) { - return UriComponentsBuilder.fromUriString(uri).buildAndExpand(uriVariables).encode().toUri(); + public AbstractEndpointConnectionManager(String uriTemplate, Object... uriVariables) { + this.uri = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVariables).encode().toUri(); } public void setAsyncSendTimeout(long timeoutInMillis) { @@ -137,11 +113,6 @@ public abstract class AbstractEndpointConnectionManager implements ApplicationCo return this.phase; } - @Override - public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { - this.applicationContext = applicationContext; - } - protected URI getUri() { return this.uri; } @@ -150,17 +121,6 @@ public abstract class AbstractEndpointConnectionManager implements ApplicationCo return this.webSocketContainer; } - protected Object getEndpoint() { - if (this.endpointClass != null) { - Assert.notNull(this.applicationContext, - "An ApplicationContext is required to initialize endpoint instances per request."); - return this.applicationContext.getAutowireCapableBeanFactory().createBean(this.endpointClass); - } - else { - return this.endpointBean; - } - } - /** * Auto-connects to the configured {@link #setDefaultUri(URI) default URI}. */ @@ -173,10 +133,10 @@ public abstract class AbstractEndpointConnectionManager implements ApplicationCo synchronized (lifecycleMonitor) { try { logger.info("Connecting to endpoint at URI " + uri); - session = connect(getEndpoint()); + session = connect(); logger.info("Successfully connected"); } - catch (Exception ex) { + catch (Throwable ex) { logger.error("Failed to connect to endpoint at " + uri, ex); } } @@ -186,7 +146,7 @@ public abstract class AbstractEndpointConnectionManager implements ApplicationCo } } - protected abstract Session connect(Object endpoint) throws DeploymentException, IOException; + protected abstract Session connect() throws DeploymentException, IOException; /** * Deactivates the configured message endpoint. diff --git a/spring-websocket/src/main/java/org/springframework/websocket/client/AnnotatedEndpointConnectionManager.java b/spring-websocket/src/main/java/org/springframework/websocket/client/AnnotatedEndpointConnectionManager.java index ccc4c3f4a0..a1005e3507 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/client/AnnotatedEndpointConnectionManager.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/client/AnnotatedEndpointConnectionManager.java @@ -21,25 +21,47 @@ import java.io.IOException; import javax.websocket.DeploymentException; import javax.websocket.Session; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.websocket.HandlerProvider; + /** * * @author Rossen Stoyanchev * @since 4.0 */ -public class AnnotatedEndpointConnectionManager extends AbstractEndpointConnectionManager { +public class AnnotatedEndpointConnectionManager extends AbstractEndpointConnectionManager + implements BeanFactoryAware { + + private static Log logger = LogFactory.getLog(AnnotatedEndpointConnectionManager.class); + + private final HandlerProvider endpointProvider; public AnnotatedEndpointConnectionManager(Class endpointClass, String uriTemplate, Object... uriVariables) { - super(endpointClass, uriTemplate, uriVariables); + super(uriTemplate, uriVariables); + this.endpointProvider = new HandlerProvider(endpointClass); + this.endpointProvider.setLogger(logger); } public AnnotatedEndpointConnectionManager(Object endpointBean, String uriTemplate, Object... uriVariables) { - super(endpointBean, uriTemplate, uriVariables); + super(uriTemplate, uriVariables); + this.endpointProvider = new HandlerProvider(endpointBean); + this.endpointProvider.setLogger(logger); } @Override - protected Session connect(Object endpoint) throws DeploymentException, IOException { + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + this.endpointProvider.setBeanFactory(beanFactory); + } + + @Override + protected Session connect() throws DeploymentException, IOException { + Object endpoint = this.endpointProvider.getHandler(); return getWebSocketContainer().connectToServer(endpoint, getUri()); } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/client/EndpointConnectionManager.java b/spring-websocket/src/main/java/org/springframework/websocket/client/EndpointConnectionManager.java index 472a9df942..9ca8653bbf 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/client/EndpointConnectionManager.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/client/EndpointConnectionManager.java @@ -29,23 +29,41 @@ import javax.websocket.Endpoint; import javax.websocket.Extension; import javax.websocket.Session; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.util.Assert; +import org.springframework.websocket.HandlerProvider; + /** * * @author Rossen Stoyanchev * @since 4.0 */ -public class EndpointConnectionManager extends AbstractEndpointConnectionManager { +public class EndpointConnectionManager extends AbstractEndpointConnectionManager implements BeanFactoryAware { + + private static Log logger = LogFactory.getLog(EndpointConnectionManager.class); private final ClientEndpointConfig.Builder configBuilder = ClientEndpointConfig.Builder.create(); + private final HandlerProvider endpointProvider; + public EndpointConnectionManager(Class endpointClass, String uriTemplate, Object... uriVariables) { - super(endpointClass, uriTemplate, uriVariables); + super(uriTemplate, uriVariables); + Assert.notNull(endpointClass, "endpointClass is required"); + this.endpointProvider = new HandlerProvider(endpointClass); + this.endpointProvider.setLogger(logger); } public EndpointConnectionManager(Endpoint endpointBean, String uriTemplate, Object... uriVariables) { - super(endpointBean, uriTemplate, uriVariables); + super(uriTemplate, uriVariables); + Assert.notNull(endpointBean, "endpointBean is required"); + this.endpointProvider = new HandlerProvider(endpointBean); + this.endpointProvider.setLogger(logger); } public void setSubProtocols(String... subprotocols) { @@ -69,8 +87,13 @@ public class EndpointConnectionManager extends AbstractEndpointConnectionManager } @Override - protected Session connect(Object endpoint) throws DeploymentException, IOException { - Endpoint typedEndpoint = (Endpoint) endpoint; + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + this.endpointProvider.setBeanFactory(beanFactory); + } + + @Override + protected Session connect() throws DeploymentException, IOException { + Endpoint typedEndpoint = this.endpointProvider.getHandler(); ClientEndpointConfig endpointConfig = this.configBuilder.build(); return getWebSocketContainer().connectToServer(typedEndpoint, endpointConfig, getUri()); } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/AbstractHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/websocket/server/DefaultHandshakeHandler.java similarity index 65% rename from spring-websocket/src/main/java/org/springframework/websocket/server/AbstractHandshakeHandler.java rename to spring-websocket/src/main/java/org/springframework/websocket/server/DefaultHandshakeHandler.java index 7ff08b8097..cf4e4046f6 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/AbstractHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/DefaultHandshakeHandler.java @@ -21,6 +21,7 @@ import java.nio.charset.Charset; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.List; @@ -28,48 +29,53 @@ import javax.xml.bind.DatatypeConverter; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.springframework.beans.BeansException; -import org.springframework.beans.factory.BeanFactory; -import org.springframework.beans.factory.BeanFactoryAware; -import org.springframework.beans.factory.config.AutowireCapableBeanFactory; +import org.springframework.beans.BeanUtils; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; -import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; +import org.springframework.websocket.WebSocketHandler; /** + * TODO + *

+ * A container-specific {@link RequestUpgradeStrategy} is required since standard Java + * WebSocket currently does not provide a way to initiate a WebSocket handshake. + * Currently available are implementations for Tomcat and Glassfish. * * @author Rossen Stoyanchev * @since 4.0 */ -public abstract class AbstractHandshakeHandler implements HandshakeHandler, BeanFactoryAware { +public class DefaultHandshakeHandler implements HandshakeHandler { private static final String GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; protected Log logger = LogFactory.getLog(getClass()); - private final Object webSocketHandler; - - private final Class webSocketHandlerClass; - private List supportedProtocols; - private AutowireCapableBeanFactory beanFactory; + private RequestUpgradeStrategy requestUpgradeStrategy; - public AbstractHandshakeHandler(Object handler) { - Assert.notNull(handler, "webSocketHandler is required"); - this.webSocketHandler = handler; - this.webSocketHandlerClass = null; + /** + * Default constructor that auto-detects and instantiates a + * {@link RequestUpgradeStrategy} suitable for the runtime container. + * + * @throws IllegalStateException if no {@link RequestUpgradeStrategy} can be found. + */ + public DefaultHandshakeHandler() { + this.requestUpgradeStrategy = new RequestUpgradeStrategyFactory().create(); } - public AbstractHandshakeHandler(Class handlerClass) { - Assert.notNull((handlerClass), "handlerClass is required"); - this.webSocketHandler = null; - this.webSocketHandlerClass = handlerClass; + /** + * A constructor that accepts a runtime specific {@link RequestUpgradeStrategy}. + * @param upgradeStrategy the upgrade strategy + */ + public DefaultHandshakeHandler(RequestUpgradeStrategy upgradeStrategy) { + this.requestUpgradeStrategy = upgradeStrategy; } public void setSupportedProtocols(String... protocols) { @@ -81,24 +87,13 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean } @Override - public void setBeanFactory(BeanFactory beanFactory) throws BeansException { - if (beanFactory instanceof AutowireCapableBeanFactory) { - this.beanFactory = (AutowireCapableBeanFactory) beanFactory; - } - } - - protected Object getWebSocketHandler() { - if (this.webSocketHandlerClass != null) { - Assert.notNull(this.beanFactory, "BeanFactory is required for WebSocket handler instances per request."); - return this.beanFactory.createBean(this.webSocketHandlerClass); - } - else { - return this.webSocketHandler; - } + public void registerWebSocketHandlers(Collection handlers) { + this.requestUpgradeStrategy.registerWebSocketHandlers(handlers); } @Override - public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response) throws Exception { + public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler webSocketHandler) throws Exception { logger.debug("Starting handshake for " + request.getURI()); @@ -131,30 +126,29 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean return false; } - String protocol = selectProtocol(request.getHeaders().getSecWebSocketProtocol()); + String selectedProtocol = selectProtocol(request.getHeaders().getSecWebSocketProtocol()); // TODO: select extensions + logger.debug("Upgrading HTTP request"); + response.setStatusCode(HttpStatus.SWITCHING_PROTOCOLS); response.getHeaders().setUpgrade("WebSocket"); response.getHeaders().setConnection("Upgrade"); - response.getHeaders().setSecWebSocketProtocol(protocol); + response.getHeaders().setSecWebSocketProtocol(selectedProtocol); response.getHeaders().setSecWebSocketAccept(getWebSocketKeyHash(wsKey)); // TODO: response.getHeaders().setSecWebSocketExtensions(extensions); - logger.debug("Successfully negotiated WebSocket handshake"); + response.flush(); - // TODO: surely there is a better way to flush headers - response.getBody(); + if (logger.isTraceEnabled()) { + logger.trace("Upgrading with " + webSocketHandler); + } - doHandshakeInternal(request, response, protocol); + this.requestUpgradeStrategy.upgrade(request, response, selectedProtocol, webSocketHandler); return true; } - protected abstract void doHandshakeInternal(ServerHttpRequest request, ServerHttpResponse response, - String protocol) throws Exception; - - protected void handleInvalidUpgradeHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException { logger.debug("Invalid Upgrade header " + request.getHeaders().getUpgrade()); response.setStatusCode(HttpStatus.BAD_REQUEST); @@ -178,7 +172,7 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean } protected String[] getSupportedVerions() { - return new String[] { "13" }; + return this.requestUpgradeStrategy.getSupportedVersions(); } protected void handleWebSocketVersionNotSupported(ServerHttpRequest request, ServerHttpResponse response) { @@ -216,4 +210,35 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean return DatatypeConverter.printBase64Binary(bytes); } + + private static class RequestUpgradeStrategyFactory { + + private static final boolean tomcatWebSocketPresent = ClassUtils.isPresent( + "org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", DefaultHandshakeHandler.class.getClassLoader()); + + private static final boolean glassfishWebSocketPresent = ClassUtils.isPresent( + "org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", DefaultHandshakeHandler.class.getClassLoader()); + + + private RequestUpgradeStrategy create() { + String className; + if (tomcatWebSocketPresent) { + className = "org.springframework.websocket.server.support.TomcatRequestUpgradeStrategy"; + } + else if (glassfishWebSocketPresent) { + className = "org.springframework.websocket.server.support.GlassfishRequestUpgradeStrategy"; + } + else { + throw new IllegalStateException("No suitable " + RequestUpgradeStrategy.class.getSimpleName()); + } + try { + Class clazz = ClassUtils.forName(className, DefaultHandshakeHandler.class.getClassLoader()); + return (RequestUpgradeStrategy) BeanUtils.instantiateClass(clazz.getConstructor()); + } + catch (Throwable t) { + throw new IllegalStateException("Failed to instantiate " + className, t); + } + } + } + } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/HandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/websocket/server/HandshakeHandler.java index a8da4981ac..8222e7258e 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/HandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/HandshakeHandler.java @@ -16,18 +16,39 @@ package org.springframework.websocket.server; +import java.util.Collection; + import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.websocket.WebSocketHandler; /** - * Abstraction for integrating a WebSocket implementation some HTTP processing pipeline. + * Contract for processing a WebSocket handshake request. * * @author Rossen Stoyanchev * @since 4.0 */ public interface HandshakeHandler { - boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response) throws Exception; + /** + * Pre-register {@link WebSocketHandler} instances so they can be adapted to the + * underlying runtime and hence re-used at runtime when + * {@link #doHandshake(ServerHttpRequest, ServerHttpResponse, WebSocketHandler) doHandshake} + * is called. + */ + void registerWebSocketHandlers(Collection webSocketHandlers); + + /** + * + * @param request the HTTP request + * @param response the HTTP response + * @param webSocketMessageHandler the handler to process WebSocket messages with + * @return a boolean indicating whether the handshake negotiation was successful + * + * @throws Exception + */ + boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler webSocketHandler) + throws Exception; } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/RequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/websocket/server/RequestUpgradeStrategy.java new file mode 100644 index 0000000000..d0cf29fa17 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/RequestUpgradeStrategy.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.websocket.server; + +import java.util.Collection; + +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.websocket.WebSocketHandler; + + +/** + * A strategy for performing container-specific steps to upgrade an HTTP request during a + * WebSocket handshake. Intended for use within {@link HandshakeHandler} implementations. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public interface RequestUpgradeStrategy { + + /** + * Return the supported WebSocket protocol versions. + */ + String[] getSupportedVersions(); + + /** + * Pre-register {@link WebSocketHandler} instances so they can be adapted to the + * underlying runtime and hence re-used at runtime when + * {@link #upgrade(ServerHttpRequest, ServerHttpResponse, String, WebSocketHandler) + * upgrade} is called. + */ + void registerWebSocketHandlers(Collection webSocketHandlers); + + /** + * Perform runtime specific steps to complete the upgrade. + * Invoked only if the handshake is successful. + * + * @param webSocketHandler the handler for WebSocket messages + */ + void upgrade(ServerHttpRequest request, ServerHttpResponse response, String selectedProtocol, + WebSocketHandler webSocketHandler) throws Exception; + +} diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java index 3c33a31017..d9e36a61fb 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java @@ -35,9 +35,7 @@ import org.springframework.beans.BeansException; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.util.Assert; -import org.springframework.util.ClassUtils; -import org.springframework.web.context.ContextLoader; -import org.springframework.web.context.WebApplicationContext; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.endpoint.WebSocketHandlerEndpoint; @@ -60,9 +58,7 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw private final String path; - private final Class endpointClass; - - private final Object endpointBean; + private final HandlerProvider endpointProvider; private List> encoders = new ArrayList>(); @@ -76,8 +72,6 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw private Configurator configurator = new Configurator() {}; - private BeanFactory beanFactory; - /** * Class constructor with the {@code javax.webscoket.Endpoint} class. @@ -87,23 +81,19 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw * @param endpointClass */ public EndpointRegistration(String path, Class endpointClass) { - this(path, endpointClass, null); - } - - public EndpointRegistration(String path, Object bean) { - this(path, null, bean); - } - - public EndpointRegistration(String path, String beanName) { - this(path, null, beanName); - } - - private EndpointRegistration(String path, Class endpointClass, Object bean) { Assert.hasText(path, "path must not be empty"); - Assert.isTrue((endpointClass != null || bean != null), "Neither endpoint class nor endpoint bean provided"); + Assert.notNull(endpointClass, "endpointClass is required"); this.path = path; - this.endpointClass = endpointClass; - this.endpointBean = bean; + this.endpointProvider = new HandlerProvider(endpointClass); + this.endpointProvider.setLogger(logger); + } + + public EndpointRegistration(String path, Endpoint endpointBean) { + Assert.hasText(path, "path must not be empty"); + Assert.notNull(endpointBean, "endpointBean is required"); + this.path = path; + this.endpointProvider = new HandlerProvider(endpointBean); + this.endpointProvider.setLogger(logger); } @Override @@ -111,40 +101,13 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw return this.path; } - @SuppressWarnings("unchecked") @Override public Class getEndpointClass() { - if (this.endpointClass != null) { - return this.endpointClass; - } - Class beanClass = this.endpointBean.getClass(); - if (beanClass.equals(String.class)) { - beanClass = this.beanFactory.getType((String) this.endpointBean); - } - beanClass = ClassUtils.getUserClass(beanClass); - if (Endpoint.class.isAssignableFrom(beanClass)) { - return (Class) beanClass; - } - else { - throw new IllegalStateException("Invalid endpoint bean: must be of type ... TODO "); - } + return this.endpointProvider.getHandlerType(); } public Endpoint getEndpoint() { - if (this.endpointClass != null) { - WebApplicationContext wac = ContextLoader.getCurrentWebApplicationContext(); - if (wac == null) { - String message = "Failed to find the root WebApplicationContext. Was ContextLoaderListener not used?"; - logger.error(message); - throw new IllegalStateException(); - } - return wac.getAutowireCapableBeanFactory().createBean(this.endpointClass); - } - Object bean = this.endpointBean; - if (this.endpointBean instanceof String) { - bean = this.beanFactory.getBean((String) this.endpointBean); - } - return (Endpoint) bean; + return this.endpointProvider.getHandler(); } public void setSubprotocols(List subprotocols) { @@ -194,11 +157,6 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw return this.decoders; } - @Override - public void setBeanFactory(BeanFactory beanFactory) throws BeansException { - this.beanFactory = beanFactory; - } - /** * The {@link Configurator#getEndpointInstance(Class)} method is always ignored. */ @@ -233,4 +191,9 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw }; } + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + this.endpointProvider.setBeanFactory(beanFactory); + } + } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/SpringConfigurator.java b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/SpringConfigurator.java index 6ed65012e6..1129cf6e66 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/SpringConfigurator.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/SpringConfigurator.java @@ -64,10 +64,10 @@ public class SpringConfigurator extends Configurator { return beans.values().iterator().next(); } else { - // This should never happen .. + // Should not happen .. String message = "Found more than one matching @ServerEndpoint beans of type " + endpointClass; logger.error(message); - throw new IllegalStateException("Found more than one matching beans of type " + endpointClass); + throw new IllegalStateException(message); } } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/EndpointHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/EndpointHandshakeHandler.java deleted file mode 100644 index 089a9f8efd..0000000000 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/EndpointHandshakeHandler.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.websocket.server.endpoint.handshake; - -import javax.websocket.Endpoint; - -import org.springframework.beans.BeanUtils; -import org.springframework.http.server.ServerHttpRequest; -import org.springframework.http.server.ServerHttpResponse; -import org.springframework.util.ClassUtils; -import org.springframework.websocket.WebSocketHandler; -import org.springframework.websocket.endpoint.WebSocketHandlerEndpoint; -import org.springframework.websocket.server.AbstractHandshakeHandler; -import org.springframework.websocket.server.HandshakeHandler; - - -/** - * A {@link HandshakeHandler} for use with standard Java WebSocket runtimes. A - * container-specific {@link RequestUpgradeStrategy} is required since standard - * Java WebSocket currently does not provide any means of integrating a WebSocket - * handshake into an HTTP request processing pipeline. Currently available are - * implementations for Tomcat and Glassfish. - * - * @author Rossen Stoyanchev - * @since 4.0 - */ -public class EndpointHandshakeHandler extends AbstractHandshakeHandler { - - private final RequestUpgradeStrategy upgradeStrategy; - - - public EndpointHandshakeHandler(Endpoint endpoint) { - super(endpoint); - this.upgradeStrategy = createRequestUpgradeStrategy(); - } - - public EndpointHandshakeHandler(WebSocketHandler webSocketHandler) { - super(webSocketHandler); - this.upgradeStrategy = createRequestUpgradeStrategy(); - } - - public EndpointHandshakeHandler(Class handlerClass) { - super(handlerClass); - this.upgradeStrategy = createRequestUpgradeStrategy(); - } - - protected RequestUpgradeStrategy createRequestUpgradeStrategy() { - return new RequestUpgradeStrategyFactory().create(); - } - - @Override - protected String[] getSupportedVerions() { - return this.upgradeStrategy.getSupportedVersions(); - } - - @Override - public void doHandshakeInternal(ServerHttpRequest request, ServerHttpResponse response, String protocol) - throws Exception { - - logger.debug("Upgrading HTTP request"); - - Object webSocketHandler = getWebSocketHandler(); - - Endpoint endpoint; - if (webSocketHandler instanceof Endpoint) { - endpoint = (Endpoint) webSocketHandler; - } - else if (webSocketHandler instanceof WebSocketHandler) { - endpoint = new WebSocketHandlerEndpoint((WebSocketHandler) webSocketHandler); - } - else { - String className = webSocketHandler.getClass().getName(); - throw new IllegalArgumentException("Unexpected WebSocket handler type: " + className); - } - - this.upgradeStrategy.upgrade(request, response, protocol, endpoint); - } - - - private static class RequestUpgradeStrategyFactory { - - private static final String packageName = EndpointHandshakeHandler.class.getPackage().getName(); - - private static final boolean tomcatWebSocketPresent = ClassUtils.isPresent( - "org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", EndpointHandshakeHandler.class.getClassLoader()); - - private static final boolean glassfishWebSocketPresent = ClassUtils.isPresent( - "org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", EndpointHandshakeHandler.class.getClassLoader()); - - - private RequestUpgradeStrategy create() { - String className; - if (tomcatWebSocketPresent) { - className = packageName + ".TomcatRequestUpgradeStrategy"; - } - else if (glassfishWebSocketPresent) { - className = packageName + ".GlassfishRequestUpgradeStrategy"; - } - else { - throw new IllegalStateException("No suitable " + RequestUpgradeStrategy.class.getSimpleName()); - } - try { - Class clazz = ClassUtils.forName(className, EndpointHandshakeHandler.class.getClassLoader()); - return (RequestUpgradeStrategy) BeanUtils.instantiateClass(clazz.getConstructor()); - } - catch (Throwable t) { - throw new IllegalStateException("Failed to instantiate " + className, t); - } - } - } - -} \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/package-info.java b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/package-info.java deleted file mode 100644 index 8adf2395cb..0000000000 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/package-info.java +++ /dev/null @@ -1,8 +0,0 @@ - -/** - * WebSocket handshake support for use with standard Java WebSocket runtimes including - * container-specific strategies for upgrading the HttpServletRequest. - * - */ -package org.springframework.websocket.server.endpoint.handshake; - diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/package-info.java b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/package-info.java index ce327f239e..34a043af9c 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/package-info.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/package-info.java @@ -6,7 +6,7 @@ * registering type-based endpoints, * {@link org.springframework.websocket.server.endpoint.SpringConfigurator} for * instantiating annotated endpoints through Spring, and - * {@link org.springframework.websocket.server.endpoint.handshake.EndpointHandshakeHandler} + * {@link org.springframework.websocket.server.support.EndpointHandshakeHandler} * for integrating endpoints into HTTP request processing. */ package org.springframework.websocket.server.endpoint; diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/support/AbstractEndpointUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/websocket/server/support/AbstractEndpointUpgradeStrategy.java new file mode 100644 index 0000000000..1e7f1a2794 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/support/AbstractEndpointUpgradeStrategy.java @@ -0,0 +1,80 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.websocket.server.support; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; + +import javax.websocket.Endpoint; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.websocket.WebSocketHandler; +import org.springframework.websocket.endpoint.WebSocketHandlerEndpoint; +import org.springframework.websocket.server.RequestUpgradeStrategy; + + +/** + * A {@link RequestUpgradeStrategy} that supports WebSocket handlers of type + * {@link WebSocketHandler} as well as {@link javax.websocket.Endpoint}. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public abstract class AbstractEndpointUpgradeStrategy implements RequestUpgradeStrategy { + + protected final Log logger = LogFactory.getLog(getClass()); + + private final Map webSocketHandlers = new HashMap(); + + + @Override + public void registerWebSocketHandlers(Collection webSocketHandlers) { + for (WebSocketHandler webSocketHandler : webSocketHandlers) { + if (!this.webSocketHandlers.containsKey(webSocketHandler)) { + this.webSocketHandlers.put(webSocketHandler, adaptWebSocketHandler(webSocketHandler)); + } + } + } + + protected Endpoint adaptWebSocketHandler(WebSocketHandler handler) { + return new WebSocketHandlerEndpoint(handler); + } + + @Override + public void upgrade(ServerHttpRequest request, ServerHttpResponse response, + String protocol, WebSocketHandler webSocketHandler) throws Exception { + + Endpoint endpoint = this.webSocketHandlers.get(webSocketHandler); + if (endpoint == null) { + endpoint = adaptWebSocketHandler(webSocketHandler); + } + + if (logger.isTraceEnabled()) { + logger.trace("Upgrading with " + endpoint); + } + + upgradeInternal(request, response, protocol, endpoint); + } + + protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, + String protocol, Endpoint endpoint) throws Exception; + +} diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/GlassfishRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/websocket/server/support/GlassfishRequestUpgradeStrategy.java similarity index 95% rename from spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/GlassfishRequestUpgradeStrategy.java rename to spring-websocket/src/main/java/org/springframework/websocket/server/support/GlassfishRequestUpgradeStrategy.java index 6a23883fb5..1a7619af45 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/GlassfishRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/support/GlassfishRequestUpgradeStrategy.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.websocket.server.endpoint.handshake; +package org.springframework.websocket.server.support; import java.lang.reflect.Constructor; import java.net.URI; @@ -54,7 +54,7 @@ import org.springframework.websocket.server.endpoint.EndpointRegistration; * @author Rossen Stoyanchev * @since 4.0 */ -public class GlassfishRequestUpgradeStrategy implements RequestUpgradeStrategy { +public class GlassfishRequestUpgradeStrategy extends AbstractEndpointUpgradeStrategy { private final static Random random = new Random(); @@ -65,8 +65,8 @@ public class GlassfishRequestUpgradeStrategy implements RequestUpgradeStrategy { } @Override - public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String protocol, - Endpoint endpoint) throws Exception { + public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, + String protocol, Endpoint endpoint) throws Exception { Assert.isTrue(request instanceof ServletServerHttpRequest); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/TomcatRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/websocket/server/support/TomcatRequestUpgradeStrategy.java similarity index 90% rename from spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/TomcatRequestUpgradeStrategy.java rename to spring-websocket/src/main/java/org/springframework/websocket/server/support/TomcatRequestUpgradeStrategy.java index b07cd2e0df..ac74c9e5f1 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/TomcatRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/support/TomcatRequestUpgradeStrategy.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.websocket.server.endpoint.handshake; +package org.springframework.websocket.server.support; import java.io.IOException; import java.lang.reflect.Method; @@ -42,7 +42,7 @@ import org.springframework.websocket.server.endpoint.EndpointRegistration; * @author Rossen Stoyanchev * @since 4.0 */ -public class TomcatRequestUpgradeStrategy implements RequestUpgradeStrategy { +public class TomcatRequestUpgradeStrategy extends AbstractEndpointUpgradeStrategy { @Override @@ -51,8 +51,8 @@ public class TomcatRequestUpgradeStrategy implements RequestUpgradeStrategy { } @Override - public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String protocol, - Endpoint endpoint) throws IOException { + public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, + String protocol, Endpoint endpoint) throws IOException { Assert.isTrue(request instanceof ServletServerHttpRequest); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/support/HandshakeHttpRequestHandler.java b/spring-websocket/src/main/java/org/springframework/websocket/server/support/WebSocketHttpRequestHandler.java similarity index 50% rename from spring-websocket/src/main/java/org/springframework/websocket/server/support/HandshakeHttpRequestHandler.java rename to spring-websocket/src/main/java/org/springframework/websocket/server/support/WebSocketHttpRequestHandler.java index 26f5d59ae1..6e641eaba0 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/support/HandshakeHttpRequestHandler.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/support/WebSocketHttpRequestHandler.java @@ -17,11 +17,15 @@ package org.springframework.websocket.server.support; import java.io.IOException; +import java.util.Collections; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; @@ -29,24 +33,48 @@ import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; import org.springframework.web.HttpRequestHandler; import org.springframework.web.util.NestedServletException; +import org.springframework.websocket.HandlerProvider; +import org.springframework.websocket.WebSocketHandler; import org.springframework.websocket.server.HandshakeHandler; +import org.springframework.websocket.server.DefaultHandshakeHandler; /** - * A Spring MVC {@link HttpRequestHandler} wrapping the invocation of a WebSocket - * {@link HandshakeHandler}; + * An {@link HttpRequestHandler} that wraps the invocation of a {@link HandshakeHandler}. * * @author Rossen Stoyanchev * @since 4.0 */ -public class HandshakeHttpRequestHandler implements HttpRequestHandler { +public class WebSocketHttpRequestHandler implements HttpRequestHandler, BeanFactoryAware { - private final HandshakeHandler handshakeHandler; + private HandshakeHandler handshakeHandler; + + private final HandlerProvider handlerProvider; - public HandshakeHttpRequestHandler(HandshakeHandler handshakeHandler) { + public WebSocketHttpRequestHandler(WebSocketHandler webSocketHandler) { + Assert.notNull(webSocketHandler, "webSocketHandler is required"); + this.handlerProvider = new HandlerProvider(webSocketHandler); + this.handshakeHandler = new DefaultHandshakeHandler(); + this.handshakeHandler.registerWebSocketHandlers(Collections.singleton(webSocketHandler)); + } + + public WebSocketHttpRequestHandler( Class webSocketHandlerClass) { + Assert.notNull(webSocketHandlerClass, "webSocketHandlerClass is required"); + this.handlerProvider = new HandlerProvider(webSocketHandlerClass); + } + + public void setHandshakeHandler(HandshakeHandler handshakeHandler) { Assert.notNull(handshakeHandler, "handshakeHandler is required"); this.handshakeHandler = handshakeHandler; + if (this.handlerProvider.isSingleton()) { + this.handshakeHandler.registerWebSocketHandlers(Collections.singleton(this.handlerProvider.getHandler())); + } + } + + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + this.handlerProvider.setBeanFactory(beanFactory); } @Override @@ -57,12 +85,16 @@ public class HandshakeHttpRequestHandler implements HttpRequestHandler { ServerHttpResponse httpResponse = new ServletServerHttpResponse(response); try { - this.handshakeHandler.doHandshake(httpRequest, httpResponse); + WebSocketHandler webSocketHandler = this.handlerProvider.getHandler(); + this.handshakeHandler.doHandshake(httpRequest, httpResponse, webSocketHandler); } catch (Exception e) { // TODO throw new NestedServletException("HandshakeHandler failure", e); } + finally { + httpResponse.flush(); + } } } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/support/package-info.java b/spring-websocket/src/main/java/org/springframework/websocket/server/support/package-info.java index c76d943d9f..ac054317dd 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/support/package-info.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/support/package-info.java @@ -1,6 +1,6 @@ /** - * Server-side support classes for WebSocket applications. + * Server-side support classes including container-specific strategies for upgrading a request. * */ package org.springframework.websocket.server.support;