diff --git a/deps/rabbitmq_stomp/src/rabbit_stomp_processor.erl b/deps/rabbitmq_stomp/src/rabbit_stomp_processor.erl index 5e01ab771f..7add748dc6 100644 --- a/deps/rabbitmq_stomp/src/rabbit_stomp_processor.erl +++ b/deps/rabbitmq_stomp/src/rabbit_stomp_processor.erl @@ -69,26 +69,29 @@ terminate(_Reason, State) -> handle_cast({"CONNECT", Frame}, State = #state{channel = none}) -> {ok, DefaultVHost} = application:get_env(rabbit, default_vhost), - with_error_unwrapping( - fun() -> + process_request( + fun(StateN) -> do_login(rabbit_stomp_frame:header(Frame, "login"), rabbit_stomp_frame:header(Frame, "passcode"), rabbit_stomp_frame:header(Frame, "virtual-host", binary_to_list(DefaultVHost)), - State) - end, State); + StateN) + end, + fun(StateM) -> StateM end, + State); handle_cast(_Request, State = #state{channel = none}) -> - {noreply, send_error("Illegal command", - "You must log in using CONNECT first\n", - State)}; + error("Illegal command", "You must log in using CONNECT first\n", State); handle_cast({Command, Frame}, State) -> - ensure_receipt(Frame, State), - with_error_unwrapping( - fun() -> - handle_frame(Command, Frame, State) - end, State). + process_request( + fun(StateN) -> + handle_frame(Command, Frame, StateN) + end, + fun(StateM) -> + ensure_receipt(Frame, StateM) + end, + State). handle_info(#'basic.consume_ok'{}, State) -> {noreply, State}; @@ -96,6 +99,30 @@ handle_info({Delivery = #'basic.deliver'{}, #amqp_msg{props = Props, payload = Payload}}, State) -> {noreply, send_delivery(Delivery, Props, Payload, State)}. +process_request(ProcessFun, SuccessFun, State) -> + Res = case catch ProcessFun(State) of + {'EXIT', + {{server_initiated_close, ReplyCode, Explanation}, _}} -> + explain_amqp_death(ReplyCode, Explanation, State); + {'EXIT', Reason} -> + priv_error("Processing error", "Processing error\n", + Reason, State); + Result -> + Result + end, + case Res of + {ok, Frame, NewState} -> + case Frame of + none -> ok; + _ -> send_frame(Frame, NewState) + end, + {noreply, SuccessFun(NewState)}; + {error, Message, Detail, NewState} -> + {noreply, send_error(Message, Detail, NewState)}; + {stop, R, State} -> + {stop, R, State} + end. + %%---------------------------------------------------------------------------- %% Frame handlers %%---------------------------------------------------------------------------- @@ -122,16 +149,14 @@ handle_frame("UNSUBSCRIBE", Frame, State = #state{subscriptions = Subs}) -> end, if ConsumerTag == missing -> - {noreply, send_error("Missing destination or id", - "UNSUBSCRIBE must include a 'destination' " - "or 'id' header\n", - State)}; + error("Missing destination or id", + "UNSUBSCRIBE must include a 'destination' or 'id' header\n", + State); true -> - {noreply, - send_method(#'basic.cancel'{consumer_tag = ConsumerTag, - nowait = true}, - State#state{subscriptions = - dict:erase(ConsumerTag, Subs)})} + ok(send_method(#'basic.cancel'{consumer_tag = ConsumerTag, + nowait = true}, + State#state{subscriptions = + dict:erase(ConsumerTag, Subs)})) end; handle_frame("SEND", Frame, State) -> @@ -155,19 +180,17 @@ handle_frame("ACK", Frame, State = #state{session_id = SessionId, State); no -> amqp_channel:call(SubChannel, Method), - {noreply, State} + ok(State) end; _ -> - {noreply, - send_error( - "Invalid message-id", - "ACK must include a valid 'message-id' header\n", - State)} + error("Invalid message-id", + "ACK must include a valid 'message-id' header\n", + State) end; not_found -> - {noreply, send_error("Missing message-id", - "ACK must include a 'message-id' header\n", - State)} + error("Missing message-id", + "ACK must include a 'message-id' header\n", + State) end; handle_frame("BEGIN", Frame, State) -> @@ -180,9 +203,9 @@ handle_frame("ABORT", Frame, State) -> transactional_action(Frame, "ABORT", fun abort_transaction/2, State); handle_frame(Command, _Frame, State) -> - {noreply, send_error("Bad command", - "Could not interpret command " ++ Command ++ "\n", - State)}. + error("Bad command", + "Could not interpret command " ++ Command ++ "\n", + State). %%---------------------------------------------------------------------------- %% Internal helpers for processing frames callbacks @@ -195,23 +218,23 @@ with_destination(Command, Frame, State, Fun) -> {ok, Destination} -> Fun(Destination, DestHdr, Frame, State); {error, {invalid_destination, Type, Content}} -> - {noreply, send_error("Invalid destination", - "'~s' is not a valid ~p destination\n", - [Content, Type], - State)}; + error("Invalid destination", + "'~s' is not a valid ~p destination\n", + [Content, Type], + State); {error, {unknown_destination, Content}} -> - {noreply, send_error("Unknown destination", - "'~s' is not a valid destination.\n" ++ - "Valid exchange types are: " ++ - "/exchange, /topic or /queue.\n", - [Content], - State)} + error("Unknown destination", + "'~s' is not a valid destination.\n" ++ + "Valid destination types are: " ++ + "/exchange, /topic or /queue.\n", + [Content], + State) end; not_found -> - {noreply, send_error("Missing destination", - "~p must include a 'destination' header\n", - [Command], - State)} + error("Missing destination", + "~p must include a 'destination' header\n", + [Command], + State) end. do_login({ok, Login}, {ok, Passcode}, VirtualHost, State) -> @@ -222,15 +245,12 @@ do_login({ok, Login}, {ok, Passcode}, VirtualHost, State) -> virtual_host = list_to_binary(VirtualHost)}), {ok, Channel} = amqp_connection:open_channel(Connection), SessionId = rabbit_guid:string_guid("session"), - {noreply, send_frame("CONNECTED", - [{"session", SessionId}], - "", - State#state{session_id = SessionId, - channel = Channel, - connection = Connection})}; + ok("CONNECTED",[{"session", SessionId}], "", + State#state{session_id = SessionId, + channel = Channel, + connection = Connection}); do_login(_, _, _, State) -> - {noreply, send_error("Bad CONNECT", "Missing login or passcode header(s)\n", - State)}. + error("Bad CONNECT", "Missing login or passcode header(s)\n", State). do_subscribe(Destination, DestHdr, Frame, State = #state{subscriptions = Subs, @@ -245,7 +265,8 @@ do_subscribe(Destination, DestHdr, Frame, prefetch_count = 1, global = false}), Channel1; - _ -> MainChannel + _ -> + MainChannel end, AckMode = rabbit_stomp_util:ack_mode(Frame), @@ -266,9 +287,8 @@ do_subscribe(Destination, DestHdr, Frame, ExchangeAndKey = rabbit_stomp_util:parse_routing_information(Destination), ok = ensure_queue_binding(Queue, ExchangeAndKey, Channel), - {noreply, - State#state{subscriptions = - dict:store(ConsumerTag, {DestHdr, Channel}, Subs)}}. + ok(State#state{subscriptions = + dict:store(ConsumerTag, {DestHdr, Channel}, Subs)}). do_send(Destination, _DestHdr, Frame = #stomp_frame{body_iolist = BodyFragments}, @@ -292,7 +312,7 @@ do_send(Destination, _DestHdr, {Method, Props, BodyFragments}, State); no -> - {noreply, send_method(Method, Props, BodyFragments, State)} + ok(send_method(Method, Props, BodyFragments, State)) end. ensure_receipt(Frame, State) -> @@ -308,11 +328,11 @@ send_delivery(Delivery = #'basic.deliver'{consumer_tag = ConsumerTag}, {Destination, _SubChannel} = dict:fetch(ConsumerTag, Subs), send_frame( - "MESSAGE", - rabbit_stomp_util:message_headers(Destination, SessionId, - Delivery, Properties), - Body, - State). + "MESSAGE", + rabbit_stomp_util:message_headers(Destination, SessionId, + Delivery, Properties), + Body, + State). send_method(Method, State = #state{channel = Channel}) -> amqp_channel:call(Channel, Method), @@ -320,25 +340,11 @@ send_method(Method, State = #state{channel = Channel}) -> send_method(Method, Properties, BodyFragments, State = #state{channel = Channel}) -> - amqp_channel:call(Channel,Method, #amqp_msg{ + amqp_channel:call(Channel, Method, #amqp_msg{ props = Properties, payload = lists:reverse(BodyFragments)}), State. -with_error_unwrapping(Fun, State) -> - case catch Fun() of - {'EXIT', - {{server_initiated_close, ReplyCode, Explanation}, _}} -> - {noreply, - explain_amqp_death(ReplyCode, Explanation, State)}; - {'EXIT', Reason} -> - {noreply, - send_priv_error("Processing error", "Processing error\n", - Reason, State)}; - Result -> - Result - end. - shutdown_channel_and_connection(State = #state{channel = Channel, connection = Connection, subscriptions = Subs}) -> @@ -374,32 +380,32 @@ transactional_action(Frame, Name, Fun, State) -> {yes, Transaction} -> Fun(Transaction, State); no -> - {noreply, send_error("Missing transaction", - Name ++ " must include a 'transaction' header\n", - State)} + error("Missing transaction", + Name ++ " must include a 'transaction' header\n", + State) end. with_transaction(Transaction, State, Fun) -> case get({transaction, Transaction}) of undefined -> - {noreply, send_error("Bad transaction", - "Invalid transaction identifier: ~p\n", - [Transaction], - State)}; + error("Bad transaction", + "Invalid transaction identifier: ~p\n", + [Transaction], + State); Actions -> Fun(Actions, State) end. begin_transaction(Transaction, State) -> put({transaction, Transaction}, []), - {noreply, State}. + ok(State). extend_transaction(Transaction, Action, State0) -> with_transaction( Transaction, State0, fun (Actions, State) -> put({transaction, Transaction}, [Action | Actions]), - {noreply, State} + ok(State) end). commit_transaction(Transaction, State0) -> @@ -410,7 +416,7 @@ commit_transaction(Transaction, State0) -> State, Actions), erase({transaction, Transaction}), - {noreply, FinalState} + ok(State) end). abort_transaction(Transaction, State0) -> @@ -418,7 +424,7 @@ abort_transaction(Transaction, State0) -> Transaction, State0, fun (_Actions, State) -> erase({transaction, Transaction}), - {noreply, State} + ok(State) end). perform_transaction_action({Method}, State) -> @@ -473,6 +479,41 @@ ensure_queue_binding(Queue, {Exchange, RoutingKey}, Channel) -> exchange = list_to_binary(Exchange), routing_key = list_to_binary(RoutingKey)}), ok. +%%---------------------------------------------------------------------------- +%% Success/error handling +%%---------------------------------------------------------------------------- + +ok(State) -> + {ok, none, State}. + +ok(Command, Headers, BodyFragments, State) -> + {ok, #stomp_frame{command = Command, + headers = Headers, + body_iolist = BodyFragments}, State}. + +explain_amqp_death(ReplyCode, Explanation, State) -> + ErrorName = ?PROTOCOL:amqp_exception(ReplyCode), + error(atom_to_list(ErrorName), "~s\n", + [Explanation], State). + +error(Message, Detail, State) -> + priv_error(Message, Detail, none, State). + +error(Message, Format, Args, State) -> + priv_error(Message, Format, Args, none, State). + +priv_error(Message, Detail, ServerPrivateDetail, State) -> + error_logger:error_msg("STOMP error frame sent:~n" ++ + "Message: ~p~n" ++ + "Detail: ~p~n" ++ + "Server private detail: ~p~n", + [Message, Detail, ServerPrivateDetail]), + {error, Message, Detail, State}. + +priv_error(Message, Format, Args, ServerPrivateDetail, State) -> + priv_error(Message, lists:flatten(io_lib:format(Format, Args)), + ServerPrivateDetail, State). + %%---------------------------------------------------------------------------- %% Frame sending utilities @@ -499,30 +540,10 @@ send_frame(Frame, State = #state{socket = Sock}) -> State end. -explain_amqp_death(ReplyCode, Explanation, State) -> - ErrorName = ?PROTOCOL:amqp_exception(ReplyCode), - send_error(atom_to_list(ErrorName), "~s\n", - [Explanation], State). - send_error(Message, Detail, State) -> - send_priv_error(Message, Detail, none, State). - -send_error(Message, Format, Args, State) -> - send_priv_error(Message, Format, Args, none, State). - -send_priv_error(Message, Detail, ServerPrivateDetail, State) -> - error_logger:error_msg("STOMP error frame sent:~n" ++ - "Message: ~p~n" ++ - "Detail: ~p~n" ++ - "Server private detail: ~p~n", - [Message, Detail, ServerPrivateDetail]), send_frame("ERROR", [{"message", Message}, {"content-type", "text/plain"}], Detail, State). -send_priv_error(Message, Format, Args, ServerPrivateDetail, State) -> - send_priv_error(Message, lists:flatten(io_lib:format(Format, Args)), - ServerPrivateDetail, State). - %%---------------------------------------------------------------------------- %% Skeleton gen_server callbacks %%----------------------------------------------------------------------------