211 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			211 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			Python
		
	
	
	
import asyncio
 | 
						|
import json
 | 
						|
import logging
 | 
						|
import uuid
 | 
						|
from typing import Optional
 | 
						|
 | 
						|
import aiohttp
 | 
						|
import websockets
 | 
						|
from pydantic import BaseModel
 | 
						|
 | 
						|
from open_webui.env import SRC_LOG_LEVELS
 | 
						|
 | 
						|
logger = logging.getLogger(__name__)
 | 
						|
logger.setLevel(SRC_LOG_LEVELS["MAIN"])
 | 
						|
 | 
						|
 | 
						|
class ResultModel(BaseModel):
 | 
						|
    """
 | 
						|
    Execute Code Result Model
 | 
						|
    """
 | 
						|
 | 
						|
    stdout: Optional[str] = ""
 | 
						|
    stderr: Optional[str] = ""
 | 
						|
    result: Optional[str] = ""
 | 
						|
 | 
						|
 | 
						|
class JupyterCodeExecuter:
 | 
						|
    """
 | 
						|
    Execute code in jupyter notebook
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        base_url: str,
 | 
						|
        code: str,
 | 
						|
        token: str = "",
 | 
						|
        password: str = "",
 | 
						|
        timeout: int = 60,
 | 
						|
    ):
 | 
						|
        """
 | 
						|
        :param base_url: Jupyter server URL (e.g., "http://localhost:8888")
 | 
						|
        :param code: Code to execute
 | 
						|
        :param token: Jupyter authentication token (optional)
 | 
						|
        :param password: Jupyter password (optional)
 | 
						|
        :param timeout: WebSocket timeout in seconds (default: 60s)
 | 
						|
        """
 | 
						|
        self.base_url = base_url
 | 
						|
        self.code = code
 | 
						|
        self.token = token
 | 
						|
        self.password = password
 | 
						|
        self.timeout = timeout
 | 
						|
        self.kernel_id = ""
 | 
						|
        if self.base_url[-1] != "/":
 | 
						|
            self.base_url += "/"
 | 
						|
        self.session = aiohttp.ClientSession(trust_env=True, base_url=self.base_url)
 | 
						|
        self.params = {}
 | 
						|
        self.result = ResultModel()
 | 
						|
 | 
						|
    async def __aenter__(self):
 | 
						|
        return self
 | 
						|
 | 
						|
    async def __aexit__(self, exc_type, exc_val, exc_tb):
 | 
						|
        if self.kernel_id:
 | 
						|
            try:
 | 
						|
                async with self.session.delete(
 | 
						|
                    f"api/kernels/{self.kernel_id}", params=self.params
 | 
						|
                ) as response:
 | 
						|
                    response.raise_for_status()
 | 
						|
            except Exception as err:
 | 
						|
                logger.exception("close kernel failed, %s", err)
 | 
						|
        await self.session.close()
 | 
						|
 | 
						|
    async def run(self) -> ResultModel:
 | 
						|
        try:
 | 
						|
            await self.sign_in()
 | 
						|
            await self.init_kernel()
 | 
						|
            await self.execute_code()
 | 
						|
        except Exception as err:
 | 
						|
            logger.exception("execute code failed, %s", err)
 | 
						|
            self.result.stderr = f"Error: {err}"
 | 
						|
        return self.result
 | 
						|
 | 
						|
    async def sign_in(self) -> None:
 | 
						|
        # password authentication
 | 
						|
        if self.password and not self.token:
 | 
						|
            async with self.session.get("login") as response:
 | 
						|
                response.raise_for_status()
 | 
						|
                xsrf_token = response.cookies["_xsrf"].value
 | 
						|
                if not xsrf_token:
 | 
						|
                    raise ValueError("_xsrf token not found")
 | 
						|
                self.session.cookie_jar.update_cookies(response.cookies)
 | 
						|
                self.session.headers.update({"X-XSRFToken": xsrf_token})
 | 
						|
            async with self.session.post(
 | 
						|
                "login",
 | 
						|
                data={"_xsrf": xsrf_token, "password": self.password},
 | 
						|
                allow_redirects=False,
 | 
						|
            ) as response:
 | 
						|
                response.raise_for_status()
 | 
						|
                self.session.cookie_jar.update_cookies(response.cookies)
 | 
						|
 | 
						|
        # token authentication
 | 
						|
        if self.token:
 | 
						|
            self.params.update({"token": self.token})
 | 
						|
 | 
						|
    async def init_kernel(self) -> None:
 | 
						|
        async with self.session.post(url="api/kernels", params=self.params) as response:
 | 
						|
            response.raise_for_status()
 | 
						|
            kernel_data = await response.json()
 | 
						|
            self.kernel_id = kernel_data["id"]
 | 
						|
 | 
						|
    def init_ws(self) -> (str, dict):
 | 
						|
        ws_base = self.base_url.replace("http", "ws", 1)
 | 
						|
        ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()])
 | 
						|
        websocket_url = f"{ws_base}api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}"
 | 
						|
        ws_headers = {}
 | 
						|
        if self.password and not self.token:
 | 
						|
            ws_headers = {
 | 
						|
                "Cookie": "; ".join(
 | 
						|
                    [
 | 
						|
                        f"{cookie.key}={cookie.value}"
 | 
						|
                        for cookie in self.session.cookie_jar
 | 
						|
                    ]
 | 
						|
                ),
 | 
						|
                **self.session.headers,
 | 
						|
            }
 | 
						|
        return websocket_url, ws_headers
 | 
						|
 | 
						|
    async def execute_code(self) -> None:
 | 
						|
        # initialize ws
 | 
						|
        websocket_url, ws_headers = self.init_ws()
 | 
						|
        # execute
 | 
						|
        async with websockets.connect(
 | 
						|
            websocket_url, additional_headers=ws_headers
 | 
						|
        ) as ws:
 | 
						|
            await self.execute_in_jupyter(ws)
 | 
						|
 | 
						|
    async def execute_in_jupyter(self, ws) -> None:
 | 
						|
        # send message
 | 
						|
        msg_id = uuid.uuid4().hex
 | 
						|
        await ws.send(
 | 
						|
            json.dumps(
 | 
						|
                {
 | 
						|
                    "header": {
 | 
						|
                        "msg_id": msg_id,
 | 
						|
                        "msg_type": "execute_request",
 | 
						|
                        "username": "user",
 | 
						|
                        "session": uuid.uuid4().hex,
 | 
						|
                        "date": "",
 | 
						|
                        "version": "5.3",
 | 
						|
                    },
 | 
						|
                    "parent_header": {},
 | 
						|
                    "metadata": {},
 | 
						|
                    "content": {
 | 
						|
                        "code": self.code,
 | 
						|
                        "silent": False,
 | 
						|
                        "store_history": True,
 | 
						|
                        "user_expressions": {},
 | 
						|
                        "allow_stdin": False,
 | 
						|
                        "stop_on_error": True,
 | 
						|
                    },
 | 
						|
                    "channel": "shell",
 | 
						|
                }
 | 
						|
            )
 | 
						|
        )
 | 
						|
        # parse message
 | 
						|
        stdout, stderr, result = "", "", []
 | 
						|
        while True:
 | 
						|
            try:
 | 
						|
                # wait for message
 | 
						|
                message = await asyncio.wait_for(ws.recv(), self.timeout)
 | 
						|
                message_data = json.loads(message)
 | 
						|
                # msg id not match, skip
 | 
						|
                if message_data.get("parent_header", {}).get("msg_id") != msg_id:
 | 
						|
                    continue
 | 
						|
                # check message type
 | 
						|
                msg_type = message_data.get("msg_type")
 | 
						|
                match msg_type:
 | 
						|
                    case "stream":
 | 
						|
                        if message_data["content"]["name"] == "stdout":
 | 
						|
                            stdout += message_data["content"]["text"]
 | 
						|
                        elif message_data["content"]["name"] == "stderr":
 | 
						|
                            stderr += message_data["content"]["text"]
 | 
						|
                    case "execute_result" | "display_data":
 | 
						|
                        data = message_data["content"]["data"]
 | 
						|
                        if "image/png" in data:
 | 
						|
                            result.append(f"data:image/png;base64,{data['image/png']}")
 | 
						|
                        elif "text/plain" in data:
 | 
						|
                            result.append(data["text/plain"])
 | 
						|
                    case "error":
 | 
						|
                        stderr += "\n".join(message_data["content"]["traceback"])
 | 
						|
                    case "status":
 | 
						|
                        if message_data["content"]["execution_state"] == "idle":
 | 
						|
                            break
 | 
						|
 | 
						|
            except asyncio.TimeoutError:
 | 
						|
                stderr += "\nExecution timed out."
 | 
						|
                break
 | 
						|
        self.result.stdout = stdout.strip()
 | 
						|
        self.result.stderr = stderr.strip()
 | 
						|
        self.result.result = "\n".join(result).strip() if result else ""
 | 
						|
 | 
						|
 | 
						|
async def execute_code_jupyter(
 | 
						|
    base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
 | 
						|
) -> dict:
 | 
						|
    async with JupyterCodeExecuter(
 | 
						|
        base_url, code, token, password, timeout
 | 
						|
    ) as executor:
 | 
						|
        result = await executor.run()
 | 
						|
        return result.model_dump()
 |