Support splitting STOMP messages in WebSocketStompClient

See gh-31970
This commit is contained in:
injae-kim 2024-01-08 00:45:54 +09:00 committed by rstoyanchev
parent bf014ef18b
commit 76d00d78db
6 changed files with 580 additions and 10 deletions

View File

@ -105,5 +105,20 @@ it handle ERROR frames in addition to the `handleException` callback for
exceptions from the handling of messages and `handleTransportError` for
transport-level errors including `ConnectionLostException`.
You can also use `setInboundMessageSizeLimit(limit)` and `setOutboundMessageSizeLimit(limit)`
to limit the maximum size of inbound and outbound message size.
When outbound message size exceeds `outboundMessageSizeLimit`, message is split into multiple incomplete frames.
Then receiver buffers these incomplete frames and reassemble to complete message.
When inbound message size exceeds `inboundMessageSizeLimit`, throw `StompConversionException`.
The default value of in&outboundMessageSizeLimit is `64KB`.
[source,java,indent=0,subs="verbatim,quotes"]
----
WebSocketClient webSocketClient = new StandardWebSocketClient();
WebSocketStompClient stompClient = new WebSocketStompClient(webSocketClient);
stompClient.setInboundMessageSizeLimit(64 * 1024); // 64KB
stompClient.setOutboundMessageSizeLimit(64 * 1024); // 64KB
----

View File

@ -0,0 +1,68 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.springframework.util.Assert;
/**
* An extension of {@link org.springframework.messaging.simp.stomp.StompEncoder}
* that splits the STOMP message to multiple incomplete STOMP frames
* when the encoded bytes length exceeds {@link SplittingStompEncoder#bufferSizeLimit}.
*
* @author Injae Kim
* @since 6.2
* @see StompEncoder
*/
public class SplittingStompEncoder {
private final StompEncoder encoder;
private final int bufferSizeLimit;
public SplittingStompEncoder(StompEncoder encoder, int bufferSizeLimit) {
Assert.notNull(encoder, "StompEncoder is required");
Assert.isTrue(bufferSizeLimit > 0, "Buffer size limit must be greater than 0");
this.encoder = encoder;
this.bufferSizeLimit = bufferSizeLimit;
}
/**
* Encodes the given payload and headers into a list of one or more {@code byte[]}s.
* @param headers the headers
* @param payload the payload
* @return the list of one or more encoded messages
*/
public List<byte[]> encode(Map<String, Object> headers, byte[] payload) {
byte[] result = this.encoder.encode(headers, payload);
int length = result.length;
if (length <= this.bufferSizeLimit) {
return List.of(result);
}
List<byte[]> frames = new ArrayList<>();
for (int i = 0; i < length; i += this.bufferSizeLimit) {
frames.add(Arrays.copyOfRange(result, i, Math.min(i + this.bufferSizeLimit, length)));
}
return frames;
}
}

View File

