| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  | import json | 
					
						
							|  |  |  | import time | 
					
						
							|  |  |  | import uuid | 
					
						
							|  |  |  | from typing import Optional | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from open_webui.internal.db import Base, get_db | 
					
						
							| 
									
										
										
										
											2024-12-23 10:40:01 +08:00
										 |  |  | from open_webui.utils.access_control import has_access | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | from pydantic import BaseModel, ConfigDict | 
					
						
							|  |  |  | from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON | 
					
						
							|  |  |  | from sqlalchemy import or_, func, select, and_, text | 
					
						
							|  |  |  | from sqlalchemy.sql import exists | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #################### | 
					
						
							|  |  |  | # Channel DB Schema | 
					
						
							|  |  |  | #################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Channel(Base): | 
					
						
							|  |  |  |     __tablename__ = "channel" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     id = Column(Text, primary_key=True) | 
					
						
							|  |  |  |     user_id = Column(Text) | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  |     type = Column(Text, nullable=True) | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     name = Column(Text) | 
					
						
							| 
									
										
										
										
											2024-12-23 13:08:27 +08:00
										 |  |  |     description = Column(Text, nullable=True) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |     data = Column(JSON, nullable=True) | 
					
						
							|  |  |  |     meta = Column(JSON, nullable=True) | 
					
						
							|  |  |  |     access_control = Column(JSON, nullable=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     created_at = Column(BigInteger) | 
					
						
							|  |  |  |     updated_at = Column(BigInteger) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ChannelModel(BaseModel): | 
					
						
							|  |  |  |     model_config = ConfigDict(from_attributes=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     id: str | 
					
						
							|  |  |  |     user_id: str | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  |     type: Optional[str] = None | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     name: str | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  |     description: Optional[str] = None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |     data: Optional[dict] = None | 
					
						
							|  |  |  |     meta: Optional[dict] = None | 
					
						
							|  |  |  |     access_control: Optional[dict] = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     created_at: int  # timestamp in epoch | 
					
						
							|  |  |  |     updated_at: int  # timestamp in epoch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #################### | 
					
						
							|  |  |  | # Forms | 
					
						
							|  |  |  | #################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ChannelForm(BaseModel): | 
					
						
							|  |  |  |     name: str | 
					
						
							| 
									
										
										
										
											2024-12-23 13:08:27 +08:00
										 |  |  |     description: Optional[str] = None | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |     data: Optional[dict] = None | 
					
						
							|  |  |  |     meta: Optional[dict] = None | 
					
						
							|  |  |  |     access_control: Optional[dict] = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ChannelTable: | 
					
						
							|  |  |  |     def insert_new_channel( | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  |         self, type: Optional[str], form_data: ChannelForm, user_id: str | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |     ) -> Optional[ChannelModel]: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							| 
									
										
										
										
											2024-12-22 19:10:10 +08:00
										 |  |  |             channel = ChannelModel( | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |                 **{ | 
					
						
							| 
									
										
										
										
											2024-12-23 13:20:24 +08:00
										 |  |  |                     **form_data.model_dump(), | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  |                     "type": type, | 
					
						
							| 
									
										
										
										
											2024-12-23 13:33:13 +08:00
										 |  |  |                     "name": form_data.name.lower(), | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |                     "id": str(uuid.uuid4()), | 
					
						
							|  |  |  |                     "user_id": user_id, | 
					
						
							| 
									
										
										
										
											2024-12-23 13:20:24 +08:00
										 |  |  |                     "created_at": int(time.time_ns()), | 
					
						
							|  |  |  |                     "updated_at": int(time.time_ns()), | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |                 } | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 19:10:10 +08:00
										 |  |  |             new_channel = Channel(**channel.model_dump()) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |             db.add(new_channel) | 
					
						
							|  |  |  |             db.commit() | 
					
						
							| 
									
										
										
										
											2024-12-22 19:10:10 +08:00
										 |  |  |             return channel | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def get_channels(self) -> list[ChannelModel]: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             channels = db.query(Channel).all() | 
					
						
							|  |  |  |             return [ChannelModel.model_validate(channel) for channel in channels] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-23 10:40:01 +08:00
										 |  |  |     def get_channels_by_user_id( | 
					
						
							|  |  |  |         self, user_id: str, permission: str = "read" | 
					
						
							|  |  |  |     ) -> list[ChannelModel]: | 
					
						
							|  |  |  |         channels = self.get_channels() | 
					
						
							|  |  |  |         return [ | 
					
						
							|  |  |  |             channel | 
					
						
							|  |  |  |             for channel in channels | 
					
						
							|  |  |  |             if channel.user_id == user_id | 
					
						
							|  |  |  |             or has_access(user_id, permission, channel.access_control) | 
					
						
							|  |  |  |         ] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |     def get_channel_by_id(self, id: str) -> Optional[ChannelModel]: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             channel = db.query(Channel).filter(Channel.id == id).first() | 
					
						
							|  |  |  |             return ChannelModel.model_validate(channel) if channel else None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def update_channel_by_id( | 
					
						
							|  |  |  |         self, id: str, form_data: ChannelForm | 
					
						
							|  |  |  |     ) -> Optional[ChannelModel]: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             channel = db.query(Channel).filter(Channel.id == id).first() | 
					
						
							|  |  |  |             if not channel: | 
					
						
							|  |  |  |                 return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             channel.name = form_data.name | 
					
						
							|  |  |  |             channel.data = form_data.data | 
					
						
							|  |  |  |             channel.meta = form_data.meta | 
					
						
							|  |  |  |             channel.access_control = form_data.access_control | 
					
						
							| 
									
										
										
										
											2024-12-23 13:20:24 +08:00
										 |  |  |             channel.updated_at = int(time.time_ns()) | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             db.commit() | 
					
						
							|  |  |  |             return ChannelModel.model_validate(channel) if channel else None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def delete_channel_by_id(self, id: str): | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             db.query(Channel).filter(Channel.id == id).delete() | 
					
						
							|  |  |  |             db.commit() | 
					
						
							|  |  |  |             return True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Channels = ChannelTable() |