This commit is contained in:
Simon 2025-10-07 14:46:08 -04:00 committed by GitHub
commit 8b3555b544
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 702 additions and 234 deletions

View File

@ -103,7 +103,7 @@
-type value() :: string().
-type header() :: {Field :: field(), Value :: value()}.
-type headers() :: [header()].
-type body() :: string() | binary().
-type body() :: iodata().
-type ssl_options() :: [ssl:tls_client_option()].

View File

@ -10,15 +10,20 @@
%% API exports
-export([
get/2, get/3,
get/2, get/3, get/4,
put/4, put/5,
post/4,
refresh_credentials/0,
request/5, request/6, request/7,
set_credentials/2,
has_credentials/0,
parse_uri/1,
set_region/1,
ensure_imdsv2_token_valid/0,
api_get_request/2
api_get_request/2,
status_text/1,
open_connection/1, open_connection/2,
close_connection/1
]).
%% gen-server exports
@ -40,23 +45,33 @@
-include("rabbitmq_aws.hrl").
-include_lib("kernel/include/logger.hrl").
%% Types for new concurrent API
-type connection_handle() :: {pid(), credential_context()}.
-type credential_context() :: #{
access_key => access_key(),
secret_access_key => secret_access_key(),
security_token => security_token(),
region => region(),
service => string()
}.
%%====================================================================
%% exported wrapper functions
%%====================================================================
-spec get(
Service :: string(),
ServiceOrHandle :: string() | connection_handle(),
Path :: path()
) -> result().
%% @doc Perform a HTTP GET request to the AWS API for the specified service. The
%% response will automatically be decoded if it is either in JSON, or XML
%% format.
%% @end
get(Service, Path) ->
get(Service, Path, []).
get(ServiceOrHandle, Path) ->
get(ServiceOrHandle, Path, []).
-spec get(
Service :: string(),
ServiceOrHandle :: string() | connection_handle(),
Path :: path(),
Headers :: headers()
) -> result().
@ -64,11 +79,14 @@ get(Service, Path) ->
%% response will automatically be decoded if it is either in JSON or XML
%% format.
%% @end
get(Service, Path, Headers) ->
request(Service, get, Path, "", Headers).
get(ServiceOrHandle, Path, Headers) ->
get(ServiceOrHandle, Path, Headers, []).
get(Service, Path, Headers, Options) ->
request(Service, get, Path, <<>>, Headers, Options).
-spec post(
Service :: string(),
ServiceOrHandle :: string() | connection_handle(),
Path :: path(),
Body :: body(),
Headers :: headers()
@ -77,8 +95,27 @@ get(Service, Path, Headers) ->
%% response will automatically be decoded if it is either in JSON or XML
%% format.
%% @end
post(Service, Path, Body, Headers) ->
request(Service, post, Path, Body, Headers).
post(ServiceOrHandle, Path, Body, Headers) ->
post(ServiceOrHandle, Path, Body, Headers, []).
post(Service, Path, Body, Headers, Options) ->
request(Service, post, Path, Body, Headers, Options).
-spec put(
ServiceOrHandle :: string() | connection_handle(),
Path :: path(),
Body :: body(),
Headers :: headers()
) -> result().
%% @doc Perform a HTTP Post request to the AWS API for the specified service. The
%% response will automatically be decoded if it is either in JSON or XML
%% format.
%% @end
put(ServiceOrHandle, Path, Body, Headers) ->
put(ServiceOrHandle, Path, Body, Headers, []).
put(Service, Path, Body, Headers, Options) ->
request(Service, put, Path, Body, Headers, Options).
-spec refresh_credentials() -> ok | error.
%% @doc Manually refresh the credentials from the environment, filesystem or EC2 Instance Metadata Service.
@ -86,6 +123,46 @@ post(Service, Path, Body, Headers) ->
refresh_credentials() ->
gen_server:call(rabbitmq_aws, refresh_credentials).
%%====================================================================
%% New Concurrent API Functions
%%====================================================================
%% Open a connection and return handle for direct use
-spec open_connection(Service :: string()) -> {ok, connection_handle()} | {error, term()}.
open_connection(Service) ->
open_connection(Service, []).
-spec open_connection(Service :: string(), Options :: list()) ->
{ok, connection_handle()} | {error, term()}.
open_connection(Service, Options) ->
gen_server:call(?MODULE, {open_direct_connection, Service, Options}).
%% Close a direct connection
-spec close_connection(Handle :: connection_handle()) -> ok.
close_connection({GunPid, _CredContext}) ->
gun:close(GunPid).
-spec direct_request(
Handle :: connection_handle(),
Method :: method(),
Path :: path(),
Body :: body(),
Headers :: headers(),
Options :: list()
) -> result().
direct_request({GunPid, CredContext}, Method, Path, Body, Headers, Options) ->
#{service := Service, region := Region} = CredContext,
% Build URI for signing
Host = endpoint_host(Region, Service),
URI = create_uri(Host, Path),
BodyHash = proplists:get_value(payload_hash, Options),
% Sign headers directly (no gen_server call)
SignedHeaders = sign_headers_with_context(
CredContext, Method, URI, Headers, Body, BodyHash
),
% Make Gun request directly
direct_gun_request(GunPid, Method, Path, SignedHeaders, Body, Options).
-spec refresh_credentials(state()) -> ok | error.
%% @doc Manually refresh the credentials from the environment, filesystem or EC2 Instance Metadata Service.
%% @end
@ -107,7 +184,7 @@ refresh_credentials(State) ->
%% format.
%% @end
request(Service, Method, Path, Body, Headers) ->
gen_server:call(rabbitmq_aws, {request, Service, Method, Headers, Path, Body, [], undefined}).
request(Service, Method, Path, Body, Headers, []).
-spec request(
Service :: string(),
@ -122,12 +199,10 @@ request(Service, Method, Path, Body, Headers) ->
%% format.
%% @end
request(Service, Method, Path, Body, Headers, HTTPOptions) ->
gen_server:call(
rabbitmq_aws, {request, Service, Method, Headers, Path, Body, HTTPOptions, undefined}
).
request(Service, Method, Path, Body, Headers, HTTPOptions, undefined).
-spec request(
Service :: string(),
ServiceOrHandle :: string() | connection_handle(),
Method :: method(),
Path :: path(),
Body :: body(),
@ -140,6 +215,10 @@ request(Service, Method, Path, Body, Headers, HTTPOptions) ->
%% of services such as DynamoDB. The response will automatically be decoded
%% if it is either in JSON or XML format.
%% @end
request({GunPid, _CredContext} = Handle, Method, Path, Body, Headers, HTTPOptions, _) when
is_pid(GunPid)
->
direct_request(Handle, Method, Path, Body, Headers, HTTPOptions);
request(Service, Method, Path, Body, Headers, HTTPOptions, Endpoint) ->
gen_server:call(
rabbitmq_aws, {request, Service, Method, Headers, Path, Body, HTTPOptions, Endpoint}
@ -186,9 +265,10 @@ start_link() ->
-spec init(list()) -> {ok, state()}.
init([]) ->
{ok, _} = application:ensure_all_started(gun),
{ok, #state{}}.
terminate(_, _) ->
terminate(_, _State) ->
ok.
code_change(_, _, State) ->
@ -211,6 +291,18 @@ handle_msg({request, Service, Method, Headers, Path, Body, Options, Host}, State
State, Service, Method, Headers, Path, Body, Options, Host
),
{reply, Response, NewState};
handle_msg({open_direct_connection, Service, Options}, State) ->
case ensure_credentials_valid_internal(State) of
{ok, ValidState} ->
case create_direct_connection(ValidState, Service, Options) of
{ok, Handle} ->
{reply, {ok, Handle}, ValidState};
{error, Reason} ->
{reply, {error, Reason}, ValidState}
end;
{error, Reason} ->
{reply, {error, Reason}, State}
end;
handle_msg(get_state, State) ->
{reply, {ok, State}, State};
handle_msg(refresh_credentials, State) ->
@ -282,6 +374,8 @@ endpoint_tld(_Other) ->
%% @end
format_response({ok, {{_Version, 200, _Message}, Headers, Body}}) ->
{ok, {Headers, maybe_decode_body(get_content_type(Headers), Body)}};
format_response({ok, {{_Version, 206, _Message}, Headers, Body}}) ->
{ok, {Headers, maybe_decode_body(get_content_type(Headers), Body)}};
format_response({ok, {{_Version, StatusCode, Message}, Headers, Body}}) when StatusCode >= 400 ->
{error, Message, {Headers, maybe_decode_body(get_content_type(Headers), Body)}};
format_response({error, Reason}) ->
@ -293,9 +387,9 @@ format_response({error, Reason}) ->
%% @end
get_content_type(Headers) ->
Value =
case proplists:get_value("content-type", Headers, undefined) of
case proplists:get_value(<<"content-type">>, Headers, undefined) of
undefined ->
proplists:get_value("Content-Type", Headers, "text/xml");
proplists:get_value(<<"Content-Type">>, Headers, "text/xml");
Other ->
Other
end,
@ -368,6 +462,8 @@ local_time() ->
list() | body().
%% @doc Attempt to decode the response body by its MIME
%% @end
maybe_decode_body(_, <<>>) ->
<<>>;
maybe_decode_body({"application", "x-amz-json-1.0"}, Body) ->
rabbitmq_aws_json:decode(Body);
maybe_decode_body({"application", "json"}, Body) ->
@ -380,6 +476,8 @@ maybe_decode_body(_ContentType, Body) ->
-spec parse_content_type(ContentType :: string()) -> {Type :: string(), Subtype :: string()}.
%% @doc parse a content type string returning a tuple of type/subtype
%% @end
parse_content_type(ContentType) when is_binary(ContentType) ->
parse_content_type(binary_to_list(ContentType));
parse_content_type(ContentType) ->
Parts = string:tokens(ContentType, ";"),
[Type, Subtype] = string:tokens(lists:nth(1, Parts), "/"),
@ -480,15 +578,13 @@ perform_request_creds_expired(true, State, _, _, _, _, _, _, _) ->
perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options, Host) ->
URI = endpoint(State, Host, Service, Path),
SignedHeaders = sign_headers(State, Service, Method, URI, Headers, Body),
ContentType = proplists:get_value("content-type", SignedHeaders, undefined),
perform_request_with_creds(State, Method, URI, SignedHeaders, ContentType, Body, Options).
perform_request_with_creds(State, Method, URI, SignedHeaders, Body, Options).
-spec perform_request_with_creds(
State :: state(),
Method :: method(),
URI :: string(),
Headers :: headers(),
ContentType :: string() | undefined,
Body :: body(),
Options :: http_options()
) ->
@ -496,14 +592,12 @@ perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options,
%% @doc Once it is validated that there are credentials to try and that they have not
%% expired, perform the request and return the response.
%% @end
perform_request_with_creds(State, Method, URI, Headers, undefined, "", Options0) ->
Options1 = ensure_timeout(Options0),
Response = httpc:request(Method, {URI, Headers}, Options1, []),
{format_response(Response), State};
perform_request_with_creds(State, Method, URI, Headers, ContentType, Body, Options0) ->
Options1 = ensure_timeout(Options0),
Response = httpc:request(Method, {URI, Headers, ContentType, Body}, Options1, []),
{format_response(Response), State}.
perform_request_with_creds(State, Method, URI, Headers, "", Options0) ->
Response = gun_request(Method, URI, Headers, <<>>, Options0),
{Response, State};
perform_request_with_creds(State, Method, URI, Headers, Body, Options0) ->
Response = gun_request(Method, URI, Headers, Body, Options0),
{Response, State}.
-spec perform_request_creds_error(State :: state()) ->
{result_error(), NewState :: state()}.
@ -513,22 +607,6 @@ perform_request_with_creds(State, Method, URI, Headers, ContentType, Body, Optio
perform_request_creds_error(State) ->
{{error, {credentials, State#state.error}}, State}.
%% @doc Ensure that the timeout option is set and greater than 0 and less
%% than about 1/2 of the default gen_server:call timeout. This gives
%% enough time for a long connect and request phase to succeed.
%% @end
-spec ensure_timeout(Options :: http_options()) -> http_options().
ensure_timeout(Options) ->
case proplists:get_value(timeout, Options) of
undefined ->
Options ++ [{timeout, ?DEFAULT_HTTP_TIMEOUT}];
Value when is_integer(Value) andalso Value >= 0 andalso Value =< ?DEFAULT_HTTP_TIMEOUT ->
Options;
_ ->
Options1 = proplists:delete(timeout, Options),
Options1 ++ [{timeout, ?DEFAULT_HTTP_TIMEOUT}]
end.
-spec sign_headers(
State :: state(),
Service :: string(),
@ -648,3 +726,207 @@ api_get_request_with_retries(Service, Path, Retries, WaitTimeBetweenRetries) ->
timer:sleep(WaitTimeBetweenRetries),
api_get_request_with_retries(Service, Path, Retries - 1, WaitTimeBetweenRetries)
end.
%% Gun HTTP client functions
gun_request(Method, URI, Headers, Body, Options) ->
{Host, Port, Path} = parse_uri(URI),
GunPid = create_gun_connection(Host, Port, Options),
Reply = direct_gun_request(GunPid, Method, Path, Headers, Body, Options),
gun:close(GunPid),
Reply.
do_gun_request(ConnPid, get, Path, Headers, _Body) ->
gun:get(ConnPid, Path, Headers);
do_gun_request(ConnPid, post, Path, Headers, Body) ->
gun:post(ConnPid, Path, Headers, Body, #{});
do_gun_request(ConnPid, put, Path, Headers, Body) ->
gun:put(ConnPid, Path, Headers, Body, #{});
do_gun_request(ConnPid, head, Path, Headers, _Body) ->
gun:head(ConnPid, Path, Headers, #{});
do_gun_request(ConnPid, delete, Path, Headers, _Body) ->
gun:delete(ConnPid, Path, Headers, #{});
do_gun_request(ConnPid, patch, Path, Headers, Body) ->
gun:patch(ConnPid, Path, Headers, Body, #{});
do_gun_request(ConnPid, options, Path, Headers, _Body) ->
gun:options(ConnPid, Path, Headers, #{}).
create_gun_connection(Host, Port, Options) ->
% Map HTTP version to Gun protocols, always include http as fallback
HttpVersion = proplists:get_value(version, Options, "HTTP/1.1"),
Protocols =
case HttpVersion of
"HTTP/2" -> [http2, http];
"HTTP/2.0" -> [http2, http];
"HTTP/1.1" -> [http];
"HTTP/1.0" -> [http];
% Default: try HTTP/2, fallback to HTTP/1.1
_ -> [http2, http]
end,
ConnectTimeout = proplists:get_value(connect_timeout, Options, infinity),
Opts = #{
transport =>
if
Port == 443 -> tls;
true -> tcp
end,
protocols => Protocols,
connect_timeout => ConnectTimeout
},
case gun:open(Host, Port, Opts) of
{ok, ConnPid} ->
case gun:await_up(ConnPid, ConnectTimeout) of
{ok, _Protocol} ->
ConnPid;
{error, Reason} ->
gun:close(ConnPid),
error({gun_connection_failed, Reason})
end;
{error, Reason} ->
error({gun_open_failed, Reason})
end.
create_uri(Host, Path) when is_list(Path) ->
"https://" ++ Host ++ Path;
create_uri(Host, {Bucket, Key}) ->
"https://" ++ Bucket ++ "." ++ Host ++ "/" ++ Key.
parse_uri(URI) ->
case string:split(URI, "://", leading) of
[Scheme, Rest] ->
case string:split(Rest, "/", leading) of
[HostPort] ->
{Host, Port} = parse_host_port(HostPort, Scheme),
{Host, Port, "/"};
[HostPort, Path] ->
{Host, Port} = parse_host_port(HostPort, Scheme),
{Host, Port, "/" ++ Path}
end
end.
parse_host_port(HostPort, Scheme) ->
DefaultPort =
case Scheme of
"https" -> 443;
"http" -> 80;
% Fallback to HTTPS
_ -> 443
end,
case string:split(HostPort, ":", trailing) of
[Host] ->
{Host, DefaultPort};
[Host, PortStr] ->
{Host, list_to_integer(PortStr)}
end.
status_text(200) -> "OK";
status_text(206) -> "Partial Content";
status_text(400) -> "Bad Request";
status_text(401) -> "Unauthorized";
status_text(403) -> "Forbidden";
status_text(404) -> "Not Found";
status_text(416) -> "Range Not Satisfiable";
status_text(500) -> "Internal Server Error";
status_text(Code) -> integer_to_list(Code).
%%====================================================================
%% New Concurrent API Helper Functions
%%====================================================================
%% Create a direct connection handle
-spec create_direct_connection(State :: state(), Service :: string(), Options :: list()) ->
{ok, connection_handle()} | {error, term()}.
create_direct_connection(State, Service, Options) ->
Region = State#state.region,
Host = endpoint_host(Region, Service),
Port = 443,
GunPid = create_gun_connection(Host, Port, Options),
CredContext = #{
access_key => State#state.access_key,
secret_access_key => State#state.secret_access_key,
security_token => State#state.security_token,
region => Region,
service => Service
},
{ok, {GunPid, CredContext}}.
%% Sign headers using credential context (no gen_server state needed)
-spec sign_headers_with_context(
CredContext :: credential_context(),
Method :: method(),
URI :: string(),
Headers :: headers(),
Body :: body(),
BodyHash :: iodata()
) -> headers().
sign_headers_with_context(CredContext, Method, URI, Headers, Body, BodyHash) ->
#{
access_key := AccessKey,
secret_access_key := SecretKey,
security_token := SecurityToken,
region := Region,
service := Service
} = CredContext,
rabbitmq_aws_sign:headers(
#request{
access_key = AccessKey,
secret_access_key = SecretKey,
security_token = SecurityToken,
region = Region,
service = Service,
method = Method,
uri = URI,
headers = Headers,
body = Body
},
BodyHash
).
%% Direct Gun request (extracted from existing gun_request function)
-spec direct_gun_request(
GunPid :: pid(),
Method :: method(),
Path :: path(),
Headers :: headers(),
Body :: body(),
Options :: list()
) -> result().
direct_gun_request(GunPid, Method, {_, Path}, Headers, Body, Options) ->
direct_gun_request(GunPid, Method, [$/ | Path], Headers, Body, Options);
direct_gun_request(GunPid, Method, Path, Headers, Body, Options) ->
HeadersBin = lists:map(
fun({Key, Value}) ->
{list_to_binary(Key), list_to_binary(Value)}
end,
Headers
),
Timeout = proplists:get_value(timeout, Options, ?DEFAULT_HTTP_TIMEOUT),
Response =
try
StreamRef = do_gun_request(GunPid, Method, Path, HeadersBin, Body),
case gun:await(GunPid, StreamRef, Timeout) of
{response, fin, Status, RespHeaders} ->
{ok, {{http_version, Status, status_text(Status)}, RespHeaders, <<>>}};
{response, nofin, Status, RespHeaders} ->
{ok, RespBody} = gun:await_body(GunPid, StreamRef, Timeout),
{ok, {{http_version, Status, status_text(Status)}, RespHeaders, RespBody}};
{error, Reason} ->
{error, Reason}
end
catch
_:Error ->
{error, Error}
end,
format_response(Response).
%% Internal credential validation (extracted from existing logic)
-spec ensure_credentials_valid_internal(State :: state()) -> {ok, state()} | {error, term()}.
ensure_credentials_valid_internal(State) ->
case has_credentials(State) of
true ->
case expired_credentials(State#state.expiration) of
false -> {ok, State};
true -> load_credentials(State)
end;
false ->
load_credentials(State)
end.

View File

@ -629,9 +629,14 @@ maybe_get_role_from_instance_metadata() ->
%% @doc Parse the response from the Availability Zone query to the
%% Instance Metadata service, returning the Region if successful.
%% end.
parse_az_response({error, _}) -> {error, undefined};
parse_az_response({ok, {{_, 200, _}, _, Body}}) -> {ok, region_from_availability_zone(Body)};
parse_az_response({ok, {{_, _, _}, _, _}}) -> {error, undefined}.
parse_az_response({error, _}) ->
{error, undefined};
parse_az_response({ok, {{_, 200, _}, _, Body}}) when is_binary(Body) ->
{ok, region_from_availability_zone(binary_to_list(Body))};
parse_az_response({ok, {{_, 200, _}, _, Body}}) ->
{ok, region_from_availability_zone(Body)};
parse_az_response({ok, {{_, _, _}, _, _}}) ->
{error, undefined}.
-spec parse_body_response(httpc_result()) ->
{ok, Value :: string()} | {error, Reason :: atom()}.
@ -640,8 +645,9 @@ parse_az_response({ok, {{_, _, _}, _, _}}) -> {error, undefined}.
%% end.
parse_body_response({error, _}) ->
{error, undefined};
parse_body_response({ok, {{_, 200, _}, _, Body}}) ->
{ok, Body};
parse_body_response({ok, {{_, 200, _}, _, Body}}) when is_binary(Body) ->
{ok, binary_to_list(Body)};
parse_body_response({ok, {{_, 200, _}, _, Body}}) when is_list(Body) -> {ok, Body};
parse_body_response({ok, {{_, 401, _}, _, _}}) ->
?LOG_ERROR(
get_instruction_on_instance_metadata_error(
@ -678,12 +684,47 @@ parse_credentials_response({ok, {{_, 200, _}, _, Body}}) ->
%% @end
perform_http_get_instance_metadata(URL) ->
?LOG_DEBUG("Querying instance metadata service: ~tp", [URL]),
httpc:request(
get,
{URL, instance_metadata_request_headers()},
[{timeout, ?DEFAULT_HTTP_TIMEOUT}],
[]
).
% Parse metadata service URL
{Host, Port, Path} = rabbitmq_aws:parse_uri(URL),
% Simple Gun connection for metadata service
% HTTP only, no TLS
Opts = #{transport => tcp, protocols => [http]},
case gun:open(Host, Port, Opts) of
{ok, ConnPid} ->
case gun:await_up(ConnPid, 5000) of
{ok, _Protocol} ->
Headers = instance_metadata_request_headers(),
StreamRef = gun:get(ConnPid, Path, Headers),
Result =
case gun:await(ConnPid, StreamRef, ?DEFAULT_HTTP_TIMEOUT) of
{response, fin, Status, RespHeaders} ->
{ok, {
{http_version, Status, rabbitmq_aws:status_text(Status)},
RespHeaders,
<<>>
}};
{response, nofin, Status, RespHeaders} ->
{ok, Body} = gun:await_body(
ConnPid, StreamRef, ?DEFAULT_HTTP_TIMEOUT
),
{ok, {
{http_version, Status, rabbitmq_aws:status_text(Status)},
RespHeaders,
Body
}};
{error, Reason} ->
{error, Reason}
end,
gun:close(ConnPid),
Result;
{error, Reason} ->
gun:close(ConnPid),
{error, Reason}
end;
{error, Reason} ->
{error, Reason}
end.
-spec get_instruction_on_instance_metadata_error(string()) -> string().
%% @doc Return error message on failures related to EC2 Instance Metadata Service with a reference to AWS document.
@ -742,29 +783,77 @@ region_from_availability_zone(Value) ->
load_imdsv2_token() ->
TokenUrl = imdsv2_token_url(),
?LOG_INFO("Attempting to obtain EC2 IMDSv2 token from ~tp ...", [TokenUrl]),
case
httpc:request(
put,
{TokenUrl, [{?METADATA_TOKEN_TTL_HEADER, integer_to_list(?METADATA_TOKEN_TTL_SECONDS)}]},
[{timeout, ?DEFAULT_HTTP_TIMEOUT}],
[]
)
of
{ok, {{_, 200, _}, _, Value}} ->
?LOG_DEBUG("Successfully obtained EC2 IMDSv2 token."),
Value;
{error, {{_, 400, _}, _, _}} ->
?LOG_WARNING(
"Failed to obtain EC2 IMDSv2 token: Missing or Invalid Parameters The PUT request is not valid."
),
undefined;
Other ->
% Parse metadata service URL
{Host, Port, Path} = rabbitmq_aws:parse_uri(TokenUrl),
% Simple Gun connection for metadata service
% HTTP only, no TLS
Opts = #{transport => tcp, protocols => [http]},
case gun:open(Host, Port, Opts) of
{ok, ConnPid} ->
case gun:await_up(ConnPid, 5000) of
{ok, _Protocol} ->
% PUT request with IMDSv2 token TTL header
Headers = [
{?METADATA_TOKEN_TTL_HEADER, integer_to_list(?METADATA_TOKEN_TTL_SECONDS)}
],
StreamRef = gun:put(ConnPid, Path, Headers, <<>>),
Result =
case gun:await(ConnPid, StreamRef, ?DEFAULT_HTTP_TIMEOUT) of
{response, fin, 200, _RespHeaders} ->
?LOG_DEBUG("Successfully obtained EC2 IMDSv2 token."),
% Empty body for fin response
<<>>;
{response, nofin, 200, _RespHeaders} ->
{ok, Body} = gun:await_body(
ConnPid, StreamRef, ?DEFAULT_HTTP_TIMEOUT
),
?LOG_DEBUG("Successfully obtained EC2 IMDSv2 token."),
binary_to_list(Body);
{response, _, 400, _RespHeaders} ->
?LOG_WARNING(
"Failed to obtain EC2 IMDSv2 token: Missing or Invalid Parameters The PUT request is not valid."
),
undefined;
{error, Reason} ->
?LOG_WARNING(
get_instruction_on_instance_metadata_error(
"Failed to obtain EC2 IMDSv2 token: ~tp. "
"Falling back to EC2 IMDSv1 for now. It is recommended to use EC2 IMDSv2."
),
[Reason]
),
undefined;
Other ->
?LOG_WARNING(
get_instruction_on_instance_metadata_error(
"Failed to obtain EC2 IMDSv2 token: ~tp. "
"Falling back to EC2 IMDSv1 for now. It is recommended to use EC2 IMDSv2."
),
[Other]
),
undefined
end,
gun:close(ConnPid),
Result;
{error, Reason} ->
gun:close(ConnPid),
?LOG_WARNING(
get_instruction_on_instance_metadata_error(
"Failed to connect for EC2 IMDSv2 token: ~tp. "
"Falling back to EC2 IMDSv1 for now. It is recommended to use EC2 IMDSv2."
),
[Reason]
),
undefined
end;
{error, Reason} ->
?LOG_WARNING(
get_instruction_on_instance_metadata_error(
"Failed to obtain EC2 IMDSv2 token: ~tp. "
"Failed to open connection for EC2 IMDSv2 token: ~tp. "
"Falling back to EC2 IMDSv1 for now. It is recommended to use EC2 IMDSv2."
),
[Other]
[Reason]
),
undefined
end.

View File

@ -8,7 +8,7 @@
-module(rabbitmq_aws_sign).
%% API
-export([headers/1, request_hash/5]).
-export([headers/1, headers/2, request_hash/5]).
%% Export all for unit tests
-ifdef(TEST).
@ -24,13 +24,19 @@
%% @doc Create the signed request headers
%% end
headers(Request) ->
headers(Request, undefined).
headers(Request, undefined) ->
headers(Request, sha256(Request#request.body));
headers(Request, PayloadHash) ->
RequestTimestamp = local_time(),
PayloadHash = sha256(Request#request.body),
URI = rabbitmq_aws_urilib:parse(Request#request.uri),
{_, Host, _} = URI#uri.authority,
BodyLength = iolist_size(Request#request.body),
Headers = append_headers(
RequestTimestamp,
length(Request#request.body),
BodyLength,
PayloadHash,
Host,
Request#request.security_token,
@ -41,7 +47,7 @@ headers(Request) ->
URI#uri.path,
URI#uri.query,
Headers,
Request#request.body
PayloadHash
),
AuthValue = authorization(
Request#request.access_key,
@ -202,11 +208,11 @@ query_string(QueryArgs) -> rabbitmq_aws_urilib:build_query_string(lists:keysort(
Path :: path(),
QArgs :: query_args(),
Headers :: headers(),
Payload :: string()
PayloadHash :: string()
) -> string().
%% @doc Create the request hash value
%% @end
request_hash(Method, Path, QArgs, Headers, Payload) ->
request_hash(Method, Path, QArgs, Headers, PayloadHash) ->
RawPath =
case string:slice(Path, 0, 1) of
"/" -> Path;
@ -220,7 +226,7 @@ request_hash(Method, Path, QArgs, Headers, Payload) ->
query_string(QArgs),
canonical_headers(Headers),
signed_headers(Headers),
sha256(Payload)
PayloadHash
],
"\n"
),
@ -236,7 +242,7 @@ request_hash(Method, Path, QArgs, Headers, Payload) ->
scope(AMZDate, Region, Service) ->
string:join([AMZDate, Region, Service, "aws4_request"], "/").
-spec sha256(Value :: string()) -> string().
-spec sha256(Value :: iodata()) -> string().
%% @doc Return the SHA-256 hash for the specified value.
%% @end
sha256(Value) ->

View File

@ -11,6 +11,8 @@
-include_lib("xmerl/include/xmerl.hrl").
-spec parse(Value :: string() | binary()) -> list().
parse(Value) when is_binary(Value) ->
parse(binary_to_list(Value));
parse(Value) ->
{Element, _} = xmerl_scan:string(Value),
parse_node(Element).

View File

@ -120,10 +120,10 @@ credentials_test_() ->
{
foreach,
fun() ->
meck:new(httpc),
meck:new(rabbitmq_aws),
meck:new(gun, []),
meck:new(rabbitmq_aws, [passthrough]),
reset_environment(),
[httpc, rabbitmq_aws]
[gun, rabbitmq_aws]
end,
fun meck:unload/1,
[
@ -222,13 +222,26 @@ credentials_test_() ->
{"from instance metadata service", fun() ->
CredsBody =
"{\n \"Code\" : \"Success\",\n \"LastUpdated\" : \"2016-03-31T21:51:49Z\",\n \"Type\" : \"AWS-HMAC\",\n \"AccessKeyId\" : \"ASIAIMAFAKEACCESSKEY\",\n \"SecretAccessKey\" : \"2+t64tZZVaz0yp0x1G23ZRYn+FAKEyVALUEs/4qh\",\n \"Token\" : \"FAKE//////////wEAK/TOKEN/VALUE=\",\n \"Expiration\" : \"2016-04-01T04:13:28Z\"\n}",
meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end),
meck:expect(gun, close, fun(_) -> ok end),
meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end),
meck:sequence(gun, get, 3, [stream_ref1, stream_ref2]),
meck:sequence(
httpc,
request,
4,
gun,
await,
3,
[
{ok, {{protocol, 200, message}, headers, "Bob"}},
{ok, {{protocol, 200, message}, headers, CredsBody}}
{response, nofin, 200, headers},
{response, nofin, 200, headers}
]
),
meck:sequence(
gun,
await_body,
3,
[
{ok, <<"Bob">>},
{ok, list_to_binary(CredsBody)}
]
),
meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined),
@ -239,41 +252,59 @@ credentials_test_() ->
end},
{"with instance metadata service role error", fun() ->
meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined),
meck:expect(httpc, request, 4, {error, timeout}),
meck:expect(gun, open, fun(_, _, _) -> {error, timeout} end),
?assertEqual({error, undefined}, rabbitmq_aws_config:credentials())
end},
{"with instance metadata service role http error", fun() ->
meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined),
meck:expect(
httpc,
request,
4,
{ok, {{protocol, 500, message}, headers, "Internal Server Error"}}
),
meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end),
meck:expect(gun, close, fun(_) -> ok end),
meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end),
meck:expect(gun, get, fun(_, _, _) -> stream_ref end),
meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 500, headers} end),
meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"Internal Server Error">>} end),
?assertEqual({error, undefined}, rabbitmq_aws_config:credentials())
end},
{"with instance metadata service credentials error", fun() ->
meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined),
meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end),
meck:expect(gun, close, fun(_) -> ok end),
meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end),
meck:sequence(gun, get, 3, [stream_ref1, stream_ref2]),
meck:sequence(
httpc,
request,
4,
gun,
await,
3,
[
{ok, {{protocol, 200, message}, headers, "Bob"}},
{response, nofin, 200, headers},
{error, timeout}
]
),
meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"Bob">>} end),
?assertEqual({error, undefined}, rabbitmq_aws_config:credentials())
end},
{"with instance metadata service credentials not found", fun() ->
meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined),
meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end),
meck:expect(gun, close, fun(_) -> ok end),
meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end),
meck:sequence(gun, get, 3, [stream_ref1, stream_ref2]),
meck:sequence(
httpc,
request,
4,
gun,
await,
3,
[
{ok, {{protocol, 200, message}, headers, "Bob"}},
{ok, {{protocol, 404, message}, headers, "File Not Found"}}
{response, nofin, 200, headers},
{response, nofin, 404, headers}
]
),
meck:sequence(
gun,
await_body,
3,
[
{ok, <<"Bob">>},
{ok, <<"File Not Found">>}
]
),
?assertEqual({error, undefined}, rabbitmq_aws_config:credentials())
@ -357,10 +388,10 @@ region_test_() ->
{
foreach,
fun() ->
meck:new(httpc),
meck:new(rabbitmq_aws),
meck:new(gun, []),
meck:new(rabbitmq_aws, [passthrough]),
reset_environment(),
[httpc, rabbitmq_aws]
[gun, rabbitmq_aws]
end,
fun meck:unload/1,
[
@ -383,12 +414,12 @@ region_test_() ->
end},
{"from instance metadata service", fun() ->
meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined),
meck:expect(
httpc,
request,
4,
{ok, {{protocol, 200, message}, headers, "us-west-1a"}}
),
meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end),
meck:expect(gun, close, fun(_) -> ok end),
meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end),
meck:expect(gun, get, fun(_, _, _) -> stream_ref end),
meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 200, headers} end),
meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"us-west-1a">>} end),
?assertEqual({ok, "us-west-1"}, rabbitmq_aws_config:region())
end},
{"full lookup failure", fun() ->
@ -397,12 +428,12 @@ region_test_() ->
end},
{"http error failure", fun() ->
meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined),
meck:expect(
httpc,
request,
4,
{ok, {{protocol, 500, message}, headers, "Internal Server Error"}}
),
meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end),
meck:expect(gun, close, fun(_) -> ok end),
meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end),
meck:expect(gun, get, fun(_, _, _) -> stream_ref end),
meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 500, headers} end),
meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"Internal Server Error">>} end),
?assertEqual({ok, ?DEFAULT_REGION}, rabbitmq_aws_config:region())
end}
]
@ -412,32 +443,41 @@ instance_id_test_() ->
{
foreach,
fun() ->
meck:new(httpc),
meck:new(rabbitmq_aws),
meck:new(gun, []),
meck:new(rabbitmq_aws, [passthrough]),
reset_environment(),
[httpc, rabbitmq_aws]
[gun, rabbitmq_aws]
end,
fun meck:unload/1,
[
{"get instance id successfully", fun() ->
meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined),
meck:expect(
httpc, request, 4, {ok, {{protocol, 200, message}, headers, "instance-id"}}
),
meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end),
meck:expect(gun, close, fun(_) -> ok end),
meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end),
meck:expect(gun, get, fun(_, _, _) -> stream_ref end),
meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 200, headers} end),
meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"instance-id">>} end),
?assertEqual({ok, "instance-id"}, rabbitmq_aws_config:instance_id())
end},
{"getting instance id is rejected with invalid token error", fun() ->
meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, "invalid"),
meck:expect(
httpc, request, 4, {error, {{protocol, 401, message}, headers, "Invalid token"}}
),
meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end),
meck:expect(gun, close, fun(_) -> ok end),
meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end),
meck:expect(gun, get, fun(_, _, _) -> stream_ref end),
meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 401, headers} end),
meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"Invalid token">>} end),
?assertEqual({error, undefined}, rabbitmq_aws_config:instance_id())
end},
{"getting instance id is rejected with access denied error", fun() ->
meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, "expired token"),
meck:expect(
httpc, request, 4, {error, {{protocol, 403, message}, headers, "access denied"}}
),
meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end),
meck:expect(gun, close, fun(_) -> ok end),
meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end),
meck:expect(gun, get, fun(_, _, _) -> stream_ref end),
meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 403, headers} end),
meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"access denied">>} end),
?assertEqual({error, undefined}, rabbitmq_aws_config:instance_id())
end}
]
@ -447,36 +487,34 @@ load_imdsv2_token_test_() ->
{
foreach,
fun() ->
meck:new(httpc),
[httpc]
meck:new(gun, []),
[gun]
end,
fun meck:unload/1,
[
{"fail to get imdsv2 token - timeout", fun() ->
meck:expect(httpc, request, 4, {error, timeout}),
meck:expect(gun, open, fun(_, _, _) -> {error, timeout} end),
?assertEqual(undefined, rabbitmq_aws_config:load_imdsv2_token())
end},
{"fail to get imdsv2 token - PUT request is not valid", fun() ->
meck:expect(
httpc,
request,
4,
{error, {
{protocol, 400, messge},
headers,
"Missing or Invalid Parameters The PUT request is not valid."
}}
),
meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end),
meck:expect(gun, close, fun(_) -> ok end),
meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end),
meck:expect(gun, put, fun(_, _, _, _) -> stream_ref end),
meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 400, headers} end),
meck:expect(gun, await_body, fun(_, _, _) ->
{ok, <<"Missing or Invalid Parameters The PUT request is not valid.">>}
end),
?assertEqual(undefined, rabbitmq_aws_config:load_imdsv2_token())
end},
{"successfully get imdsv2 token from instance metadata service", fun() ->
IMDSv2Token = "super_secret_token_value",
meck:sequence(
httpc,
request,
4,
[{ok, {{protocol, 200, message}, headers, IMDSv2Token}}]
),
meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end),
meck:expect(gun, close, fun(_) -> ok end),
meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end),
meck:expect(gun, put, fun(_, _, _, _) -> stream_ref end),
meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 200, headers} end),
meck:expect(gun, await_body, fun(_, _, _) -> {ok, list_to_binary(IMDSv2Token)} end),
?assertEqual(IMDSv2Token, rabbitmq_aws_config:load_imdsv2_token())
end}
]
@ -486,7 +524,7 @@ maybe_imdsv2_token_headers_test_() ->
{
foreach,
fun() ->
meck:new(rabbitmq_aws),
meck:new(rabbitmq_aws, [passthrough]),
[rabbitmq_aws]
end,
fun meck:unload/1,
@ -516,7 +554,7 @@ reset_environment() ->
"AWS_SHARED_CREDENTIALS_FILE",
"bad_credentials.ini"
),
meck:expect(httpc, request, 4, {error, timeout}).
meck:expect(gun, open, fun(_, _, _) -> {error, timeout} end).
setup_test_config_env_var() ->
setup_test_file_with_env_var("AWS_CONFIG_FILE", "test_aws_config.ini").

