diff --git a/spring-messaging/src/main/java/org/springframework/messaging/core/GenericMessagingTemplate.java b/spring-messaging/src/main/java/org/springframework/messaging/core/GenericMessagingTemplate.java index d270c416112..d8d4a9b0214 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/core/GenericMessagingTemplate.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/core/GenericMessagingTemplate.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,36 +40,50 @@ import org.springframework.util.Assert; * * @author Mark Fisher * @author Rossen Stoyanchev + * @author Gary Russell * @since 4.0 */ public class GenericMessagingTemplate extends AbstractDestinationResolvingMessagingTemplate implements BeanFactoryAware { + public static final String DEFAULT_SEND_TIMEOUT_HEADER = "sendTimeout"; + + public static final String DEFAULT_RECEIVE_TIMEOUT_HEADER = "receiveTimeout"; + private volatile long sendTimeout = -1; private volatile long receiveTimeout = -1; + private String sendTimeoutHeader = DEFAULT_SEND_TIMEOUT_HEADER; + + private String receiveTimeoutHeader = DEFAULT_RECEIVE_TIMEOUT_HEADER; + private volatile boolean throwExceptionOnLateReply = false; /** - * Configure the timeout value to use for send operations. + * Configure the default timeout value to use for send operations. + * May be overridden for individual messages. * @param sendTimeout the send timeout in milliseconds + * @see #setSendTimeoutHeader(String) */ public void setSendTimeout(long sendTimeout) { this.sendTimeout = sendTimeout; } /** - * Return the configured send operation timeout value. + * Return the configured default send operation timeout value. */ public long getSendTimeout() { return this.sendTimeout; } /** - * Configure the timeout value to use for receive operations. + * Configure the default timeout value to use for receive operations. + * May be overridden for individual messages when using sendAndReceive + * operations. * @param receiveTimeout the receive timeout in milliseconds + * @see #setReceiveTimeoutHeader(String) */ public void setReceiveTimeout(long receiveTimeout) { this.receiveTimeout = receiveTimeout; @@ -82,6 +96,47 @@ public class GenericMessagingTemplate extends AbstractDestinationResolvingMessag return this.receiveTimeout; } + + /** + * Set the name of the header used to determine the send timeout (if present). + * Default {@value #DEFAULT_SEND_TIMEOUT_HEADER}. + * The header is removed before sending the message to avoid propagation. + * @param sendTimeoutHeader the sendTimeoutHeader to set + * @since 5.0 + */ + public void setSendTimeoutHeader(String sendTimeoutHeader) { + Assert.notNull(sendTimeoutHeader, "'sendTimeoutHeader' cannot be null"); + this.sendTimeoutHeader = sendTimeoutHeader; + } + + /** + * @return the configured sendTimeoutHeader. + * @since 5.0 + */ + public String getSendTimeoutHeader() { + return sendTimeoutHeader; + } + + /** + * Set the name of the header used to determine the send timeout (if present). + * Default {@value #DEFAULT_RECEIVE_TIMEOUT_HEADER}. + * The header is removed before sending the message to avoid propagation. + * @param receiveTimeoutHeader the receiveTimeoutHeader to set + * @since 5.0 + */ + public void setReceiveTimeoutHeader(String receiveTimeoutHeader) { + Assert.notNull(receiveTimeoutHeader, "'receiveTimeoutHeader' cannot be null"); + this.receiveTimeoutHeader = receiveTimeoutHeader; + } + + /** + * @return the configured receiveTimeoutHeader + * @since 5.0 + */ + public String getReceiveTimeoutHeader() { + return receiveTimeoutHeader; + } + /** * Whether the thread sending a reply should have an exception raised if the * receiving thread isn't going to receive the reply either because it timed out, @@ -101,18 +156,30 @@ public class GenericMessagingTemplate extends AbstractDestinationResolvingMessag setDestinationResolver(new BeanFactoryMessageChannelDestinationResolver(beanFactory)); } - @Override protected final void doSend(MessageChannel channel, Message message) { + doSend(channel, message, sendTimeout(message)); + } + + protected final void doSend(MessageChannel channel, Message message, long timeout) { Assert.notNull(channel, "MessageChannel is required"); + Message messageToSend = message; MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class); if (accessor != null && accessor.isMutable()) { + accessor.removeHeader(this.sendTimeoutHeader); + accessor.removeHeader(this.receiveTimeoutHeader); accessor.setImmutable(); } + else if (message.getHeaders().containsKey(this.sendTimeoutHeader) + || message.getHeaders().containsKey(this.receiveTimeoutHeader)) { + messageToSend = MessageBuilder.fromMessage(message) + .setHeader(this.sendTimeoutHeader, null) + .setHeader(this.receiveTimeoutHeader, null) + .build(); + } - long timeout = this.sendTimeout; - boolean sent = (timeout >= 0 ? channel.send(message, timeout) : channel.send(message)); + boolean sent = (timeout >= 0 ? channel.send(messageToSend, timeout) : channel.send(messageToSend)); if (!sent) { throw new MessageDeliveryException(message, @@ -122,10 +189,13 @@ public class GenericMessagingTemplate extends AbstractDestinationResolvingMessag @Override protected final Message doReceive(MessageChannel channel) { + return doReceive(channel, this.receiveTimeout); + } + + protected final Message doReceive(MessageChannel channel, long timeout) { Assert.notNull(channel, "MessageChannel is required"); Assert.state(channel instanceof PollableChannel, "A PollableChannel is required to receive messages"); - long timeout = this.receiveTimeout; Message message = (timeout >= 0 ? ((PollableChannel) channel).receive(timeout) : ((PollableChannel) channel).receive()); @@ -142,20 +212,25 @@ public class GenericMessagingTemplate extends AbstractDestinationResolvingMessag Object originalReplyChannelHeader = requestMessage.getHeaders().getReplyChannel(); Object originalErrorChannelHeader = requestMessage.getHeaders().getErrorChannel(); - TemporaryReplyChannel tempReplyChannel = new TemporaryReplyChannel(); - requestMessage = MessageBuilder.fromMessage(requestMessage).setReplyChannel(tempReplyChannel). - setErrorChannel(tempReplyChannel).build(); + long sendTimeout = sendTimeout(requestMessage); + long receiveTimeout = receiveTimeout(requestMessage); + + TemporaryReplyChannel tempReplyChannel = new TemporaryReplyChannel(this.throwExceptionOnLateReply); + requestMessage = MessageBuilder.fromMessage(requestMessage).setReplyChannel(tempReplyChannel) + .setHeader(this.sendTimeoutHeader, null) + .setHeader(this.receiveTimeoutHeader, null) + .setErrorChannel(tempReplyChannel).build(); try { - doSend(channel, requestMessage); + doSend(channel, requestMessage, sendTimeout); } catch (RuntimeException ex) { tempReplyChannel.setSendFailed(true); throw ex; } - Message replyMessage = this.doReceive(tempReplyChannel); - if (replyMessage != null) { + Message replyMessage = this.doReceive(tempReplyChannel, receiveTimeout); + if (replyMessage != null && (originalReplyChannelHeader!= null || originalErrorChannelHeader != null)) { replyMessage = MessageBuilder.fromMessage(replyMessage) .setHeader(MessageHeaders.REPLY_CHANNEL, originalReplyChannelHeader) .setHeader(MessageHeaders.ERROR_CHANNEL, originalErrorChannelHeader) @@ -165,16 +240,39 @@ public class GenericMessagingTemplate extends AbstractDestinationResolvingMessag return replyMessage; } + private long sendTimeout(Message requestMessage) { + Long sendTimeout = headerToLong(requestMessage.getHeaders().get(this.sendTimeoutHeader)); + return sendTimeout == null ? this.sendTimeout : sendTimeout; + } + + private long receiveTimeout(Message requestMessage) { + Long receiveTimeout = headerToLong(requestMessage.getHeaders().get(this.receiveTimeoutHeader)); + return receiveTimeout == null ? this.receiveTimeout : receiveTimeout; + } + + private Long headerToLong(Object headerValue) { + if (headerValue instanceof Number) { + return ((Number) headerValue).longValue(); + } + else if(headerValue instanceof String) { + return Long.parseLong((String) headerValue); + } + else { + return null; + } + } /** * A temporary channel for receiving a single reply message. */ - private class TemporaryReplyChannel implements PollableChannel { + private static final class TemporaryReplyChannel implements PollableChannel { private final Log logger = LogFactory.getLog(TemporaryReplyChannel.class); private final CountDownLatch replyLatch = new CountDownLatch(1); + private final boolean throwExceptionOnLateReply; + private volatile Message replyMessage; private volatile boolean hasReceived; @@ -183,6 +281,10 @@ public class GenericMessagingTemplate extends AbstractDestinationResolvingMessag private volatile boolean hasSendFailed; + TemporaryReplyChannel(boolean throwExceptionOnLateReply) { + this.throwExceptionOnLateReply = throwExceptionOnLateReply; + } + public void setSendFailed(boolean hasSendError) { this.hasSendFailed = hasSendError; } @@ -195,12 +297,12 @@ public class GenericMessagingTemplate extends AbstractDestinationResolvingMessag @Override public Message receive(long timeout) { try { - if (GenericMessagingTemplate.this.receiveTimeout < 0) { + if (timeout < 0) { this.replyLatch.await(); this.hasReceived = true; } else { - if (this.replyLatch.await(GenericMessagingTemplate.this.receiveTimeout, TimeUnit.MILLISECONDS)) { + if (this.replyLatch.await(timeout, TimeUnit.MILLISECONDS)) { this.hasReceived = true; } else { @@ -241,7 +343,7 @@ public class GenericMessagingTemplate extends AbstractDestinationResolvingMessag if (logger.isWarnEnabled()) { logger.warn(errorDescription + ":" + message); } - if (GenericMessagingTemplate.this.throwExceptionOnLateReply) { + if (this.throwExceptionOnLateReply) { throw new MessageDeliveryException(message, errorDescription); } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/core/GenericMessagingTemplateTests.java b/spring-messaging/src/test/java/org/springframework/messaging/core/GenericMessagingTemplateTests.java index 8152e395a56..9cf5ec64e54 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/core/GenericMessagingTemplateTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/core/GenericMessagingTemplateTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicReference; import org.junit.Before; import org.junit.Test; - import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageDeliveryException; @@ -34,15 +33,19 @@ import org.springframework.messaging.StubMessageChannel; import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.support.ExecutorSubscribableChannel; import org.springframework.messaging.support.GenericMessage; +import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; /** * Unit tests for {@link GenericMessagingTemplate}. * * @author Rossen Stoyanchev + * @author Gary Russell */ public class GenericMessagingTemplateTests { @@ -63,6 +66,43 @@ public class GenericMessagingTemplateTests { this.executor.afterPropertiesSet(); } + @Test + public void sendWithTimeout() { + SubscribableChannel channel = mock(SubscribableChannel.class); + final AtomicReference> sent = new AtomicReference<>(); + doAnswer(invocation -> { + sent.set(invocation.getArgument(0)); + return true; + }).when(channel).send(any(Message.class), eq(30_000L)); + Message message = MessageBuilder.withPayload("request") + .setHeader(GenericMessagingTemplate.DEFAULT_SEND_TIMEOUT_HEADER, 30_000L) + .setHeader(GenericMessagingTemplate.DEFAULT_RECEIVE_TIMEOUT_HEADER, 1L) + .build(); + this.template.send(channel, message); + verify(channel).send(any(Message.class), eq(30_000L)); + assertNotNull(sent.get()); + assertFalse(sent.get().getHeaders().containsKey(GenericMessagingTemplate.DEFAULT_SEND_TIMEOUT_HEADER)); + assertFalse(sent.get().getHeaders().containsKey(GenericMessagingTemplate.DEFAULT_RECEIVE_TIMEOUT_HEADER)); + } + + @Test + public void sendWithTimeoutMutable() { + SubscribableChannel channel = mock(SubscribableChannel.class); + final AtomicReference> sent = new AtomicReference<>(); + doAnswer(invocation -> { + sent.set(invocation.getArgument(0)); + return true; + }).when(channel).send(any(Message.class), eq(30_000L)); + MessageHeaderAccessor accessor = new MessageHeaderAccessor(); + accessor.setLeaveMutable(true); + Message message = new GenericMessage<>("request", accessor.getMessageHeaders()); + accessor.setHeader(GenericMessagingTemplate.DEFAULT_SEND_TIMEOUT_HEADER, 30_000L); + this.template.send(channel, message); + verify(channel).send(any(Message.class), eq(30_000L)); + assertNotNull(sent.get()); + assertFalse(sent.get().getHeaders().containsKey(GenericMessagingTemplate.DEFAULT_SEND_TIMEOUT_HEADER)); + assertFalse(sent.get().getHeaders().containsKey(GenericMessagingTemplate.DEFAULT_RECEIVE_TIMEOUT_HEADER)); + } @Test public void sendAndReceive() { @@ -85,41 +125,118 @@ public class GenericMessagingTemplateTests { final CountDownLatch latch = new CountDownLatch(1); this.template.setReceiveTimeout(1); + this.template.setSendTimeout(30_000L); this.template.setThrowExceptionOnLateReply(true); - SubscribableChannel channel = new ExecutorSubscribableChannel(this.executor); - channel.subscribe(new MessageHandler() { - @Override - public void handleMessage(Message message) throws MessagingException { - try { - Thread.sleep(500); - MessageChannel replyChannel = (MessageChannel) message.getHeaders().getReplyChannel(); - replyChannel.send(new GenericMessage<>("response")); - failure.set(new IllegalStateException("Expected exception")); - } - catch (InterruptedException e) { - failure.set(e); - } - catch (MessageDeliveryException ex) { - String expected = "Reply message received but the receiving thread has exited due to a timeout"; - String actual = ex.getMessage(); - if (!expected.equals(actual)) { - failure.set(new IllegalStateException("Unexpected error: '" + actual + "'")); - } - } - finally { - latch.countDown(); - } - } - }); + SubscribableChannel channel = mock(SubscribableChannel.class); + MessageHandler handler = createLateReplier(latch, failure); + doAnswer(invocation -> { + this.executor.execute(() -> { + handler.handleMessage(invocation.getArgument(0)); + }); + return true; + }).when(channel).send(any(Message.class), anyLong()); assertNull(this.template.convertSendAndReceive(channel, "request", String.class)); - assertTrue(latch.await(1000, TimeUnit.MILLISECONDS)); + assertTrue(latch.await(10_000, TimeUnit.MILLISECONDS)); Throwable ex = failure.get(); if (ex != null) { throw new AssertionError(ex); } + verify(channel).send(any(Message.class), eq(30_000L)); + } + + @Test + public void sendAndReceiveVariableTimeout() throws InterruptedException { + final AtomicReference failure = new AtomicReference(); + final CountDownLatch latch = new CountDownLatch(1); + + this.template.setSendTimeout(20_000); + this.template.setReceiveTimeout(10_000); + this.template.setThrowExceptionOnLateReply(true); + + SubscribableChannel channel = mock(SubscribableChannel.class); + MessageHandler handler = createLateReplier(latch, failure); + doAnswer(invocation -> { + this.executor.execute(() -> { + handler.handleMessage(invocation.getArgument(0)); + }); + return true; + }).when(channel).send(any(Message.class), anyLong()); + + Message message = MessageBuilder.withPayload("request") + .setHeader(GenericMessagingTemplate.DEFAULT_SEND_TIMEOUT_HEADER, 30_000L) + .setHeader(GenericMessagingTemplate.DEFAULT_RECEIVE_TIMEOUT_HEADER, 1L) + .build(); + assertNull(this.template.sendAndReceive(channel, message)); + assertTrue(latch.await(10_000, TimeUnit.MILLISECONDS)); + + Throwable ex = failure.get(); + if (ex != null) { + throw new AssertionError(ex); + } + verify(channel).send(any(Message.class), eq(30_000L)); + } + + @Test + public void sendAndReceiveVariableTimeoutCustomHeaders() throws InterruptedException { + final AtomicReference failure = new AtomicReference(); + final CountDownLatch latch = new CountDownLatch(1); + + this.template.setSendTimeout(20_000); + this.template.setReceiveTimeout(10_000); + this.template.setThrowExceptionOnLateReply(true); + this.template.setSendTimeoutHeader("sto"); + this.template.setReceiveTimeoutHeader("rto"); + + SubscribableChannel channel = mock(SubscribableChannel.class); + MessageHandler handler = createLateReplier(latch, failure); + doAnswer(invocation -> { + this.executor.execute(() -> { + handler.handleMessage(invocation.getArgument(0)); + }); + return true; + }).when(channel).send(any(Message.class), anyLong()); + + Message message = MessageBuilder.withPayload("request") + .setHeader("sto", 30_000L) + .setHeader("rto", 1L) + .build(); + assertNull(this.template.sendAndReceive(channel, message)); + assertTrue(latch.await(10_000, TimeUnit.MILLISECONDS)); + + Throwable ex = failure.get(); + if (ex != null) { + throw new AssertionError(ex); + } + verify(channel).send(any(Message.class), eq(30_000L)); + } + + private MessageHandler createLateReplier(final CountDownLatch latch, final AtomicReference failure) { + MessageHandler handler = message -> { + try { + Thread.sleep(500); + MessageChannel replyChannel = (MessageChannel) message.getHeaders().getReplyChannel(); + replyChannel.send(new GenericMessage<>("response")); + failure.set(new IllegalStateException("Expected exception")); + } + catch (InterruptedException e) { + failure.set(e); + } + catch (MessageDeliveryException ex) { + String expected = "Reply message received but the receiving thread has exited due to a timeout"; + String actual = ex.getMessage(); + if (!expected.equals(actual)) { + failure.set(new IllegalStateException( + "Unexpected error: '" + actual + "'")); + } + } + finally { + latch.countDown(); + } + }; + return handler; } @Test