more consistent error reporting on failed handshake; heartbeating is now started between connection.tune and .tune_ok; the network connection follows the closing protocol if it fails to receive connection.open_ok

This commit is contained in:
Vlad Alexandru Ionescu 2010-11-04 22:18:37 +00:00
parent 426f4bafc2
commit 4b6831ab72
6 changed files with 148 additions and 101 deletions

View File

@ -74,21 +74,26 @@ info_keys() ->
connect(AmqpParams, SIF, _ChMgr, State) ->
try do_connect(AmqpParams, SIF, State) of
Return -> Return
catch
exit:#amqp_error{name = access_refused} -> {error, auth_failure};
_:Reason -> {error, Reason}
catch _:Reason -> {error, Reason}
end.
do_connect(#amqp_params{username = User, password = Pass, virtual_host = VHost},
SIF, State) ->
case lists:keymember(rabbit, 1, application:which_applications()) of
true -> rabbit_access_control:user_pass_login(User, Pass),
rabbit_access_control:check_vhost_access(
#user{username = User}, VHost),
{ok, Collector} = SIF(),
{ok, rabbit_reader:server_properties(), 0,
State#state{user = User,
vhost = VHost,
collector = Collector}};
false -> {error, broker_not_found_in_vm}
end.
true -> ok;
false -> exit(broker_not_found_in_vm)
end,
try rabbit_access_control:user_pass_login(User, Pass) of
_ -> ok
catch exit:#amqp_error{name = access_refused} -> exit(auth_failure)
end,
try rabbit_access_control:check_vhost_access(
#user{username = User}, VHost) of
_ -> ok
catch exit:#amqp_error{name = access_refused} -> exit(access_refused)
end,
{ok, Collector} = SIF(),
{ok, {rabbit_reader:server_properties(), 0,
State#state{user = User,
vhost = VHost,
collector = Collector}}}.

View File

