| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  | import json | 
					
						
							|  |  |  | import logging | 
					
						
							|  |  |  | import time | 
					
						
							|  |  |  | from typing import Optional | 
					
						
							|  |  |  | import uuid | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-10 16:54:13 +08:00
										 |  |  | from open_webui.internal.db import Base, get_db | 
					
						
							| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  | from open_webui.env import SRC_LOG_LEVELS | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-10 16:54:13 +08:00
										 |  |  | from open_webui.models.files import FileMetadataResponse | 
					
						
							| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from pydantic import BaseModel, ConfigDict | 
					
						
							| 
									
										
										
										
											2024-11-18 03:53:51 +08:00
										 |  |  | from sqlalchemy import BigInteger, Column, String, Text, JSON, func | 
					
						
							| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | log = logging.getLogger(__name__) | 
					
						
							|  |  |  | log.setLevel(SRC_LOG_LEVELS["MODELS"]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #################### | 
					
						
							|  |  |  | # UserGroup DB Schema | 
					
						
							|  |  |  | #################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Group(Base): | 
					
						
							|  |  |  |     __tablename__ = "group" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     id = Column(Text, unique=True, primary_key=True) | 
					
						
							|  |  |  |     user_id = Column(Text) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     name = Column(Text) | 
					
						
							|  |  |  |     description = Column(Text) | 
					
						
							| 
									
										
										
										
											2024-11-17 06:56:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     data = Column(JSON, nullable=True) | 
					
						
							| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  |     meta = Column(JSON, nullable=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     permissions = Column(JSON, nullable=True) | 
					
						
							|  |  |  |     user_ids = Column(JSON, nullable=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     created_at = Column(BigInteger) | 
					
						
							|  |  |  |     updated_at = Column(BigInteger) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class GroupModel(BaseModel): | 
					
						
							|  |  |  |     model_config = ConfigDict(from_attributes=True) | 
					
						
							|  |  |  |     id: str | 
					
						
							|  |  |  |     user_id: str | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     name: str | 
					
						
							|  |  |  |     description: str | 
					
						
							| 
									
										
										
										
											2024-11-17 06:56:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     data: Optional[dict] = None | 
					
						
							| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  |     meta: Optional[dict] = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     permissions: Optional[dict] = None | 
					
						
							|  |  |  |     user_ids: list[str] = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     created_at: int  # timestamp in epoch | 
					
						
							|  |  |  |     updated_at: int  # timestamp in epoch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #################### | 
					
						
							|  |  |  | # Forms | 
					
						
							|  |  |  | #################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class GroupResponse(BaseModel): | 
					
						
							|  |  |  |     id: str | 
					
						
							|  |  |  |     user_id: str | 
					
						
							|  |  |  |     name: str | 
					
						
							|  |  |  |     description: str | 
					
						
							|  |  |  |     permissions: Optional[dict] = None | 
					
						
							| 
									
										
										
										
											2024-11-17 06:56:00 +08:00
										 |  |  |     data: Optional[dict] = None | 
					
						
							| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  |     meta: Optional[dict] = None | 
					
						
							|  |  |  |     user_ids: list[str] = [] | 
					
						
							|  |  |  |     created_at: int  # timestamp in epoch | 
					
						
							|  |  |  |     updated_at: int  # timestamp in epoch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class GroupForm(BaseModel): | 
					
						
							|  |  |  |     name: str | 
					
						
							|  |  |  |     description: str | 
					
						
							| 
									
										
										
										
											2025-01-18 04:03:24 +08:00
										 |  |  |     permissions: Optional[dict] = None | 
					
						
							| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class GroupUpdateForm(GroupForm): | 
					
						
							|  |  |  |     user_ids: Optional[list[str]] = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class GroupTable: | 
					
						
							|  |  |  |     def insert_new_group( | 
					
						
							|  |  |  |         self, user_id: str, form_data: GroupForm | 
					
						
							|  |  |  |     ) -> Optional[GroupModel]: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             group = GroupModel( | 
					
						
							|  |  |  |                 **{ | 
					
						
							| 
									
										
										
										
											2025-01-18 04:03:24 +08:00
										 |  |  |                     **form_data.model_dump(exclude_none=True), | 
					
						
							| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  |                     "id": str(uuid.uuid4()), | 
					
						
							|  |  |  |                     "user_id": user_id, | 
					
						
							|  |  |  |                     "created_at": int(time.time()), | 
					
						
							|  |  |  |                     "updated_at": int(time.time()), | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             try: | 
					
						
							| 
									
										
										
										
											2024-11-15 10:37:29 +08:00
										 |  |  |                 result = Group(**group.model_dump()) | 
					
						
							| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  |                 db.add(result) | 
					
						
							|  |  |  |                 db.commit() | 
					
						
							|  |  |  |                 db.refresh(result) | 
					
						
							|  |  |  |                 if result: | 
					
						
							|  |  |  |                     return GroupModel.model_validate(result) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             except Exception: | 
					
						
							|  |  |  |                 return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_groups(self) -> list[GroupModel]: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             return [ | 
					
						
							|  |  |  |                 GroupModel.model_validate(group) | 
					
						
							| 
									
										
										
										
											2024-11-15 10:37:29 +08:00
										 |  |  |                 for group in db.query(Group).order_by(Group.updated_at.desc()).all() | 
					
						
							| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  |             ] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-15 17:29:07 +08:00
										 |  |  |     def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             return [ | 
					
						
							|  |  |  |                 GroupModel.model_validate(group) | 
					
						
							|  |  |  |                 for group in db.query(Group) | 
					
						
							| 
									
										
										
										
											2024-11-18 22:39:27 +08:00
										 |  |  |                 .filter( | 
					
						
							|  |  |  |                     func.json_array_length(Group.user_ids) > 0 | 
					
						
							|  |  |  |                 )  # Ensure array exists | 
					
						
							|  |  |  |                 .filter( | 
					
						
							|  |  |  |                     Group.user_ids.cast(String).like(f'%"{user_id}"%') | 
					
						
							|  |  |  |                 )  # String-based check | 
					
						
							| 
									
										
										
										
											2024-11-15 17:29:07 +08:00
										 |  |  |                 .order_by(Group.updated_at.desc()) | 
					
						
							|  |  |  |                 .all() | 
					
						
							|  |  |  |             ] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  |     def get_group_by_id(self, id: str) -> Optional[GroupModel]: | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             with get_db() as db: | 
					
						
							| 
									
										
										
										
											2024-11-15 10:37:29 +08:00
										 |  |  |                 group = db.query(Group).filter_by(id=id).first() | 
					
						
							| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  |                 return GroupModel.model_validate(group) if group else None | 
					
						
							|  |  |  |         except Exception: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-25 15:53:25 +08:00
										 |  |  |     def get_group_user_ids_by_id(self, id: str) -> Optional[str]: | 
					
						
							|  |  |  |         group = self.get_group_by_id(id) | 
					
						
							|  |  |  |         if group: | 
					
						
							|  |  |  |             return group.user_ids | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  |     def update_group_by_id( | 
					
						
							|  |  |  |         self, id: str, form_data: GroupUpdateForm, overwrite: bool = False | 
					
						
							|  |  |  |     ) -> Optional[GroupModel]: | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             with get_db() as db: | 
					
						
							| 
									
										
										
										
											2024-11-15 10:37:29 +08:00
										 |  |  |                 db.query(Group).filter_by(id=id).update( | 
					
						
							| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  |                     { | 
					
						
							|  |  |  |                         **form_data.model_dump(exclude_none=True), | 
					
						
							|  |  |  |                         "updated_at": int(time.time()), | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 db.commit() | 
					
						
							|  |  |  |                 return self.get_group_by_id(id=id) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             log.exception(e) | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def delete_group_by_id(self, id: str) -> bool: | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             with get_db() as db: | 
					
						
							| 
									
										
										
										
											2024-11-15 10:37:29 +08:00
										 |  |  |                 db.query(Group).filter_by(id=id).delete() | 
					
						
							| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  |                 db.commit() | 
					
						
							|  |  |  |                 return True | 
					
						
							|  |  |  |         except Exception: | 
					
						
							|  |  |  |             return False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def delete_all_groups(self) -> bool: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             try: | 
					
						
							| 
									
										
										
										
											2024-11-15 10:37:29 +08:00
										 |  |  |                 db.query(Group).delete() | 
					
						
							| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  |                 db.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 return True | 
					
						
							|  |  |  |             except Exception: | 
					
						
							|  |  |  |                 return False | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-21 15:20:47 +08:00
										 |  |  |     def remove_user_from_all_groups(self, user_id: str) -> bool: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 groups = self.get_groups_by_member_id(user_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 for group in groups: | 
					
						
							|  |  |  |                     group.user_ids.remove(user_id) | 
					
						
							|  |  |  |                     db.query(Group).filter_by(id=group.id).update( | 
					
						
							|  |  |  |                         { | 
					
						
							|  |  |  |                             "user_ids": group.user_ids, | 
					
						
							|  |  |  |                             "updated_at": int(time.time()), | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                     db.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 return True | 
					
						
							|  |  |  |             except Exception: | 
					
						
							|  |  |  |                 return False | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-12 00:40:19 +08:00
										 |  |  |     def create_groups_by_group_names( | 
					
						
							| 
									
										
										
										
											2025-05-25 03:17:12 +08:00
										 |  |  |         self, user_id: str, group_names: list[str] | 
					
						
							| 
									
										
										
										
											2025-06-12 00:40:19 +08:00
										 |  |  |     ) -> list[GroupModel]: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # check for existing groups | 
					
						
							|  |  |  |         existing_groups = self.get_groups() | 
					
						
							|  |  |  |         existing_group_names = {group.name for group in existing_groups} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         new_groups = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             for group_name in group_names: | 
					
						
							|  |  |  |                 if group_name not in existing_group_names: | 
					
						
							|  |  |  |                     new_group = GroupModel( | 
					
						
							|  |  |  |                         id=str(uuid.uuid4()), | 
					
						
							|  |  |  |                         user_id=user_id, | 
					
						
							|  |  |  |                         name=group_name, | 
					
						
							|  |  |  |                         description="", | 
					
						
							|  |  |  |                         created_at=int(time.time()), | 
					
						
							|  |  |  |                         updated_at=int(time.time()), | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                     try: | 
					
						
							|  |  |  |                         result = Group(**new_group.model_dump()) | 
					
						
							|  |  |  |                         db.add(result) | 
					
						
							|  |  |  |                         db.commit() | 
					
						
							|  |  |  |                         db.refresh(result) | 
					
						
							|  |  |  |                         new_groups.append(GroupModel.model_validate(result)) | 
					
						
							|  |  |  |                     except Exception as e: | 
					
						
							|  |  |  |                         log.exception(e) | 
					
						
							|  |  |  |                         continue | 
					
						
							|  |  |  |             return new_groups | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool: | 
					
						
							| 
									
										
										
										
											2025-05-25 03:17:12 +08:00
										 |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 groups = db.query(Group).filter(Group.name.in_(group_names)).all() | 
					
						
							|  |  |  |                 group_ids = [group.id for group in groups] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Remove user from groups not in the new list | 
					
						
							|  |  |  |                 existing_groups = self.get_groups_by_member_id(user_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 for group in existing_groups: | 
					
						
							|  |  |  |                     if group.id not in group_ids: | 
					
						
							|  |  |  |                         group.user_ids.remove(user_id) | 
					
						
							|  |  |  |                         db.query(Group).filter_by(id=group.id).update( | 
					
						
							|  |  |  |                             { | 
					
						
							|  |  |  |                                 "user_ids": group.user_ids, | 
					
						
							|  |  |  |                                 "updated_at": int(time.time()), | 
					
						
							|  |  |  |                             } | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Add user to new groups | 
					
						
							|  |  |  |                 for group in groups: | 
					
						
							|  |  |  |                     if user_id not in group.user_ids: | 
					
						
							|  |  |  |                         group.user_ids.append(user_id) | 
					
						
							|  |  |  |                         db.query(Group).filter_by(id=group.id).update( | 
					
						
							|  |  |  |                             { | 
					
						
							|  |  |  |                                 "user_ids": group.user_ids, | 
					
						
							|  |  |  |                                 "updated_at": int(time.time()), | 
					
						
							|  |  |  |                             } | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 db.commit() | 
					
						
							|  |  |  |                 return True | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 log.exception(e) | 
					
						
							|  |  |  |                 return False | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-15 10:35:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | Groups = GroupTable() |