@ -78,7 +78,7 @@ public class StompDecoder {
* Decodes one or more STOMP frames from the given {@code ByteBuffer} into a
* list of {@link Message Messages}. If the input buffer contains partial STOMP frame
* content, or additional content with a partial STOMP frame, the buffer is
* reset and {@code null} is returned.
* reset and an empty list is returned.
* @param byteBuffer the buffer to decode the STOMP frame from
* @return the decoded messages, or an empty list if none
* @throws StompConversionException raised in case of decoding issues

View File

@ -0,0 +1,382 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.io.ByteArrayOutputStream;
import java.util.Arrays;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Unit tests for {@link SplittingStompEncoder}.
*
* @author Injae Kim
* @since 6.2
*/
public class SplittingStompEncoderTests {
private final StompEncoder STOMP_ENCODER = new StompEncoder();
private static final int DEFAULT_MESSAGE_MAX_SIZE = 64 * 1024;
@Test
public void encodeFrameWithNoHeadersAndNoBody() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, DEFAULT_MESSAGE_MAX_SIZE);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
Message<byte[]> frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
assertThat(outputStream.toString()).isEqualTo("DISCONNECT\n\n\0");
assertThat(actual.size()).isOne();
}
@Test
public void encodeFrameWithNoHeadersAndNoBodySplitTwoFrames() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 7);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
Message<byte[]> frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
assertThat(outputStream.toString()).isEqualTo("DISCONNECT\n\n\0");
assertThat(actual.size()).isEqualTo(2);
assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 7));
assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 7, outputStream.size()));
}
@Test
public void encodeFrameWithNoHeadersAndNoBodySplitMultipleFrames() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 3);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
Message<byte[]> frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
assertThat(outputStream.toString()).isEqualTo("DISCONNECT\n\n\0");
assertThat(actual.size()).isEqualTo(5);
assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 3));
assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 3, 6));
assertThat(actual.get(2)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 6, 9));
assertThat(actual.get(3)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 9, 12));
assertThat(actual.get(4)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 12, outputStream.size()));
}
@Test
public void encodeFrameWithHeaders() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, DEFAULT_MESSAGE_MAX_SIZE);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setAcceptVersion("1.2");
headers.setHost("github.org");
Message<byte[]> frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
String actualString = outputStream.toString();
assertThat("CONNECT\naccept-version:1.2\nhost:github.org\n\n\0".equals(actualString) ||
"CONNECT\nhost:github.org\naccept-version:1.2\n\n\0".equals(actualString)).isTrue();
assertThat(actual.size()).isOne();
}
@Test
public void encodeFrameWithHeadersSplitTwoFrames() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 30);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setAcceptVersion("1.2");
headers.setHost("github.org");
Message<byte[]> frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
String actualString = outputStream.toString();
assertThat("CONNECT\naccept-version:1.2\nhost:github.org\n\n\0".equals(actualString) ||
"CONNECT\nhost:github.org\naccept-version:1.2\n\n\0".equals(actualString)).isTrue();
assertThat(actual.size()).isEqualTo(2);
assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 30));
assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 30, outputStream.size()));
}
@Test
public void encodeFrameWithHeadersSplitMultipleFrames() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 10);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setAcceptVersion("1.2");
headers.setHost("github.org");
Message<byte[]> frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
String actualString = outputStream.toString();
assertThat("CONNECT\naccept-version:1.2\nhost:github.org\n\n\0".equals(actualString) ||
"CONNECT\nhost:github.org\naccept-version:1.2\n\n\0".equals(actualString)).isTrue();
assertThat(actual.size()).isEqualTo(5);
assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 10));
assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 10, 20));
assertThat(actual.get(2)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 20, 30));
assertThat(actual.get(3)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 30, 40));
assertThat(actual.get(4)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 40, outputStream.size()));
}
@Test
public void encodeFrameWithHeadersThatShouldBeEscaped() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, DEFAULT_MESSAGE_MAX_SIZE);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
headers.addNativeHeader("a:\r\n\\b", "alpha:bravo\r\n\\");
Message<byte[]> frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
assertThat(outputStream.toString()).isEqualTo("DISCONNECT\na\\c\\r\\n\\\\b:alpha\\cbravo\\r\\n\\\\\n\n\0");
assertThat(actual.size()).isOne();
}
@Test
public void encodeFrameWithHeadersThatShouldBeEscapedSplitTwoFrames() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 30);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
headers.addNativeHeader("a:\r\n\\b", "alpha:bravo\r\n\\");
Message<byte[]> frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
assertThat(outputStream.toString()).isEqualTo("DISCONNECT\na\\c\\r\\n\\\\b:alpha\\cbravo\\r\\n\\\\\n\n\0");
assertThat(actual.size()).isEqualTo(2);
assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 30));
assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 30, outputStream.size()));
}
@Test
public void encodeFrameWithHeadersThatShouldBeEscapedSplitMultipleFrames() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 10);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
headers.addNativeHeader("a:\r\n\\b", "alpha:bravo\r\n\\");
Message<byte[]> frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
String actualString = outputStream.toString();
assertThat(outputStream.toString()).isEqualTo("DISCONNECT\na\\c\\r\\n\\\\b:alpha\\cbravo\\r\\n\\\\\n\n\0");
assertThat(actual.size()).isEqualTo(5);
assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 10));
assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 10, 20));
assertThat(actual.get(2)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 20, 30));
assertThat(actual.get(3)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 30, 40));
assertThat(actual.get(4)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 40, outputStream.size()));
}
@Test
public void encodeFrameWithHeadersBody() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, DEFAULT_MESSAGE_MAX_SIZE);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.addNativeHeader("a", "alpha");
Message<byte[]> frame = MessageBuilder.createMessage(
"Message body".getBytes(), headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
assertThat(outputStream.toString()).isEqualTo("SEND\na:alpha\ncontent-length:12\n\nMessage body\0");
assertThat(actual.size()).isOne();
}
@Test
public void encodeFrameWithHeadersBodySplitTwoFrames() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 30);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.addNativeHeader("a", "alpha");
Message<byte[]> frame = MessageBuilder.createMessage(
"Message body".getBytes(), headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
assertThat(outputStream.toString()).isEqualTo("SEND\na:alpha\ncontent-length:12\n\nMessage body\0");
assertThat(actual.size()).isEqualTo(2);
assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 30));
assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 30, outputStream.size()));
}
@Test
public void encodeFrameWithHeadersBodySplitMultipleFrames() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 10);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.addNativeHeader("a", "alpha");
Message<byte[]> frame = MessageBuilder.createMessage(
"Message body".getBytes(), headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
assertThat(outputStream.toString()).isEqualTo("SEND\na:alpha\ncontent-length:12\n\nMessage body\0");
assertThat(actual.size()).isEqualTo(5);
assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 10));
assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 10, 20));
assertThat(actual.get(2)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 20, 30));
assertThat(actual.get(3)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 30, 40));
assertThat(actual.get(4)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 40, outputStream.size()));
}
@Test
public void encodeFrameWithContentLengthPresent() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, DEFAULT_MESSAGE_MAX_SIZE);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.setContentLength(12);
Message<byte[]> frame = MessageBuilder.createMessage(
"Message body".getBytes(), headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
assertThat(outputStream.toString()).isEqualTo("SEND\ncontent-length:12\n\nMessage body\0");
assertThat(actual.size()).isOne();
}
@Test
public void encodeFrameWithContentLengthPresentSplitTwoFrames() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 20);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.setContentLength(12);
Message<byte[]> frame = MessageBuilder.createMessage(
"Message body".getBytes(), headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
assertThat(outputStream.toString()).isEqualTo("SEND\ncontent-length:12\n\nMessage body\0");
assertThat(actual.size()).isEqualTo(2);
assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 20));
assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 20, outputStream.size()));
}
@Test
public void encodeFrameWithContentLengthPresentSplitMultipleFrames() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 10);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.setContentLength(12);
Message<byte[]> frame = MessageBuilder.createMessage(
"Message body".getBytes(), headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
assertThat(outputStream.toString()).isEqualTo("SEND\ncontent-length:12\n\nMessage body\0");
assertThat(actual.size()).isEqualTo(4);
assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 10));
assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 10, 20));
assertThat(actual.get(2)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 20, 30));
assertThat(actual.get(3)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 30, outputStream.size()));
}
@Test
public void sameLengthAndBufferSizeLimit() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 44);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.addNativeHeader("a", "1234");
Message<byte[]> frame = MessageBuilder.createMessage(
"Message body".getBytes(), headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
assertThat(outputStream.toString()).isEqualTo("SEND\na:1234\ncontent-length:12\n\nMessage body\0");
assertThat(actual.size()).isOne();
assertThat(outputStream.toByteArray().length).isEqualTo(44);
}
@Test
public void lengthAndBufferSizeLimitExactlySplitTwoFrames() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 22);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.addNativeHeader("a", "1234");
Message<byte[]> frame = MessageBuilder.createMessage(
"Message body".getBytes(), headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
assertThat(outputStream.toString()).isEqualTo("SEND\na:1234\ncontent-length:12\n\nMessage body\0");
assertThat(actual.size()).isEqualTo(2);
assertThat(outputStream.toByteArray().length).isEqualTo(44);
assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 22));
assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 22, 44));
}
@Test
public void lengthAndBufferSizeLimitExactlySplitMultipleFrames() {
SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 11);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.addNativeHeader("a", "1234");
Message<byte[]> frame = MessageBuilder.createMessage(
"Message body".getBytes(), headers.getMessageHeaders());
List<byte[]> actual = encoder.encode(frame.getHeaders(), frame.getPayload());
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
actual.forEach(outputStream::writeBytes);
assertThat(outputStream.toString()).isEqualTo("SEND\na:1234\ncontent-length:12\n\nMessage body\0");
assertThat(actual.size()).isEqualTo(4);
assertThat(outputStream.toByteArray().length).isEqualTo(44);
assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 11));
assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 11, 22));
assertThat(actual.get(2)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 22, 33));
assertThat(actual.get(3)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 33, 44));
}
@Test
public void bufferSizeLimitShouldBePositive() {
assertThatThrownBy(() -> new SplittingStompEncoder(STOMP_ENCODER, 0))
.isInstanceOf(IllegalArgumentException.class);
assertThatThrownBy(() -> new SplittingStompEncoder(STOMP_ENCODER, -1))
.isInstanceOf(IllegalArgumentException.class);
}
}

