From b8da4a8cd8257d4846f3608e299618a0b4f185ed Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 29 Jul 2025 23:45:25 +0400 Subject: [PATCH] refac --- backend/open_webui/config.py | 6 --- backend/open_webui/env.py | 29 ++++++++++++++ backend/open_webui/main.py | 2 +- backend/open_webui/utils/auth.py | 66 ++++++++++++++++++++++++++------ 4 files changed, 85 insertions(+), 18 deletions(-) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 49ab1a9aad..bf4325417d 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -779,12 +779,6 @@ if CUSTOM_NAME: pass -#################################### -# LICENSE_KEY -#################################### - -LICENSE_KEY = os.environ.get("LICENSE_KEY", "") - #################################### # STORAGE PROVIDER #################################### diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 61518d59c6..01c6f0468b 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -7,6 +7,7 @@ import sys import shutil from uuid import uuid4 from pathlib import Path +from cryptography.hazmat.primitives import serialization import markdown from bs4 import BeautifulSoup @@ -430,6 +431,34 @@ ENABLE_COMPRESSION_MIDDLEWARE = ( os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true" ) + +#################################### +# LICENSE_KEY +#################################### + +LICENSE_KEY = os.environ.get("LICENSE_KEY", "") + +LICENSE_BLOB = None +LICENSE_BLOB_PATH = os.environ.get("LICENSE_BLOB_PATH", DATA_DIR / "l.data") +if LICENSE_BLOB_PATH and os.path.exists(LICENSE_BLOB_PATH): + with open(LICENSE_BLOB_PATH, "rb") as f: + LICENSE_BLOB = f.read() + +LICENSE_PUBLIC_KEY = os.environ.get("LICENSE_PUBLIC_KEY", "") + +pk = None +if LICENSE_PUBLIC_KEY: + pk = serialization.load_pem_public_key( + f""" +-----BEGIN PUBLIC KEY----- +{LICENSE_PUBLIC_KEY} +-----END PUBLIC KEY----- +""".encode( + "utf-8" + ) + ) + + #################################### # MODELS #################################### diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 595d551d75..fe1fd6ded8 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -102,7 +102,6 @@ from open_webui.models.users import UserModel, Users from open_webui.models.chats import Chats from open_webui.config import ( - LICENSE_KEY, # Ollama ENABLE_OLLAMA_API, OLLAMA_BASE_URLS, @@ -395,6 +394,7 @@ from open_webui.config import ( reset_config, ) from open_webui.env import ( + LICENSE_KEY, AUDIT_EXCLUDED_PATHS, AUDIT_LOG_LEVEL, CHANGELOG, diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index 3262c803f3..5f30738cfe 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -8,6 +8,12 @@ import requests import os +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.asymmetric import ed25519 +from cryptography.hazmat.primitives import serialization +import json + + from datetime import datetime, timedelta import pytz from pytz import UTC @@ -18,7 +24,11 @@ from opentelemetry import trace from open_webui.models.users import Users from open_webui.constants import ERROR_MESSAGES + from open_webui.env import ( + OFFLINE_MODE, + LICENSE_BLOB, + pk, WEBUI_SECRET_KEY, TRUSTED_SIGNATURE_KEY, STATIC_DIR, @@ -74,6 +84,18 @@ def override_static(path: str, content: str): def get_license_data(app, key): + def data_handler(data): + for k, v in data.items(): + if k == "resources": + for p, c in v.items(): + globals().get("override_static", lambda a, b: None)(p, c) + elif k == "count": + setattr(app.state, "USER_COUNT", v) + elif k == "name": + setattr(app.state, "WEBUI_NAME", v) + elif k == "metadata": + setattr(app.state, "LICENSE_METADATA", v) + def handler(u): res = requests.post( f"{u}/api/v1/license/", @@ -83,16 +105,7 @@ def get_license_data(app, key): if getattr(res, "ok", False): payload = getattr(res, "json", lambda: {})() - for k, v in payload.items(): - if k == "resources": - for p, c in v.items(): - globals().get("override_static", lambda a, b: None)(p, c) - elif k == "count": - setattr(app.state, "USER_COUNT", v) - elif k == "name": - setattr(app.state, "WEBUI_NAME", v) - elif k == "metadata": - setattr(app.state, "LICENSE_METADATA", v) + data_handler(payload) return True else: log.error( @@ -100,13 +113,44 @@ def get_license_data(app, key): ) if key: - us = ["https://api.openwebui.com", "https://licenses.api.openwebui.com"] + us = [ + "https://api.openwebui.com", + "https://licenses.api.openwebui.com", + ] try: for u in us: if handler(u): return True except Exception as ex: log.exception(f"License: Uncaught Exception: {ex}") + + try: + if LICENSE_BLOB: + nl = 12 + kb = hashlib.sha256((key.replace("-", "").upper()).encode()).digest() + + def nt(b): + return b[:nl], b[nl:] + + lb = base64.b64decode(LICENSE_BLOB) + ln, lt = nt(lb) + + aesgcm = AESGCM(kb) + p = json.loads(aesgcm.decrypt(ln, lt, None)) + pk.verify(base64.b64decode(p["s"]), p["p"].encode()) + + pb = base64.b64decode(p["p"]) + pn, pt = nt(pb) + + data = json.loads(aesgcm.decrypt(pn, pt, None).decode()) + if not data.get("exp") and data.get("exp") < datetime.now().date(): + return False + + data_handler(data) + return True + except Exception as e: + log.error(f"License: {e}") + return False