diff --git a/backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py b/backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py new file mode 100644 index 0000000000..dd2b7d1a68 --- /dev/null +++ b/backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py @@ -0,0 +1,34 @@ +"""Add reply_to_id column to message + +Revision ID: a5c220713937 +Revises: 38d63c18f30f +Create Date: 2025-09-27 02:24:18.058455 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "a5c220713937" +down_revision: Union[str, None] = "38d63c18f30f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add 'reply_to_id' column to the 'message' table for replying to messages + op.add_column( + "message", + sa.Column("reply_to_id", sa.Text(), nullable=True), + ) + pass + + +def downgrade() -> None: + # Remove 'reply_to_id' column from the 'message' table + op.drop_column("message", "reply_to_id") + + pass diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index ff4553ee9d..197befa061 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -5,6 +5,7 @@ from typing import Optional from open_webui.internal.db import Base, get_db from open_webui.models.tags import TagModel, Tag, Tags +from open_webui.models.users import Users, UserNameResponse from pydantic import BaseModel, ConfigDict @@ -43,6 +44,7 @@ class Message(Base): user_id = Column(Text) channel_id = Column(Text, nullable=True) + reply_to_id = Column(Text, nullable=True) parent_id = Column(Text, nullable=True) content = Column(Text) @@ -60,6 +62,7 @@ class MessageModel(BaseModel): user_id: str channel_id: Optional[str] = None + reply_to_id: Optional[str] = None parent_id: Optional[str] = None content: str @@ -77,6 +80,7 @@ class MessageModel(BaseModel): class MessageForm(BaseModel): content: str + reply_to_id: Optional[str] = None parent_id: Optional[str] = None data: Optional[dict] = None meta: Optional[dict] = None @@ -88,7 +92,15 @@ class Reactions(BaseModel): count: int -class MessageResponse(MessageModel): +class MessageUserResponse(MessageModel): + user: Optional[UserNameResponse] = None + + +class MessageReplyToResponse(MessageUserResponse): + reply_to_message: Optional[MessageUserResponse] = None + + +class MessageResponse(MessageReplyToResponse): latest_reply_at: Optional[int] reply_count: int reactions: list[Reactions] @@ -107,6 +119,7 @@ class MessageTable: "id": id, "user_id": user_id, "channel_id": channel_id, + "reply_to_id": form_data.reply_to_id, "parent_id": form_data.parent_id, "content": form_data.content, "data": form_data.data, @@ -122,25 +135,36 @@ class MessageTable: db.refresh(result) return MessageModel.model_validate(result) if result else None - def get_message_by_id(self, id: str) -> Optional[MessageResponse]: + def get_message_by_id(self, id: str) -> Optional[MessageReplyToResponse]: with get_db() as db: message = db.get(Message, id) if not message: return None + reply_to_message = ( + self.get_message_by_id(message.reply_to_id) + if message.reply_to_id + else None + ) reactions = self.get_reactions_by_message_id(id) - replies = self.get_replies_by_message_id(id) + replies = self.get_thread_replies_by_message_id(id) - return MessageResponse( - **{ + user = Users.get_user_by_id(message.user_id) + + return MessageReplyToResponse.model_validate( + { **MessageModel.model_validate(message).model_dump(), + "user": user.model_dump() if user else None, + "reply_to_message": ( + reply_to_message.model_dump() if reply_to_message else None + ), "latest_reply_at": replies[0].created_at if replies else None, "reply_count": len(replies), "reactions": reactions, } ) - def get_replies_by_message_id(self, id: str) -> list[MessageModel]: + def get_thread_replies_by_message_id(self, id: str) -> list[MessageReplyToResponse]: with get_db() as db: all_messages = ( db.query(Message) @@ -148,7 +172,19 @@ class MessageTable: .order_by(Message.created_at.desc()) .all() ) - return [MessageModel.model_validate(message) for message in all_messages] + return [ + MessageReplyToResponse.model_validate( + { + **MessageModel.model_validate(message).model_dump(), + "reply_to_message": ( + self.get_message_by_id(message.reply_to_id).model_dump() + if message.reply_to_id + else None + ), + } + ) + for message in all_messages + ] def get_reply_user_ids_by_message_id(self, id: str) -> list[str]: with get_db() as db: @@ -159,7 +195,7 @@ class MessageTable: def get_messages_by_channel_id( self, channel_id: str, skip: int = 0, limit: int = 50 - ) -> list[MessageModel]: + ) -> list[MessageReplyToResponse]: with get_db() as db: all_messages = ( db.query(Message) @@ -169,7 +205,20 @@ class MessageTable: .limit(limit) .all() ) - return [MessageModel.model_validate(message) for message in all_messages] + + return [ + MessageReplyToResponse.model_validate( + { + **MessageModel.model_validate(message).model_dump(), + "reply_to_message": ( + self.get_message_by_id(message.reply_to_id).model_dump() + if message.reply_to_id + else None + ), + } + ) + for message in all_messages + ] def get_messages_by_parent_id( self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50 diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index e7b8366347..eb5f3c29f5 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -167,7 +167,7 @@ async def delete_channel_by_id(id: str, user=Depends(get_admin_user)): class MessageUserResponse(MessageResponse): - user: UserNameResponse + pass @router.get("/{id}/messages", response_model=list[MessageUserResponse]) @@ -196,15 +196,17 @@ async def get_channel_messages( user = Users.get_user_by_id(message.user_id) users[message.user_id] = user - replies = Messages.get_replies_by_message_id(message.id) - latest_reply_at = replies[0].created_at if replies else None + thread_replies = Messages.get_thread_replies_by_message_id(message.id) + latest_thread_reply_at = ( + thread_replies[0].created_at if thread_replies else None + ) messages.append( MessageUserResponse( **{ **message.model_dump(), - "reply_count": len(replies), - "latest_reply_at": latest_reply_at, + "reply_count": len(thread_replies), + "latest_reply_at": latest_thread_reply_at, "reactions": Messages.get_reactions_by_message_id(message.id), "user": UserNameResponse(**users[message.user_id].model_dump()), } @@ -253,12 +255,26 @@ async def model_response_handler(request, channel, message, user): mentions = extract_mentions(message.content) message_content = replace_mentions(message.content) + model_mentions = {} + + # check if the message is a reply to a message sent by a model + if ( + message.reply_to_message + and message.reply_to_message.meta + and message.reply_to_message.meta.get("model_id", None) + ): + model_id = message.reply_to_message.meta.get("model_id", None) + model_mentions[model_id] = {"id": model_id, "id_type": "M"} + # check if any of the mentions are models - model_mentions = [mention for mention in mentions if mention["id_type"] == "M"] + for mention in mentions: + if mention["id_type"] == "M" and mention["id"] not in model_mentions: + model_mentions[mention["id"]] = mention + if not model_mentions: return False - for mention in model_mentions: + for mention in model_mentions.values(): model_id = mention["id"] model = MODELS.get(model_id, None) @@ -406,24 +422,14 @@ async def new_message_handler( try: message = Messages.insert_new_message(form_data, channel.id, user.id) - if message: + message = Messages.get_message_by_id(message.id) event_data = { "channel_id": channel.id, "message_id": message.id, "data": { "type": "message", - "data": MessageUserResponse( - **{ - **message.model_dump(), - "reply_count": 0, - "latest_reply_at": None, - "reactions": Messages.get_reactions_by_message_id( - message.id - ), - "user": UserNameResponse(**user.model_dump()), - } - ).model_dump(), + "data": message.model_dump(), }, "user": UserNameResponse(**user.model_dump()).model_dump(), "channel": channel.model_dump(), @@ -447,23 +453,16 @@ async def new_message_handler( "message_id": parent_message.id, "data": { "type": "message:reply", - "data": MessageUserResponse( - **{ - **parent_message.model_dump(), - "user": UserNameResponse( - **Users.get_user_by_id( - parent_message.user_id - ).model_dump() - ), - } - ).model_dump(), + "data": parent_message.model_dump(), }, "user": UserNameResponse(**user.model_dump()).model_dump(), "channel": channel.model_dump(), }, to=f"channel:{channel.id}", ) - return MessageModel(**message.model_dump()), channel + return message, channel + else: + raise Exception("Error creating message") except Exception as e: log.exception(e) raise HTTPException( @@ -651,14 +650,7 @@ async def update_message_by_id( "message_id": message.id, "data": { "type": "message:update", - "data": MessageUserResponse( - **{ - **message.model_dump(), - "user": UserNameResponse( - **user.model_dump() - ).model_dump(), - } - ).model_dump(), + "data": message.model_dump(), }, "user": UserNameResponse(**user.model_dump()).model_dump(), "channel": channel.model_dump(), @@ -724,9 +716,6 @@ async def add_reaction_to_message( "type": "message:reaction:add", "data": { **message.model_dump(), - "user": UserNameResponse( - **Users.get_user_by_id(message.user_id).model_dump() - ).model_dump(), "name": form_data.name, }, }, @@ -793,9 +782,6 @@ async def remove_reaction_by_id_and_user_id_and_name( "type": "message:reaction:remove", "data": { **message.model_dump(), - "user": UserNameResponse( - **Users.get_user_by_id(message.user_id).model_dump() - ).model_dump(), "name": form_data.name, }, }, @@ -882,16 +868,7 @@ async def delete_message_by_id( "message_id": parent_message.id, "data": { "type": "message:reply", - "data": MessageUserResponse( - **{ - **parent_message.model_dump(), - "user": UserNameResponse( - **Users.get_user_by_id( - parent_message.user_id - ).model_dump() - ), - } - ).model_dump(), + "data": parent_message.model_dump(), }, "user": UserNameResponse(**user.model_dump()).model_dump(), "channel": channel.model_dump(), diff --git a/src/lib/apis/channels/index.ts b/src/lib/apis/channels/index.ts index 548572c6fb..ac51e5a5d0 100644 --- a/src/lib/apis/channels/index.ts +++ b/src/lib/apis/channels/index.ts @@ -248,6 +248,7 @@ export const getChannelThreadMessages = async ( }; type MessageForm = { + reply_to_id?: string; parent_id?: string; content: string; data?: object; diff --git a/src/lib/components/channel/Channel.svelte b/src/lib/components/channel/Channel.svelte index 7502066bdb..efe2853e2e 100644 --- a/src/lib/components/channel/Channel.svelte +++ b/src/lib/components/channel/Channel.svelte @@ -20,12 +20,14 @@ let scrollEnd = true; let messagesContainerElement = null; + let chatInputElement = null; let top = false; let channel = null; let messages = null; + let replyToMessage = null; let threadId = null; let typingUsers = []; @@ -141,16 +143,20 @@ return; } - const res = await sendMessage(localStorage.token, id, { content: content, data: data }).catch( - (error) => { - toast.error(`${error}`); - return null; - } - ); + const res = await sendMessage(localStorage.token, id, { + content: content, + data: data, + reply_to_id: replyToMessage?.id ?? null + }).catch((error) => { + toast.error(`${error}`); + return null; + }); if (res) { messagesContainerElement.scrollTop = messagesContainerElement.scrollHeight; } + + replyToMessage = null; }; const onChange = async () => { @@ -222,8 +228,14 @@ {#key id} { + replyToMessage = message; + await tick(); + chatInputElement?.focus(); + }} onThread={(id) => { threadId = id; }} @@ -250,6 +262,8 @@
+ {#if replyToMessage !== null} +
+
+
+
+ {$i18n.t('Replying to {{NAME}}', { + NAME: replyToMessage?.meta?.model_name ?? replyToMessage.user.name + })} +
+
+
+ +
+
+
+ {/if} + {#if files.length > 0}
{#each files as file, fileIdx} @@ -890,6 +921,7 @@ if (e.key === 'Escape') { console.info('Escape'); + replyToMessage = null; } }} on:paste={async (e) => { diff --git a/src/lib/components/channel/Messages.svelte b/src/lib/components/channel/Messages.svelte index 23ca41d19b..0f1c666bb4 100644 --- a/src/lib/components/channel/Messages.svelte +++ b/src/lib/components/channel/Messages.svelte @@ -23,10 +23,12 @@ export let id = null; export let channel = null; export let messages = []; + export let replyToMessage = null; export let top = false; export let thread = false; export let onLoad: Function = () => {}; + export let onReply: Function = () => {}; export let onThread: Function = () => {}; let messagesLoading = false; @@ -94,10 +96,12 @@ { messages = messages.filter((m) => m.id !== message.id); @@ -123,6 +127,9 @@ return null; }); }} + onReply={(message) => { + onReply(message); + }} onThread={(id) => { onThread(id); }} diff --git a/src/lib/components/channel/Messages/Message.svelte b/src/lib/components/channel/Messages/Message.svelte index 73f41e8d4a..d498f0089b 100644 --- a/src/lib/components/channel/Messages/Message.svelte +++ b/src/lib/components/channel/Messages/Message.svelte @@ -13,8 +13,9 @@ import { getContext, onMount } from 'svelte'; const i18n = getContext>('i18n'); - import { settings, user, shortCodesToEmojis } from '$lib/stores'; + import { formatDate } from '$lib/utils'; + import { settings, user, shortCodesToEmojis } from '$lib/stores'; import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; import Markdown from '$lib/components/chat/Messages/Markdown.svelte'; @@ -32,18 +33,20 @@ import FaceSmile from '$lib/components/icons/FaceSmile.svelte'; import EmojiPicker from '$lib/components/common/EmojiPicker.svelte'; import ChevronRight from '$lib/components/icons/ChevronRight.svelte'; - import { formatDate } from '$lib/utils'; import Emoji from '$lib/components/common/Emoji.svelte'; - import { t } from 'i18next'; import Skeleton from '$lib/components/chat/Messages/Skeleton.svelte'; + import ArrowUpLeftAlt from '$lib/components/icons/ArrowUpLeftAlt.svelte'; export let message; export let showUserProfile = true; export let thread = false; + + export let replyToMessage = false; export let disabled = false; export let onDelete: Function = () => {}; export let onEdit: Function = () => {}; + export let onReply: Function = () => {}; export let onThread: Function = () => {}; export let onReaction: Function = () => {}; @@ -65,9 +68,15 @@ {#if message}
{#if !edit && !disabled}
+ + + + {#if !thread} +
+ {/if}
{/if} @@ -348,3 +418,18 @@
{/if} + + diff --git a/src/lib/components/channel/Thread.svelte b/src/lib/components/channel/Thread.svelte index 2cf73e3311..4b56af62b0 100644 --- a/src/lib/components/channel/Thread.svelte +++ b/src/lib/components/channel/Thread.svelte @@ -22,11 +22,14 @@ let messages = null; let top = false; + let messagesContainerElement = null; + let chatInputElement = null; + + let replyToMessage = null; + let typingUsers = []; let typingUsersTimeout = {}; - let messagesContainerElement = null; - $: if (threadId) { initHandler(); } @@ -128,12 +131,15 @@ const res = await sendMessage(localStorage.token, channel.id, { parent_id: threadId, + reply_to_id: replyToMessage?.id ?? null, content: content, data: data }).catch((error) => { toast.error(`${error}`); return null; }); + + replyToMessage = null; }; const onChange = async () => { @@ -180,9 +186,16 @@ { + replyToMessage = message; + + await tick(); + chatInputElement?.focus(); + }} onLoad={async () => { const newMessages = await getChannelThreadMessages( localStorage.token, @@ -207,6 +220,8 @@
+ export let className = 'w-4 h-4'; + export let strokeWidth = '1.5'; + + +