Refactored lifecyclke tests

This commit is contained in:
Steve Powell 2011-01-17 17:55:27 +00:00
commit e09807c6c4
3 changed files with 89 additions and 49 deletions

View File

@ -134,7 +134,7 @@ handle_frame("DISCONNECT", _Frame, State) ->
handle_frame("SUBSCRIBE", Frame, State) ->
with_destination("SUBSCRIBE", Frame, State, fun do_subscribe/4);
handle_frame("UNSUBSCRIBE", Frame, State = #state{subscriptions = Subs}) ->
handle_frame("UNSUBSCRIBE", Frame, State) ->
ConsumerTag = case rabbit_stomp_frame:header(Frame, "id") of
{ok, IdStr} ->
list_to_binary("T_" ++ IdStr);
@ -147,17 +147,7 @@ handle_frame("UNSUBSCRIBE", Frame, State = #state{subscriptions = Subs}) ->
missing
end
end,
if
ConsumerTag == missing ->
error("Missing destination or id",
"UNSUBSCRIBE must include a 'destination' or 'id' header\n",
State);
true ->
ok(send_method(#'basic.cancel'{consumer_tag = ConsumerTag,
nowait = true},
State#state{subscriptions =
dict:erase(ConsumerTag, Subs)}))
end;
cancel_subscription(ConsumerTag, State);
handle_frame("SEND", Frame, State) ->
with_destination("SEND", Frame, State, fun do_send/4);
@ -211,6 +201,25 @@ handle_frame(Command, _Frame, State) ->
%% Internal helpers for processing frames callbacks
%%----------------------------------------------------------------------------
cancel_subscription(missing, State) ->
error("Missing destination or id",
"UNSUBSCRIBE must include a 'destination' or 'id' header\n",
State);
cancel_subscription(ConsumerTag, State = #state{subscriptions = Subs}) ->
case dict:find(ConsumerTag, Subs) of
error ->
error("No subscription found",
"UNSUBSCRIBE must refer to an existing subscription\n",
State);
{ok, {_DestHdr, Channel}} ->
ok(send_method(#'basic.cancel'{consumer_tag = ConsumerTag,
nowait = true},
Channel,
State#state{subscriptions =
dict:erase(ConsumerTag, Subs)}))
end.
with_destination(Command, Frame, State, Fun) ->
case rabbit_stomp_frame:header(Frame, "destination") of
{ok, DestHdr} ->
@ -340,10 +349,13 @@ send_delivery(Delivery = #'basic.deliver'{consumer_tag = ConsumerTag},
State)
end.
send_method(Method, State = #state{channel = Channel}) ->
send_method(Method, Channel, State) ->
amqp_channel:call(Channel, Method),
State.
send_method(Method, State = #state{channel = Channel}) ->
send_method(Method, Channel, State).
send_method(Method, Properties, BodyFragments,
State = #state{channel = Channel}) ->
amqp_channel:call(Channel, Method, #amqp_msg{
@ -519,10 +531,10 @@ priv_error(Message, Detail, ServerPrivateDetail, State) ->
{error, Message, Detail, State}.
priv_error(Message, Format, Args, ServerPrivateDetail, State) ->
priv_error(Message, format_detail(Format, Args),
priv_error(Message, format_message(Format, Args),
ServerPrivateDetail, State).
format_detail(Format, Args) ->
format_message(Format, Args) ->
lists:flatten(io_lib:format(Format, Args)).
%%----------------------------------------------------------------------------
%% Frame sending utilities
@ -554,7 +566,7 @@ send_error(Message, Detail, State) ->
{"content-type", "text/plain"}], Detail, State).
send_error(Message, Format, Args, State) ->
send_error(Message, format_detail(Format, Args), State).
send_error(Message, format_message(Format, Args), State).
%%----------------------------------------------------------------------------
%% Skeleton gen_server callbacks

View File

@ -48,28 +48,55 @@ class BaseTest(unittest.TestCase):
self.assertEquals("foo", msg['message'])
self.assertEquals(dest, msg['headers']['destination'])
def assertListener(self, errMsg, numMsgs=0, numErrs=0, numRcts=0, timeout=3):
if numMsgs + numErrs + numRcts > 0:
self.assertTrue(self.listener.await(timeout), errMsg + " (#awaiting)")
else:
self.assertFalse(self.listener.await(timeout), errMsg + " (#awaiting)")
self.assertEquals(numMsgs, len(self.listener.messages), errMsg + " (#messages)")
self.assertEquals(numErrs, len(self.listener.errors), errMsg + " (#errors)")
self.assertEquals(numRcts, len(self.listener.receipts), errMsg + " (#receipts)")
def assertListenerAfter(self, verb, errMsg="", numMsgs=0, numErrs=0, numRcts=0, timeout=3):
num = numMsgs + numErrs + numRcts
self.listener.reset(num if num>0 else 1)
verb()
self.assertListener(errMsg=errMsg, numMsgs=numMsgs, numErrs=numErrs, numRcts=numRcts, timeout=timeout)
class WaitableListener(object):
def __init__(self):
self.debug = False
if self.debug:
print '(listener) init'
self.messages = []
self.errors = []
self.receipts = []
self.latch = Latch(1)
def on_receipt(self, headers, message):
if self.debug:
print '(on_message) message:', message, 'headers:', headers
self.receipts.append({'message' : message, 'headers' : headers})
self.latch.countdown()
def on_error(self, headers, message):
if self.debug:
print '(on_message) message:', message, 'headers:', headers
self.errors.append({'message' : message, 'headers' : headers})
self.latch.countdown()
def on_message(self, headers, message):
if self.debug:
print '(on_message) message:', message, 'headers:', headers
self.messages.append({'message' : message, 'headers' : headers})
self.latch.countdown()
def reset(self,count=1):
if self.debug:
print '(reset listener) #messages:', len(self.messages),
print '#errors', len(self.errors),
print '#receipts', len(self.receipts), 'Now expecting:', count
self.messages = []
self.errors = []
self.receipts = []

View File

@ -5,45 +5,46 @@ import time
class TestLifecycle(base.BaseTest):
def test_unsubscribe_destination(self):
d = "/exchange/amq.fanout"
def test_unsubscribe_exchange_destination(self):
''' Test UNSUBSCRIBE command with exchange'''
self.unsub_test(self.sub_and_send("/exchange/amq.fanout"))
# subscribe and send message
self.listener.reset()
self.conn.subscribe(destination=d)
self.conn.send("test", destination=d)
self.assertTrue(self.listener.await())
self.assertEquals(1, len(self.listener.messages))
self.assertEquals(0, len(self.listener.errors))
def test_unsubscribe_queue_destination(self):
''' Test UNSUBSCRIBE command with queue'''
self.unsub_test(self.sub_and_send("/queue/unsub01"))
# unsubscribe and send now
self.listener.reset()
self.conn.unsubscribe(destination=d)
self.conn.send("test", destination=d)
self.assertFalse(self.listener.await(3),
"UNSUBSCRIBE failed, still receiving messages")
def test_unsubscribe_exchange_id(self):
''' Test UNSUBSCRIBE command with exchange by id'''
self.unsub_test(self.subid_and_send("/exchange/amq.fanout", "exchid"))
def test_unsubscribe_id(self):
''' Test UNSUBSCRIBE command with id parameter'''
d = "/exchange/amq.fanout"
# subscribe and send message
self.listener.reset()
self.conn.subscribe(destination=d, id="test")
self.conn.send("test", destination=d)
self.assertTrue(self.listener.await())
self.assertEquals(1, len(self.listener.messages))
self.assertEquals(0, len(self.listener.errors))
# unsubscribe and send now
self.listener.reset()
self.conn.unsubscribe(id="test")
self.conn.send("test", destination=d)
self.assertFalse(self.listener.await(3),
"UNSUBSCRIBE failed, still receiving messages")
def test_unsubscribe_queue_id(self):
''' Test UNSUBSCRIBE command with queue by id'''
self.unsub_test(self.subid_and_send("/queue/unsub02", "queid"))
def test_disconnect(self):
''' Run DISCONNECT command '''
self.conn.disconnect()
self.assertFalse(self.conn.is_connected())
def unsub_test(self, verbs):
subverb, unsubverb = verbs
self.assertListenerAfter(subverb,
numMsgs=1, errMsg="FAILED to subscribe and send")
self.assertListenerAfter(unsubverb,
errMsg="Still receiving messages")
def subid_and_send(self, dest, subid):
def subfun():
self.conn.subscribe(destination=dest, id=subid)
self.conn.send("test", destination=dest)
def unsubfun():
self.conn.unsubscribe(id=subid)
return subfun, unsubfun
def sub_and_send(self, dest):
def subfun():
self.conn.subscribe(destination=dest)
self.conn.send("test", destination=dest)
def unsubfun():
self.conn.unsubscribe(destination=dest)
return subfun, unsubfun