Handle SASL authenticate

This commit is contained in:
Arnaud Cogoluègnes 2020-03-24 11:47:50 +01:00
parent e2f11fdafc
commit 22dc72c9d9
1 changed files with 110 additions and 30 deletions

View File

@ -21,10 +21,13 @@
-export([start_link/4]).
-export([init/3]).
-record(connection, {
-include_lib("rabbit_common/include/rabbit.hrl").
-record(stream_connection, {
listen_socket, socket, clusters, data, consumers,
target_subscriptions, credits,
blocked
blocked,
authentication_state, user
}).
-record(consumer, {
@ -45,6 +48,7 @@
-define(COMMAND_METADATA_UPDATE, 7).
-define(COMMAND_METADATA, 8).
-define(COMMAND_SASL_HANDSHAKE, 9).
-define(COMMAND_SASL_AUTHENTICATE, 10).
-define(COMMAND_CREATE_TARGET, 998).
-define(COMMAND_DELETE_TARGET, 999).
@ -57,6 +61,10 @@
-define(RESPONSE_CODE_TARGET_ALREADY_EXISTS, 4).
-define(RESPONSE_CODE_TARGET_DELETED, 5).
-define(RESPONSE_SASL_MECHANISM_NOT_SUPPORTED, 6).
-define(RESPONSE_AUTHENTICATION_FAILURE, 7).
-define(RESPONSE_SASL_ERROR, 8).
-define(RESPONSE_SASL_CHALLENGE, 9).
-define(RESPONSE_SASL_AUTHENTICATION_FAILURE_LOOPBACK, 10).
-define(RESPONSE_FRAME_SIZE, 10). % 2 (key) + 2 (version) + 4 (correlation ID) + 2 (response code)
@ -70,10 +78,11 @@ init(Ref, Transport, _Opts = #{initial_credits := InitialCredits,
rabbit_stream_manager:register(),
Credits = atomics:new(1, [{signed, true}]),
init_credit(Credits, InitialCredits),
State = #connection{socket = Socket, data = none,
State = #stream_connection{socket = Socket, data = none,
clusters = #{},
consumers = #{}, target_subscriptions = #{},
blocked = false, credits = Credits},
blocked = false, credits = Credits,
authentication_state = none, user = none},
Transport:setopts(Socket, [{active, once}]),
listen_loop(Transport, State, #configuration{
@ -97,7 +106,7 @@ has_credits(CreditReference) ->
has_enough_credits_to_unblock(CreditReference, CreditsRequiredForUnblocking) ->
atomics:get(CreditReference, 1) > CreditsRequiredForUnblocking.
listen_loop(Transport, #connection{socket = S, consumers = Consumers,
listen_loop(Transport, #stream_connection{socket = S, consumers = Consumers,
target_subscriptions = TargetSubscriptions, credits = Credits, blocked = Blocked} = State,
#configuration{credits_required_for_unblocking = CreditsRequiredForUnblocking} = Configuration) ->
{OK, Closed, Error} = Transport:messages(),
@ -109,7 +118,7 @@ listen_loop(Transport, #connection{socket = S, consumers = Consumers,
case has_enough_credits_to_unblock(Credits, CreditsRequiredForUnblocking) of
true ->
Transport:setopts(S, [{active, once}]),
State1#connection{blocked = false};
State1#stream_connection{blocked = false};
false ->
State1
end;
@ -119,7 +128,7 @@ listen_loop(Transport, #connection{socket = S, consumers = Consumers,
Transport:setopts(S, [{active, once}]),
State1;
false ->
State1#connection{blocked = true}
State1#stream_connection{blocked = true}
end
end,
listen_loop(Transport, State2, Configuration);
@ -147,7 +156,7 @@ listen_loop(Transport, #connection{socket = S, consumers = Consumers,
case has_enough_credits_to_unblock(Credits, CreditsRequiredForUnblocking) of
true ->
Transport:setopts(S, [{active, once}]),
State#connection{blocked = false};
State#stream_connection{blocked = false};
false ->
State
end;
@ -165,7 +174,7 @@ listen_loop(Transport, #connection{socket = S, consumers = Consumers,
State;
[] ->
error_logger:info_msg("osiris offset event for ~p, but no registered consumers!", [TargetName]),
State#connection{target_subscriptions = maps:remove(TargetName, TargetSubscriptions)};
State#stream_connection{target_subscriptions = maps:remove(TargetName, TargetSubscriptions)};
CorrelationIds when is_list(CorrelationIds) ->
Consumers1 = lists:foldl(fun(CorrelationId, ConsumersAcc) ->
#{CorrelationId := Consumer} = ConsumersAcc,
@ -184,7 +193,7 @@ listen_loop(Transport, #connection{socket = S, consumers = Consumers,
end,
Consumers,
CorrelationIds),
State#connection{consumers = Consumers1}
State#stream_connection{consumers = Consumers1}
end,
listen_loop(Transport, State1, Configuration);
{Closed, S} ->
@ -201,13 +210,13 @@ listen_loop(Transport, #connection{socket = S, consumers = Consumers,
handle_inbound_data(_Transport, State, <<>>) ->
State;
handle_inbound_data(Transport, #connection{data = none} = State, <<Size:32, Frame:Size/binary, Rest/bits>>) ->
handle_inbound_data(Transport, #stream_connection{data = none} = State, <<Size:32, Frame:Size/binary, Rest/bits>>) ->
{State1, Rest1} = handle_frame(Transport, State, Frame, Rest),
handle_inbound_data(Transport, State1, Rest1);
handle_inbound_data(_Transport, #connection{data = none} = State, Data) ->
State#connection{data = Data};
handle_inbound_data(Transport, #connection{data = Leftover} = State, Data) ->
State1 = State#connection{data = none},
handle_inbound_data(_Transport, #stream_connection{data = none} = State, Data) ->
State#stream_connection{data = Data};
handle_inbound_data(Transport, #stream_connection{data = Leftover} = State, Data) ->
State1 = State#stream_connection{data = none},
%% FIXME avoid concatenation to avoid a new binary allocation
%% see osiris_replica:parse_chunk/3
handle_inbound_data(Transport, State1, <<Leftover/binary, Data/binary>>).
@ -226,7 +235,7 @@ generate_publishing_error_details(Acc, <<PublishingId:64, MessageSize:32, _Messa
<<Acc/binary, PublishingId:64, ?RESPONSE_CODE_TARGET_DOES_NOT_EXIST:16>>,
Rest).
handle_frame(Transport, #connection{socket = S, credits = Credits} = State,
handle_frame(Transport, #stream_connection{socket = S, credits = Credits} = State,
<<?COMMAND_PUBLISH:16, ?VERSION_0:16,
TargetSize:16, Target:TargetSize/binary,
MessageCount:32, Messages/binary>>, Rest) ->
@ -242,7 +251,7 @@ handle_frame(Transport, #connection{socket = S, credits = Credits} = State,
sub_credits(Credits, MessageCount),
{State1, Rest}
end;
handle_frame(Transport, #connection{socket = Socket, consumers = Consumers, target_subscriptions = TargetSubscriptions} = State,
handle_frame(Transport, #stream_connection{socket = Socket, consumers = Consumers, target_subscriptions = TargetSubscriptions} = State,
<<?COMMAND_SUBSCRIBE:16, ?VERSION_0:16, CorrelationId:32, SubscriptionId:32, TargetSize:16, Target:TargetSize/binary, Offset:64/unsigned, Credit:16>>, Rest) ->
case lookup_cluster(Target, State) of
cluster_not_found ->
@ -281,10 +290,10 @@ handle_frame(Transport, #connection{socket = Socket, consumers = Consumers, targ
_ ->
TargetSubscriptions#{TargetKey => [SubscriptionId]}
end,
{State1#connection{consumers = Consumers1, target_subscriptions = TargetSubscriptions1}, Rest}
{State1#stream_connection{consumers = Consumers1, target_subscriptions = TargetSubscriptions1}, Rest}
end
end;
handle_frame(Transport, #connection{consumers = Consumers, target_subscriptions = TargetSubscriptions, clusters = Clusters} = State,
handle_frame(Transport, #stream_connection{consumers = Consumers, target_subscriptions = TargetSubscriptions, clusters = Clusters} = State,
<<?COMMAND_UNSUBSCRIBE:16, ?VERSION_0:16, CorrelationId:32, SubscriptionId:32>>, Rest) ->
case subscription_exists(TargetSubscriptions, SubscriptionId) of
false ->
@ -307,12 +316,12 @@ handle_frame(Transport, #connection{consumers = Consumers, target_subscriptions
end,
Consumers1 = maps:remove(SubscriptionId, Consumers),
response_ok(Transport, State, ?COMMAND_SUBSCRIBE, CorrelationId),
{State#connection{consumers = Consumers1,
{State#stream_connection{consumers = Consumers1,
target_subscriptions = TargetSubscriptions1,
clusters = Clusters1
}, Rest}
end;
handle_frame(Transport, #connection{consumers = Consumers} = State,
handle_frame(Transport, #stream_connection{consumers = Consumers} = State,
<<?COMMAND_CREDIT:16, ?VERSION_0:16, SubscriptionId:32, Credit:16>>, Rest) ->
case Consumers of
@ -326,7 +335,7 @@ handle_frame(Transport, #connection{consumers = Consumers} = State,
),
Consumer1 = Consumer#consumer{segment = Segment1, credit = Credit1},
{State#connection{consumers = Consumers#{SubscriptionId => Consumer1}}, Rest};
{State#stream_connection{consumers = Consumers#{SubscriptionId => Consumer1}}, Rest};
_ ->
%% FIXME find a way to tell the client it's crediting an unknown subscription
error_logger:warning_msg("Giving credit to unknown subscription: ~p~n", [SubscriptionId]),
@ -343,7 +352,7 @@ handle_frame(Transport, State,
response(Transport, State, ?COMMAND_CREATE_TARGET, CorrelationId, ?RESPONSE_CODE_TARGET_ALREADY_EXISTS),
{State, Rest}
end;
handle_frame(Transport, #connection{socket = S} = State,
handle_frame(Transport, #stream_connection{socket = S} = State,
<<?COMMAND_DELETE_TARGET:16, ?VERSION_0:16, CorrelationId:32, TargetSize:16, Target:TargetSize/binary>>, Rest) ->
case rabbit_stream_manager:delete(binary_to_list(Target)) of
{ok, deleted} ->
@ -363,7 +372,7 @@ handle_frame(Transport, #connection{socket = S} = State,
response(Transport, State, ?COMMAND_DELETE_TARGET, CorrelationId, ?RESPONSE_CODE_TARGET_DOES_NOT_EXIST),
{State, Rest}
end;
handle_frame(Transport, #connection{socket = S} = State,
handle_frame(Transport, #stream_connection{socket = S} = State,
<<?COMMAND_METADATA:16, ?VERSION_0:16, CorrelationId:32, TargetCount:32, BinaryTargets/binary>>, Rest) ->
%% FIXME: rely only on rabbit_networking to discover the listeners
Nodes = rabbit_mnesia:cluster_nodes(all),
@ -407,7 +416,7 @@ handle_frame(Transport, #connection{socket = S} = State,
FrameSize = byte_size(Frame),
Transport:send(S, <<FrameSize:32, Frame/binary>>),
{State, Rest};
handle_frame(Transport, #connection{socket = S} = State,
handle_frame(Transport, #stream_connection{socket = S} = State,
<<?COMMAND_SASL_HANDSHAKE:16, ?VERSION_0:16, CorrelationId:32>>, Rest) ->
Mechanisms = auth_mechanisms(S),
@ -422,27 +431,94 @@ handle_frame(Transport, #connection{socket = S} = State,
Transport:send(S, [<<FrameSize:32>>, <<Frame/binary>>]),
{State, Rest};
handle_frame(Transport, #stream_connection{socket = S, authentication_state = AuthState0} = State,
<<?COMMAND_SASL_AUTHENTICATE:16, ?VERSION_0:16, CorrelationId:32,
MechanismLength:16, Mechanism:MechanismLength/binary,
SaslBinLength:32, SaslBin:SaslBinLength/binary>>, Rest) ->
%% FIXME handle null value (length = -1) for sasl binary (change the pattern matching)
{State1, Rest1} = case auth_mechanism_to_module(Mechanism, S) of
{ok, AuthMechanism} ->
AuthState = case AuthState0 of
none ->
AuthMechanism:init(S);
AS ->
AS
end,
{S1, FrameFragment} = case AuthMechanism:handle_response(SaslBin, AuthState) of
{refused, _Username, Msg, Args} ->
error_logger:warning_msg(Msg, Args),
%% TODO close connection?
{State, <<?RESPONSE_AUTHENTICATION_FAILURE:16>>};
{protocol_error, Msg, Args} ->
error_logger:warning_msg(Msg, Args),
%% TODO close connection?
{State, <<?RESPONSE_SASL_ERROR:16>>};
{challenge, Challenge, AuthState1} ->
ChallengeSize = byte_size(Challenge),
{State#stream_connection{authentication_state = AuthState1},
<<?RESPONSE_SASL_CHALLENGE:16, ChallengeSize:32, Challenge/binary>>
};
{ok, User = #user{username = Username}} ->
case rabbit_access_control:check_user_loopback(Username, S) of
ok ->
{State#stream_connection{authentication_state = done, user = User},
<<?RESPONSE_CODE_OK:16>>
};
not_allowed ->
error_logger:warning_msg("User '~s' can only connect via localhost~n", [Username]),
%% TODO close connection?
{State, <<?RESPONSE_SASL_AUTHENTICATION_FAILURE_LOOPBACK:16>>}
end
end,
Frame = <<?COMMAND_SASL_AUTHENTICATE:16, ?VERSION_0:16, CorrelationId:32, FrameFragment/binary>>,
frame(Transport, S1, Frame),
{S1, Rest};
{error, _} ->
Frame = <<?COMMAND_SASL_AUTHENTICATE:16, ?VERSION_0:16, CorrelationId:32, ?RESPONSE_SASL_MECHANISM_NOT_SUPPORTED:16>>,
frame(Transport, State, Frame),
{State, Rest}
end,
{State1, Rest1};
handle_frame(_Transport, State, Frame, Rest) ->
error_logger:warning_msg("unknown frame ~p ~p, ignoring.~n", [Frame, Rest]),
{State, Rest}.
auth_mechanisms(Sock) ->
{ok, Configured} = application:get_env(rabbit, auth_mechanisms),
[rabbit_data_coercion:to_binary(Name) || {Name, Module} <- rabbit_registry:lookup_all(auth_mechanism),
Module:should_offer(Sock), lists:member(Name, Configured)].
auth_mechanism_to_module(TypeBin, Sock) ->
case rabbit_registry:binary_to_type(TypeBin) of
{error, not_found} ->
error_logger:warning_msg("Unknown authentication mechanism '~p'~n", [TypeBin]),
{error, not_found};
T ->
case {lists:member(TypeBin, auth_mechanisms(Sock)),
rabbit_registry:lookup_module(auth_mechanism, T)} of
{true, {ok, Module}} ->
{ok, Module};
_ ->
error_logger:warning_msg("Invalid authentication mechanism '~p'~n", [T]),
{error, invalid}
end
end.
extract_target_list(<<>>, Targets) ->
Targets;
extract_target_list(<<Length:16, Target:Length/binary, Rest/binary>>, Targets) ->
extract_target_list(Rest, [Target | Targets]).
clean_state_after_target_deletion(Target, #connection{clusters = Clusters, target_subscriptions = TargetSubscriptions,
clean_state_after_target_deletion(Target, #stream_connection{clusters = Clusters, target_subscriptions = TargetSubscriptions,
consumers = Consumers} = State) ->
TargetAsList = binary_to_list(Target),
case maps:is_key(TargetAsList, TargetSubscriptions) of
true ->
#{TargetAsList := SubscriptionIds} = TargetSubscriptions,
{cleaned, State#connection{
{cleaned, State#stream_connection{
clusters = maps:remove(Target, Clusters),
target_subscriptions = maps:remove(TargetAsList, TargetSubscriptions),
consumers = maps:without(SubscriptionIds, Consumers)
@ -451,14 +527,14 @@ clean_state_after_target_deletion(Target, #connection{clusters = Clusters, targe
{not_cleaned, State}
end.
lookup_cluster(Target, #connection{clusters = Clusters} = State) ->
lookup_cluster(Target, #stream_connection{clusters = Clusters} = State) ->
case maps:get(Target, Clusters, undefined) of
undefined ->
case lookup_cluster_from_manager(Target) of
cluster_not_found ->
cluster_not_found;
ClusterPid ->
{ClusterPid, State#connection{clusters = Clusters#{Target => ClusterPid}}}
{ClusterPid, State#stream_connection{clusters = Clusters#{Target => ClusterPid}}}
end;
ClusterPid ->
{ClusterPid, State}
@ -467,10 +543,14 @@ lookup_cluster(Target, #connection{clusters = Clusters} = State) ->
lookup_cluster_from_manager(Target) ->
rabbit_stream_manager:lookup(Target).
frame(Transport, #stream_connection{socket = S}, Frame) ->
FrameSize = byte_size(Frame),
Transport:send(S, [<<FrameSize:32>>, Frame]).
response_ok(Transport, State, CommandId, CorrelationId) ->
response(Transport, State, CommandId, CorrelationId, ?RESPONSE_CODE_OK).
response(Transport, #connection{socket = S}, CommandId, CorrelationId, ResponseCode) ->
response(Transport, #stream_connection{socket = S}, CommandId, CorrelationId, ResponseCode) ->
Transport:send(S, [<<?RESPONSE_FRAME_SIZE:32, CommandId:16, ?VERSION_0:16>>, <<CorrelationId:32>>, <<ResponseCode:16>>]).
subscription_exists(TargetSubscriptions, SubscriptionId) ->