open-notebook/api/routers/chat.py

494 lines
19 KiB
Python
Raw Permalink Normal View History

import asyncio
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, HTTPException, Query
from langchain_core.runnables import RunnableConfig
from loguru import logger
from pydantic import BaseModel, Field
from open_notebook.database.repository import ensure_record_id, repo_query
from open_notebook.domain.notebook import ChatSession, Note, Notebook, Source
from open_notebook.exceptions import (
NotFoundError,
)
from open_notebook.graphs.chat import graph as chat_graph
router = APIRouter()
# Request/Response models
class CreateSessionRequest(BaseModel):
notebook_id: str = Field(..., description="Notebook ID to create session for")
title: Optional[str] = Field(None, description="Optional session title")
model_override: Optional[str] = Field(
None, description="Optional model override for this session"
)
class UpdateSessionRequest(BaseModel):
title: Optional[str] = Field(None, description="New session title")
model_override: Optional[str] = Field(
None, description="Model override for this session"
)
class ChatMessage(BaseModel):
id: str = Field(..., description="Message ID")
type: str = Field(..., description="Message type (human|ai)")
content: str = Field(..., description="Message content")
timestamp: Optional[str] = Field(None, description="Message timestamp")
class ChatSessionResponse(BaseModel):
id: str = Field(..., description="Session ID")
title: str = Field(..., description="Session title")
notebook_id: Optional[str] = Field(None, description="Notebook ID")
created: str = Field(..., description="Creation timestamp")
updated: str = Field(..., description="Last update timestamp")
message_count: Optional[int] = Field(
None, description="Number of messages in session"
)
model_override: Optional[str] = Field(
None, description="Model override for this session"
)
class ChatSessionWithMessagesResponse(ChatSessionResponse):
messages: List[ChatMessage] = Field(
default_factory=list, description="Session messages"
)
class ExecuteChatRequest(BaseModel):
session_id: str = Field(..., description="Chat session ID")
message: str = Field(..., description="User message content")
context: Dict[str, Any] = Field(
..., description="Chat context with sources and notes"
)
model_override: Optional[str] = Field(
None, description="Optional model override for this message"
)
class ExecuteChatResponse(BaseModel):
session_id: str = Field(..., description="Session ID")
messages: List[ChatMessage] = Field(..., description="Updated message list")
class BuildContextRequest(BaseModel):
notebook_id: str = Field(..., description="Notebook ID")
context_config: Dict[str, Any] = Field(..., description="Context configuration")
class BuildContextResponse(BaseModel):
context: Dict[str, Any] = Field(..., description="Built context data")
token_count: int = Field(..., description="Estimated token count")
char_count: int = Field(..., description="Character count")
class SuccessResponse(BaseModel):
success: bool = Field(True, description="Operation success status")
message: str = Field(..., description="Success message")
@router.get("/chat/sessions", response_model=List[ChatSessionResponse])
async def get_sessions(notebook_id: str = Query(..., description="Notebook ID")):
"""Get all chat sessions for a notebook."""
try:
# Get notebook to verify it exists
notebook = await Notebook.get(notebook_id)
if not notebook:
raise HTTPException(status_code=404, detail="Notebook not found")
# Get sessions for this notebook
sessions = await notebook.get_chat_sessions()
return [
ChatSessionResponse(
id=session.id or "",
title=session.title or "Untitled Session",
notebook_id=notebook_id,
created=str(session.created),
updated=str(session.updated),
message_count=0, # TODO: Add message count if needed
model_override=getattr(session, "model_override", None),
)
for session in sessions
]
except NotFoundError:
raise HTTPException(status_code=404, detail="Notebook not found")
except Exception as e:
logger.error(f"Error fetching chat sessions: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error fetching chat sessions: {str(e)}"
)
@router.post("/chat/sessions", response_model=ChatSessionResponse)
async def create_session(request: CreateSessionRequest):
"""Create a new chat session."""
try:
# Verify notebook exists
notebook = await Notebook.get(request.notebook_id)
if not notebook:
raise HTTPException(status_code=404, detail="Notebook not found")
# Create new session
session = ChatSession(
title=request.title or f"Chat Session {asyncio.get_event_loop().time():.0f}",
model_override=request.model_override,
)
await session.save()
# Relate session to notebook
await session.relate_to_notebook(request.notebook_id)
return ChatSessionResponse(
id=session.id or "",
title=session.title or "",
notebook_id=request.notebook_id,
created=str(session.created),
updated=str(session.updated),
message_count=0,
model_override=session.model_override,
)
except NotFoundError:
raise HTTPException(status_code=404, detail="Notebook not found")
except Exception as e:
logger.error(f"Error creating chat session: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Error creating chat session: {str(e)}"
)
@router.get(
"/chat/sessions/{session_id}", response_model=ChatSessionWithMessagesResponse
)
async def get_session(session_id: str):
"""Get a specific session with its messages."""
try:
# Get session
# Ensure session_id has proper table prefix
full_session_id = (
session_id
if session_id.startswith("chat_session:")
else f"chat_session:{session_id}"
)
session = await ChatSession.get(full_session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# Get session state from LangGraph to retrieve messages
thread_state = chat_graph.get_state(
config=RunnableConfig(configurable={"thread_id": session_id})
)
# Extract messages from state
messages: list[ChatMessage] = []
if thread_state and thread_state.values and "messages" in thread_state.values:
for msg in thread_state.values["messages"]:
messages.append(
ChatMessage(
id=getattr(msg, "id", f"msg_{len(messages)}"),
type=msg.type if hasattr(msg, "type") else "unknown",
content=msg.content if hasattr(msg, "content") else str(msg),
timestamp=None, # LangChain messages don't have timestamps by default
)
)
# Find notebook_id (we need to query the relationship)
# Ensure session_id has proper table prefix
full_session_id = (
session_id
if session_id.startswith("chat_session:")
else f"chat_session:{session_id}"
)
notebook_query = await repo_query(
"SELECT out FROM refers_to WHERE in = $session_id",
{"session_id": ensure_record_id(full_session_id)},
)
notebook_id = notebook_query[0]["out"] if notebook_query else None
if not notebook_id:
# This might be an old session created before API migration
logger.warning(
f"No notebook relationship found for session {session_id} - may be an orphaned session"
)
return ChatSessionWithMessagesResponse(
id=session.id or "",
title=session.title or "Untitled Session",
notebook_id=notebook_id,
created=str(session.created),
updated=str(session.updated),
message_count=len(messages),
messages=messages,
model_override=getattr(session, "model_override", None),
)
except NotFoundError:
raise HTTPException(status_code=404, detail="Session not found")
except Exception as e:
logger.error(f"Error fetching session: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error fetching session: {str(e)}")
@router.put("/chat/sessions/{session_id}", response_model=ChatSessionResponse)
async def update_session(session_id: str, request: UpdateSessionRequest):
"""Update session title."""
try:
# Ensure session_id has proper table prefix
full_session_id = (
session_id
if session_id.startswith("chat_session:")
else f"chat_session:{session_id}"
)
session = await ChatSession.get(full_session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
update_data = request.model_dump(exclude_unset=True)
if "title" in update_data:
session.title = update_data["title"]
if "model_override" in update_data:
session.model_override = update_data["model_override"]
await session.save()
# Find notebook_id
# Ensure session_id has proper table prefix
full_session_id = (
session_id
if session_id.startswith("chat_session:")
else f"chat_session:{session_id}"
)
notebook_query = await repo_query(
"SELECT out FROM refers_to WHERE in = $session_id",
{"session_id": ensure_record_id(full_session_id)},
)
notebook_id = notebook_query[0]["out"] if notebook_query else None
return ChatSessionResponse(
id=session.id or "",
title=session.title or "",
notebook_id=notebook_id,
created=str(session.created),
updated=str(session.updated),
message_count=0,
model_override=session.model_override,
)
except NotFoundError:
raise HTTPException(status_code=404, detail="Session not found")
except Exception as e:
logger.error(f"Error updating session: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error updating session: {str(e)}")
@router.delete("/chat/sessions/{session_id}", response_model=SuccessResponse)
async def delete_session(session_id: str):
"""Delete a chat session."""
try:
# Ensure session_id has proper table prefix
full_session_id = (
session_id
if session_id.startswith("chat_session:")
else f"chat_session:{session_id}"
)
session = await ChatSession.get(full_session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
await session.delete()
return SuccessResponse(success=True, message="Session deleted successfully")
except NotFoundError:
raise HTTPException(status_code=404, detail="Session not found")
except Exception as e:
logger.error(f"Error deleting session: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error deleting session: {str(e)}")
@router.post("/chat/execute", response_model=ExecuteChatResponse)
async def execute_chat(request: ExecuteChatRequest):
"""Execute a chat request and get AI response."""
try:
# Verify session exists
# Ensure session_id has proper table prefix
full_session_id = (
request.session_id
if request.session_id.startswith("chat_session:")
else f"chat_session:{request.session_id}"
)
session = await ChatSession.get(full_session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# Determine model override (per-request override takes precedence over session-level)
model_override = (
request.model_override
if request.model_override is not None
else getattr(session, "model_override", None)
)
# Get current state
current_state = chat_graph.get_state(
config=RunnableConfig(
configurable={"thread_id": request.session_id}
)
)
# Prepare state for execution
state_values = current_state.values if current_state else {}
state_values["messages"] = state_values.get("messages", [])
state_values["context"] = request.context
state_values["model_override"] = model_override
# Add user message to state
from langchain_core.messages import HumanMessage
user_message = HumanMessage(content=request.message)
state_values["messages"].append(user_message)
# Execute chat graph
result = chat_graph.invoke(
input=state_values, # type: ignore[arg-type]
config=RunnableConfig(
configurable={
"thread_id": request.session_id,
"model_id": model_override,
}
),
)
# Update session timestamp
await session.save()
# Convert messages to response format
messages: list[ChatMessage] = []
for msg in result.get("messages", []):
messages.append(
ChatMessage(
id=getattr(msg, "id", f"msg_{len(messages)}"),
type=msg.type if hasattr(msg, "type") else "unknown",
content=msg.content if hasattr(msg, "content") else str(msg),
timestamp=None,
)
)
return ExecuteChatResponse(session_id=request.session_id, messages=messages)
except NotFoundError:
raise HTTPException(status_code=404, detail="Session not found")
except Exception as e:
logger.error(f"Error executing chat: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error executing chat: {str(e)}")
@router.post("/chat/context", response_model=BuildContextResponse)
async def build_context(request: BuildContextRequest):
"""Build context for a notebook based on context configuration."""
try:
# Verify notebook exists
notebook = await Notebook.get(request.notebook_id)
if not notebook:
raise HTTPException(status_code=404, detail="Notebook not found")
context_data: dict[str, list[dict[str, str]]] = {"sources": [], "notes": []}
total_content = ""
# Process context configuration if provided
if request.context_config:
# Process sources
for source_id, status in request.context_config.get("sources", {}).items():
if "not in" in status:
continue
try:
# Add table prefix if not present
full_source_id = (
source_id
if source_id.startswith("source:")
else f"source:{source_id}"
)
try:
source = await Source.get(full_source_id)
except Exception:
continue
if "insights" in status:
source_context = await source.get_context(context_size="short")
context_data["sources"].append(source_context)
total_content += str(source_context)
elif "full content" in status:
source_context = await source.get_context(context_size="long")
context_data["sources"].append(source_context)
total_content += str(source_context)
except Exception as e:
logger.warning(f"Error processing source {source_id}: {str(e)}")
continue
# Process notes
for note_id, status in request.context_config.get("notes", {}).items():
if "not in" in status:
continue
try:
# Add table prefix if not present
full_note_id = (
note_id if note_id.startswith("note:") else f"note:{note_id}"
)
note = await Note.get(full_note_id)
if not note:
continue
if "full content" in status:
note_context = note.get_context(context_size="long")
context_data["notes"].append(note_context)
total_content += str(note_context)
except Exception as e:
logger.warning(f"Error processing note {note_id}: {str(e)}")
continue
else:
# Default behavior - include all sources and notes with short context
sources = await notebook.get_sources()
for source in sources:
try:
source_context = await source.get_context(context_size="short")
context_data["sources"].append(source_context)
total_content += str(source_context)
except Exception as e:
logger.warning(f"Error processing source {source.id}: {str(e)}")
continue
notes = await notebook.get_notes()
for note in notes:
try:
note_context = note.get_context(context_size="short")
context_data["notes"].append(note_context)
total_content += str(note_context)
except Exception as e:
logger.warning(f"Error processing note {note.id}: {str(e)}")
continue
# Calculate character and token counts
char_count = len(total_content)
# Use token count utility if available
try:
from open_notebook.utils import token_count
estimated_tokens = token_count(total_content) if total_content else 0
except ImportError:
# Fallback to simple estimation
estimated_tokens = char_count // 4
return BuildContextResponse(
context=context_data, token_count=estimated_tokens, char_count=char_count
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error building context: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error building context: {str(e)}")