Merge pull request #149 from rabbitmq/gh-132

Ensure maximum message id value is used when saving to process state
This commit is contained in:
Michael Klishin 2017-11-09 12:27:14 +02:00 committed by GitHub
commit c1872605ff
1 changed files with 22 additions and 19 deletions

View File

@ -158,7 +158,7 @@ process_request(?PUBACK,
true -> true ->
Tag = gb_trees:get(MessageId, Awaiting), Tag = gb_trees:get(MessageId, Awaiting),
amqp_channel:cast(Channel, #'basic.ack'{ delivery_tag = Tag }), amqp_channel:cast(Channel, #'basic.ack'{ delivery_tag = Tag }),
{ok, PState #proc_state{ awaiting_ack = gb_trees:delete( MessageId, Awaiting)}} {ok, PState#proc_state{ awaiting_ack = gb_trees:delete(MessageId, Awaiting) }}
end; end;
process_request(?PUBLISH, process_request(?PUBLISH,
@ -192,13 +192,14 @@ process_request(?PUBLISH,
process_request(?SUBSCRIBE, process_request(?SUBSCRIBE,
#mqtt_frame{ #mqtt_frame{
variable = #mqtt_frame_subscribe{ variable = #mqtt_frame_subscribe{
message_id = MessageId, message_id = SubscribeMsgId,
topic_table = Topics}, topic_table = Topics},
payload = undefined}, payload = undefined},
#proc_state{channels = {Channel, _}, #proc_state{channels = {Channel, _},
exchange = Exchange, exchange = Exchange,
retainer_pid = RPid, retainer_pid = RPid,
send_fun = SendFun } = PState0) -> send_fun = SendFun,
message_id = StateMsgId } = PState0) ->
check_subscribe_or_die(Topics, fun() -> check_subscribe_or_die(Topics, fun() ->
{QosResponse, PState1} = {QosResponse, PState1} =
lists:foldl(fun (#mqtt_topic{name = TopicName, lists:foldl(fun (#mqtt_topic{name = TopicName,
@ -218,17 +219,18 @@ process_request(?SUBSCRIBE,
end, {[], PState0}, Topics), end, {[], PState0}, Topics),
SendFun(#mqtt_frame{fixed = #mqtt_frame_fixed{type = ?SUBACK}, SendFun(#mqtt_frame{fixed = #mqtt_frame_fixed{type = ?SUBACK},
variable = #mqtt_frame_suback{ variable = #mqtt_frame_suback{
message_id = MessageId, message_id = SubscribeMsgId,
qos_table = QosResponse}}, PState1), qos_table = QosResponse}}, PState1),
%% we may need to send up to length(Topics) messages. %% we may need to send up to length(Topics) messages.
%% if QoS is > 0 then we need to generate a message id, %% if QoS is > 0 then we need to generate a message id,
%% and increment the counter. %% and increment the counter.
StartMsgId = safe_max_id(SubscribeMsgId, StateMsgId),
N = lists:foldl(fun (Topic, Acc) -> N = lists:foldl(fun (Topic, Acc) ->
case maybe_send_retained_message(RPid, Topic, Acc, PState1) of case maybe_send_retained_message(RPid, Topic, Acc, PState1) of
{true, X} -> Acc + X; {true, X} -> Acc + X;
false -> Acc false -> Acc
end end
end, MessageId, Topics), end, StartMsgId, Topics),
{ok, PState1#proc_state{message_id = N}} {ok, PState1#proc_state{message_id = N}}
end, PState0); end, PState0);
@ -274,8 +276,6 @@ process_request(?PINGREQ, #mqtt_frame{}, #proc_state{ send_fun = SendFun } = PSt
process_request(?DISCONNECT, #mqtt_frame{}, PState) -> process_request(?DISCONNECT, #mqtt_frame{}, PState) ->
{stop, PState}. {stop, PState}.
%%----------------------------------------------------------------------------
hand_off_to_retainer(RetainerPid, Topic, #mqtt_msg{payload = <<"">>}) -> hand_off_to_retainer(RetainerPid, Topic, #mqtt_msg{payload = <<"">>}) ->
rabbit_mqtt_retainer:clear(RetainerPid, Topic), rabbit_mqtt_retainer:clear(RetainerPid, Topic),
ok; ok;
@ -350,11 +350,10 @@ amqp_callback({#'basic.deliver'{ consumer_tag = ConsumerTag,
{?QOS_0, ?QOS_0} -> {?QOS_0, ?QOS_0} ->
{ok, PState}; {ok, PState};
{?QOS_1, ?QOS_1} -> {?QOS_1, ?QOS_1} ->
{ok, Awaiting1 = gb_trees:insert(MsgId, DeliveryTag, Awaiting),
next_msg_id( PState1 = PState#proc_state{ awaiting_ack = Awaiting1 },
PState #proc_state{ PState2 = next_msg_id(PState1),
awaiting_ack = {ok, PState2};
gb_trees:insert(MsgId, DeliveryTag, Awaiting)})};
{?QOS_0, ?QOS_1} -> {?QOS_0, ?QOS_1} ->
amqp_channel:cast( amqp_channel:cast(
Channel, #'basic.ack'{ delivery_tag = DeliveryTag }), Channel, #'basic.ack'{ delivery_tag = DeliveryTag }),
@ -395,10 +394,17 @@ delivery_dup({#'basic.deliver'{ redelivered = Redelivered },
{bool, Dup} -> Redelivered orelse Dup {bool, Dup} -> Redelivered orelse Dup
end. end.
next_msg_id(PState = #proc_state{ message_id = 16#ffff }) -> ensure_valid_mqtt_message_id(Id) when Id >= 16#ffff ->
PState #proc_state{ message_id = 1 }; 1;
next_msg_id(PState = #proc_state{ message_id = MsgId }) -> ensure_valid_mqtt_message_id(Id) ->
PState #proc_state{ message_id = MsgId + 1 }. Id.
safe_max_id(Id0, Id1) ->
ensure_valid_mqtt_message_id(erlang:max(Id0, Id1)).
next_msg_id(PState = #proc_state{ message_id = MsgId0 }) ->
MsgId1 = ensure_valid_mqtt_message_id(MsgId0 + 1),
PState#proc_state{ message_id = MsgId1 }.
%% decide at which qos level to deliver based on subscription %% decide at which qos level to deliver based on subscription
%% and the message publish qos level. non-MQTT publishes are %% and the message publish qos level. non-MQTT publishes are
@ -438,8 +444,6 @@ session_present(Channel, ClientId) ->
_ -> false _ -> false
end. end.
%%----------------------------------------------------------------------------
make_will_msg(#mqtt_frame_connect{ will_flag = false }) -> make_will_msg(#mqtt_frame_connect{ will_flag = false }) ->
undefined; undefined;
make_will_msg(#mqtt_frame_connect{ will_retain = Retain, make_will_msg(#mqtt_frame_connect{ will_retain = Retain,
@ -776,7 +780,6 @@ human_readable_mqtt_version(_) ->
"N/A". "N/A".
send_client(Frame, #proc_state{ socket = Sock }) -> send_client(Frame, #proc_state{ socket = Sock }) ->
%log(info, "MQTT sending frame ~p ~n", [Frame]),
rabbit_net:port_command(Sock, rabbit_mqtt_frame:serialise(Frame)). rabbit_net:port_command(Sock, rabbit_mqtt_frame:serialise(Frame)).
close_connection(PState = #proc_state{ connection = undefined }) -> close_connection(PState = #proc_state{ connection = undefined }) ->