Add permission tests for stream management

This commit is contained in:
Arnaud Cogoluègnes 2020-11-23 16:32:14 +01:00
parent 0d2bb8edb0
commit ed12ae0bbf
No known key found for this signature in database
GPG Key ID: D5C8C4DFAD43AFA8
7 changed files with 184 additions and 23 deletions

1
.gitignore vendored
View File

@ -40,6 +40,7 @@
!/deps/rabbitmq_shovel_management/
!/deps/rabbitmq_stomp/
!/deps/rabbitmq_stream/
!/deps/rabbitmq_stream_management/
!/deps/rabbitmq_top/
!/deps/rabbitmq_tracing/
!/deps/rabbitmq_trust_store/

View File

@ -15,6 +15,7 @@
is_authorized_user/4,
is_authorized_monitor/2, is_authorized_policies/2,
is_authorized_vhost_visible/2,
is_authorized_vhost_visible_for_monitoring/2,
is_authorized_global_parameters/2]).
-export([bad_request/3, bad_request_exception/4, internal_server_error/4,
@ -117,6 +118,15 @@ is_authorized_vhost_visible(ReqData, Context) ->
is_admin(Tags) orelse user_matches_vhost_visible(ReqData, User)
end).
is_authorized_vhost_visible_for_monitoring(ReqData, Context) ->
is_authorized(ReqData, Context,
<<"User not authorised to access virtual host">>,
fun(#user{tags = Tags} = User) ->
is_admin(Tags)
orelse is_monitor(Tags)
orelse user_matches_vhost_visible(ReqData, User)
end).
disable_stats(ReqData) ->
MgmtOnly = case qs_val(<<"disable_stats">>, ReqData) of
<<"true">> -> true;

View File

@ -13,6 +13,7 @@
-export([init/2, to_json/2, content_types_provided/2, resource_exists/2, is_authorized/2]).
-include_lib("rabbitmq_management_agent/include/rabbit_mgmt_records.hrl").
-include_lib("amqp_client/include/amqp_client.hrl").
dispatcher() -> [{"/stream/connections/:vhost", ?MODULE, []}].
@ -40,7 +41,7 @@ to_json(ReqData, Context) ->
end.
is_authorized(ReqData, Context) ->
rabbit_mgmt_util:is_authorized_vhost(ReqData, Context).
rabbit_mgmt_util:is_authorized_vhost_visible_for_monitoring(ReqData, Context).
augmented(ReqData, Context) ->
rabbit_mgmt_util:filter_conn_ch_list(

View File

@ -64,6 +64,22 @@ end_per_testcase(Testcase, Config) ->
%% -------------------------------------------------------------------
stream_management(Config) ->
UserManagement = <<"user-management">>,
UserMonitoring = <<"user-monitoring">>,
Vhost1 = <<"vh1">>,
Vhost2 = <<"vh2">>,
rabbit_ct_broker_helpers:add_user(Config, UserManagement),
rabbit_ct_broker_helpers:set_user_tags(Config, 0, UserManagement, [management]),
rabbit_ct_broker_helpers:add_user(Config, UserMonitoring),
rabbit_ct_broker_helpers:set_user_tags(Config, 0, UserMonitoring, [monitoring]),
rabbit_ct_broker_helpers:add_vhost(Config, Vhost1),
rabbit_ct_broker_helpers:add_vhost(Config, Vhost2),
rabbit_ct_broker_helpers:set_full_permissions(Config, UserManagement, Vhost1),
rabbit_ct_broker_helpers:set_full_permissions(Config, UserMonitoring, Vhost1),
rabbit_ct_broker_helpers:set_full_permissions(Config, <<"guest">>, Vhost1),
rabbit_ct_broker_helpers:set_full_permissions(Config, <<"guest">>, Vhost2),
StreamPortNode = get_stream_port(Config),
ManagementPortNode = get_management_port(Config),
DataDir = rabbit_ct_helpers:get_config(Config, data_dir),

View File

@ -0,0 +1,3 @@
/build/
/lib/
/target/

View File

@ -29,12 +29,14 @@ import com.rabbitmq.stream.impl.Client.ClientParameters;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicBoolean;
import okhttp3.Credentials;
import java.util.stream.Collectors;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
@ -45,25 +47,23 @@ import org.junit.jupiter.api.extension.ExtendWith;
@ExtendWith(TestUtils.StreamTestInfrastructureExtension.class)
public class HttpTest {
static OkHttpClient httpClient =
new OkHttpClient.Builder()
.authenticator(
(route, response) ->
response
.request()
.newBuilder()
.header("Authorization", Credentials.basic("guest", "guest"))
.build())
.build();
static OkHttpClient httpClient = httpClient("guest");
static Gson gson = new GsonBuilder().create();
ClientFactory cf;
static OkHttpClient httpClient(String usernamePassword) {
return new OkHttpClient.Builder()
.authenticator(TestUtils.authenticator(usernamePassword))
.build();
}
static String get(String endpoint) throws IOException {
Request request =
new Request.Builder()
.url("http://localhost:" + TestUtils.managementPort() + "/api" + endpoint)
.build();
try (Response response = httpClient.newCall(request).execute()) {
return get(httpClient, endpoint);
}
static String get(OkHttpClient client, String endpoint) throws IOException {
Request request = new Request.Builder().url(url(endpoint)).build();
try (Response response = client.newCall(request).execute()) {
if (!response.isSuccessful()) throw new IOException("Unexpected code " + response);
String body = response.body().string();
@ -71,12 +71,12 @@ public class HttpTest {
}
}
static String url(String endpoint) {
return "http://localhost:" + TestUtils.managementPort() + "/api" + endpoint;
}
static void delete(String endpoint) throws IOException {
Request request =
new Request.Builder()
.delete()
.url("http://localhost:" + TestUtils.managementPort() + "/api" + endpoint)
.build();
Request request = new Request.Builder().delete().url(url(endpoint)).build();
try (Response response = httpClient.newCall(request).execute()) {
if (!response.isSuccessful()) throw new IOException("Unexpected code " + response);
}
@ -91,7 +91,7 @@ public class HttpTest {
}
@Test
void http() throws Exception {
void connections() throws Exception {
Callable<List<Map<String, Object>>> request = () -> toMaps(get("/stream/connections"));
int initialCount = request.call().size();
String connectionProvidedName = UUID.randomUUID().toString();
@ -134,5 +134,120 @@ public class HttpTest {
waitUntil(() -> closed.get());
assertThatThrownBy(() -> cRequest.call()).isInstanceOf(IOException.class);
waitUntil(() -> request.call().size() == initialCount);
}
@Test
void permissions() throws Exception {
String[][] vhostsUsers =
new String[][] {
{"/", "guest"},
{"vh1", "user-management"},
{"vh1", "user-management"},
{"vh2", "guest"},
{"vh2", "guest"},
};
List<Client> clients =
Arrays.stream(vhostsUsers)
.map(
vhostUser ->
cf.get(
new ClientParameters()
.virtualHost(vhostUser[0])
.username(vhostUser[1])
.password(vhostUser[1])))
.collect(Collectors.toList());
Callable<List<Map<String, Object>>> allConnectionsRequest =
() -> toMaps(get("/stream/connections"));
int initialCount = allConnectionsRequest.call().size();
waitUntil(() -> allConnectionsRequest.call().size() == initialCount + 5);
String vhost1ConnectionName =
toMaps(get("/stream/connections/vh1")).stream()
.filter(c -> "vh1".equals(c.get("vhost")))
.map(c -> c.get("name").toString())
.findFirst()
.get();
String vhost2ConnectionName =
toMaps(get("/stream/connections/vh2")).stream()
.filter(c -> "vh2".equals(c.get("vhost")))
.map(c -> c.get("name").toString())
.findFirst()
.get();
class TestConfiguration {
final String user;
final Map<String, Integer> connectionRequests;
final Map<String, Boolean> vhostConnections;
TestConfiguration(String user, Object[] connectionRequests, Object... vhostConnections) {
this.user = user;
this.connectionRequests = new LinkedHashMap<>(connectionRequests.length / 2);
for (int i = 0; i < connectionRequests.length; i = i + 2) {
this.connectionRequests.put(
connectionRequests[i].toString(), (Integer) connectionRequests[i + 1]);
}
this.vhostConnections = new LinkedHashMap<>();
for (int i = 0; i < vhostConnections.length; i = i + 2) {
this.vhostConnections.put(
vhostConnections[i].toString(), (Boolean) vhostConnections[i + 1]);
}
}
}
TestConfiguration[] testConfigurations =
new TestConfiguration[] {
new TestConfiguration(
"guest",
new Object[] {"", 5, "/%2f", 1, "/vh1", 2, "/vh2", 2},
"vh1/" + vhost1ConnectionName,
true,
"vh2/" + vhost2ConnectionName,
true),
new TestConfiguration(
"user-monitoring",
new Object[] {"", 5, "/%2f", 1, "/vh1", 2, "/vh2", 2},
"vh1/" + vhost1ConnectionName,
true,
"vh2/" + vhost2ConnectionName,
true),
new TestConfiguration(
"user-management",
new Object[] {"", 2, "/%2f", -1, "/vh1", 2, "/vh2", -1},
"vh1/" + vhost1ConnectionName,
true,
"vh2/" + vhost2ConnectionName,
false)
};
for (TestConfiguration configuration : testConfigurations) {
OkHttpClient client = httpClient(configuration.user);
for (Entry<String, Integer> request : configuration.connectionRequests.entrySet()) {
if (request.getValue() >= 0) {
System.out.println(request.getKey());
assertThat(toMaps(get(client, "/stream/connections" + request.getKey())))
.hasSize(request.getValue());
} else {
assertThatThrownBy(() -> toMaps(get(client, "/stream/connections" + request.getKey())))
.hasMessageContaining("401");
}
}
for (Entry<String, Boolean> request : configuration.vhostConnections.entrySet()) {
if (request.getValue()) {
Condition<Object> connNameCondition =
new Condition<>(
e -> request.getKey().endsWith(e.toString()), "connection name must match");
assertThat(toMap(get(client, "/stream/connections/" + request.getKey())))
.hasEntrySatisfying("name", connNameCondition);
} else {
assertThatThrownBy(() -> toMap(get(client, "/stream/connections/" + request.getKey())))
.hasMessageContaining("401");
}
}
}
clients.forEach(client -> client.close());
waitUntil(() -> allConnectionsRequest.call().size() == initialCount);
}
}

View File

@ -28,6 +28,8 @@ import java.time.Duration;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import okhttp3.Authenticator;
import okhttp3.Credentials;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
@ -187,4 +189,17 @@ public class TestUtils {
}
}
}
static Authenticator authenticator(String usernamePassword) {
return (route, response) -> {
if (response.request().header("Authorization") != null) {
return null; // Give up, we've already attempted to authenticate.
}
return response
.request()
.newBuilder()
.header("Authorization", Credentials.basic(usernamePassword, usernamePassword))
.build();
};
}
}