Add handshake request handler abstraction

This commit is contained in:
Rossen Stoyanchev 2013-03-27 14:17:47 -04:00
parent cdd7d7bd88
commit 715018fe75
11 changed files with 349 additions and 20 deletions

View File

@ -0,0 +1,174 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.websocket;
import java.io.IOException;
import java.nio.charset.Charset;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import javax.xml.bind.DatatypeConverter;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.web.util.UriComponentsBuilder;
/**
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public abstract class AbstractHandshakeRequestHandler implements HandshakeRequestHandler {
private static final String GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
protected Log logger = LogFactory.getLog(getClass());
private List<String> protocols;
public void setProtocols(String... protocols) {
this.protocols = Arrays.asList(protocols);
}
public String[] getProtocols() {
return this.protocols.toArray(new String[this.protocols.size()]);
}
@Override
public boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response) throws Exception {
logger.debug("Starting handshake for " + request.getURI());
if (!HttpMethod.GET.equals(request.getMethod())) {
response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED);
response.getHeaders().setAllow(Collections.singleton(HttpMethod.GET));
logger.debug("Only HTTP GET is allowed, current method is " + request.getMethod());
return false;
}
if (!validateUpgradeHeader(request, response)) {
return false;
}
if (!validateConnectHeader(request, response)) {
return false;
}
if (!validateWebSocketVersion(request, response)) {
return false;
}
if (!validateOrigin(request, response)) {
return false;
}
String wsKey = request.getHeaders().getSecWebSocketKey();
if (wsKey == null) {
logger.debug("Missing \"Sec-WebSocket-Key\" header");
response.setStatusCode(HttpStatus.BAD_REQUEST);
return false;
}
String protocol = selectProtocol(request.getHeaders().getSecWebSocketProtocol());
// TODO: request.getHeaders().getSecWebSocketExtensions())
response.setStatusCode(HttpStatus.SWITCHING_PROTOCOLS);
response.getHeaders().setUpgrade("WebSocket");
response.getHeaders().setConnection("Upgrade");
response.getHeaders().setSecWebSocketProtocol(protocol);
response.getHeaders().setSecWebSocketAccept(getWebSocketKeyHash(wsKey));
// TODO: response.getHeaders().setSecWebSocketExtensions(extensions);
logger.debug("Successfully negotiated WebSocket handshake");
// TODO: surely there is a better way to flush the headers
response.getBody();
doHandshakeInternal(request, response, protocol);
return true;
}
protected boolean validateUpgradeHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
if (!"WebSocket".equalsIgnoreCase(request.getHeaders().getUpgrade())) {
response.setStatusCode(HttpStatus.BAD_REQUEST);
response.getBody().write("Can \"Upgrade\" only to \"websocket\".".getBytes("UTF-8"));
logger.debug("Invalid Upgrade header " + request.getHeaders().getUpgrade());
return false;
}
return true;
}
protected boolean validateConnectHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
if (!request.getHeaders().getConnection().contains("Upgrade")) {
response.setStatusCode(HttpStatus.BAD_REQUEST);
response.getBody().write("\"Connection\" must be \"upgrade\".".getBytes("UTF-8"));
logger.debug("Invalid Connection header " + request.getHeaders().getConnection());
return false;
}
return true;
}
protected boolean validateWebSocketVersion(ServerHttpRequest request, ServerHttpResponse response) {
if (!"13".equals(request.getHeaders().getSecWebSocketVersion())) {
response.setStatusCode(HttpStatus.UPGRADE_REQUIRED);
response.getHeaders().set("Sec-WebSocket-Version", "13");
logger.debug("WebSocket version not supported " + request.getHeaders().get("Sec-WebSocket-Version"));
return false;
}
return true;
}
protected boolean validateOrigin(ServerHttpRequest request, ServerHttpResponse response) {
String origin = request.getHeaders().getOrigin();
if (origin != null) {
UriComponentsBuilder originUriBuilder = UriComponentsBuilder.fromHttpUrl(origin);
// TODO
// Check scheme, port, and host against list of configured origins (allow wild cards in the host?)
// Another strategy might be to match current request's scheme/port/host
// response.setStatusCode(HttpStatus.FORBIDDEN);
// return false;
}
return true;
}
protected String selectProtocol(List<String> requestedProtocols) {
if (requestedProtocols != null) {
for (String p : requestedProtocols) {
if (this.protocols.contains(p)) {
return p;
}
}
}
return null;
}
private String getWebSocketKeyHash(String key) throws NoSuchAlgorithmException {
MessageDigest digest = MessageDigest.getInstance("SHA1");
byte[] bytes = digest.digest((key + GUID).getBytes(Charset.forName("ISO-8859-1")));
return DatatypeConverter.printBase64Binary(bytes);
}
protected abstract void doHandshakeInternal(ServerHttpRequest request, ServerHttpResponse response, String protocol)
throws Exception;
}

View File

@ -0,0 +1,33 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.websocket;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
/**
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public interface HandshakeRequestHandler {
boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response) throws Exception;
}

View File

@ -21,6 +21,7 @@ package org.springframework.websocket;
/**
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public interface Session {

View File

@ -22,6 +22,7 @@ import java.io.InputStream;
/**
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public interface WebSocketHandler {

View File

@ -21,6 +21,7 @@ import java.io.InputStream;
/**
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class WebSocketHandlerAdapter implements WebSocketHandler {

View File

@ -0,0 +1,54 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.websocket.servlet;
import javax.servlet.ServletContext;
import org.apache.tomcat.websocket.server.WsServerContainer;
import org.springframework.web.context.ServletContextAware;
import org.springframework.websocket.support.ServerEndpointPostProcessor;
/**
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class ServletServerEndpointPostProcessor extends ServerEndpointPostProcessor implements ServletContextAware {
private ServletContext servletContext;
@Override
public void setServletContext(ServletContext servletContext) {
this.servletContext = servletContext;
}
public ServletContext getServletContext() {
return servletContext;
}
@Override
public void afterPropertiesSet() throws Exception {
// TODO: remove hard dependency on Tomcat (see Tomcat's WsListener)
WsServerContainer sc = WsServerContainer.getServerContainer();
sc.setServletContext(this.servletContext);
super.afterPropertiesSet();
}
}

View File

@ -0,0 +1,81 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.websocket.servlet;
import java.lang.reflect.Method;
import java.util.Collections;
import javax.servlet.http.HttpServletRequest;
import javax.websocket.Endpoint;
import javax.websocket.server.ServerEndpointConfig;
import org.apache.tomcat.websocket.server.WsHandshakeRequest;
import org.apache.tomcat.websocket.server.WsHttpUpgradeHandler;
import org.apache.tomcat.websocket.server.WsServerContainer;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import org.springframework.websocket.AbstractHandshakeRequestHandler;
import org.springframework.websocket.WebSocketHandler;
import org.springframework.websocket.support.ServerEndpointRegistration;
import org.springframework.websocket.support.StandardWebSocketHandlerAdapter;
/**
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class TomcatHandshakeRequestHandler extends AbstractHandshakeRequestHandler {
private final Endpoint endpoint;
private final ServerEndpointConfig endpointConfig;
public TomcatHandshakeRequestHandler(WebSocketHandler webSocketHandler) {
this.endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler);
this.endpointConfig = new ServerEndpointRegistration("/shouldnotmatter", this.endpoint);
}
public TomcatHandshakeRequestHandler(Endpoint endpoint) {
this.endpoint = endpoint;
this.endpointConfig = new ServerEndpointRegistration("/shouldnotmatter", this.endpoint);
}
@Override
public void doHandshakeInternal(ServerHttpRequest request, ServerHttpResponse response, String protocol) throws Exception {
logger.debug("Upgrading HTTP request");
Assert.isTrue(request instanceof ServletServerHttpRequest);
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
WsHandshakeRequest wsRequest = new WsHandshakeRequest(servletRequest);
Method method = ReflectionUtils.findMethod(WsHandshakeRequest.class, "finished");
ReflectionUtils.makeAccessible(method);
method.invoke(wsRequest);
WsHttpUpgradeHandler wsHandler = servletRequest.upgrade(WsHttpUpgradeHandler.class);
wsHandler.preInit(this.endpoint, this.endpointConfig, WsServerContainer.getServerContainer(),
wsRequest, protocol, Collections.<String, String> emptyMap(), servletRequest.isSecure());
}
}

View File

@ -15,7 +15,6 @@
*/
package org.springframework.websocket.support;
import javax.servlet.ServletContext;
import javax.websocket.DeploymentException;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerContainerProvider;
@ -23,12 +22,10 @@ import javax.websocket.server.ServerEndpointConfig;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.tomcat.websocket.server.WsServerContainer;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.util.Assert;
import org.springframework.web.context.ServletContextAware;
/**
* BeanPostProcessor that registers {@link javax.websocket.server.ServerEndpointConfig}
@ -38,7 +35,7 @@ import org.springframework.web.context.ServletContextAware;
* @author Rossen Stoyanchev
* @since 4.0
*/
public class ServerEndpointPostProcessor implements ServletContextAware, BeanPostProcessor, InitializingBean {
public class ServerEndpointPostProcessor implements BeanPostProcessor, InitializingBean {
private static Log logger = LogFactory.getLog(ServerEndpointPostProcessor.class);
@ -48,8 +45,6 @@ public class ServerEndpointPostProcessor implements ServletContextAware, BeanPos
private Integer maxBinaryMessageBufferSize;
private ServletContext servletContext;
/**
* If this property set it is in turn used to configure
@ -87,18 +82,8 @@ public class ServerEndpointPostProcessor implements ServletContextAware, BeanPos
return this.maxBinaryMessageBufferSize;
}
@Override
public void setServletContext(ServletContext servletContext) {
this.servletContext = servletContext;
}
public ServletContext getServletContext() {
return servletContext;
}
@Override
public void afterPropertiesSet() throws Exception {
ServerContainer serverContainer = ServerContainerProvider.getServerContainer();
Assert.notNull(serverContainer, "javax.websocket.server.ServerContainer not available");
@ -111,10 +96,6 @@ public class ServerEndpointPostProcessor implements ServletContextAware, BeanPos
if (this.maxBinaryMessageBufferSize != null) {
serverContainer.setDefaultMaxBinaryMessageBufferSize(this.maxBinaryMessageBufferSize);
}
// TODO: this is necessary but only done on Tomcat
WsServerContainer sc = WsServerContainer.getServerContainer();
sc.setServletContext(this.servletContext);
}
@Override

View File

@ -49,6 +49,7 @@ import org.springframework.websocket.WebSocketHandler;
* registered with a Java WebSocket runtime at startup.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class ServerEndpointRegistration implements ServerEndpointConfig, BeanFactoryAware {

View File

@ -24,6 +24,7 @@ import org.springframework.websocket.Session;
/**
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StandardSessionAdapter implements Session {

View File

@ -34,6 +34,7 @@ import org.springframework.websocket.WebSocketHandler;
/**
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StandardWebSocketHandlerAdapter extends Endpoint {