Type spec improvements in rabbit_auth_backend_oauth2

This commit is contained in:
Michael Klishin 2024-11-28 15:51:47 -05:00
parent d6366a3c65
commit 301b79c470
No known key found for this signature in database
GPG Key ID: 44BF2725475205B2
2 changed files with 52 additions and 19 deletions

View File

@ -22,6 +22,8 @@
%% End of Key JWT fields
-type raw_jwt_token() :: binary() | #{binary() => any()}.
-type decoded_jwt_token() :: #{binary() => any()}.
-record(internal_oauth_provider, {
id :: oauth_provider_id(),

View File

@ -58,6 +58,11 @@ description() ->
%%--------------------------------------------------------------------
-spec user_login_authentication(rabbit_types:username(), [term()] | map()) ->
{'ok', rabbit_types:auth_user()} |
{'refused', string(), [any()]} |
{'error', any()}.
user_login_authentication(Username, AuthProps) ->
case authenticate(Username, AuthProps) of
{refused, Msg, Args} = AuthResult ->
@ -67,12 +72,21 @@ user_login_authentication(Username, AuthProps) ->
AuthResult
end.
-spec user_login_authorization(rabbit_types:username(), [term()] | map()) ->
{'ok', any()} |
{'ok', any(), any()} |
{'refused', string(), [any()]} |
{'error', any()}.
user_login_authorization(Username, AuthProps) ->
case authenticate(Username, AuthProps) of
{ok, #auth_user{impl = Impl}} -> {ok, Impl};
Else -> Else
end.
-spec check_vhost_access(AuthUser :: rabbit_types:auth_user(),
VHost :: rabbit_types:vhost(),
AuthzData :: rabbit_types:authz_data()) -> boolean() | {'error', any()}.
check_vhost_access(#auth_user{impl = DecodedTokenFun},
VHost, _AuthzData) ->
with_decoded_token(DecodedTokenFun(),
@ -136,6 +150,11 @@ expiry_timestamp(#auth_user{impl = DecodedTokenFun}) ->
%%--------------------------------------------------------------------
-spec authenticate(Username, Props) -> Result
when Username :: rabbit_types:username(),
Props :: list() | map(),
Result :: {ok, any()} | {refused, list(), list()} | {refused, {error, any()}}.
authenticate(_, AuthProps0) ->
AuthProps = to_map(AuthProps0),
Token = token_from_context(AuthProps),
@ -149,16 +168,7 @@ authenticate(_, AuthProps0) ->
{refused, Err} ->
{refused, "Authentication using an OAuth 2/JWT token failed: ~tp", [Err]};
{ok, DecodedToken} ->
Func = fun(Token0) ->
Username = username_from(
ResourceServer#resource_server.preferred_username_claims,
Token0),
Tags = tags_from(Token0),
{ok, #auth_user{username = Username,
tags = Tags,
impl = fun() -> Token0 end}}
end,
case with_decoded_token(DecodedToken, Func) of
case with_decoded_token(DecodedToken, fun(In) -> auth_user_from_token(In, ResourceServer) end) of
{error, Err} ->
{refused, "Authentication using an OAuth 2/JWT token failed: ~tp", [Err]};
Else ->
@ -166,6 +176,14 @@ authenticate(_, AuthProps0) ->
end
end
end.
-type ok_extracted_auth_user() :: {ok, rabbit_types:auth_user()}.
-type auth_user_extraction_fun() :: fun((decoded_jwt_token()) -> any()).
-spec with_decoded_token(Token, Fun) -> Result
when Token :: decoded_jwt_token(),
Fun :: auth_user_extraction_fun(),
Result :: {ok, any()} | {'error', any()}.
with_decoded_token(DecodedToken, Fun) ->
case validate_token_expiry(DecodedToken) of
ok -> Fun(DecodedToken);
@ -173,6 +191,21 @@ with_decoded_token(DecodedToken, Fun) ->
rabbit_log:error(Msg),
Err
end.
%% This is a helper function used with HOFs that may return errors.
-spec auth_user_from_token(Token, ResourceServer) -> Result
when Token :: decoded_jwt_token(),
ResourceServer :: resource_server(),
Result :: ok_extracted_auth_user().
auth_user_from_token(Token0, ResourceServer) ->
Username = username_from(
ResourceServer#resource_server.preferred_username_claims,
Token0),
Tags = tags_from(Token0),
{ok, #auth_user{username = Username,
tags = Tags,
impl = fun() -> Token0 end}}.
ensure_same_username(PreferredUsernameClaims, CurrentDecodedToken, NewDecodedToken) ->
CurUsername = username_from(PreferredUsernameClaims, CurrentDecodedToken),
case {CurUsername, username_from(PreferredUsernameClaims, NewDecodedToken)} of
@ -188,12 +221,10 @@ validate_token_expiry(#{<<"exp">> := Exp}) when is_integer(Exp) ->
end;
validate_token_expiry(#{}) -> ok.
-spec check_token(binary() | map(), {resource_server(), internal_oauth_provider()}) ->
{'ok', map()} |
{'error', term() }|
{'refused', 'signature_invalid' |
{'error', term()} |
{'invalid_aud', term()}}.
-spec check_token(raw_jwt_token(), {resource_server(), internal_oauth_provider()}) ->
{'ok', decoded_jwt_token()} |
{'error', term() } |
{'refused', 'signature_invalid' | {'error', term()} | {'invalid_aud', term()}}.
check_token(DecodedToken, _) when is_map(DecodedToken) ->
{ok, DecodedToken};
@ -206,7 +237,7 @@ check_token(Token, {ResourceServer, InternalOAuthProvider}) ->
end.
-spec normalize_token_scope(
ResourceServer :: resource_server(), DecodedToken :: map()) -> map().
ResourceServer :: resource_server(), DecodedToken :: decoded_jwt_token()) -> map().
normalize_token_scope(ResourceServer, Payload) ->
Payload0 = maps:map(fun(K, V) ->
case K of
@ -395,7 +426,7 @@ resolve_scope_var(Elem, Token, Vhost) ->
end)
end.
-spec tags_from(map()) -> list(atom()).
-spec tags_from(decoded_jwt_token()) -> list(atom()).
tags_from(DecodedToken) ->
Scopes = maps:get(?SCOPE_JWT_FIELD, DecodedToken, []),
TagScopes = filter_matching_scope_prefix_and_drop_it(Scopes, ?TAG_SCOPE_PREFIX),