From 717730b275b779892177c326353390a469a5745a Mon Sep 17 00:00:00 2001 From: Luke Bakken Date: Wed, 8 Nov 2017 17:08:17 -0800 Subject: [PATCH] Ensure maximum message id value is used when saving to process state Fixes #132 --- .../src/rabbit_mqtt_processor.erl | 41 ++++++++++--------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/deps/rabbitmq_mqtt/src/rabbit_mqtt_processor.erl b/deps/rabbitmq_mqtt/src/rabbit_mqtt_processor.erl index 9c2d6a58c0..7e11ddaa25 100644 --- a/deps/rabbitmq_mqtt/src/rabbit_mqtt_processor.erl +++ b/deps/rabbitmq_mqtt/src/rabbit_mqtt_processor.erl @@ -158,7 +158,7 @@ process_request(?PUBACK, true -> Tag = gb_trees:get(MessageId, Awaiting), amqp_channel:cast(Channel, #'basic.ack'{ delivery_tag = Tag }), - {ok, PState #proc_state{ awaiting_ack = gb_trees:delete( MessageId, Awaiting)}} + {ok, PState#proc_state{ awaiting_ack = gb_trees:delete(MessageId, Awaiting) }} end; process_request(?PUBLISH, @@ -192,13 +192,14 @@ process_request(?PUBLISH, process_request(?SUBSCRIBE, #mqtt_frame{ variable = #mqtt_frame_subscribe{ - message_id = MessageId, + message_id = SubscribeMsgId, topic_table = Topics}, payload = undefined}, #proc_state{channels = {Channel, _}, exchange = Exchange, retainer_pid = RPid, - send_fun = SendFun } = PState0) -> + send_fun = SendFun, + message_id = StateMsgId } = PState0) -> check_subscribe_or_die(Topics, fun() -> {QosResponse, PState1} = lists:foldl(fun (#mqtt_topic{name = TopicName, @@ -218,17 +219,18 @@ process_request(?SUBSCRIBE, end, {[], PState0}, Topics), SendFun(#mqtt_frame{fixed = #mqtt_frame_fixed{type = ?SUBACK}, variable = #mqtt_frame_suback{ - message_id = MessageId, + message_id = SubscribeMsgId, qos_table = QosResponse}}, PState1), %% we may need to send up to length(Topics) messages. %% if QoS is > 0 then we need to generate a message id, %% and increment the counter. + StartMsgId = safe_max_id(SubscribeMsgId, StateMsgId), N = lists:foldl(fun (Topic, Acc) -> case maybe_send_retained_message(RPid, Topic, Acc, PState1) of {true, X} -> Acc + X; false -> Acc end - end, MessageId, Topics), + end, StartMsgId, Topics), {ok, PState1#proc_state{message_id = N}} end, PState0); @@ -274,8 +276,6 @@ process_request(?PINGREQ, #mqtt_frame{}, #proc_state{ send_fun = SendFun } = PSt process_request(?DISCONNECT, #mqtt_frame{}, PState) -> {stop, PState}. -%%---------------------------------------------------------------------------- - hand_off_to_retainer(RetainerPid, Topic, #mqtt_msg{payload = <<"">>}) -> rabbit_mqtt_retainer:clear(RetainerPid, Topic), ok; @@ -350,11 +350,10 @@ amqp_callback({#'basic.deliver'{ consumer_tag = ConsumerTag, {?QOS_0, ?QOS_0} -> {ok, PState}; {?QOS_1, ?QOS_1} -> - {ok, - next_msg_id( - PState #proc_state{ - awaiting_ack = - gb_trees:insert(MsgId, DeliveryTag, Awaiting)})}; + Awaiting1 = gb_trees:insert(MsgId, DeliveryTag, Awaiting), + PState1 = PState#proc_state{ awaiting_ack = Awaiting1 }, + PState2 = next_msg_id(PState1), + {ok, PState2}; {?QOS_0, ?QOS_1} -> amqp_channel:cast( Channel, #'basic.ack'{ delivery_tag = DeliveryTag }), @@ -395,10 +394,17 @@ delivery_dup({#'basic.deliver'{ redelivered = Redelivered }, {bool, Dup} -> Redelivered orelse Dup end. -next_msg_id(PState = #proc_state{ message_id = 16#ffff }) -> - PState #proc_state{ message_id = 1 }; -next_msg_id(PState = #proc_state{ message_id = MsgId }) -> - PState #proc_state{ message_id = MsgId + 1 }. +ensure_valid_mqtt_message_id(Id) when Id >= 16#ffff -> + 1; +ensure_valid_mqtt_message_id(Id) -> + Id. + +safe_max_id(Id0, Id1) -> + ensure_valid_mqtt_message_id(erlang:max(Id0, Id1)). + +next_msg_id(PState = #proc_state{ message_id = MsgId0 }) -> + MsgId1 = ensure_valid_mqtt_message_id(MsgId0 + 1), + PState#proc_state{ message_id = MsgId1 }. %% decide at which qos level to deliver based on subscription %% and the message publish qos level. non-MQTT publishes are @@ -438,8 +444,6 @@ session_present(Channel, ClientId) -> _ -> false end. -%%---------------------------------------------------------------------------- - make_will_msg(#mqtt_frame_connect{ will_flag = false }) -> undefined; make_will_msg(#mqtt_frame_connect{ will_retain = Retain, @@ -776,7 +780,6 @@ human_readable_mqtt_version(_) -> "N/A". send_client(Frame, #proc_state{ socket = Sock }) -> - %log(info, "MQTT sending frame ~p ~n", [Frame]), rabbit_net:port_command(Sock, rabbit_mqtt_frame:serialise(Frame)). close_connection(PState = #proc_state{ connection = undefined }) ->