Add Glassfish request upgrade strategy

This commit is contained in:
Rossen Stoyanchev 2013-04-10 09:14:38 -04:00
parent 6bd6311214
commit 592da431a8
6 changed files with 252 additions and 66 deletions

View File

@ -525,10 +525,10 @@ project("spring-websocket") {
exclude group: "org.apache.tomcat", module: "tomcat-servlet-api"
}
optional("org.eclipse.jetty:jetty-websocket:8.1.10.v20130312")
optional("org.glassfish.tyrus:tyrus-websocket-core:1.0-SNAPSHOT")
optional("org.glassfish.tyrus:tyrus-container-servlet:1.0-SNAPSHOT")
optional("com.fasterxml.jackson.core:jackson-databind:2.0.1")
optional("com.fasterxml.jackson.core:jackson-databind:2.0.1") // required for SockJS support currently
}

View File

@ -37,6 +37,7 @@ import org.springframework.http.HttpStatus;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;
import org.springframework.websocket.WebSocketHandler;
@ -56,7 +57,7 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean
private final Class<? extends WebSocketHandler> handlerClass;
private List<String> protocols;
private List<String> supportedProtocols;
private AutowireCapableBeanFactory beanFactory;
@ -73,12 +74,12 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean
this.handlerClass = handlerClass;
}
public void setProtocols(String... protocols) {
this.protocols = Arrays.asList(protocols);
public void setSupportedProtocols(String... protocols) {
this.supportedProtocols = Arrays.asList(protocols);
}
public String[] getProtocols() {
return this.protocols.toArray(new String[this.protocols.size()]);
public String[] getSupportedProtocols() {
return this.supportedProtocols.toArray(new String[this.supportedProtocols.size()]);
}
@Override
@ -109,16 +110,20 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean
logger.debug("Only HTTP GET is allowed, current method is " + request.getMethod());
return false;
}
if (!validateUpgradeHeader(request, response)) {
if (!"WebSocket".equalsIgnoreCase(request.getHeaders().getUpgrade())) {
handleInvalidUpgradeHeader(request, response);
return false;
}
if (!validateConnectHeader(request, response)) {
if (!request.getHeaders().getConnection().contains("Upgrade")) {
handleInvalidConnectHeader(request, response);
return false;
}
if (!validateWebSocketVersion(request, response)) {
if (!isWebSocketVersionSupported(request)) {
handleWebSocketVersionNotSupported(request, response);
return false;
}
if (!validateOrigin(request, response)) {
if (!isValidOrigin(request)) {
response.setStatusCode(HttpStatus.FORBIDDEN);
return false;
}
String wsKey = request.getHeaders().getSecWebSocketKey();
@ -127,8 +132,9 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean
response.setStatusCode(HttpStatus.BAD_REQUEST);
return false;
}
String protocol = selectProtocol(request.getHeaders().getSecWebSocketProtocol());
// TODO: request.getHeaders().getSecWebSocketExtensions())
// TODO: select extensions
response.setStatusCode(HttpStatus.SWITCHING_PROTOCOLS);
response.getHeaders().setUpgrade("WebSocket");
@ -139,7 +145,7 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean
logger.debug("Successfully negotiated WebSocket handshake");
// TODO: surely there is a better way to flush the headers
// TODO: surely there is a better way to flush headers
response.getBody();
doHandshakeInternal(request, response, protocol);
@ -150,46 +156,46 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean
protected abstract void doHandshakeInternal(ServerHttpRequest request, ServerHttpResponse response,
String protocol) throws Exception;
protected boolean validateUpgradeHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
if (!"WebSocket".equalsIgnoreCase(request.getHeaders().getUpgrade())) {
response.setStatusCode(HttpStatus.BAD_REQUEST);
response.getBody().write("Can \"Upgrade\" only to \"websocket\".".getBytes("UTF-8"));
logger.debug("Invalid Upgrade header " + request.getHeaders().getUpgrade());
return false;
}
return true;
protected void handleInvalidUpgradeHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
logger.debug("Invalid Upgrade header " + request.getHeaders().getUpgrade());
response.setStatusCode(HttpStatus.BAD_REQUEST);
response.getBody().write("Can \"Upgrade\" only to \"websocket\".".getBytes("UTF-8"));
}
protected boolean validateConnectHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
if (!request.getHeaders().getConnection().contains("Upgrade")) {
response.setStatusCode(HttpStatus.BAD_REQUEST);
response.getBody().write("\"Connection\" must be \"upgrade\".".getBytes("UTF-8"));
logger.debug("Invalid Connection header " + request.getHeaders().getConnection());
return false;
}
return true;
protected void handleInvalidConnectHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
logger.debug("Invalid Connection header " + request.getHeaders().getConnection());
response.setStatusCode(HttpStatus.BAD_REQUEST);
response.getBody().write("\"Connection\" must be \"upgrade\".".getBytes("UTF-8"));
}
protected boolean validateWebSocketVersion(ServerHttpRequest request, ServerHttpResponse response) {
if (!"13".equals(request.getHeaders().getSecWebSocketVersion())) {
response.setStatusCode(HttpStatus.UPGRADE_REQUIRED);
response.getHeaders().set("Sec-WebSocket-Version", "13");
logger.debug("WebSocket version not supported " + request.getHeaders().get("Sec-WebSocket-Version"));
return false;
protected boolean isWebSocketVersionSupported(ServerHttpRequest request) {
String requestedVersion = request.getHeaders().getSecWebSocketVersion();
for (String supportedVersion : getSupportedVerions()) {
if (supportedVersion.equals(requestedVersion)) {
return true;
}
}
return true;
return false;
}
protected boolean validateOrigin(ServerHttpRequest request, ServerHttpResponse response) {
protected String[] getSupportedVerions() {
return new String[] { "13" };
}
protected void handleWebSocketVersionNotSupported(ServerHttpRequest request, ServerHttpResponse response) {
logger.debug("WebSocket version not supported " + request.getHeaders().get("Sec-WebSocket-Version"));
response.setStatusCode(HttpStatus.UPGRADE_REQUIRED);
response.getHeaders().setSecWebSocketVersion(StringUtils.arrayToCommaDelimitedString(getSupportedVerions()));
}
protected boolean isValidOrigin(ServerHttpRequest request) {
String origin = request.getHeaders().getOrigin();
if (origin != null) {
UriComponentsBuilder originUriBuilder = UriComponentsBuilder.fromHttpUrl(origin);
// UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromHttpUrl(origin);
// TODO
// Check scheme, port, and host against list of configured origins (allow wild cards in the host?)
// Another strategy might be to match current request's scheme/port/host
// response.setStatusCode(HttpStatus.FORBIDDEN);
// A simple strategy checks against the current request's scheme/port/host
// Or match scheme, port, and host against configured allowed origins (wild cards for hosts?)
// return false;
}
return true;
@ -197,9 +203,9 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean
protected String selectProtocol(List<String> requestedProtocols) {
if (requestedProtocols != null) {
for (String p : requestedProtocols) {
if (this.protocols.contains(p)) {
return p;
for (String protocol : requestedProtocols) {
if (this.supportedProtocols.contains(protocol)) {
return protocol;
}
}
}

View File

@ -37,45 +37,56 @@ import org.springframework.websocket.support.WebSocketHandlerEndpoint;
public class EndpointHandshakeHandler extends AbstractHandshakeHandler {
private static final boolean tomcatWebSocketPresent = ClassUtils.isPresent(
"org.apache.tomcat.websocket.server.WsHandshakeRequest", EndpointHandshakeHandler.class.getClassLoader());
"org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", EndpointHandshakeHandler.class.getClassLoader());
private final WebSocketRequestUpgradeStrategy upgradeStrategy;
private static final boolean glassfishWebSocketPresent = ClassUtils.isPresent(
"org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", EndpointHandshakeHandler.class.getClassLoader());
private final EndpointRequestUpgradeStrategy upgradeStrategy;
public EndpointHandshakeHandler(WebSocketHandler webSocketHandler) {
super(webSocketHandler);
this.upgradeStrategy = createRequestUpgradeStrategy();
this.upgradeStrategy = createUpgradeStrategy();
}
public EndpointHandshakeHandler(Class<? extends WebSocketHandler> handlerClass) {
super(handlerClass);
this.upgradeStrategy = createRequestUpgradeStrategy();
this.upgradeStrategy = createUpgradeStrategy();
}
private static WebSocketRequestUpgradeStrategy createRequestUpgradeStrategy() {
private static EndpointRequestUpgradeStrategy createUpgradeStrategy() {
String className;
if (tomcatWebSocketPresent) {
className = "org.springframework.websocket.server.endpoint.TomcatRequestUpgradeStrategy";
className = "org.springframework.websocket.server.endpoint.support.TomcatRequestUpgradeStrategy";
}
else if (glassfishWebSocketPresent) {
className = "org.springframework.websocket.server.endpoint.support.GlassfishRequestUpgradeStrategy";
}
else {
throw new IllegalStateException("No suitable EndpointRequestUpgradeStrategy");
}
try {
Class<?> clazz = ClassUtils.forName(className, EndpointHandshakeHandler.class.getClassLoader());
return (WebSocketRequestUpgradeStrategy) BeanUtils.instantiateClass(clazz.getConstructor());
return (EndpointRequestUpgradeStrategy) BeanUtils.instantiateClass(clazz.getConstructor());
}
catch (Throwable t) {
throw new IllegalStateException("Failed to instantiate " + className, t);
}
}
@Override
protected String[] getSupportedVerions() {
return this.upgradeStrategy.getSupportedVersions();
}
@Override
public void doHandshakeInternal(ServerHttpRequest request, ServerHttpResponse response, String protocol)
throws Exception {
logger.debug("Upgrading HTTP request");
Endpoint endpoint = new WebSocketHandlerEndpoint(getWebSocketHandler());
this.upgradeStrategy.upgrade(request, response, protocol, new EndpointRegistration("/dummy", endpoint));
this.upgradeStrategy.upgrade(request, response, protocol, endpoint);
}
}

View File

@ -16,6 +16,8 @@
package org.springframework.websocket.server.endpoint;
import javax.websocket.Endpoint;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
@ -26,12 +28,14 @@ import org.springframework.http.server.ServerHttpResponse;
* @author Rossen Stoyanchev
* @since 4.0
*/
public interface WebSocketRequestUpgradeStrategy {
public interface EndpointRequestUpgradeStrategy {
String[] getSupportedVersions();
/**
* Invoked after the handshake checks have been performed and succeeded.
*/
void upgrade(ServerHttpRequest request, ServerHttpResponse response, String protocol,
EndpointRegistration registration) throws Exception;
void upgrade(ServerHttpRequest request, ServerHttpResponse response, String protocol, Endpoint endpoint)
throws Exception;
}

View File

@ -0,0 +1,152 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.websocket.server.endpoint.support;
import java.lang.reflect.Constructor;
import java.net.URI;
import java.util.Random;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import javax.websocket.Endpoint;
import org.glassfish.tyrus.core.ComponentProviderService;
import org.glassfish.tyrus.core.EndpointWrapper;
import org.glassfish.tyrus.core.ErrorCollector;
import org.glassfish.tyrus.core.RequestContext;
import org.glassfish.tyrus.server.TyrusEndpoint;
import org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler;
import org.glassfish.tyrus.websockets.Connection;
import org.glassfish.tyrus.websockets.Version;
import org.glassfish.tyrus.websockets.WebSocketEngine;
import org.glassfish.tyrus.websockets.WebSocketEngine.WebSocketHolderListener;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.websocket.server.endpoint.EndpointRegistration;
import org.springframework.websocket.server.endpoint.EndpointRequestUpgradeStrategy;
/**
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class GlassfishRequestUpgradeStrategy implements EndpointRequestUpgradeStrategy {
private final static Random random = new Random();
@Override
public String[] getSupportedVersions() {
return StringUtils.commaDelimitedListToStringArray(Version.getSupportedWireProtocolVersions());
}
@Override
public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String protocol,
Endpoint endpoint) throws Exception {
Assert.isTrue(request instanceof ServletServerHttpRequest);
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
Assert.isTrue(response instanceof ServletServerHttpResponse);
HttpServletResponse servletResponse = ((ServletServerHttpResponse) response).getServletResponse();
servletResponse = new AlreadyUpgradedResponseWrapper(servletResponse);
TyrusEndpoint tyrusEndpoint = createTyrusEndpoint(servletRequest, endpoint);
WebSocketEngine.getEngine().register(tyrusEndpoint);
try {
if (!performUpgrade(servletRequest, servletResponse, request.getHeaders(), tyrusEndpoint)) {
throw new IllegalStateException("Failed to upgrade HttpServletRequest");
}
}
finally {
WebSocketEngine.getEngine().unregister(tyrusEndpoint);
}
}
private boolean performUpgrade(HttpServletRequest request, HttpServletResponse response,
HttpHeaders headers, TyrusEndpoint tyrusEndpoint) throws Exception {
final TyrusHttpUpgradeHandler upgradeHandler = request.upgrade(TyrusHttpUpgradeHandler.class);
Connection connection = createConnection(upgradeHandler, response);
RequestContext wsRequest = RequestContext.Builder.create()
.requestURI(URI.create(tyrusEndpoint.getPath())).requestPath(tyrusEndpoint.getPath())
.connection(connection).secure(request.isSecure()).build();
for (String header : headers.keySet()) {
wsRequest.getHeaders().put(header, headers.get(header));
}
return WebSocketEngine.getEngine().upgrade(connection, wsRequest, new WebSocketHolderListener() {
@Override
public void onWebSocketHolder(WebSocketEngine.WebSocketHolder webSocketHolder) {
upgradeHandler.setWebSocketHolder(webSocketHolder);
}
});
}
private TyrusEndpoint createTyrusEndpoint(HttpServletRequest request, Endpoint endpoint) {
// Use randomized path
String requestUri = request.getRequestURI();
String randomValue = String.valueOf(random.nextLong());
String endpointPath = requestUri.endsWith("/") ? requestUri + randomValue : requestUri + "/" + randomValue;
EndpointRegistration endpointConfig = new EndpointRegistration(endpointPath, endpoint);
return new TyrusEndpoint(new EndpointWrapper(endpoint, endpointConfig,
ComponentProviderService.create(), null, "/", new ErrorCollector(),
endpointConfig.getConfigurator()));
}
private Connection createConnection(TyrusHttpUpgradeHandler handler, HttpServletResponse response) throws Exception {
String name = "org.glassfish.tyrus.servlet.ConnectionImpl";
Class<?> clazz = ClassUtils.forName(name, GlassfishRequestUpgradeStrategy.class.getClassLoader());
Constructor<?> constructor = clazz.getDeclaredConstructor(TyrusHttpUpgradeHandler.class, HttpServletResponse.class);
ReflectionUtils.makeAccessible(constructor);
return (Connection) constructor.newInstance(handler, response);
}
private static class AlreadyUpgradedResponseWrapper extends HttpServletResponseWrapper {
public AlreadyUpgradedResponseWrapper(HttpServletResponse response) {
super(response);
}
@Override
public void setStatus(int sc) {
Assert.isTrue(sc == HttpStatus.SWITCHING_PROTOCOLS.value(), "Unexpected status code " + sc);
}
@Override
public void addHeader(String name, String value) {
// ignore
}
}
}

View File

@ -14,12 +14,14 @@
* limitations under the License.
*/
package org.springframework.websocket.server.endpoint;
package org.springframework.websocket.server.endpoint.support;
import java.lang.reflect.Method;
import java.util.Collections;
import javax.servlet.http.HttpServletRequest;
import javax.websocket.Endpoint;
import javax.websocket.server.ServerEndpointConfig;
import org.apache.tomcat.websocket.server.WsHandshakeRequest;
import org.apache.tomcat.websocket.server.WsHttpUpgradeHandler;
@ -29,6 +31,8 @@ import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import org.springframework.websocket.server.endpoint.EndpointRegistration;
import org.springframework.websocket.server.endpoint.EndpointRequestUpgradeStrategy;
/**
@ -36,26 +40,35 @@ import org.springframework.util.ReflectionUtils;
* @author Rossen Stoyanchev
* @since 4.0
*/
public class TomcatRequestUpgradeStrategy implements WebSocketRequestUpgradeStrategy {
public class TomcatRequestUpgradeStrategy implements EndpointRequestUpgradeStrategy {
@Override
public String[] getSupportedVersions() {
return new String[] { "13" };
}
@Override
public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String protocol,
EndpointRegistration registration) throws Exception {
Endpoint endpoint) throws Exception {
Assert.isTrue(request instanceof ServletServerHttpRequest);
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
WsHttpUpgradeHandler wsHandler = servletRequest.upgrade(WsHttpUpgradeHandler.class);
WsHttpUpgradeHandler upgradeHandler = servletRequest.upgrade(WsHttpUpgradeHandler.class);
WsHandshakeRequest wsRequest = new WsHandshakeRequest(servletRequest);
WsHandshakeRequest webSocketRequest = new WsHandshakeRequest(servletRequest);
Method method = ReflectionUtils.findMethod(WsHandshakeRequest.class, "finished");
ReflectionUtils.makeAccessible(method);
method.invoke(wsRequest);
method.invoke(webSocketRequest);
wsHandler.preInit(registration.getEndpoint(), registration,
WsServerContainer.getServerContainer(), wsRequest, protocol,
Collections.<String, String> emptyMap(), servletRequest.isSecure());
// TODO: use ServletContext attribute when Tomcat is updated
WsServerContainer serverContainer = WsServerContainer.getServerContainer();
ServerEndpointConfig endpointConfig = new EndpointRegistration("/shouldntmatter", endpoint);
upgradeHandler.preInit(endpoint, endpointConfig, serverContainer, webSocketRequest,
protocol, Collections.<String, String> emptyMap(), servletRequest.isSecure());
}
}