335 lines
		
	
	
		
			9.6 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			335 lines
		
	
	
		
			9.6 KiB
		
	
	
	
		
			Python
		
	
	
	
import time
 | 
						|
from typing import Optional
 | 
						|
 | 
						|
from open_webui.internal.db import Base, JSONField, get_db
 | 
						|
 | 
						|
 | 
						|
from open_webui.models.chats import Chats
 | 
						|
from open_webui.models.groups import Groups
 | 
						|
 | 
						|
 | 
						|
from pydantic import BaseModel, ConfigDict
 | 
						|
from sqlalchemy import BigInteger, Column, String, Text
 | 
						|
 | 
						|
####################
 | 
						|
# User DB Schema
 | 
						|
####################
 | 
						|
 | 
						|
 | 
						|
class User(Base):
 | 
						|
    __tablename__ = "user"
 | 
						|
 | 
						|
    id = Column(String, primary_key=True)
 | 
						|
    name = Column(String)
 | 
						|
    email = Column(String)
 | 
						|
    role = Column(String)
 | 
						|
    profile_image_url = Column(Text)
 | 
						|
 | 
						|
    last_active_at = Column(BigInteger)
 | 
						|
    updated_at = Column(BigInteger)
 | 
						|
    created_at = Column(BigInteger)
 | 
						|
 | 
						|
    api_key = Column(String, nullable=True, unique=True)
 | 
						|
    settings = Column(JSONField, nullable=True)
 | 
						|
    info = Column(JSONField, nullable=True)
 | 
						|
 | 
						|
    oauth_sub = Column(Text, unique=True)
 | 
						|
 | 
						|
 | 
						|
class UserSettings(BaseModel):
 | 
						|
    ui: Optional[dict] = {}
 | 
						|
    model_config = ConfigDict(extra="allow")
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
class UserModel(BaseModel):
 | 
						|
    id: str
 | 
						|
    name: str
 | 
						|
    email: str
 | 
						|
    role: str = "pending"
 | 
						|
    profile_image_url: str
 | 
						|
 | 
						|
    last_active_at: int  # timestamp in epoch
 | 
						|
    updated_at: int  # timestamp in epoch
 | 
						|
    created_at: int  # timestamp in epoch
 | 
						|
 | 
						|
    api_key: Optional[str] = None
 | 
						|
    settings: Optional[UserSettings] = None
 | 
						|
    info: Optional[dict] = None
 | 
						|
 | 
						|
    oauth_sub: Optional[str] = None
 | 
						|
 | 
						|
    model_config = ConfigDict(from_attributes=True)
 | 
						|
 | 
						|
 | 
						|
####################
 | 
						|
# Forms
 | 
						|
####################
 | 
						|
 | 
						|
 | 
						|
class UserResponse(BaseModel):
 | 
						|
    id: str
 | 
						|
    name: str
 | 
						|
    email: str
 | 
						|
    role: str
 | 
						|
    profile_image_url: str
 | 
						|
 | 
						|
 | 
						|
class UserNameResponse(BaseModel):
 | 
						|
    id: str
 | 
						|
    name: str
 | 
						|
    role: str
 | 
						|
    profile_image_url: str
 | 
						|
 | 
						|
 | 
						|
class UserRoleUpdateForm(BaseModel):
 | 
						|
    id: str
 | 
						|
    role: str
 | 
						|
 | 
						|
 | 
						|
class UserUpdateForm(BaseModel):
 | 
						|
    name: str
 | 
						|
    email: str
 | 
						|
    profile_image_url: str
 | 
						|
    password: Optional[str] = None
 | 
						|
 | 
						|
 | 
						|
