167 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			167 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
# tasks.py
 | 
						|
import asyncio
 | 
						|
from typing import Dict
 | 
						|
from uuid import uuid4
 | 
						|
import json
 | 
						|
from redis.asyncio import Redis
 | 
						|
from fastapi import Request
 | 
						|
from typing import Dict, List, Optional
 | 
						|
 | 
						|
# A dictionary to keep track of active tasks
 | 
						|
tasks: Dict[str, asyncio.Task] = {}
 | 
						|
chat_tasks = {}
 | 
						|
 | 
						|
 | 
						|
REDIS_TASKS_KEY = "open-webui:tasks"
 | 
						|
REDIS_CHAT_TASKS_KEY = "open-webui:tasks:chat"
 | 
						|
REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
 | 
						|
 | 
						|
 | 
						|
def is_redis(request: Request) -> bool:
 | 
						|
    # Called everywhere a request is available to check Redis
 | 
						|
    return hasattr(request.app.state, "redis") and (request.app.state.redis is not None)
 | 
						|
 | 
						|
 | 
						|
async def redis_task_command_listener(app):
 | 
						|
    redis: Redis = app.state.redis
 | 
						|
    pubsub = redis.pubsub()
 | 
						|
    await pubsub.subscribe(REDIS_PUBSUB_CHANNEL)
 | 
						|
 | 
						|
    async for message in pubsub.listen():
 | 
						|
        if message["type"] != "message":
 | 
						|
            continue
 | 
						|
        try:
 | 
						|
            command = json.loads(message["data"])
 | 
						|
            if command.get("action") == "stop":
 | 
						|
                task_id = command.get("task_id")
 | 
						|
                local_task = tasks.get(task_id)
 | 
						|
                if local_task:
 | 
						|
                    local_task.cancel()
 | 
						|
        except Exception as e:
 | 
						|
            print(f"Error handling distributed task command: {e}")
 | 
						|
 | 
						|
 | 
						|
### ------------------------------
 | 
						|
### REDIS-ENABLED HANDLERS
 | 
						|
### ------------------------------
 | 
						|
 | 
						|
 | 
						|
async def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]):
 | 
						|
    pipe = redis.pipeline()
 | 
						|
    pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "")
 | 
						|
    if chat_id:
 | 
						|
        pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
 | 
						|
    await pipe.execute()
 | 
						|
 | 
						|
 | 
						|
async def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]):
 | 
						|
    pipe = redis.pipeline()
 | 
						|
    pipe.hdel(REDIS_TASKS_KEY, task_id)
 | 
						|
    if chat_id:
 | 
						|
        pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
 | 
						|
        if (await pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute())[-1] == 0:
 | 
						|
            pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}")  # Remove if empty set
 | 
						|
    await pipe.execute()
 | 
						|
 | 
						|
 | 
						|
async def redis_list_tasks(redis: Redis) -> List[str]:
 | 
						|
    return list(await redis.hkeys(REDIS_TASKS_KEY))
 | 
						|
 | 
						|
 | 
						|
async def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]:
 | 
						|
    return list(await redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}"))
 | 
						|
 | 
						|
 | 
						|
async def redis_send_command(redis: Redis, command: dict):
 | 
						|
    await redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
 | 
						|
 | 
						|
 | 
						|
async def cleanup_task(request, task_id: str, id=None):
 | 
						|
    """
 | 
						|
    Remove a completed or canceled task from the global `tasks` dictionary.
 | 
						|
    """
 | 
						|
    if is_redis(request):
 | 
						|
        await redis_cleanup_task(request.app.state.redis, task_id, id)
 | 
						|
 | 
						|
    tasks.pop(task_id, None)  # Remove the task if it exists
 | 
						|
 | 
						|
    # If an ID is provided, remove the task from the chat_tasks dictionary
 | 
						|
    if id and task_id in chat_tasks.get(id, []):
 | 
						|
        chat_tasks[id].remove(task_id)
 | 
						|
        if not chat_tasks[id]:  # If no tasks left for this ID, remove the entry
 | 
						|
            chat_tasks.pop(id, None)
 | 
						|
 | 
						|
 | 
						|
async def create_task(request, coroutine, id=None):
 | 
						|
    """
 | 
						|
    Create a new asyncio task and add it to the global task dictionary.
 | 
						|
    """
 | 
						|
    task_id = str(uuid4())  # Generate a unique ID for the task
 | 
						|
    task = asyncio.create_task(coroutine)  # Create the task
 | 
						|
 | 
						|
    # Add a done callback for cleanup
 | 
						|
    task.add_done_callback(
 | 
						|
        lambda t: asyncio.create_task(cleanup_task(request, task_id, id))
 | 
						|
    )
 | 
						|
    tasks[task_id] = task
 | 
						|
 | 
						|
    # If an ID is provided, associate the task with that ID
 | 
						|
    if chat_tasks.get(id):
 | 
						|
        chat_tasks[id].append(task_id)
 | 
						|
    else:
 | 
						|
        chat_tasks[id] = [task_id]
 | 
						|
 | 
						|
    if is_redis(request):
 | 
						|
        await redis_save_task(request.app.state.redis, task_id, id)
 | 
						|
 | 
						|
    return task_id, task
 | 
						|
 | 
						|
 | 
						|
async def list_tasks(request):
 | 
						|
    """
 | 
						|
    List all currently active task IDs.
 | 
						|
    """
 | 
						|
    if is_redis(request):
 | 
						|
        return await redis_list_tasks(request.app.state.redis)
 | 
						|
    return list(tasks.keys())
 | 
						|
 | 
						|
 | 
						|
async def list_task_ids_by_chat_id(request, id):
 | 
						|
    """
 | 
						|
    List all tasks associated with a specific ID.
 | 
						|
    """
 | 
						|
    if is_redis(request):
 | 
						|
        return await redis_list_chat_tasks(request.app.state.redis, id)
 | 
						|
    return chat_tasks.get(id, [])
 | 
						|
 | 
						|
 | 
						|
async def stop_task(request, task_id: str):
 | 
						|
    """
 | 
						|
    Cancel a running task and remove it from the global task list.
 | 
						|
    """
 | 
						|
    if is_redis(request):
 | 
						|
        # PUBSUB: All instances check if they have this task, and stop if so.
 | 
						|
        await redis_send_command(
 | 
						|
            request.app.state.redis,
 | 
						|
            {
 | 
						|
                "action": "stop",
 | 
						|
                "task_id": task_id,
 | 
						|
            },
 | 
						|
        )
 | 
						|
        # Optionally check if task_id still in Redis a few moments later for feedback?
 | 
						|
        return {"status": True, "message": f"Stop signal sent for {task_id}"}
 | 
						|
 | 
						|
    task = tasks.get(task_id)
 | 
						|
    if not task:
 | 
						|
        raise ValueError(f"Task with ID {task_id} not found.")
 | 
						|
 | 
						|
    task.cancel()  # Request task cancellation
 | 
						|
    try:
 | 
						|
        await task  # Wait for the task to handle the cancellation
 | 
						|
    except asyncio.CancelledError:
 | 
						|
        # Task successfully canceled
 | 
						|
        tasks.pop(task_id, None)  # Remove it from the dictionary
 | 
						|
        return {"status": True, "message": f"Task {task_id} successfully stopped."}
 | 
						|
 | 
						|
    return {"status": False, "message": f"Failed to stop task {task_id}."}
 |