Add Glassfish request upgrade strategy
This commit is contained in:
parent
6bd6311214
commit
592da431a8
|
@ -525,10 +525,10 @@ project("spring-websocket") {
|
||||||
exclude group: "org.apache.tomcat", module: "tomcat-servlet-api"
|
exclude group: "org.apache.tomcat", module: "tomcat-servlet-api"
|
||||||
}
|
}
|
||||||
|
|
||||||
optional("org.eclipse.jetty:jetty-websocket:8.1.10.v20130312")
|
|
||||||
optional("org.glassfish.tyrus:tyrus-websocket-core:1.0-SNAPSHOT")
|
optional("org.glassfish.tyrus:tyrus-websocket-core:1.0-SNAPSHOT")
|
||||||
|
optional("org.glassfish.tyrus:tyrus-container-servlet:1.0-SNAPSHOT")
|
||||||
|
|
||||||
optional("com.fasterxml.jackson.core:jackson-databind:2.0.1")
|
optional("com.fasterxml.jackson.core:jackson-databind:2.0.1") // required for SockJS support currently
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,6 +37,7 @@ import org.springframework.http.HttpStatus;
|
||||||
import org.springframework.http.server.ServerHttpRequest;
|
import org.springframework.http.server.ServerHttpRequest;
|
||||||
import org.springframework.http.server.ServerHttpResponse;
|
import org.springframework.http.server.ServerHttpResponse;
|
||||||
import org.springframework.util.Assert;
|
import org.springframework.util.Assert;
|
||||||
|
import org.springframework.util.StringUtils;
|
||||||
import org.springframework.web.util.UriComponentsBuilder;
|
import org.springframework.web.util.UriComponentsBuilder;
|
||||||
import org.springframework.websocket.WebSocketHandler;
|
import org.springframework.websocket.WebSocketHandler;
|
||||||
|
|
||||||
|
@ -56,7 +57,7 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean
|
||||||
|
|
||||||
private final Class<? extends WebSocketHandler> handlerClass;
|
private final Class<? extends WebSocketHandler> handlerClass;
|
||||||
|
|
||||||
private List<String> protocols;
|
private List<String> supportedProtocols;
|
||||||
|
|
||||||
private AutowireCapableBeanFactory beanFactory;
|
private AutowireCapableBeanFactory beanFactory;
|
||||||
|
|
||||||
|
@ -73,12 +74,12 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean
|
||||||
this.handlerClass = handlerClass;
|
this.handlerClass = handlerClass;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setProtocols(String... protocols) {
|
public void setSupportedProtocols(String... protocols) {
|
||||||
this.protocols = Arrays.asList(protocols);
|
this.supportedProtocols = Arrays.asList(protocols);
|
||||||
}
|
}
|
||||||
|
|
||||||
public String[] getProtocols() {
|
public String[] getSupportedProtocols() {
|
||||||
return this.protocols.toArray(new String[this.protocols.size()]);
|
return this.supportedProtocols.toArray(new String[this.supportedProtocols.size()]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -109,16 +110,20 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean
|
||||||
logger.debug("Only HTTP GET is allowed, current method is " + request.getMethod());
|
logger.debug("Only HTTP GET is allowed, current method is " + request.getMethod());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!validateUpgradeHeader(request, response)) {
|
if (!"WebSocket".equalsIgnoreCase(request.getHeaders().getUpgrade())) {
|
||||||
|
handleInvalidUpgradeHeader(request, response);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!validateConnectHeader(request, response)) {
|
if (!request.getHeaders().getConnection().contains("Upgrade")) {
|
||||||
|
handleInvalidConnectHeader(request, response);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!validateWebSocketVersion(request, response)) {
|
if (!isWebSocketVersionSupported(request)) {
|
||||||
|
handleWebSocketVersionNotSupported(request, response);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!validateOrigin(request, response)) {
|
if (!isValidOrigin(request)) {
|
||||||
|
response.setStatusCode(HttpStatus.FORBIDDEN);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
String wsKey = request.getHeaders().getSecWebSocketKey();
|
String wsKey = request.getHeaders().getSecWebSocketKey();
|
||||||
|
@ -127,8 +132,9 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean
|
||||||
response.setStatusCode(HttpStatus.BAD_REQUEST);
|
response.setStatusCode(HttpStatus.BAD_REQUEST);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
String protocol = selectProtocol(request.getHeaders().getSecWebSocketProtocol());
|
String protocol = selectProtocol(request.getHeaders().getSecWebSocketProtocol());
|
||||||
// TODO: request.getHeaders().getSecWebSocketExtensions())
|
// TODO: select extensions
|
||||||
|
|
||||||
response.setStatusCode(HttpStatus.SWITCHING_PROTOCOLS);
|
response.setStatusCode(HttpStatus.SWITCHING_PROTOCOLS);
|
||||||
response.getHeaders().setUpgrade("WebSocket");
|
response.getHeaders().setUpgrade("WebSocket");
|
||||||
|
@ -139,7 +145,7 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean
|
||||||
|
|
||||||
logger.debug("Successfully negotiated WebSocket handshake");
|
logger.debug("Successfully negotiated WebSocket handshake");
|
||||||
|
|
||||||
// TODO: surely there is a better way to flush the headers
|
// TODO: surely there is a better way to flush headers
|
||||||
response.getBody();
|
response.getBody();
|
||||||
|
|
||||||
doHandshakeInternal(request, response, protocol);
|
doHandshakeInternal(request, response, protocol);
|
||||||
|
@ -150,46 +156,46 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean
|
||||||
protected abstract void doHandshakeInternal(ServerHttpRequest request, ServerHttpResponse response,
|
protected abstract void doHandshakeInternal(ServerHttpRequest request, ServerHttpResponse response,
|
||||||
String protocol) throws Exception;
|
String protocol) throws Exception;
|
||||||
|
|
||||||
protected boolean validateUpgradeHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
|
|
||||||
if (!"WebSocket".equalsIgnoreCase(request.getHeaders().getUpgrade())) {
|
protected void handleInvalidUpgradeHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
|
||||||
response.setStatusCode(HttpStatus.BAD_REQUEST);
|
logger.debug("Invalid Upgrade header " + request.getHeaders().getUpgrade());
|
||||||
response.getBody().write("Can \"Upgrade\" only to \"websocket\".".getBytes("UTF-8"));
|
response.setStatusCode(HttpStatus.BAD_REQUEST);
|
||||||
logger.debug("Invalid Upgrade header " + request.getHeaders().getUpgrade());
|
response.getBody().write("Can \"Upgrade\" only to \"websocket\".".getBytes("UTF-8"));
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected boolean validateConnectHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
|
protected void handleInvalidConnectHeader(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
|
||||||
if (!request.getHeaders().getConnection().contains("Upgrade")) {
|
logger.debug("Invalid Connection header " + request.getHeaders().getConnection());
|
||||||
response.setStatusCode(HttpStatus.BAD_REQUEST);
|
response.setStatusCode(HttpStatus.BAD_REQUEST);
|
||||||
response.getBody().write("\"Connection\" must be \"upgrade\".".getBytes("UTF-8"));
|
response.getBody().write("\"Connection\" must be \"upgrade\".".getBytes("UTF-8"));
|
||||||
logger.debug("Invalid Connection header " + request.getHeaders().getConnection());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected boolean validateWebSocketVersion(ServerHttpRequest request, ServerHttpResponse response) {
|
protected boolean isWebSocketVersionSupported(ServerHttpRequest request) {
|
||||||
if (!"13".equals(request.getHeaders().getSecWebSocketVersion())) {
|
String requestedVersion = request.getHeaders().getSecWebSocketVersion();
|
||||||
response.setStatusCode(HttpStatus.UPGRADE_REQUIRED);
|
for (String supportedVersion : getSupportedVerions()) {
|
||||||
response.getHeaders().set("Sec-WebSocket-Version", "13");
|
if (supportedVersion.equals(requestedVersion)) {
|
||||||
logger.debug("WebSocket version not supported " + request.getHeaders().get("Sec-WebSocket-Version"));
|
return true;
|
||||||
return false;
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected boolean validateOrigin(ServerHttpRequest request, ServerHttpResponse response) {
|
protected String[] getSupportedVerions() {
|
||||||
|
return new String[] { "13" };
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void handleWebSocketVersionNotSupported(ServerHttpRequest request, ServerHttpResponse response) {
|
||||||
|
logger.debug("WebSocket version not supported " + request.getHeaders().get("Sec-WebSocket-Version"));
|
||||||
|
response.setStatusCode(HttpStatus.UPGRADE_REQUIRED);
|
||||||
|
response.getHeaders().setSecWebSocketVersion(StringUtils.arrayToCommaDelimitedString(getSupportedVerions()));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected boolean isValidOrigin(ServerHttpRequest request) {
|
||||||
String origin = request.getHeaders().getOrigin();
|
String origin = request.getHeaders().getOrigin();
|
||||||
if (origin != null) {
|
if (origin != null) {
|
||||||
UriComponentsBuilder originUriBuilder = UriComponentsBuilder.fromHttpUrl(origin);
|
// UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromHttpUrl(origin);
|
||||||
|
|
||||||
// TODO
|
// TODO
|
||||||
// Check scheme, port, and host against list of configured origins (allow wild cards in the host?)
|
// A simple strategy checks against the current request's scheme/port/host
|
||||||
// Another strategy might be to match current request's scheme/port/host
|
// Or match scheme, port, and host against configured allowed origins (wild cards for hosts?)
|
||||||
|
|
||||||
// response.setStatusCode(HttpStatus.FORBIDDEN);
|
|
||||||
// return false;
|
// return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
@ -197,9 +203,9 @@ public abstract class AbstractHandshakeHandler implements HandshakeHandler, Bean
|
||||||
|
|
||||||
protected String selectProtocol(List<String> requestedProtocols) {
|
protected String selectProtocol(List<String> requestedProtocols) {
|
||||||
if (requestedProtocols != null) {
|
if (requestedProtocols != null) {
|
||||||
for (String p : requestedProtocols) {
|
for (String protocol : requestedProtocols) {
|
||||||
if (this.protocols.contains(p)) {
|
if (this.supportedProtocols.contains(protocol)) {
|
||||||
return p;
|
return protocol;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,45 +37,56 @@ import org.springframework.websocket.support.WebSocketHandlerEndpoint;
|
||||||
public class EndpointHandshakeHandler extends AbstractHandshakeHandler {
|
public class EndpointHandshakeHandler extends AbstractHandshakeHandler {
|
||||||
|
|
||||||
private static final boolean tomcatWebSocketPresent = ClassUtils.isPresent(
|
private static final boolean tomcatWebSocketPresent = ClassUtils.isPresent(
|
||||||
"org.apache.tomcat.websocket.server.WsHandshakeRequest", EndpointHandshakeHandler.class.getClassLoader());
|
"org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", EndpointHandshakeHandler.class.getClassLoader());
|
||||||
|
|
||||||
private final WebSocketRequestUpgradeStrategy upgradeStrategy;
|
private static final boolean glassfishWebSocketPresent = ClassUtils.isPresent(
|
||||||
|
"org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", EndpointHandshakeHandler.class.getClassLoader());
|
||||||
|
|
||||||
|
private final EndpointRequestUpgradeStrategy upgradeStrategy;
|
||||||
|
|
||||||
|
|
||||||
public EndpointHandshakeHandler(WebSocketHandler webSocketHandler) {
|
public EndpointHandshakeHandler(WebSocketHandler webSocketHandler) {
|
||||||
super(webSocketHandler);
|
super(webSocketHandler);
|
||||||
this.upgradeStrategy = createRequestUpgradeStrategy();
|
this.upgradeStrategy = createUpgradeStrategy();
|
||||||
}
|
}
|
||||||
|
|
||||||
public EndpointHandshakeHandler(Class<? extends WebSocketHandler> handlerClass) {
|
public EndpointHandshakeHandler(Class<? extends WebSocketHandler> handlerClass) {
|
||||||
super(handlerClass);
|
super(handlerClass);
|
||||||
this.upgradeStrategy = createRequestUpgradeStrategy();
|
this.upgradeStrategy = createUpgradeStrategy();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static WebSocketRequestUpgradeStrategy createRequestUpgradeStrategy() {
|
private static EndpointRequestUpgradeStrategy createUpgradeStrategy() {
|
||||||
String className;
|
String className;
|
||||||
if (tomcatWebSocketPresent) {
|
if (tomcatWebSocketPresent) {
|
||||||
className = "org.springframework.websocket.server.endpoint.TomcatRequestUpgradeStrategy";
|
className = "org.springframework.websocket.server.endpoint.support.TomcatRequestUpgradeStrategy";
|
||||||
|
}
|
||||||
|
else if (glassfishWebSocketPresent) {
|
||||||
|
className = "org.springframework.websocket.server.endpoint.support.GlassfishRequestUpgradeStrategy";
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
throw new IllegalStateException("No suitable EndpointRequestUpgradeStrategy");
|
throw new IllegalStateException("No suitable EndpointRequestUpgradeStrategy");
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
Class<?> clazz = ClassUtils.forName(className, EndpointHandshakeHandler.class.getClassLoader());
|
Class<?> clazz = ClassUtils.forName(className, EndpointHandshakeHandler.class.getClassLoader());
|
||||||
return (WebSocketRequestUpgradeStrategy) BeanUtils.instantiateClass(clazz.getConstructor());
|
return (EndpointRequestUpgradeStrategy) BeanUtils.instantiateClass(clazz.getConstructor());
|
||||||
}
|
}
|
||||||
catch (Throwable t) {
|
catch (Throwable t) {
|
||||||
throw new IllegalStateException("Failed to instantiate " + className, t);
|
throw new IllegalStateException("Failed to instantiate " + className, t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String[] getSupportedVerions() {
|
||||||
|
return this.upgradeStrategy.getSupportedVersions();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doHandshakeInternal(ServerHttpRequest request, ServerHttpResponse response, String protocol)
|
public void doHandshakeInternal(ServerHttpRequest request, ServerHttpResponse response, String protocol)
|
||||||
throws Exception {
|
throws Exception {
|
||||||
|
|
||||||
logger.debug("Upgrading HTTP request");
|
logger.debug("Upgrading HTTP request");
|
||||||
Endpoint endpoint = new WebSocketHandlerEndpoint(getWebSocketHandler());
|
Endpoint endpoint = new WebSocketHandlerEndpoint(getWebSocketHandler());
|
||||||
this.upgradeStrategy.upgrade(request, response, protocol, new EndpointRegistration("/dummy", endpoint));
|
this.upgradeStrategy.upgrade(request, response, protocol, endpoint);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -16,6 +16,8 @@
|
||||||
|
|
||||||
package org.springframework.websocket.server.endpoint;
|
package org.springframework.websocket.server.endpoint;
|
||||||
|
|
||||||
|
import javax.websocket.Endpoint;
|
||||||
|
|
||||||
import org.springframework.http.server.ServerHttpRequest;
|
import org.springframework.http.server.ServerHttpRequest;
|
||||||
import org.springframework.http.server.ServerHttpResponse;
|
import org.springframework.http.server.ServerHttpResponse;
|
||||||
|
|
||||||
|
@ -26,12 +28,14 @@ import org.springframework.http.server.ServerHttpResponse;
|
||||||
* @author Rossen Stoyanchev
|
* @author Rossen Stoyanchev
|
||||||
* @since 4.0
|
* @since 4.0
|
||||||
*/
|
*/
|
||||||
public interface WebSocketRequestUpgradeStrategy {
|
public interface EndpointRequestUpgradeStrategy {
|
||||||
|
|
||||||
|
String[] getSupportedVersions();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Invoked after the handshake checks have been performed and succeeded.
|
* Invoked after the handshake checks have been performed and succeeded.
|
||||||
*/
|
*/
|
||||||
void upgrade(ServerHttpRequest request, ServerHttpResponse response, String protocol,
|
void upgrade(ServerHttpRequest request, ServerHttpResponse response, String protocol, Endpoint endpoint)
|
||||||
EndpointRegistration registration) throws Exception;
|
throws Exception;
|
||||||
|
|
||||||
}
|
}
|
|
@ -0,0 +1,152 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2002-2013 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.springframework.websocket.server.endpoint.support;
|
||||||
|
|
||||||
|
import java.lang.reflect.Constructor;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
import javax.servlet.http.HttpServletResponse;
|
||||||
|
import javax.servlet.http.HttpServletResponseWrapper;
|
||||||
|
import javax.websocket.Endpoint;
|
||||||
|
|
||||||
|
import org.glassfish.tyrus.core.ComponentProviderService;
|
||||||
|
import org.glassfish.tyrus.core.EndpointWrapper;
|
||||||
|
import org.glassfish.tyrus.core.ErrorCollector;
|
||||||
|
import org.glassfish.tyrus.core.RequestContext;
|
||||||
|
import org.glassfish.tyrus.server.TyrusEndpoint;
|
||||||
|
import org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler;
|
||||||
|
import org.glassfish.tyrus.websockets.Connection;
|
||||||
|
import org.glassfish.tyrus.websockets.Version;
|
||||||
|
import org.glassfish.tyrus.websockets.WebSocketEngine;
|
||||||
|
import org.glassfish.tyrus.websockets.WebSocketEngine.WebSocketHolderListener;
|
||||||
|
import org.springframework.http.HttpHeaders;
|
||||||
|
import org.springframework.http.HttpStatus;
|
||||||
|
import org.springframework.http.server.ServerHttpRequest;
|
||||||
|
import org.springframework.http.server.ServerHttpResponse;
|
||||||
|
import org.springframework.http.server.ServletServerHttpRequest;
|
||||||
|
import org.springframework.http.server.ServletServerHttpResponse;
|
||||||
|
import org.springframework.util.Assert;
|
||||||
|
import org.springframework.util.ClassUtils;
|
||||||
|
import org.springframework.util.ReflectionUtils;
|
||||||
|
import org.springframework.util.StringUtils;
|
||||||
|
import org.springframework.websocket.server.endpoint.EndpointRegistration;
|
||||||
|
import org.springframework.websocket.server.endpoint.EndpointRequestUpgradeStrategy;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author Rossen Stoyanchev
|
||||||
|
* @since 4.0
|
||||||
|
*/
|
||||||
|
public class GlassfishRequestUpgradeStrategy implements EndpointRequestUpgradeStrategy {
|
||||||
|
|
||||||
|
private final static Random random = new Random();
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String[] getSupportedVersions() {
|
||||||
|
return StringUtils.commaDelimitedListToStringArray(Version.getSupportedWireProtocolVersions());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String protocol,
|
||||||
|
Endpoint endpoint) throws Exception {
|
||||||
|
|
||||||
|
Assert.isTrue(request instanceof ServletServerHttpRequest);
|
||||||
|
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
|
||||||
|
|
||||||
|
Assert.isTrue(response instanceof ServletServerHttpResponse);
|
||||||
|
HttpServletResponse servletResponse = ((ServletServerHttpResponse) response).getServletResponse();
|
||||||
|
servletResponse = new AlreadyUpgradedResponseWrapper(servletResponse);
|
||||||
|
|
||||||
|
TyrusEndpoint tyrusEndpoint = createTyrusEndpoint(servletRequest, endpoint);
|
||||||
|
WebSocketEngine.getEngine().register(tyrusEndpoint);
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (!performUpgrade(servletRequest, servletResponse, request.getHeaders(), tyrusEndpoint)) {
|
||||||
|
throw new IllegalStateException("Failed to upgrade HttpServletRequest");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
finally {
|
||||||
|
WebSocketEngine.getEngine().unregister(tyrusEndpoint);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean performUpgrade(HttpServletRequest request, HttpServletResponse response,
|
||||||
|
HttpHeaders headers, TyrusEndpoint tyrusEndpoint) throws Exception {
|
||||||
|
|
||||||
|
final TyrusHttpUpgradeHandler upgradeHandler = request.upgrade(TyrusHttpUpgradeHandler.class);
|
||||||
|
|
||||||
|
Connection connection = createConnection(upgradeHandler, response);
|
||||||
|
|
||||||
|
RequestContext wsRequest = RequestContext.Builder.create()
|
||||||
|
.requestURI(URI.create(tyrusEndpoint.getPath())).requestPath(tyrusEndpoint.getPath())
|
||||||
|
.connection(connection).secure(request.isSecure()).build();
|
||||||
|
|
||||||
|
for (String header : headers.keySet()) {
|
||||||
|
wsRequest.getHeaders().put(header, headers.get(header));
|
||||||
|
}
|
||||||
|
|
||||||
|
return WebSocketEngine.getEngine().upgrade(connection, wsRequest, new WebSocketHolderListener() {
|
||||||
|
@Override
|
||||||
|
public void onWebSocketHolder(WebSocketEngine.WebSocketHolder webSocketHolder) {
|
||||||
|
upgradeHandler.setWebSocketHolder(webSocketHolder);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private TyrusEndpoint createTyrusEndpoint(HttpServletRequest request, Endpoint endpoint) {
|
||||||
|
|
||||||
|
// Use randomized path
|
||||||
|
String requestUri = request.getRequestURI();
|
||||||
|
String randomValue = String.valueOf(random.nextLong());
|
||||||
|
String endpointPath = requestUri.endsWith("/") ? requestUri + randomValue : requestUri + "/" + randomValue;
|
||||||
|
|
||||||
|
EndpointRegistration endpointConfig = new EndpointRegistration(endpointPath, endpoint);
|
||||||
|
|
||||||
|
return new TyrusEndpoint(new EndpointWrapper(endpoint, endpointConfig,
|
||||||
|
ComponentProviderService.create(), null, "/", new ErrorCollector(),
|
||||||
|
endpointConfig.getConfigurator()));
|
||||||
|
}
|
||||||
|
|
||||||
|
private Connection createConnection(TyrusHttpUpgradeHandler handler, HttpServletResponse response) throws Exception {
|
||||||
|
String name = "org.glassfish.tyrus.servlet.ConnectionImpl";
|
||||||
|
Class<?> clazz = ClassUtils.forName(name, GlassfishRequestUpgradeStrategy.class.getClassLoader());
|
||||||
|
Constructor<?> constructor = clazz.getDeclaredConstructor(TyrusHttpUpgradeHandler.class, HttpServletResponse.class);
|
||||||
|
ReflectionUtils.makeAccessible(constructor);
|
||||||
|
return (Connection) constructor.newInstance(handler, response);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private static class AlreadyUpgradedResponseWrapper extends HttpServletResponseWrapper {
|
||||||
|
|
||||||
|
public AlreadyUpgradedResponseWrapper(HttpServletResponse response) {
|
||||||
|
super(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setStatus(int sc) {
|
||||||
|
Assert.isTrue(sc == HttpStatus.SWITCHING_PROTOCOLS.value(), "Unexpected status code " + sc);
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public void addHeader(String name, String value) {
|
||||||
|
// ignore
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -14,12 +14,14 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.springframework.websocket.server.endpoint;
|
package org.springframework.websocket.server.endpoint.support;
|
||||||
|
|
||||||
import java.lang.reflect.Method;
|
import java.lang.reflect.Method;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
import javax.websocket.Endpoint;
|
||||||
|
import javax.websocket.server.ServerEndpointConfig;
|
||||||
|
|
||||||
import org.apache.tomcat.websocket.server.WsHandshakeRequest;
|
import org.apache.tomcat.websocket.server.WsHandshakeRequest;
|
||||||
import org.apache.tomcat.websocket.server.WsHttpUpgradeHandler;
|
import org.apache.tomcat.websocket.server.WsHttpUpgradeHandler;
|
||||||
|
@ -29,6 +31,8 @@ import org.springframework.http.server.ServerHttpResponse;
|
||||||
import org.springframework.http.server.ServletServerHttpRequest;
|
import org.springframework.http.server.ServletServerHttpRequest;
|
||||||
import org.springframework.util.Assert;
|
import org.springframework.util.Assert;
|
||||||
import org.springframework.util.ReflectionUtils;
|
import org.springframework.util.ReflectionUtils;
|
||||||
|
import org.springframework.websocket.server.endpoint.EndpointRegistration;
|
||||||
|
import org.springframework.websocket.server.endpoint.EndpointRequestUpgradeStrategy;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -36,26 +40,35 @@ import org.springframework.util.ReflectionUtils;
|
||||||
* @author Rossen Stoyanchev
|
* @author Rossen Stoyanchev
|
||||||
* @since 4.0
|
* @since 4.0
|
||||||
*/
|
*/
|
||||||
public class TomcatRequestUpgradeStrategy implements WebSocketRequestUpgradeStrategy {
|
public class TomcatRequestUpgradeStrategy implements EndpointRequestUpgradeStrategy {
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String[] getSupportedVersions() {
|
||||||
|
return new String[] { "13" };
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String protocol,
|
public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String protocol,
|
||||||
EndpointRegistration registration) throws Exception {
|
Endpoint endpoint) throws Exception {
|
||||||
|
|
||||||
Assert.isTrue(request instanceof ServletServerHttpRequest);
|
Assert.isTrue(request instanceof ServletServerHttpRequest);
|
||||||
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
|
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
|
||||||
|
|
||||||
WsHttpUpgradeHandler wsHandler = servletRequest.upgrade(WsHttpUpgradeHandler.class);
|
WsHttpUpgradeHandler upgradeHandler = servletRequest.upgrade(WsHttpUpgradeHandler.class);
|
||||||
|
|
||||||
WsHandshakeRequest wsRequest = new WsHandshakeRequest(servletRequest);
|
WsHandshakeRequest webSocketRequest = new WsHandshakeRequest(servletRequest);
|
||||||
Method method = ReflectionUtils.findMethod(WsHandshakeRequest.class, "finished");
|
Method method = ReflectionUtils.findMethod(WsHandshakeRequest.class, "finished");
|
||||||
ReflectionUtils.makeAccessible(method);
|
ReflectionUtils.makeAccessible(method);
|
||||||
method.invoke(wsRequest);
|
method.invoke(webSocketRequest);
|
||||||
|
|
||||||
wsHandler.preInit(registration.getEndpoint(), registration,
|
// TODO: use ServletContext attribute when Tomcat is updated
|
||||||
WsServerContainer.getServerContainer(), wsRequest, protocol,
|
WsServerContainer serverContainer = WsServerContainer.getServerContainer();
|
||||||
Collections.<String, String> emptyMap(), servletRequest.isSecure());
|
|
||||||
|
ServerEndpointConfig endpointConfig = new EndpointRegistration("/shouldntmatter", endpoint);
|
||||||
|
|
||||||
|
upgradeHandler.preInit(endpoint, endpointConfig, serverContainer, webSocketRequest,
|
||||||
|
protocol, Collections.<String, String> emptyMap(), servletRequest.isSecure());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue