Feat_18016:
Moving all group management logic that occurs during user authentication for LDAP, Oauth, and trusted header paths into a single location in order to have an approach that's standard across all three auth paths.
This commit is contained in:
parent
3f71fa641f
commit
0dbb8b7fa8
|
@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- 🗄️ DISKANN index type support for Milvus vector database with configurable maximum degree and search list size parameters. [#17770](https://github.com/open-webui/open-webui/pull/17770), [Docs:Commit](https://github.com/open-webui/docs/commit/cec50ab4d4b659558ca1ccd4b5e6fc024f05fb83)
|
||||
- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security.
|
||||
- 🌐 Translations for Chinese (Simplified & Traditional) and Bosnian (Latin) were enhanced and expanded.
|
||||
- 🔧 Standarized the authentication group management for OAuth, LDAP and Trusted Header authentication. Added ability to create groups via trusted header authenticaiton. [#18016](https://github.com/open-webui/open-webui/discussions/18016)
|
||||
|
||||
### Fixed
|
||||
|
||||
|
|
|
@ -545,19 +545,21 @@ ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
|
|||
os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true",
|
||||
)
|
||||
|
||||
## Deprecated use ENABLE_GROUP_MANAGEMENT
|
||||
ENABLE_OAUTH_GROUP_MANAGEMENT = PersistentConfig(
|
||||
"ENABLE_OAUTH_GROUP_MANAGEMENT",
|
||||
"oauth.enable_group_mapping",
|
||||
os.environ.get("ENABLE_OAUTH_GROUP_MANAGEMENT", "False").lower() == "true",
|
||||
)
|
||||
|
||||
## Deprecated use ENALBE_GROUP_CREATION
|
||||
ENABLE_OAUTH_GROUP_CREATION = PersistentConfig(
|
||||
"ENABLE_OAUTH_GROUP_CREATION",
|
||||
"oauth.enable_group_creation",
|
||||
os.environ.get("ENABLE_OAUTH_GROUP_CREATION", "False").lower() == "true",
|
||||
)
|
||||
|
||||
|
||||
## Deprecated use BLOCKED_GROUPS value instead
|
||||
OAUTH_BLOCKED_GROUPS = PersistentConfig(
|
||||
"OAUTH_BLOCKED_GROUPS",
|
||||
"oauth.blocked_groups",
|
||||
|
@ -3526,12 +3528,14 @@ LDAP_CIPHERS = PersistentConfig(
|
|||
)
|
||||
|
||||
# For LDAP Group Management
|
||||
## Deprecated use ENABLE_GROUP_MANAGEMENT
|
||||
ENABLE_LDAP_GROUP_MANAGEMENT = PersistentConfig(
|
||||
"ENABLE_LDAP_GROUP_MANAGEMENT",
|
||||
"ldap.group.enable_management",
|
||||
os.environ.get("ENABLE_LDAP_GROUP_MANAGEMENT", "False").lower() == "true",
|
||||
)
|
||||
|
||||
## Deprecated use ENABLE_GROUP_CREATION
|
||||
ENABLE_LDAP_GROUP_CREATION = PersistentConfig(
|
||||
"ENABLE_LDAP_GROUP_CREATION",
|
||||
"ldap.group.enable_creation",
|
||||
|
@ -3543,3 +3547,23 @@ LDAP_ATTRIBUTE_FOR_GROUPS = PersistentConfig(
|
|||
"ldap.server.attribute_for_groups",
|
||||
os.environ.get("LDAP_ATTRIBUTE_FOR_GROUPS", "memberOf"),
|
||||
)
|
||||
|
||||
# Generic Group Management
|
||||
ENABLE_GROUP_MANAGEMENT = PersistentConfig(
|
||||
"ENABLE_GROUP_MANAGEMENT",
|
||||
"auth.enable_group_management",
|
||||
os.environ.get("ENABLE_GROUP_MANAGEMENT", "False").lower() == "true",
|
||||
)
|
||||
|
||||
ENABLE_GROUP_CREATION = PersistentConfig(
|
||||
"ENABLE_GROUP_CREATION",
|
||||
"auth.enable_group_creation",
|
||||
os.environ.get("ENABLE_GROUP_CREATION", "False").lower() == "true",
|
||||
)
|
||||
|
||||
BLOCKED_GROUPS = PersistentConfig(
|
||||
"BLOCKED_GROUPS",
|
||||
"auth.blocked_groups",
|
||||
os.environ.get("BLOCKED_GROUPS", "[]"),
|
||||
)
|
||||
|
||||
|
|
|
@ -482,6 +482,7 @@ from open_webui.utils.oauth import (
|
|||
decrypt_data,
|
||||
OAuthClientInformationFull,
|
||||
)
|
||||
from open_webui.utils.group import GroupManager
|
||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||
from open_webui.utils.redis import get_redis_connection
|
||||
|
||||
|
@ -619,6 +620,9 @@ app.state.oauth_manager = oauth_manager
|
|||
oauth_client_manager = OAuthClientManager(app)
|
||||
app.state.oauth_client_manager = oauth_client_manager
|
||||
|
||||
group_manager = GroupManager()
|
||||
app.state.group_manager = group_manager
|
||||
|
||||
app.state.instance_id = None
|
||||
app.state.config = AppConfig(
|
||||
redis_url=REDIS_URL,
|
||||
|
|
|
@ -425,16 +425,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
and ENABLE_LDAP_GROUP_MANAGEMENT
|
||||
and user_groups
|
||||
):
|
||||
if ENABLE_LDAP_GROUP_CREATION:
|
||||
Groups.create_groups_by_group_names(user.id, user_groups)
|
||||
|
||||
try:
|
||||
Groups.sync_groups_by_group_names(user.id, user_groups)
|
||||
log.info(
|
||||
f"Successfully synced groups for user {user.id}: {user_groups}"
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to sync groups for user {user.id}: {e}")
|
||||
request.app.state.group_manager.sync_user_groups(given_groups=user_groups, user=user,
|
||||
default_permissions=request.app.state.config.USER_PERMISSIONS,
|
||||
enable_group_creation=ENABLE_LDAP_GROUP_CREATION)
|
||||
|
||||
return {
|
||||
"token": token,
|
||||
|
@ -488,7 +481,8 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||
group_names = [name.strip() for name in group_names if name.strip()]
|
||||
|
||||
if group_names:
|
||||
Groups.sync_groups_by_group_names(user.id, group_names)
|
||||
request.app.state.group_manager.sync_user_groups(given_groups=group_names, user=user,
|
||||
default_permissions=request.app.state.config.USER_PERMISSIONS)
|
||||
|
||||
elif WEBUI_AUTH == False:
|
||||
admin_email = "admin@localhost"
|
||||
|
|
|
@ -0,0 +1,440 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
import json
|
||||
|
||||
from open_webui.utils.group import (
|
||||
BlockedGroupMatcher,
|
||||
GroupManager,
|
||||
is_in_blocked_groups,
|
||||
)
|
||||
from open_webui.models.groups import GroupModel, GroupForm, GroupUpdateForm
|
||||
|
||||
|
||||
class TestBlockedGroupMatcher:
|
||||
def test_exact_match(self):
|
||||
matcher = BlockedGroupMatcher(["admin", "root", "superuser"])
|
||||
|
||||
assert matcher.is_blocked("admin") is True
|
||||
assert matcher.is_blocked("root") is True
|
||||
assert matcher.is_blocked("superuser") is True
|
||||
assert matcher.is_blocked("user") is False
|
||||
assert matcher.is_blocked("Admin") is False
|
||||
|
||||
def test_wildcard_match(self):
|
||||
matcher = BlockedGroupMatcher(["test-*", "dev_*", "?temp"])
|
||||
|
||||
assert matcher.is_blocked("test-group") is True
|
||||
assert matcher.is_blocked("test-123") is True
|
||||
assert matcher.is_blocked("dev_team") is True
|
||||
assert matcher.is_blocked("atemp") is True
|
||||
assert matcher.is_blocked("1temp") is True
|
||||
assert matcher.is_blocked("temp") is False
|
||||
assert matcher.is_blocked("test") is False
|
||||
|
||||
def test_regex_match(self):
|
||||
matcher = BlockedGroupMatcher([r"^admin-\d+$", r"test\[.*\]"])
|
||||
|
||||
assert matcher.is_blocked("admin-123") is True
|
||||
assert matcher.is_blocked("admin-001") is True
|
||||
assert matcher.is_blocked("test[beta]") is True
|
||||
assert matcher.is_blocked("admin-") is False
|
||||
assert matcher.is_blocked("admin-abc") is False
|
||||
|
||||
def test_invalid_regex_fallback(self):
|
||||
matcher = BlockedGroupMatcher(["test[invalid", "valid-*"])
|
||||
|
||||
assert matcher.is_blocked("test[invalid") is False
|
||||
assert matcher.is_blocked("valid-group") is True
|
||||
|
||||
def test_empty_patterns(self):
|
||||
matcher = BlockedGroupMatcher([])
|
||||
|
||||
assert matcher.is_blocked("any-group") is False
|
||||
|
||||
def test_none_and_empty_string_patterns(self):
|
||||
matcher = BlockedGroupMatcher(["", None, "valid"])
|
||||
|
||||
assert matcher.is_blocked("valid") is True
|
||||
assert matcher.is_blocked("") is False
|
||||
|
||||
def test_mixed_patterns(self):
|
||||
matcher = BlockedGroupMatcher([
|
||||
"exact-match",
|
||||
"wildcard-*",
|
||||
r"^regex-\d+$",
|
||||
])
|
||||
|
||||
assert matcher.is_blocked("exact-match") is True
|
||||
assert matcher.is_blocked("wildcard-test") is True
|
||||
assert matcher.is_blocked("regex-123") is True
|
||||
assert matcher.is_blocked("other") is False
|
||||
|
||||
def test_is_regex_pattern(self):
|
||||
assert BlockedGroupMatcher._is_regex_pattern("^test$") is True
|
||||
assert BlockedGroupMatcher._is_regex_pattern("test[abc]") is True
|
||||
assert BlockedGroupMatcher._is_regex_pattern("test(a|b)") is True
|
||||
assert BlockedGroupMatcher._is_regex_pattern("test{2,3}") is True
|
||||
assert BlockedGroupMatcher._is_regex_pattern("test+") is True
|
||||
assert BlockedGroupMatcher._is_regex_pattern("test\\d") is True
|
||||
|
||||
assert BlockedGroupMatcher._is_regex_pattern("test-*") is False
|
||||
assert BlockedGroupMatcher._is_regex_pattern("simple") is False
|
||||
|
||||
|
||||
class TestGroupManager:
|
||||
@pytest.fixture
|
||||
def manager(self):
|
||||
return GroupManager()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(self):
|
||||
user = Mock()
|
||||
user.id = "user-123"
|
||||
user.name = "Test User"
|
||||
return user
|
||||
|
||||
@pytest.fixture
|
||||
def mock_group(self):
|
||||
group = Mock(spec=GroupModel)
|
||||
group.id = "group-1"
|
||||
group.name = "test-group"
|
||||
group.description = "Test Group"
|
||||
group.user_ids = []
|
||||
group.permissions = {"read": True}
|
||||
return group
|
||||
|
||||
@patch("open_webui.utils.group.auth_manager_config")
|
||||
def test_parse_blocked_groups_config_empty(self, mock_config, manager):
|
||||
mock_config.BLOCKED_GROUPS = "[]"
|
||||
mock_config.OAUTH_BLOCKED_GROUPS = "[]"
|
||||
|
||||
result = manager._parse_blocked_groups_config()
|
||||
|
||||
assert isinstance(result, BlockedGroupMatcher)
|
||||
assert result.is_blocked("any-group") is False
|
||||
|
||||
@patch("open_webui.utils.group.auth_manager_config")
|
||||
def test_parse_blocked_groups_config_valid(self, mock_config, manager):
|
||||
mock_config.BLOCKED_GROUPS = '["admin", "root"]'
|
||||
mock_config.OAUTH_BLOCKED_GROUPS = '["oauth-blocked"]'
|
||||
|
||||
result = manager._parse_blocked_groups_config()
|
||||
|
||||
assert isinstance(result, BlockedGroupMatcher)
|
||||
assert result.is_blocked("admin") is True
|
||||
assert result.is_blocked("root") is True
|
||||
assert result.is_blocked("oauth-blocked") is True
|
||||
|
||||
@patch("open_webui.utils.group.auth_manager_config")
|
||||
def test_parse_blocked_groups_config_invalid_json(self, mock_config, manager):
|
||||
mock_config.BLOCKED_GROUPS = "invalid json"
|
||||
mock_config.OAUTH_BLOCKED_GROUPS = "[]"
|
||||
|
||||
result = manager._parse_blocked_groups_config()
|
||||
|
||||
assert isinstance(result, BlockedGroupMatcher)
|
||||
assert result.is_blocked("any-group") is False
|
||||
|
||||
@patch("open_webui.utils.group.auth_manager_config")
|
||||
def test_parse_blocked_groups_config_wrong_type(self, mock_config, manager):
|
||||
mock_config.BLOCKED_GROUPS = '{"not": "a list"}'
|
||||
mock_config.OAUTH_BLOCKED_GROUPS = "[]"
|
||||
|
||||
result = manager._parse_blocked_groups_config()
|
||||
|
||||
assert isinstance(result, BlockedGroupMatcher)
|
||||
|
||||
@patch("open_webui.utils.group.Users.get_super_admin_user")
|
||||
def test_determine_creator_id_admin_exists(self, mock_get_admin, manager, mock_user):
|
||||
admin = Mock()
|
||||
admin.id = "admin-456"
|
||||
mock_get_admin.return_value = admin
|
||||
|
||||
result = manager._determine_creator_id(mock_user)
|
||||
|
||||
assert result == "admin-456"
|
||||
mock_get_admin.assert_called_once()
|
||||
|
||||
@patch("open_webui.utils.group.Users.get_super_admin_user")
|
||||
def test_determine_creator_id_no_admin(self, mock_get_admin, manager, mock_user):
|
||||
mock_get_admin.return_value = None
|
||||
|
||||
result = manager._determine_creator_id(mock_user)
|
||||
|
||||
assert result == "user-123"
|
||||
|
||||
@patch("open_webui.utils.group.Groups.insert_new_group")
|
||||
def test_create_group_success(self, mock_insert, manager):
|
||||
created_group = Mock()
|
||||
created_group.id = "new-group-id"
|
||||
mock_insert.return_value = created_group
|
||||
|
||||
result = manager._create_group("new-group", "creator-id", {"read": True})
|
||||
|
||||
assert result is True
|
||||
mock_insert.assert_called_once()
|
||||
call_args = mock_insert.call_args
|
||||
assert call_args[0][0] == "creator-id"
|
||||
assert isinstance(call_args[0][1], GroupForm)
|
||||
|
||||
@patch("open_webui.utils.group.Groups.insert_new_group")
|
||||
def test_create_group_failure(self, mock_insert, manager):
|
||||
mock_insert.return_value = None
|
||||
|
||||
result = manager._create_group("new-group", "creator-id", {"read": True})
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch("open_webui.utils.group.Groups.insert_new_group")
|
||||
def test_create_group_exception(self, mock_insert, manager):
|
||||
mock_insert.side_effect = Exception("Database error")
|
||||
|
||||
result = manager._create_group("new-group", "creator-id", {"read": True})
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch("open_webui.utils.group.Groups.get_groups")
|
||||
def test_ensure_groups_exist_creation_disabled(
|
||||
self, mock_get_groups, manager, mock_user
|
||||
):
|
||||
existing_groups = [Mock(name="existing")]
|
||||
mock_get_groups.return_value = existing_groups
|
||||
|
||||
result = manager._ensure_groups_exist(False, ["new-group"], {}, mock_user)
|
||||
|
||||
assert result == existing_groups
|
||||
assert mock_get_groups.call_count == 1
|
||||
|
||||
@patch("open_webui.utils.group.Groups.get_groups")
|
||||
@patch("open_webui.utils.group.Users.get_super_admin_user")
|
||||
@patch("open_webui.utils.group.Groups.insert_new_group")
|
||||
def test_ensure_groups_exist_creates_missing(
|
||||
self, mock_insert, mock_get_admin, mock_get_groups, manager, mock_user
|
||||
):
|
||||
existing = Mock()
|
||||
existing.name = "existing-group"
|
||||
mock_get_groups.side_effect = [[existing], [existing, Mock(name="new-group")]]
|
||||
|
||||
admin = Mock()
|
||||
admin.id = "admin-id"
|
||||
mock_get_admin.return_value = admin
|
||||
|
||||
created_group = Mock()
|
||||
created_group.id = "new-id"
|
||||
mock_insert.return_value = created_group
|
||||
|
||||
result = manager._ensure_groups_exist(
|
||||
True, ["existing-group", "new-group"], {"read": True}, mock_user
|
||||
)
|
||||
|
||||
assert mock_insert.call_count == 1
|
||||
assert mock_get_groups.call_count == 2
|
||||
|
||||
@patch("open_webui.utils.group.Groups.update_group_by_id")
|
||||
def test_update_group_membership(self, mock_update, manager, mock_group):
|
||||
manager._update_group_membership(mock_group, ["user-1", "user-2"], {"read": True})
|
||||
|
||||
mock_update.assert_called_once()
|
||||
call_args = mock_update.call_args
|
||||
assert call_args[1]["id"] == "group-1"
|
||||
assert isinstance(call_args[1]["form_data"], GroupUpdateForm)
|
||||
|
||||
@patch("open_webui.utils.group.Groups.update_group_by_id")
|
||||
def test_update_group_membership_no_permissions(
|
||||
self, mock_update, manager, mock_group
|
||||
):
|
||||
mock_group.permissions = None
|
||||
|
||||
manager._update_group_membership(mock_group, ["user-1"], {"default": True})
|
||||
|
||||
mock_update.assert_called_once()
|
||||
|
||||
def test_sync_group_memberships_no_given_groups(self, manager, mock_user, mock_group):
|
||||
manager._sync_group_memberships(
|
||||
mock_user,
|
||||
[mock_group],
|
||||
[mock_group],
|
||||
[],
|
||||
BlockedGroupMatcher([]),
|
||||
{},
|
||||
)
|
||||
|
||||
@patch("open_webui.utils.group.Groups.update_group_by_id")
|
||||
def test_sync_group_memberships_remove_user(
|
||||
self, mock_update, manager, mock_user, mock_group
|
||||
):
|
||||
mock_group.user_ids = ["user-123", "other-user"]
|
||||
current_groups = [mock_group]
|
||||
available_groups = [mock_group]
|
||||
given_groups = []
|
||||
|
||||
manager._sync_group_memberships(
|
||||
mock_user,
|
||||
current_groups,
|
||||
available_groups,
|
||||
["other-group"],
|
||||
BlockedGroupMatcher([]),
|
||||
{"read": True},
|
||||
)
|
||||
|
||||
mock_update.assert_called_once()
|
||||
|
||||
@patch("open_webui.utils.group.Groups.update_group_by_id")
|
||||
def test_sync_group_memberships_add_user(
|
||||
self, mock_update, manager, mock_user, mock_group
|
||||
):
|
||||
mock_group.user_ids = ["other-user"]
|
||||
current_groups = []
|
||||
available_groups = [mock_group]
|
||||
given_groups = ["test-group"]
|
||||
|
||||
manager._sync_group_memberships(
|
||||
mock_user,
|
||||
current_groups,
|
||||
available_groups,
|
||||
given_groups,
|
||||
BlockedGroupMatcher([]),
|
||||
{"read": True},
|
||||
)
|
||||
|
||||
mock_update.assert_called_once()
|
||||
|
||||
@patch("open_webui.utils.group.Groups.update_group_by_id")
|
||||
def test_sync_group_memberships_blocked_group(
|
||||
self, mock_update, manager, mock_user, mock_group
|
||||
):
|
||||
mock_group.user_ids = ["other-user"]
|
||||
current_groups = []
|
||||
available_groups = [mock_group]
|
||||
given_groups = ["test-group"]
|
||||
blocked_matcher = BlockedGroupMatcher(["test-group"])
|
||||
|
||||
manager._sync_group_memberships(
|
||||
mock_user,
|
||||
current_groups,
|
||||
available_groups,
|
||||
given_groups,
|
||||
blocked_matcher,
|
||||
{"read": True},
|
||||
)
|
||||
|
||||
mock_update.assert_not_called()
|
||||
|
||||
@patch("open_webui.utils.group.auth_manager_config")
|
||||
@patch("open_webui.utils.group.Groups.get_groups_by_member_id")
|
||||
@patch("open_webui.utils.group.Groups.get_groups")
|
||||
def test_sync_user_groups_integration(
|
||||
self, mock_get_groups, mock_get_member_groups, mock_config, manager, mock_user, mock_group
|
||||
):
|
||||
mock_config.ENABLE_GROUP_MANAGEMENT = True
|
||||
mock_config.BLOCKED_GROUPS = "[]"
|
||||
mock_config.OAUTH_BLOCKED_GROUPS = "[]"
|
||||
mock_config.ENABLE_GROUP_CREATION = False
|
||||
|
||||
mock_get_member_groups.return_value = []
|
||||
mock_get_groups.return_value = [mock_group]
|
||||
|
||||
manager.sync_user_groups(["test-group"], mock_user, {"read": True}, False)
|
||||
|
||||
mock_get_member_groups.assert_called_once_with(mock_user.id)
|
||||
mock_get_groups.assert_called()
|
||||
|
||||
@patch("open_webui.utils.group.auth_manager_config")
|
||||
@patch("open_webui.utils.group.Groups.get_groups_by_member_id")
|
||||
def test_sync_user_groups_disabled(
|
||||
self, mock_get_member_groups, mock_config, manager, mock_user
|
||||
):
|
||||
mock_config.ENABLE_GROUP_MANAGEMENT = False
|
||||
|
||||
manager.sync_user_groups(["test-group"], mock_user, {"read": True})
|
||||
|
||||
mock_get_member_groups.assert_not_called()
|
||||
|
||||
def test_log_group_sync_status(self, manager, mock_user, mock_group):
|
||||
group1 = Mock(spec=GroupModel)
|
||||
group1.name = "group1"
|
||||
group2 = Mock(spec=GroupModel)
|
||||
group2.name = "group2"
|
||||
|
||||
manager._log_group_sync_status(
|
||||
["external-group1", "external-group2"],
|
||||
[group1],
|
||||
[group1, group2],
|
||||
)
|
||||
|
||||
@patch("open_webui.utils.group.Groups.update_group_by_id")
|
||||
def test_sync_group_memberships_user_already_in_group(
|
||||
self, mock_update, manager, mock_user, mock_group
|
||||
):
|
||||
mock_group.user_ids = ["user-123"]
|
||||
current_groups = [mock_group]
|
||||
available_groups = [mock_group]
|
||||
given_groups = ["test-group"]
|
||||
|
||||
manager._sync_group_memberships(
|
||||
mock_user,
|
||||
current_groups,
|
||||
available_groups,
|
||||
given_groups,
|
||||
BlockedGroupMatcher([]),
|
||||
{"read": True},
|
||||
)
|
||||
|
||||
mock_update.assert_not_called()
|
||||
|
||||
@patch("open_webui.utils.group.Groups.update_group_by_id")
|
||||
def test_sync_group_memberships_group_not_available(
|
||||
self, mock_update, manager, mock_user, mock_group
|
||||
):
|
||||
current_groups = []
|
||||
available_groups = []
|
||||
given_groups = ["non-existent-group"]
|
||||
|
||||
manager._sync_group_memberships(
|
||||
mock_user,
|
||||
current_groups,
|
||||
available_groups,
|
||||
given_groups,
|
||||
BlockedGroupMatcher([]),
|
||||
{"read": True},
|
||||
)
|
||||
|
||||
mock_update.assert_not_called()
|
||||
|
||||
@patch("open_webui.utils.group.Groups.update_group_by_id")
|
||||
def test_sync_group_memberships_remove_blocked_group(
|
||||
self, mock_update, manager, mock_user, mock_group
|
||||
):
|
||||
mock_group.user_ids = ["user-123"]
|
||||
mock_group.name = "blocked-group"
|
||||
current_groups = [mock_group]
|
||||
available_groups = [mock_group]
|
||||
given_groups = ["other-group"]
|
||||
blocked_matcher = BlockedGroupMatcher(["blocked-group"])
|
||||
|
||||
manager._sync_group_memberships(
|
||||
mock_user,
|
||||
current_groups,
|
||||
available_groups,
|
||||
given_groups,
|
||||
blocked_matcher,
|
||||
{"read": True},
|
||||
)
|
||||
|
||||
mock_update.assert_not_called()
|
||||
|
||||
|
||||
class TestIsInBlockedGroups:
|
||||
def test_backward_compatibility(self):
|
||||
groups = ["admin", "test-*", r"^dev-\d+$"]
|
||||
|
||||
assert is_in_blocked_groups("admin", groups) is True
|
||||
assert is_in_blocked_groups("test-123", groups) is True
|
||||
assert is_in_blocked_groups("dev-456", groups) is True
|
||||
assert is_in_blocked_groups("user", groups) is False
|
||||
|
||||
def test_empty_list(self):
|
||||
assert is_in_blocked_groups("any-group", []) is False
|
||||
|
||||
def test_none_list(self):
|
||||
assert is_in_blocked_groups("any-group", None) is False
|
|
@ -0,0 +1,293 @@
|
|||
import logging
|
||||
import sys
|
||||
import json
|
||||
import re
|
||||
import fnmatch
|
||||
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm
|
||||
from open_webui.config import (
|
||||
ENABLE_GROUP_MANAGEMENT,
|
||||
ENABLE_GROUP_CREATION,
|
||||
BLOCKED_GROUPS,
|
||||
OAUTH_BLOCKED_GROUPS,
|
||||
AppConfig,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
auth_manager_config = AppConfig()
|
||||
auth_manager_config.ENABLE_GROUP_MANAGEMENT = ENABLE_GROUP_MANAGEMENT
|
||||
auth_manager_config.ENABLE_GROUP_CREATION = ENABLE_GROUP_CREATION
|
||||
auth_manager_config.BLOCKED_GROUPS = BLOCKED_GROUPS
|
||||
auth_manager_config.OAUTH_BLOCKED_GROUPS = OAUTH_BLOCKED_GROUPS
|
||||
|
||||
|
||||
class BlockedGroupMatcher:
|
||||
"""Compiled matcher for blocked group patterns with optimized matching."""
|
||||
|
||||
def __init__(self, patterns: list[str]):
|
||||
self.exact_matches = set()
|
||||
self.wildcards = []
|
||||
self.regexes = []
|
||||
|
||||
for pattern in patterns:
|
||||
if not pattern:
|
||||
continue
|
||||
|
||||
if self._is_regex_pattern(pattern):
|
||||
try:
|
||||
self.regexes.append(re.compile(pattern))
|
||||
except re.error as e:
|
||||
log.warning(f"Invalid regex pattern '{pattern}': {e}")
|
||||
elif "*" in pattern or "?" in pattern:
|
||||
self.wildcards.append(pattern)
|
||||
else:
|
||||
self.exact_matches.add(pattern)
|
||||
|
||||
@staticmethod
|
||||
def _is_regex_pattern(pattern: str) -> bool:
|
||||
"""Check if pattern contains regex-specific characters."""
|
||||
return any(
|
||||
c in pattern
|
||||
for c in ["^", "$", "[", "]", "(", ")", "{", "}", "+", "\\", "|"]
|
||||
)
|
||||
|
||||
def is_blocked(self, group_name: str) -> bool:
|
||||
"""Check if group name matches any blocked pattern."""
|
||||
if group_name in self.exact_matches:
|
||||
return True
|
||||
|
||||
for pattern in self.wildcards:
|
||||
if fnmatch.fnmatch(group_name, pattern):
|
||||
return True
|
||||
|
||||
for regex in self.regexes:
|
||||
if regex.search(group_name):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class GroupManager:
|
||||
"""Manages group synchronization for users with OAuth/SCIM integration."""
|
||||
|
||||
def sync_user_groups(
|
||||
self,
|
||||
given_groups: list[str],
|
||||
user,
|
||||
default_permissions: dict,
|
||||
enable_group_creation = auth_manager_config.ENABLE_GROUP_CREATION,
|
||||
) -> None:
|
||||
"""
|
||||
Synchronize user's group memberships based on external groups.
|
||||
|
||||
Args:
|
||||
given_groups: List of group names from external auth provider
|
||||
user: User object to synchronize groups for
|
||||
default_permissions: Default permissions for newly created groups
|
||||
enable_group_creation: if false given_groups that do not already exist will not be created
|
||||
"""
|
||||
if auth_manager_config.ENABLE_GROUP_MANAGEMENT:
|
||||
log.debug("Running Group management")
|
||||
|
||||
blocked_matcher = self._parse_blocked_groups_config()
|
||||
|
||||
user_current_groups = Groups.get_groups_by_member_id(user.id)
|
||||
all_available_groups = self._ensure_groups_exist(
|
||||
enable_group_creation, given_groups, default_permissions, user
|
||||
)
|
||||
|
||||
self._log_group_sync_status(given_groups, user_current_groups, all_available_groups)
|
||||
|
||||
self._sync_group_memberships(
|
||||
user,
|
||||
user_current_groups,
|
||||
all_available_groups,
|
||||
given_groups,
|
||||
blocked_matcher,
|
||||
default_permissions,
|
||||
)
|
||||
|
||||
def _parse_blocked_groups_config(self) -> BlockedGroupMatcher:
|
||||
"""Parse and combine blocked groups from configuration."""
|
||||
blocked_groups = []
|
||||
|
||||
try:
|
||||
parsed = json.loads(auth_manager_config.BLOCKED_GROUPS) if auth_manager_config.BLOCKED_GROUPS else []
|
||||
if isinstance(parsed, list):
|
||||
blocked_groups = parsed
|
||||
else:
|
||||
log.warning(f"BLOCKED_GROUPS is not a list: {type(parsed)}")
|
||||
except json.JSONDecodeError as e:
|
||||
log.error(f"Invalid JSON in BLOCKED_GROUPS: {e}")
|
||||
|
||||
# Support the legacy OAUTH_BLOCKED_GROUPS value. This should be phased out for the generic BLOCKED_GROUPS value
|
||||
try:
|
||||
oauth_blocked = json.loads(auth_manager_config.OAUTH_BLOCKED_GROUPS) if auth_manager_config.OAUTH_BLOCKED_GROUPS else []
|
||||
if isinstance(oauth_blocked, list):
|
||||
blocked_groups.extend(oauth_blocked)
|
||||
except json.JSONDecodeError as e:
|
||||
log.error(f"Invalid JSON in OAUTH_BLOCKED_GROUPS: {e}")
|
||||
|
||||
return BlockedGroupMatcher(blocked_groups)
|
||||
|
||||
def _ensure_groups_exist(
|
||||
self,
|
||||
enable_group_creation: bool,
|
||||
given_groups: list[str],
|
||||
default_permissions: dict,
|
||||
user
|
||||
) -> list[GroupModel]:
|
||||
"""Create missing groups if creation is enabled."""
|
||||
if not enable_group_creation:
|
||||
return Groups.get_groups()
|
||||
|
||||
all_available_groups = Groups.get_groups()
|
||||
existing_names = {g.name for g in all_available_groups}
|
||||
|
||||
log.debug("Checking for missing groups to create...")
|
||||
creator_id = self._determine_creator_id(user)
|
||||
groups_created = False
|
||||
|
||||
for group_name in given_groups:
|
||||
if group_name not in existing_names:
|
||||
if self._create_group(group_name, creator_id, default_permissions):
|
||||
groups_created = True
|
||||
existing_names.add(group_name)
|
||||
|
||||
if groups_created:
|
||||
all_available_groups = Groups.get_groups()
|
||||
log.debug("Refreshed list of all available groups after creation.")
|
||||
|
||||
return all_available_groups
|
||||
|
||||
def _determine_creator_id(self, fallback_user) -> str:
|
||||
"""Get admin user ID or fallback to provided user."""
|
||||
admin_user = Users.get_super_admin_user()
|
||||
creator_id = admin_user.id if admin_user else fallback_user.id
|
||||
log.debug(f"Using creator ID {creator_id} for potential group creation")
|
||||
return creator_id
|
||||
|
||||
def _create_group(
|
||||
self, group_name: str, creator_id: str, default_permissions: dict
|
||||
) -> bool:
|
||||
"""Create a single group."""
|
||||
log.info(f"Group '{group_name}' not found. Creating group...")
|
||||
try:
|
||||
new_group_form = GroupForm(
|
||||
name=group_name,
|
||||
description=f"Group '{group_name}' created automatically.",
|
||||
permissions=default_permissions,
|
||||
user_ids=[],
|
||||
)
|
||||
created_group = Groups.insert_new_group(creator_id, new_group_form)
|
||||
if created_group:
|
||||
log.info(
|
||||
f"Successfully created group '{group_name}' with ID {created_group.id}"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
log.error(f"Failed to create group '{group_name}'")
|
||||
return False
|
||||
except Exception as e:
|
||||
log.error(f"Error creating group '{group_name}': {e}")
|
||||
return False
|
||||
|
||||
def _log_group_sync_status(
|
||||
self,
|
||||
given_groups: list[str],
|
||||
user_current_groups: list[GroupModel],
|
||||
all_available_groups: list[GroupModel],
|
||||
) -> None:
|
||||
"""Log current state of group synchronization."""
|
||||
log.debug(f"Given user groups: {given_groups}")
|
||||
log.debug(f"User's current groups: {[g.name for g in user_current_groups]}")
|
||||
log.debug(
|
||||
f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}"
|
||||
)
|
||||
|
||||
def _sync_group_memberships(
|
||||
self,
|
||||
user,
|
||||
user_current_groups: list[GroupModel],
|
||||
all_available_groups: list[GroupModel],
|
||||
given_groups: list[str],
|
||||
blocked_matcher: BlockedGroupMatcher,
|
||||
default_permissions: dict,
|
||||
) -> None:
|
||||
"""Efficiently sync group memberships in a single pass."""
|
||||
if not given_groups:
|
||||
return
|
||||
|
||||
given_groups_set = set(given_groups)
|
||||
current_groups_map = {g.name: g for g in user_current_groups}
|
||||
available_groups_map = {g.name: g for g in all_available_groups}
|
||||
|
||||
for group_name, group_model in current_groups_map.items():
|
||||
if (
|
||||
group_name not in given_groups_set
|
||||
and not blocked_matcher.is_blocked(group_name)
|
||||
):
|
||||
log.debug(
|
||||
f"Removing user from group {group_name} as it is no longer in their groups"
|
||||
)
|
||||
user_ids = [uid for uid in group_model.user_ids if uid != user.id]
|
||||
self._update_group_membership(
|
||||
group_model, user_ids, default_permissions
|
||||
)
|
||||
|
||||
for group_name in given_groups_set:
|
||||
if (
|
||||
group_name not in current_groups_map
|
||||
and group_name in available_groups_map
|
||||
and not blocked_matcher.is_blocked(group_name)
|
||||
):
|
||||
log.debug(
|
||||
f"Adding user to group {group_name} as it was found in their given groups"
|
||||
)
|
||||
group_model = available_groups_map[group_name]
|
||||
user_ids = group_model.user_ids + [user.id]
|
||||
self._update_group_membership(
|
||||
group_model, user_ids, default_permissions
|
||||
)
|
||||
|
||||
def _update_group_membership(
|
||||
self, group_model: GroupModel, user_ids: list[str], default_permissions: dict
|
||||
) -> None:
|
||||
"""Update group membership with given user IDs."""
|
||||
permissions = group_model.permissions or default_permissions
|
||||
|
||||
update_form = GroupUpdateForm(
|
||||
name=group_model.name,
|
||||
description=group_model.description,
|
||||
permissions=permissions,
|
||||
user_ids=user_ids,
|
||||
)
|
||||
Groups.update_group_by_id(
|
||||
id=group_model.id, form_data=update_form, overwrite=False
|
||||
)
|
||||
|
||||
|
||||
def is_in_blocked_groups(group_name: str, groups: list) -> bool:
|
||||
"""
|
||||
Check if a group name matches any blocked pattern.
|
||||
Supports exact matches, shell-style wildcards (*, ?), and regex patterns.
|
||||
|
||||
Deprecated: Use BlockedGroupMatcher.is_blocked() instead for better performance.
|
||||
|
||||
Args:
|
||||
group_name: The group name to check
|
||||
groups: List of patterns to match against
|
||||
|
||||
Returns:
|
||||
True if the group is blocked, False otherwise
|
||||
"""
|
||||
if not groups:
|
||||
return False
|
||||
|
||||
matcher = BlockedGroupMatcher(groups)
|
||||
return matcher.is_blocked(group_name)
|
|
@ -10,8 +10,6 @@ from datetime import datetime, timedelta
|
|||
|
||||
import re
|
||||
import fnmatch
|
||||
import time
|
||||
import secrets
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
|
@ -31,7 +29,6 @@ from open_webui.models.oauth_sessions import OAuthSessions
|
|||
from open_webui.models.users import Users
|
||||
|
||||
|
||||
from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm
|
||||
from open_webui.config import (
|
||||
DEFAULT_USER_ROLE,
|
||||
ENABLE_OAUTH_SIGNUP,
|
||||
|
@ -873,12 +870,6 @@ class OAuthManager:
|
|||
log.debug("Running OAUTH Group management")
|
||||
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
|
||||
|
||||
try:
|
||||
blocked_groups = json.loads(auth_manager_config.OAUTH_BLOCKED_GROUPS)
|
||||
except Exception as e:
|
||||
log.exception(f"Error loading OAUTH_BLOCKED_GROUPS: {e}")
|
||||
blocked_groups = []
|
||||
|
||||
user_oauth_groups = []
|
||||
# Nested claim search for groups claim
|
||||
if oauth_claim:
|
||||
|
@ -894,121 +885,8 @@ class OAuthManager:
|
|||
else:
|
||||
user_oauth_groups = []
|
||||
|
||||
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
|
||||
all_available_groups: list[GroupModel] = Groups.get_groups()
|
||||
|
||||
# Create groups if they don't exist and creation is enabled
|
||||
if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
|
||||
log.debug("Checking for missing groups to create...")
|
||||
all_group_names = {g.name for g in all_available_groups}
|
||||
groups_created = False
|
||||
# Determine creator ID: Prefer admin, fallback to current user if no admin exists
|
||||
admin_user = Users.get_super_admin_user()
|
||||
creator_id = admin_user.id if admin_user else user.id
|
||||
log.debug(f"Using creator ID {creator_id} for potential group creation.")
|
||||
|
||||
for group_name in user_oauth_groups:
|
||||
if group_name not in all_group_names:
|
||||
log.info(
|
||||
f"Group '{group_name}' not found via OAuth claim. Creating group..."
|
||||
)
|
||||
try:
|
||||
new_group_form = GroupForm(
|
||||
name=group_name,
|
||||
description=f"Group '{group_name}' created automatically via OAuth.",
|
||||
permissions=default_permissions, # Use default permissions from function args
|
||||
user_ids=[], # Start with no users, user will be added later by subsequent logic
|
||||
)
|
||||
# Use determined creator ID (admin or fallback to current user)
|
||||
created_group = Groups.insert_new_group(
|
||||
creator_id, new_group_form
|
||||
)
|
||||
if created_group:
|
||||
log.info(
|
||||
f"Successfully created group '{group_name}' with ID {created_group.id} using creator ID {creator_id}"
|
||||
)
|
||||
groups_created = True
|
||||
# Add to local set to prevent duplicate creation attempts in this run
|
||||
all_group_names.add(group_name)
|
||||
else:
|
||||
log.error(
|
||||
f"Failed to create group '{group_name}' via OAuth."
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error creating group '{group_name}' via OAuth: {e}")
|
||||
|
||||
# Refresh the list of all available groups if any were created
|
||||
if groups_created:
|
||||
all_available_groups = Groups.get_groups()
|
||||
log.debug("Refreshed list of all available groups after creation.")
|
||||
|
||||
log.debug(f"Oauth Groups claim: {oauth_claim}")
|
||||
log.debug(f"User oauth groups: {user_oauth_groups}")
|
||||
log.debug(f"User's current groups: {[g.name for g in user_current_groups]}")
|
||||
log.debug(
|
||||
f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}"
|
||||
)
|
||||
|
||||
# Remove groups that user is no longer a part of
|
||||
for group_model in user_current_groups:
|
||||
if (
|
||||
user_oauth_groups
|
||||
and group_model.name not in user_oauth_groups
|
||||
and not is_in_blocked_groups(group_model.name, blocked_groups)
|
||||
):
|
||||
# Remove group from user
|
||||
log.debug(
|
||||
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
|
||||
)
|
||||
|
||||
user_ids = group_model.user_ids
|
||||
user_ids = [i for i in user_ids if i != user.id]
|
||||
|
||||
# In case a group is created, but perms are never assigned to the group by hitting "save"
|
||||
group_permissions = group_model.permissions
|
||||
if not group_permissions:
|
||||
group_permissions = default_permissions
|
||||
|
||||
update_form = GroupUpdateForm(
|
||||
name=group_model.name,
|
||||
description=group_model.description,
|
||||
permissions=group_permissions,
|
||||
user_ids=user_ids,
|
||||
)
|
||||
Groups.update_group_by_id(
|
||||
id=group_model.id, form_data=update_form, overwrite=False
|
||||
)
|
||||
|
||||
# Add user to new groups
|
||||
for group_model in all_available_groups:
|
||||
if (
|
||||
user_oauth_groups
|
||||
and group_model.name in user_oauth_groups
|
||||
and not any(gm.name == group_model.name for gm in user_current_groups)
|
||||
and not is_in_blocked_groups(group_model.name, blocked_groups)
|
||||
):
|
||||
# Add user to group
|
||||
log.debug(
|
||||
f"Adding user to group {group_model.name} as it was found in their oauth groups"
|
||||
)
|
||||
|
||||
user_ids = group_model.user_ids
|
||||
user_ids.append(user.id)
|
||||
|
||||
# In case a group is created, but perms are never assigned to the group by hitting "save"
|
||||
group_permissions = group_model.permissions
|
||||
if not group_permissions:
|
||||
group_permissions = default_permissions
|
||||
|
||||
update_form = GroupUpdateForm(
|
||||
name=group_model.name,
|
||||
description=group_model.description,
|
||||
permissions=group_permissions,
|
||||
user_ids=user_ids,
|
||||
)
|
||||
Groups.update_group_by_id(
|
||||
id=group_model.id, form_data=update_form, overwrite=False
|
||||
)
|
||||
self.app.state.group_manager.sync_user_groups(given_groups=user_oauth_groups, user=user,
|
||||
default_permissions=default_permissions, enable_group_creation=auth_manager_config.ENABLE_OAUTH_GROUP_CREATION)
|
||||
|
||||
async def _process_picture_url(
|
||||
self, picture_url: str, access_token: str = None
|
||||
|
|
Loading…
Reference in New Issue