| 
									
										
										
										
											2024-12-19 17:00:32 +08:00
										 |  |  | # tasks.py | 
					
						
							|  |  |  | import asyncio | 
					
						
							|  |  |  | from typing import Dict | 
					
						
							|  |  |  | from uuid import uuid4 | 
					
						
							| 
									
										
										
										
											2025-06-09 00:58:31 +08:00
										 |  |  | import json | 
					
						
							| 
									
										
										
										
											2025-07-03 07:25:39 +08:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2025-06-09 21:21:10 +08:00
										 |  |  | from redis.asyncio import Redis | 
					
						
							| 
									
										
										
										
											2025-06-09 01:20:30 +08:00
										 |  |  | from fastapi import Request | 
					
						
							|  |  |  | from typing import Dict, List, Optional | 
					
						
							| 
									
										
										
										
											2024-12-19 17:00:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-23 01:40:29 +08:00
										 |  |  | from open_webui.env import SRC_LOG_LEVELS, REDIS_KEY_PREFIX | 
					
						
							| 
									
										
										
										
											2025-07-03 07:25:39 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | log = logging.getLogger(__name__) | 
					
						
							|  |  |  | log.setLevel(SRC_LOG_LEVELS["MAIN"]) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-19 17:00:32 +08:00
										 |  |  | # A dictionary to keep track of active tasks | 
					
						
							|  |  |  | tasks: Dict[str, asyncio.Task] = {} | 
					
						
							| 
									
										
										
										
											2025-07-11 22:14:48 +08:00
										 |  |  | item_tasks = {} | 
					
						
							| 
									
										
										
										
											2024-12-19 17:00:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-23 01:40:29 +08:00
										 |  |  | REDIS_TASKS_KEY = f"{REDIS_KEY_PREFIX}:tasks" | 
					
						
							|  |  |  | REDIS_ITEM_TASKS_KEY = f"{REDIS_KEY_PREFIX}:tasks:item" | 
					
						
							|  |  |  | REDIS_PUBSUB_CHANNEL = f"{REDIS_KEY_PREFIX}:tasks:commands" | 
					
						
							| 
									
										
										
										
											2025-06-09 01:20:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | async def redis_task_command_listener(app): | 
					
						
							|  |  |  |     redis: Redis = app.state.redis | 
					
						
							|  |  |  |     pubsub = redis.pubsub() | 
					
						
							|  |  |  |     await pubsub.subscribe(REDIS_PUBSUB_CHANNEL) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async for message in pubsub.listen(): | 
					
						
							|  |  |  |         if message["type"] != "message": | 
					
						
							|  |  |  |             continue | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             command = json.loads(message["data"]) | 
					
						
							|  |  |  |             if command.get("action") == "stop": | 
					
						
							|  |  |  |                 task_id = command.get("task_id") | 
					
						
							|  |  |  |                 local_task = tasks.get(task_id) | 
					
						
							|  |  |  |                 if local_task: | 
					
						
							|  |  |  |                     local_task.cancel() | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							| 
									
										
										
										
											2025-07-03 07:25:39 +08:00
										 |  |  |             log.exception(f"Error handling distributed task command: {e}") | 
					
						
							| 
									
										
										
										
											2025-06-09 01:20:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ### ------------------------------ | 
					
						
							|  |  |  | ### REDIS-ENABLED HANDLERS | 
					
						
							|  |  |  | ### ------------------------------ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 22:14:48 +08:00
										 |  |  | async def redis_save_task(redis: Redis, task_id: str, item_id: Optional[str]): | 
					
						
							| 
									
										
										
										
											2025-06-09 01:20:30 +08:00
										 |  |  |     pipe = redis.pipeline() | 
					
						
							| 
									
										
										
										
											2025-07-11 22:14:48 +08:00
										 |  |  |     pipe.hset(REDIS_TASKS_KEY, task_id, item_id or "") | 
					
						
							|  |  |  |     if item_id: | 
					
						
							|  |  |  |         pipe.sadd(f"{REDIS_ITEM_TASKS_KEY}:{item_id}", task_id) | 
					
						
							| 
									
										
										
										
											2025-06-09 21:21:10 +08:00
										 |  |  |     await pipe.execute() | 
					
						
							| 
									
										
										
										
											2025-06-09 01:20:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 22:14:48 +08:00
										 |  |  | async def redis_cleanup_task(redis: Redis, task_id: str, item_id: Optional[str]): | 
					
						
							| 
									
										
										
										
											2025-06-09 01:20:30 +08:00
										 |  |  |     pipe = redis.pipeline() | 
					
						
							|  |  |  |     pipe.hdel(REDIS_TASKS_KEY, task_id) | 
					
						
							| 
									
										
										
										
											2025-07-11 22:14:48 +08:00
										 |  |  |     if item_id: | 
					
						
							|  |  |  |         pipe.srem(f"{REDIS_ITEM_TASKS_KEY}:{item_id}", task_id) | 
					
						
							|  |  |  |         if (await pipe.scard(f"{REDIS_ITEM_TASKS_KEY}:{item_id}").execute())[-1] == 0: | 
					
						
							|  |  |  |             pipe.delete(f"{REDIS_ITEM_TASKS_KEY}:{item_id}")  # Remove if empty set | 
					
						
							| 
									
										
										
										
											2025-06-09 21:21:10 +08:00
										 |  |  |     await pipe.execute() | 
					
						
							| 
									
										
										
										
											2025-06-09 01:20:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-09 21:21:10 +08:00
										 |  |  | async def redis_list_tasks(redis: Redis) -> List[str]: | 
					
						
							|  |  |  |     return list(await redis.hkeys(REDIS_TASKS_KEY)) | 
					
						
							| 
									
										
										
										
											2025-06-09 01:20:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 22:14:48 +08:00
										 |  |  | async def redis_list_item_tasks(redis: Redis, item_id: str) -> List[str]: | 
					
						
							|  |  |  |     return list(await redis.smembers(f"{REDIS_ITEM_TASKS_KEY}:{item_id}")) | 
					
						
							| 
									
										
										
										
											2025-06-09 01:20:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-06-09 21:21:10 +08:00
										 |  |  | async def redis_send_command(redis: Redis, command: dict): | 
					
						
							|  |  |  |     await redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command)) | 
					
						
							| 
									
										
										
										
											2025-06-09 01:20:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 21:53:53 +08:00
										 |  |  | async def cleanup_task(redis, task_id: str, id=None): | 
					
						
							| 
									
										
										
										
											2024-12-19 17:00:32 +08:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     Remove a completed or canceled task from the global `tasks` dictionary. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-07-11 21:53:53 +08:00
										 |  |  |     if redis: | 
					
						
							|  |  |  |         await redis_cleanup_task(redis, task_id, id) | 
					
						
							| 
									
										
										
										
											2025-06-09 01:20:30 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-19 17:00:32 +08:00
										 |  |  |     tasks.pop(task_id, None)  # Remove the task if it exists | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 22:14:48 +08:00
										 |  |  |     # If an ID is provided, remove the task from the item_tasks dictionary | 
					
						
							|  |  |  |     if id and task_id in item_tasks.get(id, []): | 
					
						
							|  |  |  |         item_tasks[id].remove(task_id) | 
					
						
							|  |  |  |         if not item_tasks[id]:  # If no tasks left for this ID, remove the entry | 
					
						
							|  |  |  |             item_tasks.pop(id, None) | 
					
						
							| 
									
										
										
										
											2024-12-19 17:00:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-13 11:51:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 21:53:53 +08:00
										 |  |  | async def create_task(redis, coroutine, id=None): | 
					
						
							| 
									
										
										
										
											2024-12-19 17:00:32 +08:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     Create a new asyncio task and add it to the global task dictionary. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     task_id = str(uuid4())  # Generate a unique ID for the task | 
					
						
							|  |  |  |     task = asyncio.create_task(coroutine)  # Create the task | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Add a done callback for cleanup | 
					
						
							| 
									
										
										
										
											2025-06-09 21:21:10 +08:00
										 |  |  |     task.add_done_callback( | 
					
						
							| 
									
										
										
										
											2025-07-11 21:53:53 +08:00
										 |  |  |         lambda t: asyncio.create_task(cleanup_task(redis, task_id, id)) | 
					
						
							| 
									
										
										
										
											2025-06-09 21:21:10 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-12-19 17:00:32 +08:00
										 |  |  |     tasks[task_id] = task | 
					
						
							| 
									
										
										
										
											2025-04-13 11:51:02 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # If an ID is provided, associate the task with that ID | 
					
						
							| 
									
										
										
										
											2025-07-11 22:14:48 +08:00
										 |  |  |     if item_tasks.get(id): | 
					
						
							|  |  |  |         item_tasks[id].append(task_id) | 
					
						
							| 
									
										
										
										
											2025-04-13 11:51:02 +08:00
										 |  |  |     else: | 
					
						
							| 
									
										
										
										
											2025-07-11 22:14:48 +08:00
										 |  |  |         item_tasks[id] = [task_id] | 
					
						
							| 
									
										
										
										
											2025-04-13 11:51:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 21:53:53 +08:00
										 |  |  |     if redis: | 
					
						
							|  |  |  |         await redis_save_task(redis, task_id, id) | 
					
						
							| 
									
										
										
										
											2025-06-09 01:20:30 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-19 17:00:32 +08:00
										 |  |  |     return task_id, task | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 21:53:53 +08:00
										 |  |  | async def list_tasks(redis): | 
					
						
							| 
									
										
										
										
											2024-12-19 17:00:32 +08:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     List all currently active task IDs. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-07-11 21:53:53 +08:00
										 |  |  |     if redis: | 
					
						
							|  |  |  |         return await redis_list_tasks(redis) | 
					
						
							| 
									
										
										
										
											2024-12-19 17:00:32 +08:00
										 |  |  |     return list(tasks.keys()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 22:14:48 +08:00
										 |  |  | async def list_task_ids_by_item_id(redis, id): | 
					
						
							| 
									
										
										
										
											2025-04-13 11:51:02 +08:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     List all tasks associated with a specific ID. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-07-11 21:53:53 +08:00
										 |  |  |     if redis: | 
					
						
							| 
									
										
										
										
											2025-07-11 22:14:48 +08:00
										 |  |  |         return await redis_list_item_tasks(redis, id) | 
					
						
							|  |  |  |     return item_tasks.get(id, []) | 
					
						
							| 
									
										
										
										
											2025-04-13 11:51:02 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-11 21:53:53 +08:00
										 |  |  | async def stop_task(redis, task_id: str): | 
					
						
							| 
									
										
										
										
											2024-12-19 17:00:32 +08:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     Cancel a running task and remove it from the global task list. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-07-11 21:53:53 +08:00
										 |  |  |     if redis: | 
					
						
							| 
									
										
										
										
											2025-06-09 01:20:30 +08:00
										 |  |  |         # PUBSUB: All instances check if they have this task, and stop if so. | 
					
						
							| 
									
										
										
										
											2025-06-09 21:21:10 +08:00
										 |  |  |         await redis_send_command( | 
					
						
							| 
									
										
										
										
											2025-07-11 21:53:53 +08:00
										 |  |  |             redis, | 
					
						
							| 
									
										
										
										
											2025-06-09 01:20:30 +08:00
										 |  |  |             { | 
					
						
							|  |  |  |                 "action": "stop", | 
					
						
							|  |  |  |                 "task_id": task_id, | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         # Optionally check if task_id still in Redis a few moments later for feedback? | 
					
						
							|  |  |  |         return {"status": True, "message": f"Stop signal sent for {task_id}"} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-12 07:00:40 +08:00
										 |  |  |     task = tasks.pop(task_id) | 
					
						
							| 
									
										
										
										
											2024-12-19 17:00:32 +08:00
										 |  |  |     if not task: | 
					
						
							|  |  |  |         raise ValueError(f"Task with ID {task_id} not found.") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     task.cancel()  # Request task cancellation | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         await task  # Wait for the task to handle the cancellation | 
					
						
							|  |  |  |     except asyncio.CancelledError: | 
					
						
							|  |  |  |         # Task successfully canceled | 
					
						
							|  |  |  |         return {"status": True, "message": f"Task {task_id} successfully stopped."} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return {"status": False, "message": f"Failed to stop task {task_id}."} | 
					
						
							| 
									
										
										
										
											2025-07-11 22:41:09 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | async def stop_item_tasks(redis: Redis, item_id: str): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Stop all tasks associated with a specific item ID. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     task_ids = await list_task_ids_by_item_id(redis, item_id) | 
					
						
							|  |  |  |     if not task_ids: | 
					
						
							|  |  |  |         return {"status": True, "message": f"No tasks found for item {item_id}."} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for task_id in task_ids: | 
					
						
							|  |  |  |         result = await stop_task(redis, task_id) | 
					
						
							|  |  |  |         if not result["status"]: | 
					
						
							|  |  |  |             return result  # Return the first failure | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return {"status": True, "message": f"All tasks for item {item_id} stopped."} |