From 8200601acebfac7e768942acf9f2f6688c19fc38 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Thu, 25 Apr 2013 18:23:16 -0400 Subject: [PATCH] Tighten up exception handling strategy WebSocketHandler implementations: - methods must deal with exceptions locally - uncaught runtime exceptions are handled by ending the session - transport errors (websocket engine) are passed into handleError WebSocketSession methods may raise IOException SockJS implementation of WebSocketHandler: - delegate SockJS transport errors into handleError - stop runtime exceptions from user WebSocketHandler and end session SockJsServce and TransportHandlers: - raise IOException or TransportErrorException HandshakeHandler: - raise IOException --- .../sockjs/AbstractSockJsSession.java | 84 +++++++++++++--- .../server/AbstractServerSockJsSession.java | 16 ++-- .../sockjs/server/AbstractSockJsService.java | 14 +-- .../sockjs/server/SockJsService.java | 4 +- .../server/TransportErrorException.java | 55 +++++++++++ .../sockjs/server/TransportHandler.java | 2 +- .../server/support/DefaultSockJsService.java | 7 +- ...AbstractHttpReceivingTransportHandler.java | 28 ++++-- .../AbstractHttpSendingTransportHandler.java | 34 +++---- .../AbstractHttpServerSockJsSession.java | 95 ++++++++++++------- .../AbstractStreamingTransportHandler.java | 62 ------------ .../EventSourceTransportHandler.java | 20 ++-- .../transport/HtmlFileTransportHandler.java | 49 ++++++---- .../JsonpPollingTransportHandler.java | 19 ++-- .../transport/JsonpTransportHandler.java | 13 ++- .../transport/PollingServerSockJsSession.java | 6 +- .../transport/SockJsWebSocketHandler.java | 83 ++++++++-------- .../StreamingServerSockJsSession.java | 22 ++++- .../transport/WebSocketTransportHandler.java | 16 +++- .../XhrStreamingTransportHandler.java | 26 +++-- .../websocket/BinaryMessageHandler.java | 3 +- .../websocket/TextMessageHandler.java | 3 +- .../websocket/WebSocketHandler.java | 6 +- .../websocket/WebSocketHandlerAdapter.java | 6 +- .../websocket/WebSocketSession.java | 8 +- .../endpoint/WebSocketHandlerEndpoint.java | 40 ++++++-- .../server/DefaultHandshakeHandler.java | 15 ++- .../websocket/server/HandshakeHandler.java | 4 +- .../server/RequestUpgradeStrategy.java | 4 +- .../AbstractEndpointUpgradeStrategy.java | 6 +- .../GlassfishRequestUpgradeStrategy.java | 31 ++++-- .../support/JettyRequestUpgradeStrategy.java | 91 ++++++++++-------- 32 files changed, 556 insertions(+), 316 deletions(-) create mode 100644 spring-websocket/src/main/java/org/springframework/sockjs/server/TransportErrorException.java delete mode 100644 spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractStreamingTransportHandler.java diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/AbstractSockJsSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/AbstractSockJsSession.java index 24c5e84cbb6..ee8157e5b64 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/AbstractSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/AbstractSockJsSession.java @@ -16,6 +16,7 @@ package org.springframework.sockjs; +import java.io.IOException; import java.net.URI; import org.apache.commons.logging.Log; @@ -121,10 +122,15 @@ public abstract class AbstractSockJsSession implements WebSocketSession { this.timeLastActive = System.currentTimeMillis(); } - public void delegateConnectionEstablished() throws Exception { + public void delegateConnectionEstablished() { this.state = State.OPEN; initHandler(); - this.handler.afterConnectionEstablished(this); + try { + this.handler.afterConnectionEstablished(this); + } + catch (Throwable ex) { + tryCloseWithError(ex, null); + } } private void initHandler() { @@ -133,14 +139,61 @@ public abstract class AbstractSockJsSession implements WebSocketSession { this.handler = (TextMessageHandler) webSocketHandler; } - public void delegateMessages(String[] messages) throws Exception { - for (String message : messages) { - this.handler.handleTextMessage(new TextMessage(message), this); + /** + * Close due to unhandled runtime error from WebSocketHandler. + * @param closeStatus TODO + */ + private void tryCloseWithError(Throwable ex, CloseStatus closeStatus) { + logger.error("Unhandled error for " + this, ex); + try { + closeStatus = (closeStatus != null) ? closeStatus : CloseStatus.SERVER_ERROR; + close(closeStatus); + } + catch (Throwable t) { + destroyHandler(); } } - public void delegateError(Throwable ex) throws Exception { - this.handler.handleError(ex, this); + private void destroyHandler() { + try { + if (this.handler != null) { + this.handlerProvider.destroy(this.handler); + } + } + catch (Throwable t) { + logger.warn("Error while destroying handler", t); + } + finally { + this.handler = null; + } + } + + /** + * Close due to error arising from SockJS transport handling. + */ + protected void tryCloseWithSockJsTransportError(Throwable ex, CloseStatus closeStatus) { + delegateError(ex); + tryCloseWithError(ex, closeStatus); + } + + public void delegateMessages(String[] messages) { + try { + for (String message : messages) { + this.handler.handleTextMessage(new TextMessage(message), this); + } + } + catch (Throwable ex) { + tryCloseWithError(ex, null); + } + } + + public void delegateError(Throwable ex) { + try { + this.handler.handleTransportError(ex, this); + } + catch (Throwable t) { + tryCloseWithError(t, null); + } } /** @@ -149,7 +202,7 @@ public abstract class AbstractSockJsSession implements WebSocketSession { * {@link TextMessageHandler}. This is in contrast to {@link #close()} that pro-actively * closes the connection. */ - public final void delegateConnectionClosed(CloseStatus status) throws Exception { + public final void delegateConnectionClosed(CloseStatus status) { if (!isClosed()) { if (logger.isDebugEnabled()) { logger.debug(this + " was closed, " + status); @@ -159,7 +212,12 @@ public abstract class AbstractSockJsSession implements WebSocketSession { } finally { this.state = State.CLOSED; - this.handler.afterConnectionClosed(status, this); + try { + this.handler.afterConnectionClosed(status, this); + } + finally { + destroyHandler(); + } } } } @@ -171,7 +229,7 @@ public abstract class AbstractSockJsSession implements WebSocketSession { * {@inheritDoc} *

Performs cleanup and notifies the {@link SockJsHandler}. */ - public final void close() throws Exception { + public final void close() throws IOException { close(CloseStatus.NORMAL); } @@ -179,7 +237,7 @@ public abstract class AbstractSockJsSession implements WebSocketSession { * {@inheritDoc} *

Performs cleanup and notifies the {@link SockJsHandler}. */ - public final void close(CloseStatus status) throws Exception { + public final void close(CloseStatus status) throws IOException { if (!isClosed()) { if (logger.isDebugEnabled()) { logger.debug("Closing " + this + ", " + status); @@ -193,13 +251,13 @@ public abstract class AbstractSockJsSession implements WebSocketSession { this.handler.afterConnectionClosed(status, this); } finally { - this.handlerProvider.destroy(this.handler); + destroyHandler(); } } } } - protected abstract void closeInternal(CloseStatus status) throws Exception; + protected abstract void closeInternal(CloseStatus status) throws IOException; @Override diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractServerSockJsSession.java index a39e5d67897..13ff66c27af 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractServerSockJsSession.java @@ -56,13 +56,13 @@ public abstract class AbstractServerSockJsSession extends AbstractSockJsSession return this.sockJsConfig; } - public final synchronized void sendMessage(WebSocketMessage message) throws Exception { - Assert.isTrue(!isClosed(), "Cannot send a message, session has been closed"); + public final synchronized void sendMessage(WebSocketMessage message) throws IOException { + Assert.isTrue(!isClosed(), "Cannot send a message when session is closed"); Assert.isInstanceOf(TextMessage.class, message, "Expected text message: " + message); sendMessageInternal(((TextMessage) message).getPayload()); } - protected abstract void sendMessageInternal(String message) throws Exception; + protected abstract void sendMessageInternal(String message) throws IOException; @Override @@ -72,7 +72,7 @@ public abstract class AbstractServerSockJsSession extends AbstractSockJsSession } @Override - public final synchronized void closeInternal(CloseStatus status) throws Exception { + public final synchronized void closeInternal(CloseStatus status) throws IOException { if (isActive()) { // TODO: deliver messages "in flight" before sending close frame try { @@ -89,13 +89,13 @@ public abstract class AbstractServerSockJsSession extends AbstractSockJsSession } // TODO: close status/reason - protected abstract void disconnect(CloseStatus status) throws Exception; + protected abstract void disconnect(CloseStatus status) throws IOException; /** * For internal use within a TransportHandler and the (TransportHandler-specific) * session sub-class. */ - protected void writeFrame(SockJsFrame frame) throws Exception { + protected void writeFrame(SockJsFrame frame) throws IOException { if (logger.isTraceEnabled()) { logger.trace("Preparing to write " + frame); } @@ -115,7 +115,7 @@ public abstract class AbstractServerSockJsSession extends AbstractSockJsSession catch (Throwable ex) { logger.warn("Terminating connection due to failure to send message: " + ex.getMessage()); close(); - throw new NestedSockJsRuntimeException("Failed to write frame " + frame, ex); + throw new NestedSockJsRuntimeException("Failed to write " + frame, ex); } } @@ -140,7 +140,7 @@ public abstract class AbstractServerSockJsSession extends AbstractSockJsSession try { sendHeartbeat(); } - catch (Exception e) { + catch (Throwable t) { // ignore } } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java index e44c44cec0d..755c8ee4b6d 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java @@ -201,7 +201,8 @@ public abstract class AbstractSockJsService implements SockJsService, SockJsConf * @throws Exception */ public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, - String sockJsPath, HandlerProvider handler) throws Exception { + String sockJsPath, HandlerProvider handler) + throws IOException, TransportErrorException { logger.debug(request.getMethod() + " [" + sockJsPath + "]"); @@ -255,10 +256,11 @@ public abstract class AbstractSockJsService implements SockJsService, SockJsConf } protected abstract void handleRawWebSocketRequest(ServerHttpRequest request, ServerHttpResponse response, - HandlerProvider handler) throws Exception; + HandlerProvider handler) throws IOException; protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response, - String sessionId, TransportType transportType, HandlerProvider handler) throws Exception; + String sessionId, TransportType transportType, HandlerProvider handler) + throws IOException, TransportErrorException; protected boolean validateRequest(String serverId, String sessionId, String transport) { @@ -321,7 +323,7 @@ public abstract class AbstractSockJsService implements SockJsService, SockJsConf private interface SockJsRequestHandler { - void handle(ServerHttpRequest request, ServerHttpResponse response) throws Exception; + void handle(ServerHttpRequest request, ServerHttpResponse response) throws IOException; } private static final Random random = new Random(); @@ -331,7 +333,7 @@ public abstract class AbstractSockJsService implements SockJsService, SockJsConf private static final String INFO_CONTENT = "{\"entropy\":%s,\"origins\":[\"*:*\"],\"cookie_needed\":%s,\"websocket\":%s}"; - public void handle(ServerHttpRequest request, ServerHttpResponse response) throws Exception { + public void handle(ServerHttpRequest request, ServerHttpResponse response) throws IOException { if (HttpMethod.GET.equals(request.getMethod())) { @@ -376,7 +378,7 @@ public abstract class AbstractSockJsService implements SockJsService, SockJsConf "\n" + ""; - public void handle(ServerHttpRequest request, ServerHttpResponse response) throws Exception { + public void handle(ServerHttpRequest request, ServerHttpResponse response) throws IOException { if (!HttpMethod.GET.equals(request.getMethod())) { sendMethodNotAllowed(response, Arrays.asList(HttpMethod.GET)); diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsService.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsService.java index eeef0913d19..8df23d21342 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsService.java @@ -16,6 +16,8 @@ package org.springframework.sockjs.server; +import java.io.IOException; + import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.websocket.HandlerProvider; @@ -31,6 +33,6 @@ public interface SockJsService { void handleRequest(ServerHttpRequest request, ServerHttpResponse response, String sockJsPath, - HandlerProvider handler) throws Exception; + HandlerProvider handler) throws IOException, TransportErrorException; } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportErrorException.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportErrorException.java new file mode 100644 index 00000000000..90eb5e81702 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportErrorException.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.sockjs.server; + +import org.springframework.core.NestedRuntimeException; +import org.springframework.websocket.WebSocketHandler; + + +/** + * Raised when a TransportHandler fails during request processing. + * + *

If the underlying exception occurs while sending messages to the client, + * the session will have been closed and the {@link WebSocketHandler} notified. + * + *

If the underlying exception occurs while processing an incoming HTTP request + * including posted messages, the session will remain open. Only the incoming + * request is rejected. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +@SuppressWarnings("serial") +public class TransportErrorException extends NestedRuntimeException { + + private final String sockJsSessionId; + + public TransportErrorException(String msg, Throwable cause, String sockJsSessionId) { + super(msg, cause); + this.sockJsSessionId = sockJsSessionId; + } + + public String getSockJsSessionId() { + return sockJsSessionId; + } + + @Override + public String getMessage() { + return "Transport error for SockJS session id=" + this.sockJsSessionId + ", " + super.getMessage(); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandler.java index 1d0eedb7664..6364a1c83ff 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandler.java @@ -32,6 +32,6 @@ public interface TransportHandler { TransportType getTransportType(); void handleRequest(ServerHttpRequest request, ServerHttpResponse response, - HandlerProvider handler, AbstractSockJsSession session) throws Exception; + HandlerProvider handler, AbstractSockJsSession session) throws TransportErrorException; } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java index 3d389353e95..925ba0305a9 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java @@ -15,6 +15,7 @@ */ package org.springframework.sockjs.server.support; +import java.io.IOException; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -37,6 +38,7 @@ import org.springframework.sockjs.SockJsSessionFactory; import org.springframework.sockjs.server.AbstractSockJsService; import org.springframework.sockjs.server.ConfigurableTransportHandler; import org.springframework.sockjs.server.SockJsService; +import org.springframework.sockjs.server.TransportErrorException; import org.springframework.sockjs.server.TransportHandler; import org.springframework.sockjs.server.TransportType; import org.springframework.sockjs.server.transport.EventSourceTransportHandler; @@ -140,7 +142,7 @@ public class DefaultSockJsService extends AbstractSockJsService { @Override protected void handleRawWebSocketRequest(ServerHttpRequest request, ServerHttpResponse response, - HandlerProvider handler) throws Exception { + HandlerProvider handler) throws IOException { if (isWebSocketEnabled()) { TransportHandler transportHandler = this.transportHandlers.get(TransportType.WEBSOCKET); @@ -157,7 +159,8 @@ public class DefaultSockJsService extends AbstractSockJsService { @Override protected void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response, - String sessionId, TransportType transportType, HandlerProvider handler) throws Exception { + String sessionId, TransportType transportType, HandlerProvider handler) + throws IOException, TransportErrorException { TransportHandler transportHandler = this.transportHandlers.get(transportType); diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpReceivingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpReceivingTransportHandler.java index d3928730a48..99a59a723fd 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpReceivingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpReceivingTransportHandler.java @@ -26,6 +26,7 @@ import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.sockjs.AbstractSockJsSession; +import org.springframework.sockjs.server.TransportErrorException; import org.springframework.sockjs.server.TransportHandler; import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; @@ -54,7 +55,8 @@ public abstract class AbstractHttpReceivingTransportHandler implements Transport @Override public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, - HandlerProvider webSocketHandler, AbstractSockJsSession session) throws Exception { + HandlerProvider webSocketHandler, AbstractSockJsSession session) + throws TransportErrorException { if (session == null) { response.setStatusCode(HttpStatus.NOT_FOUND); @@ -65,20 +67,22 @@ public abstract class AbstractHttpReceivingTransportHandler implements Transport } protected void handleRequestInternal(ServerHttpRequest request, ServerHttpResponse response, - AbstractSockJsSession session) throws Exception { + AbstractSockJsSession session) throws TransportErrorException { String[] messages = null; try { messages = readMessages(request); } catch (JsonMappingException ex) { - response.setStatusCode(HttpStatus.INTERNAL_SERVER_ERROR); - response.getBody().write("Payload expected.".getBytes("UTF-8")); + sendInternalServerError(response, "Payload expected.", session.getId()); return; } catch (IOException ex) { - response.setStatusCode(HttpStatus.INTERNAL_SERVER_ERROR); - response.getBody().write("Broken JSON encoding.".getBytes("UTF-8")); + sendInternalServerError(response, "Broken JSON encoding.", session.getId()); + return; + } + catch (Throwable t) { + sendInternalServerError(response, "Failed to process messages", session.getId()); return; } @@ -92,6 +96,18 @@ public abstract class AbstractHttpReceivingTransportHandler implements Transport response.getHeaders().setContentType(new MediaType("text", "plain", Charset.forName("UTF-8"))); } + protected void sendInternalServerError(ServerHttpResponse response, String error, + String sessionId) throws TransportErrorException { + + try { + response.setStatusCode(HttpStatus.INTERNAL_SERVER_ERROR); + response.getBody().write(error.getBytes("UTF-8")); + } + catch (Throwable t) { + throw new TransportErrorException("Failed to send error message to client", t, sessionId); + } + } + protected abstract String[] readMessages(ServerHttpRequest request) throws IOException; protected abstract HttpStatus getResponseStatus(); diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpSendingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpSendingTransportHandler.java index e2218715b01..e2aa777fb54 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpSendingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpSendingTransportHandler.java @@ -27,6 +27,7 @@ import org.springframework.sockjs.SockJsSessionFactory; import org.springframework.sockjs.server.ConfigurableTransportHandler; import org.springframework.sockjs.server.SockJsConfiguration; import org.springframework.sockjs.server.SockJsFrame; +import org.springframework.sockjs.server.TransportErrorException; import org.springframework.sockjs.server.SockJsFrame.FrameFormat; import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; @@ -56,7 +57,8 @@ public abstract class AbstractHttpSendingTransportHandler @Override public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, - HandlerProvider webSocketHandler, AbstractSockJsSession session) throws Exception { + HandlerProvider webSocketHandler, AbstractSockJsSession session) + throws TransportErrorException { // Set content type before writing response.getHeaders().setContentType(getContentType()); @@ -66,30 +68,28 @@ public abstract class AbstractHttpSendingTransportHandler } protected void handleRequestInternal(ServerHttpRequest request, ServerHttpResponse response, - AbstractHttpServerSockJsSession httpServerSession) throws Exception, IOException { + AbstractHttpServerSockJsSession httpServerSession) throws TransportErrorException { if (httpServerSession.isNew()) { - handleNewSession(request, response, httpServerSession); + logger.debug("Opening " + getTransportType() + " connection"); + httpServerSession.setInitialRequest(request, response, getFrameFormat(request)); } - else if (httpServerSession.isActive()) { - logger.debug("another " + getTransportType() + " connection still open: " + httpServerSession); - httpServerSession.writeFrame(response, SockJsFrame.closeFrameAnotherConnectionOpen()); + else if (!httpServerSession.isActive()) { + logger.debug("starting " + getTransportType() + " async request"); + httpServerSession.setLongPollingRequest(request, response, getFrameFormat(request)); } else { - logger.debug("starting " + getTransportType() + " async request"); - httpServerSession.setCurrentRequest(request, response, getFrameFormat(request)); + try { + logger.debug("another " + getTransportType() + " connection still open: " + httpServerSession); + SockJsFrame closeFrame = SockJsFrame.closeFrameAnotherConnectionOpen(); + response.getBody().write(getFrameFormat(request).format(closeFrame).getContentBytes()); + } + catch (IOException e) { + throw new TransportErrorException("Failed to send SockJS close frame", e, httpServerSession.getId()); + } } } - protected void handleNewSession(ServerHttpRequest request, ServerHttpResponse response, - AbstractHttpServerSockJsSession session) throws Exception { - - logger.debug("Opening " + getTransportType() + " connection"); - session.setFrameFormat(getFrameFormat(request)); - session.writeFrame(response, SockJsFrame.openFrame()); - session.delegateConnectionEstablished(); - } - protected abstract MediaType getContentType(); protected abstract FrameFormat getFrameFormat(ServerHttpRequest request); diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpServerSockJsSession.java index 72aa97e7c11..7e6742a1e1c 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpServerSockJsSession.java @@ -25,8 +25,8 @@ import org.springframework.http.server.ServerHttpResponse; import org.springframework.sockjs.server.AbstractServerSockJsSession; import org.springframework.sockjs.server.SockJsConfiguration; import org.springframework.sockjs.server.SockJsFrame; +import org.springframework.sockjs.server.TransportErrorException; import org.springframework.sockjs.server.SockJsFrame.FrameFormat; -import org.springframework.sockjs.server.TransportHandler; import org.springframework.util.Assert; import org.springframework.websocket.CloseStatus; import org.springframework.websocket.HandlerProvider; @@ -55,32 +55,64 @@ public abstract class AbstractHttpServerSockJsSession extends AbstractServerSock super(sessionId, sockJsConfig, handler); } - public void setFrameFormat(FrameFormat frameFormat) { - this.frameFormat = frameFormat; + public synchronized void setInitialRequest(ServerHttpRequest request, ServerHttpResponse response, + FrameFormat frameFormat) throws TransportErrorException { + + try { + udpateRequest(request, response, frameFormat); + writePrelude(); + writeFrame(SockJsFrame.openFrame()); + } + catch (Throwable t) { + tryCloseWithSockJsTransportError(t, null); + throw new TransportErrorException("Failed open SockJS session", t, getId()); + } + delegateConnectionEstablished(); } - public synchronized void setCurrentRequest(ServerHttpRequest request, ServerHttpResponse response, - FrameFormat frameFormat) throws Exception { + protected void writePrelude() throws IOException { + } - if (isClosed()) { - logger.debug("connection already closed"); - writeFrame(response, SockJsFrame.closeFrameGoAway()); - return; + public synchronized void setLongPollingRequest(ServerHttpRequest request, ServerHttpResponse response, + FrameFormat frameFormat) throws TransportErrorException { + + try { + udpateRequest(request, response, frameFormat); + + if (isClosed()) { + logger.debug("connection already closed"); + try { + writeFrame(SockJsFrame.closeFrameGoAway()); + } + catch (IOException ex) { + throw new TransportErrorException("Failed to send SockJS close frame", ex, getId()); + } + return; + } + + this.asyncRequest.setTimeout(-1); + this.asyncRequest.startAsync(); + + scheduleHeartbeat(); + tryFlushCache(); } + catch (Throwable t) { + tryCloseWithSockJsTransportError(t, null); + throw new TransportErrorException("Failed to start long running request and flush messages", t, getId()); + } + } + private void udpateRequest(ServerHttpRequest request, ServerHttpResponse response, FrameFormat frameFormat) { + Assert.notNull(request, "expected request"); + Assert.notNull(response, "expected response"); + Assert.notNull(frameFormat, "expected frameFormat"); Assert.isInstanceOf(AsyncServerHttpRequest.class, request, "Expected AsyncServerHttpRequest"); - this.asyncRequest = (AsyncServerHttpRequest) request; - this.asyncRequest.setTimeout(-1); - this.asyncRequest.startAsync(); - this.response = response; this.frameFormat = frameFormat; - - scheduleHeartbeat(); - tryFlushCache(); } + public synchronized boolean isActive() { return ((this.asyncRequest != null) && (!this.asyncRequest.isAsyncCompleted())); } @@ -89,18 +121,20 @@ public abstract class AbstractHttpServerSockJsSession extends AbstractServerSock return this.messageCache; } + protected ServerHttpRequest getRequest() { + return this.asyncRequest; + } + protected ServerHttpResponse getResponse() { return this.response; } - protected final synchronized void sendMessageInternal(String message) throws Exception { - // assert close() was not called - // threads: TH-Session-Endpoint or any other thread + protected final synchronized void sendMessageInternal(String message) throws IOException { this.messageCache.add(message); tryFlushCache(); } - private void tryFlushCache() throws Exception { + private void tryFlushCache() throws IOException { if (isActive() && !getMessageCache().isEmpty()) { logger.trace("Flushing messages"); flushCache(); @@ -110,7 +144,7 @@ public abstract class AbstractHttpServerSockJsSession extends AbstractServerSock /** * Only called if the connection is currently active */ - protected abstract void flushCache() throws Exception; + protected abstract void flushCache() throws IOException; @Override protected void disconnect(CloseStatus status) { @@ -133,21 +167,12 @@ public abstract class AbstractHttpServerSockJsSession extends AbstractServerSock protected synchronized void writeFrameInternal(SockJsFrame frame) throws IOException { if (isActive()) { - writeFrame(this.response, frame); + frame = this.frameFormat.format(frame); + if (logger.isTraceEnabled()) { + logger.trace("Writing " + frame); + } + this.response.getBody().write(frame.getContentBytes()); } } - /** - * This method may be called by a {@link TransportHandler} to write a frame - * even when the connection is not active, as long as a valid OutputStream - * is provided. - */ - public void writeFrame(ServerHttpResponse response, SockJsFrame frame) throws IOException { - frame = this.frameFormat.format(frame); - if (logger.isTraceEnabled()) { - logger.trace("Writing " + frame); - } - response.getBody().write(frame.getContentBytes()); - } - } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractStreamingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractStreamingTransportHandler.java deleted file mode 100644 index ed626a29615..00000000000 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractStreamingTransportHandler.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.sockjs.server.transport; - -import java.io.IOException; - -import org.springframework.http.server.ServerHttpRequest; -import org.springframework.http.server.ServerHttpResponse; -import org.springframework.util.Assert; -import org.springframework.websocket.HandlerProvider; -import org.springframework.websocket.WebSocketHandler; - - -/** - * TODO - * - * @author Rossen Stoyanchev - * @since 4.0 - */ -public abstract class AbstractStreamingTransportHandler extends AbstractHttpSendingTransportHandler { - - - @Override - public StreamingServerSockJsSession createSession(String sessionId, HandlerProvider handler) { - Assert.notNull(getSockJsConfig(), "This transport requires SockJsConfiguration"); - return new StreamingServerSockJsSession(sessionId, getSockJsConfig(), handler); - } - - @Override - public void handleRequestInternal(ServerHttpRequest request, ServerHttpResponse response, - AbstractHttpServerSockJsSession session) throws Exception { - - writePrelude(request, response); - super.handleRequestInternal(request, response, session); - } - - protected abstract void writePrelude(ServerHttpRequest request, ServerHttpResponse response) - throws IOException; - - @Override - protected void handleNewSession(ServerHttpRequest request, ServerHttpResponse response, - AbstractHttpServerSockJsSession session) throws IOException, Exception { - - super.handleNewSession(request, response, session); - - session.setCurrentRequest(request, response, getFrameFormat(request)); - } - -} \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/EventSourceTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/EventSourceTransportHandler.java index ae8efbc6670..700fd34706a 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/EventSourceTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/EventSourceTransportHandler.java @@ -20,10 +20,12 @@ import java.nio.charset.Charset; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; -import org.springframework.http.server.ServerHttpResponse; import org.springframework.sockjs.server.SockJsFrame.DefaultFrameFormat; import org.springframework.sockjs.server.SockJsFrame.FrameFormat; import org.springframework.sockjs.server.TransportType; +import org.springframework.util.Assert; +import org.springframework.websocket.HandlerProvider; +import org.springframework.websocket.WebSocketHandler; /** @@ -32,7 +34,7 @@ import org.springframework.sockjs.server.TransportType; * @author Rossen Stoyanchev * @since 4.0 */ -public class EventSourceTransportHandler extends AbstractStreamingTransportHandler { +public class EventSourceTransportHandler extends AbstractHttpSendingTransportHandler { @Override @@ -46,10 +48,16 @@ public class EventSourceTransportHandler extends AbstractStreamingTransportHandl } @Override - protected void writePrelude(ServerHttpRequest request, ServerHttpResponse response) throws IOException { - response.getBody().write('\r'); - response.getBody().write('\n'); - response.flush(); + public StreamingServerSockJsSession createSession(String sessionId, HandlerProvider handler) { + Assert.notNull(getSockJsConfig(), "This transport requires SockJsConfiguration"); + return new StreamingServerSockJsSession(sessionId, getSockJsConfig(), handler) { + @Override + protected void writePrelude() throws IOException { + getResponse().getBody().write('\r'); + getResponse().getBody().write('\n'); + getResponse().flush(); + } + }; } @Override diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/HtmlFileTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/HtmlFileTransportHandler.java index eb752cd302c..ccd029d0c15 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/HtmlFileTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/HtmlFileTransportHandler.java @@ -24,9 +24,13 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.sockjs.server.SockJsFrame.DefaultFrameFormat; import org.springframework.sockjs.server.SockJsFrame.FrameFormat; +import org.springframework.sockjs.server.TransportErrorException; import org.springframework.sockjs.server.TransportType; +import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.util.JavaScriptUtils; +import org.springframework.websocket.HandlerProvider; +import org.springframework.websocket.WebSocketHandler; /** @@ -35,7 +39,7 @@ import org.springframework.web.util.JavaScriptUtils; * @author Rossen Stoyanchev * @since 4.0 */ -public class HtmlFileTransportHandler extends AbstractStreamingTransportHandler { +public class HtmlFileTransportHandler extends AbstractHttpSendingTransportHandler { private static final String PARTIAL_HTML_CONTENT; @@ -77,27 +81,40 @@ public class HtmlFileTransportHandler extends AbstractStreamingTransportHandler } @Override - public void handleRequestInternal(ServerHttpRequest request, ServerHttpResponse response, - AbstractHttpServerSockJsSession session) throws Exception { + public StreamingServerSockJsSession createSession(String sessionId, HandlerProvider handler) { + Assert.notNull(getSockJsConfig(), "This transport requires SockJsConfiguration"); - String callback = request.getQueryParams().getFirst("c"); - if (! StringUtils.hasText(callback)) { - response.setStatusCode(HttpStatus.INTERNAL_SERVER_ERROR); - response.getBody().write("\"callback\" parameter required".getBytes("UTF-8")); - return; - } - super.handleRequestInternal(request, response, session); + return new StreamingServerSockJsSession(sessionId, getSockJsConfig(), handler) { + + @Override + protected void writePrelude() throws IOException { + // we already validated the parameter.. + String callback = getRequest().getQueryParams().getFirst("c"); + + String html = String.format(PARTIAL_HTML_CONTENT, callback); + getResponse().getBody().write(html.getBytes("UTF-8")); + getResponse().flush(); + } + }; } @Override - protected void writePrelude(ServerHttpRequest request, ServerHttpResponse response) throws IOException { + public void handleRequestInternal(ServerHttpRequest request, ServerHttpResponse response, + AbstractHttpServerSockJsSession session) throws TransportErrorException { - // we already validated the parameter.. - String callback = request.getQueryParams().getFirst("c"); + try { + String callback = request.getQueryParams().getFirst("c"); + if (! StringUtils.hasText(callback)) { + response.setStatusCode(HttpStatus.INTERNAL_SERVER_ERROR); + response.getBody().write("\"callback\" parameter required".getBytes("UTF-8")); + return; + } + } + catch (Throwable t) { + throw new TransportErrorException("Failed to send error to client", t, session.getId()); + } - String html = String.format(PARTIAL_HTML_CONTENT, callback); - response.getBody().write(html.getBytes("UTF-8")); - response.flush(); + super.handleRequestInternal(request, response, session); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpPollingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpPollingTransportHandler.java index 3f5854d30bb..24c754e89dd 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpPollingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpPollingTransportHandler.java @@ -23,6 +23,7 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.sockjs.server.SockJsFrame; import org.springframework.sockjs.server.SockJsFrame.FrameFormat; +import org.springframework.sockjs.server.TransportErrorException; import org.springframework.sockjs.server.TransportType; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -58,14 +59,20 @@ public class JsonpPollingTransportHandler extends AbstractHttpSendingTransportHa @Override public void handleRequestInternal(ServerHttpRequest request, ServerHttpResponse response, - AbstractHttpServerSockJsSession session) throws Exception { + AbstractHttpServerSockJsSession session) throws TransportErrorException { - String callback = request.getQueryParams().getFirst("c"); - if (! StringUtils.hasText(callback)) { - response.setStatusCode(HttpStatus.INTERNAL_SERVER_ERROR); - response.getBody().write("\"callback\" parameter required".getBytes("UTF-8")); - return; + try { + String callback = request.getQueryParams().getFirst("c"); + if (! StringUtils.hasText(callback)) { + response.setStatusCode(HttpStatus.INTERNAL_SERVER_ERROR); + response.getBody().write("\"callback\" parameter required".getBytes("UTF-8")); + return; + } } + catch (Throwable t) { + throw new TransportErrorException("Failed to send error to client", t, session.getId()); + } + super.handleRequestInternal(request, response, session); } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpTransportHandler.java index 8d39e2d125a..5cbd6288e29 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpTransportHandler.java @@ -22,6 +22,7 @@ import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.sockjs.AbstractSockJsSession; +import org.springframework.sockjs.server.TransportErrorException; import org.springframework.sockjs.server.TransportType; public class JsonpTransportHandler extends AbstractHttpReceivingTransportHandler { @@ -34,19 +35,23 @@ public class JsonpTransportHandler extends AbstractHttpReceivingTransportHandler @Override public void handleRequestInternal(ServerHttpRequest request, ServerHttpResponse response, - AbstractSockJsSession sockJsSession) throws Exception { + AbstractSockJsSession sockJsSession) throws TransportErrorException { if (MediaType.APPLICATION_FORM_URLENCODED.equals(request.getHeaders().getContentType())) { if (request.getQueryParams().getFirst("d") == null) { - response.setStatusCode(HttpStatus.INTERNAL_SERVER_ERROR); - response.getBody().write("Payload expected.".getBytes("UTF-8")); + sendInternalServerError(response, "Payload expected.", sockJsSession.getId()); return; } } super.handleRequestInternal(request, response, sockJsSession); - response.getBody().write("ok".getBytes("UTF-8")); + try { + response.getBody().write("ok".getBytes("UTF-8")); + } + catch (Throwable t) { + throw new TransportErrorException("Failed to write response body", t, sockJsSession.getId()); + } } @Override diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/PollingServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/PollingServerSockJsSession.java index 86851ab9d59..743c3694875 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/PollingServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/PollingServerSockJsSession.java @@ -15,6 +15,8 @@ */ package org.springframework.sockjs.server.transport; +import java.io.IOException; + import org.springframework.sockjs.server.SockJsConfiguration; import org.springframework.sockjs.server.SockJsFrame; import org.springframework.websocket.HandlerProvider; @@ -30,7 +32,7 @@ public class PollingServerSockJsSession extends AbstractHttpServerSockJsSession } @Override - protected void flushCache() throws Exception { + protected void flushCache() throws IOException { cancelHeartbeat(); String[] messages = getMessageCache().toArray(new String[getMessageCache().size()]); getMessageCache().clear(); @@ -38,7 +40,7 @@ public class PollingServerSockJsSession extends AbstractHttpServerSockJsSession } @Override - protected void writeFrame(SockJsFrame frame) throws Exception { + protected void writeFrame(SockJsFrame frame) throws IOException { super.writeFrame(frame); resetRequest(); } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java index 07b6cf34577..691571a72a8 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java @@ -19,9 +19,6 @@ package org.springframework.sockjs.server.transport; import java.io.IOException; import java.util.concurrent.atomic.AtomicInteger; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.springframework.sockjs.AbstractSockJsSession; import org.springframework.sockjs.server.AbstractServerSockJsSession; import org.springframework.sockjs.server.SockJsConfiguration; import org.springframework.sockjs.server.SockJsFrame; @@ -38,7 +35,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; /** - * A wrapper around a {@link WebSocketHandler} instance that parses and adds SockJS + * A wrapper around a {@link WebSocketHandler} instance that parses as well as adds SockJS * messages frames as well as sends SockJS heartbeat messages. * * @author Rossen Stoyanchev @@ -46,13 +43,11 @@ import com.fasterxml.jackson.databind.ObjectMapper; */ public class SockJsWebSocketHandler implements TextMessageHandler { - private static final Log logger = LogFactory.getLog(SockJsWebSocketHandler.class); - private final SockJsConfiguration sockJsConfig; private final HandlerProvider handlerProvider; - private AbstractSockJsSession session; + private WebSocketServerSockJsSession sockJsSession; private final AtomicInteger sessionCount = new AtomicInteger(0); @@ -72,38 +67,25 @@ public class SockJsWebSocketHandler implements TextMessageHandler { } @Override - public void afterConnectionEstablished(WebSocketSession wsSession) throws Exception { + public void afterConnectionEstablished(WebSocketSession wsSession) { Assert.isTrue(this.sessionCount.compareAndSet(0, 1), "Unexpected connection"); - this.session = new WebSocketServerSockJsSession(wsSession, getSockJsConfig()); + this.sockJsSession = new WebSocketServerSockJsSession(getSockJsSessionId(wsSession), getSockJsConfig()); + this.sockJsSession.initWebSocketSession(wsSession); } @Override - public void handleTextMessage(TextMessage message, WebSocketSession wsSession) throws Exception { - String payload = message.getPayload(); - if (StringUtils.isEmpty(payload)) { - logger.trace("Ignoring empty message"); - return; - } - String[] messages; - try { - messages = this.objectMapper.readValue(payload, String[].class); - } - catch (IOException e) { - logger.error("Broken data received. Terminating WebSocket connection abruptly", e); - wsSession.close(); - return; - } - this.session.delegateMessages(messages); + public void handleTextMessage(TextMessage message, WebSocketSession wsSession) { + this.sockJsSession.handleMessage(message, wsSession); } @Override - public void afterConnectionClosed(CloseStatus status, WebSocketSession wsSession) throws Exception { - this.session.delegateConnectionClosed(status); + public void afterConnectionClosed(CloseStatus status, WebSocketSession wsSession) { + this.sockJsSession.delegateConnectionClosed(status); } @Override - public void handleError(Throwable exception, WebSocketSession webSocketSession) throws Exception { - this.session.delegateError(exception); + public void handleTransportError(Throwable exception, WebSocketSession webSocketSession) { + this.sockJsSession.delegateError(exception); } private static String getSockJsSessionId(WebSocketSession wsSession) { @@ -117,16 +99,23 @@ public class SockJsWebSocketHandler implements TextMessageHandler { private class WebSocketServerSockJsSession extends AbstractServerSockJsSession { - private final WebSocketSession wsSession; + private WebSocketSession wsSession; - public WebSocketServerSockJsSession(WebSocketSession wsSession, SockJsConfiguration sockJsConfig) - throws Exception { + public WebSocketServerSockJsSession(String sessionId, SockJsConfiguration config) { + super(sessionId, config, SockJsWebSocketHandler.this.handlerProvider); + } - super(getSockJsSessionId(wsSession), sockJsConfig, SockJsWebSocketHandler.this.handlerProvider); + public void initWebSocketSession(WebSocketSession wsSession) { this.wsSession = wsSession; - TextMessage message = new TextMessage(SockJsFrame.openFrame().getContent()); - this.wsSession.sendMessage(message); + try { + TextMessage message = new TextMessage(SockJsFrame.openFrame().getContent()); + this.wsSession.sendMessage(message); + } + catch (IOException ex) { + tryCloseWithSockJsTransportError(ex, null); + return; + } scheduleHeartbeat(); delegateConnectionEstablished(); } @@ -136,15 +125,33 @@ public class SockJsWebSocketHandler implements TextMessageHandler { return this.wsSession.isOpen(); } + public void handleMessage(TextMessage message, WebSocketSession wsSession) { + String payload = message.getPayload(); + if (StringUtils.isEmpty(payload)) { + logger.trace("Ignoring empty message"); + return; + } + String[] messages; + try { + messages = objectMapper.readValue(payload, String[].class); + } + catch (IOException ex) { + logger.error("Broken data received. Terminating WebSocket connection abruptly", ex); + tryCloseWithSockJsTransportError(ex, CloseStatus.BAD_DATA); + return; + } + delegateMessages(messages); + } + @Override - public void sendMessageInternal(String message) throws Exception { + public void sendMessageInternal(String message) throws IOException { cancelHeartbeat(); writeFrame(SockJsFrame.messageFrame(message)); scheduleHeartbeat(); } @Override - protected void writeFrameInternal(SockJsFrame frame) throws Exception { + protected void writeFrameInternal(SockJsFrame frame) throws IOException { if (logger.isTraceEnabled()) { logger.trace("Write " + frame); } @@ -153,7 +160,7 @@ public class SockJsWebSocketHandler implements TextMessageHandler { } @Override - protected void disconnect(CloseStatus status) throws Exception { + protected void disconnect(CloseStatus status) throws IOException { this.wsSession.close(status); } } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/StreamingServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/StreamingServerSockJsSession.java index 094fd115e98..7e4a3c83fb6 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/StreamingServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/StreamingServerSockJsSession.java @@ -17,9 +17,12 @@ package org.springframework.sockjs.server.transport; import java.io.IOException; +import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.sockjs.server.SockJsConfiguration; import org.springframework.sockjs.server.SockJsFrame; +import org.springframework.sockjs.server.SockJsFrame.FrameFormat; +import org.springframework.sockjs.server.TransportErrorException; import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; @@ -35,7 +38,15 @@ public class StreamingServerSockJsSession extends AbstractHttpServerSockJsSessio super(sessionId, sockJsConfig, handler); } - protected void flushCache() throws Exception { + @Override + public synchronized void setInitialRequest(ServerHttpRequest request, ServerHttpResponse response, + FrameFormat frameFormat) throws TransportErrorException { + + super.setInitialRequest(request, response, frameFormat); + super.setLongPollingRequest(request, response, frameFormat); + } + + protected void flushCache() throws IOException { cancelHeartbeat(); @@ -68,9 +79,12 @@ public class StreamingServerSockJsSession extends AbstractHttpServerSockJsSessio } @Override - public void writeFrame(ServerHttpResponse response, SockJsFrame frame) throws IOException { - super.writeFrame(response, frame); - response.flush(); + protected synchronized void writeFrameInternal(SockJsFrame frame) throws IOException { + if (isActive()) { + super.writeFrameInternal(frame); + getResponse().flush(); + } } + } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketTransportHandler.java index 361c8b3c14c..8f615c913dc 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketTransportHandler.java @@ -16,11 +16,14 @@ package org.springframework.sockjs.server.transport; +import java.io.IOException; + import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.sockjs.AbstractSockJsSession; import org.springframework.sockjs.server.ConfigurableTransportHandler; import org.springframework.sockjs.server.SockJsConfiguration; +import org.springframework.sockjs.server.TransportErrorException; import org.springframework.sockjs.server.TransportHandler; import org.springframework.sockjs.server.TransportType; import org.springframework.util.Assert; @@ -62,17 +65,22 @@ public class WebSocketTransportHandler implements ConfigurableTransportHandler, @Override public void handleRequest(ServerHttpRequest request, ServerHttpResponse response, - HandlerProvider handler, AbstractSockJsSession session) throws Exception { + HandlerProvider handler, AbstractSockJsSession session) throws TransportErrorException { - WebSocketHandler sockJsWrapper = new SockJsWebSocketHandler(this.sockJsConfig, handler); - this.handshakeHandler.doHandshake(request, response, new SimpleHandlerProvider(sockJsWrapper)); + try { + WebSocketHandler sockJsWrapper = new SockJsWebSocketHandler(this.sockJsConfig, handler); + this.handshakeHandler.doHandshake(request, response, new SimpleHandlerProvider(sockJsWrapper)); + } + catch (Throwable t) { + throw new TransportErrorException("Failed to start handshake request", t, session.getId()); + } } // HandshakeHandler methods @Override public boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, - HandlerProvider handler) throws Exception { + HandlerProvider handler) throws IOException { return this.handshakeHandler.doHandshake(request, response, handler); } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrStreamingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrStreamingTransportHandler.java index cf407442849..9e6c7457a48 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrStreamingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrStreamingTransportHandler.java @@ -20,10 +20,12 @@ import java.nio.charset.Charset; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; -import org.springframework.http.server.ServerHttpResponse; import org.springframework.sockjs.server.SockJsFrame.DefaultFrameFormat; import org.springframework.sockjs.server.SockJsFrame.FrameFormat; import org.springframework.sockjs.server.TransportType; +import org.springframework.util.Assert; +import org.springframework.websocket.HandlerProvider; +import org.springframework.websocket.WebSocketHandler; /** @@ -32,7 +34,7 @@ import org.springframework.sockjs.server.TransportType; * @author Rossen Stoyanchev * @since 4.0 */ -public class XhrStreamingTransportHandler extends AbstractStreamingTransportHandler { +public class XhrStreamingTransportHandler extends AbstractHttpSendingTransportHandler { @Override @@ -46,12 +48,20 @@ public class XhrStreamingTransportHandler extends AbstractStreamingTransportHand } @Override - protected void writePrelude(ServerHttpRequest request, ServerHttpResponse response) throws IOException { - for (int i=0; i < 2048; i++) { - response.getBody().write('h'); - } - response.getBody().write('\n'); - response.flush(); + public StreamingServerSockJsSession createSession(String sessionId, HandlerProvider handler) { + Assert.notNull(getSockJsConfig(), "This transport requires SockJsConfiguration"); + + return new StreamingServerSockJsSession(sessionId, getSockJsConfig(), handler) { + + @Override + protected void writePrelude() throws IOException { + for (int i=0; i < 2048; i++) { + getResponse().getBody().write('h'); + } + getResponse().getBody().write('\n'); + getResponse().flush(); + } + }; } @Override diff --git a/spring-websocket/src/main/java/org/springframework/websocket/BinaryMessageHandler.java b/spring-websocket/src/main/java/org/springframework/websocket/BinaryMessageHandler.java index 50940e34daf..bac882087fd 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/BinaryMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/BinaryMessageHandler.java @@ -29,7 +29,6 @@ public interface BinaryMessageHandler extends WebSocketHandler { /** * Handle an incoming binary message. */ - void handleBinaryMessage(BinaryMessage message, WebSocketSession session) - throws Exception; + void handleBinaryMessage(BinaryMessage message, WebSocketSession session); } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/TextMessageHandler.java b/spring-websocket/src/main/java/org/springframework/websocket/TextMessageHandler.java index ea073eaf3d4..fe79622b4b2 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/TextMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/TextMessageHandler.java @@ -34,7 +34,6 @@ public interface TextMessageHandler extends WebSocketHandler { /** * Handle an incoming text message. */ - void handleTextMessage(TextMessage message, WebSocketSession session) - throws Exception; + void handleTextMessage(TextMessage message, WebSocketSession session); } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/WebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/websocket/WebSocketHandler.java index a4972eeb466..da4b841d2cd 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/WebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/WebSocketHandler.java @@ -29,16 +29,16 @@ public interface WebSocketHandler { /** * A new WebSocket connection has been opened and is ready to be used. */ - void afterConnectionEstablished(WebSocketSession session) throws Exception; + void afterConnectionEstablished(WebSocketSession session); /** * A WebSocket connection has been closed. */ - void afterConnectionClosed(CloseStatus closeStatus, WebSocketSession session) throws Exception; + void afterConnectionClosed(CloseStatus closeStatus, WebSocketSession session); /** * TODO */ - void handleError(Throwable exception, WebSocketSession session) throws Exception; + void handleTransportError(Throwable exception, WebSocketSession session); } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/WebSocketHandlerAdapter.java b/spring-websocket/src/main/java/org/springframework/websocket/WebSocketHandlerAdapter.java index 4914f3767a3..4d9e5259d96 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/WebSocketHandlerAdapter.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/WebSocketHandlerAdapter.java @@ -26,15 +26,15 @@ package org.springframework.websocket; public class WebSocketHandlerAdapter implements WebSocketHandler { @Override - public void afterConnectionEstablished(WebSocketSession session) throws Exception { + public void afterConnectionEstablished(WebSocketSession session) { } @Override - public void afterConnectionClosed(CloseStatus status, WebSocketSession session) throws Exception { + public void afterConnectionClosed(CloseStatus status, WebSocketSession session) { } @Override - public void handleError(Throwable exception, WebSocketSession session) { + public void handleTransportError(Throwable exception, WebSocketSession session) { } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/WebSocketSession.java b/spring-websocket/src/main/java/org/springframework/websocket/WebSocketSession.java index 2c9d4355f52..483c6f72e81 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/WebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/WebSocketSession.java @@ -16,10 +16,10 @@ package org.springframework.websocket; +import java.io.IOException; import java.net.URI; - /** * Allows sending messages over a WebSocket connection as well as closing it. * @@ -52,7 +52,7 @@ public interface WebSocketSession { * Send a WebSocket message either {@link TextMessage} or * {@link BinaryMessage}. */ - void sendMessage(WebSocketMessage message) throws Exception; + void sendMessage(WebSocketMessage message) throws IOException; /** * Close the WebSocket connection with status 1000, i.e. equivalent to: @@ -60,11 +60,11 @@ public interface WebSocketSession { * session.close(CloseStatus.NORMAL); * */ - void close() throws Exception; + void close() throws IOException; /** * Close the WebSocket connection with the given close status. */ - void close(CloseStatus status) throws Exception; + void close(CloseStatus status) throws IOException; } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/endpoint/WebSocketHandlerEndpoint.java b/spring-websocket/src/main/java/org/springframework/websocket/endpoint/WebSocketHandlerEndpoint.java index 1b151164865..ce9f0736ba9 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/endpoint/WebSocketHandlerEndpoint.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/endpoint/WebSocketHandlerEndpoint.java @@ -110,7 +110,34 @@ public class WebSocketHandlerEndpoint extends Endpoint { this.handler.afterConnectionEstablished(this.webSocketSession); } catch (Throwable ex) { - onError(session, ex); + tryCloseWithError(ex); + } + } + + private void tryCloseWithError(Throwable ex) { + logger.error("Unhandled error for " + this.webSocketSession, ex); + if (this.webSocketSession.isOpen()) { + try { + this.webSocketSession.close(CloseStatus.SERVER_ERROR); + } + catch (Throwable t) { + destroyHandler(); + } + } + } + + private void destroyHandler() { + try { + if (this.handler != null) { + this.handlerProvider.destroy(this.handler); + } + } + catch (Throwable t) { + logger.warn("Error while destroying handler", t); + } + finally { + this.webSocketSession = null; + this.handler = null; } } @@ -123,7 +150,7 @@ public class WebSocketHandlerEndpoint extends Endpoint { ((TextMessageHandler) handler).handleTextMessage(textMessage, this.webSocketSession); } catch (Throwable ex) { - onError(session, ex); + tryCloseWithError(ex); } } @@ -136,7 +163,7 @@ public class WebSocketHandlerEndpoint extends Endpoint { ((BinaryMessageHandler) handler).handleBinaryMessage(binaryMessage, this.webSocketSession); } catch (Throwable ex) { - onError(session, ex); + tryCloseWithError(ex); } } @@ -150,7 +177,7 @@ public class WebSocketHandlerEndpoint extends Endpoint { this.handler.afterConnectionClosed(closeStatus, this.webSocketSession); } catch (Throwable ex) { - onError(session, ex); + logger.error("Unhandled error for " + this.webSocketSession, ex); } finally { this.handlerProvider.destroy(this.handler); @@ -161,11 +188,10 @@ public class WebSocketHandlerEndpoint extends Endpoint { public void onError(javax.websocket.Session session, Throwable exception) { logger.error("Error for WebSocket session id=" + session.getId(), exception); try { - this.handler.handleError(exception, this.webSocketSession); + this.handler.handleTransportError(exception, this.webSocketSession); } catch (Throwable ex) { - // TODO: close the session? - logger.error("Failed to handle error", ex); + tryCloseWithError(ex); } } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/websocket/server/DefaultHandshakeHandler.java index 6ae721819ce..79463681ea7 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/DefaultHandshakeHandler.java @@ -88,7 +88,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler { @Override public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, - HandlerProvider handler) throws Exception { + HandlerProvider handler) throws IOException { logger.debug("Starting handshake for " + request.getURI()); @@ -199,10 +199,15 @@ public class DefaultHandshakeHandler implements HandshakeHandler { return null; } - private String getWebSocketKeyHash(String key) throws NoSuchAlgorithmException { - MessageDigest digest = MessageDigest.getInstance("SHA1"); - byte[] bytes = digest.digest((key + GUID).getBytes(Charset.forName("ISO-8859-1"))); - return DatatypeConverter.printBase64Binary(bytes); + private String getWebSocketKeyHash(String key) { + try { + MessageDigest digest = MessageDigest.getInstance("SHA1"); + byte[] bytes = digest.digest((key + GUID).getBytes(Charset.forName("ISO-8859-1"))); + return DatatypeConverter.printBase64Binary(bytes); + } + catch (NoSuchAlgorithmException ex) { + throw new IllegalStateException("Failed to generate value for Sec-WebSocket-Key header", ex); + } } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/HandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/websocket/server/HandshakeHandler.java index 61f0e42b332..1989f61ad08 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/HandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/HandshakeHandler.java @@ -16,6 +16,8 @@ package org.springframework.websocket.server; +import java.io.IOException; + import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.websocket.HandlerProvider; @@ -32,6 +34,6 @@ public interface HandshakeHandler { boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, - HandlerProvider handler) throws Exception; + HandlerProvider handler) throws IOException; } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/RequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/websocket/server/RequestUpgradeStrategy.java index fadf15c99e0..359cf9dc474 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/RequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/RequestUpgradeStrategy.java @@ -16,6 +16,8 @@ package org.springframework.websocket.server; +import java.io.IOException; + import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.websocket.HandlerProvider; @@ -43,7 +45,7 @@ public interface RequestUpgradeStrategy { * @param handler the handler for WebSocket messages */ void upgrade(ServerHttpRequest request, ServerHttpResponse response, String selectedProtocol, - HandlerProvider handlerProvider) throws Exception; + HandlerProvider handlerProvider) throws IOException; // FIXME how to indicate failure to upgrade? } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/support/AbstractEndpointUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/websocket/server/support/AbstractEndpointUpgradeStrategy.java index f7884838d32..ef53d8e237b 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/support/AbstractEndpointUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/support/AbstractEndpointUpgradeStrategy.java @@ -16,6 +16,8 @@ package org.springframework.websocket.server.support; +import java.io.IOException; + import javax.websocket.Endpoint; import org.apache.commons.logging.Log; @@ -42,7 +44,7 @@ public abstract class AbstractEndpointUpgradeStrategy implements RequestUpgradeS @Override public void upgrade(ServerHttpRequest request, ServerHttpResponse response, - String protocol, HandlerProvider handler) throws Exception { + String protocol, HandlerProvider handler) throws IOException { upgradeInternal(request, response, protocol, adaptWebSocketHandler(handler)); } @@ -52,6 +54,6 @@ public abstract class AbstractEndpointUpgradeStrategy implements RequestUpgradeS } protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, - String selectedProtocol, Endpoint endpoint) throws Exception; + String selectedProtocol, Endpoint endpoint) throws IOException; } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/support/GlassfishRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/websocket/server/support/GlassfishRequestUpgradeStrategy.java index e462e65dd74..a5ca56c5b50 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/support/GlassfishRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/support/GlassfishRequestUpgradeStrategy.java @@ -16,6 +16,7 @@ package org.springframework.websocket.server.support; +import java.io.IOException; import java.lang.reflect.Constructor; import java.net.URI; import java.util.Arrays; @@ -24,6 +25,7 @@ import java.util.Random; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponseWrapper; +import javax.websocket.DeploymentException; import javax.websocket.Endpoint; import org.glassfish.tyrus.core.ComponentProviderService; @@ -67,7 +69,7 @@ public class GlassfishRequestUpgradeStrategy extends AbstractEndpointUpgradeStra @Override public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, - String selectedProtocol, Endpoint endpoint) throws Exception { + String selectedProtocol, Endpoint endpoint) throws IOException { Assert.isTrue(request instanceof ServletServerHttpRequest); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); @@ -78,7 +80,13 @@ public class GlassfishRequestUpgradeStrategy extends AbstractEndpointUpgradeStra TyrusEndpoint tyrusEndpoint = createTyrusEndpoint(servletRequest, endpoint, selectedProtocol); WebSocketEngine engine = WebSocketEngine.getEngine(); - engine.register(tyrusEndpoint); + + try { + engine.register(tyrusEndpoint); + } + catch (DeploymentException ex) { + throw new IllegalStateException("Failed to deploy endpoint in Glassfish", ex); + } try { if (!performUpgrade(servletRequest, servletResponse, request.getHeaders(), tyrusEndpoint)) { @@ -91,7 +99,7 @@ public class GlassfishRequestUpgradeStrategy extends AbstractEndpointUpgradeStra } private boolean performUpgrade(HttpServletRequest request, HttpServletResponse response, - HttpHeaders headers, TyrusEndpoint tyrusEndpoint) throws Exception { + HttpHeaders headers, TyrusEndpoint tyrusEndpoint) throws IOException { final TyrusHttpUpgradeHandler upgradeHandler = request.upgrade(TyrusHttpUpgradeHandler.class); @@ -128,12 +136,17 @@ public class GlassfishRequestUpgradeStrategy extends AbstractEndpointUpgradeStra endpointConfig.getConfigurator())); } - private Connection createConnection(TyrusHttpUpgradeHandler handler, HttpServletResponse response) throws Exception { - String name = "org.glassfish.tyrus.servlet.ConnectionImpl"; - Class clazz = ClassUtils.forName(name, GlassfishRequestUpgradeStrategy.class.getClassLoader()); - Constructor constructor = clazz.getDeclaredConstructor(TyrusHttpUpgradeHandler.class, HttpServletResponse.class); - ReflectionUtils.makeAccessible(constructor); - return (Connection) constructor.newInstance(handler, response); + private Connection createConnection(TyrusHttpUpgradeHandler handler, HttpServletResponse response) { + try { + String name = "org.glassfish.tyrus.servlet.ConnectionImpl"; + Class clazz = ClassUtils.forName(name, GlassfishRequestUpgradeStrategy.class.getClassLoader()); + Constructor constructor = clazz.getDeclaredConstructor(TyrusHttpUpgradeHandler.class, HttpServletResponse.class); + ReflectionUtils.makeAccessible(constructor); + return (Connection) constructor.newInstance(handler, response); + } + catch (Exception ex) { + throw new IllegalStateException("Failed to instantiate Glassfish connection", ex); + } } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/support/JettyRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/websocket/server/support/JettyRequestUpgradeStrategy.java index 6be0568cb8b..c313f9cdc08 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/support/JettyRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/support/JettyRequestUpgradeStrategy.java @@ -18,6 +18,7 @@ package org.springframework.websocket.server.support; import java.io.IOException; import java.net.URI; +import java.util.concurrent.atomic.AtomicInteger; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -101,7 +102,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { @Override public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String selectedProtocol, HandlerProvider handlerProvider) - throws Exception { + throws IOException { Assert.isInstanceOf(ServletServerHttpRequest.class, request); Assert.isInstanceOf(ServletServerHttpResponse.class, response); upgrade(((ServletServerHttpRequest) request).getServletRequest(), @@ -111,7 +112,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { private void upgrade(HttpServletRequest request, HttpServletResponse response, String selectedProtocol, final HandlerProvider handlerProvider) - throws Exception { + throws IOException { request.setAttribute(HANDLER_PROVIDER, handlerProvider); Assert.state(factory.isUpgradeRequest(request, response), "Not a suitable WebSocket upgrade request"); Assert.state(factory.acceptWebSocket(request, response), "Unable to accept WebSocket"); @@ -129,6 +130,8 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { private WebSocketSession session; + private final AtomicInteger sessionCount = new AtomicInteger(0); + public WebSocketHandlerAdapter(HandlerProvider provider) { Assert.notNull(provider, "Provider must not be null"); @@ -139,31 +142,53 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { @Override public void onWebSocketConnect(Session session) { - Assert.state(this.session == null, "WebSocket already open"); + + Assert.isTrue(this.sessionCount.compareAndSet(0, 1), "Unexpected connection"); + + this.session = new WebSocketSessionAdapter(session); + if (logger.isDebugEnabled()) { + logger.debug("Connection established, WebSocket session id=" + + this.session.getId() + ", uri=" + this.session.getURI()); + } + this.handler = this.provider.getHandler(); + try { - this.session = new WebSocketSessionAdapter(session); - if (logger.isDebugEnabled()) { - logger.debug("Connection established, WebSocket session id=" - + this.session.getId() + ", uri=" + this.session.getURI()); - } - this.handler = this.provider.getHandler(); this.handler.afterConnectionEstablished(this.session); } - catch (Exception ex) { + catch (Throwable ex) { + tryCloseWithError(ex); + } + } + + private void tryCloseWithError(Throwable ex) { + logger.error("Unhandled error for " + this.session, ex); + if (this.session.isOpen()) { try { - // FIXME revisit after error handling - onWebSocketError(ex); + this.session.close(CloseStatus.SERVER_ERROR); } - finally { - this.session = null; - this.handler = null; + catch (Throwable t) { + destroyHandler(); } } } + private void destroyHandler() { + try { + if (this.handler != null) { + this.provider.destroy(this.handler); + } + } + catch (Throwable t) { + logger.warn("Error while destroying handler", t); + } + finally { + this.session = null; + this.handler = null; + } + } + @Override public void onWebSocketClose(int statusCode, String reason) { - Assert.state(this.session != null, "WebSocket not open"); try { CloseStatus closeStatus = new CloseStatus(statusCode, reason); if (logger.isDebugEnabled()) { @@ -172,19 +197,11 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { } this.handler.afterConnectionClosed(closeStatus, this.session); } - catch (Exception ex) { - onWebSocketError(ex); + catch (Throwable ex) { + logger.error("Unhandled error for " + this.session, ex); } finally { - try { - if (this.handler != null) { - this.provider.destroy(this.handler); - } - } - finally { - this.session = null; - this.handler = null; - } + destroyHandler(); } } @@ -200,8 +217,8 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { ((TextMessageHandler) this.handler).handleTextMessage(message, this.session); } } - catch(Exception ex) { - ex.printStackTrace(); //FIXME + catch(Throwable ex) { + tryCloseWithError(ex); } } @@ -218,20 +235,18 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { this.session); } } - catch(Exception ex) { - ex.printStackTrace(); //FIXME + catch(Throwable ex) { + tryCloseWithError(ex); } } @Override public void onWebSocketError(Throwable cause) { try { - this.handler.handleError(cause, this.session); + this.handler.handleTransportError(cause, this.session); } catch (Throwable ex) { - // FIXME exceptions - logger.error("Error for WebSocket session id=" + this.session.getId(), - cause); + tryCloseWithError(ex); } } } @@ -271,7 +286,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { } @Override - public void sendMessage(WebSocketMessage message) throws Exception { + public void sendMessage(WebSocketMessage message) throws IOException { if (message instanceof BinaryMessage) { sendMessage((BinaryMessage) message); } @@ -283,11 +298,11 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { } } - private void sendMessage(BinaryMessage message) throws Exception { + private void sendMessage(BinaryMessage message) throws IOException { this.session.getRemote().sendBytes(message.getPayload()); } - private void sendMessage(TextMessage message) throws Exception { + private void sendMessage(TextMessage message) throws IOException { this.session.getRemote().sendString(message.getPayload()); }