open-notebook/open_notebook/utils/context_builder.py

502 lines
17 KiB
Python

"""
Generic ContextBuilder for the Open Notebook project.
This module provides a flexible ContextBuilder class that can handle any parameters
and build context from sources, notebooks, insights, and notes.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional
from loguru import logger
from open_notebook.domain.notebook import Note, Notebook, Source
from open_notebook.exceptions import DatabaseOperationError, NotFoundError
from .text_utils import token_count
@dataclass
class ContextItem:
"""Represents a single item in the context."""
id: str
type: Literal["source", "note", "insight"]
content: Dict[str, Any]
priority: int = 0
token_count: Optional[int] = None
def __post_init__(self):
"""Calculate token count for the content if not provided."""
if self.token_count is None:
content_str = str(self.content)
self.token_count = token_count(content_str)
@dataclass
class ContextConfig:
"""Configuration for context building."""
sources: Optional[Dict[str, str]] = None # {source_id: inclusion_level}
notes: Optional[Dict[str, str]] = None # {note_id: inclusion_level}
include_insights: bool = True
include_notes: bool = True
max_tokens: Optional[int] = None
priority_weights: Optional[Dict[str, int]] = None # {type: weight}
def __post_init__(self):
"""Initialize default values."""
if self.sources is None:
self.sources = {}
if self.notes is None:
self.notes = {}
if self.priority_weights is None:
self.priority_weights = {"source": 100, "note": 50, "insight": 75}
class ContextBuilder:
"""
Generic ContextBuilder that can handle any parameters and build context
from sources, notebooks, insights, and notes.
"""
def __init__(self, **kwargs):
"""
Initialize ContextBuilder with flexible parameters.
Supported parameters:
- source_id: str - Include specific source
- notebook_id: str - Include notebook content
- include_insights: bool - Include source insights
- include_notes: bool - Include notes
- context_config: ContextConfig - Custom context configuration
- max_tokens: int - Maximum token limit
- priority_order: List[str] - Custom priority order
"""
# Store all parameters for flexibility
self.params = kwargs
# Extract commonly used parameters
self.source_id: Optional[str] = kwargs.get('source_id')
self.notebook_id: Optional[str] = kwargs.get('notebook_id')
self.include_insights: bool = kwargs.get('include_insights', True)
self.include_notes: bool = kwargs.get('include_notes', True)
self.max_tokens: Optional[int] = kwargs.get('max_tokens')
# Context configuration
context_config_arg: Optional[ContextConfig] = kwargs.get('context_config')
self.context_config: ContextConfig
if context_config_arg is None:
self.context_config = ContextConfig(
include_insights=self.include_insights,
include_notes=self.include_notes,
max_tokens=self.max_tokens
)
else:
self.context_config = context_config_arg
# Items storage
self.items: List[ContextItem] = []
logger.debug(f"ContextBuilder initialized with params: {list(kwargs.keys())}")
async def build(self) -> Dict[str, Any]:
"""
Build context based on provided parameters.
Returns:
Dict containing the built context with metadata
"""
try:
logger.info("Starting context building")
# Clear existing items
self.items = []
# Build context based on parameters
if self.source_id:
await self._add_source_context(self.source_id)
if self.notebook_id:
await self._add_notebook_context(self.notebook_id)
# Process any additional custom parameters
await self._process_custom_params()
# Apply post-processing
self.remove_duplicates()
self.prioritize()
if self.max_tokens:
self.truncate_to_fit(self.max_tokens)
# Format and return response
return self._format_response()
except Exception as e:
logger.error(f"Error building context: {str(e)}")
raise DatabaseOperationError(f"Failed to build context: {str(e)}")
async def _add_source_context(
self,
source_id: str,
inclusion_level: str = "insights"
) -> None:
"""
Add source and its insights to context.
Args:
source_id: ID of the source
inclusion_level: "insights", "full content", or "not in"
"""
if inclusion_level == "not in":
return
try:
# Ensure source ID has table prefix
full_source_id = (
source_id if source_id.startswith("source:")
else f"source:{source_id}"
)
source = await Source.get(full_source_id)
if not source:
logger.warning(f"Source {source_id} not found")
return
# Determine context size based on inclusion level
context_size: Literal["short", "long"] = "long" if "full content" in inclusion_level else "short"
source_context = await source.get_context(context_size=context_size)
# Add source item
priority = (self.context_config.priority_weights or {}).get("source", 100)
item = ContextItem(
id=source.id or "",
type="source",
content=source_context,
priority=priority
)
self.add_item(item)
# Add insights if requested and available
if self.include_insights and "insights" in inclusion_level:
insights = await source.get_insights()
for insight in insights:
insight_priority = (self.context_config.priority_weights or {}).get("insight", 75)
insight_item = ContextItem(
id=insight.id or "",
type="insight",
content={
"id": insight.id,
"source_id": source.id,
"insight_type": insight.insight_type,
"content": insight.content
},
priority=insight_priority
)
self.add_item(insight_item)
logger.debug(f"Added source context for {source_id}")
except NotFoundError:
logger.warning(f"Source {source_id} not found")
except Exception as e:
logger.error(f"Error adding source context for {source_id}: {str(e)}")
raise
async def _add_notebook_context(self, notebook_id: str) -> None:
"""
Add notebook content based on context configuration.
Args:
notebook_id: ID of the notebook
"""
try:
notebook = await Notebook.get(notebook_id)
if not notebook:
raise NotFoundError(f"Notebook {notebook_id} not found")
# Process sources from context config or get all
config_sources = self.context_config.sources
if config_sources:
for source_id, status in config_sources.items():
await self._add_source_context(source_id, status)
else:
# Default: get all sources with insights
sources = await notebook.get_sources()
for source in sources:
if source.id:
await self._add_source_context(source.id, "insights")
# Process notes from context config or get all
if self.include_notes:
config_notes = self.context_config.notes
if config_notes:
for note_id, status in config_notes.items():
if "not in" not in status:
await self._add_note_context(note_id, status)
else:
# Default: get all notes with short content
notes = await notebook.get_notes()
for note in notes:
if note.id:
await self._add_note_context(note.id, "full content")
logger.debug(f"Added notebook context for {notebook_id}")
except Exception as e:
logger.error(f"Error adding notebook context for {notebook_id}: {str(e)}")
raise
async def _add_note_context(
self,
note_id: str,
inclusion_level: str = "full content"
) -> None:
"""
Add note to context.
Args:
note_id: ID of the note
inclusion_level: "full content" or "not in"
"""
if inclusion_level == "not in":
return
try:
# Ensure note ID has table prefix
full_note_id = (
note_id if note_id.startswith("note:")
else f"note:{note_id}"
)
note = await Note.get(full_note_id)
if not note:
logger.warning(f"Note {note_id} not found")
return
# Get note context
context_size: Literal["short", "long"] = "long" if "full content" in inclusion_level else "short"
note_context = note.get_context(context_size=context_size)
# Add note item
priority = (self.context_config.priority_weights or {}).get("note", 50)
item = ContextItem(
id=note.id or "",
type="note",
content=note_context,
priority=priority
)
self.add_item(item)
logger.debug(f"Added note context for {note_id}")
except NotFoundError:
logger.warning(f"Note {note_id} not found")
except Exception as e:
logger.error(f"Error adding note context for {note_id}: {str(e)}")
async def _process_custom_params(self) -> None:
"""Process any additional custom parameters."""
# Hook for future extensions - can be overridden in subclasses
# or used to process additional kwargs
for key, value in self.params.items():
if key.startswith('custom_'):
logger.debug(f"Processing custom parameter: {key}={value}")
# Custom processing logic can be added here
def add_item(self, item: ContextItem) -> None:
"""
Add a ContextItem to the builder.
Args:
item: ContextItem to add
"""
self.items.append(item)
logger.debug(f"Added item {item.id} with priority {item.priority}")
def prioritize(self) -> None:
"""Sort items by priority (higher priority first)."""
self.items.sort(key=lambda x: x.priority, reverse=True)
logger.debug(f"Prioritized {len(self.items)} items")
def truncate_to_fit(self, max_tokens: int) -> None:
"""
Remove items if total token count exceeds limit.
Args:
max_tokens: Maximum allowed tokens
"""
if not max_tokens:
return
total_tokens = sum(item.token_count or 0 for item in self.items)
if total_tokens <= max_tokens:
logger.debug(f"Token count {total_tokens} within limit {max_tokens}")
return
logger.info(f"Truncating from {total_tokens} to {max_tokens} tokens")
# Remove items from the end (lowest priority) until under limit
current_tokens = total_tokens
removed_count = 0
while current_tokens > max_tokens and self.items:
removed_item = self.items.pop()
current_tokens -= (removed_item.token_count or 0)
removed_count += 1
logger.info(f"Removed {removed_count} items, final token count: {current_tokens}")
def remove_duplicates(self) -> None:
"""Remove duplicate items based on ID."""
seen_ids = set()
deduplicated_items = []
for item in self.items:
if item.id not in seen_ids:
deduplicated_items.append(item)
seen_ids.add(item.id)
removed_count = len(self.items) - len(deduplicated_items)
self.items = deduplicated_items
if removed_count > 0:
logger.debug(f"Removed {removed_count} duplicate items")
def _format_response(self) -> Dict[str, Any]:
"""
Format the final response.
Returns:
Formatted context response
"""
# Group items by type
sources = []
notes = []
insights = []
for item in self.items:
if item.type == "source":
sources.append(item.content)
elif item.type == "note":
notes.append(item.content)
elif item.type == "insight":
insights.append(item.content)
# Calculate total tokens
total_tokens = sum(item.token_count or 0 for item in self.items)
response = {
"sources": sources,
"notes": notes,
"insights": insights,
"total_tokens": total_tokens,
"total_items": len(self.items),
"metadata": {
"source_count": len(sources),
"note_count": len(notes),
"insight_count": len(insights),
"config": {
"include_insights": self.include_insights,
"include_notes": self.include_notes,
"max_tokens": self.max_tokens
}
}
}
# Add notebook_id if provided
if self.notebook_id:
response["notebook_id"] = self.notebook_id
logger.info(f"Built context with {len(self.items)} items, {total_tokens} tokens")
return response
# Convenience functions for common use cases
async def build_notebook_context(
notebook_id: str,
context_config: Optional[ContextConfig] = None,
max_tokens: Optional[int] = None
) -> Dict[str, Any]:
"""
Build context for a notebook.
Args:
notebook_id: ID of the notebook
context_config: Optional context configuration
max_tokens: Optional token limit
Returns:
Built context
"""
builder = ContextBuilder(
notebook_id=notebook_id,
context_config=context_config,
max_tokens=max_tokens
)
return await builder.build()
async def build_source_context(
source_id: str,
include_insights: bool = True,
max_tokens: Optional[int] = None
) -> Dict[str, Any]:
"""
Build context for a single source.
Args:
source_id: ID of the source
include_insights: Whether to include insights
max_tokens: Optional token limit
Returns:
Built context
"""
builder = ContextBuilder(
source_id=source_id,
include_insights=include_insights,
max_tokens=max_tokens
)
return await builder.build()
async def build_mixed_context(
source_ids: Optional[List[str]] = None,
note_ids: Optional[List[str]] = None,
notebook_id: Optional[str] = None,
max_tokens: Optional[int] = None
) -> Dict[str, Any]:
"""
Build context from mixed sources.
Args:
source_ids: List of source IDs
note_ids: List of note IDs
notebook_id: Optional notebook ID
max_tokens: Optional token limit
Returns:
Built context
"""
context_config = ContextConfig(max_tokens=max_tokens)
# Configure sources
if source_ids:
context_config.sources = {sid: "insights" for sid in source_ids}
# Configure notes
if note_ids:
context_config.notes = {nid: "full content" for nid in note_ids}
builder = ContextBuilder(
notebook_id=notebook_id,
context_config=context_config,
max_tokens=max_tokens
)
return await builder.build()