404 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			404 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
| import time
 | |
| from typing import Optional
 | |
| 
 | |
| from open_webui.internal.db import Base, JSONField, get_db
 | |
| 
 | |
| 
 | |
| from open_webui.models.chats import Chats
 | |
| from open_webui.models.groups import Groups
 | |
| 
 | |
| 
 | |
| from pydantic import BaseModel, ConfigDict
 | |
| from sqlalchemy import BigInteger, Column, String, Text
 | |
| from sqlalchemy import or_
 | |
| 
 | |
| 
 | |
| ####################
 | |
| # User DB Schema
 | |
| ####################
 | |
| 
 | |
| 
 | |
| class User(Base):
 | |
|     __tablename__ = "user"
 | |
| 
 | |
|     id = Column(String, primary_key=True)
 | |
|     name = Column(String)
 | |
|     email = Column(String)
 | |
|     role = Column(String)
 | |
|     profile_image_url = Column(Text)
 | |
| 
 | |
|     last_active_at = Column(BigInteger)
 | |
|     updated_at = Column(BigInteger)
 | |
|     created_at = Column(BigInteger)
 | |
| 
 | |
|     api_key = Column(String, nullable=True, unique=True)
 | |
|     settings = Column(JSONField, nullable=True)
 | |
|     info = Column(JSONField, nullable=True)
 | |
| 
 | |
|     oauth_sub = Column(Text, unique=True)
 | |
| 
 | |
| 
 | |
| class UserSettings(BaseModel):
 | |
|     ui: Optional[dict] = {}
 | |
|     model_config = ConfigDict(extra="allow")
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class UserModel(BaseModel):
 | |
|     id: str
 | |
|     name: str
 | |
|     email: str
 | |
|     role: str = "pending"
 | |
|     profile_image_url: str
 | |
| 
 | |
|     last_active_at: int  # timestamp in epoch
 | |
|     updated_at: int  # timestamp in epoch
 | |
|     created_at: int  # timestamp in epoch
 | |
| 
 | |
|     api_key: Optional[str] = None
 | |
|     settings: Optional[UserSettings] = None
 | |
|     info: Optional[dict] = None
 | |
| 
 | |
|     oauth_sub: Optional[str] = None
 | |
| 
 | |
|     model_config = ConfigDict(from_attributes=True)
 | |
| 
 | |
| 
 | |
| ####################
 | |
| # Forms
 | |
| ####################
 | |
| 
 | |
| 
 | |
| class UserListResponse(BaseModel):
 | |
|     users: list[UserModel]
 | |
|     total: int
 | |
| 
 | |
| 
 | |
| class UserResponse(BaseModel):
 | |
|     id: str
 | |
|     name: str
 | |
|     email: str
 | |
|     role: str
 | |
|     profile_image_url: str
 | |
| 
 | |
| 
 | |
| class UserNameResponse(BaseModel):
 | |
|     id: str
 | |
|     name: str
 | |
|     role: str
 | |
|     profile_image_url: str
 | |
| 
 | |
| 
 | |
| class UserRoleUpdateForm(BaseModel):
 | |
|     id: str
 | |
|     role: str
 | |
| 
 | |
| 
 | |
| class UserUpdateForm(BaseModel):
 | |
|     name: str
 | |
|     email: str
 | |
|     profile_image_url: str
 | |
|     password: Optional[str] = None
 | |
| 
 | |
| 
 | |
| class UsersTable:
 | |
|     def insert_new_user(
 | |
|         self,
 | |
|         id: str,
 | |
|         name: str,
 | |
|         email: str,
 | |
|         profile_image_url: str = "/user.png",
 | |
|         role: str = "pending",
 | |
|         oauth_sub: Optional[str] = None,
 | |
|     ) -> Optional[UserModel]:
 | |
|         with get_db() as db:
 | |
|             user = UserModel(
 | |
|                 **{
 | |
|                     "id": id,
 | |
|                     "name": name,
 | |
|                     "email": email,
 | |
|                     "role": role,
 | |
|                     "profile_image_url": profile_image_url,
 | |
|                     "last_active_at": int(time.time()),
 | |
|                     "created_at": int(time.time()),
 | |
|                     "updated_at": int(time.time()),
 | |
|                     "oauth_sub": oauth_sub,
 | |
|                 }
 | |
|             )
 | |
|             result = User(**user.model_dump())
 | |
|             db.add(result)
 | |
|             db.commit()
 | |
|             db.refresh(result)
 | |
|             if result:
 | |
|                 return user
 | |
|             else:
 | |
|                 return None
 | |
| 
 | |
|     def get_user_by_id(self, id: str) -> Optional[UserModel]:
 | |
|         try:
 | |
|             with get_db() as db:
 | |
|                 user = db.query(User).filter_by(id=id).first()
 | |
|                 return UserModel.model_validate(user)
 | |
|         except Exception:
 | |
|             return None
 | |
| 
 | |
|     def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
 | |
|         try:
 | |