View File

@ -46,7 +46,13 @@ init_test_() ->
]}.
terminate_test() ->
?assertEqual(ok, rabbitmq_aws:terminate(foo, bar)).
?assertEqual(
ok,
rabbitmq_aws:terminate(
foo,
{state, undefined, undefined, undefined, undefined, "us-west-3", undefined, test_result}
)
).
code_change_test() ->
?assertEqual({ok, {state, denial}}, rabbitmq_aws:code_change(foo, bar, {state, denial})).
@ -133,9 +139,11 @@ format_response_test_() ->
{"ok", fun() ->
Response =
{ok, {
{"HTTP/1.1", 200, "Ok"}, [{"Content-Type", "text/xml"}], "<test>Value</test>"
{"HTTP/1.1", 200, "Ok"},
[{<<"Content-Type">>, <<"text/xml">>}],
"<test>Value</test>"
}},
Expectation = {ok, {[{"Content-Type", "text/xml"}], [{"test", "Value"}]}},
Expectation = {ok, {[{<<"Content-Type">>, <<"text/xml">>}], [{"test", "Value"}]}},
?assertEqual(Expectation, rabbitmq_aws:format_response(Response))
end},
{"error", fun() ->
@ -161,8 +169,8 @@ gen_server_call_test_() ->
os:putenv("AWS_DEFAULT_REGION", "us-west-3"),
os:putenv("AWS_ACCESS_KEY_ID", "Sésame"),
os:putenv("AWS_SECRET_ACCESS_KEY", "ouvre-toi"),
meck:new(httpc, []),
[httpc]
meck:new(gun, []),
[gun]
end,
fun(Mods) ->
meck:unload(Mods),
@ -186,31 +194,41 @@ gen_server_call_test_() ->
Body = "",
Options = [],
Host = undefined,
meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end),
meck:expect(gun, close, fun(_) -> ok end),
meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end),
meck:expect(
httpc,
request,
fun(
get,
{"https://ec2.us-east-1.amazonaws.com/?Action=DescribeTags&Version=2015-10-01",
_Headers},
_Options,
[]
) ->
{ok, {
{"HTTP/1.0", 200, "OK"},
[{"content-type", "application/json"}],
"{\"pass\": true}"
}}
gun,
get,
fun(_Pid, _Path, _Headers) -> nofin end
),
%% {ok, {{"HTTP/1.0", 200, "OK"}, [{"content-type", "application/json"}], "{\"pass\": true}"}}
%% end),
meck:expect(
gun,
await,
fun(_Pid, _, _) ->
{response, nofin, 200, [{<<"content-type">>, <<"application/json">>}]}
end
),
meck:expect(
gun,
await_body,
fun(_Pid, _, _) -> {ok, <<"{\"pass\": true}">>} end
),
%% {ok, {{"HTTP/1.0", 200, "OK"}, [{"content-type", "application/json"}], "{\"pass\": true}"}}
%% end),
Expectation =
{reply, {ok, {[{"content-type", "application/json"}], [{"pass", true}]}},
{reply,
{ok,
{[{<<"content-type">>, <<"application/json">>}], [{"pass", true}]}},
State},
Result = rabbitmq_aws:handle_call(
{request, Service, Method, Headers, Path, Body, Options, Host}, eunit, State
),
?assertEqual(Expectation, Result),
meck:validate(httpc)
meck:validate(gun)
end
},
{
@ -388,9 +406,9 @@ perform_request_test_() ->
{
foreach,
fun() ->
meck:new(httpc, []),
meck:new(gun, []),
meck:new(rabbitmq_aws_config, []),
[httpc, rabbitmq_aws_config]
[gun, rabbitmq_aws_config]
end,
fun meck:unload/1,
[
@ -411,33 +429,37 @@ perform_request_test_() ->
Host = undefined,
ExpectURI =
"https://ec2.us-east-1.amazonaws.com/?Action=DescribeTags&Version=2015-10-01",
meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end),
meck:expect(gun, close, fun(_) -> ok end),
meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end),
meck:expect(
httpc,
request,
fun(get, {URI, _Headers}, _Options, []) ->
case URI of
ExpectURI ->
{ok, {
{"HTTP/1.0", 200, "OK"},
[{"content-type", "application/json"}],
"{\"pass\": true}"
}};
_ ->
{ok,
{{"HTTP/1.0", 400, "RequestFailure",
[{"content-type", "application/json"}],
"{\"pass\": false}"}}}
end
gun,
get,
fun(_Pid, "/?Action=DescribeTags&Version=2015-10-01", _Headers) -> nofin end
),
meck:expect(
gun,
await,
fun(_Pid, _, _) ->
{response, nofin, 200, [{<<"content-type">>, <<"application/json">>}]}
end
),
meck:expect(
gun,
await_body,
fun(_Pid, _, _) -> {ok, <<"{\"pass\": true}">>} end
),
Expectation = {
{ok, {[{"content-type", "application/json"}], [{"pass", true}]}}, State
{ok, {[{<<"content-type">>, <<"application/json">>}], [{"pass", true}]}},
State
},
Result = rabbitmq_aws:perform_request(
State, Service, Method, Headers, Path, Body, Options, Host
),
?assertEqual(Expectation, Result),
meck:validate(httpc)
meck:validate(gun)
end
},
{
@ -451,19 +473,11 @@ perform_request_test_() ->
Body = "",
Options = [],
Host = undefined,
meck:expect(httpc, request, fun(get, {_URI, _Headers}, _Options, []) ->
{ok, {
{"HTTP/1.0", 400, "RequestFailure"},
[{"content-type", "application/json"}],
"{\"pass\": false}"
}}
end),
Expectation = {{error, {credentials, State#state.error}}, State},
Result = rabbitmq_aws:perform_request(
State, Service, Method, Headers, Path, Body, Options, Host
),
?assertEqual(Expectation, Result),
meck:validate(httpc)
?assertEqual(Expectation, Result)
end
},
{
@ -554,9 +568,9 @@ api_get_request_test_() ->
{
foreach,
fun() ->
meck:new(httpc, []),
meck:new(gun, []),
meck:new(rabbitmq_aws_config, []),
[httpc, rabbitmq_aws_config]
[gun, rabbitmq_aws_config]
end,
fun meck:unload/1,
[
@ -567,23 +581,34 @@ api_get_request_test_() ->
region = "us-east-1",
expiration = {{3016, 4, 1}, {12, 0, 0}}
},
meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end),
meck:expect(gun, close, fun(_) -> ok end),
meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end),
meck:expect(
httpc,
request,
4,
{ok, {
{"HTTP/1.0", 200, "OK"},
[{"content-type", "application/json"}],
"{\"data\": \"value\"}"
}}
gun,
get,
fun(_Pid, _Path, _Headers) -> nofin end
),
meck:expect(
gun,
await,
fun(_Pid, _, _) ->
{response, nofin, 200, [{<<"content-type">>, <<"application/json">>}]}
end
),
meck:expect(
gun,
await_body,
fun(_Pid, _, _) -> {ok, <<"{\"data\": \"value\"}">>} end
),
{ok, Pid} = rabbitmq_aws:start_link(),
rabbitmq_aws:set_region("us-east-1"),
rabbitmq_aws:set_credentials(State),
Result = rabbitmq_aws:api_get_request("AWS", "API"),
ok = gen_server:stop(Pid),
?assertEqual({ok, [{"data", "value"}]}, Result),
meck:validate(httpc)
meck:validate(gun)
end},
{"AWS service API request failed - credentials", fun() ->
meck:expect(rabbitmq_aws_config, credentials, 0, {error, undefined}),
@ -600,14 +625,27 @@ api_get_request_test_() ->
region = "us-east-1",
expiration = {{3016, 4, 1}, {12, 0, 0}}
},
meck:expect(httpc, request, 4, {error, "network error"}),
meck:expect(gun, open, fun(_, _, _) -> {ok, spawn(fun() -> ok end)} end),
meck:expect(gun, close, fun(_) -> ok end),
meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end),
meck:expect(
gun,
get,
fun(_Pid, _Path, _Headers) -> nofin end
),
meck:expect(
gun,
await,
fun(_Pid, _, _) -> {error, "network error"} end
),
{ok, Pid} = rabbitmq_aws:start_link(),
rabbitmq_aws:set_region("us-east-1"),
rabbitmq_aws:set_credentials(State),
Result = rabbitmq_aws:api_get_request_with_retries("AWS", "API", 3, 1),
ok = gen_server:stop(Pid),
?assertEqual({error, "AWS service is unavailable"}, Result),
meck:validate(httpc)
meck:validate(gun)
end},
{"AWS service API request succeeded after a transient error", fun() ->
State = #state{
@ -616,22 +654,35 @@ api_get_request_test_() ->
region = "us-east-1",
expiration = {{3016, 4, 1}, {12, 0, 0}}
},
meck:expect(gun, open, fun(_, _, _) -> {ok, spawn(fun() -> ok end)} end),
meck:expect(gun, close, fun(_) -> ok end),
meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end),
meck:expect(
httpc,
request,
4,
gun,
get,
fun(_Pid, _Path, _Headers) -> nofin end
),
%% meck:expect(gun, get, 3, meck:seq(
%% fun(_Pid, _Path, _Headers) -> {error, "network errors"} end),
meck:expect(
gun,
await,
3,
meck:seq([
{error, "network error"},
{ok, {
{"HTTP/1.0", 500, "OK"},
[{"content-type", "application/json"}],
"{\"error\": \"server error\"}"
}},
{ok, {
{"HTTP/1.0", 200, "OK"},
[{"content-type", "application/json"}],
"{\"data\": \"value\"}"
}}
{response, nofin, 500, [{<<"content-type">>, <<"application/json">>}]},
{response, nofin, 200, [{<<"content-type">>, <<"application/json">>}]}
])
),
meck:expect(
gun,
await_body,
3,
meck:seq([
{ok, <<"{\"error\": \"server error\"}">>},
{ok, <<"{\"data\": \"value\"}">>}
])
),
{ok, Pid} = rabbitmq_aws:start_link(),
@ -640,7 +691,7 @@ api_get_request_test_() ->
Result = rabbitmq_aws:api_get_request_with_retries("AWS", "API", 3, 1),
ok = gen_server:stop(Pid),
?assertEqual({ok, [{"data", "value"}]}, Result),
meck:validate(httpc)
meck:validate(gun)
end}
]
}.