diff --git a/deps/rabbitmq_stomp/src/rabbit_stomp_processor.erl b/deps/rabbitmq_stomp/src/rabbit_stomp_processor.erl index a361ff83c9..1f36c0346b 100644 --- a/deps/rabbitmq_stomp/src/rabbit_stomp_processor.erl +++ b/deps/rabbitmq_stomp/src/rabbit_stomp_processor.erl @@ -134,7 +134,7 @@ handle_frame("DISCONNECT", _Frame, State) -> handle_frame("SUBSCRIBE", Frame, State) -> with_destination("SUBSCRIBE", Frame, State, fun do_subscribe/4); -handle_frame("UNSUBSCRIBE", Frame, State = #state{subscriptions = Subs}) -> +handle_frame("UNSUBSCRIBE", Frame, State) -> ConsumerTag = case rabbit_stomp_frame:header(Frame, "id") of {ok, IdStr} -> list_to_binary("T_" ++ IdStr); @@ -147,17 +147,7 @@ handle_frame("UNSUBSCRIBE", Frame, State = #state{subscriptions = Subs}) -> missing end end, - if - ConsumerTag == missing -> - error("Missing destination or id", - "UNSUBSCRIBE must include a 'destination' or 'id' header\n", - State); - true -> - ok(send_method(#'basic.cancel'{consumer_tag = ConsumerTag, - nowait = true}, - State#state{subscriptions = - dict:erase(ConsumerTag, Subs)})) - end; + cancel_subscription(ConsumerTag, State); handle_frame("SEND", Frame, State) -> with_destination("SEND", Frame, State, fun do_send/4); @@ -211,6 +201,25 @@ handle_frame(Command, _Frame, State) -> %% Internal helpers for processing frames callbacks %%---------------------------------------------------------------------------- +cancel_subscription(missing, State) -> + error("Missing destination or id", + "UNSUBSCRIBE must include a 'destination' or 'id' header\n", + State); + +cancel_subscription(ConsumerTag, State = #state{subscriptions = Subs}) -> + case dict:find(ConsumerTag, Subs) of + error -> + error("No subscription found", + "UNSUBSCRIBE must refer to an existing subscription\n", + State); + {ok, {_DestHdr, Channel}} -> + ok(send_method(#'basic.cancel'{consumer_tag = ConsumerTag, + nowait = true}, + Channel, + State#state{subscriptions = + dict:erase(ConsumerTag, Subs)})) + end. + with_destination(Command, Frame, State, Fun) -> case rabbit_stomp_frame:header(Frame, "destination") of {ok, DestHdr} -> @@ -340,10 +349,13 @@ send_delivery(Delivery = #'basic.deliver'{consumer_tag = ConsumerTag}, State) end. -send_method(Method, State = #state{channel = Channel}) -> +send_method(Method, Channel, State) -> amqp_channel:call(Channel, Method), State. +send_method(Method, State = #state{channel = Channel}) -> + send_method(Method, Channel, State). + send_method(Method, Properties, BodyFragments, State = #state{channel = Channel}) -> amqp_channel:call(Channel, Method, #amqp_msg{ @@ -519,10 +531,10 @@ priv_error(Message, Detail, ServerPrivateDetail, State) -> {error, Message, Detail, State}. priv_error(Message, Format, Args, ServerPrivateDetail, State) -> - priv_error(Message, format_detail(Format, Args), + priv_error(Message, format_message(Format, Args), ServerPrivateDetail, State). -format_detail(Format, Args) -> +format_message(Format, Args) -> lists:flatten(io_lib:format(Format, Args)). %%---------------------------------------------------------------------------- %% Frame sending utilities @@ -554,7 +566,7 @@ send_error(Message, Detail, State) -> {"content-type", "text/plain"}], Detail, State). send_error(Message, Format, Args, State) -> - send_error(Message, format_detail(Format, Args), State). + send_error(Message, format_message(Format, Args), State). %%---------------------------------------------------------------------------- %% Skeleton gen_server callbacks diff --git a/deps/rabbitmq_stomp/test/base.py b/deps/rabbitmq_stomp/test/base.py index 2a1c8deb77..9db861c651 100644 --- a/deps/rabbitmq_stomp/test/base.py +++ b/deps/rabbitmq_stomp/test/base.py @@ -48,28 +48,55 @@ class BaseTest(unittest.TestCase): self.assertEquals("foo", msg['message']) self.assertEquals(dest, msg['headers']['destination']) + def assertListener(self, errMsg, numMsgs=0, numErrs=0, numRcts=0, timeout=3): + if numMsgs + numErrs + numRcts > 0: + self.assertTrue(self.listener.await(timeout), errMsg + " (#awaiting)") + else: + self.assertFalse(self.listener.await(timeout), errMsg + " (#awaiting)") + self.assertEquals(numMsgs, len(self.listener.messages), errMsg + " (#messages)") + self.assertEquals(numErrs, len(self.listener.errors), errMsg + " (#errors)") + self.assertEquals(numRcts, len(self.listener.receipts), errMsg + " (#receipts)") + + def assertListenerAfter(self, verb, errMsg="", numMsgs=0, numErrs=0, numRcts=0, timeout=3): + num = numMsgs + numErrs + numRcts + self.listener.reset(num if num>0 else 1) + verb() + self.assertListener(errMsg=errMsg, numMsgs=numMsgs, numErrs=numErrs, numRcts=numRcts, timeout=timeout) class WaitableListener(object): def __init__(self): + self.debug = False + if self.debug: + print '(listener) init' self.messages = [] self.errors = [] self.receipts = [] self.latch = Latch(1) def on_receipt(self, headers, message): + if self.debug: + print '(on_message) message:', message, 'headers:', headers self.receipts.append({'message' : message, 'headers' : headers}) self.latch.countdown() def on_error(self, headers, message): + if self.debug: + print '(on_message) message:', message, 'headers:', headers self.errors.append({'message' : message, 'headers' : headers}) self.latch.countdown() def on_message(self, headers, message): + if self.debug: + print '(on_message) message:', message, 'headers:', headers self.messages.append({'message' : message, 'headers' : headers}) self.latch.countdown() def reset(self,count=1): + if self.debug: + print '(reset listener) #messages:', len(self.messages), + print '#errors', len(self.errors), + print '#receipts', len(self.receipts), 'Now expecting:', count self.messages = [] self.errors = [] self.receipts = [] diff --git a/deps/rabbitmq_stomp/test/lifecycle.py b/deps/rabbitmq_stomp/test/lifecycle.py index 5fe32e9e43..2dfbf782d7 100644 --- a/deps/rabbitmq_stomp/test/lifecycle.py +++ b/deps/rabbitmq_stomp/test/lifecycle.py @@ -5,45 +5,46 @@ import time class TestLifecycle(base.BaseTest): - def test_unsubscribe_destination(self): - d = "/exchange/amq.fanout" + def test_unsubscribe_exchange_destination(self): + ''' Test UNSUBSCRIBE command with exchange''' + self.unsub_test(self.sub_and_send("/exchange/amq.fanout")) - # subscribe and send message - self.listener.reset() - self.conn.subscribe(destination=d) - self.conn.send("test", destination=d) - self.assertTrue(self.listener.await()) - self.assertEquals(1, len(self.listener.messages)) - self.assertEquals(0, len(self.listener.errors)) + def test_unsubscribe_queue_destination(self): + ''' Test UNSUBSCRIBE command with queue''' + self.unsub_test(self.sub_and_send("/queue/unsub01")) - # unsubscribe and send now - self.listener.reset() - self.conn.unsubscribe(destination=d) - self.conn.send("test", destination=d) - self.assertFalse(self.listener.await(3), - "UNSUBSCRIBE failed, still receiving messages") + def test_unsubscribe_exchange_id(self): + ''' Test UNSUBSCRIBE command with exchange by id''' + self.unsub_test(self.subid_and_send("/exchange/amq.fanout", "exchid")) - def test_unsubscribe_id(self): - ''' Test UNSUBSCRIBE command with id parameter''' - d = "/exchange/amq.fanout" - - # subscribe and send message - self.listener.reset() - self.conn.subscribe(destination=d, id="test") - self.conn.send("test", destination=d) - self.assertTrue(self.listener.await()) - self.assertEquals(1, len(self.listener.messages)) - self.assertEquals(0, len(self.listener.errors)) - - # unsubscribe and send now - self.listener.reset() - self.conn.unsubscribe(id="test") - self.conn.send("test", destination=d) - self.assertFalse(self.listener.await(3), - "UNSUBSCRIBE failed, still receiving messages") + def test_unsubscribe_queue_id(self): + ''' Test UNSUBSCRIBE command with queue by id''' + self.unsub_test(self.subid_and_send("/queue/unsub02", "queid")) def test_disconnect(self): ''' Run DISCONNECT command ''' self.conn.disconnect() self.assertFalse(self.conn.is_connected()) + def unsub_test(self, verbs): + subverb, unsubverb = verbs + self.assertListenerAfter(subverb, + numMsgs=1, errMsg="FAILED to subscribe and send") + self.assertListenerAfter(unsubverb, + errMsg="Still receiving messages") + + def subid_and_send(self, dest, subid): + def subfun(): + self.conn.subscribe(destination=dest, id=subid) + self.conn.send("test", destination=dest) + def unsubfun(): + self.conn.unsubscribe(id=subid) + return subfun, unsubfun + + def sub_and_send(self, dest): + def subfun(): + self.conn.subscribe(destination=dest) + self.conn.send("test", destination=dest) + def unsubfun(): + self.conn.unsubscribe(destination=dest) + return subfun, unsubfun