@ -22,6 +22,7 @@
%%
%% Contributor(s): Ben Hood <0x6e6562@gmail.com>.
%% @private
-module(amqp_gen_connection).
-include("amqp_client.hrl").
@ -117,7 +118,10 @@ behaviour_info(callbacks) ->
{terminate, 2},
%% connect(AmqpParams, SIF, ChMgr, State) ->
%% {ok, ServerProperties, ChannelMax} | {error, Error}
%% {ok, ConnectParams} | {closing, ConnectParams, AmqpError, Reply} |
%% {error, Error}
%% where
%% ConnectParams = {ServerProperties, ChannelMax, NewState}
{connect, 4},
%% do(Method, State) -> Ignored
@ -169,26 +173,21 @@ init([Mod, Sup, AmqpParams, SIF, SChMF, ExtraParams]) ->
start_channels_manager_fun = SChMF}}.
handle_call(connect, _From,
State = #state{module = Mod,
module_state = MState,
amqp_params = AmqpParams,
start_infrastructure_fun = SIF,
start_channels_manager_fun = SChMF}) ->
State0 = #state{module = Mod,
module_state = MState,
amqp_params = AmqpParams,
start_infrastructure_fun = SIF,
start_channels_manager_fun = SChMF}) ->
{ok, ChMgr} = SChMF(),
State1 = State0#state{channels_manager = ChMgr},
case Mod:connect(AmqpParams, SIF, ChMgr, MState) of
{ok, ServerProperties, ChannelMax, NewMState} ->
if ChannelMax =/= 0 ->
amqp_channels_manager:set_channel_max(ChMgr, ChannelMax);
true ->
ok
end,
{reply, {ok, self()},
State#state{module_state = NewMState,
server_properties = ServerProperties,
channel_max = ChannelMax,
channels_manager = ChMgr}};
{ok, Params} ->
{reply, {ok, self()}, after_connect(Params, State1)};
{closing, Params, #amqp_error{} = AmqpError, Error} ->
server_misbehaved(self(), AmqpError),
{reply, Error, after_connect(Params, State1)};
{error, _} = Error ->
{stop, Error, Error, State}
{stop, Error, Error, State0}
end;
handle_call({command, Command}, From, State = #state{closing = Closing}) ->
case Closing of false -> handle_command(Command, From, State);
@ -199,6 +198,16 @@ handle_call({info, Items}, _From, State) ->
handle_call(info_keys, _From, State = #state{module = Mod}) ->
{reply, ?INFO_KEYS ++ Mod:info_keys(), State}.
after_connect({ServerProperties, ChannelMax, NewMState},
State = #state{channels_manager = ChMgr}) ->
case ChannelMax of 0 -> ok;
_ -> amqp_channels_manager:set_channel_max(ChMgr,
ChannelMax)
end,
State#state{server_properties = ServerProperties,
channel_max = ChannelMax,
module_state = NewMState}.
handle_cast({method, Method, none}, State) ->
handle_method(Method, State);
handle_cast(channels_terminated, State) ->

View File

@ -135,45 +135,34 @@ connect(AmqpParams = #amqp_params{ssl_options = SslOpts,
try_handshake(AmqpParams, SIF, ChMgr, State) ->
try handshake(AmqpParams, SIF, ChMgr, State) of
Return -> Return
catch
exit:socket_closed_unexpectedly = Reason ->
{error, {auth_failure_likely, Reason}};
_:Reason ->
{error, Reason}
catch _:Reason -> {error, Reason}
end.
handshake(AmqpParams, SIF, ChMgr, State0 = #state{sock = Sock}) ->
ok = rabbit_net:send(Sock, ?PROTOCOL_HEADER),
{SHF, State1} = start_infrastructure(SIF, ChMgr, State0),
{ServerProperties, ChannelMax, State2} =
network_handshake(AmqpParams, State1),
start_heartbeat(SHF, State2),
{ok, ServerProperties, ChannelMax, State2}.
network_handshake(AmqpParams, SHF, State1).
start_infrastructure(SIF, ChMgr, State = #state{sock = Sock}) ->
{ok, {_MainReader, _Framing, Writer, SHF}} = SIF(Sock, ChMgr),
{SHF, State#state{writer0 = Writer}}.
network_handshake(AmqpParams, State) ->
Start = handshake_recv(),
#'connection.start'{server_properties = ServerProperties} = Start,
network_handshake(AmqpParams, SHF, State0) ->
Start = #'connection.start'{server_properties = ServerProperties} =
handshake_recv(expecting_start),
ok = check_version(Start),
do2(start_ok(AmqpParams), State),
Tune = handshake_recv(),
TuneOk = negotiate_values(Tune, AmqpParams),
do2(TuneOk, State),
ConnectionOpen =
#'connection.open'{virtual_host = AmqpParams#amqp_params.virtual_host},
do2(ConnectionOpen, State),
#'connection.open_ok'{} = handshake_recv(),
#'connection.tune_ok'{channel_max = ChannelMax,
frame_max = FrameMax,
heartbeat = Heartbeat} = TuneOk,
{ServerProperties, ChannelMax, State#state{heartbeat = Heartbeat,
frame_max = FrameMax}}.
start_heartbeat(SHF, #state{sock = Sock, heartbeat = Heartbeat}) ->
SHF(Sock, Heartbeat).
do2(start_ok(AmqpParams), State0),
Tune = handshake_recv(expecting_tune),
{TuneOk, ChannelMax, State1} = tune(Tune, AmqpParams, SHF, State0),
do2(TuneOk, State1),
do2(#'connection.open'{virtual_host = AmqpParams#amqp_params.virtual_host},
State1),
Params = {ServerProperties, ChannelMax, State1},
case handshake_recv(expecting_open_ok) of
#'connection.open_ok'{} -> {ok, Params};
{closing, #amqp_error{} = AmqpError, Error} -> {closing, Params,
AmqpError, Error}
end.
check_version(#'connection.start'{version_major = ?PROTOCOL_VERSION_MAJOR,
version_minor = ?PROTOCOL_VERSION_MINOR}) ->
@ -185,21 +174,26 @@ check_version(#'connection.start'{version_major = Major,
version_minor = Minor}) ->
exit({protocol_version_mismatch, Major, Minor}).
negotiate_values(#'connection.tune'{channel_max = ServerChannelMax,
frame_max = ServerFrameMax,
heartbeat = ServerHeartbeat},
#amqp_params{channel_max = ClientChannelMax,
frame_max = ClientFrameMax,
heartbeat = ClientHeartbeat}) ->
#'connection.tune_ok'{
channel_max = negotiate_max_value(ClientChannelMax, ServerChannelMax),
frame_max = negotiate_max_value(ClientFrameMax, ServerFrameMax),
heartbeat = negotiate_max_value(ClientHeartbeat, ServerHeartbeat)}.
negotiate_max_value(Client, Server) when Client =:= 0; Server =:= 0 ->
lists:max([Client, Server]);
negotiate_max_value(Client, Server) ->
lists:min([Client, Server]).
tune(#'connection.tune'{channel_max = ServerChannelMax,
frame_max = ServerFrameMax,
heartbeat = ServerHeartbeat},
#amqp_params{channel_max = ClientChannelMax,
frame_max = ClientFrameMax,
heartbeat = ClientHeartbeat},
SHF,
State = #state{sock = Sock}) ->
[ChannelMax, Heartbeat, FrameMax] =
lists:zipwith(fun (Client, Server) when Client =:= 0; Server =:= 0 ->
lists:max([Client, Server]);
(Client, Server) ->
lists:min([Client, Server])
end, [ClientChannelMax, ClientHeartbeat, ClientFrameMax],
[ServerChannelMax, ServerHeartbeat, ServerFrameMax]),
SHF(Sock, Heartbeat),
{#'connection.tune_ok'{channel_max = ChannelMax,
frame_max = FrameMax,
heartbeat = Heartbeat},
ChannelMax, State#state{heartbeat = Heartbeat, frame_max = FrameMax}}.
start_ok(#amqp_params{username = Username,
password = Password,
@ -227,16 +221,46 @@ client_properties(UserProperties) ->
lists:keystore(K, 1, Acc, Tuple)
end, Default, UserProperties).
handshake_recv() ->
handshake_recv(Phase) ->
receive
{'$gen_cast', {method, Method, none}} ->
Method;
case {Phase, Method} of
{expecting_start, #'connection.start'{}} ->
Method;
{expecting_tune, #'connection.tune'{}} ->
Method;
{expecting_open_ok, #'connection.open_ok'{}} ->
Method;
{expecting_open_ok, _} ->
{closing,
#amqp_error{name = command_invalid,
explanation = "was expecting "
"connection.open_ok"},
{error, {unexpected_method, Method, Phase}}};
_ ->
exit({unexpected_method, Method, Phase})
end;
socket_closed ->
exit(socket_closed_unexpectedly);
case Phase of expecting_tune -> exit(auth_failure);
expecting_open_ok -> exit(access_refused);
_ -> exit({socket_closed_unexpectedly,
Phase})
end;
{socket_error, _} = SocketError ->
exit(SocketError);
exit({SocketError, Phase});
timeout ->
exit(heartbeat_timeout);
Other ->
exit({handshake_recv_unexpected_message, Other})
after ?HANDSHAKE_RECEIVE_TIMEOUT ->
exit(handshake_receive_timed_out)
case Phase of
expecting_open_ok ->
{closing,
#amqp_error{name = internal_error,
explanation = "handshake timed out waiting "
"connection.open_ok"},
{error, handshake_receive_timed_out}};
_ ->
exit(handshake_receive_timed_out)
end
end.

View File

@ -97,6 +97,18 @@ bogus_rpc_test() ->
channel_death_test() ->
negative_test_util:channel_death_test(new_connection()).
non_existent_user_test() ->
negative_test_util:non_existent_user_test(fun new_connection/1).
invalid_password_test() ->
negative_test_util:invalid_password_test(fun new_connection/1).
non_existent_vhost_test() ->
negative_test_util:non_existent_vhost_test(fun new_connection/1).
no_permission_test() ->
negative_test_util:no_permission_test(fun new_connection/1).
command_invalid_over_channel_test() ->
negative_test_util:command_invalid_over_channel_test(new_connection()).
@ -105,9 +117,11 @@ command_invalid_over_channel_test() ->
%%---------------------------------------------------------------------------
new_connection() ->
case amqp_connection:start(direct) of
{ok, Conn} -> Conn;
{error, _} = Error -> Error
new_connection(#amqp_params{}).
new_connection(AmqpParams) ->
case amqp_connection:start(direct, AmqpParams) of {ok, Conn} -> Conn;
{error, _} = E -> E
end.
test_coverage() ->

View File

@ -182,25 +182,21 @@ assert_down_with_error(MonitorRef, CodeAtom) ->
exit(did_not_die)
end.
non_existent_user_test() ->
non_existent_user_test(StartConnectionFun) ->
Params = #amqp_params{username = test_util:uuid(),
password = test_util:uuid()},
assert_fail_start_with_params(Params).
?assertMatch({error, auth_failure}, StartConnectionFun(Params)).
invalid_password_test() ->
invalid_password_test(StartConnectionFun) ->
Params = #amqp_params{username = <<"guest">>,
password = test_util:uuid()},
assert_fail_start_with_params(Params).
?assertMatch({error, auth_failure}, StartConnectionFun(Params)).
non_existent_vhost_test() ->
non_existent_vhost_test(StartConnectionFun) ->
Params = #amqp_params{virtual_host = test_util:uuid()},
assert_fail_start_with_params(Params).
?assertMatch({error, access_refused}, StartConnectionFun(Params)).
no_permission_test() ->
no_permission_test(StartConnectionFun) ->
Params = #amqp_params{username = <<"test_user_no_perm">>,
password = <<"test_user_no_perm">>},
assert_fail_start_with_params(Params).
assert_fail_start_with_params(Params) ->
{error, {auth_failure_likely, _}} = amqp_connection:start(network, Params),
ok.
?assertMatch({error, access_refused}, StartConnectionFun(Params)).

View File

@ -108,16 +108,16 @@ hard_error_test() ->
repeat(fun negative_test_util:hard_error_test/1, ?ITERATIONS).
non_existent_user_test() ->
negative_test_util:non_existent_user_test().
negative_test_util:non_existent_user_test(fun new_connection/1).
invalid_password_test() ->
negative_test_util:invalid_password_test().
negative_test_util:invalid_password_test(fun new_connection/1).
non_existent_vhost_test() ->
negative_test_util:non_existent_vhost_test().
negative_test_util:non_existent_vhost_test(fun new_connection/1).
no_permission_test() ->
negative_test_util:no_permission_test().
negative_test_util:no_permission_test(fun new_connection/1).
channel_writer_death_test() ->
negative_test_util:channel_writer_death_test(new_connection()).
@ -147,9 +147,8 @@ new_connection() ->
new_connection(#amqp_params{}).
new_connection(AmqpParams) ->
case amqp_connection:start(network, AmqpParams) of
{ok, Conn} -> Conn;
{error, _Err} = Error -> Error
case amqp_connection:start(network, AmqpParams) of {ok, Conn} -> Conn;
{error, _} = E -> E
end.
test_coverage() ->