Synchronized lazy start in JettyRequestUpgradeStrategy

Issue: SPR-14527
This commit is contained in:
Rossen Stoyanchev 2016-12-12 21:01:49 -05:00
parent 75422787b6
commit 885e76bdd0
3 changed files with 56 additions and 38 deletions

View File

@ -37,6 +37,7 @@ import org.springframework.web.reactive.socket.adapter.JettyWebSocketHandlerAdap
import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy; import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
/** /**
* A {@link RequestUpgradeStrategy} for use with Jetty. * A {@link RequestUpgradeStrategy} for use with Jetty.
* *
@ -45,52 +46,58 @@ import org.springframework.web.server.ServerWebExchange;
*/ */
public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Lifecycle { public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Lifecycle {
private static final ThreadLocal<JettyWebSocketHandlerAdapter> wsContainerHolder = private static final ThreadLocal<JettyWebSocketHandlerAdapter> adapterHolder =
new NamedThreadLocal<>("Jetty WebSocketHandler Adapter"); new NamedThreadLocal<>("JettyWebSocketHandlerAdapter");
private WebSocketServerFactory factory; private WebSocketServerFactory factory;
private ServletContext servletContext; private ServletContext servletContext;
private volatile boolean running = false; private boolean running = false;
private final Object lifecycleMonitor = new Object();
@Override @Override
public void start() { public void start() {
if (!isRunning() && this.servletContext != null) { synchronized (this.lifecycleMonitor) {
this.running = true; if (!isRunning() && this.servletContext != null) {
try { this.running = true;
this.factory = new WebSocketServerFactory(this.servletContext); try {
this.factory.setCreator((request, response) -> { this.factory = new WebSocketServerFactory(this.servletContext);
JettyWebSocketHandlerAdapter adapter = wsContainerHolder.get(); this.factory.setCreator((request, response) -> adapterHolder.get());
Assert.state(adapter != null, "Expected JettyWebSocketHandlerAdapter"); this.factory.start();
return adapter; }
}); catch (Exception ex) {
this.factory.start(); throw new IllegalStateException("Unable to start WebSocketServerFactory", ex);
} }
catch (Exception ex) {
throw new IllegalStateException("Unable to start Jetty WebSocketServerFactory", ex);
} }
} }
} }
@Override @Override
public void stop() { public void stop() {
if (isRunning()) { synchronized (this.lifecycleMonitor) {
this.running = false; if (isRunning()) {
try { try {
this.factory.stop(); this.factory.stop();
} }
catch (Exception ex) { catch (Exception ex) {
throw new IllegalStateException("Unable to stop Jetty WebSocketServerFactory", ex); throw new IllegalStateException("Failed to stop WebSocketServerFactory", ex);
}
finally {
this.running = false;
}
} }
} }
} }
@Override @Override
public boolean isRunning() { public boolean isRunning() {
return this.running; synchronized (this.lifecycleMonitor) {
return this.running;
}
} }
@Override @Override
@ -103,25 +110,20 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
HttpServletRequest servletRequest = getHttpServletRequest(request); HttpServletRequest servletRequest = getHttpServletRequest(request);
HttpServletResponse servletResponse = getHttpServletResponse(response); HttpServletResponse servletResponse = getHttpServletResponse(response);
if (this.servletContext == null) { startLazily(servletRequest);
this.servletContext = servletRequest.getServletContext();
this.servletContext.setAttribute(DecoratedObjectFactory.ATTR, new DecoratedObjectFactory()); boolean isUpgrade = this.factory.isUpgradeRequest(servletRequest, servletResponse);
} Assert.isTrue(isUpgrade, "Not a WebSocket handshake");
try { try {
start(); adapterHolder.set(adapter);
Assert.isTrue(this.factory.isUpgradeRequest(
servletRequest, servletResponse), "Not a WebSocket handshake");
wsContainerHolder.set(adapter);
this.factory.acceptWebSocket(servletRequest, servletResponse); this.factory.acceptWebSocket(servletRequest, servletResponse);
} }
catch (IOException ex) { catch (IOException ex) {
return Mono.error(ex); return Mono.error(ex);
} }
finally { finally {
wsContainerHolder.remove(); adapterHolder.remove();
} }
return Mono.empty(); return Mono.empty();
@ -137,4 +139,17 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life
return ((ServletServerHttpResponse) response).getServletResponse(); return ((ServletServerHttpResponse) response).getServletResponse();
} }
private void startLazily(HttpServletRequest request) {
if (this.servletContext != null) {
return;
}
synchronized (this.lifecycleMonitor) {
if (this.servletContext == null) {
this.servletContext = request.getServletContext();
this.servletContext.setAttribute(DecoratedObjectFactory.ATTR, new DecoratedObjectFactory());
start();
}
}
}
} }

View File

@ -66,7 +66,7 @@ public abstract class AbstractWebSocketHandlerIntegrationTests {
public Class<?> handlerAdapterConfigClass; public Class<?> handlerAdapterConfigClass;
@Parameters @Parameters(name = "server [{0}]")
public static Object[][] arguments() { public static Object[][] arguments() {
File base = new File(System.getProperty("java.io.tmpdir")); File base = new File(System.getProperty("java.io.tmpdir"));
return new Object[][] { return new Object[][] {

View File

@ -20,7 +20,6 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.reactivex.netty.protocol.http.client.HttpClient; import io.reactivex.netty.protocol.http.client.HttpClient;
@ -66,7 +65,11 @@ public class BasicWebSocketHandlerIntegrationTests extends AbstractWebSocketHand
.mergeWith(conn.getInput()) .mergeWith(conn.getInput())
) )
.take(10) .take(10)
.map(frame -> frame.content().toString(StandardCharsets.UTF_8)) .map(frame -> {
String text = frame.content().toString(StandardCharsets.UTF_8);
frame.release();
return text;
})
.toList().toBlocking().first(); .toList().toBlocking().first();
List<String> expected = messages.toList().toBlocking().first(); List<String> expected = messages.toList().toBlocking().first();
assertEquals(expected, actual); assertEquals(expected, actual);