167 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			167 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
| # tasks.py
 | |
| import asyncio
 | |
| from typing import Dict
 | |
| from uuid import uuid4
 | |
| import json
 | |
| from redis.asyncio import Redis
 | |
| from fastapi import Request
 | |
| from typing import Dict, List, Optional
 | |
| 
 | |
| # A dictionary to keep track of active tasks
 | |
| tasks: Dict[str, asyncio.Task] = {}
 | |
| chat_tasks = {}
 | |
| 
 | |
| 
 | |
| REDIS_TASKS_KEY = "open-webui:tasks"
 | |
| REDIS_CHAT_TASKS_KEY = "open-webui:tasks:chat"
 | |
| REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
 | |
| 
 | |
| 
 | |
| def is_redis(request: Request) -> bool:
 | |
|     # Called everywhere a request is available to check Redis
 | |
|     return hasattr(request.app.state, "redis") and (request.app.state.redis is not None)
 | |
| 
 | |
| 
 | |
| async def redis_task_command_listener(app):
 | |
|     redis: Redis = app.state.redis
 | |
|     pubsub = redis.pubsub()
 | |
|     await pubsub.subscribe(REDIS_PUBSUB_CHANNEL)
 | |
| 
 | |
|     async for message in pubsub.listen():
 | |
|         if message["type"] != "message":
 | |
|             continue
 | |
|         try:
 | |
|             command = json.loads(message["data"])
 | |
|             if command.get("action") == "stop":
 | |
|                 task_id = command.get("task_id")
 | |
|                 local_task = tasks.get(task_id)
 | |
|                 if local_task:
 | |
|                     local_task.cancel()
 | |
|         except Exception as e:
 | |
|             print(f"Error handling distributed task command: {e}")
 | |
| 
 | |
| 
 | |
| ### ------------------------------
 | |
| ### REDIS-ENABLED HANDLERS
 | |
| ### ------------------------------
 | |
| 
 | |
| 
 | |
| async def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]):
 | |
|     pipe = redis.pipeline()
 | |
|     pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "")
 | |
|     if chat_id:
 | |
|         pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
 | |
|     await pipe.execute()
 | |
| 
 | |
| 
 | |
| async def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]):
 | |
|     pipe = redis.pipeline()
 | |
|     pipe.hdel(REDIS_TASKS_KEY, task_id)
 | |
|     if chat_id:
 | |
|         pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
 | |
|         if (await pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute())[-1] == 0:
 | |
|             pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}")  # Remove if empty set
 | |
|     await pipe.execute()
 | |
| 
 | |
| 
 | |
| async def redis_list_tasks(redis: Redis) -> List[str]:
 | |
|     return list(await redis.hkeys(REDIS_TASKS_KEY))
 | |
| 
 | |
| 
 | |
| async def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]:
 | |
|     return list(await redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}"))
 | |
| 
 | |
| 
 | |
| async def redis_send_command(redis: Redis, command: dict):
 | |
|     await redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
 | |
| 
 | |
| 
 | |
| async def cleanup_task(request, task_id: str, id=None):
 | |
|     """
 | |
|     Remove a completed or canceled task from the global `tasks` dictionary.
 | |
|     """
 | |
|     if is_redis(request):
 | |
|         await redis_cleanup_task(request.app.state.redis, task_id, id)
 | |
| 
 | |
|     tasks.pop(task_id, None)  # Remove the task if it exists
 | |
| 
 | |
|     # If an ID is provided, remove the task from the chat_tasks dictionary
 | |
|     if id and task_id in chat_tasks.get(id, []):
 | |
|         chat_tasks[id].remove(task_id)
 | |
|         if not chat_tasks[id]:  # If no tasks left for this ID, remove the entry
 | |
|             chat_tasks.pop(id, None)
 | |
| 
 | |
| 
 | |
| async def create_task(request, coroutine, id=None):
 | |
|     """
 | |
|     Create a new asyncio task and add it to the global task dictionary.
 | |
|     """
 | |
|     task_id = str(uuid4())  # Generate a unique ID for the task
 | |
|     task = asyncio.create_task(coroutine)  # Create the task
 | |
| 
 | |
|     # Add a done callback for cleanup
 | |
|     task.add_done_callback(
 | |
|         lambda t: asyncio.create_task(cleanup_task(request, task_id, id))
 | |
|     )
 | |
|     tasks[task_id] = task
 | |
| 
 | |
|     # If an ID is provided, associate the task with that ID
 | |
|     if chat_tasks.get(id):
 | |
|         chat_tasks[id].append(task_id)
 | |
|     else:
 | |
|         chat_tasks[id] = [task_id]
 | |
| 
 | |
|     if is_redis(request):
 | |
|         await redis_save_task(request.app.state.redis, task_id, id)
 | |
| 
 | |
|     return task_id, task
 | |
| 
 | |
| 
 | |
| async def list_tasks(request):
 | |
|     """
 | |
|     List all currently active task IDs.
 | |
|     """
 | |
|     if is_redis(request):
 | |
|         return await redis_list_tasks(request.app.state.redis)
 | |
|     return list(tasks.keys())
 | |
| 
 | |
| 
 | |
| async def list_task_ids_by_chat_id(request, id):
 | |
|     """
 | |
|     List all tasks associated with a specific ID.
 | |
|     """
 | |
|     if is_redis(request):
 | |
|         return await redis_list_chat_tasks(request.app.state.redis, id)
 | |
|     return chat_tasks.get(id, [])
 | |
| 
 | |
| 
 | |
| async def stop_task(request, task_id: str):
 | |
|     """
 | |
|     Cancel a running task and remove it from the global task list.
 | |
|     """
 | |
|     if is_redis(request):
 | |
|         # PUBSUB: All instances check if they have this task, and stop if so.
 | |
|         await redis_send_command(
 | |
|             request.app.state.redis,
 | |
|             {
 | |
|                 "action": "stop",
 | |
|                 "task_id": task_id,
 | |
|             },
 | |
|         )
 | |
|         # Optionally check if task_id still in Redis a few moments later for feedback?
 | |
|         return {"status": True, "message": f"Stop signal sent for {task_id}"}
 | |
| 
 | |
|     task = tasks.get(task_id)
 | |
|     if not task:
 | |
|         raise ValueError(f"Task with ID {task_id} not found.")
 | |
| 
 | |
|     task.cancel()  # Request task cancellation
 | |
|     try:
 | |
|         await task  # Wait for the task to handle the cancellation
 | |
|     except asyncio.CancelledError:
 | |
|         # Task successfully canceled
 | |
|         tasks.pop(task_id, None)  # Remove it from the dictionary
 | |
|         return {"status": True, "message": f"Task {task_id} successfully stopped."}
 | |
| 
 | |
|     return {"status": False, "message": f"Failed to stop task {task_id}."}
 |