100 lines
3.3 KiB
Python
100 lines
3.3 KiB
Python
import os
|
|
from typing import Optional
|
|
|
|
from fastapi import HTTPException, Request
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.responses import JSONResponse
|
|
|
|
|
|
class PasswordAuthMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
Middleware to check password authentication for all API requests.
|
|
Only active when OPEN_NOTEBOOK_PASSWORD environment variable is set.
|
|
"""
|
|
|
|
def __init__(self, app, excluded_paths: Optional[list] = None):
|
|
super().__init__(app)
|
|
self.password = os.environ.get("OPEN_NOTEBOOK_PASSWORD")
|
|
self.excluded_paths = excluded_paths or ["/", "/health", "/docs", "/openapi.json", "/redoc"]
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
# Skip authentication if no password is set
|
|
if not self.password:
|
|
return await call_next(request)
|
|
|
|
# Skip authentication for excluded paths
|
|
if request.url.path in self.excluded_paths:
|
|
return await call_next(request)
|
|
|
|
# Skip authentication for CORS preflight requests (OPTIONS)
|
|
if request.method == "OPTIONS":
|
|
return await call_next(request)
|
|
|
|
# Check authorization header
|
|
auth_header = request.headers.get("Authorization")
|
|
|
|
if not auth_header:
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"detail": "Missing authorization header"},
|
|
headers={"WWW-Authenticate": "Bearer"}
|
|
)
|
|
|
|
# Expected format: "Bearer {password}"
|
|
try:
|
|
scheme, credentials = auth_header.split(" ", 1)
|
|
if scheme.lower() != "bearer":
|
|
raise ValueError("Invalid authentication scheme")
|
|
except ValueError:
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"detail": "Invalid authorization header format"},
|
|
headers={"WWW-Authenticate": "Bearer"}
|
|
)
|
|
|
|
# Check password
|
|
if credentials != self.password:
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"detail": "Invalid password"},
|
|
headers={"WWW-Authenticate": "Bearer"}
|
|
)
|
|
|
|
# Password is correct, proceed with the request
|
|
response = await call_next(request)
|
|
return response
|
|
|
|
|
|
# Optional: HTTPBearer security scheme for OpenAPI documentation
|
|
security = HTTPBearer(auto_error=False)
|
|
|
|
|
|
def check_api_password(credentials: Optional[HTTPAuthorizationCredentials] = None) -> bool:
|
|
"""
|
|
Utility function to check API password.
|
|
Can be used as a dependency in individual routes if needed.
|
|
"""
|
|
password = os.environ.get("OPEN_NOTEBOOK_PASSWORD")
|
|
|
|
# No password set, allow access
|
|
if not password:
|
|
return True
|
|
|
|
# No credentials provided
|
|
if not credentials:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Missing authorization",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
# Check password
|
|
if credentials.credentials != password:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Invalid password",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
return True |