diff --git a/build.gradle b/build.gradle index e56d3060d1..3bb431d2ac 100644 --- a/build.gradle +++ b/build.gradle @@ -515,13 +515,24 @@ project("spring-websocket") { compile(project(":spring-core")) compile(project(":spring-context")) compile(project(":spring-web")) - optional("javax.websocket:javax.websocket-api:1.0-b14") + + optional("org.apache.tomcat:tomcat-servlet-api:8.0-SNAPSHOT") // TODO: replace with "javax.servlet:javax.servlet-api" + optional("org.apache.tomcat:tomcat-websocket-api:8.0-SNAPSHOT") // TODO: replace with "javax.websocket:javax.websocket-api" + + optional("org.apache.tomcat:tomcat-websocket:8.0-SNAPSHOT") { + exclude group: "org.apache.tomcat", module: "tomcat-websocket-api" + exclude group: "org.apache.tomcat", module: "tomcat-servlet-api" + } + + optional("org.eclipse.jetty:jetty-websocket:8.1.10.v20130312") + optional("org.glassfish.tyrus:tyrus-websocket-core:1.0-SNAPSHOT") } repositories { maven { url "http://repo.springsource.org/libs-release" } - maven { url "https://repository.apache.org" } // tomcat-websocket snapshot maven { url "https://maven.java.net/content/groups/public/" } // javax.websocket-* + maven { url "https://repository.apache.org/content/repositories/snapshots" } // tomcat-websocket snapshots + maven { url "https://maven.java.net/content/repositories/snapshots" } // tyrus/glassfish snapshots } } diff --git a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java index 5d480d5e4b..e972709572 100644 --- a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java +++ b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java @@ -17,14 +17,10 @@ package org.springframework.http; import java.io.Serializable; - import java.net.URI; - import java.nio.charset.Charset; - import java.text.ParseException; import java.text.SimpleDateFormat; - import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -40,6 +36,7 @@ import java.util.Set; import java.util.TimeZone; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedCaseInsensitiveMap; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; @@ -71,6 +68,8 @@ public class HttpHeaders implements MultiValueMap, Serializable private static final String CACHE_CONTROL = "Cache-Control"; + private static final String CONNECTION = "Connection"; + private static final String CONTENT_DISPOSITION = "Content-Disposition"; private static final String CONTENT_LENGTH = "Content-Length"; @@ -91,8 +90,22 @@ public class HttpHeaders implements MultiValueMap, Serializable private static final String LOCATION = "Location"; + private static final String ORIGIN = "Origin"; + + private static final String SEC_WEBSOCKET_ACCEPT = "Sec-WebSocket-Accept"; + + private static final String SEC_WEBSOCKET_EXTENSIONS = "Sec-WebSocket-Extensions"; + + private static final String SEC_WEBSOCKET_KEY = "Sec-WebSocket-Key"; + + private static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"; + + private static final String SEC_WEBSOCKET_VERSION = "Sec-WebSocket-Version"; + private static final String PRAGMA = "Pragma"; + private static final String UPGARDE = "Upgrade"; + private static final String[] DATE_FORMATS = new String[] { "EEE, dd MMM yyyy HH:mm:ss zzz", @@ -251,6 +264,30 @@ public class HttpHeaders implements MultiValueMap, Serializable return getFirst(CACHE_CONTROL); } + /** + * Sets the (new) value of the {@code Connection} header. + * @param connection the value of the header + */ + public void setConnection(String connection) { + set(CONNECTION, connection); + } + + /** + * Sets the (new) value of the {@code Connection} header. + * @param connection the value of the header + */ + public void setConnection(List connection) { + set(CONNECTION, toCommaDelimitedString(connection)); + } + + /** + * Returns the value of the {@code Connection} header. + * @return the value of the header + */ + public List getConnection() { + return getFirstValueAsList(CONNECTION); + } + /** * Sets the (new) value of the {@code Content-Disposition} header for {@code form-data}. * @param name the control name @@ -393,15 +430,19 @@ public class HttpHeaders implements MultiValueMap, Serializable * @param ifNoneMatchList the new value of the header */ public void setIfNoneMatch(List ifNoneMatchList) { + set(IF_NONE_MATCH, toCommaDelimitedString(ifNoneMatchList)); + } + + private String toCommaDelimitedString(List list) { StringBuilder builder = new StringBuilder(); - for (Iterator iterator = ifNoneMatchList.iterator(); iterator.hasNext();) { + for (Iterator iterator = list.iterator(); iterator.hasNext();) { String ifNoneMatch = iterator.next(); builder.append(ifNoneMatch); if (iterator.hasNext()) { builder.append(", "); } } - set(IF_NONE_MATCH, builder.toString()); + return builder.toString(); } /** @@ -409,9 +450,13 @@ public class HttpHeaders implements MultiValueMap, Serializable * @return the header value */ public List getIfNoneMatch() { + return getFirstValueAsList(IF_NONE_MATCH); + } + + private List getFirstValueAsList(String header) { List result = new ArrayList(); - String value = getFirst(IF_NONE_MATCH); + String value = getFirst(header); if (value != null) { String[] tokens = value.split(",\\s*"); for (String token : tokens) { @@ -457,6 +502,130 @@ public class HttpHeaders implements MultiValueMap, Serializable return (value != null ? URI.create(value) : null); } + /** + * Sets the (new) value of the {@code Origin} header. + * @param origin the value of the header + */ + public void setOrigin(String origin) { + set(ORIGIN, origin); + } + + /** + * Returns the value of the {@code Origin} header. + * @return the value of the header + */ + public String getOrigin() { + return getFirst(ORIGIN); + } + + /** + * Sets the (new) value of the {@code Sec-WebSocket-Accept} header. + * @param secWebSocketAccept the value of the header + */ + public void setSecWebSocketAccept(String secWebSocketAccept) { + set(SEC_WEBSOCKET_ACCEPT, secWebSocketAccept); + } + + /** + * Returns the value of the {@code Sec-WebSocket-Accept} header. + * @return the value of the header + */ + public String getSecWebSocketAccept() { + return getFirst(SEC_WEBSOCKET_ACCEPT); + } + + /** + * Returns the value of the {@code Sec-WebSocket-Extensions} header. + * @return the value of the header + */ + public List getSecWebSocketExtensions() { + List values = get(SEC_WEBSOCKET_EXTENSIONS); + if (CollectionUtils.isEmpty(values)) { + return Collections.emptyList(); + } + else if (values.size() == 1) { + return getFirstValueAsList(SEC_WEBSOCKET_EXTENSIONS); + } + else { + return values; + } + } + + /** + * Sets the (new) value of the {@code Sec-WebSocket-Extensions} header. + * @param secWebSocketExtensions the value of the header + */ + public void setSecWebSocketExtensions(List secWebSocketExtensions) { + set(SEC_WEBSOCKET_EXTENSIONS, toCommaDelimitedString(secWebSocketExtensions)); + } + + /** + * Sets the (new) value of the {@code Sec-WebSocket-Key} header. + * @param secWebSocketKey the value of the header + */ + public void setSecWebSocketKey(String secWebSocketKey) { + set(SEC_WEBSOCKET_KEY, secWebSocketKey); + } + + /** + * Returns the value of the {@code Sec-WebSocket-Key} header. + * @return the value of the header + */ + public String getSecWebSocketKey() { + return getFirst(SEC_WEBSOCKET_KEY); + } + + /** + * Sets the (new) value of the {@code Sec-WebSocket-Protocol} header. + * @param secWebSocketProtocol the value of the header + */ + public void setSecWebSocketProtocol(String secWebSocketProtocol) { + if (secWebSocketProtocol != null) { + set(SEC_WEBSOCKET_PROTOCOL, secWebSocketProtocol); + } + } + + /** + * Sets the (new) value of the {@code Sec-WebSocket-Protocol} header. + * @param secWebSocketProtocols the value of the header + */ + public void setSecWebSocketProtocol(List secWebSocketProtocols) { + set(SEC_WEBSOCKET_PROTOCOL, toCommaDelimitedString(secWebSocketProtocols)); + } + + /** + * Returns the value of the {@code Sec-WebSocket-Key} header. + * @return the value of the header + */ + public List getSecWebSocketProtocol() { + List values = get(SEC_WEBSOCKET_PROTOCOL); + if (CollectionUtils.isEmpty(values)) { + return Collections.emptyList(); + } + else if (values.size() == 1) { + return getFirstValueAsList(SEC_WEBSOCKET_PROTOCOL); + } + else { + return values; + } + } + + /** + * Sets the (new) value of the {@code Sec-WebSocket-Version} header. + * @param secWebSocketKey the value of the header + */ + public void setSecWebSocketVersion(String secWebSocketVersion) { + set(SEC_WEBSOCKET_VERSION, secWebSocketVersion); + } + + /** + * Returns the value of the {@code Sec-WebSocket-Version} header. + * @return the value of the header + */ + public String getSecWebSocketVersion() { + return getFirst(SEC_WEBSOCKET_VERSION); + } + /** * Sets the (new) value of the {@code Pragma} header. * @param pragma the value of the header @@ -473,6 +642,22 @@ public class HttpHeaders implements MultiValueMap, Serializable return getFirst(PRAGMA); } + /** + * Sets the (new) value of the {@code Upgrade} header. + * @param upgrade the value of the header + */ + public void setUpgrade(String upgrade) { + set(UPGARDE, upgrade); + } + + /** + * Returns the value of the {@code Upgrade} header. + * @return the value of the header + */ + public String getUpgrade() { + return getFirst(UPGARDE); + } + // Utility methods private long getFirstDate(String headerName) { diff --git a/spring-websocket/src/main/java/org/springframework/websocket/HandshakeRequestHandler.java b/spring-websocket/src/main/java/org/springframework/websocket/Session.java similarity index 74% rename from spring-websocket/src/main/java/org/springframework/websocket/HandshakeRequestHandler.java rename to spring-websocket/src/main/java/org/springframework/websocket/Session.java index 35324cfd6a..c7ccb12650 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/HandshakeRequestHandler.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/Session.java @@ -16,17 +16,16 @@ package org.springframework.websocket; -import org.springframework.http.server.ServerHttpRequest; -import org.springframework.http.server.ServerHttpResponse; /** * * @author Rossen Stoyanchev */ -public interface HandshakeRequestHandler { +public interface Session { + void sendText(String text) throws Exception; - boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response); + void close(int code, String reason) throws Exception; } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/WebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/websocket/WebSocketHandler.java new file mode 100644 index 0000000000..b4bfb6c1e4 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/websocket/WebSocketHandler.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.websocket; + +import java.io.InputStream; + + +/** + * + * @author Rossen Stoyanchev + */ +public interface WebSocketHandler { + + void newSession(Session session) throws Exception; + + void handleTextMessage(Session session, String message) throws Exception; + + void handleBinaryMessage(Session session, InputStream message) throws Exception; + + void handleException(Session session, Throwable exception); + + void sessionClosed(Session session, int statusCode, String reason) throws Exception; + +} diff --git a/spring-websocket/src/main/java/org/springframework/websocket/WebSocketHandlerAdapter.java b/spring-websocket/src/main/java/org/springframework/websocket/WebSocketHandlerAdapter.java new file mode 100644 index 0000000000..a1f4154dc3 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/websocket/WebSocketHandlerAdapter.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.websocket; + +import java.io.InputStream; + +/** + * + * @author Rossen Stoyanchev + */ +public class WebSocketHandlerAdapter implements WebSocketHandler { + + @Override + public void newSession(Session session) throws Exception { + } + + @Override + public void handleTextMessage(Session session, String message) throws Exception { + } + + @Override + public void handleBinaryMessage(Session session, InputStream message) throws Exception { + } + + @Override + public void handleException(Session session, Throwable exception) { + } + + @Override + public void sessionClosed(Session session, int statusCode, String reason) throws Exception { + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/websocket/support/ServerEndpointPostProcessor.java b/spring-websocket/src/main/java/org/springframework/websocket/support/ServerEndpointPostProcessor.java new file mode 100644 index 0000000000..fc08356b07 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/websocket/support/ServerEndpointPostProcessor.java @@ -0,0 +1,141 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.websocket.support; + +import javax.servlet.ServletContext; +import javax.websocket.DeploymentException; +import javax.websocket.server.ServerContainer; +import javax.websocket.server.ServerContainerProvider; +import javax.websocket.server.ServerEndpointConfig; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.tomcat.websocket.server.WsServerContainer; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.config.BeanPostProcessor; +import org.springframework.util.Assert; +import org.springframework.web.context.ServletContextAware; + +/** + * BeanPostProcessor that registers {@link javax.websocket.server.ServerEndpointConfig} + * beans with a standard Java WebSocket runtime and also configures the underlying + * {@link javax.websocket.server.ServerContainer}. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class ServerEndpointPostProcessor implements ServletContextAware, BeanPostProcessor, InitializingBean { + + private static Log logger = LogFactory.getLog(ServerEndpointPostProcessor.class); + + private Long maxSessionIdleTimeout; + + private Integer maxTextMessageBufferSize; + + private Integer maxBinaryMessageBufferSize; + + private ServletContext servletContext; + + + /** + * If this property set it is in turn used to configure + * {@link ServerContainer#setDefaultMaxSessionIdleTimeout(long)}. + */ + public void setMaxSessionIdleTimeout(long maxSessionIdleTimeout) { + this.maxSessionIdleTimeout = maxSessionIdleTimeout; + } + + public Long getMaxSessionIdleTimeout() { + return this.maxSessionIdleTimeout; + } + + /** + * If this property set it is in turn used to configure + * {@link ServerContainer#setDefaultMaxTextMessageBufferSize(int)} + */ + public void setMaxTextMessageBufferSize(int maxTextMessageBufferSize) { + this.maxTextMessageBufferSize = maxTextMessageBufferSize; + } + + public Integer getMaxTextMessageBufferSize() { + return this.maxTextMessageBufferSize; + } + + /** + * If this property set it is in turn used to configure + * {@link ServerContainer#setDefaultMaxBinaryMessageBufferSize(int)}. + */ + public void setMaxBinaryMessageBufferSize(int maxBinaryMessageBufferSize) { + this.maxBinaryMessageBufferSize = maxBinaryMessageBufferSize; + } + + public Integer getMaxBinaryMessageBufferSize() { + return this.maxBinaryMessageBufferSize; + } + + @Override + public void setServletContext(ServletContext servletContext) { + this.servletContext = servletContext; + } + + public ServletContext getServletContext() { + return servletContext; + } + + @Override + public void afterPropertiesSet() throws Exception { + + ServerContainer serverContainer = ServerContainerProvider.getServerContainer(); + Assert.notNull(serverContainer, "javax.websocket.server.ServerContainer not available"); + + if (this.maxSessionIdleTimeout != null) { + serverContainer.setDefaultMaxSessionIdleTimeout(this.maxSessionIdleTimeout); + } + if (this.maxTextMessageBufferSize != null) { + serverContainer.setDefaultMaxTextMessageBufferSize(this.maxTextMessageBufferSize); + } + if (this.maxBinaryMessageBufferSize != null) { + serverContainer.setDefaultMaxBinaryMessageBufferSize(this.maxBinaryMessageBufferSize); + } + + // TODO: this is necessary but only done on Tomcat + WsServerContainer sc = WsServerContainer.getServerContainer(); + sc.setServletContext(this.servletContext); + } + + @Override + public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { + if (bean instanceof ServerEndpointConfig) { + ServerEndpointConfig sec = (ServerEndpointConfig) bean; + ServerContainer serverContainer = ServerContainerProvider.getServerContainer(); + try { + logger.debug("Registering javax.websocket.Endpoint for path " + sec.getPath()); + serverContainer.addEndpoint(sec); + } + catch (DeploymentException e) { + throw new IllegalStateException("Failed to deploy Endpoint " + bean, e); + } + } + return bean; + } + + @Override + public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { + return bean; + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/websocket/support/ServerEndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/websocket/support/ServerEndpointRegistration.java new file mode 100644 index 0000000000..22ee961ea6 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/websocket/support/ServerEndpointRegistration.java @@ -0,0 +1,204 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.websocket.support; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.websocket.Decoder; +import javax.websocket.Encoder; +import javax.websocket.Endpoint; +import javax.websocket.Extension; +import javax.websocket.HandshakeResponse; +import javax.websocket.server.HandshakeRequest; +import javax.websocket.server.ServerEndpointConfig; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.websocket.WebSocketHandler; + + +/** + * An implementation of {@link javax.websocket.server.ServerEndpointConfig} that also + * holds the target {@link javax.websocket.Endpoint} as a reference or a bean name. + * The target can also be {@link org.springframework.websocket.WebSocketHandler}, in + * which case it will be adapted via {@link StandardWebSocketHandlerAdapter}. + * + *

+ * Beans of this type are detected by {@link ServerEndpointPostProcessor} and + * registered with a Java WebSocket runtime at startup. + * + * @author Rossen Stoyanchev + */ +public class ServerEndpointRegistration implements ServerEndpointConfig, BeanFactoryAware { + + private final String path; + + private final Object bean; + + private List subprotocols = new ArrayList(); + + private List extensions = new ArrayList(); + + private Map userProperties = new HashMap(); + + private BeanFactory beanFactory; + + private final Configurator configurator = new Configurator() {}; + + + public ServerEndpointRegistration(String path, String beanName) { + Assert.hasText(path, "path must not be empty"); + Assert.notNull(beanName, "beanName is required"); + this.path = path; + this.bean = beanName; + } + + public ServerEndpointRegistration(String path, Object bean) { + Assert.hasText(path, "path must not be empty"); + Assert.notNull(bean, "bean is required"); + this.path = path; + this.bean = bean; + } + + @Override + public String getPath() { + return this.path; + } + + @SuppressWarnings("unchecked") + @Override + public Class getEndpointClass() { + Class beanClass = this.bean.getClass(); + if (beanClass.equals(String.class)) { + beanClass = this.beanFactory.getType((String) this.bean); + } + beanClass = ClassUtils.getUserClass(beanClass); + if (WebSocketHandler.class.isAssignableFrom(beanClass)) { + return StandardWebSocketHandlerAdapter.class; + } + else { + return (Class) beanClass; + } + } + + protected Endpoint getEndpoint() { + Object bean = this.bean; + if (this.bean instanceof String) { + bean = this.beanFactory.getBean((String) this.bean); + } + if (bean instanceof WebSocketHandler) { + return new StandardWebSocketHandlerAdapter((WebSocketHandler) bean); + } + else { + return (Endpoint) bean; + } + } + + @Override + public List getSubprotocols() { + return this.subprotocols; + } + + public void setSubprotocols(List subprotocols) { + this.subprotocols = subprotocols; + } + + @Override + public List getExtensions() { + return this.extensions; + } + + public void setExtensions(List extensions) { + // TODO: verify against ServerContainer.getInstalledExtensions() + this.extensions = extensions; + } + + @Override + public Map getUserProperties() { + return this.userProperties; + } + + public void setUserProperties(Map userProperties) { + this.userProperties = userProperties; + } + + @Override + public List> getEncoders() { + return Collections.emptyList(); + } + + @Override + public List> getDecoders() { + return Collections.emptyList(); + } + + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + this.beanFactory = beanFactory; + } + + @Override + public Configurator getConfigurator() { + return new Configurator() { + @SuppressWarnings("unchecked") + @Override + public T getEndpointInstance(Class clazz) throws InstantiationException { + return (T) ServerEndpointRegistration.this.getEndpoint(); + } + @Override + public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) { + ServerEndpointRegistration.this.modifyHandshake(request, response); + } + @Override + public boolean checkOrigin(String originHeaderValue) { + return ServerEndpointRegistration.this.checkOrigin(originHeaderValue); + } + @Override + public String getNegotiatedSubprotocol(List supported, List requested) { + return ServerEndpointRegistration.this.selectSubProtocol(requested); + } + @Override + public List getNegotiatedExtensions(List installed, List requested) { + return ServerEndpointRegistration.this.selectExtensions(requested); + } + }; + } + + protected void modifyHandshake(HandshakeRequest request, HandshakeResponse response) { + this.configurator.modifyHandshake(this, request, response); + } + + protected boolean checkOrigin(String originHeaderValue) { + return this.configurator.checkOrigin(originHeaderValue); + } + + protected String selectSubProtocol(List requested) { + return this.configurator.getNegotiatedSubprotocol(getSubprotocols(), requested); + } + + protected List selectExtensions(List requested) { + return this.configurator.getNegotiatedExtensions(getExtensions(), requested); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/websocket/support/StandardSessionAdapter.java b/spring-websocket/src/main/java/org/springframework/websocket/support/StandardSessionAdapter.java new file mode 100644 index 0000000000..e4f8d64501 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/websocket/support/StandardSessionAdapter.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.websocket.support; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.websocket.Session; + + +/** + * + * @author Rossen Stoyanchev + */ +public class StandardSessionAdapter implements Session { + + private static Log logger = LogFactory.getLog(StandardSessionAdapter.class); + + private javax.websocket.Session sourceSession; + + + public StandardSessionAdapter(javax.websocket.Session sourceSession) { + this.sourceSession = sourceSession; + } + + @Override + public void sendText(String text) throws Exception { + logger.trace("Sending text message: " + text); + this.sourceSession.getBasicRemote().sendText(text); + } + + @Override + public void close(int code, String reason) throws Exception { + this.sourceSession = null; + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/websocket/support/StandardWebSocketHandlerAdapter.java b/spring-websocket/src/main/java/org/springframework/websocket/support/StandardWebSocketHandlerAdapter.java new file mode 100644 index 0000000000..0e1c99eabc --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/websocket/support/StandardWebSocketHandlerAdapter.java @@ -0,0 +1,131 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.websocket.support; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import javax.websocket.CloseReason; +import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.MessageHandler; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.util.Assert; +import org.springframework.websocket.Session; +import org.springframework.websocket.WebSocketHandler; + + +/** + * + * @author Rossen Stoyanchev + */ +public class StandardWebSocketHandlerAdapter extends Endpoint { + + private static Log logger = LogFactory.getLog(StandardWebSocketHandlerAdapter.class); + + private final WebSocketHandler webSocketHandler; + + private final Map sessionMap = new ConcurrentHashMap(); + + + public StandardWebSocketHandlerAdapter(WebSocketHandler webSocketHandler) { + this.webSocketHandler = webSocketHandler; + } + + @Override + public void onOpen(javax.websocket.Session sourceSession, EndpointConfig config) { + logger.debug("New WebSocket session: " + sourceSession); + try { + Session session = new StandardSessionAdapter(sourceSession); + this.sessionMap.put(sourceSession.getId(), session); + sourceSession.addMessageHandler(new StandardMessageHandler(sourceSession.getId())); + this.webSocketHandler.newSession(session); + } + catch (Throwable ex) { + // TODO + logger.error("Error while processing new session", ex); + } + } + + @Override + public void onClose(javax.websocket.Session sourceSession, CloseReason closeReason) { + String id = sourceSession.getId(); + if (logger.isDebugEnabled()) { + logger.debug("Closing session: " + sourceSession + ", " + closeReason); + } + try { + Session session = getSession(id); + this.sessionMap.remove(id); + int code = closeReason.getCloseCode().getCode(); + String reason = closeReason.getReasonPhrase(); + session.close(code, reason); + this.webSocketHandler.sessionClosed(session, code, reason); + } + catch (Throwable ex) { + // TODO + logger.error("Error while processing session closing", ex); + } + } + + @Override + public void onError(javax.websocket.Session sourceSession, Throwable exception) { + logger.error("Error for WebSocket session: " + sourceSession.getId(), exception); + try { + Session session = getSession(sourceSession.getId()); + this.webSocketHandler.handleException(session, exception); + } + catch (Throwable ex) { + // TODO + logger.error("Failed to handle error", ex); + } + } + + private Session getSession(String sourceSessionId) { + Session session = this.sessionMap.get(sourceSessionId); + Assert.notNull(session, "No session"); + return session; + } + + + private class StandardMessageHandler implements MessageHandler.Whole { + + private final String sourceSessionId; + + public StandardMessageHandler(String sourceSessionId) { + this.sourceSessionId = sourceSessionId; + } + + @Override + public void onMessage(String message) { + if (logger.isTraceEnabled()) { + logger.trace("Message for session [" + this.sourceSessionId + "]: " + message); + } + try { + Session session = getSession(this.sourceSessionId); + StandardWebSocketHandlerAdapter.this.webSocketHandler.handleTextMessage(session, message); + } + catch (Throwable ex) { + // TODO + logger.error("Error while processing message", ex); + } + } + + } + +}