enh: reply to message

This commit is contained in:
Timothy Jaeryang Baek 2025-09-27 04:05:12 -05:00
parent d7c54d92b5
commit 1a18928c94
10 changed files with 318 additions and 84 deletions

View File

@ -0,0 +1,34 @@
"""Add reply_to_id column to message
Revision ID: a5c220713937
Revises: 38d63c18f30f
Create Date: 2025-09-27 02:24:18.058455
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "a5c220713937"
down_revision: Union[str, None] = "38d63c18f30f"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add 'reply_to_id' column to the 'message' table for replying to messages
op.add_column(
"message",
sa.Column("reply_to_id", sa.Text(), nullable=True),
)
pass
def downgrade() -> None:
# Remove 'reply_to_id' column from the 'message' table
op.drop_column("message", "reply_to_id")
pass

View File

@ -5,6 +5,7 @@ from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.models.users import Users, UserNameResponse
from pydantic import BaseModel, ConfigDict
@ -43,6 +44,7 @@ class Message(Base):
user_id = Column(Text)
channel_id = Column(Text, nullable=True)
reply_to_id = Column(Text, nullable=True)
parent_id = Column(Text, nullable=True)
content = Column(Text)
@ -60,6 +62,7 @@ class MessageModel(BaseModel):
user_id: str
channel_id: Optional[str] = None
reply_to_id: Optional[str] = None
parent_id: Optional[str] = None
content: str
@ -77,6 +80,7 @@ class MessageModel(BaseModel):
class MessageForm(BaseModel):
content: str
reply_to_id: Optional[str] = None
parent_id: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
@ -88,7 +92,15 @@ class Reactions(BaseModel):
count: int
class MessageResponse(MessageModel):
class MessageUserResponse(MessageModel):
user: Optional[UserNameResponse] = None
class MessageReplyToResponse(MessageUserResponse):
reply_to_message: Optional[MessageUserResponse] = None
class MessageResponse(MessageReplyToResponse):
latest_reply_at: Optional[int]
reply_count: int
reactions: list[Reactions]
@ -107,6 +119,7 @@ class MessageTable:
"id": id,
"user_id": user_id,
"channel_id": channel_id,
"reply_to_id": form_data.reply_to_id,
"parent_id": form_data.parent_id,
"content": form_data.content,
"data": form_data.data,
@ -122,25 +135,36 @@ class MessageTable:
db.refresh(result)
return MessageModel.model_validate(result) if result else None
def get_message_by_id(self, id: str) -> Optional[MessageResponse]:
def get_message_by_id(self, id: str) -> Optional[MessageReplyToResponse]:
with get_db() as db:
message = db.get(Message, id)
if not message:
return None
reply_to_message = (
self.get_message_by_id(message.reply_to_id)
if message.reply_to_id
else None
)
reactions = self.get_reactions_by_message_id(id)
replies = self.get_replies_by_message_id(id)
replies = self.get_thread_replies_by_message_id(id)
return MessageResponse(
**{
user = Users.get_user_by_id(message.user_id)
return MessageReplyToResponse.model_validate(
{
**MessageModel.model_validate(message).model_dump(),
"user": user.model_dump() if user else None,
"reply_to_message": (
reply_to_message.model_dump() if reply_to_message else None
),
"latest_reply_at": replies[0].created_at if replies else None,
"reply_count": len(replies),
"reactions": reactions,
}
)
def get_replies_by_message_id(self, id: str) -> list[MessageModel]:
def get_thread_replies_by_message_id(self, id: str) -> list[MessageReplyToResponse]:
with get_db() as db:
all_messages = (
db.query(Message)
@ -148,7 +172,19 @@ class MessageTable:
.order_by(Message.created_at.desc())
.all()
)
return [MessageModel.model_validate(message) for message in all_messages]
return [
MessageReplyToResponse.model_validate(
{
**MessageModel.model_validate(message).model_dump(),
"reply_to_message": (
self.get_message_by_id(message.reply_to_id).model_dump()
if message.reply_to_id
else None
),
}
)
for message in all_messages
]
def get_reply_user_ids_by_message_id(self, id: str) -> list[str]:
with get_db() as db:
@ -159,7 +195,7 @@ class MessageTable:
def get_messages_by_channel_id(
self, channel_id: str, skip: int = 0, limit: int = 50
) -> list[MessageModel]:
) -> list[MessageReplyToResponse]:
with get_db() as db:
all_messages = (
db.query(Message)
@ -169,7 +205,20 @@ class MessageTable:
.limit(limit)
.all()
)
return [MessageModel.model_validate(message) for message in all_messages]
return [
MessageReplyToResponse.model_validate(
{
**MessageModel.model_validate(message).model_dump(),
"reply_to_message": (
self.get_message_by_id(message.reply_to_id).model_dump()
if message.reply_to_id
else None
),
}
)
for message in all_messages
]
def get_messages_by_parent_id(
self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50

View File

@ -167,7 +167,7 @@ async def delete_channel_by_id(id: str, user=Depends(get_admin_user)):
class MessageUserResponse(MessageResponse):
user: UserNameResponse
pass
@router.get("/{id}/messages", response_model=list[MessageUserResponse])
@ -196,15 +196,17 @@ async def get_channel_messages(
user = Users.get_user_by_id(message.user_id)
users[message.user_id] = user
replies = Messages.get_replies_by_message_id(message.id)
latest_reply_at = replies[0].created_at if replies else None
thread_replies = Messages.get_thread_replies_by_message_id(message.id)
latest_thread_reply_at = (
thread_replies[0].created_at if thread_replies else None
)
messages.append(
MessageUserResponse(
**{
**message.model_dump(),
"reply_count": len(replies),
"latest_reply_at": latest_reply_at,
"reply_count": len(thread_replies),
"latest_reply_at": latest_thread_reply_at,
"reactions": Messages.get_reactions_by_message_id(message.id),
"user": UserNameResponse(**users[message.user_id].model_dump()),
}
@ -253,12 +255,26 @@ async def model_response_handler(request, channel, message, user):
mentions = extract_mentions(message.content)
message_content = replace_mentions(message.content)
model_mentions = {}
# check if the message is a reply to a message sent by a model
if (
message.reply_to_message
and message.reply_to_message.meta
and message.reply_to_message.meta.get("model_id", None)
):
model_id = message.reply_to_message.meta.get("model_id", None)
model_mentions[model_id] = {"id": model_id, "id_type": "M"}
# check if any of the mentions are models
model_mentions = [mention for mention in mentions if mention["id_type"] == "M"]
for mention in mentions:
if mention["id_type"] == "M" and mention["id"] not in model_mentions:
model_mentions[mention["id"]] = mention
if not model_mentions:
return False
for mention in model_mentions:
for mention in model_mentions.values():
model_id = mention["id"]
model = MODELS.get(model_id, None)
@ -406,24 +422,14 @@ async def new_message_handler(
try:
message = Messages.insert_new_message(form_data, channel.id, user.id)
if message:
message = Messages.get_message_by_id(message.id)
event_data = {
"channel_id": channel.id,
"message_id": message.id,
"data": {
"type": "message",
"data": MessageUserResponse(
**{
**message.model_dump(),
"reply_count": 0,
"latest_reply_at": None,
"reactions": Messages.get_reactions_by_message_id(
message.id
),
"user": UserNameResponse(**user.model_dump()),
}
).model_dump(),
"data": message.model_dump(),
},
"user": UserNameResponse(**user.model_dump()).model_dump(),
"channel": channel.model_dump(),
@ -447,23 +453,16 @@ async def new_message_handler(
"message_id": parent_message.id,
"data": {
"type": "message:reply",
"data": MessageUserResponse(
**{
**parent_message.model_dump(),
"user": UserNameResponse(
**Users.get_user_by_id(
parent_message.user_id
).model_dump()
),
}
).model_dump(),
"data": parent_message.model_dump(),
},
"user": UserNameResponse(**user.model_dump()).model_dump(),
"channel": channel.model_dump(),
},
to=f"channel:{channel.id}",
)
return MessageModel(**message.model_dump()), channel
return message, channel
else:
raise Exception("Error creating message")
except Exception as e:
log.exception(e)
raise HTTPException(
@ -651,14 +650,7 @@ async def update_message_by_id(
"message_id": message.id,
"data": {
"type": "message:update",
"data": MessageUserResponse(
**{
**message.model_dump(),
"user": UserNameResponse(
**user.model_dump()
).model_dump(),
}
).model_dump(),
"data": message.model_dump(),
},
"user": UserNameResponse(**user.model_dump()).model_dump(),
"channel": channel.model_dump(),
@ -724,9 +716,6 @@ async def add_reaction_to_message(
"type": "message:reaction:add",
"data": {
**message.model_dump(),
"user": UserNameResponse(
**Users.get_user_by_id(message.user_id).model_dump()
).model_dump(),
"name": form_data.name,
},
},
@ -793,9 +782,6 @@ async def remove_reaction_by_id_and_user_id_and_name(
"type": "message:reaction:remove",
"data": {
**message.model_dump(),
"user": UserNameResponse(
**Users.get_user_by_id(message.user_id).model_dump()
).model_dump(),
"name": form_data.name,
},
},
@ -882,16 +868,7 @@ async def delete_message_by_id(
"message_id": parent_message.id,
"data": {
"type": "message:reply",
"data": MessageUserResponse(
**{
**parent_message.model_dump(),
"user": UserNameResponse(
**Users.get_user_by_id(
parent_message.user_id
).model_dump()
),
}
).model_dump(),
"data": parent_message.model_dump(),
},
"user": UserNameResponse(**user.model_dump()).model_dump(),
"channel": channel.model_dump(),

View File

@ -248,6 +248,7 @@ export const getChannelThreadMessages = async (
};
type MessageForm = {
reply_to_id?: string;
parent_id?: string;
content: string;
data?: object;

View File

@ -20,12 +20,14 @@
let scrollEnd = true;
let messagesContainerElement = null;
let chatInputElement = null;
let top = false;
let channel = null;
let messages = null;
let replyToMessage = null;
let threadId = null;
let typingUsers = [];
@ -141,16 +143,20 @@
return;
}
const res = await sendMessage(localStorage.token, id, { content: content, data: data }).catch(
(error) => {
toast.error(`${error}`);
return null;
}
);
const res = await sendMessage(localStorage.token, id, {
content: content,
data: data,
reply_to_id: replyToMessage?.id ?? null
}).catch((error) => {
toast.error(`${error}`);
return null;
});
if (res) {
messagesContainerElement.scrollTop = messagesContainerElement.scrollHeight;
}
replyToMessage = null;
};
const onChange = async () => {
@ -222,8 +228,14 @@
{#key id}
<Messages
{channel}
{messages}
{top}
{messages}
{replyToMessage}
onReply={async (message) => {
replyToMessage = message;
await tick();
chatInputElement?.focus();
}}
onThread={(id) => {
threadId = id;
}}
@ -250,6 +262,8 @@
<div class=" pb-[1rem] px-2.5">
<MessageInput
id="root"
bind:chatInputElement
bind:replyToMessage
{typingUsers}
userSuggestions={true}
channelSuggestions={true}

View File

@ -23,20 +23,23 @@
import { getSessionUser } from '$lib/apis/auths';
import { uploadFile } from '$lib/apis/files';
import { WEBUI_API_BASE_URL } from '$lib/constants';
import { getSuggestionRenderer } from '../common/RichTextInput/suggestions';
import CommandSuggestionList from '../chat/MessageInput/CommandSuggestionList.svelte';
import InputMenu from './MessageInput/InputMenu.svelte';
import Tooltip from '../common/Tooltip.svelte';
import RichTextInput from '../common/RichTextInput.svelte';
import VoiceRecording from '../chat/MessageInput/VoiceRecording.svelte';
import InputMenu from './MessageInput/InputMenu.svelte';
import { uploadFile } from '$lib/apis/files';
import { WEBUI_API_BASE_URL } from '$lib/constants';
import FileItem from '../common/FileItem.svelte';
import Image from '../common/Image.svelte';
import FilesOverlay from '../chat/MessageInput/FilesOverlay.svelte';
import InputVariablesModal from '../chat/MessageInput/InputVariablesModal.svelte';
import { getSuggestionRenderer } from '../common/RichTextInput/suggestions';
import CommandSuggestionList from '../chat/MessageInput/CommandSuggestionList.svelte';
import MentionList from './MessageInput/MentionList.svelte';
import Skeleton from '../chat/Messages/Skeleton.svelte';
import XMark from '../icons/XMark.svelte';
export let placeholder = $i18n.t('Type here...');
@ -60,6 +63,8 @@
export let userSuggestions = false;
export let channelSuggestions = false;
export let replyToMessage = null;
export let typingUsersClassName = 'from-white dark:from-gray-900';
let loaded = false;
@ -773,6 +778,32 @@
class="flex-1 flex flex-col relative w-full shadow-lg rounded-3xl border border-gray-50 dark:border-gray-850 hover:border-gray-100 focus-within:border-gray-100 hover:dark:border-gray-800 focus-within:dark:border-gray-800 transition px-1 bg-white/90 dark:bg-gray-400/5 dark:text-gray-100"
dir={$settings?.chatDirection ?? 'auto'}
>
{#if replyToMessage !== null}
<div class="px-3 pt-3 text-left w-full flex flex-col z-10">
<div class="flex items-center justify-between w-full">
<div class="pl-[1px] flex items-center gap-2 text-sm">
<div class="translate-y-[0.5px]">
<span class=""
>{$i18n.t('Replying to {{NAME}}', {
NAME: replyToMessage?.meta?.model_name ?? replyToMessage.user.name
})}</span
>
</div>
</div>
<div>
<button
class="flex items-center dark:text-gray-500"
on:click={() => {
replyToMessage = null;
}}
>
<XMark />
</button>
</div>
</div>
</div>
{/if}
{#if files.length > 0}
<div class="mx-2 mt-2.5 -mb-1 flex flex-wrap gap-2">
{#each files as file, fileIdx}
@ -890,6 +921,7 @@
if (e.key === 'Escape') {
console.info('Escape');
replyToMessage = null;
}
}}
on:paste={async (e) => {

View File

@ -23,10 +23,12 @@
export let id = null;
export let channel = null;
export let messages = [];
export let replyToMessage = null;
export let top = false;
export let thread = false;
export let onLoad: Function = () => {};
export let onReply: Function = () => {};
export let onThread: Function = () => {};
let messagesLoading = false;
@ -94,10 +96,12 @@
<Message
{message}
{thread}
replyToMessage={replyToMessage?.id === message.id}
disabled={!channel?.write_access}
showUserProfile={messageIdx === 0 ||
messageList.at(messageIdx - 1)?.user_id !== message.user_id ||
messageList.at(messageIdx - 1)?.meta?.model_id !== message?.meta?.model_id}
messageList.at(messageIdx - 1)?.meta?.model_id !== message?.meta?.model_id ||
message?.reply_to_message}
onDelete={() => {
messages = messages.filter((m) => m.id !== message.id);
@ -123,6 +127,9 @@
return null;
});
}}
onReply={(message) => {
onReply(message);
}}
onThread={(id) => {
onThread(id);
}}

View File

@ -13,8 +13,9 @@
import { getContext, onMount } from 'svelte';
const i18n = getContext<Writable<i18nType>>('i18n');
import { settings, user, shortCodesToEmojis } from '$lib/stores';
import { formatDate } from '$lib/utils';
import { settings, user, shortCodesToEmojis } from '$lib/stores';
import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
import Markdown from '$lib/components/chat/Messages/Markdown.svelte';
@ -32,18 +33,20 @@
import FaceSmile from '$lib/components/icons/FaceSmile.svelte';
import EmojiPicker from '$lib/components/common/EmojiPicker.svelte';
import ChevronRight from '$lib/components/icons/ChevronRight.svelte';
import { formatDate } from '$lib/utils';
import Emoji from '$lib/components/common/Emoji.svelte';
import { t } from 'i18next';
import Skeleton from '$lib/components/chat/Messages/Skeleton.svelte';
import ArrowUpLeftAlt from '$lib/components/icons/ArrowUpLeftAlt.svelte';
export let message;
export let showUserProfile = true;
export let thread = false;
export let replyToMessage = false;
export let disabled = false;
export let onDelete: Function = () => {};
export let onEdit: Function = () => {};
export let onReply: Function = () => {};
export let onThread: Function = () => {};
export let onReaction: Function = () => {};
@ -65,9 +68,15 @@
{#if message}
<div
id="message-{message.id}"
class="flex flex-col justify-between px-5 {showUserProfile
? 'pt-1.5 pb-0.5'
: ''} w-full max-w-full mx-auto group hover:bg-gray-300/5 dark:hover:bg-gray-700/5 transition relative"
: ''} w-full max-w-full mx-auto group hover:bg-gray-300/5 dark:hover:bg-gray-700/5 transition relative {replyToMessage
? 'border-l-4 border-blue-500 bg-blue-100/10 dark:bg-blue-100/5 pl-4'
: ''} {(message?.reply_to_message?.meta?.model_id ?? message?.reply_to_message?.user_id) ===
$user?.id
? 'border-l-4 border-orange-500 bg-orange-100/10 dark:bg-orange-100/5 pl-4'
: ''}"
>
{#if !edit && !disabled}
<div
@ -95,6 +104,17 @@
</Tooltip>
</EmojiPicker>
<Tooltip content={$i18n.t('Reply')}>
<button
class="hover:bg-gray-100 dark:hover:bg-gray-800 transition rounded-lg p-0.5"
on:click={() => {
onReply(message);
}}
>
<ArrowUpLeftAlt className="size-5" />
</button>
</Tooltip>
{#if !thread}
<Tooltip content={$i18n.t('Reply in Thread')}>
<button
@ -134,6 +154,56 @@
</div>
{/if}
{#if message?.reply_to_message?.user}
<div class="relative text-xs mb-1">
<div
class="absolute h-3 w-7 left-[18px] top-2 rounded-tl-lg border-t-2 border-l-2 border-gray-300 dark:border-gray-500 z-0"
></div>
<button
class="ml-12 flex items-center space-x-2 relative z-0"
on:click={() => {
const messageElement = document.getElementById(
`message-${message.reply_to_message.id}`
);
if (messageElement) {
messageElement.scrollIntoView({ behavior: 'smooth', block: 'center' });
messageElement.classList.add('highlight');
setTimeout(() => {
messageElement.classList.remove('highlight');
}, 2000);
return;
}
}}
>
{#if message?.reply_to_message?.meta?.model_id}
<img
src={`${WEBUI_API_BASE_URL}/models/model/profile/image?id=${message.reply_to_message.meta.model_id}`}
alt={message.reply_to_message.meta.model_name ??
message.reply_to_message.meta.model_id}
class="size-4 ml-0.5 rounded-full object-cover"
/>
{:else}
<img
src={message.reply_to_message.user?.profile_image_url ??
`${WEBUI_BASE_URL}/static/favicon.png`}
alt={message.reply_to_message.user?.name ?? $i18n.t('Unknown User')}
class="size-4 ml-0.5 rounded-full object-cover"
/>
{/if}
<div class="shrink-0">
{message?.reply_to_message.meta?.model_name ??
message?.reply_to_message.user?.name ??
$i18n.t('Unknown User')}
</div>
<div class="italic text-sm text-gray-500 dark:text-gray-400 line-clamp-1 w-full flex-1">
<Markdown id={`${message.id}-reply-to`} content={message?.reply_to_message?.content} />
</div>
</button>
</div>
{/if}
<div
class=" flex w-full message-{message.id}"
id="message-{message.id}"
@ -151,7 +221,7 @@
<ProfilePreview user={message.user}>
<ProfileImage
src={message.user?.profile_image_url ?? `${WEBUI_BASE_URL}/static/favicon.png`}
className={'size-8 translate-y-1 ml-0.5'}
className={'size-8 ml-0.5'}
/>
</ProfilePreview>
{/if}
@ -348,3 +418,18 @@
</div>
</div>
{/if}
<style>
.highlight {
animation: highlightAnimation 2s ease-in-out;
}
@keyframes highlightAnimation {
0% {
background-color: rgba(0, 60, 255, 0.1);
}
100% {
background-color: transparent;
}
}
</style>

View File

@ -22,11 +22,14 @@
let messages = null;
let top = false;
let messagesContainerElement = null;
let chatInputElement = null;
let replyToMessage = null;
let typingUsers = [];
let typingUsersTimeout = {};
let messagesContainerElement = null;
$: if (threadId) {
initHandler();
}
@ -128,12 +131,15 @@
const res = await sendMessage(localStorage.token, channel.id, {
parent_id: threadId,
reply_to_id: replyToMessage?.id ?? null,
content: content,
data: data
}).catch((error) => {
toast.error(`${error}`);
return null;
});
replyToMessage = null;
};
const onChange = async () => {
@ -180,9 +186,16 @@
<Messages
id={threadId}
{channel}
{messages}
{top}
{messages}
{replyToMessage}
thread={true}
onReply={async (message) => {
replyToMessage = message;
await tick();
chatInputElement?.focus();
}}
onLoad={async () => {
const newMessages = await getChannelThreadMessages(
localStorage.token,
@ -207,6 +220,8 @@
<div class=" pb-[1rem] px-2.5 w-full">
<MessageInput
bind:replyToMessage
bind:chatInputElement
id={threadId}
disabled={!channel?.write_access}
placeholder={!channel?.write_access

View File

@ -0,0 +1,20 @@
<script lang="ts">
export let className = 'w-4 h-4';
export let strokeWidth = '1.5';
</script>
<svg
class={className}
aria-hidden="true"
xmlns="http://www.w3.org/2000/svg"
stroke-width={strokeWidth}
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
><path d="M10.25 4.75L6.75 8.25L10.25 11.75" stroke-linecap="round" stroke-linejoin="round"
></path><path
d="M6.75 8.25L12.75 8.25C14.9591 8.25 16.75 10.0409 16.75 12.25V19.25"
stroke-linecap="round"
stroke-linejoin="round"
></path></svg
>