502 lines
17 KiB
Python
502 lines
17 KiB
Python
"""
|
|
Generic ContextBuilder for the Open Notebook project.
|
|
|
|
This module provides a flexible ContextBuilder class that can handle any parameters
|
|
and build context from sources, notebooks, insights, and notes.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Literal, Optional
|
|
|
|
from loguru import logger
|
|
|
|
from open_notebook.domain.notebook import Note, Notebook, Source
|
|
from open_notebook.exceptions import DatabaseOperationError, NotFoundError
|
|
|
|
from .text_utils import token_count
|
|
|
|
|
|
@dataclass
|
|
class ContextItem:
|
|
"""Represents a single item in the context."""
|
|
|
|
id: str
|
|
type: Literal["source", "note", "insight"]
|
|
content: Dict[str, Any]
|
|
priority: int = 0
|
|
token_count: Optional[int] = None
|
|
|
|
def __post_init__(self):
|
|
"""Calculate token count for the content if not provided."""
|
|
if self.token_count is None:
|
|
content_str = str(self.content)
|
|
self.token_count = token_count(content_str)
|
|
|
|
|
|
@dataclass
|
|
class ContextConfig:
|
|
"""Configuration for context building."""
|
|
|
|
sources: Optional[Dict[str, str]] = None # {source_id: inclusion_level}
|
|
notes: Optional[Dict[str, str]] = None # {note_id: inclusion_level}
|
|
include_insights: bool = True
|
|
include_notes: bool = True
|
|
max_tokens: Optional[int] = None
|
|
priority_weights: Optional[Dict[str, int]] = None # {type: weight}
|
|
|
|
def __post_init__(self):
|
|
"""Initialize default values."""
|
|
if self.sources is None:
|
|
self.sources = {}
|
|
if self.notes is None:
|
|
self.notes = {}
|
|
if self.priority_weights is None:
|
|
self.priority_weights = {"source": 100, "note": 50, "insight": 75}
|
|
|
|
|
|
class ContextBuilder:
|
|
"""
|
|
Generic ContextBuilder that can handle any parameters and build context
|
|
from sources, notebooks, insights, and notes.
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
"""
|
|
Initialize ContextBuilder with flexible parameters.
|
|
|
|
Supported parameters:
|
|
- source_id: str - Include specific source
|
|
- notebook_id: str - Include notebook content
|
|
- include_insights: bool - Include source insights
|
|
- include_notes: bool - Include notes
|
|
- context_config: ContextConfig - Custom context configuration
|
|
- max_tokens: int - Maximum token limit
|
|
- priority_order: List[str] - Custom priority order
|
|
"""
|
|
# Store all parameters for flexibility
|
|
self.params = kwargs
|
|
|
|
# Extract commonly used parameters
|
|
self.source_id: Optional[str] = kwargs.get('source_id')
|
|
self.notebook_id: Optional[str] = kwargs.get('notebook_id')
|
|
self.include_insights: bool = kwargs.get('include_insights', True)
|
|
self.include_notes: bool = kwargs.get('include_notes', True)
|
|
self.max_tokens: Optional[int] = kwargs.get('max_tokens')
|
|
|
|
# Context configuration
|
|
context_config_arg: Optional[ContextConfig] = kwargs.get('context_config')
|
|
self.context_config: ContextConfig
|
|
if context_config_arg is None:
|
|
self.context_config = ContextConfig(
|
|
include_insights=self.include_insights,
|
|
include_notes=self.include_notes,
|
|
max_tokens=self.max_tokens
|
|
)
|
|
else:
|
|
self.context_config = context_config_arg
|
|
|
|
# Items storage
|
|
self.items: List[ContextItem] = []
|
|
|
|
logger.debug(f"ContextBuilder initialized with params: {list(kwargs.keys())}")
|
|
|
|
async def build(self) -> Dict[str, Any]:
|
|
"""
|
|
Build context based on provided parameters.
|
|
|
|
Returns:
|
|
Dict containing the built context with metadata
|
|
"""
|
|
try:
|
|
logger.info("Starting context building")
|
|
|
|
# Clear existing items
|
|
self.items = []
|
|
|
|
# Build context based on parameters
|
|
if self.source_id:
|
|
await self._add_source_context(self.source_id)
|
|
|
|
if self.notebook_id:
|
|
await self._add_notebook_context(self.notebook_id)
|
|
|
|
# Process any additional custom parameters
|
|
await self._process_custom_params()
|
|
|
|
# Apply post-processing
|
|
self.remove_duplicates()
|
|
self.prioritize()
|
|
|
|
if self.max_tokens:
|
|
self.truncate_to_fit(self.max_tokens)
|
|
|
|
# Format and return response
|
|
return self._format_response()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error building context: {str(e)}")
|
|
raise DatabaseOperationError(f"Failed to build context: {str(e)}")
|
|
|
|
async def _add_source_context(
|
|
self,
|
|
source_id: str,
|
|
inclusion_level: str = "insights"
|
|
) -> None:
|
|
"""
|
|
Add source and its insights to context.
|
|
|
|
Args:
|
|
source_id: ID of the source
|
|
inclusion_level: "insights", "full content", or "not in"
|
|
"""
|
|
if inclusion_level == "not in":
|
|
return
|
|
|
|
try:
|
|
# Ensure source ID has table prefix
|
|
full_source_id = (
|
|
source_id if source_id.startswith("source:")
|
|
else f"source:{source_id}"
|
|
)
|
|
|
|
source = await Source.get(full_source_id)
|
|
if not source:
|
|
logger.warning(f"Source {source_id} not found")
|
|
return
|
|
|
|
# Determine context size based on inclusion level
|
|
context_size: Literal["short", "long"] = "long" if "full content" in inclusion_level else "short"
|
|
source_context = await source.get_context(context_size=context_size)
|
|
|
|
# Add source item
|
|
priority = (self.context_config.priority_weights or {}).get("source", 100)
|
|
item = ContextItem(
|
|
id=source.id or "",
|
|
type="source",
|
|
content=source_context,
|
|
priority=priority
|
|
)
|
|
self.add_item(item)
|
|
|
|
# Add insights if requested and available
|
|
if self.include_insights and "insights" in inclusion_level:
|
|
insights = await source.get_insights()
|
|
for insight in insights:
|
|
insight_priority = (self.context_config.priority_weights or {}).get("insight", 75)
|
|
insight_item = ContextItem(
|
|
id=insight.id or "",
|
|
type="insight",
|
|
content={
|
|
"id": insight.id,
|
|
"source_id": source.id,
|
|
"insight_type": insight.insight_type,
|
|
"content": insight.content
|
|
},
|
|
priority=insight_priority
|
|
)
|
|
self.add_item(insight_item)
|
|
|
|
logger.debug(f"Added source context for {source_id}")
|
|
|
|
except NotFoundError:
|
|
logger.warning(f"Source {source_id} not found")
|
|
except Exception as e:
|
|
logger.error(f"Error adding source context for {source_id}: {str(e)}")
|
|
raise
|
|
|
|
async def _add_notebook_context(self, notebook_id: str) -> None:
|
|
"""
|
|
Add notebook content based on context configuration.
|
|
|
|
Args:
|
|
notebook_id: ID of the notebook
|
|
"""
|
|
try:
|
|
notebook = await Notebook.get(notebook_id)
|
|
if not notebook:
|
|
raise NotFoundError(f"Notebook {notebook_id} not found")
|
|
|
|
# Process sources from context config or get all
|
|
config_sources = self.context_config.sources
|
|
if config_sources:
|
|
for source_id, status in config_sources.items():
|
|
await self._add_source_context(source_id, status)
|
|
else:
|
|
# Default: get all sources with insights
|
|
sources = await notebook.get_sources()
|
|
for source in sources:
|
|
if source.id:
|
|
await self._add_source_context(source.id, "insights")
|
|
|
|
# Process notes from context config or get all
|
|
if self.include_notes:
|
|
config_notes = self.context_config.notes
|
|
if config_notes:
|
|
for note_id, status in config_notes.items():
|
|
if "not in" not in status:
|
|
await self._add_note_context(note_id, status)
|
|
else:
|
|
# Default: get all notes with short content
|
|
notes = await notebook.get_notes()
|
|
for note in notes:
|
|
if note.id:
|
|
await self._add_note_context(note.id, "full content")
|
|
|
|
logger.debug(f"Added notebook context for {notebook_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding notebook context for {notebook_id}: {str(e)}")
|
|
raise
|
|
|
|
async def _add_note_context(
|
|
self,
|
|
note_id: str,
|
|
inclusion_level: str = "full content"
|
|
) -> None:
|
|
"""
|
|
Add note to context.
|
|
|
|
Args:
|
|
note_id: ID of the note
|
|
inclusion_level: "full content" or "not in"
|
|
"""
|
|
if inclusion_level == "not in":
|
|
return
|
|
|
|
try:
|
|
# Ensure note ID has table prefix
|
|
full_note_id = (
|
|
note_id if note_id.startswith("note:")
|
|
else f"note:{note_id}"
|
|
)
|
|
|
|
note = await Note.get(full_note_id)
|
|
if not note:
|
|
logger.warning(f"Note {note_id} not found")
|
|
return
|
|
|
|
# Get note context
|
|
context_size: Literal["short", "long"] = "long" if "full content" in inclusion_level else "short"
|
|
note_context = note.get_context(context_size=context_size)
|
|
|
|
# Add note item
|
|
priority = (self.context_config.priority_weights or {}).get("note", 50)
|
|
item = ContextItem(
|
|
id=note.id or "",
|
|
type="note",
|
|
content=note_context,
|
|
priority=priority
|
|
)
|
|
self.add_item(item)
|
|
|
|
logger.debug(f"Added note context for {note_id}")
|
|
|
|
except NotFoundError:
|
|
logger.warning(f"Note {note_id} not found")
|
|
except Exception as e:
|
|
logger.error(f"Error adding note context for {note_id}: {str(e)}")
|
|
|
|
async def _process_custom_params(self) -> None:
|
|
"""Process any additional custom parameters."""
|
|
# Hook for future extensions - can be overridden in subclasses
|
|
# or used to process additional kwargs
|
|
for key, value in self.params.items():
|
|
if key.startswith('custom_'):
|
|
logger.debug(f"Processing custom parameter: {key}={value}")
|
|
# Custom processing logic can be added here
|
|
|
|
def add_item(self, item: ContextItem) -> None:
|
|
"""
|
|
Add a ContextItem to the builder.
|
|
|
|
Args:
|
|
item: ContextItem to add
|
|
"""
|
|
self.items.append(item)
|
|
logger.debug(f"Added item {item.id} with priority {item.priority}")
|
|
|
|
def prioritize(self) -> None:
|
|
"""Sort items by priority (higher priority first)."""
|
|
self.items.sort(key=lambda x: x.priority, reverse=True)
|
|
logger.debug(f"Prioritized {len(self.items)} items")
|
|
|
|
def truncate_to_fit(self, max_tokens: int) -> None:
|
|
"""
|
|
Remove items if total token count exceeds limit.
|
|
|
|
Args:
|
|
max_tokens: Maximum allowed tokens
|
|
"""
|
|
if not max_tokens:
|
|
return
|
|
|
|
total_tokens = sum(item.token_count or 0 for item in self.items)
|
|
|
|
if total_tokens <= max_tokens:
|
|
logger.debug(f"Token count {total_tokens} within limit {max_tokens}")
|
|
return
|
|
|
|
logger.info(f"Truncating from {total_tokens} to {max_tokens} tokens")
|
|
|
|
# Remove items from the end (lowest priority) until under limit
|
|
current_tokens = total_tokens
|
|
removed_count = 0
|
|
|
|
while current_tokens > max_tokens and self.items:
|
|
removed_item = self.items.pop()
|
|
current_tokens -= (removed_item.token_count or 0)
|
|
removed_count += 1
|
|
|
|
logger.info(f"Removed {removed_count} items, final token count: {current_tokens}")
|
|
|
|
def remove_duplicates(self) -> None:
|
|
"""Remove duplicate items based on ID."""
|
|
seen_ids = set()
|
|
deduplicated_items = []
|
|
|
|
for item in self.items:
|
|
if item.id not in seen_ids:
|
|
deduplicated_items.append(item)
|
|
seen_ids.add(item.id)
|
|
|
|
removed_count = len(self.items) - len(deduplicated_items)
|
|
self.items = deduplicated_items
|
|
|
|
if removed_count > 0:
|
|
logger.debug(f"Removed {removed_count} duplicate items")
|
|
|
|
def _format_response(self) -> Dict[str, Any]:
|
|
"""
|
|
Format the final response.
|
|
|
|
Returns:
|
|
Formatted context response
|
|
"""
|
|
# Group items by type
|
|
sources = []
|
|
notes = []
|
|
insights = []
|
|
|
|
for item in self.items:
|
|
if item.type == "source":
|
|
sources.append(item.content)
|
|
elif item.type == "note":
|
|
notes.append(item.content)
|
|
elif item.type == "insight":
|
|
insights.append(item.content)
|
|
|
|
# Calculate total tokens
|
|
total_tokens = sum(item.token_count or 0 for item in self.items)
|
|
|
|
response = {
|
|
"sources": sources,
|
|
"notes": notes,
|
|
"insights": insights,
|
|
"total_tokens": total_tokens,
|
|
"total_items": len(self.items),
|
|
"metadata": {
|
|
"source_count": len(sources),
|
|
"note_count": len(notes),
|
|
"insight_count": len(insights),
|
|
"config": {
|
|
"include_insights": self.include_insights,
|
|
"include_notes": self.include_notes,
|
|
"max_tokens": self.max_tokens
|
|
}
|
|
}
|
|
}
|
|
|
|
# Add notebook_id if provided
|
|
if self.notebook_id:
|
|
response["notebook_id"] = self.notebook_id
|
|
|
|
logger.info(f"Built context with {len(self.items)} items, {total_tokens} tokens")
|
|
|
|
return response
|
|
|
|
|
|
# Convenience functions for common use cases
|
|
|
|
async def build_notebook_context(
|
|
notebook_id: str,
|
|
context_config: Optional[ContextConfig] = None,
|
|
max_tokens: Optional[int] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Build context for a notebook.
|
|
|
|
Args:
|
|
notebook_id: ID of the notebook
|
|
context_config: Optional context configuration
|
|
max_tokens: Optional token limit
|
|
|
|
Returns:
|
|
Built context
|
|
"""
|
|
builder = ContextBuilder(
|
|
notebook_id=notebook_id,
|
|
context_config=context_config,
|
|
max_tokens=max_tokens
|
|
)
|
|
return await builder.build()
|
|
|
|
|
|
async def build_source_context(
|
|
source_id: str,
|
|
include_insights: bool = True,
|
|
max_tokens: Optional[int] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Build context for a single source.
|
|
|
|
Args:
|
|
source_id: ID of the source
|
|
include_insights: Whether to include insights
|
|
max_tokens: Optional token limit
|
|
|
|
Returns:
|
|
Built context
|
|
"""
|
|
builder = ContextBuilder(
|
|
source_id=source_id,
|
|
include_insights=include_insights,
|
|
max_tokens=max_tokens
|
|
)
|
|
return await builder.build()
|
|
|
|
|
|
async def build_mixed_context(
|
|
source_ids: Optional[List[str]] = None,
|
|
note_ids: Optional[List[str]] = None,
|
|
notebook_id: Optional[str] = None,
|
|
max_tokens: Optional[int] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Build context from mixed sources.
|
|
|
|
Args:
|
|
source_ids: List of source IDs
|
|
note_ids: List of note IDs
|
|
notebook_id: Optional notebook ID
|
|
max_tokens: Optional token limit
|
|
|
|
Returns:
|
|
Built context
|
|
"""
|
|
context_config = ContextConfig(max_tokens=max_tokens)
|
|
|
|
# Configure sources
|
|
if source_ids:
|
|
context_config.sources = {sid: "insights" for sid in source_ids}
|
|
|
|
# Configure notes
|
|
if note_ids:
|
|
context_config.notes = {nid: "full content" for nid in note_ids}
|
|
|
|
builder = ContextBuilder(
|
|
notebook_id=notebook_id,
|
|
context_config=context_config,
|
|
max_tokens=max_tokens
|
|
)
|
|
return await builder.build() |