open-notebook/tests/test_source_chat.py

296 lines
9.9 KiB
Python

"""
Integration tests for Source Chat Langgraph.
These tests verify that the Source Chat Langgraph integrates correctly
with the existing Open Notebook infrastructure.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from open_notebook.domain.notebook import Source, SourceInsight
from open_notebook.graphs.source_chat import (
SourceChatState,
_format_source_context,
call_model_with_source_context,
source_chat_graph,
)
@pytest.fixture
def mock_source():
"""Create a mock Source object for testing."""
source = MagicMock(spec=Source)
source.id = "source:test123"
source.title = "Test Source"
source.topics = ["AI", "Machine Learning"]
source.full_text = "This is test content for the source."
source.model_dump.return_value = {
"id": "source:test123",
"title": "Test Source",
"topics": ["AI", "Machine Learning"],
"full_text": "This is test content for the source."
}
return source
@pytest.fixture
def mock_insight():
"""Create a mock SourceInsight object for testing."""
insight = MagicMock(spec=SourceInsight)
insight.id = "insight:test456"
insight.insight_type = "summary"
insight.content = "This is a test insight about the source."
insight.model_dump.return_value = {
"id": "insight:test456",
"insight_type": "summary",
"content": "This is a test insight about the source."
}
return insight
@pytest.fixture
def sample_state():
"""Create a sample SourceChatState for testing."""
return SourceChatState(
messages=[HumanMessage(content="What are the main topics in this source?")],
source_id="source:test123",
source=None,
insights=None,
context=None,
model_override=None,
context_indicators=None
)
@pytest.fixture
def sample_config():
"""Create a sample configuration for testing."""
return {
"configurable": {
"thread_id": "test_thread",
"model_id": "test_model"
}
}
class TestSourceChatState:
"""Test the SourceChatState TypedDict structure."""
def test_source_chat_state_creation(self, sample_state):
"""Test that SourceChatState can be created with required fields."""
assert sample_state["source_id"] == "source:test123"
assert len(sample_state["messages"]) == 1
assert sample_state["source"] is None
assert sample_state["insights"] is None
class TestContextFormatting:
"""Test the context formatting functionality."""
def test_format_source_context_with_sources(self):
"""Test formatting context data containing sources."""
context_data = {
"sources": [{
"id": "source:test123",
"title": "Test Source",
"full_text": "This is test content."
}],
"insights": [],
"metadata": {
"source_count": 1,
"insight_count": 0
},
"total_tokens": 100
}
result = _format_source_context(context_data)
assert "## SOURCE CONTENT" in result
assert "source:test123" in result
assert "Test Source" in result
assert "This is test content." in result
assert "## CONTEXT METADATA" in result
def test_format_source_context_with_insights(self):
"""Test formatting context data containing insights."""
context_data = {
"sources": [],
"insights": [{
"id": "insight:test456",
"insight_type": "summary",
"content": "Test insight content."
}],
"metadata": {
"source_count": 0,
"insight_count": 1
},
"total_tokens": 50
}
result = _format_source_context(context_data)
assert "## SOURCE INSIGHTS" in result
assert "insight:test456" in result
assert "summary" in result
assert "Test insight content." in result
def test_format_source_context_empty(self):
"""Test formatting empty context data."""
context_data = {
"sources": [],
"insights": [],
"metadata": {
"source_count": 0,
"insight_count": 0
},
"total_tokens": 0
}
result = _format_source_context(context_data)
assert "## CONTEXT METADATA" in result
assert "Source count: 0" in result
assert "Insight count: 0" in result
class TestSourceChatIntegration:
"""Test the integration of source chat components."""
@patch('open_notebook.graphs.source_chat.ContextBuilder')
@patch('open_notebook.graphs.source_chat.provision_langchain_model')
@patch('open_notebook.graphs.source_chat.Prompter')
async def test_call_model_with_source_context(
self,
mock_prompter,
mock_provision_model,
mock_context_builder,
sample_state,
sample_config,
mock_source,
mock_insight
):
"""Test the main model calling function with mocked dependencies."""
# Mock the ContextBuilder
mock_builder_instance = AsyncMock()
mock_builder_instance.build.return_value = {
"sources": [mock_source.model_dump()],
"insights": [mock_insight.model_dump()],
"metadata": {"source_count": 1, "insight_count": 1},
"total_tokens": 150
}
mock_context_builder.return_value = mock_builder_instance
# Mock the Prompter
mock_prompter_instance = MagicMock()
mock_prompter_instance.render.return_value = "Rendered prompt"
mock_prompter.return_value = mock_prompter_instance
# Mock the model
mock_model = AsyncMock()
mock_ai_message = AIMessage(content="Test response from AI")
mock_model.invoke.return_value = mock_ai_message
mock_provision_model.return_value = mock_model
# Call the function
result = await call_model_with_source_context(sample_state, sample_config) # type: ignore[misc]
# Verify the result
assert "messages" in result
assert result["messages"] == mock_ai_message
assert "source" in result
assert "insights" in result
assert "context" in result
assert "context_indicators" in result
# Verify mocks were called correctly
mock_context_builder.assert_called_once()
mock_builder_instance.build.assert_called_once()
mock_prompter.assert_called_once_with(prompt_template="source_chat")
mock_provision_model.assert_called_once()
def test_source_chat_graph_structure(self):
"""Test that the source chat graph is properly structured."""
# Verify the graph has the expected structure
assert source_chat_graph is not None
# Check that the graph has nodes
nodes = source_chat_graph.get_graph().nodes
assert "source_chat_agent" in [node for node in nodes]
# Check that the graph has the checkpointer
assert source_chat_graph.checkpointer is not None
@pytest.mark.asyncio
async def test_source_chat_state_validation(self):
"""Test that the source chat validates required state fields."""
# Test with missing source_id
invalid_state = SourceChatState(
messages=[HumanMessage(content="Test")],
source_id="", # Empty source_id should cause error
source=None,
insights=None,
context=None,
model_override=None,
context_indicators=None
)
config = {"configurable": {"thread_id": "test"}}
# This should raise an error due to missing source_id
with pytest.raises(ValueError, match="source_id is required"):
await call_model_with_source_context(invalid_state, config) # type: ignore[misc, arg-type]
class TestSourceChatGraphExecution:
"""Test the execution of the source chat graph."""
@patch('open_notebook.graphs.source_chat.Source')
@patch('open_notebook.graphs.source_chat.ContextBuilder')
@patch('open_notebook.graphs.source_chat.provision_langchain_model')
@patch('open_notebook.graphs.source_chat.Prompter')
@pytest.mark.asyncio
async def test_graph_execution_flow(
self,
mock_prompter,
mock_provision_model,
mock_context_builder,
mock_source_class,
sample_state,
sample_config
):
"""Test the complete graph execution flow with mocked dependencies."""
# Setup mocks (similar to previous test but for full graph execution)
mock_builder_instance = AsyncMock()
mock_builder_instance.build.return_value = {
"sources": [{"id": "source:test123", "title": "Test"}],
"insights": [{"id": "insight:test456", "content": "Test insight"}],
"metadata": {"source_count": 1, "insight_count": 1},
"total_tokens": 100
}
mock_context_builder.return_value = mock_builder_instance
mock_prompter_instance = MagicMock()
mock_prompter_instance.render.return_value = "Test prompt"
mock_prompter.return_value = mock_prompter_instance
mock_model = AsyncMock()
mock_model.invoke.return_value = AIMessage(content="AI response")
mock_provision_model.return_value = mock_model
# Execute the graph
result = await source_chat_graph.ainvoke(sample_state, sample_config)
# Verify the result structure
assert "messages" in result
assert "source_id" in result
assert result["source_id"] == "source:test123"
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])