open-webui/backend/open_webui/socket/main.py

652 lines
19 KiB
Python
Raw Normal View History

2024-06-05 02:13:43 +08:00
import asyncio
import random
2024-08-28 06:10:27 +08:00
import socketio
2024-09-22 08:14:59 +08:00
import logging
import sys
2024-09-22 08:12:55 +08:00
import time
2025-07-12 03:59:48 +08:00
from typing import Dict, Set
from redis import asyncio as aioredis
2025-07-12 03:59:48 +08:00
import pycrdt as Y
2024-09-22 08:12:55 +08:00
2024-12-27 13:51:09 +08:00
from open_webui.models.users import Users, UserNameResponse
2024-12-23 10:40:01 +08:00
from open_webui.models.channels import Channels
2024-12-25 09:03:14 +08:00
from open_webui.models.chats import Chats
2025-07-12 03:59:48 +08:00
from open_webui.models.notes import Notes, NoteUpdateForm
2025-03-29 02:47:14 +08:00
from open_webui.utils.redis import (
get_sentinels_from_env,
get_sentinel_url_from_env,
2025-03-29 02:47:14 +08:00
)
2024-12-25 09:03:14 +08:00
2024-09-21 05:43:22 +08:00
from open_webui.env import (
ENABLE_WEBSOCKET_SUPPORT,
WEBSOCKET_MANAGER,
WEBSOCKET_REDIS_URL,
2025-03-01 08:02:15 +08:00
WEBSOCKET_REDIS_LOCK_TIMEOUT,
WEBSOCKET_SENTINEL_PORT,
WEBSOCKET_SENTINEL_HOSTS,
2024-09-21 05:43:22 +08:00
)
2024-12-09 08:01:56 +08:00
from open_webui.utils.auth import decode_token
from open_webui.socket.utils import RedisDict, RedisLock
2025-07-12 03:59:48 +08:00
from open_webui.tasks import create_task, stop_item_tasks
from open_webui.utils.redis import get_redis_connection
from open_webui.utils.access_control import has_access, get_users_with_access
2024-06-04 14:39:52 +08:00
2024-09-22 08:14:59 +08:00
from open_webui.env import (
GLOBAL_LOG_LEVEL,
SRC_LOG_LEVELS,
)
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["SOCKET"])
2024-09-21 05:43:22 +08:00
2025-07-12 03:59:48 +08:00
REDIS = get_redis_connection(
redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=get_sentinels_from_env(
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
),
async_mode=True,
)
2024-09-21 05:43:22 +08:00
if WEBSOCKET_MANAGER == "redis":
if WEBSOCKET_SENTINEL_HOSTS:
2025-04-13 07:35:11 +08:00
mgr = socketio.AsyncRedisManager(
get_sentinel_url_from_env(
WEBSOCKET_REDIS_URL, WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
)
)
else:
mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
2024-09-21 08:24:30 +08:00
sio = socketio.AsyncServer(
cors_allowed_origins=[],
async_mode="asgi",
transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
2024-09-21 08:24:30 +08:00
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
always_connect=True,
client_manager=mgr,
)
2024-09-21 05:43:22 +08:00
else:
sio = socketio.AsyncServer(
cors_allowed_origins=[],
async_mode="asgi",
transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
2024-09-21 05:43:22 +08:00
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
always_connect=True,
)
# Timeout duration in seconds
TIMEOUT_DURATION = 3
2024-09-24 21:41:23 +08:00
# Dictionary to maintain the user pool
if WEBSOCKET_MANAGER == "redis":
log.debug("Using Redis to manage websockets.")
2025-03-29 02:47:14 +08:00
redis_sentinels = get_sentinels_from_env(
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
)
SESSION_POOL = RedisDict(
"open-webui:session_pool",
redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels,
)
USER_POOL = RedisDict(
"open-webui:user_pool",
redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels,
)
USAGE_POOL = RedisDict(
"open-webui:usage_pool",
redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels,
)
2025-07-12 04:19:24 +08:00
# TODO: Implement Yjs document management with Redis
2025-07-12 03:59:48 +08:00
DOCUMENTS = {}
DOCUMENT_USERS = {}
clean_up_lock = RedisLock(
redis_url=WEBSOCKET_REDIS_URL,
lock_name="usage_cleanup_lock",
2025-03-01 08:02:15 +08:00
timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
2025-03-27 15:51:55 +08:00
redis_sentinels=redis_sentinels,
)
aquire_func = clean_up_lock.aquire_lock
renew_func = clean_up_lock.renew_lock
release_func = clean_up_lock.release_lock
2024-09-24 21:41:23 +08:00
else:
SESSION_POOL = {}
USER_POOL = {}
USAGE_POOL = {}
2025-07-12 03:59:48 +08:00
DOCUMENTS = {} # document_id -> Y.YDoc instance
DOCUMENT_USERS = {} # document_id -> set of user sids
aquire_func = release_func = renew_func = lambda: True
2024-06-04 14:39:52 +08:00
2024-09-22 08:12:55 +08:00
async def periodic_usage_pool_cleanup():
max_retries = 2
retry_delay = random.uniform(
WEBSOCKET_REDIS_LOCK_TIMEOUT / 2, WEBSOCKET_REDIS_LOCK_TIMEOUT
)
for attempt in range(max_retries + 1):
if aquire_func():
break
else:
if attempt < max_retries:
log.debug(
f"Cleanup lock already exists. Retry {attempt + 1} after {retry_delay}s..."
)
await asyncio.sleep(retry_delay)
else:
log.warning(
"Failed to acquire cleanup lock after retries. Skipping cleanup."
)
return
log.debug("Running periodic_cleanup")
try:
while True:
if not renew_func():
log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.")
raise Exception("Unable to renew usage pool cleanup lock.")
now = int(time.time())
send_usage = False
for model_id, connections in list(USAGE_POOL.items()):
# Creating a list of sids to remove if they have timed out
expired_sids = [
sid
for sid, details in connections.items()
if now - details["updated_at"] > TIMEOUT_DURATION
]
for sid in expired_sids:
del connections[sid]
if not connections:
log.debug(f"Cleaning up model {model_id} from usage pool")
del USAGE_POOL[model_id]
else:
USAGE_POOL[model_id] = connections
send_usage = True
await asyncio.sleep(TIMEOUT_DURATION)
finally:
release_func()
2024-09-22 08:12:55 +08:00
2024-09-24 21:31:55 +08:00
app = socketio.ASGIApp(
sio,
socketio_path="/ws/socket.io",
)
2024-09-22 08:12:55 +08:00
def get_models_in_use():
# List models that are currently in use
models_in_use = list(USAGE_POOL.keys())
return models_in_use
2025-06-16 14:42:34 +08:00
def get_active_user_ids():
"""Get the list of active user IDs."""
return list(USER_POOL.keys())
def get_user_active_status(user_id):
"""Check if a user is currently active."""
return user_id in USER_POOL
def get_user_id_from_session_pool(sid):
user = SESSION_POOL.get(sid)
if user:
return user["id"]
return None
def get_user_ids_from_room(room):
active_session_ids = sio.manager.get_participants(
namespace="/",
room=room,
)
active_user_ids = list(
set(
[SESSION_POOL.get(session_id[0])["id"] for session_id in active_session_ids]
)
)
return active_user_ids
def get_active_status_by_user_id(user_id):
if user_id in USER_POOL:
return True
return False
2024-09-22 08:12:55 +08:00
@sio.on("usage")
async def usage(sid, data):
2025-05-09 18:23:16 +08:00
if sid in SESSION_POOL:
model_id = data["model"]
# Record the timestamp for the last update
current_time = int(time.time())
2024-09-22 08:12:55 +08:00
2025-05-09 18:23:16 +08:00
# Store the new usage data and task
USAGE_POOL[model_id] = {
**(USAGE_POOL[model_id] if model_id in USAGE_POOL else {}),
sid: {"updated_at": current_time},
}
2024-09-22 08:12:55 +08:00
2024-06-04 14:39:52 +08:00
@sio.event
async def connect(sid, environ, auth):
user = None
2024-06-04 15:45:56 +08:00
if auth and "token" in auth:
data = decode_token(auth["token"])
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user:
2024-12-27 13:51:09 +08:00
SESSION_POOL[sid] = user.model_dump()
2024-06-08 08:35:01 +08:00
if user.id in USER_POOL:
2024-12-20 05:46:30 +08:00
USER_POOL[user.id] = USER_POOL[user.id] + [sid]
2024-06-08 08:35:01 +08:00
else:
USER_POOL[user.id] = [sid]
2024-06-04 16:10:31 +08:00
2024-06-05 00:52:27 +08:00
@sio.on("user-join")
async def user_join(sid, data):
auth = data["auth"] if "auth" in data else None
2024-08-03 21:24:26 +08:00
if not auth or "token" not in auth:
return
2024-06-05 00:52:27 +08:00
2024-08-03 21:24:26 +08:00
data = decode_token(auth["token"])
if data is None or "id" not in data:
return
2024-06-05 00:52:27 +08:00
2024-08-03 21:24:26 +08:00
user = Users.get_user_by_id(data["id"])
if not user:
return
2024-06-05 00:52:27 +08:00
2024-12-27 13:51:09 +08:00
SESSION_POOL[sid] = user.model_dump()
2024-08-03 21:24:26 +08:00
if user.id in USER_POOL:
2024-12-20 05:46:30 +08:00
USER_POOL[user.id] = USER_POOL[user.id] + [sid]
2024-08-03 21:24:26 +08:00
else:
USER_POOL[user.id] = [sid]
2024-06-08 08:35:01 +08:00
2024-12-23 10:40:01 +08:00
# Join all the channels
channels = Channels.get_channels_by_user_id(user.id)
log.debug(f"{channels=}")
for channel in channels:
await sio.enter_room(sid, f"channel:{channel.id}")
2024-12-25 17:32:47 +08:00
return {"id": user.id, "name": user.name}
2024-06-05 00:52:27 +08:00
2024-12-24 05:00:58 +08:00
@sio.on("join-channels")
async def join_channel(sid, data):
auth = data["auth"] if "auth" in data else None
if not auth or "token" not in auth:
return
data = decode_token(auth["token"])
if data is None or "id" not in data:
return
user = Users.get_user_by_id(data["id"])
if not user:
return
# Join all the channels
channels = Channels.get_channels_by_user_id(user.id)
log.debug(f"{channels=}")
for channel in channels:
await sio.enter_room(sid, f"channel:{channel.id}")
2024-12-27 13:51:09 +08:00
@sio.on("channel-events")
async def channel_events(sid, data):
room = f"channel:{data['channel_id']}"
participants = sio.manager.get_participants(
namespace="/",
room=room,
)
sids = [sid for sid, _ in participants]
if sid not in sids:
return
event_data = data["data"]
event_type = event_data["type"]
if event_type == "typing":
await sio.emit(
"channel-events",
{
"channel_id": data["channel_id"],
2024-12-31 16:51:43 +08:00
"message_id": data.get("message_id", None),
2024-12-27 13:51:09 +08:00
"data": event_data,
"user": UserNameResponse(**SESSION_POOL[sid]).model_dump(),
},
room=room,
)
2025-07-12 03:59:48 +08:00
@sio.on("yjs:document:join")
async def yjs_document_join(sid, data):
"""Handle user joining a document"""
user = SESSION_POOL.get(sid)
try:
document_id = data["document_id"]
if document_id.startswith("note:"):
note_id = document_id.split(":")[1]
note = Notes.get_note_by_id(note_id)
if not note:
log.error(f"Note {note_id} not found")
return
2025-07-12 04:34:18 +08:00
if (
user.get("role") != "admin"
and user.get("id") != note.user_id
and not has_access(
user.get("id"), type="read", access_control=note.access_control
)
2025-07-12 03:59:48 +08:00
):
log.error(
f"User {user.get('id')} does not have access to note {note_id}"
)
return
user_id = data.get("user_id", sid)
user_name = data.get("user_name", "Anonymous")
user_color = data.get("user_color", "#000000")
log.info(f"User {user_id} joining document {document_id}")
# Initialize document if it doesn't exist
if document_id not in DOCUMENTS:
DOCUMENTS[document_id] = {
"ydoc": Y.Doc(), # Create actual Yjs document
"users": set(),
}
DOCUMENT_USERS[document_id] = set()
# Add user to document
DOCUMENTS[document_id]["users"].add(sid)
DOCUMENT_USERS[document_id].add(sid)
# Join Socket.IO room
await sio.enter_room(sid, f"doc_{document_id}")
# Send current document state as a proper Yjs update
ydoc = DOCUMENTS[document_id]["ydoc"]
# Encode the entire document state as an update
state_update = ydoc.get_update()
await sio.emit(
"yjs:document:state",
{
"document_id": document_id,
"state": list(state_update), # Convert bytes to list for JSON
},
room=sid,
)
# Notify other users about the new user
await sio.emit(
"yjs:user:joined",
{
"document_id": document_id,
"user_id": user_id,
"user_name": user_name,
"user_color": user_color,
},
room=f"doc_{document_id}",
skip_sid=sid,
)
log.info(f"User {user_id} successfully joined document {document_id}")
except Exception as e:
log.error(f"Error in yjs_document_join: {e}")
await sio.emit("error", {"message": "Failed to join document"}, room=sid)
async def document_save_handler(document_id, data, user):
if document_id.startswith("note:"):
note_id = document_id.split(":")[1]
note = Notes.get_note_by_id(note_id)
if not note:
log.error(f"Note {note_id} not found")
return
2025-07-12 04:34:18 +08:00
if (
user.get("role") != "admin"
and user.get("id") != note.user_id
and not has_access(
user.get("id"), type="read", access_control=note.access_control
)
2025-07-12 03:59:48 +08:00
):
log.error(f"User {user.get('id')} does not have access to note {note_id}")
return
Notes.update_note_by_id(note_id, NoteUpdateForm(data=data))
@sio.on("yjs:document:update")
async def yjs_document_update(sid, data):
"""Handle Yjs document updates"""
try:
document_id = data["document_id"]
await stop_item_tasks(REDIS, document_id)
user_id = data.get("user_id", sid)
2025-07-12 06:51:51 +08:00
2025-07-12 03:59:48 +08:00
update = data["update"] # List of bytes from frontend
if document_id not in DOCUMENTS:
log.warning(f"Document {document_id} not found")
return
# Apply the update to the server's Yjs document
ydoc = DOCUMENTS[document_id]["ydoc"]
update_bytes = bytes(update)
try:
ydoc.apply_update(update_bytes)
except Exception as e:
log.error(f"Failed to apply Yjs update: {e}")
return
# Broadcast update to all other users in the document
await sio.emit(
"yjs:document:update",
{
"document_id": document_id,
"user_id": user_id,
"update": update,
"socket_id": sid, # Add socket_id to match frontend filtering
},
room=f"doc_{document_id}",
skip_sid=sid,
)
async def debounced_save():
await asyncio.sleep(0.5)
await document_save_handler(
document_id, data.get("data", {}), SESSION_POOL.get(sid)
)
await create_task(REDIS, debounced_save(), document_id)
except Exception as e:
log.error(f"Error in yjs_document_update: {e}")
@sio.on("yjs:document:leave")
async def yjs_document_leave(sid, data):
"""Handle user leaving a document"""
try:
document_id = data["document_id"]
user_id = data.get("user_id", sid)
log.info(f"User {user_id} leaving document {document_id}")
if document_id in DOCUMENTS:
DOCUMENTS[document_id]["users"].discard(sid)
if document_id in DOCUMENT_USERS:
DOCUMENT_USERS[document_id].discard(sid)
# Leave Socket.IO room
await sio.leave_room(sid, f"doc_{document_id}")
# Notify other users
await sio.emit(
"yjs:user:left",
{"document_id": document_id, "user_id": user_id},
room=f"doc_{document_id}",
)
if document_id in DOCUMENTS and not DOCUMENTS[document_id]["users"]:
# If no users left, clean up the document
log.info(f"Cleaning up document {document_id} as no users are left")
del DOCUMENTS[document_id]
del DOCUMENT_USERS[document_id]
except Exception as e:
log.error(f"Error in yjs_document_leave: {e}")
@sio.on("yjs:awareness:update")
async def yjs_awareness_update(sid, data):
"""Handle awareness updates (cursors, selections, etc.)"""
try:
document_id = data["document_id"]
user_id = data.get("user_id", sid)
update = data["update"]
# Broadcast awareness update to all other users in the document
await sio.emit(
"yjs:awareness:update",
{"document_id": document_id, "user_id": user_id, "update": update},
room=f"doc_{document_id}",
skip_sid=sid,
)
except Exception as e:
log.error(f"Error in yjs_awareness_update: {e}")
2024-06-04 14:39:52 +08:00
@sio.event
2024-06-04 16:10:31 +08:00
async def disconnect(sid):
2024-06-08 12:38:09 +08:00
if sid in SESSION_POOL:
2024-12-27 13:51:09 +08:00
user = SESSION_POOL[sid]
2024-06-08 12:38:09 +08:00
del SESSION_POOL[sid]
2024-06-08 08:35:01 +08:00
2024-12-27 13:51:09 +08:00
user_id = user["id"]
2024-09-22 08:12:55 +08:00
USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid]
2024-06-08 12:38:09 +08:00
if len(USER_POOL[user_id]) == 0:
del USER_POOL[user_id]
2024-06-04 14:39:52 +08:00
else:
2024-09-12 21:13:21 +08:00
pass
# print(f"Unknown session ID {sid} disconnected")
2024-07-12 01:40:10 +08:00
def get_event_emitter(request_info, update_db=True):
2024-07-12 01:40:10 +08:00
async def __event_emitter__(event_data):
2024-12-19 17:00:32 +08:00
user_id = request_info["user_id"]
2025-03-26 16:25:01 +08:00
2024-12-20 05:11:44 +08:00
session_ids = list(
2025-03-26 16:25:01 +08:00
set(
USER_POOL.get(user_id, [])
+ (
[request_info.get("session_id")]
if request_info.get("session_id")
else []
)
)
2024-12-20 05:11:44 +08:00
)
2024-12-19 17:00:32 +08:00
2025-05-04 03:48:24 +08:00
emit_tasks = [
sio.emit(
"chat-events",
{
"chat_id": request_info.get("chat_id", None),
"message_id": request_info.get("message_id", None),
"data": event_data,
},
to=session_id,
)
for session_id in session_ids
]
2025-04-28 23:17:34 +08:00
await asyncio.gather(*emit_tasks)
2024-07-12 01:40:10 +08:00
if update_db:
if "type" in event_data and event_data["type"] == "status":
Chats.add_message_status_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
event_data.get("data", {}),
)
if "type" in event_data and event_data["type"] == "message":
message = Chats.get_message_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
)
2025-04-14 16:08:16 +08:00
if message:
content = message.get("content", "")
content += event_data.get("data", {}).get("content", "")
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
if "type" in event_data and event_data["type"] == "replace":
2025-03-26 16:11:12 +08:00
content = event_data.get("data", {}).get("content", "")
2024-12-29 11:31:03 +08:00
2025-03-26 16:11:12 +08:00
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
2024-12-29 11:31:03 +08:00
2024-07-12 01:40:10 +08:00
return __event_emitter__
2024-07-31 20:35:02 +08:00
def get_event_call(request_info):
2025-02-03 11:24:07 +08:00
async def __event_caller__(event_data):
2024-07-12 01:40:10 +08:00
response = await sio.call(
"chat-events",
{
2025-02-13 16:34:45 +08:00
"chat_id": request_info.get("chat_id", None),
"message_id": request_info.get("message_id", None),
2024-07-12 01:40:10 +08:00
"data": event_data,
},
to=request_info["session_id"],
)
return response
2025-02-03 11:24:07 +08:00
return __event_caller__
get_event_caller = get_event_call