class UsersTable:
 | 
						|
    def insert_new_user(
 | 
						|
        self,
 | 
						|
        id: str,
 | 
						|
        name: str,
 | 
						|
        email: str,
 | 
						|
        profile_image_url: str = "/user.png",
 | 
						|
        role: str = "pending",
 | 
						|
        oauth_sub: Optional[str] = None,
 | 
						|
    ) -> Optional[UserModel]:
 | 
						|
        with get_db() as db:
 | 
						|
            user = UserModel(
 | 
						|
                **{
 | 
						|
                    "id": id,
 | 
						|
                    "name": name,
 | 
						|
                    "email": email,
 | 
						|
                    "role": role,
 | 
						|
                    "profile_image_url": profile_image_url,
 | 
						|
                    "last_active_at": int(time.time()),
 | 
						|
                    "created_at": int(time.time()),
 | 
						|
                    "updated_at": int(time.time()),
 | 
						|
                    "oauth_sub": oauth_sub,
 | 
						|
                }
 | 
						|
            )
 | 
						|
            result = User(**user.model_dump())
 | 
						|
            db.add(result)
 | 
						|
            db.commit()
 | 
						|
            db.refresh(result)
 | 
						|
            if result:
 | 
						|
                return user
 | 
						|
            else:
 | 
						|
                return None
 | 
						|
 | 
						|
    def get_user_by_id(self, id: str) -> Optional[UserModel]:
 | 
						|
        try:
 | 
						|
            with get_db() as db:
 | 
						|
                user = db.query(User).filter_by(id=id).first()
 | 
						|
                return UserModel.model_validate(user)
 | 
						|
        except Exception:
 | 
						|
            return None
 | 
						|
 | 
						|
    def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
 | 
						|
        try:
 | 
						|
            with get_db() as db:
 | 
						|
                user = db.query(User).filter_by(api_key=api_key).first()
 | 
						|
                return UserModel.model_validate(user)
 | 
						|
        except Exception:
 | 
						|
            return None
 | 
						|
 | 
						|
    def get_user_by_email(self, email: str) -> Optional[UserModel]:
 | 
						|
        try:
 | 
						|
            with get_db() as db:
 | 
						|
                user = db.query(User).filter_by(email=email).first()
 | 
						|
                return UserModel.model_validate(user)
 | 
						|
        except Exception:
 | 
						|
            return None
 | 
						|
 | 
						|
    def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
 | 
						|
        try:
 | 
						|
            with get_db() as db:
 | 
						|
                user = db.query(User).filter_by(oauth_sub=sub).first()
 | 
						|
                return UserModel.model_validate(user)
 | 
						|
        except Exception:
 | 
						|
            return None
 | 
						|
 | 
						|
    def get_users(
 | 
						|
        self, skip: Optional[int] = None, limit: Optional[int] = None
 | 
						|
    ) -> list[UserModel]:
 | 
						|
        with get_db() as db:
 | 
						|
 | 
						|
            query = db.query(User).order_by(User.created_at.desc())
 | 
						|
 | 
						|
            if skip:
 | 
						|
                query = query.offset(skip)
 | 
						|
            if limit:
 | 
						|
                query = query.limit(limit)
 | 
						|
 | 
						|
            users = query.all()
 | 
						|
 | 
						|
            return [UserModel.model_validate(user) for user in users]
 | 
						|
 | 
						|
    def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
 | 
						|
        with get_db() as db:
 | 
						|
            users = db.query(User).filter(User.id.in_(user_ids)).all()
 | 
						|
            return [UserModel.model_validate(user) for user in users]
 | 
						|
 | 
						|
    def get_num_users(self) -> Optional[int]:
 | 
						|
        with get_db() as db:
 | 
						|
            return db.query(User).count()
 | 
						|
 | 
						|
    def get_first_user(self) -> UserModel:
 | 
						|
        try:
 | 
						|
            with get_db() as db:
 | 
						|
                user = db.query(User).order_by(User.created_at).first()
 | 
						|
                return UserModel.model_validate(user)
 | 
						|
        except Exception:
 | 
						|
            return None
 | 
						|
 | 
						|
    def get_user_webhook_url_by_id(self, id: str) -> Optional[str]:
 | 
						|
        try:
 | 
						|
            with get_db() as db:
 | 
						|
                user = db.query(User).filter_by(id=id).first()
 | 
						|
 | 
						|
                if user.settings is None:
 | 
						|
                    return None
 | 
						|
                else:
 | 
						|
                    return (
 | 
						|
                        user.settings.get("ui", {})
 | 
						|
                        .get("notifications", {})
 | 
						|
                        .get("webhook_url", None)
 | 
						|
                    )
 | 
						|
        except Exception:
 | 
						|
            return None
 | 
						|
 | 
						|
    def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
 | 
						|
        try:
 | 
						|
            with get_db() as db:
 | 
						|
                db.query(User).filter_by(id=id).update({"role": role})
 | 
						|
                db.commit()
 | 
						|
                user = db.query(User).filter_by(id=id).first()
 | 
						|
                return UserModel.model_validate(user)
 | 
						|
        except Exception:
 | 
						|
            return None
 | 
						|
 | 
						|
    def update_user_profile_image_url_by_id(
 | 
						|
        self, id: str, profile_image_url: str
 | 
						|
    ) -> Optional[UserModel]:
 | 
						|
        try:
 | 
						|
            with get_db() as db:
 | 
						|
                db.query(User).filter_by(id=id).update(
 | 
						|
                    {"profile_image_url": profile_image_url}
 | 
						|
                )
 | 
						|
                db.commit()
 | 
						|
 | 
						|
                user = db.query(User).filter_by(id=id).first()
 | 
						|
                return UserModel.model_validate(user)
 | 
						|
        except Exception:
 | 
						|
            return None
 | 
						|
 | 
						|
    def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
 | 
						|
        try:
 | 
						|
            with get_db() as db:
 | 
						|
                db.query(User).filter_by(id=id).update(
 | 
						|
                    {"last_active_at": int(time.time())}
 | 
						|
                )
 | 
						|
                db.commit()
 | 
						|
 | 
						|
                user = db.query(User).filter_by(id=id).first()
 | 
						|
                return UserModel.model_validate(user)
 | 
						|
        except Exception:
 | 
						|
            return None
 | 
						|
 | 
						|
    def update_user_oauth_sub_by_id(
 | 
						|
        self, id: str, oauth_sub: str
 | 
						|
    ) -> Optional[UserModel]:
 | 
						|
        try:
 | 
						|
            with get_db() as db:
 | 
						|
                db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
 | 
						|
                db.commit()
 | 
						|
 | 
						|
                user = db.query(User).filter_by(id=id).first()
 | 
						|
                return UserModel.model_validate(user)
 | 
						|
        except Exception:
 | 
						|
            return None
 | 
						|
 | 
						|
    def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
 | 
						|
        try:
 | 
						|
            with get_db() as db:
 | 
						|
                db.query(User).filter_by(id=id).update(updated)
 | 
						|
                db.commit()
 | 
						|
 | 
						|
                user = db.query(User).filter_by(id=id).first()
 | 
						|
                return UserModel.model_validate(user)
 | 
						|
                # return UserModel(**user.dict())
 | 
						|
        except Exception:
 | 
						|
            return None
 | 
						|
 | 
						|
    def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
 | 
						|
        try:
 | 
						|
            with get_db() as db:
 | 
						|
                user_settings = db.query(User).filter_by(id=id).first().settings
 | 
						|
 | 
						|
                if user_settings is None:
 | 
						|
                    user_settings = {}
 | 
						|
 | 
						|
                user_settings.update(updated)
 | 
						|
 | 
						|
                db.query(User).filter_by(id=id).update({"settings": user_settings})
 | 
						|
                db.commit()
 | 
						|
 | 
						|
                user = db.query(User).filter_by(id=id).first()
 | 
						|
                return UserModel.model_validate(user)
 | 
						|
        except Exception:
 | 
						|
            return None
 | 
						|
 | 
						|
    def delete_user_by_id(self, id: str) -> bool:
 | 
						|
        try:
 | 
						|
            # Remove User from Groups
 | 
						|
            Groups.remove_user_from_all_groups(id)
 | 
						|
 | 
						|
            # Delete User Chats
 | 
						|
            result = Chats.delete_chats_by_user_id(id)
 | 
						|
            if result:
 | 
						|
                with get_db() as db:
 | 
						|
                    # Delete User
 | 
						|
                    db.query(User).filter_by(id=id).delete()
 | 
						|
                    db.commit()
 | 
						|
 | 
						|
                return True
 | 
						|
            else:
 | 
						|
                return False
 | 
						|
        except Exception:
 | 
						|
            return False
 | 
						|
 | 
						|
    def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
 | 
						|
        try:
 | 
						|
            with get_db() as db:
 | 
						|
                result = db.query(User).filter_by(id=id).update({"api_key": api_key})
 | 
						|
                db.commit()
 | 
						|
                return True if result == 1 else False
 | 
						|
        except Exception:
 | 
						|
            return False
 | 
						|
 | 
						|
    def get_user_api_key_by_id(self, id: str) -> Optional[str]:
 | 
						|
        try:
 | 
						|
            with get_db() as db:
 | 
						|
                user = db.query(User).filter_by(id=id).first()
 | 
						|
                return user.api_key
 | 
						|
        except Exception:
 | 
						|
            return None
 | 
						|
 | 
						|
    def get_valid_user_ids(self, user_ids: list[str]) -> list[str]:
 | 
						|
        with get_db() as db:
 | 
						|
            users = db.query(User).filter(User.id.in_(user_ids)).all()
 | 
						|
            return [user.id for user in users]
 | 
						|
 | 
						|
 | 
						|
Users = UsersTable()
 |