Re-evaluate stream permissions after secret update

Re-evaluate permissions, cancel publishers and
subscriptions, send metadata update accordingly.

Move record definitions from the stream reader
to a dedicated header file to be able to write
unit tests.

Fixes #10292
This commit is contained in:
Arnaud Cogoluègnes 2024-01-12 16:41:49 +01:00
parent 33c64d06ea
commit bc2a11d1bd
No known key found for this signature in database
GPG Key ID: D5C8C4DFAD43AFA8
8 changed files with 508 additions and 113 deletions

View File

@ -257,7 +257,7 @@ update_state(User = #user{authz_backends = Backends0}, NewState) ->
permission_cache_can_expire(#user{authz_backends = Backends}) ->
lists:any(fun ({Module, _State}) -> Module:state_can_expire() end, Backends).
-spec expiry_timestamp(User :: rabbit_types:user()) -> integer | never.
-spec expiry_timestamp(User :: rabbit_types:user()) -> integer() | never.
expiry_timestamp(User = #user{authz_backends = Modules}) ->
lists:foldl(fun({Module, Impl}, Ts0) ->
case Module:expiry_timestamp(auth_user(User, Impl)) of

View File

@ -126,6 +126,13 @@ rabbitmq_integration_suite(
name = "rabbit_stream_manager_SUITE",
)
rabbitmq_integration_suite(
name = "rabbit_stream_reader_SUITE",
deps = [
"//deps/rabbitmq_stream_common:erlang_app",
],
)
rabbitmq_integration_suite(
name = "rabbit_stream_SUITE",
shard_count = 3,

View File

@ -95,6 +95,7 @@ def all_srcs(name = "all_srcs"):
)
filegroup(
name = "private_hdrs",
srcs = ["src/rabbit_stream_reader.hrl"],
)
filegroup(
name = "srcs",
@ -175,3 +176,13 @@ def test_suite_beam_files(name = "test_suite_beam_files"):
erlc_opts = "//:test_erlc_opts",
deps = ["//deps/rabbit_common:erlang_app"],
)
erlang_bytecode(
name = "rabbit_stream_reader_SUITE_beam_files",
testonly = True,
srcs = ["test/rabbit_stream_reader_SUITE.erl"],
outs = ["test/rabbit_stream_reader_SUITE.beam"],
hdrs = ["src/rabbit_stream_reader.hrl"],
app_name = "rabbitmq_stream",
erlc_opts = "//:test_erlc_opts",
deps = ["//deps/rabbit_common:erlang_app", "//deps/rabbitmq_stream_common:erlang_app"],
)

View File

@ -9,106 +9,21 @@
%% The Original Code is RabbitMQ.
%%
%% The Initial Developer of the Original Code is Pivotal Software, Inc.
%% Copyright (c) 2020-2023 VMware, Inc. or its affiliates. All rights reserved.
%% Copyright (c) 2020-2024 Broadcom. All Rights Reserved.
%% The term Broadcom refers to Broadcom Inc. and/or its subsidiaries. All rights reserved.
%%
-module(rabbit_stream_reader).
-feature(maybe_expr, enable).
-behaviour(gen_statem).
-include_lib("rabbit_common/include/rabbit.hrl").
-include_lib("rabbitmq_stream_common/include/rabbit_stream.hrl").
-include("rabbit_stream_reader.hrl").
-include("rabbit_stream_metrics.hrl").
-type stream() :: binary().
-type publisher_id() :: byte().
-type publisher_reference() :: binary().
-type subscription_id() :: byte().
-type internal_id() :: integer().
-include_lib("rabbitmq_stream_common/include/rabbit_stream.hrl").
-record(publisher,
{publisher_id :: publisher_id(),
stream :: stream(),
reference :: undefined | publisher_reference(),
leader :: pid(),
message_counters :: atomics:atomics_ref(),
%% use to distinguish a stale publisher from a live publisher with the same ID
%% used only for publishers without a reference (dedup off)
internal_id :: internal_id()}).
-record(consumer_configuration,
{socket :: rabbit_net:socket(), %% ranch_transport:socket(),
member_pid :: pid(),
subscription_id :: subscription_id(),
stream :: stream(),
offset :: osiris:offset(),
counters :: atomics:atomics_ref(),
properties :: map(),
active :: boolean()}).
-record(consumer,
{configuration :: #consumer_configuration{},
credit :: non_neg_integer(),
send_limit :: non_neg_integer(),
log :: undefined | osiris_log:state(),
last_listener_offset = undefined :: undefined | osiris:offset()}).
-record(request,
{start :: integer(),
content :: term()}).
-record(stream_connection_state,
{data :: rabbit_stream_core:state(), blocked :: boolean(),
consumers :: #{subscription_id() => #consumer{}}}).
-record(stream_connection,
{name :: binary(),
%% server host
host,
%% client host
peer_host,
%% server port
port,
%% client port
peer_port,
auth_mechanism,
authentication_state :: any(),
connected_at :: integer(),
helper_sup :: pid(),
socket :: rabbit_net:socket(),
publishers ::
#{publisher_id() =>
#publisher{}}, %% FIXME replace with a list (0-255 lookup faster?)
publisher_to_ids ::
#{{stream(), publisher_reference()} => publisher_id()},
stream_leaders :: #{stream() => pid()},
stream_subscriptions :: #{stream() => [subscription_id()]},
credits :: atomics:atomics_ref(),
user :: undefined | #user{},
virtual_host :: undefined | binary(),
connection_step ::
atom(), % tcp_connected, peer_properties_exchanged, authenticating, authenticated, tuning, tuned, opened, failure, closing, closing_done
frame_max :: integer(),
heartbeat :: undefined | integer(),
heartbeater :: any(),
client_properties = #{} :: #{binary() => binary()},
monitors = #{} :: #{reference() => {pid(), stream()}},
stats_timer :: undefined | rabbit_event:state(),
resource_alarm :: boolean(),
send_file_oct ::
atomics:atomics_ref(), % number of bytes sent with send_file (for metrics)
transport :: tcp | ssl,
proxy_socket :: undefined | ranch_transport:socket(),
correlation_id_sequence :: integer(),
outstanding_requests :: #{integer() => #request{}},
deliver_version :: rabbit_stream_core:command_version(),
request_timeout :: pos_integer(),
outstanding_requests_timer :: undefined | erlang:reference(),
filtering_supported :: boolean(),
%% internal sequence used for publishers
internal_sequence = 0 :: integer()}).
-record(configuration,
{initial_credits :: integer(),
credits_required_for_unblocking :: integer(),
frame_max :: integer(),
heartbeat :: integer(),
connection_negotiation_step_timeout :: integer()}).
-record(statem_data,
{transport :: module(),
connection :: #stream_connection{},
@ -184,6 +99,10 @@
tuned/3,
open/3,
close_sent/3]).
-ifdef(TEST).
-export([ensure_token_expiry_timer/2,
evaluate_state_after_secret_update/4]).
-endif.
callback_mode() ->
[state_functions, state_enter].
@ -999,6 +918,11 @@ open(info, check_outstanding_requests,
),
{keep_state, StatemData#statem_data{connection = Connection1}}
end;
open(info, token_expired, #statem_data{connection = Connection}) ->
_ = demonitor_all_streams(Connection),
rabbit_log_connection:info("Forcing stream connection ~tp closing because token expired",
[self()]),
{stop, {shutdown, <<"Token expired">>}};
open(info, {shutdown, Explanation} = Reason,
#statem_data{connection = Connection}) ->
%% rabbitmq_management or rabbitmq_stream_management plugin
@ -1573,8 +1497,11 @@ handle_frame_pre_auth(Transport,
send(Transport, S, Frame),
%% FIXME check if vhost is alive (see rabbit_reader:is_vhost_alive/2)
Connection#stream_connection{connection_step = opened,
virtual_host = VirtualHost}
{_, Conn} = ensure_token_expiry_timer(User,
Connection#stream_connection{connection_step = opened,
virtual_host = VirtualHost}),
Conn
catch
exit:_ ->
F = rabbit_stream_core:frame({response, CorrelationId,
@ -1648,18 +1575,17 @@ handle_frame_post_auth(Transport,
handle_frame_post_auth(Transport,
#stream_connection{user = #user{username = Username} = _User,
socket = S,
socket = Socket,
host = Host,
auth_mechanism = Auth_Mechanism,
authentication_state = AuthState,
resource_alarm = false} =
C1,
State,
resource_alarm = false} = C1,
S1,
{request, CorrelationId,
{sasl_authenticate, NewMechanism, NewSaslBin}}) ->
rabbit_log:debug("Open frame received sasl_authenticate for username '~ts'", [Username]),
Connection1 =
{Connection1, State1} =
case Auth_Mechanism of
{NewMechanism, AuthMechanism} -> %% Mechanism is the same used during the pre-auth phase
{C2, CmdBody} =
@ -1668,7 +1594,7 @@ handle_frame_post_auth(Transport,
rabbit_core_metrics:auth_attempt_failed(Host,
NewUsername,
stream),
auth_fail(NewUsername, Msg, Args, C1, State),
auth_fail(NewUsername, Msg, Args, C1, S1),
rabbit_log_connection:warning(Msg, Args),
{C1#stream_connection{connection_step = failure},
{sasl_authenticate,
@ -1683,7 +1609,7 @@ handle_frame_post_auth(Transport,
rabbit_misc:format(Msg,
Args)}],
C1,
State),
S1),
rabbit_log_connection:warning(Msg, Args),
{C1#stream_connection{connection_step = failure},
{sasl_authenticate, ?RESPONSE_SASL_ERROR, <<>>}};
@ -1702,7 +1628,7 @@ handle_frame_post_auth(Transport,
user_authentication_success,
[],
C1,
State),
S1),
rabbit_log:debug("Successfully updated secret for username '~ts'", [Username]),
{C1#stream_connection{user = NewUser,
authentication_state = done,
@ -1725,20 +1651,24 @@ handle_frame_post_auth(Transport,
Frame =
rabbit_stream_core:frame({response, CorrelationId,
CmdBody}),
send(Transport, S, Frame),
C2;
send(Transport, Socket, Frame),
case CmdBody of
{sasl_authenticate, ?RESPONSE_CODE_OK, _} ->
#stream_connection{user = NewUsr} = C2,
evaluate_state_after_secret_update(Transport, NewUsr, C2, S1);
_ ->
{C2, S1}
end;
{OtherMechanism, _} ->
rabbit_log_connection:warning("User '~ts' cannot change initial auth mechanism '~ts' for '~ts'",
[Username, NewMechanism, OtherMechanism]),
CmdBody =
{sasl_authenticate, ?RESPONSE_SASL_CANNOT_CHANGE_MECHANISM, <<>>},
Frame = rabbit_stream_core:frame({response, CorrelationId, CmdBody}),
send(Transport, S, Frame),
C1#stream_connection{connection_step = failure}
send(Transport, Socket, Frame),
{C1#stream_connection{connection_step = failure}, S1}
end,
{Connection1, State};
{Connection1, State1};
handle_frame_post_auth(Transport,
#stream_connection{user = User,
publishers = Publishers0,
@ -3244,6 +3174,57 @@ request(Content) ->
#request{start = erlang:monotonic_time(millisecond),
content = Content}.
evaluate_state_after_secret_update(Transport,
User,
#stream_connection{socket = Socket,
publishers = Publishers,
stream_subscriptions = Subscriptions} = Conn0,
State0) ->
{_, Conn1} = ensure_token_expiry_timer(User, Conn0),
rabbit_stream_utils:clear_permission_cache(),
PublisherStreams =
lists:foldl(fun(#publisher{stream = Str}, Acc) ->
case rabbit_stream_utils:check_write_permitted(stream_r(Str, Conn0), User) of
ok ->
Acc;
_ ->
Acc#{Str => ok}
end
end, #{}, maps:values(Publishers)),
{SubscriptionStreams, Conn2, State1} =
maps:fold(fun(Str, Subs, {Acc, C0, S0}) ->
case rabbit_stream_utils:check_read_permitted(stream_r(Str, Conn0), User, #{}) of
ok ->
{Acc, C0, S0};
_ ->
{C1, S1} =
lists:foldl(fun(SubId, {Conn, St}) ->
remove_subscription(SubId, Conn, St)
end, {C0, S0}, Subs),
{Acc#{Str => ok}, C1, S1}
end
end, {#{}, Conn1, State0}, Subscriptions),
Streams = maps:merge(PublisherStreams, SubscriptionStreams),
{Conn3, State2} =
case maps:size(Streams) of
0 ->
{Conn2, State1};
_ ->
maps:fold(fun(Str, _, {C0, S0}) ->
{_, C1, S1} = clean_state_after_stream_deletion_or_failure(
undefined, Str, C0, S0),
Command = {metadata_update, Str,
?RESPONSE_CODE_STREAM_NOT_AVAILABLE},
Frame = rabbit_stream_core:frame(Command),
send(Transport, Socket, Frame),
rabbit_global_counters:increase_protocol_counter(stream,
?STREAM_NOT_AVAILABLE,
1),
{C1, S1}
end, {Conn2, State1}, Streams)
end,
{Conn3, State2}.
ensure_outstanding_requests_timer(#stream_connection{
outstanding_requests = Requests,
outstanding_requests_timer = undefined
@ -3265,6 +3246,33 @@ ensure_outstanding_requests_timer(#stream_connection{
ensure_outstanding_requests_timer(C) ->
C.
ensure_token_expiry_timer(User, #stream_connection{token_expiry_timer = Timer} = Conn) ->
TimerRef =
maybe
rabbit_log:debug("Checking token expiry"),
true ?= rabbit_access_control:permission_cache_can_expire(User),
rabbit_log:debug("Token can expire"),
Ts = rabbit_access_control:expiry_timestamp(User),
rabbit_log:debug("Token expiry timestamp: ~tp", [Ts]),
true ?= is_integer(Ts),
Time = (Ts - os:system_time(second)) * 1000,
rabbit_log:debug("Token expires in ~tp ms, setting timer to close connection", [Time]),
true ?= Time > 0,
erlang:send_after(Time, self(), token_expired)
else
false ->
undefined;
{error, _} ->
undefined
end,
Cancel = case Timer of
undefined ->
ok;
_ ->
erlang:cancel_timer(Timer, [{async, false}, {info, true}])
end,
{Cancel, Conn#stream_connection{token_expiry_timer = TimerRef}}.
maybe_unregister_consumer(_, _, false = _Sac, Requests) ->
Requests;
maybe_unregister_consumer(VirtualHost,

View File

@ -0,0 +1,103 @@
%% The contents of this file are subject to the Mozilla Public License
%% at https://www.mozilla.org/en-US/MPL/2.0/
%%
%% Software distributed under the License is distributed on an "AS IS"
%% basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See
%% the License for the specific language governing rights and
%% limitations under the License.
%%
%% The Original Code is RabbitMQ.
%%
%% The Initial Developer of the Original Code is Pivotal Software, Inc.
%% Copyright (c) 2020-2024 Broadcom. All Rights Reserved.
%% The term Broadcom refers to Broadcom Inc. and/or its subsidiaries. All rights reserved.
%%
-include_lib("rabbit_common/include/rabbit.hrl").
-type stream() :: binary().
-type publisher_id() :: byte().
-type publisher_reference() :: binary().
-type subscription_id() :: byte().
-type internal_id() :: integer().
-record(publisher,
{publisher_id :: publisher_id(),
stream :: stream(),
reference :: undefined | publisher_reference(),
leader :: pid(),
message_counters :: atomics:atomics_ref(),
%% use to distinguish a stale publisher from a live publisher with the same ID
%% used only for publishers without a reference (dedup off)
internal_id :: internal_id()}).
-record(consumer_configuration,
{socket :: rabbit_net:socket(), %% ranch_transport:socket(),
member_pid :: pid(),
subscription_id :: subscription_id(),
stream :: stream(),
offset :: osiris:offset(),
counters :: atomics:atomics_ref(),
properties :: map(),
active :: boolean()}).
-record(consumer,
{configuration :: #consumer_configuration{},
credit :: non_neg_integer(),
send_limit :: non_neg_integer(),
log = undefined :: undefined | osiris_log:state(),
last_listener_offset = undefined :: undefined | osiris:offset()}).
-record(request,
{start :: integer(),
content :: term()}).
-record(stream_connection_state,
{data :: rabbit_stream_core:state(), blocked :: boolean(),
consumers :: #{subscription_id() => #consumer{}}}).
-record(stream_connection,
{name :: binary(),
%% server host
host,
%% client host
peer_host,
%% server port
port,
%% client port
peer_port,
auth_mechanism,
authentication_state :: any(),
connected_at :: integer(),
helper_sup :: pid(),
socket :: rabbit_net:socket(),
publishers = #{} :: #{publisher_id() => #publisher{}},
publisher_to_ids = #{} :: #{{stream(), publisher_reference()} => publisher_id()},
stream_leaders = #{} :: #{stream() => pid()},
stream_subscriptions = #{} :: #{stream() => [subscription_id()]},
credits :: atomics:atomics_ref(),
user :: undefined | #user{},
virtual_host :: undefined | binary(),
connection_step ::
atom(), % tcp_connected, peer_properties_exchanged, authenticating, authenticated, tuning, tuned, opened, failure, closing, closing_done
frame_max :: integer(),
heartbeat :: undefined | integer(),
heartbeater :: any(),
client_properties = #{} :: #{binary() => binary()},
monitors = #{} :: #{reference() => {pid(), stream()}},
stats_timer :: undefined | rabbit_event:state(),
resource_alarm :: boolean(),
send_file_oct ::
atomics:atomics_ref(), % number of bytes sent with send_file (for metrics)
transport :: tcp | ssl,
proxy_socket :: undefined | ranch_transport:socket(),
correlation_id_sequence :: integer(),
outstanding_requests :: #{integer() => #request{}},
deliver_version :: rabbit_stream_core:command_version(),
request_timeout :: pos_integer(),
outstanding_requests_timer :: undefined | erlang:reference(),
filtering_supported :: boolean(),
%% internal sequence used for publishers
internal_sequence = 0 :: integer(),
token_expiry_timer = undefined :: undefined | erlang:reference()}).
-record(configuration,
{initial_credits :: integer(),
credits_required_for_unblocking :: integer(),
frame_max :: integer(),
heartbeat :: integer(),
connection_negotiation_step_timeout :: integer()}).

View File

@ -11,7 +11,7 @@
%% The Original Code is RabbitMQ.
%%
%% The Initial Developer of the Original Code is Pivotal Software, Inc.
%% Copyright (c) 2020-2023 VMware, Inc. or its affiliates. All rights reserved.
%% Copyright (c) 2020-2024 VMware, Inc. or its affiliates. All rights reserved.
%%
-module(rabbit_stream_utils).
@ -35,7 +35,8 @@
filter_spec/1,
command_versions/0,
filtering_supported/0,
check_super_stream_management_permitted/4]).
check_super_stream_management_permitted/4,
clear_permission_cache/0]).
-define(MAX_PERMISSION_CACHE_SIZE, 12).
@ -202,6 +203,10 @@ check_resource_access(User, Resource, Perm, Context) ->
end
end.
clear_permission_cache() ->
erase(permission_cache),
ok.
check_configure_permitted(Resource, User) ->
check_resource_access(User, Resource, configure, #{}).

View File

@ -11,7 +11,8 @@
%% The Original Code is RabbitMQ.
%%
%% The Initial Developer of the Original Code is Pivotal Software, Inc.
%% Copyright (c) 2020-2023 VMware, Inc. or its affiliates. All rights reserved.
%% Copyright (c) 2020-2024 Broadcom. All Rights Reserved.
%% The term Broadcom refers to Broadcom Inc. and/or its subsidiaries. All rights reserved.
%%
-module(rabbit_stream_SUITE).
@ -27,6 +28,7 @@
-compile(export_all).
-import(rabbit_stream_core, [frame/1]).
-import(rabbit_ct_broker_helpers, [rpc/5]).
-define(WAIT, 5000).
@ -55,7 +57,9 @@ groups() ->
max_segment_size_bytes_validation,
close_connection_on_consumer_update_timeout,
set_filter_size,
vhost_queue_limit
vhost_queue_limit,
connection_should_be_closed_on_token_expiry,
should_receive_metadata_update_after_update_secret
]},
%% Run `test_global_counters` on its own so the global metrics are
%% initialised to 0 for each testcase
@ -688,6 +692,75 @@ vhost_queue_limit(Config) ->
ok.
connection_should_be_closed_on_token_expiry(Config) ->
rabbit_ct_broker_helpers:setup_meck(Config),
Mod = rabbit_access_control,
ok = rpc(Config, 0, meck, new, [Mod, [no_link, passthrough]]),
ok = rpc(Config, 0, meck, expect, [Mod, check_user_loopback, 2, ok]),
ok = rpc(Config, 0, meck, expect, [Mod, check_vhost_access, 4, ok]),
ok = rpc(Config, 0, meck, expect, [Mod, permission_cache_can_expire, 1, true]),
Expiry = os:system_time(seconds) + 2,
ok = rpc(Config, 0, meck, expect, [Mod, expiry_timestamp, 1, Expiry]),
T = gen_tcp,
Port = get_port(T, Config),
Opts = get_opts(T),
{ok, S} = T:connect("localhost", Port, Opts),
C = rabbit_stream_core:init(0),
test_peer_properties(T, S, C),
test_authenticate(T, S, C),
closed = wait_for_socket_close(T, S, 10),
ok = rpc(Config, 0, meck, unload, [Mod]).
should_receive_metadata_update_after_update_secret(Config) ->
T = gen_tcp,
Port = get_port(T, Config),
Opts = get_opts(T),
{ok, S} = T:connect("localhost", Port, Opts),
C = rabbit_stream_core:init(0),
test_peer_properties(T, S, C),
test_authenticate(T, S, C),
Prefix = atom_to_binary(?FUNCTION_NAME, utf8),
PublishStream = <<Prefix/binary, <<"-publish">>/binary>>,
test_create_stream(T, S, PublishStream, C),
ConsumeStream = <<Prefix/binary, <<"-consume">>/binary>>,
test_create_stream(T, S, ConsumeStream, C),
test_declare_publisher(T, S, 1, PublishStream, C),
test_subscribe(T, S, 1, ConsumeStream, C),
rabbit_ct_broker_helpers:setup_meck(Config),
Mod = rabbit_stream_utils,
ok = rpc(Config, 0, meck, new, [Mod, [no_link, passthrough]]),
ok = rpc(Config, 0, meck, expect, [Mod, check_write_permitted, 2, error]),
ok = rpc(Config, 0, meck, expect, [Mod, check_read_permitted, 3, error]),
C01 = expect_successful_authentication(try_authenticate(T, S, C, <<"PLAIN">>, <<"guest">>, <<"guest">>)),
{Meta1, C02} = receive_commands(T, S, C01),
{metadata_update, Stream1, ?RESPONSE_CODE_STREAM_NOT_AVAILABLE} = Meta1,
{Meta2, C03} = receive_commands(T, S, C02),
{metadata_update, Stream2, ?RESPONSE_CODE_STREAM_NOT_AVAILABLE} = Meta2,
ImpactedStreams = #{Stream1 => ok, Stream2 => ok},
?assert(maps:is_key(PublishStream, ImpactedStreams)),
?assert(maps:is_key(ConsumeStream, ImpactedStreams)),
test_close(T, S, C03),
closed = wait_for_socket_close(T, S, 10),
ok = rpc(Config, 0, meck, unload, [Mod]),
{ok, S2} = T:connect("localhost", Port, Opts),
C2 = rabbit_stream_core:init(0),
test_peer_properties(T, S2, C2),
test_authenticate(T, S2, C2),
test_delete_stream(T, S2, PublishStream, C2, false),
test_delete_stream(T, S2, ConsumeStream, C2, false),
test_close(T, S2, C2),
closed = wait_for_socket_close(T, S2, 10),
ok.
consumer_count(Config) ->
ets_count(Config, ?TABLE_CONSUMER).

View File

@ -0,0 +1,188 @@
%% The contents of this file are subject to the Mozilla Public License
%% at https://www.mozilla.org/en-US/MPL/2.0/
%%
%% Software distributed under the License is distributed on an "AS IS"
%% basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See
%% the License for the specific language governing rights and
%% limitations under the License.
%%
%% The Original Code is RabbitMQ.
%%
%% The Initial Developer of the Original Code is Pivotal Software, Inc.
%% Copyright (c) 2024 Broadcom. All Rights Reserved.
%% The term Broadcom refers to Broadcom Inc. and/or its subsidiaries. All rights reserved.
%%
-module(rabbit_stream_reader_SUITE).
-compile(export_all).
-include_lib("eunit/include/eunit.hrl").
-include_lib("rabbitmq_stream/src/rabbit_stream_reader.hrl").
-include_lib("rabbitmq_stream_common/include/rabbit_stream.hrl").
-import(rabbit_stream_reader, [ensure_token_expiry_timer/2]).
%%%===================================================================
%%% Common Test callbacks
%%%===================================================================
all() ->
[{group, tests}].
%% replicate eunit like test resolution
all_tests() ->
[F
|| {F, _} <- ?MODULE:module_info(functions),
re:run(atom_to_list(F), "_test$") /= nomatch].
groups() ->
[{tests, [], all_tests()}].
init_per_suite(Config) ->
Config.
end_per_suite(_Config) ->
ok.
init_per_group(_Group, Config) ->
Config.
end_per_group(_Group, _Config) ->
ok.
init_per_testcase(_TestCase, Config) ->
Config.
end_per_testcase(_TestCase, _Config) ->
meck:unload(),
ok.
ensure_token_expiry_timer_test(_) ->
ok = meck:new(rabbit_access_control),
meck:expect(rabbit_access_control, permission_cache_can_expire, fun (_) -> false end),
{_, #stream_connection{token_expiry_timer = TR1}} = ensure_token_expiry_timer(#user{}, #stream_connection{}),
?assertEqual(undefined, TR1),
meck:expect(rabbit_access_control, permission_cache_can_expire, fun (_) -> true end),
meck:expect(rabbit_access_control, expiry_timestamp, fun (_) -> never end),
{_, #stream_connection{token_expiry_timer = TR2}} = ensure_token_expiry_timer(#user{}, #stream_connection{}),
?assertEqual(undefined, TR2),
Now = os:system_time(second),
meck:expect(rabbit_access_control, expiry_timestamp, fun (_) -> Now + 60 end),
{_, #stream_connection{token_expiry_timer = TR3}} = ensure_token_expiry_timer(#user{}, #stream_connection{}),
Cancel3 = erlang:cancel_timer(TR3, [{async, false}, {info, true}]),
?assert(is_integer(Cancel3)),
meck:expect(rabbit_access_control, expiry_timestamp, fun (_) -> Now - 60 end),
{_, #stream_connection{token_expiry_timer = TR4}} = ensure_token_expiry_timer(#user{}, #stream_connection{}),
?assertEqual(undefined, TR4),
DummyTRef = erlang:send_after(1_000 * 1_000, self(), dummy),
meck:expect(rabbit_access_control, permission_cache_can_expire, fun (_) -> false end),
{Cancel5, #stream_connection{token_expiry_timer = TR5}} = ensure_token_expiry_timer(#user{},
#stream_connection{token_expiry_timer = DummyTRef}),
?assertEqual(undefined, TR5),
?assert(is_integer(Cancel5)),
ok.
evaluate_state_after_secret_update_test(_) ->
Mod = rabbit_stream_reader,
meck:new(Mod, [passthrough]),
ModUtils = rabbit_stream_utils,
meck:new(ModUtils, [passthrough]),
CheckFun = fun(N) ->
case binary:match(N, <<"ok_">>) of
nomatch ->
error;
_ ->
ok
end
end,
meck:expect(ModUtils, check_write_permitted, fun(#resource{name = N}, _) -> CheckFun(N) end),
meck:expect(ModUtils, check_read_permitted, fun(#resource{name = N}, _, _) -> CheckFun(N) end),
ModAccess = rabbit_access_control,
meck:new(ModAccess),
meck:expect(ModAccess, permission_cache_can_expire, 1, false),
meck:new(rabbit_stream_metrics, [stub_all]),
meck:new(rabbit_global_counters, [stub_all]),
ModTransport = dummy_transport,
meck:new(ModTransport, [non_strict]),
meck:expect(ModTransport, send, 2, ok),
ModLog = osiris_log,
meck:new(ModLog),
meck:expect(ModLog, init, 1, ok),
put(close_log_count, 0),
meck:expect(ModLog, close, fun(_) -> put(close_log_count, get(close_log_count) + 1) end),
ModCore = rabbit_stream_core,
meck:new(ModCore),
put(metadata_update, []),
meck:expect(ModCore, frame, fun(Cmd) -> put(metadata_update, [Cmd | get(metadata_update)]) end),
Publishers = #{0 => #publisher{stream = <<"ok_publish">>},
1 => #publisher{stream = <<"ko_publish">>},
2 => #publisher{stream = <<"ok_publish_consume">>},
3 => #publisher{stream = <<"ko_publish_consume">>}},
Subscriptions = #{<<"ok_consume">> => [0],
<<"ko_consume">> => [1],
<<"ok_publish_consume">> => [2],
<<"ko_publish_consume">> => [3]},
Consumers = #{0 => consumer(<<"ok_consume">>),
1 => consumer(<<"ko_consume">>),
2 => consumer(<<"ok_publish_consume">>),
3 => consumer(<<"ko_publish_consume">>)},
{C1, S1} = Mod:evaluate_state_after_secret_update(ModTransport, #user{},
#stream_connection{publishers = Publishers,
stream_subscriptions = Subscriptions},
#stream_connection_state{consumers = Consumers}),
meck:validate(ModLog),
?assertEqual(2, get(close_log_count)),
erase(close_log_count),
Cmds = get(metadata_update),
?assertEqual(3, length(Cmds)),
?assert(lists:member({metadata_update, <<"ko_publish">>, ?RESPONSE_CODE_STREAM_NOT_AVAILABLE}, Cmds)),
?assert(lists:member({metadata_update, <<"ko_consume">>, ?RESPONSE_CODE_STREAM_NOT_AVAILABLE}, Cmds)),
?assert(lists:member({metadata_update, <<"ko_publish_consume">>, ?RESPONSE_CODE_STREAM_NOT_AVAILABLE}, Cmds)),
erase(metadata_update),
#stream_connection{token_expiry_timer = TRef1,
publishers = Pubs1,
stream_subscriptions = Subs1} = C1,
?assertEqual(undefined, TRef1), %% no expiry set in the mock
?assertEqual(2, maps:size(Pubs1)),
?assertEqual(#publisher{stream = <<"ok_publish">>}, maps:get(0, Pubs1)),
?assertEqual(#publisher{stream = <<"ok_publish_consume">>}, maps:get(2, Pubs1)),
#stream_connection_state{consumers = Cons1} = S1,
?assertEqual([0], maps:get(<<"ok_consume">>, Subs1)),
?assertEqual([2], maps:get(<<"ok_publish_consume">>, Subs1)),
?assertEqual(consumer(<<"ok_consume">>), maps:get(0, Cons1)),
?assertEqual(consumer(<<"ok_publish_consume">>), maps:get(2, Cons1)),
%% making sure the token expiry timer is set if the token expires
meck:expect(ModAccess, permission_cache_can_expire, 1, true),
Now = os:system_time(second),
meck:expect(rabbit_access_control, expiry_timestamp, fun (_) -> Now + 60 end),
{C2, _} = Mod:evaluate_state_after_secret_update(ModTransport, #user{},
#stream_connection{},
#stream_connection_state{}),
#stream_connection{token_expiry_timer = TRef2} = C2,
Cancel2 = erlang:cancel_timer(TRef2, [{async, false}, {info, true}]),
?assert(is_integer(Cancel2)),
ok.
consumer(S) ->
#consumer{configuration = #consumer_configuration{stream = S},
log = osiris_log:init(#{})}.