|             with get_db() as db:
 | |
|                 user = db.query(User).filter_by(api_key=api_key).first()
 | |
|                 return UserModel.model_validate(user)
 | |
|         except Exception:
 | |
|             return None
 | |
| 
 | |
|     def get_user_by_email(self, email: str) -> Optional[UserModel]:
 | |
|         try:
 | |
|             with get_db() as db:
 | |
|                 user = db.query(User).filter_by(email=email).first()
 | |
|                 return UserModel.model_validate(user)
 | |
|         except Exception:
 | |
|             return None
 | |
| 
 | |
|     def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
 | |
|         try:
 | |
|             with get_db() as db:
 | |
|                 user = db.query(User).filter_by(oauth_sub=sub).first()
 | |
|                 return UserModel.model_validate(user)
 | |
|         except Exception:
 | |
|             return None
 | |
| 
 | |
|     def get_users(
 | |
|         self,
 | |
|         filter: Optional[dict] = None,
 | |
|         skip: Optional[int] = None,
 | |
|         limit: Optional[int] = None,
 | |
|     ) -> UserListResponse:
 | |
|         with get_db() as db:
 | |
|             query = db.query(User)
 | |
| 
 | |
|             if filter:
 | |
|                 query_key = filter.get("query")
 | |
|                 if query_key:
 | |
|                     query = query.filter(
 | |
|                         or_(
 | |
|                             User.name.ilike(f"%{query_key}%"),
 | |
|                             User.email.ilike(f"%{query_key}%"),
 | |
|                         )
 | |
|                     )
 | |
| 
 | |
|                 order_by = filter.get("order_by")
 | |
|                 direction = filter.get("direction")
 | |
| 
 | |
|                 if order_by == "name":
 | |
|                     if direction == "asc":
 | |
|                         query = query.order_by(User.name.asc())
 | |
|                     else:
 | |
|                         query = query.order_by(User.name.desc())
 | |
|                 elif order_by == "email":
 | |
|                     if direction == "asc":
 | |
|                         query = query.order_by(User.email.asc())
 | |
|                     else:
 | |
|                         query = query.order_by(User.email.desc())
 | |
| 
 | |
|                 elif order_by == "created_at":
 | |
|                     if direction == "asc":
 | |
|                         query = query.order_by(User.created_at.asc())
 | |
|                     else:
 | |
|                         query = query.order_by(User.created_at.desc())
 | |
| 
 | |
|                 elif order_by == "last_active_at":
 | |
|                     if direction == "asc":
 | |
|                         query = query.order_by(User.last_active_at.asc())
 | |
|                     else:
 | |
|                         query = query.order_by(User.last_active_at.desc())
 | |
| 
 | |
|                 elif order_by == "updated_at":
 | |
|                     if direction == "asc":
 | |
|                         query = query.order_by(User.updated_at.asc())
 | |
|                     else:
 | |
|                         query = query.order_by(User.updated_at.desc())
 | |
|                 elif order_by == "role":
 | |
|                     if direction == "asc":
 | |
|                         query = query.order_by(User.role.asc())
 | |
|                     else:
 | |
|                         query = query.order_by(User.role.desc())
 | |
| 
 | |
|             else:
 | |
|                 query = query.order_by(User.created_at.desc())
 | |
| 
 | |
|             if skip:
 | |
|                 query = query.offset(skip)
 | |
|             if limit:
 | |
|                 query = query.limit(limit)
 | |
| 
 | |
|             users = query.all()
 | |
|             return {
 | |
|                 "users": [UserModel.model_validate(user) for user in users],
 | |
|                 "total": db.query(User).count(),
 | |
|             }
 | |
| 
 | |
|     def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
 | |
|         with get_db() as db:
 | |
|             users = db.query(User).filter(User.id.in_(user_ids)).all()
 | |
|             return [UserModel.model_validate(user) for user in users]
 | |
| 
 | |
|     def get_num_users(self) -> Optional[int]:
 | |
|         with get_db() as db:
 | |
|             return db.query(User).count()
 | |
| 
 | |
|     def get_first_user(self) -> UserModel:
 | |
|         try:
 | |
|             with get_db() as db:
 | |
|                 user = db.query(User).order_by(User.created_at).first()
 | |
|                 return UserModel.model_validate(user)
 | |
|         except Exception:
 | |
|             return None
 | |
| 
 | |
|     def get_user_webhook_url_by_id(self, id: str) -> Optional[str]:
 | |
|         try:
 | |
|             with get_db() as db:
 | |
|                 user = db.query(User).filter_by(id=id).first()
 | |
| 
 | |
|                 if user.settings is None:
 | |
|                     return None
 | |
|                 else:
 | |
|                     return (
 | |
|                         user.settings.get("ui", {})
 | |
|                         .get("notifications", {})
 | |
|                         .get("webhook_url", None)
 | |
|                     )
 | |
|         except Exception:
 | |
|             return None
 | |
| 
 | |
|     def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
 | |
|         try:
 | |
