Polishing

See gh-29408
This commit is contained in:
Sam Brannen 2022-11-02 12:13:30 +01:00
parent 5245327962
commit 723e09c164
14 changed files with 247 additions and 312 deletions

View File

@ -98,18 +98,19 @@ public class SockJsServiceRegistration {
/** /**
* Transports with no native cross-domain communication (e.g. "eventsource", * Transports with no native cross-domain communication (e.g. "eventsource",
* "htmlfile") must get a simple page from the "foreign" domain in an invisible * "htmlfile") must get a simple page from the "foreign" domain in an invisible
* iframe so that code in the iframe can run from a domain local to the SockJS * {@code iframe} so that code in the {@code iframe} can run from a domain
* server. Since the iframe needs to load the SockJS javascript client library, * local to the SockJS server. Since the {@code iframe} needs to load the
* this property allows specifying where to load it from. * SockJS JavaScript client library, this property allows specifying where to
* load it from.
* <p>By default this is set to point to * <p>By default this is set to point to
* "<a href="https://cdn.jsdelivr.net/sockjs/1.0.0/sockjs.min.js">sockjs.min.js</a>". However, it can * <a href="https://cdn.jsdelivr.net/sockjs/1.0.0/sockjs.min.js">"https://cdn.jsdelivr.net/sockjs/1.0.0/sockjs.min.js"</a>.
* also be set to point to a URL served by the application. * However, it can also be set to point to a URL served by the application.
* <p>Note that it's possible to specify a relative URL in which case the URL * <p>Note that it's possible to specify a relative URL in which case the URL
* must be relative to the iframe URL. For example assuming a SockJS endpoint * must be relative to the {@code iframe} URL. For example assuming a SockJS endpoint
* mapped to "/sockjs", and resulting iframe URL "/sockjs/iframe.html", then * mapped to "/sockjs", and resulting {@code iframe} URL "/sockjs/iframe.html", then
* the relative URL must start with "../../" to traverse up to the location * the relative URL must start with "../../" to traverse up to the location
* above the SockJS mapping. In case of a prefix-based Servlet mapping one more * above the SockJS mapping. In case of a prefix-based Servlet mapping one more
* traversal may be needed. * traversals may be needed.
*/ */
public SockJsServiceRegistration setClientLibraryUrl(String clientLibraryUrl) { public SockJsServiceRegistration setClientLibraryUrl(String clientLibraryUrl) {
this.clientLibraryUrl = clientLibraryUrl; this.clientLibraryUrl = clientLibraryUrl;

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2021 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -145,18 +145,19 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
/** /**
* Transports with no native cross-domain communication (e.g. "eventsource", * Transports with no native cross-domain communication (e.g. "eventsource",
* "htmlfile") must get a simple page from the "foreign" domain in an invisible * "htmlfile") must get a simple page from the "foreign" domain in an invisible
* iframe so that code in the iframe can run from a domain local to the SockJS * {@code iframe} so that code in the {@code iframe} can run from a domain
* server. Since the iframe needs to load the SockJS javascript client library, * local to the SockJS server. Since the {@code iframe} needs to load the
* this property allows specifying where to load it from. * SockJS JavaScript client library, this property allows specifying where to
* load it from.
* <p>By default this is set to point to * <p>By default this is set to point to
* "<a href="https://cdn.jsdelivr.net/sockjs/1.0.0/sockjs.min.js">sockjs.min.js</a>". * <a href="https://cdn.jsdelivr.net/sockjs/1.0.0/sockjs.min.js">"https://cdn.jsdelivr.net/sockjs/1.0.0/sockjs.min.js"</a>.
* However, it can also be set to point to a URL served by the application. * However, it can also be set to point to a URL served by the application.
* <p>Note that it's possible to specify a relative URL in which case the URL * <p>Note that it's possible to specify a relative URL in which case the URL
* must be relative to the iframe URL. For example assuming a SockJS endpoint * must be relative to the {@code iframe} URL. For example assuming a SockJS endpoint
* mapped to "/sockjs", and resulting iframe URL "/sockjs/iframe.html", then * mapped to "/sockjs", and resulting {@code iframe} URL "/sockjs/iframe.html", then
* the relative URL must start with "../../" to traverse up to the location * the relative URL must start with "../../" to traverse up to the location
* above the SockJS mapping. In case of a prefix-based Servlet mapping one more * above the SockJS mapping. In case of a prefix-based Servlet mapping one more
* traversal may be needed. * traversals may be needed.
*/ */
public void setSockJsClientLibraryUrl(String clientLibraryUrl) { public void setSockJsClientLibraryUrl(String clientLibraryUrl) {
this.clientLibraryUrl = clientLibraryUrl; this.clientLibraryUrl = clientLibraryUrl;
@ -613,24 +614,23 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
private class IframeHandler implements SockJsRequestHandler { private class IframeHandler implements SockJsRequestHandler {
private static final String IFRAME_CONTENT = private static final String IFRAME_CONTENT = """
""" <!DOCTYPE html>
<!DOCTYPE html> <html>
<html> <head>
<head> <meta http-equiv="X-UA-Compatible" content="IE=edge" />
<meta http-equiv="X-UA-Compatible" content="IE=edge" /> <meta http-equiv="Content-Type" content="text/html; charset=UTF-8" />
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8" /> <script>
<script> document.domain = document.domain;
document.domain = document.domain; _sockjs_onload = function(){SockJS.bootstrap_iframe();};
_sockjs_onload = function(){SockJS.bootstrap_iframe();}; </script>
</script> <script src="%s"></script>
<script src="%s"></script> </head>
</head> <body>
<body> <h2>Don't panic!</h2>
<h2>Don't panic!</h2> <p>This is a SockJS hidden iframe. It's used for cross domain magic.</p>
<p>This is a SockJS hidden iframe. It's used for cross domain magic.</p> </body>
</body> </html>""";
</html>""";
@Override @Override
public void handle(ServerHttpRequest request, ServerHttpResponse response) throws IOException { public void handle(ServerHttpRequest request, ServerHttpResponse response) throws IOException {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2021 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -59,26 +59,22 @@ public class HtmlFileTransportHandler extends AbstractHttpSendingTransportHandle
static { static {
StringBuilder sb = new StringBuilder( StringBuilder sb = new StringBuilder("""
""" <!DOCTYPE html>
<!DOCTYPE html> <html><head>
<html><head> <meta http-equiv="X-UA-Compatible" content="IE=edge" />
<meta http-equiv="X-UA-Compatible" content="IE=edge" /> <meta http-equiv="Content-Type" content="text/html; charset=UTF-8" />
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8" /> </head><body><h2>Don't panic!</h2>
</head><body><h2>Don't panic!</h2> <script>
<script> document.domain = document.domain;
document.domain = document.domain; var c = parent.%s;
var c = parent.%s; c.start();
c.start(); function p(d) {c.message(d);};
function p(d) {c.message(d);}; window.onload = function() {c.stop();};
window.onload = function() {c.stop();}; </script>""");
</script>"""
);
while (sb.length() < MINIMUM_PARTIAL_HTML_CONTENT_LENGTH) { sb.append(" ".repeat(MINIMUM_PARTIAL_HTML_CONTENT_LENGTH - sb.length()));
sb.append(' '); PARTIAL_HTML_CONTENT = sb.append('\n').toString();
}
PARTIAL_HTML_CONTENT = sb.toString();
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -45,7 +45,7 @@ public abstract class AbstractHttpRequestTests {
@BeforeEach @BeforeEach
public void setup() { protected void setup() {
resetRequestAndResponse(); resetRequestAndResponse();
} }

View File

@ -19,7 +19,6 @@ package org.springframework.web.socket.config.annotation;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageChannel;
@ -47,67 +46,62 @@ import static org.mockito.Mockito.mock;
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
*/ */
public class WebMvcStompWebSocketEndpointRegistrationTests { class WebMvcStompWebSocketEndpointRegistrationTests {
private SubProtocolWebSocketHandler handler; private final SubProtocolWebSocketHandler handler =
new SubProtocolWebSocketHandler(mock(MessageChannel.class), mock(SubscribableChannel.class));
private TaskScheduler scheduler; private final TaskScheduler scheduler = mock(TaskScheduler.class);
@BeforeEach
public void setup() {
this.handler = new SubProtocolWebSocketHandler(mock(MessageChannel.class), mock(SubscribableChannel.class));
this.scheduler = mock(TaskScheduler.class);
}
@Test @Test
public void minimalRegistration() { void minimalRegistration() {
WebMvcStompWebSocketEndpointRegistration registration = WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings(); MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertThat(mappings.size()).isEqualTo(1); assertThat(mappings).hasSize(1);
Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next(); Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
assertThat(((WebSocketHttpRequestHandler) entry.getKey()).getWebSocketHandler()).isNotNull(); assertThat(((WebSocketHttpRequestHandler) entry.getKey()).getWebSocketHandler()).isNotNull();
assertThat(((WebSocketHttpRequestHandler) entry.getKey()).getHandshakeInterceptors().size()).isEqualTo(1); assertThat(((WebSocketHttpRequestHandler) entry.getKey()).getHandshakeInterceptors()).hasSize(1);
assertThat(entry.getValue()).isEqualTo(List.of("/foo")); assertThat(entry.getValue()).containsExactly("/foo");
} }
@Test @Test
public void allowedOrigins() { void allowedOrigins() {
WebMvcStompWebSocketEndpointRegistration registration = WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
registration.setAllowedOrigins(); registration.setAllowedOrigins();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings(); MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertThat(mappings.size()).isEqualTo(1); assertThat(mappings).hasSize(1);
HttpRequestHandler handler = mappings.entrySet().iterator().next().getKey(); HttpRequestHandler handler = mappings.entrySet().iterator().next().getKey();
WebSocketHttpRequestHandler wsHandler = (WebSocketHttpRequestHandler) handler; WebSocketHttpRequestHandler wsHandler = (WebSocketHttpRequestHandler) handler;
assertThat(wsHandler.getWebSocketHandler()).isNotNull(); assertThat(wsHandler.getWebSocketHandler()).isNotNull();
assertThat(wsHandler.getHandshakeInterceptors().size()).isEqualTo(1); assertThat(wsHandler.getHandshakeInterceptors()).hasSize(1);
assertThat(wsHandler.getHandshakeInterceptors().get(0).getClass()).isEqualTo(OriginHandshakeInterceptor.class); assertThat(wsHandler.getHandshakeInterceptors().get(0).getClass()).isEqualTo(OriginHandshakeInterceptor.class);
} }
@Test @Test
public void sameOrigin() { void sameOrigin() {
WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration( WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(
new String[] {"/foo"}, this.handler, this.scheduler); new String[] {"/foo"}, this.handler, this.scheduler);
registration.setAllowedOrigins(); registration.setAllowedOrigins();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings(); MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertThat(mappings.size()).isEqualTo(1); assertThat(mappings).hasSize(1);
HttpRequestHandler handler = mappings.entrySet().iterator().next().getKey(); HttpRequestHandler handler = mappings.entrySet().iterator().next().getKey();
WebSocketHttpRequestHandler wsHandler = (WebSocketHttpRequestHandler) handler; WebSocketHttpRequestHandler wsHandler = (WebSocketHttpRequestHandler) handler;
assertThat(wsHandler.getWebSocketHandler()).isNotNull(); assertThat(wsHandler.getWebSocketHandler()).isNotNull();
assertThat(wsHandler.getHandshakeInterceptors().size()).isEqualTo(1); assertThat(wsHandler.getHandshakeInterceptors()).hasSize(1);
assertThat(wsHandler.getHandshakeInterceptors().get(0).getClass()).isEqualTo(OriginHandshakeInterceptor.class); assertThat(wsHandler.getHandshakeInterceptors().get(0).getClass()).isEqualTo(OriginHandshakeInterceptor.class);
} }
@Test @Test
public void allowedOriginsWithSockJsService() { void allowedOriginsWithSockJsService() {
WebMvcStompWebSocketEndpointRegistration registration = WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
@ -115,7 +109,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
registration.setAllowedOrigins(origin).withSockJS(); registration.setAllowedOrigins(origin).withSockJS();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings(); MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertThat(mappings.size()).isEqualTo(1); assertThat(mappings).hasSize(1);
SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey(); SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertThat(requestHandler.getSockJsService()).isNotNull(); assertThat(requestHandler.getSockJsService()).isNotNull();
DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService(); DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
@ -126,7 +120,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
registration.withSockJS().setAllowedOrigins(origin); registration.withSockJS().setAllowedOrigins(origin);
mappings = registration.getMappings(); mappings = registration.getMappings();
assertThat(mappings.size()).isEqualTo(1); assertThat(mappings).hasSize(1);
requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey(); requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertThat(requestHandler.getSockJsService()).isNotNull(); assertThat(requestHandler.getSockJsService()).isNotNull();
sockJsService = (DefaultSockJsService)requestHandler.getSockJsService(); sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
@ -135,7 +129,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
} }
@Test @Test
public void allowedOriginPatterns() { void allowedOriginPatterns() {
WebMvcStompWebSocketEndpointRegistration registration = WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
@ -143,7 +137,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
registration.setAllowedOriginPatterns(origin).withSockJS(); registration.setAllowedOriginPatterns(origin).withSockJS();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings(); MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertThat(mappings.size()).isEqualTo(1); assertThat(mappings).hasSize(1);
SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey(); SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertThat(requestHandler.getSockJsService()).isNotNull(); assertThat(requestHandler.getSockJsService()).isNotNull();
DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService(); DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
@ -153,7 +147,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
registration.withSockJS().setAllowedOriginPatterns(origin); registration.withSockJS().setAllowedOriginPatterns(origin);
mappings = registration.getMappings(); mappings = registration.getMappings();
assertThat(mappings.size()).isEqualTo(1); assertThat(mappings).hasSize(1);
requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey(); requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertThat(requestHandler.getSockJsService()).isNotNull(); assertThat(requestHandler.getSockJsService()).isNotNull();
sockJsService = (DefaultSockJsService)requestHandler.getSockJsService(); sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
@ -161,14 +155,14 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
} }
@Test // SPR-12283 @Test // SPR-12283
public void disableCorsWithSockJsService() { void disableCorsWithSockJsService() {
WebMvcStompWebSocketEndpointRegistration registration = WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
registration.withSockJS().setSuppressCors(true); registration.withSockJS().setSuppressCors(true);
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings(); MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertThat(mappings.size()).isEqualTo(1); assertThat(mappings).hasSize(1);
SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey(); SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertThat(requestHandler.getSockJsService()).isNotNull(); assertThat(requestHandler.getSockJsService()).isNotNull();
DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService(); DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
@ -176,7 +170,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
} }
@Test @Test
public void handshakeHandlerAndInterceptor() { void handshakeHandlerAndInterceptor() {
WebMvcStompWebSocketEndpointRegistration registration = WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
@ -186,21 +180,21 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor); registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor);
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings(); MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertThat(mappings.size()).isEqualTo(1); assertThat(mappings).hasSize(1);
Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next(); Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
assertThat(entry.getValue()).isEqualTo(List.of("/foo")); assertThat(entry.getValue()).containsExactly("/foo");
WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey(); WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey();
assertThat(requestHandler.getWebSocketHandler()).isNotNull(); assertThat(requestHandler.getWebSocketHandler()).isNotNull();
assertThat(requestHandler.getHandshakeHandler()).isSameAs(handshakeHandler); assertThat(requestHandler.getHandshakeHandler()).isSameAs(handshakeHandler);
assertThat(requestHandler.getHandshakeInterceptors().size()).isEqualTo(2); assertThat(requestHandler.getHandshakeInterceptors()).hasSize(2);
assertThat(requestHandler.getHandshakeInterceptors().get(0)).isEqualTo(interceptor); assertThat(requestHandler.getHandshakeInterceptors().get(0)).isEqualTo(interceptor);
assertThat(requestHandler.getHandshakeInterceptors().get(1).getClass()).isEqualTo(OriginHandshakeInterceptor.class); assertThat(requestHandler.getHandshakeInterceptors().get(1).getClass()).isEqualTo(OriginHandshakeInterceptor.class);
} }
@Test @Test
public void handshakeHandlerAndInterceptorWithAllowedOrigins() { void handshakeHandlerAndInterceptorWithAllowedOrigins() {
WebMvcStompWebSocketEndpointRegistration registration = WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
@ -210,21 +204,21 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).setAllowedOrigins(origin); registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).setAllowedOrigins(origin);
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings(); MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertThat(mappings.size()).isEqualTo(1); assertThat(mappings).hasSize(1);
Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next(); Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
assertThat(entry.getValue()).isEqualTo(List.of("/foo")); assertThat(entry.getValue()).containsExactly("/foo");
WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey(); WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey();
assertThat(requestHandler.getWebSocketHandler()).isNotNull(); assertThat(requestHandler.getWebSocketHandler()).isNotNull();
assertThat(requestHandler.getHandshakeHandler()).isSameAs(handshakeHandler); assertThat(requestHandler.getHandshakeHandler()).isSameAs(handshakeHandler);
assertThat(requestHandler.getHandshakeInterceptors().size()).isEqualTo(2); assertThat(requestHandler.getHandshakeInterceptors()).hasSize(2);
assertThat(requestHandler.getHandshakeInterceptors().get(0)).isEqualTo(interceptor); assertThat(requestHandler.getHandshakeInterceptors().get(0)).isEqualTo(interceptor);
assertThat(requestHandler.getHandshakeInterceptors().get(1).getClass()).isEqualTo(OriginHandshakeInterceptor.class); assertThat(requestHandler.getHandshakeInterceptors().get(1).getClass()).isEqualTo(OriginHandshakeInterceptor.class);
} }
@Test @Test
public void handshakeHandlerInterceptorWithSockJsService() { void handshakeHandlerInterceptorWithSockJsService() {
WebMvcStompWebSocketEndpointRegistration registration = WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
@ -234,10 +228,10 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).withSockJS(); registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).withSockJS();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings(); MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertThat(mappings.size()).isEqualTo(1); assertThat(mappings).hasSize(1);
Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next(); Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
assertThat(entry.getValue()).isEqualTo(List.of("/foo/**")); assertThat(entry.getValue()).containsExactly("/foo/**");
SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler) entry.getKey(); SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler) entry.getKey();
assertThat(requestHandler.getWebSocketHandler()).isNotNull(); assertThat(requestHandler.getWebSocketHandler()).isNotNull();
@ -248,13 +242,13 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
Map<TransportType, TransportHandler> handlers = sockJsService.getTransportHandlers(); Map<TransportType, TransportHandler> handlers = sockJsService.getTransportHandlers();
WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET); WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET);
assertThat(transportHandler.getHandshakeHandler()).isSameAs(handshakeHandler); assertThat(transportHandler.getHandshakeHandler()).isSameAs(handshakeHandler);
assertThat(sockJsService.getHandshakeInterceptors().size()).isEqualTo(2); assertThat(sockJsService.getHandshakeInterceptors()).hasSize(2);
assertThat(sockJsService.getHandshakeInterceptors().get(0)).isEqualTo(interceptor); assertThat(sockJsService.getHandshakeInterceptors().get(0)).isEqualTo(interceptor);
assertThat(sockJsService.getHandshakeInterceptors().get(1).getClass()).isEqualTo(OriginHandshakeInterceptor.class); assertThat(sockJsService.getHandshakeInterceptors().get(1).getClass()).isEqualTo(OriginHandshakeInterceptor.class);
} }
@Test @Test
public void handshakeHandlerInterceptorWithSockJsServiceAndAllowedOrigins() { void handshakeHandlerInterceptorWithSockJsServiceAndAllowedOrigins() {
WebMvcStompWebSocketEndpointRegistration registration = WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
@ -266,10 +260,10 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
.addInterceptors(interceptor).setAllowedOrigins(origin).withSockJS(); .addInterceptors(interceptor).setAllowedOrigins(origin).withSockJS();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings(); MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertThat(mappings.size()).isEqualTo(1); assertThat(mappings).hasSize(1);
Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next(); Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
assertThat(entry.getValue()).isEqualTo(List.of("/foo/**")); assertThat(entry.getValue()).containsExactly("/foo/**");
SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler) entry.getKey(); SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler) entry.getKey();
assertThat(requestHandler.getWebSocketHandler()).isNotNull(); assertThat(requestHandler.getWebSocketHandler()).isNotNull();
@ -280,7 +274,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
Map<TransportType, TransportHandler> handlers = sockJsService.getTransportHandlers(); Map<TransportType, TransportHandler> handlers = sockJsService.getTransportHandlers();
WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET); WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET);
assertThat(transportHandler.getHandshakeHandler()).isSameAs(handshakeHandler); assertThat(transportHandler.getHandshakeHandler()).isSameAs(handshakeHandler);
assertThat(sockJsService.getHandshakeInterceptors().size()).isEqualTo(2); assertThat(sockJsService.getHandshakeInterceptors()).hasSize(2);
assertThat(sockJsService.getHandshakeInterceptors().get(0)).isEqualTo(interceptor); assertThat(sockJsService.getHandshakeInterceptors().get(0)).isEqualTo(interceptor);
assertThat(sockJsService.getHandshakeInterceptors().get(1).getClass()).isEqualTo(OriginHandshakeInterceptor.class); assertThat(sockJsService.getHandshakeInterceptors().get(1).getClass()).isEqualTo(OriginHandshakeInterceptor.class);
assertThat(sockJsService.getAllowedOrigins().contains(origin)).isTrue(); assertThat(sockJsService.getAllowedOrigins().contains(origin)).isTrue();

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -88,7 +88,7 @@ public class StompSubProtocolHandlerTests {
@BeforeEach @BeforeEach
public void setup() { void setup() {
this.protocolHandler = new StompSubProtocolHandler(); this.protocolHandler = new StompSubProtocolHandler();
this.channel = Mockito.mock(MessageChannel.class); this.channel = Mockito.mock(MessageChannel.class);
this.messageCaptor = ArgumentCaptor.forClass(Message.class); this.messageCaptor = ArgumentCaptor.forClass(Message.class);
@ -101,24 +101,22 @@ public class StompSubProtocolHandlerTests {
} }
@Test @Test
public void handleMessageToClientWithConnectedFrame() { void handleMessageToClientWithConnectedFrame() {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
this.protocolHandler.handleMessageToClient(this.session, message); this.protocolHandler.handleMessageToClient(this.session, message);
assertThat(this.session.getSentMessages().size()).isEqualTo(1); assertThat(this.session.getSentMessages().size()).isEqualTo(1);
WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0); WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
assertThat(textMessage.getPayload()).isEqualTo((""" assertThat(textMessage.getPayload()).isEqualTo("""
CONNECTED CONNECTED
user-name:joe user-name:joe
\u0000""")); \u0000""");
} }
@Test @Test
public void handleMessageToClientWithDestinationUserNameProvider() { void handleMessageToClientWithDestinationUserNameProvider() {
this.session.setPrincipal(new UniqueUser("joe")); this.session.setPrincipal(new UniqueUser("joe"));
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
@ -127,16 +125,15 @@ public class StompSubProtocolHandlerTests {
assertThat(this.session.getSentMessages().size()).isEqualTo(1); assertThat(this.session.getSentMessages().size()).isEqualTo(1);
WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0); WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
assertThat(textMessage.getPayload()).isEqualTo((""" assertThat(textMessage.getPayload()).isEqualTo("""
CONNECTED CONNECTED
user-name:joe user-name:joe
\u0000""")); \u0000""");
} }
@Test @Test
public void handleMessageToClientWithSimpConnectAck() { void handleMessageToClientWithSimpConnectAck() {
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT); StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT);
accessor.setHeartbeat(10000, 10000); accessor.setHeartbeat(10000, 10000);
accessor.setAcceptVersion("1.0,1.1,1.2"); accessor.setAcceptVersion("1.0,1.1,1.2");
@ -150,18 +147,17 @@ public class StompSubProtocolHandlerTests {
assertThat(this.session.getSentMessages().size()).isEqualTo(1); assertThat(this.session.getSentMessages().size()).isEqualTo(1);
TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); TextMessage actual = (TextMessage) this.session.getSentMessages().get(0);
assertThat(actual.getPayload()).isEqualTo((""" assertThat(actual.getPayload()).isEqualTo("""
CONNECTED CONNECTED
version:1.2 version:1.2
heart-beat:15000,15000 heart-beat:15000,15000
user-name:joe user-name:joe
\u0000""")); \u0000""");
} }
@Test @Test
public void handleMessageToClientWithSimpConnectAckDefaultHeartBeat() { void handleMessageToClientWithSimpConnectAckDefaultHeartBeat() {
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT); StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT);
accessor.setHeartbeat(10000, 10000); accessor.setHeartbeat(10000, 10000);
accessor.setAcceptVersion("1.0"); accessor.setAcceptVersion("1.0");
@ -174,18 +170,17 @@ public class StompSubProtocolHandlerTests {
assertThat(this.session.getSentMessages().size()).isEqualTo(1); assertThat(this.session.getSentMessages().size()).isEqualTo(1);
TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); TextMessage actual = (TextMessage) this.session.getSentMessages().get(0);
assertThat(actual.getPayload()).isEqualTo((""" assertThat(actual.getPayload()).isEqualTo("""
CONNECTED CONNECTED
version:1.0 version:1.0
heart-beat:0,0 heart-beat:0,0
user-name:joe user-name:joe
\u0000""")); \u0000""");
} }
@Test @Test
public void handleMessageToClientWithSimpDisconnectAck() { void handleMessageToClientWithSimpDisconnectAck() {
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.DISCONNECT); StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.DISCONNECT);
Message<?> connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders()); Message<?> connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
@ -196,17 +191,16 @@ public class StompSubProtocolHandlerTests {
assertThat(this.session.getSentMessages().size()).isEqualTo(1); assertThat(this.session.getSentMessages().size()).isEqualTo(1);
TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); TextMessage actual = (TextMessage) this.session.getSentMessages().get(0);
assertThat(actual.getPayload()).isEqualTo((""" assertThat(actual.getPayload()).isEqualTo("""
ERROR ERROR
message:Session closed. message:Session closed.
content-length:0 content-length:0
\u0000""")); \u0000""");
} }
@Test @Test
public void handleMessageToClientWithSimpDisconnectAckAndReceipt() { void handleMessageToClientWithSimpDisconnectAckAndReceipt() {
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.DISCONNECT); StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.DISCONNECT);
accessor.setReceipt("message-123"); accessor.setReceipt("message-123");
Message<?> connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders()); Message<?> connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
@ -218,16 +212,15 @@ public class StompSubProtocolHandlerTests {
assertThat(this.session.getSentMessages().size()).isEqualTo(1); assertThat(this.session.getSentMessages().size()).isEqualTo(1);
TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); TextMessage actual = (TextMessage) this.session.getSentMessages().get(0);
assertThat(actual.getPayload()).isEqualTo((""" assertThat(actual.getPayload()).isEqualTo("""
RECEIPT RECEIPT
receipt-id:message-123 receipt-id:message-123
\u0000""")); \u0000""");
} }
@Test @Test
public void handleMessageToClientWithSimpHeartbeat() { void handleMessageToClientWithSimpHeartbeat() {
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT); SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT);
accessor.setSessionId("s1"); accessor.setSessionId("s1");
accessor.setUser(new TestPrincipal("joe")); accessor.setUser(new TestPrincipal("joe"));
@ -240,8 +233,7 @@ public class StompSubProtocolHandlerTests {
} }
@Test @Test
public void handleMessageToClientWithHeartbeatSuppressingSockJsHeartbeat() throws IOException { void handleMessageToClientWithHeartbeatSuppressingSockJsHeartbeat() throws IOException {
SockJsSession sockJsSession = Mockito.mock(SockJsSession.class); SockJsSession sockJsSession = Mockito.mock(SockJsSession.class);
given(sockJsSession.getId()).willReturn("s1"); given(sockJsSession.getId()).willReturn("s1");
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED); StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED);
@ -269,8 +261,7 @@ public class StompSubProtocolHandlerTests {
} }
@Test @Test
public void handleMessageToClientWithUserDestination() { void handleMessageToClientWithUserDestination() {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE);
headers.setMessageId("mess0"); headers.setMessageId("mess0");
headers.setSubscriptionId("sub0"); headers.setSubscriptionId("sub0");
@ -288,8 +279,7 @@ public class StompSubProtocolHandlerTests {
// SPR-12475 // SPR-12475
@Test @Test
public void handleMessageToClientWithBinaryWebSocketMessage() { void handleMessageToClientWithBinaryWebSocketMessage() {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE);
headers.setMessageId("mess0"); headers.setMessageId("mess0");
headers.setSubscriptionId("sub0"); headers.setSubscriptionId("sub0");
@ -318,8 +308,7 @@ public class StompSubProtocolHandlerTests {
} }
@Test @Test
public void handleMessageFromClient() { void handleMessageFromClient() {
TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.STOMP).headers( TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.STOMP).headers(
"login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build(); "login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build();
@ -347,7 +336,7 @@ public class StompSubProtocolHandlerTests {
} }
@Test @Test
public void handleMessageFromClientWithImmutableMessageInterceptor() { void handleMessageFromClientWithImmutableMessageInterceptor() {
AtomicReference<Boolean> mutable = new AtomicReference<>(); AtomicReference<Boolean> mutable = new AtomicReference<>();
ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(); ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel();
channel.addInterceptor(new ChannelInterceptor() { channel.addInterceptor(new ChannelInterceptor() {
@ -369,7 +358,7 @@ public class StompSubProtocolHandlerTests {
} }
@Test @Test
public void handleMessageFromClientWithoutImmutableMessageInterceptor() { void handleMessageFromClientWithoutImmutableMessageInterceptor() {
AtomicReference<Boolean> mutable = new AtomicReference<>(); AtomicReference<Boolean> mutable = new AtomicReference<>();
ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(); ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel();
channel.addInterceptor(new ChannelInterceptor() { channel.addInterceptor(new ChannelInterceptor() {
@ -390,7 +379,7 @@ public class StompSubProtocolHandlerTests {
} }
@Test // SPR-14690 @Test // SPR-14690
public void handleMessageFromClientWithTokenAuthentication() { void handleMessageFromClientWithTokenAuthentication() {
ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(); ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel();
channel.addInterceptor(new AuthenticationInterceptor("__pete__@gmail.com")); channel.addInterceptor(new AuthenticationInterceptor("__pete__@gmail.com"));
channel.addInterceptor(new ImmutableMessageChannelInterceptor()); channel.addInterceptor(new ImmutableMessageChannelInterceptor());
@ -416,17 +405,15 @@ public class StompSubProtocolHandlerTests {
assertThat(this.session.getSentMessages()).hasSize(1); assertThat(this.session.getSentMessages()).hasSize(1);
WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0); WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
assertThat(textMessage.getPayload()) assertThat(textMessage.getPayload()).isEqualTo("""
.isEqualTo(""" CONNECTED
CONNECTED user-name:__pete__@gmail.com
user-name:__pete__@gmail.com
\u0000"""); \u0000""");
} }
@Test @Test
public void handleMessageFromClientWithInvalidStompCommand() { void handleMessageFromClientWithInvalidStompCommand() {
TextMessage textMessage = new TextMessage("FOO\n\n\0"); TextMessage textMessage = new TextMessage("FOO\n\n\0");
this.protocolHandler.afterSessionStarted(this.session, this.channel); this.protocolHandler.afterSessionStarted(this.session, this.channel);
@ -439,8 +426,7 @@ public class StompSubProtocolHandlerTests {
} }
@Test @Test
public void eventPublication() { void eventPublication() {
TestPublisher publisher = new TestPublisher(); TestPublisher publisher = new TestPublisher();
this.protocolHandler.setApplicationEventPublisher(publisher); this.protocolHandler.setApplicationEventPublisher(publisher);
@ -476,8 +462,7 @@ public class StompSubProtocolHandlerTests {
} }
@Test @Test
public void eventPublicationWithExceptions() { void eventPublicationWithExceptions() {
ApplicationEventPublisher publisher = mock(ApplicationEventPublisher.class); ApplicationEventPublisher publisher = mock(ApplicationEventPublisher.class);
this.protocolHandler.setApplicationEventPublisher(publisher); this.protocolHandler.setApplicationEventPublisher(publisher);
@ -500,11 +485,11 @@ public class StompSubProtocolHandlerTests {
assertThat(this.session.getSentMessages().size()).isEqualTo(1); assertThat(this.session.getSentMessages().size()).isEqualTo(1);
textMessage = (TextMessage) this.session.getSentMessages().get(0); textMessage = (TextMessage) this.session.getSentMessages().get(0);
assertThat(textMessage.getPayload()).isEqualTo((""" assertThat(textMessage.getPayload()).isEqualTo("""
CONNECTED CONNECTED
user-name:joe user-name:joe
\u0000""")); \u0000""");
this.protocolHandler.afterSessionEnded(this.session, CloseStatus.BAD_DATA, this.channel); this.protocolHandler.afterSessionEnded(this.session, CloseStatus.BAD_DATA, this.channel);
@ -518,8 +503,7 @@ public class StompSubProtocolHandlerTests {
} }
@Test @Test
public void webSocketScope() { void webSocketScope() {
Runnable runnable = Mockito.mock(Runnable.class); Runnable runnable = Mockito.mock(Runnable.class);
SimpAttributes simpAttributes = new SimpAttributes(this.session.getId(), this.session.getAttributes()); SimpAttributes simpAttributes = new SimpAttributes(this.session.getId(), this.session.getAttributes());
simpAttributes.setAttribute("name", "value"); simpAttributes.setAttribute("name", "value");
@ -610,4 +594,5 @@ public class StompSubProtocolHandlerTests {
return message; return message;
} }
} }
} }

View File

@ -20,7 +20,6 @@ import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.net.URI; import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Arrays; import java.util.Arrays;
import java.util.Queue; import java.util.Queue;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
@ -54,6 +53,7 @@ import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec;
import org.springframework.web.socket.sockjs.frame.SockJsFrame; import org.springframework.web.socket.sockjs.frame.SockJsFrame;
import org.springframework.web.socket.sockjs.transport.TransportType; import org.springframework.web.socket.sockjs.transport.TransportType;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
@ -66,7 +66,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
*/ */
public class RestTemplateXhrTransportTests { class RestTemplateXhrTransportTests {
private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec(); private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec();
@ -74,7 +74,7 @@ public class RestTemplateXhrTransportTests {
@Test @Test
public void connectReceiveAndClose() throws Exception { void connectReceiveAndClose() throws Exception {
String body = """ String body = """
o o
a["foo"] a["foo"]
@ -89,12 +89,13 @@ public class RestTemplateXhrTransportTests {
} }
@Test @Test
public void connectReceiveAndCloseWithPrelude() throws Exception { void connectReceiveAndCloseWithPrelude() throws Exception {
StringBuilder sb = new StringBuilder(2048); String prelude = "h".repeat(2048);
for (int i = 0; i < 2048; i++) { String body = """
sb.append('h'); %s
} o
String body = sb + "\n" + "o\n" + "a[\"foo\"]\n" + "c[3000,\"Go away!\"]"; a["foo"]
c[3000,"Go away!"]""".formatted(prelude);
ClientHttpResponse response = response(HttpStatus.OK, body); ClientHttpResponse response = response(HttpStatus.OK, body);
connect(response); connect(response);
@ -105,16 +106,19 @@ public class RestTemplateXhrTransportTests {
} }
@Test @Test
public void connectReceiveAndCloseWithStompFrame() throws Exception { void connectReceiveAndCloseWithStompFrame() throws Exception {
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND); StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND);
accessor.setDestination("/destination"); accessor.setDestination("/destination");
MessageHeaders headers = accessor.getMessageHeaders(); MessageHeaders headers = accessor.getMessageHeaders();
Message<byte[]> message = MessageBuilder.createMessage("body".getBytes(StandardCharsets.UTF_8), headers); Message<byte[]> message = MessageBuilder.createMessage("body".getBytes(UTF_8), headers);
byte[] bytes = new StompEncoder().encode(message); byte[] bytes = new StompEncoder().encode(message);
TextMessage textMessage = new TextMessage(bytes); TextMessage textMessage = new TextMessage(bytes);
SockJsFrame frame = SockJsFrame.messageFrame(new Jackson2SockJsMessageCodec(), textMessage.getPayload()); SockJsFrame frame = SockJsFrame.messageFrame(new Jackson2SockJsMessageCodec(), textMessage.getPayload());
String body = "o\n" + frame.getContent() + "\n" + "c[3000,\"Go away!\"]"; String body = """
o
%s
c[3000,"Go away!"]""".formatted(frame.getContent());
ClientHttpResponse response = response(HttpStatus.OK, body); ClientHttpResponse response = response(HttpStatus.OK, body);
connect(response); connect(response);
@ -126,7 +130,7 @@ public class RestTemplateXhrTransportTests {
@Test @Test
@SuppressWarnings("deprecation") @SuppressWarnings("deprecation")
public void connectFailure() throws Exception { void connectFailure() throws Exception {
final HttpServerErrorException expected = new HttpServerErrorException(HttpStatus.INTERNAL_SERVER_ERROR); final HttpServerErrorException expected = new HttpServerErrorException(HttpStatus.INTERNAL_SERVER_ERROR);
RestOperations restTemplate = mock(RestOperations.class); RestOperations restTemplate = mock(RestOperations.class);
given(restTemplate.execute((URI) any(), eq(HttpMethod.POST), any(), any())).willThrow(expected); given(restTemplate.execute((URI) any(), eq(HttpMethod.POST), any(), any())).willThrow(expected);
@ -149,7 +153,7 @@ public class RestTemplateXhrTransportTests {
} }
@Test @Test
public void errorResponseStatus() throws Exception { void errorResponseStatus() throws Exception {
connect(response(HttpStatus.OK, "o\n"), response(HttpStatus.INTERNAL_SERVER_ERROR, "Oops")); connect(response(HttpStatus.OK, "o\n"), response(HttpStatus.INTERNAL_SERVER_ERROR, "Oops"));
verify(this.webSocketHandler).afterConnectionEstablished(any()); verify(this.webSocketHandler).afterConnectionEstablished(any());
@ -159,7 +163,7 @@ public class RestTemplateXhrTransportTests {
} }
@Test @Test
public void responseClosedAfterDisconnected() throws Exception { void responseClosedAfterDisconnected() throws Exception {
String body = """ String body = """
o o
c[3000,"Go away!"] c[3000,"Go away!"]
@ -205,7 +209,7 @@ public class RestTemplateXhrTransportTests {
} }
private InputStream getInputStream(String content) { private InputStream getInputStream(String content) {
byte[] bytes = content.getBytes(StandardCharsets.UTF_8); byte[] bytes = content.getBytes(UTF_8);
return new ByteArrayInputStream(bytes); return new ByteArrayInputStream(bytes);
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -22,7 +22,6 @@ import java.util.Collections;
import jakarta.servlet.ServletOutputStream; import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
@ -49,23 +48,13 @@ import static org.mockito.Mockito.verify;
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @author Sebastien Deleuze * @author Sebastien Deleuze
*/ */
public class SockJsServiceTests extends AbstractHttpRequestTests { class SockJsServiceTests extends AbstractHttpRequestTests {
private TestSockJsService service; private final TestSockJsService service = new TestSockJsService(new ThreadPoolTaskScheduler());
private WebSocketHandler handler;
@Override
@BeforeEach
public void setup() {
super.setup();
this.service = new TestSockJsService(new ThreadPoolTaskScheduler());
}
@Test @Test
public void validateRequest() { void validateRequest() {
this.service.setWebSocketEnabled(false); this.service.setWebSocketEnabled(false);
resetResponseAndHandleRequest("GET", "/echo/server/session/websocket", HttpStatus.NOT_FOUND); resetResponseAndHandleRequest("GET", "/echo/server/session/websocket", HttpStatus.NOT_FOUND);
@ -84,7 +73,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
} }
@Test @Test
public void handleInfoGet() throws IOException { void handleInfoGet() throws IOException {
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK); resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
assertThat(this.servletResponse.getContentType()).isEqualTo("application/json;charset=UTF-8"); assertThat(this.servletResponse.getContentType()).isEqualTo("application/json;charset=UTF-8");
@ -113,7 +102,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
} }
@Test // SPR-12226 and SPR-12660 @Test // SPR-12226 and SPR-12660
public void handleInfoGetWithOrigin() throws IOException { void handleInfoGetWithOrigin() throws IOException {
this.servletRequest.setServerName("mydomain2.example"); this.servletRequest.setServerName("mydomain2.example");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.example"); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.example");
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK); resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
@ -140,7 +129,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
} }
@Test // SPR-11443 @Test // SPR-11443
public void handleInfoGetCorsFilter() { void handleInfoGetCorsFilter() {
// Simulate scenario where Filter would have already set CORS headers // Simulate scenario where Filter would have already set CORS headers
this.servletResponse.setHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN, "foobar:123"); this.servletResponse.setHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN, "foobar:123");
@ -150,7 +139,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
} }
@Test // SPR-11919 @Test // SPR-11919
public void handleInfoGetWildflyNPE() throws IOException { void handleInfoGetWildflyNPE() throws IOException {
HttpServletResponse mockResponse = mock(HttpServletResponse.class); HttpServletResponse mockResponse = mock(HttpServletResponse.class);
ServletOutputStream ous = mock(ServletOutputStream.class); ServletOutputStream ous = mock(ServletOutputStream.class);
given(mockResponse.getHeaders(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)).willThrow(NullPointerException.class); given(mockResponse.getHeaders(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)).willThrow(NullPointerException.class);
@ -163,7 +152,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
} }
@Test // SPR-12660 @Test // SPR-12660
public void handleInfoOptions() { void handleInfoOptions() {
this.servletRequest.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Last-Modified"); this.servletRequest.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Last-Modified");
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT); resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
assertThat(this.service.getCorsConfiguration(this.servletRequest)).isNull(); assertThat(this.service.getCorsConfiguration(this.servletRequest)).isNull();
@ -174,7 +163,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
} }
@Test // SPR-12226 and SPR-12660 @Test // SPR-12226 and SPR-12660
public void handleInfoOptionsWithAllowedOrigin() { void handleInfoOptionsWithAllowedOrigin() {
this.servletRequest.setServerName("mydomain2.example"); this.servletRequest.setServerName("mydomain2.example");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.example"); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.example");
this.servletRequest.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); this.servletRequest.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
@ -196,7 +185,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
} }
@Test // SPR-16304 @Test // SPR-16304
public void handleInfoOptionsWithForbiddenOrigin() { void handleInfoOptionsWithForbiddenOrigin() {
this.servletRequest.setServerName("mydomain3.com"); this.servletRequest.setServerName("mydomain3.com");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "https://mydomain2.example"); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "https://mydomain2.example");
this.servletRequest.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); this.servletRequest.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
@ -212,7 +201,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
} }
@Test // SPR-12283 @Test // SPR-12283
public void handleInfoOptionsWithOriginAndCorsHeadersDisabled() { void handleInfoOptionsWithOriginAndCorsHeadersDisabled() {
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "https://mydomain2.example"); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "https://mydomain2.example");
this.service.setAllowedOriginPatterns(Collections.singletonList("*")); this.service.setAllowedOriginPatterns(Collections.singletonList("*"));
this.service.setSuppressCors(true); this.service.setSuppressCors(true);
@ -233,7 +222,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
} }
@Test @Test
public void handleIframeRequest() throws IOException { void handleIframeRequest() throws IOException {
resetResponseAndHandleRequest("GET", "/echo/iframe.html", HttpStatus.OK); resetResponseAndHandleRequest("GET", "/echo/iframe.html", HttpStatus.OK);
assertThat(this.servletResponse.getContentType()).isEqualTo("text/html;charset=UTF-8"); assertThat(this.servletResponse.getContentType()).isEqualTo("text/html;charset=UTF-8");
@ -244,23 +233,22 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
} }
@Test @Test
public void handleIframeRequestNotModified() { void handleIframeRequestNotModified() {
this.servletRequest.addHeader("If-None-Match", "\"096aaf2482e2a85effc0ab65a61993ae0\""); this.servletRequest.addHeader("If-None-Match", "\"096aaf2482e2a85effc0ab65a61993ae0\"");
resetResponseAndHandleRequest("GET", "/echo/iframe.html", HttpStatus.NOT_MODIFIED); resetResponseAndHandleRequest("GET", "/echo/iframe.html", HttpStatus.NOT_MODIFIED);
} }
@Test @Test
public void handleRawWebSocketRequest() throws IOException { void handleRawWebSocketRequest() throws IOException {
resetResponseAndHandleRequest("GET", "/echo", HttpStatus.OK); resetResponseAndHandleRequest("GET", "/echo", HttpStatus.OK);
assertThat(this.servletResponse.getContentAsString()).isEqualTo("Welcome to SockJS!\n"); assertThat(this.servletResponse.getContentAsString()).isEqualTo("Welcome to SockJS!\n");
resetResponseAndHandleRequest("GET", "/echo/websocket", HttpStatus.OK); resetResponseAndHandleRequest("GET", "/echo/websocket", HttpStatus.OK);
assertThat(this.service.sessionId).as("Raw WebSocket should not open a SockJS session").isNull(); assertThat(this.service.sessionId).as("Raw WebSocket should not open a SockJS session").isNull();
assertThat(this.service.handler).isSameAs(this.handler);
} }
@Test @Test
public void handleEmptyContentType() { void handleEmptyContentType() {
this.servletRequest.setContentType(""); this.servletRequest.setContentType("");
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK); resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
@ -276,7 +264,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
private void handleRequest(String httpMethod, String uri, HttpStatus httpStatus) { private void handleRequest(String httpMethod, String uri, HttpStatus httpStatus) {
setRequest(httpMethod, uri); setRequest(httpMethod, uri);
String sockJsPath = uri.substring("/echo".length()); String sockJsPath = uri.substring("/echo".length());
this.service.handleRequest(this.request, this.response, sockJsPath, this.handler); this.service.handleRequest(this.request, this.response, sockJsPath, null);
assertThat(this.servletResponse.getStatus()).isEqualTo(httpStatus.value()); assertThat(this.servletResponse.getStatus()).isEqualTo(httpStatus.value());
} }
@ -286,27 +274,20 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
private String sessionId; private String sessionId;
@SuppressWarnings("unused")
private String transport;
private WebSocketHandler handler; TestSockJsService(TaskScheduler scheduler) {
public TestSockJsService(TaskScheduler scheduler) {
super(scheduler); super(scheduler);
} }
@Override @Override
protected void handleRawWebSocketRequest(ServerHttpRequest req, ServerHttpResponse res, protected void handleRawWebSocketRequest(ServerHttpRequest req, ServerHttpResponse res,
WebSocketHandler handler) throws IOException { WebSocketHandler handler) throws IOException {
this.handler = handler;
} }
@Override @Override
protected void handleTransportRequest(ServerHttpRequest req, ServerHttpResponse res, WebSocketHandler handler, protected void handleTransportRequest(ServerHttpRequest req, ServerHttpResponse res, WebSocketHandler handler,
String sessionId, String transport) throws SockJsException { String sessionId, String transport) throws SockJsException {
this.sessionId = sessionId; this.sessionId = sessionId;
this.transport = transport;
this.handler = handler;
} }
} }

View File

@ -16,8 +16,6 @@
package org.springframework.web.socket.sockjs.transport.handler; package org.springframework.web.socket.sockjs.transport.handler;
import java.nio.charset.StandardCharsets;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.web.socket.AbstractHttpRequestTests; import org.springframework.web.socket.AbstractHttpRequestTests;
@ -28,6 +26,7 @@ import org.springframework.web.socket.sockjs.transport.session.AbstractSockJsSes
import org.springframework.web.socket.sockjs.transport.session.StubSockJsServiceConfig; import org.springframework.web.socket.sockjs.transport.session.StubSockJsServiceConfig;
import org.springframework.web.socket.sockjs.transport.session.TestHttpSockJsSession; import org.springframework.web.socket.sockjs.transport.session.TestHttpSockJsSession;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@ -37,41 +36,41 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
/** /**
* Test fixture for {@link AbstractHttpReceivingTransportHandler} and subclasses * Test fixture for {@link AbstractHttpReceivingTransportHandler} and
* {@link XhrReceivingTransportHandler}. * {@link XhrReceivingTransportHandler}.
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
*/ */
public class HttpReceivingTransportHandlerTests extends AbstractHttpRequestTests { class HttpReceivingTransportHandlerTests extends AbstractHttpRequestTests {
@Test @Test
public void readMessagesXhr() throws Exception { void readMessagesXhr() throws Exception {
this.servletRequest.setContent("[\"x\"]".getBytes(StandardCharsets.UTF_8)); this.servletRequest.setContent("[\"x\"]".getBytes(UTF_8));
handleRequest(new XhrReceivingTransportHandler()); handleRequest(new XhrReceivingTransportHandler());
assertThat(this.servletResponse.getStatus()).isEqualTo(204); assertThat(this.servletResponse.getStatus()).isEqualTo(204);
} }
@Test @Test
public void readMessagesBadContent() throws Exception { void readMessagesBadContent() throws Exception {
this.servletRequest.setContent("".getBytes(StandardCharsets.UTF_8)); this.servletRequest.setContent("".getBytes(UTF_8));
handleRequestAndExpectFailure(); handleRequestAndExpectFailure();
this.servletRequest.setContent("[\"x]".getBytes(StandardCharsets.UTF_8)); this.servletRequest.setContent("[\"x]".getBytes(UTF_8));
handleRequestAndExpectFailure(); handleRequestAndExpectFailure();
} }
@Test @Test
public void readMessagesNoSession() throws Exception { void readMessagesNoSession() throws Exception {
WebSocketHandler webSocketHandler = mock(WebSocketHandler.class); WebSocketHandler webSocketHandler = mock(WebSocketHandler.class);
assertThatIllegalArgumentException().isThrownBy(() -> assertThatIllegalArgumentException().isThrownBy(() ->
new XhrReceivingTransportHandler().handleRequest(this.request, this.response, webSocketHandler, null)); new XhrReceivingTransportHandler().handleRequest(this.request, this.response, webSocketHandler, null));
} }
@Test @Test
public void delegateMessageException() throws Exception { void delegateMessageException() throws Exception {
StubSockJsServiceConfig sockJsConfig = new StubSockJsServiceConfig(); StubSockJsServiceConfig sockJsConfig = new StubSockJsServiceConfig();
this.servletRequest.setContent("[\"x\"]".getBytes(StandardCharsets.UTF_8)); this.servletRequest.setContent("[\"x\"]".getBytes(UTF_8));
WebSocketHandler wsHandler = mock(WebSocketHandler.class); WebSocketHandler wsHandler = mock(WebSocketHandler.class);
TestHttpSockJsSession session = new TestHttpSockJsSession("1", sockJsConfig, wsHandler, null); TestHttpSockJsSession session = new TestHttpSockJsSession("1", sockJsConfig, wsHandler, null);

View File

@ -40,7 +40,7 @@ import static org.mockito.Mockito.verify;
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
*/ */
public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests { class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests {
private WebSocketHandler webSocketHandler; private WebSocketHandler webSocketHandler;
@ -51,7 +51,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests
@Override @Override
@BeforeEach @BeforeEach
public void setup() { protected void setup() {
super.setup(); super.setup();
this.webSocketHandler = mock(WebSocketHandler.class); this.webSocketHandler = mock(WebSocketHandler.class);
@ -65,7 +65,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests
@Test @Test
public void handleRequestXhr() throws Exception { void handleRequestXhr() throws Exception {
XhrPollingTransportHandler transportHandler = new XhrPollingTransportHandler(); XhrPollingTransportHandler transportHandler = new XhrPollingTransportHandler();
transportHandler.initialize(this.sockJsConfig); transportHandler.initialize(this.sockJsConfig);
@ -91,7 +91,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests
} }
@Test @Test
public void handleRequestXhrStreaming() throws Exception { void handleRequestXhrStreaming() throws Exception {
XhrStreamingTransportHandler transportHandler = new XhrStreamingTransportHandler(); XhrStreamingTransportHandler transportHandler = new XhrStreamingTransportHandler();
transportHandler.initialize(this.sockJsConfig); transportHandler.initialize(this.sockJsConfig);
AbstractSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null); AbstractSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null);
@ -104,7 +104,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests
} }
@Test @Test
public void htmlFileTransport() throws Exception { void htmlFileTransport() throws Exception {
HtmlFileTransportHandler transportHandler = new HtmlFileTransportHandler(); HtmlFileTransportHandler transportHandler = new HtmlFileTransportHandler();
transportHandler.initialize(this.sockJsConfig); transportHandler.initialize(this.sockJsConfig);
StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null); StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null);
@ -126,7 +126,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests
} }
@Test @Test
public void eventSourceTransport() throws Exception { void eventSourceTransport() throws Exception {
EventSourceTransportHandler transportHandler = new EventSourceTransportHandler(); EventSourceTransportHandler transportHandler = new EventSourceTransportHandler();
transportHandler.initialize(this.sockJsConfig); transportHandler.initialize(this.sockJsConfig);
StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null); StreamingSockJsSession session = transportHandler.createSession("1", this.webSocketHandler, null);
@ -139,7 +139,7 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests
} }
@Test @Test
public void frameFormats() throws Exception { void frameFormats() throws Exception {
this.servletRequest.setQueryString("c=callback"); this.servletRequest.setQueryString("c=callback");
this.servletRequest.addParameter("c", "callback"); this.servletRequest.addParameter("c", "callback");

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -29,25 +29,20 @@ import static org.mockito.Mockito.mock;
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
*/ */
public abstract class AbstractSockJsSessionTests<S extends AbstractSockJsSession> { abstract class AbstractSockJsSessionTests<S extends AbstractSockJsSession> {
protected WebSocketHandler webSocketHandler; protected WebSocketHandler webSocketHandler = mock(WebSocketHandler.class);
protected StubSockJsServiceConfig sockJsConfig; protected TaskScheduler taskScheduler = mock(TaskScheduler.class);
protected TaskScheduler taskScheduler; protected StubSockJsServiceConfig sockJsConfig = new StubSockJsServiceConfig();
protected S session; protected S session;
@BeforeEach @BeforeEach
public void setUp() { protected void setUp() {
this.webSocketHandler = mock(WebSocketHandler.class);
this.taskScheduler = mock(TaskScheduler.class);
this.sockJsConfig = new StubSockJsServiceConfig();
this.sockJsConfig.setTaskScheduler(this.taskScheduler); this.sockJsConfig.setTaskScheduler(this.taskScheduler);
this.session = initSockJsSession(); this.session = initSockJsSession();
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -44,17 +44,17 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
*/ */
public class HttpSockJsSessionTests extends AbstractSockJsSessionTests<TestAbstractHttpSockJsSession> { class HttpSockJsSessionTests extends AbstractSockJsSessionTests<TestAbstractHttpSockJsSession> {
protected ServerHttpRequest request; protected MockHttpServletRequest servletRequest = new MockHttpServletRequest();
protected ServerHttpResponse response; protected MockHttpServletResponse servletResponse = new MockHttpServletResponse();
protected MockHttpServletRequest servletRequest; protected ServerHttpRequest request = new ServletServerHttpRequest(this.servletRequest);
protected MockHttpServletResponse servletResponse; protected ServerHttpResponse response = new ServletServerHttpResponse(this.servletResponse);
private SockJsFrameFormat frameFormat; private SockJsFrameFormat frameFormat = new DefaultSockJsFrameFormat("%s");
@Override @Override
@ -63,23 +63,14 @@ public class HttpSockJsSessionTests extends AbstractSockJsSessionTests<TestAbstr
} }
@BeforeEach @BeforeEach
public void setup() { @Override
protected void setUp() {
super.setUp(); super.setUp();
this.frameFormat = new DefaultSockJsFrameFormat("%s");
this.servletResponse = new MockHttpServletResponse();
this.response = new ServletServerHttpResponse(this.servletResponse);
this.servletRequest = new MockHttpServletRequest();
this.servletRequest.setAsyncSupported(true); this.servletRequest.setAsyncSupported(true);
this.request = new ServletServerHttpRequest(this.servletRequest);
} }
@Test @Test
public void handleInitialRequest() throws Exception { void handleInitialRequest() throws Exception {
this.session.handleInitialRequest(this.request, this.response, this.frameFormat); this.session.handleInitialRequest(this.request, this.response, this.frameFormat);
assertThat(this.servletResponse.getContentAsString()).isEqualTo("hhh\no"); assertThat(this.servletResponse.getContentAsString()).isEqualTo("hhh\no");
@ -89,8 +80,7 @@ public class HttpSockJsSessionTests extends AbstractSockJsSessionTests<TestAbstr
} }
@Test @Test
public void handleSuccessiveRequest() throws Exception { void handleSuccessiveRequest() throws Exception {
this.session.getMessageCache().add("x"); this.session.getMessageCache().add("x");
this.session.handleSuccessiveRequest(this.request, this.response, this.frameFormat); this.session.handleSuccessiveRequest(this.request, this.response, this.frameFormat);
@ -112,7 +102,7 @@ public class HttpSockJsSessionTests extends AbstractSockJsSessionTests<TestAbstr
private boolean heartbeatScheduled; private boolean heartbeatScheduled;
public TestAbstractHttpSockJsSession(SockJsServiceConfig config, WebSocketHandler handler, TestAbstractHttpSockJsSession(SockJsServiceConfig config, WebSocketHandler handler,
Map<String, Object> attributes) { Map<String, Object> attributes) {
super("1", config, handler, attributes); super("1", config, handler, attributes);
@ -123,15 +113,15 @@ public class HttpSockJsSessionTests extends AbstractSockJsSessionTests<TestAbstr
return "hhh\n".getBytes(); return "hhh\n".getBytes();
} }
public boolean wasCacheFlushed() { boolean wasCacheFlushed() {
return this.cacheFlushed; return this.cacheFlushed;
} }
public boolean wasHeartbeatScheduled() { boolean wasHeartbeatScheduled() {
return this.heartbeatScheduled; return this.heartbeatScheduled;
} }
public void setExceptionOnWriteFrame(IOException exceptionOnWriteFrame) { void setExceptionOnWriteFrame(IOException exceptionOnWriteFrame) {
this.exceptionOnWriteFrame = exceptionOnWriteFrame; this.exceptionOnWriteFrame = exceptionOnWriteFrame;
} }

View File

@ -44,8 +44,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
*/ */
public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSession> { class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSession> {
@Override @Override
protected TestSockJsSession initSockJsSession() { protected TestSockJsSession initSockJsSession() {
@ -54,7 +53,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
@Test @Test
public void getTimeSinceLastActive() throws Exception { void getTimeSinceLastActive() throws Exception {
Thread.sleep(1); Thread.sleep(1);
long time1 = this.session.getTimeSinceLastActive(); long time1 = this.session.getTimeSinceLastActive();
@ -77,7 +76,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
} }
@Test @Test
public void delegateConnectionEstablished() throws Exception { void delegateConnectionEstablished() throws Exception {
assertNew(); assertNew();
this.session.delegateConnectionEstablished(); this.session.delegateConnectionEstablished();
assertOpen(); assertOpen();
@ -85,15 +84,14 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
} }
@Test @Test
public void delegateError() throws Exception { void delegateError() throws Exception {
Exception ex = new Exception(); Exception ex = new Exception();
this.session.delegateError(ex); this.session.delegateError(ex);
verify(this.webSocketHandler).handleTransportError(this.session, ex); verify(this.webSocketHandler).handleTransportError(this.session, ex);
} }
@Test @Test
public void delegateMessages() throws Exception { void delegateMessages() throws Exception {
String msg1 = "message 1"; String msg1 = "message 1";
String msg2 = "message 2"; String msg2 = "message 2";
@ -105,8 +103,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
} }
@Test @Test
public void delegateMessagesWithError() throws Exception { void delegateMessagesWithError() throws Exception {
TestSockJsSession session = new TestSockJsSession("1", this.sockJsConfig, TestSockJsSession session = new TestSockJsSession("1", this.sockJsConfig,
new ExceptionWebSocketHandlerDecorator(this.webSocketHandler), Collections.emptyMap()); new ExceptionWebSocketHandlerDecorator(this.webSocketHandler), Collections.emptyMap());
@ -127,8 +124,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
} }
@Test // gh-23828 @Test // gh-23828
public void delegateMessagesEmptyAfterConnectionClosed() throws Exception { void delegateMessagesEmptyAfterConnectionClosed() throws Exception {
TestSockJsSession session = new TestSockJsSession("1", this.sockJsConfig, TestSockJsSession session = new TestSockJsSession("1", this.sockJsConfig,
new ExceptionWebSocketHandlerDecorator(this.webSocketHandler), Collections.emptyMap()); new ExceptionWebSocketHandlerDecorator(this.webSocketHandler), Collections.emptyMap());
@ -144,7 +140,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
} }
@Test @Test
public void delegateConnectionClosed() throws Exception { void delegateConnectionClosed() throws Exception {
this.session.delegateConnectionEstablished(); this.session.delegateConnectionEstablished();
this.session.delegateConnectionClosed(CloseStatus.GOING_AWAY); this.session.delegateConnectionClosed(CloseStatus.GOING_AWAY);
@ -154,7 +150,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
} }
@Test @Test
public void closeWhenNotOpen() throws Exception { void closeWhenNotOpen() throws Exception {
assertNew(); assertNew();
this.session.close(); this.session.close();
@ -172,7 +168,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
} }
@Test @Test
public void closeWhenNotActive() throws Exception { void closeWhenNotActive() throws Exception {
this.session.delegateConnectionEstablished(); this.session.delegateConnectionEstablished();
assertOpen(); assertOpen();
@ -183,7 +179,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
} }
@Test @Test
public void close() throws Exception { void close() throws Exception {
this.session.delegateConnectionEstablished(); this.session.delegateConnectionEstablished();
assertOpen(); assertOpen();
@ -202,7 +198,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
} }
@Test @Test
public void closeWithWriteFrameExceptions() throws Exception { void closeWithWriteFrameExceptions() throws Exception {
this.session.setExceptionOnWrite(new IOException()); this.session.setExceptionOnWrite(new IOException());
this.session.delegateConnectionEstablished(); this.session.delegateConnectionEstablished();
@ -214,7 +210,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
} }
@Test @Test
public void closeWithWebSocketHandlerExceptions() throws Exception { void closeWithWebSocketHandlerExceptions() throws Exception {
willThrow(new Exception()).given(this.webSocketHandler).afterConnectionClosed(this.session, CloseStatus.NORMAL); willThrow(new Exception()).given(this.webSocketHandler).afterConnectionClosed(this.session, CloseStatus.NORMAL);
this.session.delegateConnectionEstablished(); this.session.delegateConnectionEstablished();
@ -226,7 +222,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
} }
@Test @Test
public void tryCloseWithWebSocketHandlerExceptions() throws Exception { void tryCloseWithWebSocketHandlerExceptions() throws Exception {
this.session.delegateConnectionEstablished(); this.session.delegateConnectionEstablished();
this.session.setActive(true); this.session.setActive(true);
this.session.tryCloseWithSockJsTransportError(new Exception(), CloseStatus.BAD_DATA); this.session.tryCloseWithSockJsTransportError(new Exception(), CloseStatus.BAD_DATA);
@ -236,7 +232,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
} }
@Test @Test
public void writeFrame() { void writeFrame() {
this.session.writeFrame(SockJsFrame.openFrame()); this.session.writeFrame(SockJsFrame.openFrame());
assertThat(this.session.getSockJsFramesWritten().size()).isEqualTo(1); assertThat(this.session.getSockJsFramesWritten().size()).isEqualTo(1);
@ -244,7 +240,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
} }
@Test @Test
public void writeFrameIoException() throws Exception { void writeFrameIoException() throws Exception {
this.session.setExceptionOnWrite(new IOException()); this.session.setExceptionOnWrite(new IOException());
this.session.delegateConnectionEstablished(); this.session.delegateConnectionEstablished();
@ -255,7 +251,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
} }
@Test @Test
public void sendHeartbeat() { void sendHeartbeat() {
this.session.setActive(true); this.session.setActive(true);
this.session.sendHeartbeat(); this.session.sendHeartbeat();
@ -267,7 +263,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
} }
@Test @Test
public void scheduleHeartbeatNotActive() { void scheduleHeartbeatNotActive() {
this.session.setActive(false); this.session.setActive(false);
this.session.scheduleHeartbeat(); this.session.scheduleHeartbeat();
@ -275,7 +271,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
} }
@Test @Test
public void sendHeartbeatWhenDisabled() { void sendHeartbeatWhenDisabled() {
this.session.disableHeartbeat(); this.session.disableHeartbeat();
this.session.setActive(true); this.session.setActive(true);
this.session.sendHeartbeat(); this.session.sendHeartbeat();
@ -284,7 +280,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests<TestSockJsSes
} }
@Test @Test
public void scheduleAndCancelHeartbeat() { void scheduleAndCancelHeartbeat() {
ScheduledFuture<?> task = mock(ScheduledFuture.class); ScheduledFuture<?> task = mock(ScheduledFuture.class);
willReturn(task).given(this.taskScheduler).schedule(any(Runnable.class), any(Instant.class)); willReturn(task).given(this.taskScheduler).schedule(any(Runnable.class), any(Instant.class));

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,8 +18,6 @@ package org.springframework.web.socket.sockjs.transport.session;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -46,13 +44,14 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
*/ */
public class WebSocketServerSockJsSessionTests extends AbstractSockJsSessionTests<TestWebSocketServerSockJsSession> { class WebSocketServerSockJsSessionTests extends AbstractSockJsSessionTests<TestWebSocketServerSockJsSession> {
private TestWebSocketSession webSocketSession; private TestWebSocketSession webSocketSession;
@BeforeEach @BeforeEach
public void setup() { @Override
protected void setUp() {
super.setUp(); super.setUp();
this.webSocketSession = new TestWebSocketSession(); this.webSocketSession = new TestWebSocketSession();
this.webSocketSession.setOpen(true); this.webSocketSession.setOpen(true);
@ -60,12 +59,12 @@ public class WebSocketServerSockJsSessionTests extends AbstractSockJsSessionTest
@Override @Override
protected TestWebSocketServerSockJsSession initSockJsSession() { protected TestWebSocketServerSockJsSession initSockJsSession() {
return new TestWebSocketServerSockJsSession(this.sockJsConfig, this.webSocketHandler, return new TestWebSocketServerSockJsSession(this.sockJsConfig, this.webSocketHandler, Map.of());
Collections.<String, Object>emptyMap());
} }
@Test @Test
public void isActive() throws Exception { void isActive() throws Exception {
assertThat(this.session.isActive()).isFalse(); assertThat(this.session.isActive()).isFalse();
this.session.initializeDelegateSession(this.webSocketSession); this.session.initializeDelegateSession(this.webSocketSession);
@ -76,17 +75,17 @@ public class WebSocketServerSockJsSessionTests extends AbstractSockJsSessionTest
} }
@Test @Test
public void afterSessionInitialized() throws Exception { void afterSessionInitialized() throws Exception {
this.session.initializeDelegateSession(this.webSocketSession); this.session.initializeDelegateSession(this.webSocketSession);
assertThat(this.webSocketSession.getSentMessages()).isEqualTo(Collections.singletonList(new TextMessage("o"))); assertThat(this.webSocketSession.getSentMessages()).containsExactly(new TextMessage("o"));
assertThat(this.session.heartbeatSchedulingEvents).isEqualTo(List.of("schedule")); assertThat(this.session.heartbeatSchedulingEvents).containsExactly("schedule");
verify(this.webSocketHandler).afterConnectionEstablished(this.session); verify(this.webSocketHandler).afterConnectionEstablished(this.session);
verifyNoMoreInteractions(this.taskScheduler, this.webSocketHandler); verifyNoMoreInteractions(this.taskScheduler, this.webSocketHandler);
} }
@Test @Test
@SuppressWarnings("resource") @SuppressWarnings("resource")
public void afterSessionInitializedOpenFrameFirst() throws Exception { void afterSessionInitializedOpenFrameFirst() throws Exception {
TextWebSocketHandler handler = new TextWebSocketHandler() { TextWebSocketHandler handler = new TextWebSocketHandler() {
@Override @Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception { public void afterConnectionEstablished(WebSocketSession session) throws Exception {
@ -95,19 +94,17 @@ public class WebSocketServerSockJsSessionTests extends AbstractSockJsSessionTest
}; };
TestWebSocketServerSockJsSession session = new TestWebSocketServerSockJsSession(this.sockJsConfig, handler, null); TestWebSocketServerSockJsSession session = new TestWebSocketServerSockJsSession(this.sockJsConfig, handler, null);
session.initializeDelegateSession(this.webSocketSession); session.initializeDelegateSession(this.webSocketSession);
List<TextMessage> expected = Arrays.asList(new TextMessage("o"), new TextMessage("a[\"go go\"]")); assertThat(this.webSocketSession.getSentMessages()).containsExactly(new TextMessage("o"), new TextMessage("a[\"go go\"]"));
assertThat(this.webSocketSession.getSentMessages()).isEqualTo(expected);
} }
@Test @Test
public void handleMessageEmptyPayload() throws Exception { void handleMessageEmptyPayload() throws Exception {
this.session.handleMessage(new TextMessage(""), this.webSocketSession); this.session.handleMessage(new TextMessage(""), this.webSocketSession);
verifyNoMoreInteractions(this.webSocketHandler); verifyNoMoreInteractions(this.webSocketHandler);
} }
@Test @Test
public void handleMessage() throws Exception { void handleMessage() throws Exception {
TextMessage message = new TextMessage("[\"x\"]"); TextMessage message = new TextMessage("[\"x\"]");
this.session.handleMessage(message, this.webSocketSession); this.session.handleMessage(message, this.webSocketSession);
@ -116,7 +113,7 @@ public class WebSocketServerSockJsSessionTests extends AbstractSockJsSessionTest
} }
@Test @Test
public void handleMessageBadData() throws Exception { void handleMessageBadData() throws Exception {
TextMessage message = new TextMessage("[\"x]"); TextMessage message = new TextMessage("[\"x]");
this.session.handleMessage(message, this.webSocketSession); this.session.handleMessage(message, this.webSocketSession);
@ -126,19 +123,16 @@ public class WebSocketServerSockJsSessionTests extends AbstractSockJsSessionTest
} }
@Test @Test
public void sendMessageInternal() throws Exception { void sendMessageInternal() throws Exception {
this.session.initializeDelegateSession(this.webSocketSession); this.session.initializeDelegateSession(this.webSocketSession);
this.session.sendMessageInternal("x"); this.session.sendMessageInternal("x");
assertThat(this.webSocketSession.getSentMessages()).isEqualTo(Arrays.asList(new TextMessage("o"), new TextMessage("a[\"x\"]"))); assertThat(this.webSocketSession.getSentMessages()).containsExactly(new TextMessage("o"), new TextMessage("a[\"x\"]"));
assertThat(this.session.heartbeatSchedulingEvents).containsExactly("schedule", "cancel", "schedule");
assertThat(this.session.heartbeatSchedulingEvents).isEqualTo(Arrays.asList("schedule", "cancel", "schedule"));
} }
@Test @Test
public void disconnect() throws Exception { void disconnect() throws Exception {
this.session.initializeDelegateSession(this.webSocketSession); this.session.initializeDelegateSession(this.webSocketSession);
this.session.close(CloseStatus.NOT_ACCEPTABLE); this.session.close(CloseStatus.NOT_ACCEPTABLE);
@ -150,7 +144,7 @@ public class WebSocketServerSockJsSessionTests extends AbstractSockJsSessionTest
private final List<String> heartbeatSchedulingEvents = new ArrayList<>(); private final List<String> heartbeatSchedulingEvents = new ArrayList<>();
public TestWebSocketServerSockJsSession(SockJsServiceConfig config, WebSocketHandler handler, TestWebSocketServerSockJsSession(SockJsServiceConfig config, WebSocketHandler handler,
Map<String, Object> attributes) { Map<String, Object> attributes) {
super("1", config, handler, attributes); super("1", config, handler, attributes);