Remove custom handling of byte[] in DefaultStompSession

Closes gh-23358
This commit is contained in:
Rossen Stoyanchev 2019-07-30 10:52:13 +01:00
parent 50a909908c
commit 1d92755cc7
2 changed files with 50 additions and 10 deletions

View File

@ -256,12 +256,9 @@ public class DefaultStompSession implements ConnectionHandlingStompSession {
private Message<byte[]> createMessage(StompHeaderAccessor accessor, @Nullable Object payload) {
accessor.updateSimpMessageHeadersFromStompHeaders();
Message<byte[]> message;
if (payload == null) {
if (isEmpty(payload)) {
message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
}
else if (payload instanceof byte[]) {
message = MessageBuilder.createMessage((byte[]) payload, accessor.getMessageHeaders());
}
else {
message = (Message<byte[]>) getMessageConverter().toMessage(payload, accessor.getMessageHeaders());
accessor.updateStompHeadersFromSimpMessageHeaders();
@ -274,6 +271,11 @@ public class DefaultStompSession implements ConnectionHandlingStompSession {
return message;
}
private boolean isEmpty(@Nullable Object payload) {
return payload == null || StringUtils.isEmpty(payload) ||
(payload instanceof byte[] && ((byte[]) payload).length == 0);
}
private void execute(Message<byte[]> message) {
if (logger.isTraceEnabled()) {
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,6 +17,7 @@
package org.springframework.messaging.simp.stomp;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Date;
import java.util.Map;
import java.util.concurrent.ScheduledFuture;
@ -34,6 +35,8 @@ import org.mockito.MockitoAnnotations;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageDeliveryException;
import org.springframework.messaging.converter.ByteArrayMessageConverter;
import org.springframework.messaging.converter.CompositeMessageConverter;
import org.springframework.messaging.converter.MessageConversionException;
import org.springframework.messaging.converter.StringMessageConverter;
import org.springframework.messaging.simp.stomp.StompSession.Receiptable;
@ -46,10 +49,23 @@ import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils;
import org.springframework.util.concurrent.SettableListenableFuture;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.notNull;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.same;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
/**
* Unit tests for {@link DefaultStompSession}.
@ -82,7 +98,9 @@ public class DefaultStompSessionTests {
this.sessionHandler = mock(StompSessionHandler.class);
this.connectHeaders = new StompHeaders();
this.session = new DefaultStompSession(this.sessionHandler, this.connectHeaders);
this.session.setMessageConverter(new StringMessageConverter());
this.session.setMessageConverter(
new CompositeMessageConverter(
Arrays.asList(new StringMessageConverter(), new ByteArrayMessageConverter())));
SettableListenableFuture<Void> future = new SettableListenableFuture<>();
future.set(null);
@ -110,7 +128,7 @@ public class DefaultStompSessionTests {
@Test // SPR-16844
public void afterConnectedWithSpecificVersion() {
assertFalse(this.session.isConnected());
this.connectHeaders.setAcceptVersion(new String[] {"1.1"});
this.connectHeaders.setAcceptVersion("1.1");
this.session.afterConnected(this.connection);
@ -388,6 +406,26 @@ public class DefaultStompSessionTests {
assertEquals("my-receipt", accessor.getReceipt());
}
@Test // gh-23358
public void sendByteArray() {
this.session.afterConnected(this.connection);
assertTrue(this.session.isConnected());
String destination = "/topic/foo";
String payload = "sample payload";
this.session.send(destination, payload.getBytes(StandardCharsets.UTF_8));
Message<byte[]> message = this.messageCaptor.getValue();
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
StompHeaders stompHeaders = StompHeaders.readOnlyStompHeaders(accessor.getNativeHeaders());
assertEquals(stompHeaders.toString(), 2, stompHeaders.size());
assertEquals(destination, stompHeaders.getDestination());
assertEquals(MimeTypeUtils.APPLICATION_OCTET_STREAM, stompHeaders.getContentType());
assertEquals(payload, new String(message.getPayload(), StandardCharsets.UTF_8));
}
@Test
public void sendWithConversionException() {
this.session.afterConnected(this.connection);