reworked tests and introduced /exchange and /queue destinations
This commit is contained in:
parent
f7d4308997
commit
3f9a113873
|
|
@ -1,5 +1,6 @@
|
|||
syntax: regexp
|
||||
\.beam$
|
||||
\.pyc$
|
||||
^dist/
|
||||
^build/
|
||||
^ebin/
|
||||
|
|
|
|||
|
|
@ -314,7 +314,8 @@ receipt_if_necessary(Frame, State) ->
|
|||
end.
|
||||
|
||||
send_method(Method, State = #state{channel = Channel}) ->
|
||||
amqp_channel:call(Channel, Method),
|
||||
Res = amqp_channel:call(Channel, Method),
|
||||
io:format("Res: ~p~n", [Res]),
|
||||
State.
|
||||
|
||||
send_method(Method, Properties, BodyFragments,
|
||||
|
|
@ -420,6 +421,9 @@ process_command("SEND",
|
|||
{ok, DestHeader} ->
|
||||
{ok, Destination} =
|
||||
rabbit_stomp_destination_parser:parse_destination(DestHeader),
|
||||
|
||||
{ok, _Q} = create_queue_if_needed(send, Destination, State),
|
||||
|
||||
Props = #'P_basic'{
|
||||
content_type = BinH("content-type", <<"text/plain">>),
|
||||
content_encoding = BinH("content-encoding", undefined),
|
||||
|
|
@ -430,12 +434,15 @@ process_command("SEND",
|
|||
message_id = BinH("amqp-message-id", undefined),
|
||||
headers = [longstr_field(K, V) ||
|
||||
{"X-" ++ K, V} <- Headers]},
|
||||
|
||||
{Exchange, RoutingKey} = parse_routing_information(Destination),
|
||||
|
||||
Method = #'basic.publish'{
|
||||
exchange = list_to_binary(Exchange),
|
||||
routing_key = list_to_binary(RoutingKey),
|
||||
mandatory = false,
|
||||
immediate = false},
|
||||
|
||||
case transactional(Frame) of
|
||||
{yes, Transaction} ->
|
||||
extend_transaction(Transaction, {Method, Props, BodyFragments},
|
||||
|
|
@ -496,7 +503,7 @@ process_command("SUBSCRIBE",
|
|||
{ok, Str} ->
|
||||
list_to_binary("T_" ++ Str);
|
||||
not_found ->
|
||||
list_to_binary("Q_" ++ Queue)
|
||||
list_to_binary("Q_" ++ DestHeader)
|
||||
end,
|
||||
|
||||
amqp_channel:subscribe(Channel,
|
||||
|
|
@ -506,7 +513,7 @@ process_command("SUBSCRIBE",
|
|||
no_local = false,
|
||||
no_ack = (AckMode == auto),
|
||||
exclusive = false},
|
||||
self()),
|
||||
self()),
|
||||
|
||||
ok = bind_queue_if_needed(subscribe, Queue, Destination, State),
|
||||
|
||||
|
|
@ -529,6 +536,7 @@ process_command("UNSUBSCRIBE", Frame, State = #state{subscriptions = Subs}) ->
|
|||
missing
|
||||
end
|
||||
end,
|
||||
io:format("~p~n", [ConsumerTag]),
|
||||
if
|
||||
ConsumerTag == missing ->
|
||||
{ok, send_error("Missing destination or id",
|
||||
|
|
@ -537,7 +545,7 @@ process_command("UNSUBSCRIBE", Frame, State = #state{subscriptions = Subs}) ->
|
|||
State)};
|
||||
true ->
|
||||
{ok, send_method(#'basic.cancel'{consumer_tag = ConsumerTag,
|
||||
nowait = true},
|
||||
nowait = true},
|
||||
State#state{subscriptions =
|
||||
dict:erase(ConsumerTag, Subs)})}
|
||||
end;
|
||||
|
|
@ -549,16 +557,30 @@ process_command(Command, _Frame, State) ->
|
|||
parse_routing_information({exchange, {Name, undefined}}) ->
|
||||
{Name, ""};
|
||||
parse_routing_information({exchange, {Name, Pattern}}) ->
|
||||
{Name, Pattern}.
|
||||
{Name, Pattern};
|
||||
parse_routing_information({queue, Name}) ->
|
||||
{"", Name}.
|
||||
|
||||
|
||||
create_queue_if_needed(subscribe, {exchange, _},
|
||||
State = #state{channel = Channel}) ->
|
||||
#'queue.declare_ok'{queue = Queue} =
|
||||
amqp_channel:call(Channel, #'queue.declare'{auto_delete = true}),
|
||||
{ok, Queue};
|
||||
create_queue_if_needed(send, {exchange, _}, State) ->
|
||||
{ok, undefined};
|
||||
create_queue_if_needed(_, {queue, Name},
|
||||
State = #state{channel = Channel}) ->
|
||||
Queue = list_to_binary(Name),
|
||||
#'queue.declare_ok'{queue = Queue} =
|
||||
amqp_channel:call(Channel,
|
||||
#'queue.declare'{durable = true,
|
||||
queue = Queue}),
|
||||
{ok, Queue}.
|
||||
|
||||
|
||||
bind_queue_if_needed(subscribe, Queue, {exchange, {Name, Pattern}},
|
||||
State = #state{channel = Channel}) ->
|
||||
State = #state{channel = Channel}) ->
|
||||
RoutingKey = case Pattern of
|
||||
undefined -> "";
|
||||
_ -> Pattern
|
||||
|
|
@ -569,6 +591,9 @@ bind_queue_if_needed(subscribe, Queue, {exchange, {Name, Pattern}},
|
|||
queue = Queue,
|
||||
exchange = list_to_binary(Name),
|
||||
routing_key = list_to_binary(RoutingKey)}),
|
||||
ok;
|
||||
bind_queue_if_needed(_Method, _Queue, {queue, _}, _State) ->
|
||||
%% rely on default binding for /queue
|
||||
ok.
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
import unittest
|
||||
import stomp
|
||||
import base
|
||||
import time
|
||||
|
||||
class TestAck(base.BaseTest):
|
||||
|
||||
def test_ack_client(self):
|
||||
d = "/exchange/amq.direct/test"
|
||||
|
||||
# subscribe and send message
|
||||
self.listener.reset()
|
||||
self.conn.subscribe(destination=d, ack='client')
|
||||
self.conn.send("test", destination=d)
|
||||
self.assertTrue(self.listener.await(3), "initial message not received")
|
||||
self.assertEquals(1, len(self.listener.messages))
|
||||
|
||||
# disconnect with no ack
|
||||
self.conn.disconnect()
|
||||
|
||||
# now reconnect
|
||||
self.listener.reset()
|
||||
conn2 = self.createConnection()
|
||||
conn2.subscribe(destination=d, ack='client')
|
||||
self.assertTrue(self.listener.await())
|
||||
self.assertEquals(1, len(self.listener.messages))
|
||||
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
import unittest
|
||||
import stomp
|
||||
import sys
|
||||
import threading
|
||||
|
||||
|
||||
class BaseTest(unittest.TestCase):
|
||||
|
||||
def createConnection(self):
|
||||
conn = stomp.Connection(user="guest", passcode="guest")
|
||||
conn.start()
|
||||
conn.connect()
|
||||
return conn
|
||||
|
||||
def setUp(self):
|
||||
self.conn = self.createConnection()
|
||||
self.listener = WaitableListener()
|
||||
self.conn.set_listener('', self.listener)
|
||||
|
||||
def tearDown(self):
|
||||
if self.conn.is_connected():
|
||||
self.conn.stop()
|
||||
|
||||
def simple_test_send_rec(self, dest, route = None):
|
||||
self.listener.reset()
|
||||
|
||||
self.conn.subscribe(destination=dest)
|
||||
self.conn.send("foo", destination=dest)
|
||||
|
||||
self.assertTrue(self.listener.await(), "Timeout, no message received")
|
||||
|
||||
# assert no errors
|
||||
if len(self.listener.errors) > 0:
|
||||
self.fail(self.listener.errors[0]['message'])
|
||||
|
||||
# check header content
|
||||
msg = self.listener.messages[0]
|
||||
self.assertEquals("foo", msg['message'])
|
||||
self.assertEquals(dest, msg['headers']['destination'])
|
||||
|
||||
|
||||
class WaitableListener(object):
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
self.errors = []
|
||||
self.receipts = []
|
||||
self.event = threading.Event()
|
||||
|
||||
|
||||
def on_receipt(self, headers, message):
|
||||
self.receipt.append({'message' : message, 'headers' : headers})
|
||||
self.event.set()
|
||||
|
||||
def on_error(self, headers, message):
|
||||
self.errors.append({'message' : message, 'headers' : headers})
|
||||
self.event.set()
|
||||
|
||||
def on_message(self, headers, message):
|
||||
print message
|
||||
self.messages.append({'message' : message, 'headers' : headers})
|
||||
self.event.set()
|
||||
|
||||
def reset(self):
|
||||
self.messages = []
|
||||
self.errors = []
|
||||
self.event.clear()
|
||||
|
||||
def await(self, timeout=10):
|
||||
self.event.wait(timeout)
|
||||
return self.event.is_set()
|
||||
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
import unittest
|
||||
import stomp
|
||||
import base
|
||||
|
||||
class TestExchange(base.BaseTest):
|
||||
|
||||
|
||||
def test_amq_direct(self):
|
||||
self.__test_exchange_send_rec("amq.direct", "route")
|
||||
|
||||
def test_amq_topic(self):
|
||||
self.__test_exchange_send_rec("amq.topic", "route")
|
||||
|
||||
def test_amq_fanout(self):
|
||||
self.__test_exchange_send_rec("amq.fanout", "route")
|
||||
|
||||
def test_amq_fanout_no_route(self):
|
||||
self.__test_exchange_send_rec("amq.fanout")
|
||||
|
||||
def test_invalid_exchange(self):
|
||||
self.listener.reset()
|
||||
self.conn.subscribe(destination="/exchange/does.not.exist")
|
||||
self.listener.await()
|
||||
self.assertEquals(1, len(self.listener.errors))
|
||||
err = self.listener.errors[0]
|
||||
self.assertEquals("not_found", err['headers']['message'])
|
||||
self.assertEquals("no exchange 'does.not.exist' in vhost '/'\n", err['message'])
|
||||
|
||||
def __test_exchange_send_rec(self, exchange, route = None):
|
||||
dest = "/exchange/" + exchange
|
||||
if route != None:
|
||||
dest += "/" + route
|
||||
|
||||
self.simple_test_send_rec(dest)
|
||||
|
||||
class TestQueue(base.BaseTest):
|
||||
|
||||
def test_send_receive(self):
|
||||
d = '/queue/test'
|
||||
self.simple_test_send_rec(d)
|
||||
|
||||
def test_send_receive_in_other_conn(self):
|
||||
d = '/queue/test2'
|
||||
|
||||
# send
|
||||
self.conn.send("hello", destination=d)
|
||||
|
||||
# now receive
|
||||
conn2 = self.createConnection()
|
||||
try:
|
||||
listener2 = base.WaitableListener()
|
||||
conn2.set_listener('', listener2)
|
||||
|
||||
conn2.subscribe(destination=d)
|
||||
self.assertTrue(listener2.await(10), "no receive")
|
||||
finally:
|
||||
conn2.stop()
|
||||
|
||||
def test_send_receive_in_other_conn_with_disconnect(self):
|
||||
d = '/queue/test3'
|
||||
|
||||
# send
|
||||
self.conn.send("hello thar", destination=d)
|
||||
self.conn.stop()
|
||||
|
||||
# now receive
|
||||
conn2 = self.createConnection()
|
||||
try:
|
||||
listener2 = base.WaitableListener()
|
||||
conn2.set_listener('', listener2)
|
||||
|
||||
conn2.subscribe(destination=d)
|
||||
self.assertTrue(listener2.await(5), "no receive")
|
||||
finally:
|
||||
conn2.stop()
|
||||
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
import unittest
|
||||
import stomp
|
||||
import base
|
||||
import time
|
||||
|
||||
class TestLifecycle(base.BaseTest):
|
||||
|
||||
def test_unsubscribe_destination(self):
|
||||
d = "/exchange/amq.fanout"
|
||||
|
||||
# subscribe and send message
|
||||
self.listener.reset()
|
||||
self.conn.subscribe(destination=d)
|
||||
self.conn.send("test", destination=d)
|
||||
self.assertTrue(self.listener.await())
|
||||
self.assertEquals(1, len(self.listener.messages))
|
||||
|
||||
# unsubscribe and send now
|
||||
self.listener.reset()
|
||||
self.conn.unsubscribe(destination=d)
|
||||
self.conn.send("test", destination=d)
|
||||
self.assertFalse(self.listener.await(3),
|
||||
"UNSUBSCRIBE failed, still receiving messages")
|
||||
|
||||
def test_unsubscribe_id(self):
|
||||
''' Test UNSUBSCRIBE command with id parameter'''
|
||||
d = "/exchange/amq.fanout"
|
||||
|
||||
# subscribe and send message
|
||||
self.listener.reset()
|
||||
self.conn.subscribe(destination=d, id="test")
|
||||
self.conn.send("test", destination=d)
|
||||
self.assertTrue(self.listener.await())
|
||||
self.assertEquals(1, len(self.listener.messages))
|
||||
|
||||
# unsubscribe and send now
|
||||
self.listener.reset()
|
||||
self.conn.unsubscribe(id="test")
|
||||
self.conn.send("test", destination=d)
|
||||
self.assertFalse(self.listener.await(3),
|
||||
"UNSUBSCRIBE failed, still receiving messages")
|
||||
|
||||
def test_disconnect(self):
|
||||
''' Run DISCONNECT command '''
|
||||
self.conn.disconnect()
|
||||
self.assertFalse(self.conn.is_connected())
|
||||
|
||||
|
|
@ -0,0 +1,229 @@
|
|||
'''
|
||||
Few tests for a rabbitmq-stomp adaptor. They intend to increase code coverage
|
||||
of the erlang stomp code.
|
||||
'''
|
||||
import unittest
|
||||
import re
|
||||
import socket
|
||||
import functools
|
||||
import time
|
||||
import sys
|
||||
|
||||
def connect(cnames):
|
||||
''' Decorator that creates stomp connections and issues CONNECT '''
|
||||
cmd=('CONNECT\n'
|
||||
'prefetch: 0\n'
|
||||
'login:guest\n'
|
||||
'passcode:guest\n'
|
||||
'\n'
|
||||
'\n\0')
|
||||
resp = ('CONNECTED\n'
|
||||
'session:(.*)\n'
|
||||
'\n\x00')
|
||||
def w(m):
|
||||
@functools.wraps(m)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
for cname in cnames:
|
||||
sd = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sd.settimeout(3)
|
||||
sd.connect((self.host, self.port))
|
||||
sd.sendall(cmd)
|
||||
self.match(resp, sd.recv(4096))
|
||||
setattr(self, cname, sd)
|
||||
try:
|
||||
r = m(self, *args, **kwargs)
|
||||
finally:
|
||||
for cname in cnames:
|
||||
try:
|
||||
getattr(self, cname).close()
|
||||
except IOError:
|
||||
pass
|
||||
return r
|
||||
return wrapper
|
||||
return w
|
||||
|
||||
|
||||
class TestParsing(unittest.TestCase):
|
||||
host='127.0.0.1'
|
||||
port=61613
|
||||
|
||||
|
||||
def match(self, pattern, data):
|
||||
''' helper: try to match 'pattern' regexp with 'data' string.
|
||||
Fail testif they don't match.
|
||||
'''
|
||||
matched = re.match(pattern, data)
|
||||
if matched:
|
||||
return matched.groups()
|
||||
self.assertTrue(False, 'No match:\n%r\n%r' % (pattern, data) )
|
||||
|
||||
@connect(['cd'])
|
||||
def test_newline_after_nul(self):
|
||||
self.cd.sendall('\n'
|
||||
'SUBSCRIBE\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'\n\x00\n'
|
||||
'SEND\n'
|
||||
'destination:/exchange/amq.fanout\n\n'
|
||||
'hello\n\x00\n')
|
||||
resp = ('MESSAGE\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'message-id:session-(.*)\n'
|
||||
'content-type:text/plain\n'
|
||||
'content-length:6\n'
|
||||
'\n'
|
||||
'hello\n\0')
|
||||
self.match(resp, self.cd.recv(4096))
|
||||
|
||||
|
||||
@connect(['cd'])
|
||||
def test_newline_after_nul_and_leading_nul(self):
|
||||
self.cd.sendall('\n'
|
||||
'\x00SUBSCRIBE\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'\n\x00\n'
|
||||
'\x00SEND\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'\nhello\n\x00\n')
|
||||
resp = ('MESSAGE\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'message-id:session-(.*)\n'
|
||||
'content-type:text/plain\n'
|
||||
'content-length:6\n'
|
||||
'\n'
|
||||
'hello\n\0')
|
||||
self.match(resp, self.cd.recv(4096))
|
||||
|
||||
@connect(['cd'])
|
||||
def test_bad_command(self):
|
||||
''' Trigger an error message. '''
|
||||
self.cd.sendall('WRONGCOMMAND\n'
|
||||
'destination:a\n'
|
||||
'exchange:amq.fanout\n'
|
||||
'\n\0')
|
||||
resp = ('ERROR\n'
|
||||
'message:Bad command\n'
|
||||
'content-type:text/plain\n'
|
||||
'content-length:41\n'
|
||||
'\n'
|
||||
'Could not interpret command WRONGCOMMAND\n'
|
||||
'\0')
|
||||
self.match(resp, self.cd.recv(4096))
|
||||
|
||||
@connect(['sd', 'cd1', 'cd2'])
|
||||
def test_broadcast(self):
|
||||
''' Single message should be delivered to two consumers:
|
||||
amq.topic --routing_key--> first_queue --> first_connection
|
||||
\--routing_key--> second_queue--> second_connection
|
||||
'''
|
||||
subscribe=( 'SUBSCRIBE\n'
|
||||
'id: XsKNhAf\n'
|
||||
'destination:/exchange/amq.topic/da9d4779\n'
|
||||
'\n\0')
|
||||
for cd in [self.cd1, self.cd2]:
|
||||
cd.sendall(subscribe)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
self.sd.sendall('SEND\n'
|
||||
'destination:/exchange/amq.topic/da9d4779\n'
|
||||
'\n'
|
||||
'message'
|
||||
'\n\0')
|
||||
|
||||
resp=('MESSAGE\n'
|
||||
'destination:/exchange/amq.topic/da9d4779\n'
|
||||
'message-id:(.*)\n'
|
||||
'content-type:text/plain\n'
|
||||
'subscription:(.*)\n'
|
||||
'content-length:8\n'
|
||||
'\n'
|
||||
'message'
|
||||
'\n\x00')
|
||||
for cd in [self.cd1, self.cd2]:
|
||||
self.match(resp, cd.recv(4096))
|
||||
|
||||
|
||||
@connect(['sd', 'cd1', 'cd2'])
|
||||
def test_roundrobin(self):
|
||||
''' Two messages should be delivered to two consumers using round robin:
|
||||
amq.topic --routing_key--> single_queue --> first_connection
|
||||
\---> second_connection
|
||||
'''
|
||||
messages = ['message1', 'message2']
|
||||
subscribe=(
|
||||
'SUBSCRIBE\n'
|
||||
'id: sTXtc\n'
|
||||
'destination:/exchange/amq.topic/yAoXMwiF\n'
|
||||
'\n\0')
|
||||
for cd in [self.cd1, self.cd2]:
|
||||
cd.sendall(subscribe)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
for msg in messages:
|
||||
self.sd.sendall('SEND\n'
|
||||
'destination:/exchange/amq.topic/yAoXMwiF\n'
|
||||
'\n'
|
||||
'%s'
|
||||
'\n\0' % msg)
|
||||
|
||||
resp=('MESSAGE\n'
|
||||
'destination:/exchange/amq.topic/yAoXMwiF\n'
|
||||
'message-id:.*\n'
|
||||
'content-type:text/plain\n'
|
||||
'subscription:.*\n'
|
||||
'content-length:.\n'
|
||||
'\n'
|
||||
'(.*)'
|
||||
'\n\x00')
|
||||
|
||||
recv_messages = [self.match(resp, cd.recv(4096))[0] \
|
||||
for cd in [self.cd1, self.cd2]]
|
||||
self.assertTrue(sorted(messages) == sorted(recv_messages), \
|
||||
'%r != %r ' % (messages, recv_messages))
|
||||
|
||||
|
||||
|
||||
@connect(['cd'])
|
||||
def test_huge_message(self):
|
||||
''' Test sending/receiving huge (92MB) message. '''
|
||||
subscribe=( 'SUBSCRIBE\n'
|
||||
'id: xxx\n'
|
||||
'destination:/exchange/amq.topic/test_huge_message\n'
|
||||
'\n\0')
|
||||
self.cd.sendall(subscribe)
|
||||
|
||||
# Instead of 92MB, let's use 16, so that the test can finish in
|
||||
# reasonable time.
|
||||
##message = 'x' * 1024*1024*92
|
||||
message = 'x' * 1024*1024*16
|
||||
|
||||
self.cd.sendall('SEND\n'
|
||||
'destination:/exchange/amq.topic/test_huge_message\n'
|
||||
'\n'
|
||||
'%s'
|
||||
'\0' % message)
|
||||
|
||||
resp=('MESSAGE\n'
|
||||
'destination:/exchange/amq.topic/test_huge_message\n'
|
||||
'message-id:(.*)\n'
|
||||
'content-type:text/plain\n'
|
||||
'subscription:(.*)\n'
|
||||
'content-length:%i\n'
|
||||
'\n'
|
||||
'%s(.*)'
|
||||
% (len(message), message[:8000]) )
|
||||
|
||||
recv = []
|
||||
s = 0
|
||||
while len(recv) < 1 or recv[-1][-1] != '\0':
|
||||
buf = self.cd.recv(4096*16)
|
||||
s += len(buf)
|
||||
recv.append( buf )
|
||||
buf = ''.join(recv)
|
||||
|
||||
# matching 100MB regexp is way too expensive.
|
||||
self.match(resp, buf[:8192])
|
||||
self.assertEqual(len(buf) > len(message), True)
|
||||
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
"""
|
||||
This provides basic connectivity to a message broker supporting the 'stomp' protocol.
|
||||
At the moment ACK, SEND, SUBSCRIBE, UNSUBSCRIBE, BEGIN, ABORT, COMMIT, CONNECT and DISCONNECT operations
|
||||
are supported.
|
||||
|
||||
See the project page for more information.
|
||||
|
||||
Meta-Data
|
||||
---------
|
||||
Author: Jason R Briggs
|
||||
License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
Start Date: 2005/12/01
|
||||
Last Revision Date: $Date: 2008/09/11 00:16 $
|
||||
Project Page: http://www.briggs.net.nz/log/projects/stomp.py
|
||||
|
||||
Notes/Attribution
|
||||
-----------------
|
||||
* uuid method courtesy of Carl Free Jr:
|
||||
http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/213761
|
||||
* patch from Andreas Schobel
|
||||
* patches from Julian Scheid of Rising Sun Pictures (http://open.rsp.com.au)
|
||||
* patch from Fernando
|
||||
* patches from Eugene Strulyov
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.split(__file__)[0])
|
||||
|
||||
import connect, listener, exception
|
||||
|
||||
__version__ = __version__ = (3, 0, 2)
|
||||
Connection = connect.Connection
|
||||
ConnectionListener = listener.ConnectionListener
|
||||
StatsListener = listener.StatsListener
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
import sys
|
||||
|
||||
#
|
||||
# Functions for backwards compatibility
|
||||
#
|
||||
|
||||
def get_func_argcount(func):
|
||||
"""
|
||||
Return the argument count for a function
|
||||
"""
|
||||
if sys.hexversion > 0x03000000:
|
||||
return func.__code__.co_argcount
|
||||
else:
|
||||
return func.func_code.co_argcount
|
||||
|
||||
def input_prompt(prompt):
|
||||
"""
|
||||
Get user input
|
||||
"""
|
||||
if sys.hexversion > 0x03000000:
|
||||
return input(prompt)
|
||||
else:
|
||||
return raw_input(prompt)
|
||||
|
||||
def join(chars):
|
||||
if sys.hexversion > 0x03000000:
|
||||
return bytes('', 'UTF-8').join(chars).decode('UTF-8')
|
||||
else:
|
||||
return ''.join(chars)
|
||||
|
||||
def socksend(conn, msg):
|
||||
if sys.hexversion > 0x03000000:
|
||||
conn.send(msg.encode())
|
||||
else:
|
||||
conn.send(msg)
|
||||
|
||||
|
||||
def getheader(headers, key):
|
||||
if sys.hexversion > 0x03000000:
|
||||
return headers[key]
|
||||
else:
|
||||
return headers.getheader(key)
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
Stomp Bridge for Oracle AQ
|
||||
==========================
|
||||
|
||||
This provides a STOMP bridging mechanism to Oracle AQ (Advanced Messaging).
|
||||
|
||||
Before using, you'll need to grant various privileges to your user:
|
||||
|
||||
GRANT RESOURCE TO <UserName>;
|
||||
GRANT CONNECT TO <UserName>;
|
||||
GRANT EXECUTE ANY PROCEDURE TO <UserName>;
|
||||
GRANT aq_administrator_role TO <UserName>;
|
||||
GRANT aq_user_role TO <UserName>;
|
||||
GRANT EXECUTE ON dbms_aqadm TO <UserName>;
|
||||
GRANT EXECUTE ON dbms_aq TO <UserName>;
|
||||
GRANT EXECUTE ON dbms_aqin TO <UserName>;
|
||||
|
||||
You will also need to create a sequence:
|
||||
|
||||
CREATE SEQUENCE stomp_client_id_seq
|
||||
/
|
||||
|
||||
Startup the Oracle Stomp Server, by running the following command (from the root directory of this project):
|
||||
|
||||
bridge/oracleaq.py -D localhost -B 1521 -I xe -U test -W test -N localhost -T 8888
|
||||
|
||||
Run:
|
||||
|
||||
bridge/oracleaq.py --help
|
||||
|
||||
to see the make up of the command line arguments.
|
||||
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
import os
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from stomp import utils, backward
|
||||
|
||||
class StompServer(threading.Thread):
|
||||
def __init__(self, listen_host_and_port, connection_class):
|
||||
threading.Thread.__init__(self)
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
|
||||
self.socket.bind(listen_host_and_port)
|
||||
self.socket.listen(1)
|
||||
print('Listening for STOMP connections on %s:%s' % listen_host_and_port)
|
||||
self.running = True
|
||||
self.connections = [ ]
|
||||
self.connection_class = connection_class
|
||||
|
||||
def notify(self, queue, msg_id):
|
||||
for conn in self.connections:
|
||||
conn.notify(queue, msg_id)
|
||||
|
||||
def add_connection(self, conn):
|
||||
self.connections.append(conn)
|
||||
|
||||
def remove_connection(self, conn):
|
||||
pos = self.connections.index(conn)
|
||||
if pos >= 0:
|
||||
del self.connections[pos]
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
while self.running:
|
||||
conn, addr = self.socket.accept()
|
||||
conn = self.connection_class(self, conn, addr)
|
||||
self.add_connection(conn)
|
||||
conn.start()
|
||||
finally:
|
||||
for conn in self.connections:
|
||||
conn.shutdown()
|
||||
self.shutdown()
|
||||
|
||||
|
||||
class StompConnection(threading.Thread):
|
||||
def __init__(self, server, conn, addr):
|
||||
threading.Thread.__init__(self)
|
||||
self.server = server
|
||||
self.conn = conn
|
||||
self.addr = addr
|
||||
self.running = True
|
||||
self.id = str(uuid.uuid4())
|
||||
|
||||
def send_error(self, msg):
|
||||
self.send('ERROR\nmessage: %s\n\n' % msg)
|
||||
|
||||
def send(self, msg):
|
||||
if not msg.endswith('\x00'):
|
||||
msg = msg + '\x00'
|
||||
backward.socksend(self.conn, msg)
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
data = []
|
||||
while self.running:
|
||||
c = self.conn.recv(1)
|
||||
if c == '' or len(c) == 0:
|
||||
break
|
||||
data.append(c)
|
||||
if ord(c) == 0:
|
||||
frame = backward.join(data)
|
||||
print(frame)
|
||||
(frame_type, headers, body) = utils.parse_frame(frame)
|
||||
method = 'handle_%s' % frame_type
|
||||
print('Method = %s' % method)
|
||||
if hasattr(self, method):
|
||||
getattr(self, method)(headers, body)
|
||||
else:
|
||||
self.send_error('invalid command %s' % frame_type)
|
||||
data = []
|
||||
except Exception:
|
||||
_, e, tb = sys.exc_info()
|
||||
print(e)
|
||||
import traceback
|
||||
traceback.print_tb(tb)
|
||||
self.server.remove_connection(self)
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self):
|
||||
self.conn.close()
|
||||
self.running = False
|
||||
|
|
@ -0,0 +1,373 @@
|
|||
#! /usr/bin/env python
|
||||
|
||||
import cx_Oracle
|
||||
from optparse import OptionParser
|
||||
import re
|
||||
import sys
|
||||
try:
|
||||
from SocketServer import ThreadingMixIn, ThreadingTCPServer, BaseRequestHandler
|
||||
except ImportError:
|
||||
from socketserver import ThreadingMixIn, ThreadingTCPServer, BaseRequestHandler
|
||||
try:
|
||||
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
|
||||
except ImportError:
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
import threading
|
||||
|
||||
from bridge import StompServer, StompConnection
|
||||
|
||||
from stomp import utils, backward
|
||||
|
||||
global QUEUE_TABLE
|
||||
QUEUE_TABLE = 'STOMP_MSG_QUEUE'
|
||||
|
||||
SETUP_SQLS = ['''CREATE OR REPLACE PROCEDURE stomp_enq(queue_name in varchar2, msg in varchar2, props in varchar2) AS
|
||||
enqueue_options dbms_aq.enqueue_options_t;
|
||||
message_properties dbms_aq.message_properties_t;
|
||||
message_handle RAW(16);
|
||||
BEGIN
|
||||
message_properties.user_property := sys.anyData.convertVarchar2(props);
|
||||
dbms_aq.enqueue(queue_name => queue_name, enqueue_options => enqueue_options, message_properties => message_properties, payload => utl_raw.cast_to_raw(msg), msgid => message_handle);
|
||||
END;''',
|
||||
'''CREATE OR REPLACE PROCEDURE stomp_sub(qn in varchar2, subscriber_name in varchar2, notification_address in varchar2) AS
|
||||
BEGIN
|
||||
dbms_aqadm.add_subscriber(queue_name => qn, subscriber => sys.aq$_agent(subscriber_name, null, null));
|
||||
dbms_aq.register(sys.aq$_reg_info_list(sys.aq$_reg_info(qn || ':' || subscriber_name, DBMS_AQ.NAMESPACE_AQ, notification_address, HEXTORAW('FF')) ), 1);
|
||||
END;''',
|
||||
'''CREATE OR REPLACE PROCEDURE stomp_unsub(qn in varchar2, subscriber_name in varchar2, notification_address in varchar2) AS
|
||||
subscriber_count int;
|
||||
BEGIN
|
||||
dbms_aq.unregister(sys.aq$_reg_info_list(sys.aq$_reg_info(qn || ':' || subscriber_name, dbms_aq.namespace_aq, notification_address, HEXTORAW('FF')) ), 1);
|
||||
dbms_aqadm.remove_subscriber(queue_name => qn, subscriber => sys.aq$_agent(subscriber_name, null, null));
|
||||
select count(*) into subscriber_count from user_queue_subscribers where queue_name = qn;
|
||||
IF subscriber_count = 0 THEN
|
||||
dbms_aqadm.stop_queue(qn);
|
||||
dbms_aqadm.drop_queue(qn);
|
||||
END IF;
|
||||
END;''',
|
||||
'''CREATE OR REPLACE FUNCTION getvarchar2(anydata_p in sys.anydata) return varchar2 is
|
||||
x number;
|
||||
thevarchar2 varchar2(4000);
|
||||
BEGIN
|
||||
IF anydata_p IS NULL THEN
|
||||
return '';
|
||||
ELSE
|
||||
x := anydata_p.getvarchar2(thevarchar2);
|
||||
return thevarchar2;
|
||||
END IF;
|
||||
END;''']
|
||||
|
||||
DEST_RE = re.compile(r'<destination>"[^"]*"."([^"]*)"</destination>')
|
||||
MSGID_RE = re.compile(r'<message_id>([^<]*)</message_id>')
|
||||
|
||||
class NotificationHandler(BaseHTTPRequestHandler):
|
||||
'''
|
||||
Handler for message notifications from Oracle
|
||||
'''
|
||||
def do_POST(self):
|
||||
try:
|
||||
length = backward.getheader(self.headers, 'Content-Length')
|
||||
s = self.rfile.read(int(length))
|
||||
s = s.decode('UTF-8')
|
||||
queue = DEST_RE.search(s).group(1)
|
||||
msg_id = MSGID_RE.search(s).group(1).lstrip().rstrip()
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
self.wfile.write("OK".encode())
|
||||
if msg_id not in self.server.msg_ids:
|
||||
self.server.notify(queue, msg_id)
|
||||
self.server.msg_ids.append(msg_id)
|
||||
if len(self.server.msg_ids) > 100:
|
||||
del self.server.msg_ids[0]
|
||||
except Exception:
|
||||
_, e, tb = sys.exc_info()
|
||||
import traceback
|
||||
traceback.print_tb(tb)
|
||||
|
||||
class NotificationListener(ThreadingMixIn, HTTPServer, threading.Thread):
|
||||
def __init__(self, notify, host_and_port):
|
||||
HTTPServer.__init__(self, host_and_port, NotificationHandler)
|
||||
threading.Thread.__init__(self)
|
||||
self.setDaemon(True)
|
||||
self.notify = notify
|
||||
self.msg_ids = [ ]
|
||||
print('Listening for Oracle Notifications on %s:%s' % host_and_port)
|
||||
|
||||
def serve_forever(self):
|
||||
self.stop_serving = False
|
||||
while not self.stop_serving:
|
||||
self.handle_request()
|
||||
|
||||
def run(self):
|
||||
self.serve_forever()
|
||||
|
||||
def stop(self):
|
||||
self.stop_serving = True
|
||||
|
||||
class OracleStompConnection(StompConnection):
|
||||
def __init__(self, server, conn, addr):
|
||||
StompConnection.__init__(self, server, conn, addr)
|
||||
self.dbconn = cx_Oracle.connect('%s/%s@//%s:%s/%s' % (server.username, server.passwd, server.oracle_host_and_port[0], server.oracle_host_and_port[1], server.db))
|
||||
print("Connected to Oracle")
|
||||
self.client_id = self.__get_client_id()
|
||||
print("Client Id %s" % self.client_id)
|
||||
self.queues = {}
|
||||
self.transactions = {}
|
||||
self.semaphore = threading.BoundedSemaphore(1)
|
||||
|
||||
def __get_client_id(self):
|
||||
cursor = self.dbconn.cursor()
|
||||
cursor.execute('SELECT stomp_client_id_seq.nextval FROM dual')
|
||||
row = cursor.fetchone()
|
||||
return 's%s' % row[0]
|
||||
|
||||
def __is_created(self, cursor, destination):
|
||||
if destination in self.queues:
|
||||
return True
|
||||
else:
|
||||
cursor.execute('SELECT COUNT(*) FROM user_queues WHERE name = UPPER(:queue)', queue = destination)
|
||||
row = cursor.fetchone()
|
||||
return row[0] > 0
|
||||
|
||||
def __create(self, cursor, destination):
|
||||
cursor.callproc('DBMS_AQADM.CREATE_QUEUE', [], { 'queue_name' : destination, 'queue_table' : QUEUE_TABLE })
|
||||
cursor.callproc('DBMS_AQADM.START_QUEUE', [], { 'queue_name' : destination })
|
||||
|
||||
def __sanitise(self, headers):
|
||||
if 'destination' in headers:
|
||||
dest = headers['destination'].replace('/', '_')
|
||||
if dest.startswith('_'):
|
||||
dest = dest[1:]
|
||||
headers['destination'] = dest.upper()
|
||||
|
||||
def __get_notification_address(self):
|
||||
return 'http://%s:%s' % (self.server.notification_host_and_port[0], self.server.notification_host_and_port[1])
|
||||
|
||||
def __commit_or_rollback(self, headers, commit = True):
|
||||
self.__sanitise(headers)
|
||||
if 'transaction' not in headers:
|
||||
self.send_error('Transaction identifier is required')
|
||||
return
|
||||
transaction_id = headers['transaction']
|
||||
if transaction_id not in self.transactions:
|
||||
self.send_error('Transaction %s does not exist' % transaction_id)
|
||||
return
|
||||
else:
|
||||
if commit:
|
||||
for (method, headers, body) in self.transactions[transaction_id]:
|
||||
getattr(self, method)(headers, body)
|
||||
del self.transactions[transaction_id]
|
||||
|
||||
def __save(self, command, headers, body):
|
||||
transaction_id = headers['transaction']
|
||||
if transaction_id not in self.transactions:
|
||||
self.send_error('No such transaction %s' % transaction_id)
|
||||
else:
|
||||
del headers['transaction']
|
||||
self.transactions[transaction_id].append((command, headers, body))
|
||||
|
||||
def notify(self, queue, msg_id):
|
||||
if queue in self.queues.keys():
|
||||
self.semaphore.acquire()
|
||||
cursor = self.dbconn.cursor()
|
||||
try:
|
||||
cursor.execute('SELECT user_data, getvarchar2(user_prop) AS user_props FROM %s WHERE msgid = :msgid' % QUEUE_TABLE, msgid = msg_id)
|
||||
row = cursor.fetchone()
|
||||
headers = utils.parse_headers(row[1].split('\n'))
|
||||
headers['destination'] = self.queues[queue]
|
||||
headers['message-id'] = msg_id
|
||||
hdr = [ ]
|
||||
for key, val in headers.items():
|
||||
hdr.append('%s:%s' % (key, val))
|
||||
msg = row[0].read().decode('UTF-8')
|
||||
self.send('MESSAGE\n%s\n\n%s' % ('\n'.join(hdr), msg))
|
||||
except Exception:
|
||||
_, e, tb = sys.exc_info()
|
||||
import traceback
|
||||
traceback.print_tb(tb)
|
||||
print(e)
|
||||
finally:
|
||||
cursor.close()
|
||||
self.semaphore.release()
|
||||
|
||||
def handle_ACK(self, headers, body):
|
||||
self.__sanitise(headers)
|
||||
if 'transaction' in headers:
|
||||
self.__save('handle_ACK', headers, body)
|
||||
else:
|
||||
# FIXME
|
||||
self.send_error('Not currently supported')
|
||||
|
||||
|
||||
def handle_BEGIN(self, headers, body):
|
||||
if 'transaction' not in headers:
|
||||
self.send_error('Transaction identifier is required')
|
||||
return
|
||||
transaction_id = headers['transaction']
|
||||
if transaction_id in self.transactions:
|
||||
self.send_error('Transaction %s already started' % transaction_id)
|
||||
return
|
||||
else:
|
||||
self.transactions[transaction_id] = [ ]
|
||||
|
||||
def handle_COMMIT(self, headers, body):
|
||||
self.__commit_or_rollback(headers)
|
||||
|
||||
def handle_ABORT(self, headers, body):
|
||||
self.__commit_or_rollback(headers, False)
|
||||
|
||||
def handle_CONNECT(self, headers, body):
|
||||
self.send('CONNECTED\nsession: %s\n\n' % self.id)
|
||||
|
||||
def handle_DISCONNECT(self, headers, body):
|
||||
self.shutdown()
|
||||
|
||||
def handle_SUBSCRIBE(self, headers, body):
|
||||
self.semaphore.acquire()
|
||||
cursor = self.dbconn.cursor()
|
||||
try:
|
||||
orig_qn = headers['destination']
|
||||
self.__sanitise(headers)
|
||||
if not self.__is_created(cursor, headers['destination']):
|
||||
self.__create(cursor, headers['destination'])
|
||||
try:
|
||||
cursor.callproc('stomp_sub', [headers['destination'], self.client_id, self.__get_notification_address()])
|
||||
self.queues[headers['destination']] = orig_qn
|
||||
except Exception:
|
||||
_, e, _ = sys.exc_info()
|
||||
print(e)
|
||||
finally:
|
||||
cursor.close()
|
||||
self.semaphore.release()
|
||||
|
||||
def handle_UNSUBSCRIBE(self, headers, body):
|
||||
self.semaphore.acquire()
|
||||
cursor = self.dbconn.cursor()
|
||||
try:
|
||||
self.__sanitise(headers)
|
||||
try:
|
||||
cursor.callproc('stomp_unsub', [headers['destination'], self.client_id, self.__get_notification_address()])
|
||||
except:
|
||||
pass
|
||||
if headers['destination'] in self.queues.keys():
|
||||
del self.queues[headers['destination']]
|
||||
finally:
|
||||
cursor.close()
|
||||
self.semaphore.release()
|
||||
|
||||
def handle_SEND(self, headers, body):
|
||||
self.__sanitise(headers)
|
||||
if 'transaction' in headers:
|
||||
self.__save('handle_SEND', headers, body)
|
||||
else:
|
||||
self.semaphore.acquire()
|
||||
cursor = self.dbconn.cursor()
|
||||
try:
|
||||
if not self.__is_created(cursor, headers['destination']):
|
||||
self.__create(cursor, headers['destination'])
|
||||
hdr = [ ]
|
||||
for key, val in headers.items():
|
||||
hdr.append('%s:%s\n' % (key, val))
|
||||
cursor.callproc('stomp_enq', [headers['destination'], body.rstrip(), ''.join(hdr)])
|
||||
self.dbconn.commit()
|
||||
except Exception:
|
||||
_, e, tb = sys.exc_info()
|
||||
import traceback
|
||||
traceback.print_tb(tb)
|
||||
print(e)
|
||||
finally:
|
||||
cursor.close()
|
||||
self.semaphore.release()
|
||||
|
||||
def shutdown(self):
|
||||
self.running = False
|
||||
self.semaphore.acquire()
|
||||
for queue in list(self.queues.keys()):
|
||||
self.handle_UNSUBSCRIBE({'destination' : queue}, '')
|
||||
self.dbconn.close()
|
||||
self.semaphore.release()
|
||||
StompConnection.shutdown(self)
|
||||
|
||||
|
||||
class OracleStompServer(StompServer):
|
||||
def __init__(self, listen_host_and_port, oracle_host_and_port, username, passwd, db, notification_host_and_port):
|
||||
StompServer.__init__(self, listen_host_and_port, OracleStompConnection)
|
||||
self.oracle_host_and_port = oracle_host_and_port
|
||||
self.username = username
|
||||
self.passwd = passwd
|
||||
self.db = db
|
||||
self.notification_host_and_port = notification_host_and_port
|
||||
|
||||
#
|
||||
# setup
|
||||
#
|
||||
dbconn = cx_Oracle.connect('%s/%s@//%s:%s/%s' % (username, passwd, oracle_host_and_port[0], oracle_host_and_port[1], db))
|
||||
cursor = dbconn.cursor()
|
||||
for sql in SETUP_SQLS:
|
||||
cursor.execute(sql)
|
||||
cursor.execute('SELECT COUNT(*) FROM user_queue_tables WHERE queue_table = :queue', queue = QUEUE_TABLE)
|
||||
row = cursor.fetchone()
|
||||
if row[0] == 0:
|
||||
cursor.callproc('DBMS_AQADM.CREATE_QUEUE_TABLE', [], {'queue_table' : QUEUE_TABLE, 'queue_payload_type' : 'raw', 'multiple_consumers' : True})
|
||||
cursor.close()
|
||||
dbconn.close()
|
||||
|
||||
#
|
||||
# Oracle notification listener
|
||||
#
|
||||
self.listener = NotificationListener(self.notify, self.notification_host_and_port)
|
||||
self.listener.start()
|
||||
|
||||
def shutdown(self):
|
||||
print('OracleStompServer shutdown')
|
||||
|
||||
|
||||
def main():
|
||||
parser = OptionParser()
|
||||
|
||||
parser.add_option('-P', '--port', type = int, dest = 'port', default = 61613,
|
||||
help = 'Port to listen for STOMP connections. Defaults to 61613, if not specified.')
|
||||
parser.add_option('-D', '--dbhost', type = 'string', dest = 'db_host', default = None,
|
||||
help = 'Oracle hostname to connect to')
|
||||
parser.add_option('-B', '--dbport', type = 'int', dest = 'db_port', default = None,
|
||||
help = 'Oracle port to connect to')
|
||||
parser.add_option('-I', '--dbinst', type = 'string', dest = 'db_inst', default = None,
|
||||
help = 'Oracle database instance (for example "xe")')
|
||||
parser.add_option('-U', '--user', type = 'string', dest = 'db_user', default = None,
|
||||
help = 'Username for the database connection')
|
||||
parser.add_option('-W', '--passwd', type = 'string', dest = 'db_passwd', default = None,
|
||||
help = 'Password for the database connection')
|
||||
parser.add_option('-N', '--nhost', type = 'string', dest = 'notification_host',
|
||||
help = 'IP address (i.e. this machine) which is listening for Oracle AQ notifications.')
|
||||
parser.add_option('-T', '--nport', type = 'int', dest = 'notification_port',
|
||||
help = 'Port which is listening for Oracle AQ notifications.')
|
||||
|
||||
(options, args) = parser.parse_args()
|
||||
|
||||
if not options.db_host:
|
||||
parser.error("Database hostname (-D) is required")
|
||||
|
||||
if not options.db_port:
|
||||
parser.error("Database port (-B) is required")
|
||||
|
||||
if not options.db_inst:
|
||||
parser.error("Database instance (-I) is required")
|
||||
|
||||
if not options.db_user:
|
||||
parser.error("Database user (-U) is required")
|
||||
|
||||
if not options.db_passwd:
|
||||
parser.error("Database password (-W) is required")
|
||||
|
||||
if not options.notification_host:
|
||||
parser.error("Notification host or IP (-N) is required")
|
||||
|
||||
if not options.notification_port:
|
||||
parser.error("Notification port (-T) is required")
|
||||
|
||||
server = OracleStompServer(('', options.port), (options.db_host, options.db_port), options.db_user, options.db_passwd, options.db_inst, (options.notification_host, options.notification_port))
|
||||
server.start()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -0,0 +1,434 @@
|
|||
import base64
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
from optparse import OptionParser
|
||||
|
||||
from connect import Connection
|
||||
from listener import ConnectionListener, StatsListener
|
||||
from exception import NotConnectedException
|
||||
from backward import input_prompt
|
||||
|
||||
def sysout(msg, end='\n'):
|
||||
sys.stdout.write(str(msg) + end)
|
||||
|
||||
def get_commands():
|
||||
"""
|
||||
Return a list of commands available on a \link StompCLI \endlink (the command line interface
|
||||
to stomp.py)
|
||||
"""
|
||||
commands = [ ]
|
||||
for f in dir(StompCLI):
|
||||
if f.startswith('_') or f.startswith('on_') or f == 'c':
|
||||
continue
|
||||
else:
|
||||
commands.append(f)
|
||||
return commands
|
||||
|
||||
|
||||
class StompCLI(ConnectionListener):
|
||||
"""
|
||||
A command line interface to the stomp.py client. See \link stomp::internal::connect::Connection \endlink
|
||||
for more information on establishing a connection to a stomp server.
|
||||
"""
|
||||
def __init__(self, host='localhost', port=61613, user='', passcode=''):
|
||||
self.conn = Connection([(host, port)], user, passcode)
|
||||
self.conn.set_listener('', self)
|
||||
self.conn.start()
|
||||
self.__commands = get_commands()
|
||||
self.transaction_id = None
|
||||
|
||||
def __print_async(self, frame_type, headers, body):
|
||||
"""
|
||||
Utility function to print a message and setup the command prompt
|
||||
for the next input
|
||||
"""
|
||||
sysout("\r \r", end='')
|
||||
sysout(frame_type)
|
||||
for header_key in headers.keys():
|
||||
sysout('%s: %s' % (header_key, headers[header_key]))
|
||||
sysout('')
|
||||
sysout(body)
|
||||
sysout('> ', end='')
|
||||
sys.stdout.flush()
|
||||
|
||||
def on_connecting(self, host_and_port):
|
||||
"""
|
||||
\see ConnectionListener::on_connecting
|
||||
"""
|
||||
self.conn.connect(wait=True)
|
||||
|
||||
def on_disconnected(self):
|
||||
"""
|
||||
\see ConnectionListener::on_disconnected
|
||||
"""
|
||||
sysout("lost connection")
|
||||
|
||||
def on_message(self, headers, body):
|
||||
"""
|
||||
\see ConnectionListener::on_message
|
||||
|
||||
Special case: if the header 'filename' is present, the content is written out
|
||||
as a file
|
||||
"""
|
||||
if 'filename' in headers:
|
||||
content = base64.b64decode(body.encode())
|
||||
if os.path.exists(headers['filename']):
|
||||
fname = '%s.%s' % (headers['filename'], int(time.time()))
|
||||
else:
|
||||
fname = headers['filename']
|
||||
f = open(fname, 'wb')
|
||||
f.write(content)
|
||||
f.close()
|
||||
self.__print_async("MESSAGE", headers, "Saved file: %s" % fname)
|
||||
else:
|
||||
self.__print_async("MESSAGE", headers, body)
|
||||
|
||||
def on_error(self, headers, body):
|
||||
"""
|
||||
\see ConnectionListener::on_error
|
||||
"""
|
||||
self.__print_async("ERROR", headers, body)
|
||||
|
||||
def on_receipt(self, headers, body):
|
||||
"""
|
||||
\see ConnectionListener::on_receipt
|
||||
"""
|
||||
self.__print_async("RECEIPT", headers, body)
|
||||
|
||||
def on_connected(self, headers, body):
|
||||
"""
|
||||
\see ConnectionListener::on_connected
|
||||
"""
|
||||
self.__print_async("CONNECTED", headers, body)
|
||||
|
||||
def ack(self, args):
|
||||
"""
|
||||
Usage:
|
||||
ack <message-id>
|
||||
|
||||
Required Parameters:
|
||||
message-id - the id of the message being acknowledged
|
||||
|
||||
Description:
|
||||
The command 'ack' is used to acknowledge consumption of a message from a subscription using client
|
||||
acknowledgment. When a client has issued a 'subscribe' with the ack flag set to client, any messages
|
||||
received from that destination will not be considered to have been consumed (by the server) until
|
||||
the message has been acknowledged.
|
||||
"""
|
||||
if len(args) < 2:
|
||||
sysout("Expecting: ack <message-id>")
|
||||
elif not self.transaction_id:
|
||||
self.conn.ack(headers = { 'message-id' : args[1] })
|
||||
else:
|
||||
self.conn.ack(headers = { 'message-id' : args[1] }, transaction=self.transaction_id)
|
||||
|
||||
def abort(self, args):
|
||||
"""
|
||||
Usage:
|
||||
abort
|
||||
|
||||
Description:
|
||||
Roll back a transaction in progress.
|
||||
"""
|
||||
if not self.transaction_id:
|
||||
sysout("Not currently in a transaction")
|
||||
else:
|
||||
self.conn.abort(transaction = self.transaction_id)
|
||||
self.transaction_id = None
|
||||
|
||||
def begin(self, args):
|
||||
"""
|
||||
Usage:
|
||||
begin
|
||||
|
||||
Description:
|
||||
Start a transaction. Transactions in this case apply to sending and acknowledging -
|
||||
any messages sent or acknowledged during a transaction will be handled atomically based on the
|
||||
transaction.
|
||||
"""
|
||||
if self.transaction_id:
|
||||
sysout("Currently in a transaction (%s)" % self.transaction_id)
|
||||
else:
|
||||
self.transaction_id = self.conn.begin()
|
||||
sysout('Transaction id: %s' % self.transaction_id)
|
||||
|
||||
def commit(self, args):
|
||||
"""
|
||||
Usage:
|
||||
commit
|
||||
|
||||
Description:
|
||||
Commit a transaction in progress.
|
||||
"""
|
||||
if not self.transaction_id:
|
||||
sysout("Not currently in a transaction")
|
||||
else:
|
||||
sysout('Committing %s' % self.transaction_id)
|
||||
self.conn.commit(transaction=self.transaction_id)
|
||||
self.transaction_id = None
|
||||
|
||||
def disconnect(self, args):
|
||||
"""
|
||||
Usage:
|
||||
disconnect
|
||||
|
||||
Description:
|
||||
Gracefully disconnect from the server.
|
||||
"""
|
||||
try:
|
||||
self.conn.disconnect()
|
||||
except NotConnectedException:
|
||||
pass # ignore if no longer connected
|
||||
|
||||
def send(self, args):
|
||||
"""
|
||||
Usage:
|
||||
send <destination> <message>
|
||||
|
||||
Required Parameters:
|
||||
destination - where to send the message
|
||||
message - the content to send
|
||||
|
||||
Description:
|
||||
Sends a message to a destination in the messaging system.
|
||||
"""
|
||||
if len(args) < 3:
|
||||
sysout('Expecting: send <destination> <message>')
|
||||
elif not self.transaction_id:
|
||||
self.conn.send(destination=args[1], message=' '.join(args[2:]))
|
||||
else:
|
||||
self.conn.send(destination=args[1], message=' '.join(args[2:]), transaction=self.transaction_id)
|
||||
|
||||
def sendreply(self, args):
|
||||
"""
|
||||
Usage:
|
||||
sendreply <destination> <correlation-id> <message>
|
||||
|
||||
Required Parameters:
|
||||
destination - where to send the message
|
||||
correlation-id - the correlating identifier to send with the response
|
||||
message - the content to send
|
||||
|
||||
Description:
|
||||
Sends a reply message to a destination in the messaging system.
|
||||
"""
|
||||
if len(args) < 4:
|
||||
sysout('expecting: sendreply <destination> <correlation-id> <message>')
|
||||
else:
|
||||
self.conn.send(destination=args[1], message="%s\n" % ' '.join(args[3:]), headers={'correlation-id': args[2]})
|
||||
|
||||
def sendfile(self, args):
|
||||
"""
|
||||
Usage:
|
||||
sendfile <destination> <filename>
|
||||
|
||||
Required Parameters:
|
||||
destination - where to send the message
|
||||
filename - the file to send
|
||||
|
||||
Description:
|
||||
Sends a file to a destination in the messaging system.
|
||||
"""
|
||||
if len(args) < 3:
|
||||
sysout('Expecting: sendfile <destination> <filename>')
|
||||
elif not os.path.exists(args[2]):
|
||||
sysout('File %s does not exist' % args[2])
|
||||
else:
|
||||
s = open(args[2], mode='rb').read()
|
||||
msg = base64.b64encode(s).decode()
|
||||
if not self.transaction_id:
|
||||
self.conn.send(destination=args[1], message=msg, filename=args[2])
|
||||
else:
|
||||
self.conn.send(destination=args[1], message=msg, filename=args[2], transaction=self.transaction_id)
|
||||
|
||||
def subscribe(self, args):
|
||||
"""
|
||||
Usage:
|
||||
subscribe <destination> [ack]
|
||||
|
||||
Required Parameters:
|
||||
destination - the name to subscribe to
|
||||
|
||||
Optional Parameters:
|
||||
ack - how to handle acknowledgements for a message; either automatically (auto) or manually (client)
|
||||
|
||||
Description:
|
||||
Register to listen to a given destination. Like send, the subscribe command requires a destination
|
||||
header indicating which destination to subscribe to. The ack parameter is optional, and defaults to
|
||||
auto.
|
||||
"""
|
||||
if len(args) < 2:
|
||||
sysout('Expecting: subscribe <destination> [ack]')
|
||||
elif len(args) > 2:
|
||||
sysout('Subscribing to "%s" with acknowledge set to "%s"' % (args[1], args[2]))
|
||||
self.conn.subscribe(destination=args[1], ack=args[2])
|
||||
else:
|
||||
sysout('Subscribing to "%s" with auto acknowledge' % args[1])
|
||||
self.conn.subscribe(destination=args[1], ack='auto')
|
||||
|
||||
def unsubscribe(self, args):
|
||||
"""
|
||||
Usage:
|
||||
unsubscribe <destination>
|
||||
|
||||
Required Parameters:
|
||||
destination - the name to unsubscribe from
|
||||
|
||||
Description:
|
||||
Remove an existing subscription - so that the client no longer receive messages from that destination.
|
||||
"""
|
||||
if len(args) < 2:
|
||||
sysout('Expecting: unsubscribe <destination>')
|
||||
else:
|
||||
sysout('Unsubscribing from "%s"' % args[1])
|
||||
self.conn.unsubscribe(destination=args[1])
|
||||
|
||||
def stats(self, args):
|
||||
"""
|
||||
Usage:
|
||||
stats [on|off]
|
||||
|
||||
Description:
|
||||
Record statistics on messages sent, received, errors, etc. If no argument (on|off) is specified,
|
||||
dump the current statistics.
|
||||
"""
|
||||
if len(args) < 2:
|
||||
stats = self.conn.get_listener('stats')
|
||||
if stats:
|
||||
sysout(stats)
|
||||
else:
|
||||
sysout('No stats available')
|
||||
elif args[1] == 'on':
|
||||
self.conn.set_listener('stats', StatsListener())
|
||||
elif args[1] == 'off':
|
||||
self.conn.remove_listener('stats')
|
||||
else:
|
||||
sysout('Expecting: stats [on|off]')
|
||||
|
||||
def run(self, args):
|
||||
"""
|
||||
Usage:
|
||||
run <filename>
|
||||
|
||||
Description:
|
||||
Execute commands in a specified file
|
||||
"""
|
||||
if len(args) == 1:
|
||||
sysout("Expecting: run <filename>")
|
||||
elif not os.path.exists(args[1]):
|
||||
sysout("File %s was not found" % args[1])
|
||||
else:
|
||||
filecommands = open(args[1]).read().split('\n')
|
||||
for x in range(len(filecommands)):
|
||||
split = filecommands[x].split()
|
||||
if len(split) < 1:
|
||||
continue
|
||||
elif split[0] in self.__commands:
|
||||
getattr(self, split[0])(split)
|
||||
else:
|
||||
sysout('Unrecognized command "%s" at line %s' % (split[0], x))
|
||||
break
|
||||
|
||||
def help(self, args):
|
||||
"""
|
||||
Usage:
|
||||
help [command]
|
||||
|
||||
Description:
|
||||
Display info on a specified command, or a list of available commands
|
||||
"""
|
||||
if len(args) == 1:
|
||||
sysout('Usage: help <command>, where command is one of the following:')
|
||||
sysout(' ')
|
||||
for f in self.__commands:
|
||||
sysout('%s ' % f, end='')
|
||||
sysout('')
|
||||
return
|
||||
elif not hasattr(self, args[1]):
|
||||
sysout('There is no command "%s"' % args[1])
|
||||
return
|
||||
|
||||
func = getattr(self, args[1])
|
||||
if hasattr(func, '__doc__') and getattr(func, '__doc__') is not None:
|
||||
sysout(func.__doc__)
|
||||
else:
|
||||
sysout('There is no help for command "%s"' % args[1])
|
||||
man = help
|
||||
|
||||
def version(self, args):
|
||||
sysout('Stomp.py Version %s.%s' % internal.__version__)
|
||||
ver = version
|
||||
|
||||
def quit(self, args):
|
||||
pass
|
||||
exit = quit
|
||||
|
||||
|
||||
def main():
|
||||
commands = get_commands()
|
||||
|
||||
parser = OptionParser()
|
||||
|
||||
parser.add_option('-H', '--host', type = 'string', dest = 'host', default = 'localhost',
|
||||
help = 'Hostname or IP to connect to. Defaults to localhost if not specified.')
|
||||
parser.add_option('-P', '--port', type = int, dest = 'port', default = 61613,
|
||||
help = 'Port providing stomp protocol connections. Defaults to 61613 if not specified.')
|
||||
parser.add_option('-U', '--user', type = 'string', dest = 'user', default = None,
|
||||
help = 'Username for the connection')
|
||||
parser.add_option('-W', '--password', type = 'string', dest = 'password', default = None,
|
||||
help = 'Password for the connection')
|
||||
parser.add_option('-F', '--file', type = 'string', dest = 'filename',
|
||||
help = 'File containing commands to be executed, instead of prompting from the command prompt.')
|
||||
|
||||
(options, args) = parser.parse_args()
|
||||
|
||||
st = StompCLI(options.host, options.port, options.user, options.password)
|
||||
try:
|
||||
if not options.filename:
|
||||
# If the readline module is available, make command input easier
|
||||
try:
|
||||
import readline
|
||||
def stomp_completer(text, state):
|
||||
for command in commands[state:]:
|
||||
if command.startswith(text):
|
||||
return "%s " % command
|
||||
return None
|
||||
|
||||
readline.parse_and_bind("tab: complete")
|
||||
readline.set_completer(stomp_completer)
|
||||
readline.set_completer_delims("")
|
||||
except ImportError:
|
||||
pass # ignore unavailable readline module
|
||||
|
||||
while True:
|
||||
line = input_prompt("\r> ")
|
||||
if not line or line.lstrip().rstrip() == '':
|
||||
continue
|
||||
line = line.lstrip().rstrip()
|
||||
if line.startswith('quit') or line.startswith('exit') or line.startswith('disconnect'):
|
||||
break
|
||||
split = line.split()
|
||||
command = split[0]
|
||||
if command in commands:
|
||||
getattr(st, command)(split)
|
||||
else:
|
||||
sysout('Unrecognized command')
|
||||
else:
|
||||
st.run(['run', options.filename])
|
||||
except EOFError:
|
||||
pass
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
st.disconnect(None)
|
||||
|
||||
|
||||
|
||||
#
|
||||
# command line testing
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
|
@ -0,0 +1,683 @@
|
|||
import functools
|
||||
import math
|
||||
import random
|
||||
import re
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
import uuid
|
||||
import xml.dom.minidom
|
||||
|
||||
try:
|
||||
from cStringIO import StringIO
|
||||
except ImportError:
|
||||
from io import StringIO
|
||||
|
||||
try:
|
||||
import ssl
|
||||
from ssl import SSLError
|
||||
except ImportError: # python version < 2.6 without the backported ssl module
|
||||
ssl = None
|
||||
class SSLError:
|
||||
pass
|
||||
|
||||
import exception
|
||||
import listener
|
||||
import utils
|
||||
import backward
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
try:
|
||||
logging.config.fileConfig('stomp.log.conf')
|
||||
except:
|
||||
pass
|
||||
log = logging.getLogger('stomp.py')
|
||||
if not log:
|
||||
log = utils.DevNullLogger()
|
||||
|
||||
|
||||
class Connection(object):
|
||||
"""
|
||||
Represents a STOMP client connection.
|
||||
"""
|
||||
|
||||
# ========= PRIVATE MEMBERS =========
|
||||
|
||||
# List of all host names (unqualified, fully-qualified, and IP
|
||||
# addresses) that refer to the local host (both loopback interface
|
||||
# and external interfaces). This is used for determining
|
||||
# preferred targets.
|
||||
__localhost_names = [ "localhost", "127.0.0.1" ]
|
||||
|
||||
try:
|
||||
__localhost_names.append(socket.gethostbyname(socket.gethostname()))
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
__localhost_names.append(socket.gethostname())
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
__localhost_names.append(socket.getfqdn(socket.gethostname()))
|
||||
except:
|
||||
pass
|
||||
|
||||
#
|
||||
# Used to parse the STOMP "content-length" header lines,
|
||||
#
|
||||
__content_length_re = re.compile('^content-length[:]\\s*(?P<value>[0-9]+)', re.MULTILINE)
|
||||
|
||||
|
||||
def __init__(self,
|
||||
host_and_ports = [ ('localhost', 61613) ],
|
||||
user = None,
|
||||
passcode = None,
|
||||
prefer_localhost = True,
|
||||
try_loopback_connect = True,
|
||||
reconnect_sleep_initial = 0.1,
|
||||
reconnect_sleep_increase = 0.5,
|
||||
reconnect_sleep_jitter = 0.1,
|
||||
reconnect_sleep_max = 60.0,
|
||||
reconnect_attempts_max = 3,
|
||||
use_ssl = False,
|
||||
ssl_key_file = None,
|
||||
ssl_cert_file = None,
|
||||
ssl_ca_certs = None,
|
||||
ssl_cert_validator = None):
|
||||
"""
|
||||
Initialize and start this connection.
|
||||
|
||||
\param host_and_ports
|
||||
a list of (host, port) tuples.
|
||||
|
||||
\param prefer_localhost
|
||||
if True and the local host is mentioned in the (host,
|
||||
port) tuples, try to connect to this first
|
||||
|
||||
\param try_loopback_connect
|
||||
if True and the local host is found in the host
|
||||
tuples, try connecting to it using loopback interface
|
||||
(127.0.0.1)
|
||||
|
||||
\param reconnect_sleep_initial
|
||||
initial delay in seconds to wait before reattempting
|
||||
to establish a connection if connection to any of the
|
||||
hosts fails.
|
||||
|
||||
\param reconnect_sleep_increase
|
||||
factor by which the sleep delay is increased after
|
||||
each connection attempt. For example, 0.5 means
|
||||
to wait 50% longer than before the previous attempt,
|
||||
1.0 means wait twice as long, and 0.0 means keep
|
||||
the delay constant.
|
||||
|
||||
\param reconnect_sleep_max
|
||||
maximum delay between connection attempts, regardless
|
||||
of the reconnect_sleep_increase.
|
||||
|
||||
\param reconnect_sleep_jitter
|
||||
random additional time to wait (as a percentage of
|
||||
the time determined using the previous parameters)
|
||||
between connection attempts in order to avoid
|
||||
stampeding. For example, a value of 0.1 means to wait
|
||||
an extra 0%-10% (randomly determined) of the delay
|
||||
calculated using the previous three parameters.
|
||||
|
||||
\param reconnect_attempts_max
|
||||
maximum attempts to reconnect
|
||||
|
||||
\param use_ssl
|
||||
connect using SSL to the socket. This wraps the
|
||||
socket in a SSL connection. The constructor will
|
||||
raise an exception if you ask for SSL, but it can't
|
||||
find the SSL module.
|
||||
|
||||
\param ssl_cert_file
|
||||
the path to a X509 certificate
|
||||
|
||||
\param ssl_key_file
|
||||
the path to a X509 key file
|
||||
|
||||
\param ssl_ca_certs
|
||||
the path to the a file containing CA certificates
|
||||
to validate the server against. If this is not set,
|
||||
server side certificate validation is not done.
|
||||
|
||||
\param ssl_cert_validator
|
||||
function which performs extra validation on the client
|
||||
certificate, for example checking the returned
|
||||
certificate has a commonName attribute equal to the
|
||||
hostname (to avoid man in the middle attacks)
|
||||
|
||||
The signature is:
|
||||
(OK, err_msg) = validation_function(cert, hostname)
|
||||
|
||||
where OK is a boolean, and cert is a certificate structure
|
||||
as returned by ssl.SSLSocket.getpeercert()
|
||||
"""
|
||||
|
||||
sorted_host_and_ports = []
|
||||
sorted_host_and_ports.extend(host_and_ports)
|
||||
|
||||
#
|
||||
# If localhost is preferred, make sure all (host, port) tuples that refer to the local host come first in the list
|
||||
#
|
||||
if prefer_localhost:
|
||||
sorted_host_and_ports.sort(key = self.is_localhost)
|
||||
|
||||
#
|
||||
# If the user wishes to attempt connecting to local ports using the loopback interface, for each (host, port) tuple
|
||||
# referring to a local host, add an entry with the host name replaced by 127.0.0.1 if it doesn't exist already
|
||||
#
|
||||
loopback_host_and_ports = []
|
||||
if try_loopback_connect:
|
||||
for host_and_port in sorted_host_and_ports:
|
||||
if self.is_localhost(host_and_port) == 1:
|
||||
port = host_and_port[1]
|
||||
if (not ("127.0.0.1", port) in sorted_host_and_ports
|
||||
and not ("localhost", port) in sorted_host_and_ports):
|
||||
loopback_host_and_ports.append(("127.0.0.1", port))
|
||||
|
||||
#
|
||||
# Assemble the final, possibly sorted list of (host, port) tuples
|
||||
#
|
||||
self.__host_and_ports = []
|
||||
self.__host_and_ports.extend(loopback_host_and_ports)
|
||||
self.__host_and_ports.extend(sorted_host_and_ports)
|
||||
|
||||
self.__recvbuf = ''
|
||||
|
||||
self.__listeners = {}
|
||||
|
||||
self.__reconnect_sleep_initial = reconnect_sleep_initial
|
||||
self.__reconnect_sleep_increase = reconnect_sleep_increase
|
||||
self.__reconnect_sleep_jitter = reconnect_sleep_jitter
|
||||
self.__reconnect_sleep_max = reconnect_sleep_max
|
||||
self.__reconnect_attempts_max = reconnect_attempts_max
|
||||
|
||||
self.__connect_headers = {}
|
||||
if user is not None and passcode is not None:
|
||||
self.__connect_headers['login'] = user
|
||||
self.__connect_headers['passcode'] = passcode
|
||||
|
||||
self.__socket = None
|
||||
self.__socket_semaphore = threading.BoundedSemaphore(1)
|
||||
self.__current_host_and_port = None
|
||||
|
||||
self.__receiver_thread_exit_condition = threading.Condition()
|
||||
self.__receiver_thread_exited = False
|
||||
|
||||
self.blocking = None
|
||||
|
||||
if use_ssl and not ssl:
|
||||
raise Exception("SSL connection requested, but SSL library not found.")
|
||||
self.__ssl = use_ssl
|
||||
self.__ssl_cert_file = ssl_cert_file
|
||||
self.__ssl_key_file = ssl_key_file
|
||||
self.__ssl_ca_certs = ssl_ca_certs
|
||||
self.__ssl_cert_validator = ssl_cert_validator
|
||||
|
||||
def is_localhost(self, host_and_port):
|
||||
"""
|
||||
Return true if the specified host+port is a member of the 'localhost' list of hosts
|
||||
"""
|
||||
(host, port) = host_and_port
|
||||
if host in Connection.__localhost_names:
|
||||
return 1
|
||||
else:
|
||||
return 2
|
||||
|
||||
#
|
||||
# Manage the connection
|
||||
#
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start the connection. This should be called after all
|
||||
listeners have been registered. If this method is not called,
|
||||
no frames will be received by the connection.
|
||||
"""
|
||||
self.__running = True
|
||||
self.__attempt_connection()
|
||||
thread = threading.Thread(None, self.__receiver_loop)
|
||||
thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Stop the connection. This is equivalent to calling
|
||||
disconnect() but will do a clean shutdown by waiting for the
|
||||
receiver thread to exit.
|
||||
"""
|
||||
self.disconnect()
|
||||
|
||||
self.__receiver_thread_exit_condition.acquire()
|
||||
if not self.__receiver_thread_exited:
|
||||
self.__receiver_thread_exit_condition.wait()
|
||||
self.__receiver_thread_exit_condition.release()
|
||||
|
||||
def get_host_and_port(self):
|
||||
"""
|
||||
Return a (host, port) tuple indicating which STOMP host and
|
||||
port is currently connected, or None if there is currently no
|
||||
connection.
|
||||
"""
|
||||
return self.__current_host_and_port
|
||||
|
||||
def is_connected(self):
|
||||
"""
|
||||
Return true if the socket managed by this connection is connected
|
||||
"""
|
||||
try:
|
||||
return self.__socket is not None and self.__socket.getsockname()[1] != 0
|
||||
except socket.error:
|
||||
return False
|
||||
|
||||
#
|
||||
# Manage objects listening to incoming frames
|
||||
#
|
||||
|
||||
def set_listener(self, name, listener):
|
||||
"""
|
||||
Set a named listener on this connection
|
||||
|
||||
\see listener::ConnectionListener
|
||||
|
||||
\param name the name of the listener
|
||||
\param listener the listener object
|
||||
"""
|
||||
self.__listeners[name] = listener
|
||||
|
||||
def remove_listener(self, name):
|
||||
"""
|
||||
Remove a listener according to the specified name
|
||||
|
||||
\param name the name of the listener to remove
|
||||
"""
|
||||
del self.__listeners[name]
|
||||
|
||||
def get_listener(self, name):
|
||||
"""
|
||||
Return a named listener
|
||||
|
||||
\param name the listener to return
|
||||
"""
|
||||
if name in self.__listeners:
|
||||
return self.__listeners[name]
|
||||
else:
|
||||
return None
|
||||
|
||||
#
|
||||
# STOMP transmissions
|
||||
#
|
||||
|
||||
def subscribe(self, headers={}, **keyword_headers):
|
||||
"""
|
||||
Send a SUBSCRIBE frame to subscribe to a queue
|
||||
"""
|
||||
self.__send_frame_helper('SUBSCRIBE', '', utils.merge_headers([headers, keyword_headers]), [ 'destination' ])
|
||||
|
||||
def unsubscribe(self, headers={}, **keyword_headers):
|
||||
"""
|
||||
Send an UNSUBSCRIBE frame to unsubscribe from a queue
|
||||
"""
|
||||
self.__send_frame_helper('UNSUBSCRIBE', '', utils.merge_headers([headers, keyword_headers]), [ ('destination', 'id') ])
|
||||
|
||||
def send(self, message='', headers={}, **keyword_headers):
|
||||
"""
|
||||
Send a message (SEND) frame
|
||||
"""
|
||||
if '\x00' in message:
|
||||
content_length_headers = {'content-length': len(message)}
|
||||
else:
|
||||
content_length_headers = {}
|
||||
self.__send_frame_helper('SEND', message, utils.merge_headers([headers,
|
||||
keyword_headers,
|
||||
content_length_headers]), [ 'destination' ])
|
||||
self.__notify('send', headers, message)
|
||||
|
||||
def ack(self, headers={}, **keyword_headers):
|
||||
"""
|
||||
Send an ACK frame, to acknowledge receipt of a message
|
||||
"""
|
||||
self.__send_frame_helper('ACK', '', utils.merge_headers([headers, keyword_headers]), [ 'message-id' ])
|
||||
|
||||
def begin(self, headers={}, **keyword_headers):
|
||||
"""
|
||||
Send a BEGIN frame to start a transaction
|
||||
"""
|
||||
use_headers = utils.merge_headers([headers, keyword_headers])
|
||||
if not 'transaction' in use_headers.keys():
|
||||
use_headers['transaction'] = str(uuid.uuid4())
|
||||
self.__send_frame_helper('BEGIN', '', use_headers, [ 'transaction' ])
|
||||
return use_headers['transaction']
|
||||
|
||||
def abort(self, headers={}, **keyword_headers):
|
||||
"""
|
||||
Send an ABORT frame to rollback a transaction
|
||||
"""
|
||||
self.__send_frame_helper('ABORT', '', utils.merge_headers([headers, keyword_headers]), [ 'transaction' ])
|
||||
|
||||
def commit(self, headers={}, **keyword_headers):
|
||||
"""
|
||||
Send a COMMIT frame to commit a transaction (send pending messages)
|
||||
"""
|
||||
self.__send_frame_helper('COMMIT', '', utils.merge_headers([headers, keyword_headers]), [ 'transaction' ])
|
||||
|
||||
def connect(self, headers={}, **keyword_headers):
|
||||
"""
|
||||
Send a CONNECT frame to start a connection
|
||||
"""
|
||||
if 'wait' in keyword_headers and keyword_headers['wait']:
|
||||
while not self.is_connected(): time.sleep(0.1)
|
||||
del keyword_headers['wait']
|
||||
self.__send_frame_helper('CONNECT', '', utils.merge_headers([self.__connect_headers, headers, keyword_headers]), [ ])
|
||||
|
||||
def disconnect(self, headers={}, **keyword_headers):
|
||||
"""
|
||||
Send a DISCONNECT frame to finish a connection
|
||||
"""
|
||||
self.__send_frame_helper('DISCONNECT', '', utils.merge_headers([self.__connect_headers, headers, keyword_headers]), [ ])
|
||||
self.__running = False
|
||||
if self.__socket is not None:
|
||||
if self.__ssl:
|
||||
#
|
||||
# Even though we don't want to use the socket, unwrap is the only API method which does a proper SSL shutdown
|
||||
#
|
||||
try:
|
||||
self.__socket = self.__socket.unwrap()
|
||||
except Exception:
|
||||
#
|
||||
# unwrap seems flaky on Win with the backported ssl mod, so catch any exception and log it
|
||||
#
|
||||
_, e, _ = sys.exc_info()
|
||||
log.warn(e)
|
||||
elif hasattr(socket, 'SHUT_RDWR'):
|
||||
self.__socket.shutdown(socket.SHUT_RDWR)
|
||||
#
|
||||
# split this into a separate check, because sometimes the socket is nulled between shutdown and this call
|
||||
#
|
||||
if self.__socket is not None:
|
||||
self.__socket.close()
|
||||
self.__current_host_and_port = None
|
||||
|
||||
def __convert_dict(self, payload):
|
||||
"""
|
||||
Encode a python dictionary as a <map>...</map> structure.
|
||||
"""
|
||||
xmlStr = "<map>\n"
|
||||
for key in payload:
|
||||
xmlStr += "<entry>\n"
|
||||
xmlStr += "<string>%s</string>" % key
|
||||
xmlStr += "<string>%s</string>" % payload[key]
|
||||
xmlStr += "</entry>\n"
|
||||
xmlStr += "</map>"
|
||||
return xmlStr
|
||||
|
||||
def __send_frame_helper(self, command, payload, headers, required_header_keys):
|
||||
"""
|
||||
Helper function for sending a frame after verifying that a
|
||||
given set of headers are present.
|
||||
|
||||
\param command the command to send
|
||||
|
||||
\param payload the frame's payload
|
||||
|
||||
\param headers a dictionary containing the frame's headers
|
||||
|
||||
\param required_header_keys a sequence enumerating all
|
||||
required header keys. If an element in this sequence is itself
|
||||
a tuple, that tuple is taken as a list of alternatives, one of
|
||||
which must be present.
|
||||
|
||||
\throws ArgumentError if one of the required header keys is
|
||||
not present in the header map.
|
||||
"""
|
||||
for required_header_key in required_header_keys:
|
||||
if type(required_header_key) == tuple:
|
||||
found_alternative = False
|
||||
for alternative in required_header_key:
|
||||
if alternative in headers.keys():
|
||||
found_alternative = True
|
||||
if not found_alternative:
|
||||
raise KeyError("Command %s requires one of the following headers: %s" % (command, str(required_header_key)))
|
||||
elif not required_header_key in headers.keys():
|
||||
raise KeyError("Command %s requires header %r" % (command, required_header_key))
|
||||
self.__send_frame(command, headers, payload)
|
||||
|
||||
def __send_frame(self, command, headers={}, payload=''):
|
||||
"""
|
||||
Send a STOMP frame.
|
||||
|
||||
\param command the frame command
|
||||
|
||||
\param headers a map of headers (key-val pairs)
|
||||
|
||||
\param payload the message payload
|
||||
"""
|
||||
if type(payload) == dict:
|
||||
headers["transformation"] = "jms-map-xml"
|
||||
payload = self.__convert_dict(payload)
|
||||
|
||||
if self.__socket is not None:
|
||||
try:
|
||||
frame = [ command + '\n' ]
|
||||
for key, val in headers.items():
|
||||
frame.append('%s:%s\n' % (key, val))
|
||||
frame.append('\n')
|
||||
if payload:
|
||||
frame.append(payload)
|
||||
frame.append('\x00')
|
||||
frame = ''.join(frame)
|
||||
self.__socket_semaphore.acquire()
|
||||
try:
|
||||
self.__socket.sendall(frame.encode())
|
||||
finally:
|
||||
self.__socket_semaphore.release()
|
||||
except Exception:
|
||||
_, e, _ = sys.exc_info()
|
||||
print(e)
|
||||
log.debug("Sent frame: type=%s, headers=%r, body=%r" % (command, headers, payload))
|
||||
else:
|
||||
raise exception.NotConnectedException()
|
||||
|
||||
def __notify(self, frame_type, headers=None, body=None):
|
||||
"""
|
||||
Utility function for notifying listeners of incoming and outgoing messages
|
||||
|
||||
\param frame_type the type of message
|
||||
|
||||
\param headers the map of headers associated with the message
|
||||
|
||||
\param body the content of the message
|
||||
"""
|
||||
for listener in self.__listeners.values():
|
||||
if not hasattr(listener, 'on_%s' % frame_type):
|
||||
log.debug('listener %s has no method on_%s' % (listener, frame_type))
|
||||
continue
|
||||
|
||||
if frame_type == 'connecting':
|
||||
listener.on_connecting(self.__current_host_and_port)
|
||||
continue
|
||||
|
||||
notify_func = getattr(listener, 'on_%s' % frame_type)
|
||||
params = backward.get_func_argcount(notify_func)
|
||||
if params >= 3:
|
||||
notify_func(headers, body)
|
||||
elif params == 2:
|
||||
notify_func(headers)
|
||||
else:
|
||||
notify_func()
|
||||
|
||||
def __receiver_loop(self):
|
||||
"""
|
||||
Main loop listening for incoming data.
|
||||
"""
|
||||
try:
|
||||
try:
|
||||
threading.currentThread().setName("StompReceiver")
|
||||
while self.__running:
|
||||
log.debug('starting receiver loop')
|
||||
|
||||
if self.__socket is None:
|
||||
break
|
||||
|
||||
try:
|
||||
try:
|
||||
self.__notify('connecting')
|
||||
|
||||
while self.__running:
|
||||
frames = self.__read()
|
||||
|
||||
for frame in frames:
|
||||
(frame_type, headers, body) = utils.parse_frame(frame)
|
||||
log.debug("Received frame: result=%r, headers=%r, body=%r" % (frame_type, headers, body))
|
||||
frame_type = frame_type.lower()
|
||||
if frame_type in [ 'connected', 'message', 'receipt', 'error' ]:
|
||||
self.__notify(frame_type, headers, body)
|
||||
else:
|
||||
log.warning('Unknown response frame type: "%s" (frame length was %d)' % (frame_type, len(frame)))
|
||||
finally:
|
||||
try:
|
||||
self.__socket.close()
|
||||
except:
|
||||
pass # ignore errors when attempting to close socket
|
||||
self.__socket = None
|
||||
self.__current_host_and_port = None
|
||||
except exception.ConnectionClosedException:
|
||||
if self.__running:
|
||||
log.error("Lost connection")
|
||||
self.__notify('disconnected')
|
||||
#
|
||||
# Clear out any half-received messages after losing connection
|
||||
#
|
||||
self.__recvbuf = ''
|
||||
continue
|
||||
else:
|
||||
break
|
||||
except:
|
||||
log.exception("An unhandled exception was encountered in the stomp receiver loop")
|
||||
|
||||
finally:
|
||||
self.__receiver_thread_exit_condition.acquire()
|
||||
self.__receiver_thread_exited = True
|
||||
self.__receiver_thread_exit_condition.notifyAll()
|
||||
self.__receiver_thread_exit_condition.release()
|
||||
|
||||
def __read(self):
|
||||
"""
|
||||
Read the next frame(s) from the socket.
|
||||
"""
|
||||
fastbuf = StringIO()
|
||||
while self.__running:
|
||||
try:
|
||||
c = self.__socket.recv(1024)
|
||||
c = c.decode()
|
||||
except Exception:
|
||||
_, e, _ = sys.exc_info()
|
||||
c = ''
|
||||
if len(c) == 0:
|
||||
raise exception.ConnectionClosedException
|
||||
fastbuf.write(c)
|
||||
if '\x00' in c:
|
||||
break
|
||||
self.__recvbuf += fastbuf.getvalue()
|
||||
fastbuf.close()
|
||||
result = []
|
||||
|
||||
if len(self.__recvbuf) > 0 and self.__running:
|
||||
while True:
|
||||
pos = self.__recvbuf.find('\x00')
|
||||
|
||||
if pos >= 0:
|
||||
frame = self.__recvbuf[0:pos]
|
||||
preamble_end = frame.find('\n\n')
|
||||
if preamble_end >= 0:
|
||||
content_length_match = Connection.__content_length_re.search(frame[0:preamble_end])
|
||||
if content_length_match:
|
||||
content_length = int(content_length_match.group('value'))
|
||||
content_offset = preamble_end + 2
|
||||
frame_size = content_offset + content_length
|
||||
if frame_size > len(frame):
|
||||
#
|
||||
# Frame contains NUL bytes, need to read more
|
||||
#
|
||||
if frame_size < len(self.__recvbuf):
|
||||
pos = frame_size
|
||||
frame = self.__recvbuf[0:pos]
|
||||
else:
|
||||
#
|
||||
# Haven't read enough data yet, exit loop and wait for more to arrive
|
||||
#
|
||||
break
|
||||
result.append(frame)
|
||||
self.__recvbuf = self.__recvbuf[pos+1:]
|
||||
else:
|
||||
break
|
||||
return result
|
||||
|
||||
def __attempt_connection(self):
|
||||
"""
|
||||
Try connecting to the (host, port) tuples specified at construction time.
|
||||
"""
|
||||
sleep_exp = 1
|
||||
connect_count = 0
|
||||
while self.__running and self.__socket is None and connect_count < self.__reconnect_attempts_max:
|
||||
for host_and_port in self.__host_and_ports:
|
||||
try:
|
||||
log.debug("Attempting connection to host %s, port %s" % host_and_port)
|
||||
self.__socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
if self.__ssl: # wrap socket
|
||||
if self.__ssl_ca_certs:
|
||||
cert_validation = ssl.CERT_REQUIRED
|
||||
else:
|
||||
cert_validation = ssl.CERT_NONE
|
||||
self.__socket = ssl.wrap_socket(self.__socket, keyfile = self.__ssl_key_file,
|
||||
certfile = self.__ssl_cert_file, cert_reqs = cert_validation,
|
||||
ca_certs=self.__ssl_ca_certs, ssl_version = ssl.PROTOCOL_SSLv3)
|
||||
self.__socket.settimeout(None)
|
||||
if self.blocking is not None:
|
||||
self.__socket.setblocking(self.blocking)
|
||||
self.__socket.connect(host_and_port)
|
||||
|
||||
#
|
||||
# Validate server cert
|
||||
#
|
||||
if self.__ssl and self.__ssl_cert_validator:
|
||||
cert = self.__socket.getpeercert()
|
||||
(ok, errmsg) = apply(self.__ssl_cert_validator, (cert, host_and_port[0]))
|
||||
if not ok:
|
||||
raise SSLError("Server certificate validation failed: %s" % errmsg)
|
||||
|
||||
self.__current_host_and_port = host_and_port
|
||||
log.info("Established connection to host %s, port %s" % host_and_port)
|
||||
break
|
||||
except socket.error:
|
||||
self.__socket = None
|
||||
if isinstance(sys.exc_info()[1], tuple):
|
||||
exc = sys.exc_info()[1][1]
|
||||
else:
|
||||
exc = sys.exc_info()[1]
|
||||
connect_count += 1
|
||||
print(exc)
|
||||
log.warning("Could not connect to host %s, port %s: %s" % (host_and_port[0], host_and_port[1], exc))
|
||||
|
||||
if self.__socket is None:
|
||||
sleep_duration = (min(self.__reconnect_sleep_max,
|
||||
((self.__reconnect_sleep_initial / (1.0 + self.__reconnect_sleep_increase))
|
||||
* math.pow(1.0 + self.__reconnect_sleep_increase, sleep_exp)))
|
||||
* (1.0 + random.random() * self.__reconnect_sleep_jitter))
|
||||
sleep_end = time.time() + sleep_duration
|
||||
log.debug("Sleeping for %.1f seconds before attempting reconnect" % sleep_duration)
|
||||
while self.__running and time.time() < sleep_end:
|
||||
time.sleep(0.2)
|
||||
|
||||
if sleep_duration < self.__reconnect_sleep_max:
|
||||
sleep_exp += 1
|
||||
|
||||
if not self.__socket:
|
||||
raise exception.ReconnectFailedException
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
class ConnectionClosedException(Exception):
|
||||
"""
|
||||
Raised in the receiver thread when the connection has been closed
|
||||
by the server.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class NotConnectedException(Exception):
|
||||
"""
|
||||
Raised by Connection.__send_frame when there is currently no server
|
||||
connection.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ReconnectFailedException(Exception):
|
||||
"""
|
||||
Raised by Connection.__attempt_connection when reconnection attempts
|
||||
have exceeded Connection.__reconnect_attempts_max.
|
||||
"""
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
class ConnectionListener(object):
|
||||
"""
|
||||
This class should be used as a base class for objects registered
|
||||
using Connection.set_listener().
|
||||
"""
|
||||
def on_connecting(self, host_and_port):
|
||||
"""
|
||||
Called by the STOMP connection once a TCP/IP connection to the
|
||||
STOMP server has been established or re-established. Note that
|
||||
at this point, no connection has been established on the STOMP
|
||||
protocol level. For this, you need to invoke the "connect"
|
||||
method on the connection.
|
||||
|
||||
\param host_and_port a tuple containing the host name and port
|
||||
number to which the connection has been established.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_connected(self, headers, body):
|
||||
"""
|
||||
Called by the STOMP connection when a CONNECTED frame is
|
||||
received, that is after a connection has been established or
|
||||
re-established.
|
||||
|
||||
\param headers a dictionary containing all headers sent by the
|
||||
server as key/value pairs.
|
||||
|
||||
\param body the frame's payload. This is usually empty for
|
||||
CONNECTED frames.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_disconnected(self):
|
||||
"""
|
||||
Called by the STOMP connection when a TCP/IP connection to the
|
||||
STOMP server has been lost. No messages should be sent via
|
||||
the connection until it has been reestablished.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_message(self, headers, body):
|
||||
"""
|
||||
Called by the STOMP connection when a MESSAGE frame is
|
||||
received.
|
||||
|
||||
\param headers a dictionary containing all headers sent by the
|
||||
server as key/value pairs.
|
||||
|
||||
\param body the frame's payload - the message body.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_receipt(self, headers, body):
|
||||
"""
|
||||
Called by the STOMP connection when a RECEIPT frame is
|
||||
received, sent by the server if requested by the client using
|
||||
the 'receipt' header.
|
||||
|
||||
\param headers a dictionary containing all headers sent by the
|
||||
server as key/value pairs.
|
||||
|
||||
\param body the frame's payload. This is usually empty for
|
||||
RECEIPT frames.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_error(self, headers, body):
|
||||
"""
|
||||
Called by the STOMP connection when an ERROR frame is
|
||||
received.
|
||||
|
||||
\param headers a dictionary containing all headers sent by the
|
||||
server as key/value pairs.
|
||||
|
||||
\param body the frame's payload - usually a detailed error
|
||||
description.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_send(self, headers, body):
|
||||
"""
|
||||
Called by the STOMP connection when it is in the process of sending a message
|
||||
|
||||
\param headers a dictionary containing the headers that will be sent with this message
|
||||
|
||||
\param body the message payload
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StatsListener(ConnectionListener):
|
||||
"""
|
||||
A connection listener for recording statistics on messages sent and received.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.errors = 0
|
||||
self.connections = 0
|
||||
self.messages_recd = 0
|
||||
self.messages_sent = 0
|
||||
|
||||
def on_error(self, headers, message):
|
||||
"""
|
||||
\see ConnectionListener::on_error
|
||||
"""
|
||||
self.errors += 1
|
||||
|
||||
def on_connecting(self, host_and_port):
|
||||
"""
|
||||
\see ConnectionListener::on_connecting
|
||||
"""
|
||||
self.connections += 1
|
||||
|
||||
def on_message(self, headers, message):
|
||||
"""
|
||||
\see ConnectionListener::on_message
|
||||
"""
|
||||
self.messages_recd += 1
|
||||
|
||||
def on_send(self, headers, message):
|
||||
"""
|
||||
\see ConnectionListener::on_send
|
||||
"""
|
||||
self.messages_sent += 1
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
Return a string containing the current statistics (messages sent and received,
|
||||
errors, etc)
|
||||
"""
|
||||
return '''Connections: %s
|
||||
Messages sent: %s
|
||||
Messages received: %s
|
||||
Errors: %s''' % (self.connections, self.messages_sent, self.messages_recd, self.errors)
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.split(__file__)[0])
|
||||
|
||||
__all__ = [ 'basictest', 'ssltest', 'transtest', 'rabbitmqtest', 'threadingtest' ]
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
import time
|
||||
import unittest
|
||||
|
||||
import stomp
|
||||
|
||||
import testlistener
|
||||
|
||||
|
||||
class TestBasicSend(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def testbasic(self):
|
||||
conn = stomp.Connection([('127.0.0.2', 61613), ('localhost', 61613)])
|
||||
listener = testlistener.TestListener()
|
||||
conn.set_listener('', listener)
|
||||
conn.start()
|
||||
conn.connect(wait=True)
|
||||
conn.subscribe(destination='/queue/test', ack='auto')
|
||||
|
||||
conn.send('this is a test', destination='/queue/test')
|
||||
|
||||
time.sleep(3)
|
||||
conn.disconnect()
|
||||
|
||||
self.assert_(listener.connections == 1, 'should have received 1 connection acknowledgement')
|
||||
self.assert_(listener.messages == 1, 'should have received 1 message')
|
||||
self.assert_(listener.errors == 0, 'should not have received any errors')
|
||||
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(TestBasicSend)
|
||||
unittest.TextTestRunner(verbosity=2).run(suite)
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
import time
|
||||
import unittest
|
||||
|
||||
import stomp
|
||||
|
||||
from . import testlistener
|
||||
|
||||
|
||||
class TestRabbitMQSend(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def testbasic(self):
|
||||
conn = stomp.Connection([('0.0.0.0', 61613), ('127.0.0.1', 61613)], 'guest', 'guest')
|
||||
listener = testlistener.TestListener()
|
||||
conn.set_listener('', listener)
|
||||
conn.start()
|
||||
conn.connect(wait=True)
|
||||
conn.subscribe(destination='/queue/test', ack='auto')
|
||||
|
||||
conn.send('this is a test', destination='/queue/test')
|
||||
|
||||
time.sleep(2)
|
||||
conn.disconnect()
|
||||
|
||||
self.assert_(listener.connections == 1, 'should have received 1 connection acknowledgement')
|
||||
self.assert_(listener.messages == 1, 'should have received 1 message')
|
||||
self.assert_(listener.errors == 0, 'should not have received any errors')
|
||||
|
||||
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(TestRabbitMQSend)
|
||||
unittest.TextTestRunner(verbosity=2).run(suite)
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
import time
|
||||
import unittest
|
||||
|
||||
import stomp
|
||||
|
||||
import testlistener
|
||||
|
||||
|
||||
class TestSSLSend(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def testsslbasic(self):
|
||||
conn = stomp.Connection([('127.0.0.1', 61612), ('localhost', 61612)], use_ssl = True)
|
||||
listener = testlistener.TestListener()
|
||||
conn.set_listener('', listener)
|
||||
conn.start()
|
||||
conn.connect(wait=True)
|
||||
conn.subscribe(destination='/queue/test', ack='auto')
|
||||
|
||||
conn.send('this is a test', destination='/queue/test')
|
||||
|
||||
time.sleep(3)
|
||||
conn.disconnect()
|
||||
|
||||
self.assert_(listener.connections == 1, 'should have received 1 connection acknowledgement')
|
||||
self.assert_(listener.messages == 1, 'should have received 1 message')
|
||||
self.assert_(listener.errors == 0, 'should not have received any errors')
|
||||
|
||||
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(TestSSLSend)
|
||||
unittest.TextTestRunner(verbosity=2).run(suite)
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
from stomp import ConnectionListener
|
||||
|
||||
class TestListener(ConnectionListener):
|
||||
def __init__(self):
|
||||
self.errors = 0
|
||||
self.connections = 0
|
||||
self.messages = 0
|
||||
|
||||
def on_error(self, headers, message):
|
||||
print('received an error %s' % message)
|
||||
self.errors = self.errors + 1
|
||||
|
||||
def on_connecting(self, host_and_port):
|
||||
print('connecting %s %s' % host_and_port)
|
||||
self.connections = self.connections + 1
|
||||
|
||||
def on_message(self, headers, message):
|
||||
print('received a message %s' % message)
|
||||
self.messages = self.messages + 1
|
||||
|
|
@ -0,0 +1,124 @@
|
|||
try:
|
||||
from queue import Queue, Empty, Full
|
||||
except ImportError:
|
||||
from Queue import Queue, Empty, Full
|
||||
import threading
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import stomp
|
||||
|
||||
import testlistener
|
||||
|
||||
class MQ(object):
|
||||
def __init__(self):
|
||||
self.connection = stomp.Connection([('localhost', 61613)])
|
||||
self.connection.set_listener('', None)
|
||||
self.connection.start()
|
||||
self.connection.connect(wait=True)
|
||||
|
||||
def send(self, topic, msg, persistent='true', retry=False):
|
||||
self.connection.send(destination="/topic/%s" % topic, message=msg,
|
||||
persistent=persistent)
|
||||
mq = MQ()
|
||||
|
||||
|
||||
class TestThreading(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Test that mq sends don't wedge their threads.
|
||||
|
||||
Starts a number of sender threads, and runs for a set amount of
|
||||
time. Each thread sends messages as fast as it can, and after each
|
||||
send, pops from a Queue. Meanwhile, the Queue is filled with one
|
||||
marker per second. If the Queue fills, the test fails, as that
|
||||
indicates that all threads are no longer emptying the queue, and thus
|
||||
must be wedged in their send() calls.
|
||||
|
||||
"""
|
||||
self.Q = Queue(10)
|
||||
self.Cmd = Queue()
|
||||
self.Error = Queue()
|
||||
self.clients = 20
|
||||
self.threads = []
|
||||
self.runfor = 20
|
||||
for i in range(0, self.clients):
|
||||
t = threading.Thread(name="client %s" % i,
|
||||
target=self.make_sender(i))
|
||||
t.setDaemon(1)
|
||||
self.threads.append(t)
|
||||
|
||||
def tearDown(self):
|
||||
for t in self.threads:
|
||||
if not t.isAlive:
|
||||
print("thread", t, "died")
|
||||
self.Cmd.put('stop')
|
||||
for t in self.threads:
|
||||
t.join()
|
||||
print()
|
||||
print()
|
||||
errs = []
|
||||
while 1:
|
||||
try:
|
||||
errs.append(self.Error.get(block=False))
|
||||
except Empty:
|
||||
break
|
||||
print("Dead threads:", len(errs), "of", self.clients)
|
||||
etype = {}
|
||||
for ec, ev, tb in errs:
|
||||
if ec in etype:
|
||||
etype[ec] = etype[ec] + 1
|
||||
else:
|
||||
etype[ec] = 1
|
||||
for k in sorted(etype.keys()):
|
||||
print("%s: %s" % (k, etype[k]))
|
||||
mq.connection.disconnect()
|
||||
|
||||
def make_sender(self, i):
|
||||
Q = self.Q
|
||||
Cmd = self.Cmd
|
||||
Error = self.Error
|
||||
def send(i=i, Q=Q, Cmd=Cmd, Error=Error):
|
||||
counter = 0
|
||||
print("%s starting" % i)
|
||||
try:
|
||||
while 1:
|
||||
# print "%s sending %s" % (i, counter)
|
||||
try:
|
||||
mq.send('testclientwedge',
|
||||
'Message %s:%s' % (i, counter))
|
||||
except:
|
||||
Error.put(sys.exc_info())
|
||||
# thread will die
|
||||
raise
|
||||
else:
|
||||
# print "%s sent %s" % (i, counter)
|
||||
try:
|
||||
Q.get(block=False)
|
||||
except Empty:
|
||||
pass
|
||||
try:
|
||||
if Cmd.get(block=False):
|
||||
break
|
||||
except Empty:
|
||||
pass
|
||||
counter +=1
|
||||
finally:
|
||||
print("final", i, counter)
|
||||
return send
|
||||
|
||||
def test_threads_dont_wedge(self):
|
||||
for t in self.threads:
|
||||
t.start()
|
||||
start = time.time()
|
||||
while time.time() - start < self.runfor:
|
||||
try:
|
||||
self.Q.put(1, False)
|
||||
time.sleep(1.0)
|
||||
except Full:
|
||||
assert False, "Failed: 'request' queue filled up"
|
||||
print("passed")
|
||||
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(TestThreading)
|
||||
unittest.TextTestRunner(verbosity=2).run(suite)
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
import time
|
||||
import unittest
|
||||
|
||||
import stomp
|
||||
|
||||
import testlistener
|
||||
|
||||
|
||||
class TestTrans(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
conn = stomp.Connection([('127.0.0.2', 61613), ('localhost', 61613)])
|
||||
listener = testlistener.TestListener()
|
||||
conn.set_listener('', listener)
|
||||
conn.start()
|
||||
conn.connect(wait=True)
|
||||
self.conn = conn
|
||||
self.listener = listener
|
||||
|
||||
def tearDown(self):
|
||||
self.conn.disconnect()
|
||||
|
||||
def testcommit(self):
|
||||
self.conn.subscribe(destination='/queue/test', ack='auto')
|
||||
trans_id = self.conn.begin()
|
||||
self.conn.send('this is a test1', destination='/queue/test', transaction=trans_id)
|
||||
self.conn.send('this is a test2', destination='/queue/test', transaction=trans_id)
|
||||
self.conn.send('this is a test3', destination='/queue/test', transaction=trans_id)
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
self.assert_(self.listener.connections == 1, 'should have received 1 connection acknowledgement')
|
||||
self.assert_(self.listener.messages == 0, 'should not have received any messages')
|
||||
|
||||
self.conn.commit(transaction = trans_id)
|
||||
time.sleep(3)
|
||||
|
||||
self.assert_(self.listener.messages == 3, 'should have received 3 messages')
|
||||
self.assert_(self.listener.errors == 0, 'should not have received any errors')
|
||||
|
||||
def testabort(self):
|
||||
self.conn.subscribe(destination='/queue/test', ack='auto')
|
||||
trans_id = self.conn.begin()
|
||||
self.conn.send('this is a test1', destination='/queue/test', transaction=trans_id)
|
||||
self.conn.send('this is a test2', destination='/queue/test', transaction=trans_id)
|
||||
self.conn.send('this is a test3', destination='/queue/test', transaction=trans_id)
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
self.assert_(self.listener.connections == 1, 'should have received 1 connection acknowledgement')
|
||||
self.assert_(self.listener.messages == 0, 'should not have received any messages')
|
||||
|
||||
self.conn.abort(transaction = trans_id)
|
||||
time.sleep(3)
|
||||
|
||||
self.assert_(self.listener.messages == 0, 'should not have received any messages')
|
||||
self.assert_(self.listener.errors == 0, 'should not have received any errors')
|
||||
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(TestTrans)
|
||||
unittest.TextTestRunner(verbosity=2).run(suite)
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
import hashlib
|
||||
import time
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import xml
|
||||
|
||||
#
|
||||
# Used to parse STOMP header lines in the format "key:value",
|
||||
#
|
||||
HEADER_LINE_RE = re.compile('(?P<key>[^:]+)[:](?P<value>.*)')
|
||||
|
||||
|
||||
class DevNullLogger(object):
|
||||
"""
|
||||
Dummy logging class for environments without the logging module
|
||||
"""
|
||||
def log(self, msg):
|
||||
"""
|
||||
Log a message (print to console)
|
||||
"""
|
||||
print(msg)
|
||||
|
||||
def devnull(self, msg):
|
||||
"""
|
||||
Dump a message (i.e. send to /dev/null)
|
||||
"""
|
||||
pass
|
||||
|
||||
debug = devnull
|
||||
info = devnull
|
||||
warning = log
|
||||
error = log
|
||||
critical = log
|
||||
exception = log
|
||||
|
||||
def isEnabledFor(self, lvl):
|
||||
"""
|
||||
Always return False
|
||||
"""
|
||||
return False
|
||||
|
||||
def parse_headers(lines, offset=0):
|
||||
headers = {}
|
||||
for header_line in lines[offset:]:
|
||||
header_match = HEADER_LINE_RE.match(header_line)
|
||||
if header_match:
|
||||
headers[header_match.group('key')] = header_match.group('value')
|
||||
return headers
|
||||
|
||||
def parse_frame(frame):
|
||||
"""
|
||||
Parse a STOMP frame into a (frame_type, headers, body) tuple,
|
||||
where frame_type is the frame type as a string (e.g. MESSAGE),
|
||||
headers is a map containing all header key/value pairs, and
|
||||
body is a string containing the frame's payload.
|
||||
"""
|
||||
preamble_end = frame.find('\n\n')
|
||||
preamble = frame[0:preamble_end]
|
||||
preamble_lines = preamble.split('\n')
|
||||
body = frame[preamble_end+2:]
|
||||
|
||||
# Skip any leading newlines
|
||||
first_line = 0
|
||||
while first_line < len(preamble_lines) and len(preamble_lines[first_line]) == 0:
|
||||
first_line += 1
|
||||
|
||||
# Extract frame type
|
||||
frame_type = preamble_lines[first_line]
|
||||
|
||||
# Put headers into a key/value map
|
||||
headers = parse_headers(preamble_lines, first_line + 1)
|
||||
|
||||
if 'transformation' in headers:
|
||||
body = transform(body, headers['transformation'])
|
||||
|
||||
return (frame_type, headers, body)
|
||||
|
||||
def transform(body, trans_type):
|
||||
"""
|
||||
Perform body transformation. Currently, the only supported transformation is
|
||||
'jms-map-xml', which converts a map into python dictionary. This can be extended
|
||||
to support other transformation types.
|
||||
|
||||
The body has the following format:
|
||||
<map>
|
||||
<entry>
|
||||
<string>name</string>
|
||||
<string>Dejan</string>
|
||||
</entry>
|
||||
<entry>
|
||||
<string>city</string>
|
||||
<string>Belgrade</string>
|
||||
</entry>
|
||||
</map>
|
||||
|
||||
(see http://docs.codehaus.org/display/STOMP/Stomp+v1.1+Ideas)
|
||||
|
||||
\param body the content of a message
|
||||
|
||||
\param trans_type the type transformation
|
||||
"""
|
||||
if trans_type != 'jms-map-xml':
|
||||
return body
|
||||
|
||||
try:
|
||||
entries = {}
|
||||
doc = xml.dom.minidom.parseString(body)
|
||||
rootElem = doc.documentElement
|
||||
for entryElem in rootElem.getElementsByTagName("entry"):
|
||||
pair = []
|
||||
for node in entryElem.childNodes:
|
||||
if not isinstance(node, xml.dom.minidom.Element): continue
|
||||
pair.append(node.firstChild.nodeValue)
|
||||
assert len(pair) == 2
|
||||
entries[pair[0]] = pair[1]
|
||||
return entries
|
||||
except Exception:
|
||||
_, e, _ = sys.exc_info()
|
||||
#
|
||||
# unable to parse message. return original
|
||||
#
|
||||
return body
|
||||
|
||||
def merge_headers(header_map_list):
|
||||
"""
|
||||
Helper function for combining multiple header maps into one.
|
||||
"""
|
||||
headers = {}
|
||||
for header_map in header_map_list:
|
||||
for header_key in header_map.keys():
|
||||
headers[header_key] = header_map[header_key]
|
||||
return headers
|
||||
|
|
@ -1,379 +1,27 @@
|
|||
#!/usr/bin/env python
|
||||
'''
|
||||
Few tests for a rabbitmq-stomp adaptor. They intend to increase code coverage
|
||||
of the erlang stomp code.
|
||||
'''
|
||||
|
||||
import unittest
|
||||
import re
|
||||
import socket
|
||||
import functools
|
||||
import time
|
||||
import sys
|
||||
import parsing, destinations, lifecycle
|
||||
import logging
|
||||
|
||||
def connect(cnames):
|
||||
''' Decorator that creates stomp connections and issues CONNECT '''
|
||||
cmd=('CONNECT\n'
|
||||
'prefetch: 0\n'
|
||||
'login:guest\n'
|
||||
'passcode:guest\n'
|
||||
'\n'
|
||||
'\n\0')
|
||||
resp = ('CONNECTED\n'
|
||||
'session:(.*)\n'
|
||||
'\n\x00')
|
||||
def w(m):
|
||||
@functools.wraps(m)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
for cname in cnames:
|
||||
sd = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sd.settimeout(3)
|
||||
sd.connect((self.host, self.port))
|
||||
sd.sendall(cmd)
|
||||
self.match(resp, sd.recv(4096))
|
||||
setattr(self, cname, sd)
|
||||
try:
|
||||
r = m(self, *args, **kwargs)
|
||||
finally:
|
||||
for cname in cnames:
|
||||
try:
|
||||
getattr(self, cname).close()
|
||||
except IOError:
|
||||
pass
|
||||
return r
|
||||
return wrapper
|
||||
return w
|
||||
def run_unittests():
|
||||
emodules = ['parsing', 'destinations', 'lifecycle', 'transactions', 'ack']
|
||||
modules = ['destinations']
|
||||
|
||||
suite = unittest.TestSuite()
|
||||
for m in modules:
|
||||
mod = __import__(m)
|
||||
for name in dir(mod):
|
||||
obj = getattr(mod, name)
|
||||
if name.startswith("Test") and issubclass(obj, unittest.TestCase):
|
||||
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(obj))
|
||||
|
||||
|
||||
class TestConnected(unittest.TestCase):
|
||||
host='127.0.0.1'
|
||||
port=61613
|
||||
|
||||
def match(self, pattern, data):
|
||||
''' helper: try to match 'pattern' regexp with 'data' string.
|
||||
Fail testif they don't match.
|
||||
'''
|
||||
matched = re.match(pattern, data)
|
||||
if matched:
|
||||
return matched.groups()
|
||||
self.assertTrue(False, 'No match:\n%r\n%r' % (pattern, data) )
|
||||
|
||||
|
||||
@connect(['cd'])
|
||||
def test_newline_after_nul(self):
|
||||
self.cd.sendall('\n'
|
||||
'SUBSCRIBE\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'\n\x00\n'
|
||||
'SEND\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'\nhello\n\x00\n')
|
||||
resp = ('MESSAGE\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'message-id:session-(.*)\n'
|
||||
'content-type:text/plain\n'
|
||||
'content-length:6\n'
|
||||
'\n'
|
||||
'hello\n\0')
|
||||
self.match(resp, self.cd.recv(4096))
|
||||
|
||||
|
||||
@connect(['cd'])
|
||||
def test_newline_after_nul_and_leading_nul(self):
|
||||
self.cd.sendall('\n'
|
||||
'\x00SUBSCRIBE\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'\n\x00\n'
|
||||
'\x00SEND\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'\nhello\n\x00\n')
|
||||
resp = ('MESSAGE\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'message-id:session-(.*)\n'
|
||||
'content-type:text/plain\n'
|
||||
'content-length:6\n'
|
||||
'\n'
|
||||
'hello\n\0')
|
||||
self.match(resp, self.cd.recv(4096))
|
||||
|
||||
|
||||
@connect(['cd'])
|
||||
def test_subscribe_present_exchange(self):
|
||||
''' Just send a valid message '''
|
||||
self.cd.sendall('SUBSCRIBE\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'\n\x00'
|
||||
'SEND\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'\nhello\n\x00')
|
||||
resp = ('MESSAGE\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'message-id:session-(.*)\n'
|
||||
'content-type:text/plain\n'
|
||||
'content-length:6\n'
|
||||
'\n'
|
||||
'hello\n\0')
|
||||
self.match(resp, self.cd.recv(4096))
|
||||
|
||||
|
||||
@connect(['cd'])
|
||||
def test_subscribe_missing_exchange(self):
|
||||
''' Just send a message to a wrong exchange'''
|
||||
self.cd.sendall('SUBSCRIBE\n'
|
||||
'destination:/exchange/foo\n'
|
||||
'\n\x00'
|
||||
'SEND\n'
|
||||
'destination:/exchange/foo\n'
|
||||
'\nhello\n\x00')
|
||||
resp = ('ERROR\n'
|
||||
'message:not_found\n'
|
||||
'content-type:text/plain\n'
|
||||
'content-length:31\n'
|
||||
'\n'
|
||||
"no exchange 'foo' in vhost '/'\n\x00")
|
||||
self.match(resp, self.cd.recv(4096))
|
||||
|
||||
|
||||
@connect(['cd'])
|
||||
def test_bad_command(self):
|
||||
''' Trigger an error message. '''
|
||||
self.cd.sendall('WRONGCOMMAND\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'\n\0')
|
||||
resp = ('ERROR\n'
|
||||
'message:Bad command\n'
|
||||
'content-type:text/plain\n'
|
||||
'content-length:41\n'
|
||||
'\n'
|
||||
'Could not interpret command WRONGCOMMAND\n'
|
||||
'\0')
|
||||
self.match(resp, self.cd.recv(4096))
|
||||
|
||||
|
||||
@connect(['cd'])
|
||||
def test_unsubscribe_destination(self):
|
||||
''' Test UNSUBSCRIBE command with destination parameter '''
|
||||
self.cd.sendall('SUBSCRIBE\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'\n\0'
|
||||
'UNSUBSCRIBE\n'
|
||||
'receipt: 1\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'\n\0')
|
||||
resp= ('RECEIPT\n'
|
||||
'receipt-id:1\n'
|
||||
'\n\x00')
|
||||
self.match(resp, self.cd.recv(4096))
|
||||
|
||||
|
||||
@connect(['cd'])
|
||||
def test_unsubscribe_id(self):
|
||||
''' Test UNSUBSCRIBE command with id parameter'''
|
||||
self.cd.sendall('SUBSCRIBE\n'
|
||||
'id: 123\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'\n\0'
|
||||
'UNSUBSCRIBE\n'
|
||||
'receipt: 1\n'
|
||||
'id: 123\n'
|
||||
'\n\0')
|
||||
resp= ('RECEIPT\n'
|
||||
'receipt-id:1\n'
|
||||
'\n\x00')
|
||||
self.match(resp, self.cd.recv(4096))
|
||||
|
||||
|
||||
@connect(['sd', 'cd1', 'cd2'])
|
||||
def test_broadcast(self):
|
||||
''' Single message should be delivered to two consumers:
|
||||
amq.topic --routing_key--> first_queue --> first_connection
|
||||
\--routing_key--> second_queue--> second_connection
|
||||
'''
|
||||
subscribe=( 'SUBSCRIBE\n'
|
||||
'id: XsKNhAf\n'
|
||||
'destination:/exchange/amq.topic/da9d4779\n'
|
||||
'\n\0')
|
||||
for cd in [self.cd1, self.cd2]:
|
||||
cd.sendall(subscribe)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
self.sd.sendall('SEND\n'
|
||||
'destination:/exchange/amq.topic/da9d4779\n'
|
||||
'\n'
|
||||
'message'
|
||||
'\n\0')
|
||||
|
||||
resp=('MESSAGE\n'
|
||||
'destination:/exchange/amq.topic/da9d4779\n'
|
||||
'message-id:(.*)\n'
|
||||
'content-type:text/plain\n'
|
||||
'subscription:(.*)\n'
|
||||
'content-length:8\n'
|
||||
'\n'
|
||||
'message'
|
||||
'\n\x00')
|
||||
for cd in [self.cd1, self.cd2]:
|
||||
self.match(resp, cd.recv(4096))
|
||||
|
||||
|
||||
@connect(['sd', 'cd1', 'cd2'])
|
||||
def test_roundrobin(self):
|
||||
''' Two messages should be delivered to two consumers using round robin:
|
||||
amq.topic --routing_key--> single_queue --> first_connection
|
||||
\---> second_connection
|
||||
'''
|
||||
messages = ['message1', 'message2']
|
||||
subscribe=(
|
||||
'SUBSCRIBE\n'
|
||||
'id: sTXtc\n'
|
||||
'destination:/exchange/amq.topic/yAoXMwiF\n'
|
||||
'\n\0')
|
||||
for cd in [self.cd1, self.cd2]:
|
||||
cd.sendall(subscribe)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
for msg in messages:
|
||||
self.sd.sendall('SEND\n'
|
||||
'destination:/exchange/amq.topic/yAoXMwiF\n'
|
||||
'\n'
|
||||
'%s'
|
||||
'\n\0' % msg)
|
||||
|
||||
resp=('MESSAGE\n'
|
||||
'destination:/exchange/amq.topic/yAoXMwiF\n'
|
||||
'message-id:.*\n'
|
||||
'content-type:text/plain\n'
|
||||
'subscription:.*\n'
|
||||
'content-length:.\n'
|
||||
'\n'
|
||||
'(.*)'
|
||||
'\n\x00')
|
||||
|
||||
recv_messages = [self.match(resp, cd.recv(4096))[0] \
|
||||
for cd in [self.cd1, self.cd2]]
|
||||
self.assertTrue(sorted(messages) == sorted(recv_messages), \
|
||||
'%r != %r ' % (messages, recv_messages))
|
||||
|
||||
|
||||
@connect(['cd'])
|
||||
def test_disconnect(self):
|
||||
''' Run DISCONNECT command '''
|
||||
self.cd.sendall('DISCONNECT\n'
|
||||
'receipt: 1\n'
|
||||
'\n\0')
|
||||
resp= ('RECEIPT\n'
|
||||
'receipt-id:1\n'
|
||||
'\n\x00')
|
||||
self.match(resp, self.cd.recv(4096))
|
||||
|
||||
|
||||
@connect(['cd'])
|
||||
def test_ack_commit(self):
|
||||
''' Run ACK and COMMIT commands '''
|
||||
self.cd.sendall('BEGIN\n'
|
||||
'transaction: abc\n'
|
||||
'\n\0'
|
||||
'SUBSCRIBE\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'ack: client\n'
|
||||
'\n\0'
|
||||
'SEND\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'\n'
|
||||
'hello\n\0')
|
||||
resp = ('MESSAGE\n'
|
||||
'destination:/exchange/amq.fanout\n'
|
||||
'message-id:(.*)\n'
|
||||
'content-type:text/plain\n'
|
||||
'content-length:6\n'
|
||||
'\n'
|
||||
'hello')
|
||||
ack = self.match(resp, self.cd.recv(4096))[0]
|
||||
self.cd.sendall('ACK\n'
|
||||
'message-id: %s\n'
|
||||
'transaction: abc\n'
|
||||
'receipt: 1\n'
|
||||
'\n\0' % (ack,))
|
||||
resp= ('RECEIPT\n'
|
||||
'receipt-id:1\n'
|
||||
'\n\x00')
|
||||
self.match(resp, self.cd.recv(4096))
|
||||
self.cd.sendall('COMMIT\n'
|
||||
'transaction: abc\n'
|
||||
'\n\0')
|
||||
|
||||
|
||||
@connect(['cd'])
|
||||
def test_abort(self):
|
||||
''' Run ABORT command '''
|
||||
self.cd.sendall('BEGIN\n'
|
||||
'transaction: abc\n'
|
||||
'\n\0'
|
||||
'ABORT\n'
|
||||
'transaction: abc\n'
|
||||
'receipt: 1\n'
|
||||
'\n\0')
|
||||
resp= ('RECEIPT\n'
|
||||
'receipt-id:1\n'
|
||||
'\n\x00')
|
||||
self.match(resp, self.cd.recv(4096))
|
||||
|
||||
|
||||
@connect(['cd'])
|
||||
def test_huge_message(self):
|
||||
''' Test sending/receiving huge (92MB) message. '''
|
||||
subscribe=( 'SUBSCRIBE\n'
|
||||
'id: xxx\n'
|
||||
'destination:/exchange/amq.topic/test_huge_message\n'
|
||||
'\n\0')
|
||||
self.cd.sendall(subscribe)
|
||||
|
||||
# Instead of 92MB, let's use 16, so that the test can finish in
|
||||
# reasonable time.
|
||||
##message = 'x' * 1024*1024*92
|
||||
message = 'x' * 1024*1024*16
|
||||
|
||||
self.cd.sendall('SEND\n'
|
||||
'destination:/exchange/amq.topic/test_huge_message\n'
|
||||
'\n'
|
||||
'%s'
|
||||
'\0' % message)
|
||||
|
||||
resp=('MESSAGE\n'
|
||||
'destination:/exchange/amq.topic/test_huge_message\n'
|
||||
'message-id:(.*)\n'
|
||||
'content-type:text/plain\n'
|
||||
'subscription:(.*)\n'
|
||||
'content-length:%i\n'
|
||||
'\n'
|
||||
'%s(.*)'
|
||||
% (len(message), message[:8000]) )
|
||||
|
||||
recv = []
|
||||
s = 0
|
||||
while len(recv) < 1 or recv[-1][-1] != '\0':
|
||||
buf = self.cd.recv(4096*16)
|
||||
s += len(buf)
|
||||
recv.append( buf )
|
||||
buf = ''.join(recv)
|
||||
|
||||
# matching 100MB regexp is way too expensive.
|
||||
self.match(resp, buf[:8192])
|
||||
self.assertEqual(len(buf) > len(message), True)
|
||||
|
||||
|
||||
|
||||
|
||||
def run_unittests(g):
|
||||
for t in [t for t in g.keys()
|
||||
if (t.startswith('Test') and issubclass(g[t], unittest.TestCase)) ]:
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(g[t])
|
||||
ts = unittest.TextTestRunner().run(suite)
|
||||
if ts.errors or ts.failures:
|
||||
sys.exit(1)
|
||||
ts = unittest.TextTestRunner().run(unittest.TestSuite(suite))
|
||||
if ts.errors or ts.failures:
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_unittests(globals())
|
||||
run_unittests()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,52 @@
|
|||
import unittest
|
||||
import stomp
|
||||
import base
|
||||
import time
|
||||
|
||||
class TestTransactions(base.BaseTest):
|
||||
|
||||
def test_tx_commit(self):
|
||||
''' Test TX with a COMMIT and ensure messages are delivered '''
|
||||
d = "/exchange/amq.fanout"
|
||||
tx = "test.tx"
|
||||
|
||||
self.listener.reset()
|
||||
self.conn.subscribe(destination=d)
|
||||
self.conn.begin(transaction=tx)
|
||||
self.conn.send("hello!", destination=d, transaction=tx)
|
||||
self.conn.send("again!", destination=d)
|
||||
|
||||
## should see the second message
|
||||
self.assertTrue(self.listener.await(3))
|
||||
self.assertEquals(1, len(self.listener.messages))
|
||||
self.assertEquals("again!", self.listener.messages[0]['message'])
|
||||
|
||||
## now look for the first message
|
||||
self.listener.reset()
|
||||
self.conn.commit(transaction=tx)
|
||||
self.assertTrue(self.listener.await(3))
|
||||
self.assertEquals(1, len(self.listener.messages), "Missing committed message")
|
||||
self.assertEquals("hello!", self.listener.messages[0]['message'])
|
||||
|
||||
def test_tx_abort(self):
|
||||
''' Test TX with an ABORT and ensure messages are discarded '''
|
||||
d = "/exchange/amq.fanout"
|
||||
tx = "test.tx"
|
||||
|
||||
self.listener.reset()
|
||||
self.conn.subscribe(destination=d)
|
||||
self.conn.begin(transaction=tx)
|
||||
self.conn.send("hello!", destination=d, transaction=tx)
|
||||
self.conn.send("again!", destination=d)
|
||||
|
||||
## should see the second message
|
||||
self.assertTrue(self.listener.await(3))
|
||||
self.assertEquals(1, len(self.listener.messages))
|
||||
self.assertEquals("again!", self.listener.messages[0]['message'])
|
||||
|
||||
## now look for the first message to be discarded
|
||||
self.listener.reset()
|
||||
self.conn.abort(transaction=tx)
|
||||
self.assertFalse(self.listener.await(3))
|
||||
self.assertEquals(0, len(self.listener.messages), "Unexpected committed message")
|
||||
|
||||
Loading…
Reference in New Issue