open-webui/backend/open_webui/socket/utils.py

215 lines
7.0 KiB
Python

import json
import uuid
from open_webui.utils.redis import get_redis_connection
from open_webui.env import REDIS_KEY_PREFIX
from typing import Optional, List, Tuple
import pycrdt as Y
class RedisLock:
def __init__(
self,
redis_url,
lock_name,
timeout_secs,
redis_sentinels=[],
redis_cluster=False,
):
self.lock_name = lock_name
self.lock_id = str(uuid.uuid4())
self.timeout_secs = timeout_secs
self.lock_obtained = False
self.redis = get_redis_connection(
redis_url,
redis_sentinels,
redis_cluster=redis_cluster,
decode_responses=True,
)
def aquire_lock(self):
# nx=True will only set this key if it _hasn't_ already been set
self.lock_obtained = self.redis.set(
self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs
)
return self.lock_obtained
def renew_lock(self):
# xx=True will only set this key if it _has_ already been set
return self.redis.set(
self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs
)
def release_lock(self):
lock_value = self.redis.get(self.lock_name)
if lock_value and lock_value == self.lock_id:
self.redis.delete(self.lock_name)
class RedisDict:
def __init__(self, name, redis_url, redis_sentinels=[], redis_cluster=False):
self.name = name
self.redis = get_redis_connection(
redis_url,
redis_sentinels,
redis_cluster=redis_cluster,
decode_responses=True,
)
def __setitem__(self, key, value):
serialized_value = json.dumps(value)
self.redis.hset(self.name, key, serialized_value)
def __getitem__(self, key):
value = self.redis.hget(self.name, key)
if value is None:
raise KeyError(key)
return json.loads(value)
def __delitem__(self, key):
result = self.redis.hdel(self.name, key)
if result == 0:
raise KeyError(key)
def __contains__(self, key):
return self.redis.hexists(self.name, key)
def __len__(self):
return self.redis.hlen(self.name)
def keys(self):
return self.redis.hkeys(self.name)
def values(self):
return [json.loads(v) for v in self.redis.hvals(self.name)]
def items(self):
return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()]
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def clear(self):
self.redis.delete(self.name)
def update(self, other=None, **kwargs):
if other is not None:
for k, v in other.items() if hasattr(other, "items") else other:
self[k] = v
for k, v in kwargs.items():
self[k] = v
def setdefault(self, key, default=None):
if key not in self:
self[key] = default
return self[key]
class YdocManager:
def __init__(
self,
redis=None,
redis_key_prefix: str = f"{REDIS_KEY_PREFIX}:ydoc:documents",
):
self._updates = {}
self._users = {}
self._redis = redis
self._redis_key_prefix = redis_key_prefix
async def append_to_updates(self, document_id: str, update: bytes):
document_id = document_id.replace(":", "_")
if self._redis:
redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
await self._redis.rpush(redis_key, json.dumps(list(update)))
else:
if document_id not in self._updates:
self._updates[document_id] = []
self._updates[document_id].append(update)
async def get_updates(self, document_id: str) -> List[bytes]:
document_id = document_id.replace(":", "_")
if self._redis:
redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
updates = await self._redis.lrange(redis_key, 0, -1)
return [bytes(json.loads(update)) for update in updates]
else:
return self._updates.get(document_id, [])
async def document_exists(self, document_id: str) -> bool:
document_id = document_id.replace(":", "_")
if self._redis:
redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
return await self._redis.exists(redis_key) > 0
else:
return document_id in self._updates
async def get_users(self, document_id: str) -> List[str]:
document_id = document_id.replace(":", "_")
if self._redis:
redis_key = f"{self._redis_key_prefix}:{document_id}:users"
users = await self._redis.smembers(redis_key)
return list(users)
else:
return self._users.get(document_id, [])
async def add_user(self, document_id: str, user_id: str):
document_id = document_id.replace(":", "_")
if self._redis:
redis_key = f"{self._redis_key_prefix}:{document_id}:users"
await self._redis.sadd(redis_key, user_id)
else:
if document_id not in self._users:
self._users[document_id] = set()
self._users[document_id].add(user_id)
async def remove_user(self, document_id: str, user_id: str):
document_id = document_id.replace(":", "_")
if self._redis:
redis_key = f"{self._redis_key_prefix}:{document_id}:users"
await self._redis.srem(redis_key, user_id)
else:
if document_id in self._users and user_id in self._users[document_id]:
self._users[document_id].remove(user_id)
async def remove_user_from_all_documents(self, user_id: str):
if self._redis:
keys = await self._redis.keys(f"{self._redis_key_prefix}:*")
for key in keys:
if key.endswith(":users"):
await self._redis.srem(key, user_id)
document_id = key.split(":")[-2]
if len(await self.get_users(document_id)) == 0:
await self.clear_document(document_id)
else:
for document_id in list(self._users.keys()):
if user_id in self._users[document_id]:
self._users[document_id].remove(user_id)
if not self._users[document_id]:
del self._users[document_id]
await self.clear_document(document_id)
async def clear_document(self, document_id: str):
document_id = document_id.replace(":", "_")
if self._redis:
redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
await self._redis.delete(redis_key)
redis_users_key = f"{self._redis_key_prefix}:{document_id}:users"
await self._redis.delete(redis_users_key)
else:
if document_id in self._updates:
del self._updates[document_id]
if document_id in self._users:
del self._users[document_id]