|             with get_db() as db:
 | |
|                 db.query(User).filter_by(id=id).update({"role": role})
 | |
|                 db.commit()
 | |
|                 user = db.query(User).filter_by(id=id).first()
 | |
|                 return UserModel.model_validate(user)
 | |
|         except Exception:
 | |
|             return None
 | |
| 
 | |
|     def update_user_profile_image_url_by_id(
 | |
|         self, id: str, profile_image_url: str
 | |
|     ) -> Optional[UserModel]:
 | |
|         try:
 | |
|             with get_db() as db:
 | |
|                 db.query(User).filter_by(id=id).update(
 | |
|                     {"profile_image_url": profile_image_url}
 | |
|                 )
 | |
|                 db.commit()
 | |
| 
 | |
|                 user = db.query(User).filter_by(id=id).first()
 | |
|                 return UserModel.model_validate(user)
 | |
|         except Exception:
 | |
|             return None
 | |
| 
 | |
|     def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
 | |
|         try:
 | |
|             with get_db() as db:
 | |
|                 db.query(User).filter_by(id=id).update(
 | |
|                     {"last_active_at": int(time.time())}
 | |
|                 )
 | |
|                 db.commit()
 | |
| 
 | |
|                 user = db.query(User).filter_by(id=id).first()
 | |
|                 return UserModel.model_validate(user)
 | |
|         except Exception:
 | |
|             return None
 | |
| 
 | |
|     def update_user_oauth_sub_by_id(
 | |
|         self, id: str, oauth_sub: str
 | |
|     ) -> Optional[UserModel]:
 | |
|         try:
 | |
|             with get_db() as db:
 | |
|                 db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
 | |
|                 db.commit()
 | |
| 
 | |
|                 user = db.query(User).filter_by(id=id).first()
 | |
|                 return UserModel.model_validate(user)
 | |
|         except Exception:
 | |
|             return None
 | |
| 
 | |
|     def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
 | |
|         try:
 | |
|             with get_db() as db:
 | |
|                 db.query(User).filter_by(id=id).update(updated)
 | |
|                 db.commit()
 | |
| 
 | |
|                 user = db.query(User).filter_by(id=id).first()
 | |
|                 return UserModel.model_validate(user)
 | |
|                 # return UserModel(**user.dict())
 | |
|         except Exception:
 | |
|             return None
 | |
| 
 | |
|     def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
 | |
|         try:
 | |
|             with get_db() as db:
 | |
|                 user_settings = db.query(User).filter_by(id=id).first().settings
 | |
| 
 | |
|                 if user_settings is None:
 | |
|                     user_settings = {}
 | |
| 
 | |
|                 user_settings.update(updated)
 | |
| 
 | |
|                 db.query(User).filter_by(id=id).update({"settings": user_settings})
 | |
|                 db.commit()
 | |
| 
 | |
|                 user = db.query(User).filter_by(id=id).first()
 | |
|                 return UserModel.model_validate(user)
 | |
|         except Exception:
 | |
|             return None
 | |
| 
 | |
|     def delete_user_by_id(self, id: str) -> bool:
 | |
|         try:
 | |
|             # Remove User from Groups
 | |
|             Groups.remove_user_from_all_groups(id)
 | |
| 
 | |
|             # Delete User Chats
 | |
|             result = Chats.delete_chats_by_user_id(id)
 | |
|             if result:
 | |
|                 with get_db() as db:
 | |
|                     # Delete User
 | |
|                     db.query(User).filter_by(id=id).delete()
 | |
|                     db.commit()
 | |
| 
 | |
|                 return True
 | |
|             else:
 | |
|                 return False
 | |
|         except Exception:
 | |
|             return False
 | |
| 
 | |
|     def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
 | |
|         try:
 | |
|             with get_db() as db:
 | |
|                 result = db.query(User).filter_by(id=id).update({"api_key": api_key})
 | |
|                 db.commit()
 | |
|                 return True if result == 1 else False
 | |
|         except Exception:
 | |
|             return False
 | |
| 
 | |
|     def get_user_api_key_by_id(self, id: str) -> Optional[str]:
 | |
|         try:
 | |
|             with get_db() as db:
 | |
|                 user = db.query(User).filter_by(id=id).first()
 | |
|                 return user.api_key
 | |
|         except Exception:
 | |
|             return None
 | |
| 
 | |
|     def get_valid_user_ids(self, user_ids: list[str]) -> list[str]:
 | |
|         with get_db() as db:
 | |
|             users = db.query(User).filter(User.id.in_(user_ids)).all()
 | |
|             return [user.id for user in users]
 | |
| 
 | |
|     def get_super_admin_user(self) -> Optional[UserModel]:
 | |
|         with get_db() as db:
 | |
|             user = db.query(User).filter_by(role="admin").first()
 | |
|             if user:
 | |
|                 return UserModel.model_validate(user)
 | |
|             else:
 | |
|                 return None
 | |
| 
 | |
| 
 | |
| Users = UsersTable()
 |