diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index e9d4b6c9420..5e6edc07e3a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -85,6 +85,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE */ public static final String CONNECTED_USER_HEADER = "user-name"; + private static final String[] SUPPORTED_VERSIONS = {"1.2", "1.1", "1.0"}; + private static final Log logger = LogFactory.getLog(StompSubProtocolHandler.class); private static final byte[] EMPTY_PAYLOAD = new byte[0]; @@ -524,15 +526,12 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE if (connectHeaders != null) { Set acceptVersions = connectHeaders.getAcceptVersion(); - if (acceptVersions.contains("1.2")) { - connectedHeaders.setVersion("1.2"); - } - else if (acceptVersions.contains("1.1")) { - connectedHeaders.setVersion("1.1"); - } - else if (!acceptVersions.isEmpty()) { - throw new IllegalArgumentException("Unsupported STOMP version '" + acceptVersions + "'"); - } + connectedHeaders.setVersion( + Arrays.stream(SUPPORTED_VERSIONS) + .filter(acceptVersions::contains) + .findAny() + .orElseThrow(() -> new IllegalArgumentException( + "Unsupported STOMP version '" + acceptVersions + "'"))); } long[] heartbeat = (long[]) connectAckHeaders.getHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index 418daa29121..3a9be4f4d31 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -126,7 +126,7 @@ public class StompSubProtocolHandlerTests { StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT); accessor.setHeartbeat(10000, 10000); - accessor.setAcceptVersion("1.0,1.1"); + accessor.setAcceptVersion("1.0,1.1,1.2"); Message connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders()); SimpMessageHeaderAccessor ackAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); @@ -137,7 +137,7 @@ public class StompSubProtocolHandlerTests { assertEquals(1, this.session.getSentMessages().size()); TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); - assertEquals("CONNECTED\n" + "version:1.1\n" + "heart-beat:15000,15000\n" + + assertEquals("CONNECTED\n" + "version:1.2\n" + "heart-beat:15000,15000\n" + "user-name:joe\n" + "\n" + "\u0000", actual.getPayload()); } @@ -146,7 +146,7 @@ public class StompSubProtocolHandlerTests { StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT); accessor.setHeartbeat(10000, 10000); - accessor.setAcceptVersion("1.0,1.1"); + accessor.setAcceptVersion("1.0"); Message connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders()); SimpMessageHeaderAccessor ackAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); @@ -156,7 +156,7 @@ public class StompSubProtocolHandlerTests { assertEquals(1, this.session.getSentMessages().size()); TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); - assertEquals("CONNECTED\n" + "version:1.1\n" + "heart-beat:0,0\n" + + assertEquals("CONNECTED\n" + "version:1.0\n" + "heart-beat:0,0\n" + "user-name:joe\n" + "\n" + "\u0000", actual.getPayload()); }