diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/MessageBuilder.java b/spring-messaging/src/main/java/org/springframework/messaging/support/MessageBuilder.java index e722bafc8a..64d75f2d0f 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/MessageBuilder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/MessageBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,23 +41,23 @@ public final class MessageBuilder { private final T payload; @Nullable - private final Message originalMessage; + private final Message providedMessage; private MessageHeaderAccessor headerAccessor; - private MessageBuilder(Message originalMessage) { - Assert.notNull(originalMessage, "Message must not be null"); - this.payload = originalMessage.getPayload(); - this.originalMessage = originalMessage; - this.headerAccessor = new MessageHeaderAccessor(originalMessage); + private MessageBuilder(Message providedMessage) { + Assert.notNull(providedMessage, "Message must not be null"); + this.payload = providedMessage.getPayload(); + this.providedMessage = providedMessage; + this.headerAccessor = new MessageHeaderAccessor(providedMessage); } private MessageBuilder(T payload, MessageHeaderAccessor accessor) { Assert.notNull(payload, "Payload must not be null"); Assert.notNull(accessor, "MessageHeaderAccessor must not be null"); this.payload = payload; - this.originalMessage = null; + this.providedMessage = null; this.headerAccessor = accessor; } @@ -99,6 +99,7 @@ public final class MessageBuilder { this.headerAccessor.removeHeaders(headerPatterns); return this; } + /** * Remove the value for the given header name. */ @@ -148,11 +149,17 @@ public final class MessageBuilder { @SuppressWarnings("unchecked") public Message build() { - if (this.originalMessage != null && !this.headerAccessor.isModified()) { - return this.originalMessage; + if (this.providedMessage != null && !this.headerAccessor.isModified()) { + return this.providedMessage; } MessageHeaders headersToUse = this.headerAccessor.toMessageHeaders(); if (this.payload instanceof Throwable) { + if (this.providedMessage != null && this.providedMessage instanceof ErrorMessage) { + Message message = ((ErrorMessage) this.providedMessage).getOriginalMessage(); + if (message != null) { + return (Message) new ErrorMessage((Throwable) this.payload, headersToUse, message); + } + } return (Message) new ErrorMessage((Throwable) this.payload, headersToUse); } else { @@ -165,6 +172,9 @@ public final class MessageBuilder { * Create a builder for a new {@link Message} instance pre-populated with all of the * headers copied from the provided message. The payload of the provided Message will * also be used as the payload for the new message. + *

If the provided message is an {@link ErrorMessage}, the + * {@link ErrorMessage#getOriginalMessage() originalMessage} it contains, will be + * passed on to new instance. * @param message the Message from which the payload and all headers will be copied */ public static MessageBuilder fromMessage(Message message) { diff --git a/spring-messaging/src/test/java/org/springframework/messaging/support/MessageBuilderTests.java b/spring-messaging/src/test/java/org/springframework/messaging/support/MessageBuilderTests.java index e0a0b8ccb0..c1e2df81ea 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/support/MessageBuilderTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/support/MessageBuilderTests.java @@ -16,6 +16,7 @@ package org.springframework.messaging.support; +import java.util.Collections; import java.util.Date; import java.util.HashMap; import java.util.Map; @@ -107,6 +108,19 @@ public class MessageBuilderTests { assertThat(message2.getHeaders().get("foo")).isEqualTo("bar"); } + @Test // gh-23417 + public void createErrorMessageFromErrorMessage() { + Message source = MessageBuilder.withPayload("test").setHeader("foo", "bar").build(); + RuntimeException ex = new RuntimeException(); + ErrorMessage errorMessage1 = new ErrorMessage(ex, Collections.singletonMap("baz", "42"), source); + Message errorMessage2 = MessageBuilder.fromMessage(errorMessage1).build(); + assertThat(errorMessage2).isExactlyInstanceOf(ErrorMessage.class); + ErrorMessage actual = (ErrorMessage) errorMessage2; + assertThat(actual.getPayload()).isSameAs(ex); + assertThat(actual.getHeaders().get("baz")).isEqualTo("42"); + assertThat(actual.getOriginalMessage()).isSameAs(source); + } + @Test public void createIdRegenerated() { Message message1 = MessageBuilder.withPayload("test") @@ -119,20 +133,20 @@ public class MessageBuilderTests { @Test public void testRemove() { Message message1 = MessageBuilder.withPayload(1) - .setHeader("foo", "bar").build(); + .setHeader("foo", "bar").build(); Message message2 = MessageBuilder.fromMessage(message1) - .removeHeader("foo") - .build(); + .removeHeader("foo") + .build(); assertThat(message2.getHeaders().containsKey("foo")).isFalse(); } @Test public void testSettingToNullRemoves() { Message message1 = MessageBuilder.withPayload(1) - .setHeader("foo", "bar").build(); + .setHeader("foo", "bar").build(); Message message2 = MessageBuilder.fromMessage(message1) - .setHeader("foo", null) - .build(); + .setHeader("foo", null) + .build(); assertThat(message2.getHeaders().containsKey("foo")).isFalse(); } @@ -192,7 +206,7 @@ public class MessageBuilderTests { assertThatIllegalStateException().isThrownBy(() -> accessor.setHeader("foo", "bar")) - .withMessageContaining("Already immutable"); + .withMessageContaining("Already immutable"); assertThat(MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class)).isSameAs(accessor); }