View File

@ -35,6 +35,7 @@ import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.stomp.BufferingStompDecoder;
import org.springframework.messaging.simp.stomp.ConnectionHandlingStompSession;
import org.springframework.messaging.simp.stomp.SplittingStompEncoder;
import org.springframework.messaging.simp.stomp.StompClientSupport;
import org.springframework.messaging.simp.stomp.StompDecoder;
import org.springframework.messaging.simp.stomp.StompEncoder;
@ -67,15 +68,23 @@ import org.springframework.web.util.UriComponentsBuilder;
* SockJsClient}.
*
* @author Rossen Stoyanchev
* @author Injae Kim
* @since 4.2
*/
public class WebSocketStompClient extends StompClientSupport implements SmartLifecycle {
private static final Log logger = LogFactory.getLog(WebSocketStompClient.class);
/**
* The default max size for in&outbound STOMP message.
*/
private static final int DEFAULT_MESSAGE_MAX_SIZE = 64 * 1024;
private final WebSocketClient webSocketClient;
private int inboundMessageSizeLimit = 64 * 1024;
private int inboundMessageSizeLimit = DEFAULT_MESSAGE_MAX_SIZE;
private int outboundMessageSizeLimit = DEFAULT_MESSAGE_MAX_SIZE;
private boolean autoStartup = true;
@ -122,7 +131,7 @@ public class WebSocketStompClient extends StompClientSupport implements SmartLif
* Since a STOMP message can be received in multiple WebSocket messages,
* buffering may be required and this property determines the maximum buffer
* size per message.
* <p>By default this is set to 64 * 1024 (64K).
* <p>By default this is set to 64 * 1024 (64K), see {@link WebSocketStompClient#DEFAULT_MESSAGE_MAX_SIZE}.
*/
public void setInboundMessageSizeLimit(int inboundMessageSizeLimit) {
this.inboundMessageSizeLimit = inboundMessageSizeLimit;
@ -135,6 +144,25 @@ public class WebSocketStompClient extends StompClientSupport implements SmartLif
return this.inboundMessageSizeLimit;
}
/**
* Configure the maximum size allowed for outbound STOMP message.
* If STOMP message's size exceeds {@link WebSocketStompClient#outboundMessageSizeLimit},
* STOMP message is split into multiple frames.
* <p>By default this is set to 64 * 1024 (64K), see {@link WebSocketStompClient#DEFAULT_MESSAGE_MAX_SIZE}.
* @since 6.2
*/
public void setOutboundMessageSizeLimit(int outboundMessageSizeLimit) {
this.outboundMessageSizeLimit = outboundMessageSizeLimit;
}
/**
* Get the configured outbound message buffer size in bytes.
* @since 6.2
*/
public int getOutboundMessageSizeLimit() {
return this.outboundMessageSizeLimit;
}
/**
* Set whether to auto-start the contained WebSocketClient when the Spring
* context has been refreshed.
@ -373,7 +401,8 @@ public class WebSocketStompClient extends StompClientSupport implements SmartLif
private final TcpConnectionHandler<byte[]> stompSession;
private final StompWebSocketMessageCodec codec = new StompWebSocketMessageCodec(getInboundMessageSizeLimit());
private final StompWebSocketMessageCodec codec =
new StompWebSocketMessageCodec(getInboundMessageSizeLimit(),getOutboundMessageSizeLimit());
@Nullable
private volatile WebSocketSession session;
@ -450,7 +479,9 @@ public class WebSocketStompClient extends StompClientSupport implements SmartLif
try {
WebSocketSession session = this.session;
Assert.state(session != null, "No WebSocketSession available");
session.sendMessage(this.codec.encode(message, session.getClass()));
for (WebSocketMessage<?> webSocketMessage : this.codec.encode(message, session.getClass())) {
session.sendMessage(webSocketMessage);
}
future.complete(null);
}
catch (Throwable ex) {
@ -561,8 +592,11 @@ public class WebSocketStompClient extends StompClientSupport implements SmartLif
private final BufferingStompDecoder bufferingDecoder;
public StompWebSocketMessageCodec(int messageSizeLimit) {
this.bufferingDecoder = new BufferingStompDecoder(DECODER, messageSizeLimit);
private final SplittingStompEncoder splittingEncoder;
public StompWebSocketMessageCodec(int inboundMessageSizeLimit, int outboundMessageSizeLimit) {
this.bufferingDecoder = new BufferingStompDecoder(DECODER, inboundMessageSizeLimit);
this.splittingEncoder = new SplittingStompEncoder(ENCODER, outboundMessageSizeLimit);
}
public List<Message<byte[]>> decode(WebSocketMessage<?> webSocketMessage) {
@ -588,17 +622,21 @@ public class WebSocketStompClient extends StompClientSupport implements SmartLif
return result;
}
public WebSocketMessage<?> encode(Message<byte[]> message, Class<? extends WebSocketSession> sessionType) {
public List<WebSocketMessage<?>> encode(Message<byte[]> message, Class<? extends WebSocketSession> sessionType) {
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
Assert.notNull(accessor, "No StompHeaderAccessor available");
byte[] payload = message.getPayload();
byte[] bytes = ENCODER.encode(accessor.getMessageHeaders(), payload);
List<byte[]> frames = splittingEncoder.encode(accessor.getMessageHeaders(), payload);
boolean useBinary = (payload.length > 0 &&
!(SockJsSession.class.isAssignableFrom(sessionType)) &&
MimeTypeUtils.APPLICATION_OCTET_STREAM.isCompatibleWith(accessor.getContentType()));
return (useBinary ? new BinaryMessage(bytes) : new TextMessage(bytes));
List<WebSocketMessage<?>> messages = new ArrayList<>();
for (byte[] frame : frames) {
messages.add(useBinary ? new BinaryMessage(frame) : new TextMessage(frame));
}
return messages;
}
}

View File

@ -65,6 +65,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
* Tests for {@link WebSocketStompClient}.
*
* @author Rossen Stoyanchev
* @author Injae Kim
*/
@MockitoSettings(strictness = Strictness.LENIENT)
class WebSocketStompClientTests {
@ -211,6 +212,29 @@ class WebSocketStompClientTests {
assertThat(textMessage.getPayload()).isEqualTo("SEND\ndestination:/topic/foo\ncontent-length:7\n\npayload\0");
}
@Test
void sendWebSocketMessageExceedOutboundMessageSizeLimit() throws Exception {
stompClient.setOutboundMessageSizeLimit(30);
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND);
accessor.setDestination("/topic/foo");
byte[] payload = "payload".getBytes(StandardCharsets.UTF_8);
getTcpConnection().sendAsync(MessageBuilder.createMessage(payload, accessor.getMessageHeaders()));
ArgumentCaptor<TextMessage> textMessageCaptor = ArgumentCaptor.forClass(TextMessage.class);
verify(this.webSocketSession, times(2)).sendMessage(textMessageCaptor.capture());
TextMessage textMessage = textMessageCaptor.getAllValues().get(0);
assertThat(textMessage).isNotNull();
assertThat(textMessage.getPayload()).isEqualTo("SEND\ndestination:/topic/foo\nco");
assertThat(textMessage.getPayload().getBytes().length).isEqualTo(30);
textMessage = textMessageCaptor.getAllValues().get(1);
assertThat(textMessage).isNotNull();
assertThat(textMessage.getPayload()).isEqualTo("ntent-length:7\n\npayload\0");
assertThat(textMessage.getPayload().getBytes().length).isEqualTo(24);
}
@Test
void sendWebSocketBinary() throws Exception {
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND);
@ -228,6 +252,49 @@ class WebSocketStompClientTests {
.isEqualTo("SEND\ndestination:/b\ncontent-type:application/octet-stream\ncontent-length:7\n\npayload\0");
}
@Test
void sendWebSocketBinaryExceedOutboundMessageSizeLimit() throws Exception {
stompClient.setOutboundMessageSizeLimit(50);
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND);
accessor.setDestination("/b");
accessor.setContentType(MimeTypeUtils.APPLICATION_OCTET_STREAM);
byte[] payload = "payload".getBytes(StandardCharsets.UTF_8);
getTcpConnection().sendAsync(MessageBuilder.createMessage(payload, accessor.getMessageHeaders()));
ArgumentCaptor<BinaryMessage> binaryMessageCaptor = ArgumentCaptor.forClass(BinaryMessage.class);
verify(this.webSocketSession, times(2)).sendMessage(binaryMessageCaptor.capture());
BinaryMessage binaryMessage = binaryMessageCaptor.getAllValues().get(0);
assertThat(binaryMessage).isNotNull();
assertThat(new String(binaryMessage.getPayload().array(), StandardCharsets.UTF_8))
.isEqualTo("SEND\ndestination:/b\ncontent-type:application/octet");
assertThat(binaryMessage.getPayload().array().length).isEqualTo(50);
binaryMessage = binaryMessageCaptor.getAllValues().get(1);
assertThat(binaryMessage).isNotNull();
assertThat(new String(binaryMessage.getPayload().array(), StandardCharsets.UTF_8))
.isEqualTo("-stream\ncontent-length:7\n\npayload\0");
assertThat(binaryMessage.getPayload().array().length).isEqualTo(34);
}
@Test
void reassembleReceivedIFragmentedFrames() throws Exception {
WebSocketHandler handler = connect();
handler.handleMessage(this.webSocketSession, new TextMessage("SEND\ndestination:/topic/foo\nco"));
handler.handleMessage(this.webSocketSession, new TextMessage("ntent-length:7\n\npayload\0"));
ArgumentCaptor<Message> receiveMessageCaptor = ArgumentCaptor.forClass(Message.class);
verify(this.stompSession).handleMessage(receiveMessageCaptor.capture());
Message<byte[]> receiveMessage = receiveMessageCaptor.getValue();
assertThat(receiveMessage).isNotNull();
StompHeaderAccessor headers = StompHeaderAccessor.wrap(receiveMessage);
assertThat(headers.toNativeHeaderMap()).hasSize(2);
assertThat(headers.getContentLength()).isEqualTo(7);
assertThat(headers.getDestination()).isEqualTo("/topic/foo");
assertThat(new String(receiveMessage.getPayload())).isEqualTo("payload");
}
@Test
void heartbeatDefaultValue() {
WebSocketStompClient stompClient = new WebSocketStompClient(mock());