| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  | import json | 
					
						
							|  |  |  | import time | 
					
						
							|  |  |  | import uuid | 
					
						
							|  |  |  | from typing import Optional | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from open_webui.internal.db import Base, get_db | 
					
						
							|  |  |  | from open_webui.models.tags import TagModel, Tag, Tags | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from pydantic import BaseModel, ConfigDict | 
					
						
							|  |  |  | from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON | 
					
						
							|  |  |  | from sqlalchemy import or_, func, select, and_, text | 
					
						
							|  |  |  | from sqlalchemy.sql import exists | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #################### | 
					
						
							|  |  |  | # Message DB Schema | 
					
						
							|  |  |  | #################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  | class MessageReaction(Base): | 
					
						
							|  |  |  |     __tablename__ = "message_reaction" | 
					
						
							|  |  |  |     id = Column(Text, primary_key=True) | 
					
						
							|  |  |  |     user_id = Column(Text) | 
					
						
							|  |  |  |     message_id = Column(Text) | 
					
						
							|  |  |  |     name = Column(Text) | 
					
						
							|  |  |  |     created_at = Column(BigInteger) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MessageReactionModel(BaseModel): | 
					
						
							|  |  |  |     model_config = ConfigDict(from_attributes=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     id: str | 
					
						
							|  |  |  |     user_id: str | 
					
						
							|  |  |  |     message_id: str | 
					
						
							|  |  |  |     name: str | 
					
						
							|  |  |  |     created_at: int  # timestamp in epoch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  | class Message(Base): | 
					
						
							|  |  |  |     __tablename__ = "message" | 
					
						
							|  |  |  |     id = Column(Text, primary_key=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     user_id = Column(Text) | 
					
						
							|  |  |  |     channel_id = Column(Text, nullable=True) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  |     parent_id = Column(Text, nullable=True) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |     content = Column(Text) | 
					
						
							|  |  |  |     data = Column(JSON, nullable=True) | 
					
						
							|  |  |  |     meta = Column(JSON, nullable=True) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-23 11:28:15 +08:00
										 |  |  |     created_at = Column(BigInteger)  # time_ns | 
					
						
							|  |  |  |     updated_at = Column(BigInteger)  # time_ns | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MessageModel(BaseModel): | 
					
						
							|  |  |  |     model_config = ConfigDict(from_attributes=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     id: str | 
					
						
							|  |  |  |     user_id: str | 
					
						
							|  |  |  |     channel_id: Optional[str] = None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  |     parent_id: Optional[str] = None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |     content: str | 
					
						
							|  |  |  |     data: Optional[dict] = None | 
					
						
							|  |  |  |     meta: Optional[dict] = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     created_at: int  # timestamp in epoch | 
					
						
							|  |  |  |     updated_at: int  # timestamp in epoch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #################### | 
					
						
							|  |  |  | # Forms | 
					
						
							|  |  |  | #################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MessageForm(BaseModel): | 
					
						
							|  |  |  |     content: str | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  |     parent_id: Optional[str] = None | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |     data: Optional[dict] = None | 
					
						
							|  |  |  |     meta: Optional[dict] = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  | class Reactions(BaseModel): | 
					
						
							|  |  |  |     name: str | 
					
						
							|  |  |  |     user_ids: list[str] | 
					
						
							|  |  |  |     count: int | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MessageResponse(MessageModel): | 
					
						
							| 
									
										
										
										
											2024-12-31 18:05:11 +08:00
										 |  |  |     latest_reply_at: Optional[int] | 
					
						
							|  |  |  |     reply_count: int | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  |     reactions: list[Reactions] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  | class MessageTable: | 
					
						
							|  |  |  |     def insert_new_message( | 
					
						
							|  |  |  |         self, form_data: MessageForm, channel_id: str, user_id: str | 
					
						
							|  |  |  |     ) -> Optional[MessageModel]: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             id = str(uuid.uuid4()) | 
					
						
							| 
									
										
										
										
											2024-12-23 16:19:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             ts = int(time.time_ns()) | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |             message = MessageModel( | 
					
						
							|  |  |  |                 **{ | 
					
						
							|  |  |  |                     "id": id, | 
					
						
							|  |  |  |                     "user_id": user_id, | 
					
						
							|  |  |  |                     "channel_id": channel_id, | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  |                     "parent_id": form_data.parent_id, | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |                     "content": form_data.content, | 
					
						
							|  |  |  |                     "data": form_data.data, | 
					
						
							|  |  |  |                     "meta": form_data.meta, | 
					
						
							| 
									
										
										
										
											2024-12-23 16:19:30 +08:00
										 |  |  |                     "created_at": ts, | 
					
						
							|  |  |  |                     "updated_at": ts, | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |                 } | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             result = Message(**message.model_dump()) | 
					
						
							|  |  |  |             db.add(result) | 
					
						
							|  |  |  |             db.commit() | 
					
						
							|  |  |  |             db.refresh(result) | 
					
						
							|  |  |  |             return MessageModel.model_validate(result) if result else None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  |     def get_message_by_id(self, id: str) -> Optional[MessageResponse]: | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             message = db.get(Message, id) | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  |             if not message: | 
					
						
							|  |  |  |                 return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             reactions = self.get_reactions_by_message_id(id) | 
					
						
							| 
									
										
										
										
											2024-12-31 18:05:11 +08:00
										 |  |  |             replies = self.get_replies_by_message_id(id) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  |             return MessageResponse( | 
					
						
							|  |  |  |                 **{ | 
					
						
							|  |  |  |                     **MessageModel.model_validate(message).model_dump(), | 
					
						
							| 
									
										
										
										
											2024-12-31 18:05:11 +08:00
										 |  |  |                     "latest_reply_at": replies[0].created_at if replies else None, | 
					
						
							|  |  |  |                     "reply_count": len(replies), | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  |                     "reactions": reactions, | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-31 18:05:11 +08:00
										 |  |  |     def get_replies_by_message_id(self, id: str) -> list[MessageModel]: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             all_messages = ( | 
					
						
							|  |  |  |                 db.query(Message) | 
					
						
							|  |  |  |                 .filter_by(parent_id=id) | 
					
						
							|  |  |  |                 .order_by(Message.created_at.desc()) | 
					
						
							|  |  |  |                 .all() | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             return [MessageModel.model_validate(message) for message in all_messages] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_reply_user_ids_by_message_id(self, id: str) -> list[str]: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             return [ | 
					
						
							|  |  |  |                 message.user_id | 
					
						
							|  |  |  |                 for message in db.query(Message).filter_by(parent_id=id).all() | 
					
						
							|  |  |  |             ] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |     def get_messages_by_channel_id( | 
					
						
							|  |  |  |         self, channel_id: str, skip: int = 0, limit: int = 50 | 
					
						
							|  |  |  |     ) -> list[MessageModel]: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             all_messages = ( | 
					
						
							|  |  |  |                 db.query(Message) | 
					
						
							| 
									
										
										
										
											2024-12-31 16:51:43 +08:00
										 |  |  |                 .filter_by(channel_id=channel_id, parent_id=None) | 
					
						
							| 
									
										
										
										
											2024-12-23 11:28:15 +08:00
										 |  |  |                 .order_by(Message.created_at.desc()) | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |                 .offset(skip) | 
					
						
							| 
									
										
										
										
											2024-12-23 11:28:15 +08:00
										 |  |  |                 .limit(limit) | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |                 .all() | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             return [MessageModel.model_validate(message) for message in all_messages] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-31 16:51:43 +08:00
										 |  |  |     def get_messages_by_parent_id( | 
					
						
							|  |  |  |         self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50 | 
					
						
							|  |  |  |     ) -> list[MessageModel]: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             message = db.get(Message, parent_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if not message: | 
					
						
							|  |  |  |                 return [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             all_messages = ( | 
					
						
							|  |  |  |                 db.query(Message) | 
					
						
							|  |  |  |                 .filter_by(channel_id=channel_id, parent_id=parent_id) | 
					
						
							|  |  |  |                 .order_by(Message.created_at.desc()) | 
					
						
							|  |  |  |                 .offset(skip) | 
					
						
							|  |  |  |                 .limit(limit) | 
					
						
							|  |  |  |                 .all() | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-03 12:48:50 +08:00
										 |  |  |             # If length of all_messages is less than limit, then add the parent message | 
					
						
							|  |  |  |             if len(all_messages) < limit: | 
					
						
							|  |  |  |                 all_messages.append(message) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return [MessageModel.model_validate(message) for message in all_messages] | 
					
						
							| 
									
										
										
										
											2024-12-31 16:51:43 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |     def update_message_by_id( | 
					
						
							|  |  |  |         self, id: str, form_data: MessageForm | 
					
						
							|  |  |  |     ) -> Optional[MessageModel]: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             message = db.get(Message, id) | 
					
						
							|  |  |  |             message.content = form_data.content | 
					
						
							|  |  |  |             message.data = form_data.data | 
					
						
							|  |  |  |             message.meta = form_data.meta | 
					
						
							| 
									
										
										
										
											2024-12-23 11:28:15 +08:00
										 |  |  |             message.updated_at = int(time.time_ns()) | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |             db.commit() | 
					
						
							|  |  |  |             db.refresh(message) | 
					
						
							|  |  |  |             return MessageModel.model_validate(message) if message else None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  |     def add_reaction_to_message( | 
					
						
							|  |  |  |         self, id: str, user_id: str, name: str | 
					
						
							|  |  |  |     ) -> Optional[MessageReactionModel]: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             reaction_id = str(uuid.uuid4()) | 
					
						
							|  |  |  |             reaction = MessageReactionModel( | 
					
						
							|  |  |  |                 id=reaction_id, | 
					
						
							|  |  |  |                 user_id=user_id, | 
					
						
							|  |  |  |                 message_id=id, | 
					
						
							|  |  |  |                 name=name, | 
					
						
							|  |  |  |                 created_at=int(time.time_ns()), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             result = MessageReaction(**reaction.model_dump()) | 
					
						
							|  |  |  |             db.add(result) | 
					
						
							|  |  |  |             db.commit() | 
					
						
							|  |  |  |             db.refresh(result) | 
					
						
							|  |  |  |             return MessageReactionModel.model_validate(result) if result else None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_reactions_by_message_id(self, id: str) -> list[Reactions]: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             all_reactions = db.query(MessageReaction).filter_by(message_id=id).all() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             reactions = {} | 
					
						
							|  |  |  |             for reaction in all_reactions: | 
					
						
							|  |  |  |                 if reaction.name not in reactions: | 
					
						
							|  |  |  |                     reactions[reaction.name] = { | 
					
						
							|  |  |  |                         "name": reaction.name, | 
					
						
							|  |  |  |                         "user_ids": [], | 
					
						
							|  |  |  |                         "count": 0, | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 reactions[reaction.name]["user_ids"].append(reaction.user_id) | 
					
						
							|  |  |  |                 reactions[reaction.name]["count"] += 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return [Reactions(**reaction) for reaction in reactions.values()] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def remove_reaction_by_id_and_user_id_and_name( | 
					
						
							|  |  |  |         self, id: str, user_id: str, name: str | 
					
						
							|  |  |  |     ) -> bool: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             db.query(MessageReaction).filter_by( | 
					
						
							|  |  |  |                 message_id=id, user_id=user_id, name=name | 
					
						
							|  |  |  |             ).delete() | 
					
						
							|  |  |  |             db.commit() | 
					
						
							|  |  |  |             return True | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-01 05:04:27 +08:00
										 |  |  |     def delete_reactions_by_id(self, id: str) -> bool: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             db.query(MessageReaction).filter_by(message_id=id).delete() | 
					
						
							|  |  |  |             db.commit() | 
					
						
							|  |  |  |             return True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def delete_replies_by_id(self, id: str) -> bool: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             db.query(Message).filter_by(parent_id=id).delete() | 
					
						
							|  |  |  |             db.commit() | 
					
						
							|  |  |  |             return True | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |     def delete_message_by_id(self, id: str) -> bool: | 
					
						
							|  |  |  |         with get_db() as db: | 
					
						
							|  |  |  |             db.query(Message).filter_by(id=id).delete() | 
					
						
							| 
									
										
										
										
											2024-12-31 15:06:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             # Delete all reactions to this message | 
					
						
							|  |  |  |             db.query(MessageReaction).filter_by(message_id=id).delete() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 18:42:19 +08:00
										 |  |  |             db.commit() | 
					
						
							|  |  |  |             return True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Messages = MessageTable() |