296 lines
9.9 KiB
Python
296 lines
9.9 KiB
Python
"""
|
|
Integration tests for Source Chat Langgraph.
|
|
|
|
These tests verify that the Source Chat Langgraph integrates correctly
|
|
with the existing Open Notebook infrastructure.
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from langchain_core.messages import AIMessage, HumanMessage
|
|
|
|
from open_notebook.domain.notebook import Source, SourceInsight
|
|
from open_notebook.graphs.source_chat import (
|
|
SourceChatState,
|
|
_format_source_context,
|
|
call_model_with_source_context,
|
|
source_chat_graph,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_source():
|
|
"""Create a mock Source object for testing."""
|
|
source = MagicMock(spec=Source)
|
|
source.id = "source:test123"
|
|
source.title = "Test Source"
|
|
source.topics = ["AI", "Machine Learning"]
|
|
source.full_text = "This is test content for the source."
|
|
source.model_dump.return_value = {
|
|
"id": "source:test123",
|
|
"title": "Test Source",
|
|
"topics": ["AI", "Machine Learning"],
|
|
"full_text": "This is test content for the source."
|
|
}
|
|
return source
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_insight():
|
|
"""Create a mock SourceInsight object for testing."""
|
|
insight = MagicMock(spec=SourceInsight)
|
|
insight.id = "insight:test456"
|
|
insight.insight_type = "summary"
|
|
insight.content = "This is a test insight about the source."
|
|
insight.model_dump.return_value = {
|
|
"id": "insight:test456",
|
|
"insight_type": "summary",
|
|
"content": "This is a test insight about the source."
|
|
}
|
|
return insight
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_state():
|
|
"""Create a sample SourceChatState for testing."""
|
|
return SourceChatState(
|
|
messages=[HumanMessage(content="What are the main topics in this source?")],
|
|
source_id="source:test123",
|
|
source=None,
|
|
insights=None,
|
|
context=None,
|
|
model_override=None,
|
|
context_indicators=None
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_config():
|
|
"""Create a sample configuration for testing."""
|
|
return {
|
|
"configurable": {
|
|
"thread_id": "test_thread",
|
|
"model_id": "test_model"
|
|
}
|
|
}
|
|
|
|
|
|
class TestSourceChatState:
|
|
"""Test the SourceChatState TypedDict structure."""
|
|
|
|
def test_source_chat_state_creation(self, sample_state):
|
|
"""Test that SourceChatState can be created with required fields."""
|
|
assert sample_state["source_id"] == "source:test123"
|
|
assert len(sample_state["messages"]) == 1
|
|
assert sample_state["source"] is None
|
|
assert sample_state["insights"] is None
|
|
|
|
|
|
class TestContextFormatting:
|
|
"""Test the context formatting functionality."""
|
|
|
|
def test_format_source_context_with_sources(self):
|
|
"""Test formatting context data containing sources."""
|
|
context_data = {
|
|
"sources": [{
|
|
"id": "source:test123",
|
|
"title": "Test Source",
|
|
"full_text": "This is test content."
|
|
}],
|
|
"insights": [],
|
|
"metadata": {
|
|
"source_count": 1,
|
|
"insight_count": 0
|
|
},
|
|
"total_tokens": 100
|
|
}
|
|
|
|
result = _format_source_context(context_data)
|
|
|
|
assert "## SOURCE CONTENT" in result
|
|
assert "source:test123" in result
|
|
assert "Test Source" in result
|
|
assert "This is test content." in result
|
|
assert "## CONTEXT METADATA" in result
|
|
|
|
def test_format_source_context_with_insights(self):
|
|
"""Test formatting context data containing insights."""
|
|
context_data = {
|
|
"sources": [],
|
|
"insights": [{
|
|
"id": "insight:test456",
|
|
"insight_type": "summary",
|
|
"content": "Test insight content."
|
|
}],
|
|
"metadata": {
|
|
"source_count": 0,
|
|
"insight_count": 1
|
|
},
|
|
"total_tokens": 50
|
|
}
|
|
|
|
result = _format_source_context(context_data)
|
|
|
|
assert "## SOURCE INSIGHTS" in result
|
|
assert "insight:test456" in result
|
|
assert "summary" in result
|
|
assert "Test insight content." in result
|
|
|
|
def test_format_source_context_empty(self):
|
|
"""Test formatting empty context data."""
|
|
context_data = {
|
|
"sources": [],
|
|
"insights": [],
|
|
"metadata": {
|
|
"source_count": 0,
|
|
"insight_count": 0
|
|
},
|
|
"total_tokens": 0
|
|
}
|
|
|
|
result = _format_source_context(context_data)
|
|
|
|
assert "## CONTEXT METADATA" in result
|
|
assert "Source count: 0" in result
|
|
assert "Insight count: 0" in result
|
|
|
|
|
|
class TestSourceChatIntegration:
|
|
"""Test the integration of source chat components."""
|
|
|
|
@patch('open_notebook.graphs.source_chat.ContextBuilder')
|
|
@patch('open_notebook.graphs.source_chat.provision_langchain_model')
|
|
@patch('open_notebook.graphs.source_chat.Prompter')
|
|
async def test_call_model_with_source_context(
|
|
self,
|
|
mock_prompter,
|
|
mock_provision_model,
|
|
mock_context_builder,
|
|
sample_state,
|
|
sample_config,
|
|
mock_source,
|
|
mock_insight
|
|
):
|
|
"""Test the main model calling function with mocked dependencies."""
|
|
|
|
# Mock the ContextBuilder
|
|
mock_builder_instance = AsyncMock()
|
|
mock_builder_instance.build.return_value = {
|
|
"sources": [mock_source.model_dump()],
|
|
"insights": [mock_insight.model_dump()],
|
|
"metadata": {"source_count": 1, "insight_count": 1},
|
|
"total_tokens": 150
|
|
}
|
|
mock_context_builder.return_value = mock_builder_instance
|
|
|
|
# Mock the Prompter
|
|
mock_prompter_instance = MagicMock()
|
|
mock_prompter_instance.render.return_value = "Rendered prompt"
|
|
mock_prompter.return_value = mock_prompter_instance
|
|
|
|
# Mock the model
|
|
mock_model = AsyncMock()
|
|
mock_ai_message = AIMessage(content="Test response from AI")
|
|
mock_model.invoke.return_value = mock_ai_message
|
|
mock_provision_model.return_value = mock_model
|
|
|
|
# Call the function
|
|
result = await call_model_with_source_context(sample_state, sample_config) # type: ignore[misc]
|
|
|
|
# Verify the result
|
|
assert "messages" in result
|
|
assert result["messages"] == mock_ai_message
|
|
assert "source" in result
|
|
assert "insights" in result
|
|
assert "context" in result
|
|
assert "context_indicators" in result
|
|
|
|
# Verify mocks were called correctly
|
|
mock_context_builder.assert_called_once()
|
|
mock_builder_instance.build.assert_called_once()
|
|
mock_prompter.assert_called_once_with(prompt_template="source_chat")
|
|
mock_provision_model.assert_called_once()
|
|
|
|
def test_source_chat_graph_structure(self):
|
|
"""Test that the source chat graph is properly structured."""
|
|
# Verify the graph has the expected structure
|
|
assert source_chat_graph is not None
|
|
|
|
# Check that the graph has nodes
|
|
nodes = source_chat_graph.get_graph().nodes
|
|
assert "source_chat_agent" in [node for node in nodes]
|
|
|
|
# Check that the graph has the checkpointer
|
|
assert source_chat_graph.checkpointer is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_source_chat_state_validation(self):
|
|
"""Test that the source chat validates required state fields."""
|
|
# Test with missing source_id
|
|
invalid_state = SourceChatState(
|
|
messages=[HumanMessage(content="Test")],
|
|
source_id="", # Empty source_id should cause error
|
|
source=None,
|
|
insights=None,
|
|
context=None,
|
|
model_override=None,
|
|
context_indicators=None
|
|
)
|
|
|
|
config = {"configurable": {"thread_id": "test"}}
|
|
|
|
# This should raise an error due to missing source_id
|
|
with pytest.raises(ValueError, match="source_id is required"):
|
|
await call_model_with_source_context(invalid_state, config) # type: ignore[misc, arg-type]
|
|
|
|
|
|
class TestSourceChatGraphExecution:
|
|
"""Test the execution of the source chat graph."""
|
|
|
|
@patch('open_notebook.graphs.source_chat.Source')
|
|
@patch('open_notebook.graphs.source_chat.ContextBuilder')
|
|
@patch('open_notebook.graphs.source_chat.provision_langchain_model')
|
|
@patch('open_notebook.graphs.source_chat.Prompter')
|
|
@pytest.mark.asyncio
|
|
async def test_graph_execution_flow(
|
|
self,
|
|
mock_prompter,
|
|
mock_provision_model,
|
|
mock_context_builder,
|
|
mock_source_class,
|
|
sample_state,
|
|
sample_config
|
|
):
|
|
"""Test the complete graph execution flow with mocked dependencies."""
|
|
|
|
# Setup mocks (similar to previous test but for full graph execution)
|
|
mock_builder_instance = AsyncMock()
|
|
mock_builder_instance.build.return_value = {
|
|
"sources": [{"id": "source:test123", "title": "Test"}],
|
|
"insights": [{"id": "insight:test456", "content": "Test insight"}],
|
|
"metadata": {"source_count": 1, "insight_count": 1},
|
|
"total_tokens": 100
|
|
}
|
|
mock_context_builder.return_value = mock_builder_instance
|
|
|
|
mock_prompter_instance = MagicMock()
|
|
mock_prompter_instance.render.return_value = "Test prompt"
|
|
mock_prompter.return_value = mock_prompter_instance
|
|
|
|
mock_model = AsyncMock()
|
|
mock_model.invoke.return_value = AIMessage(content="AI response")
|
|
mock_provision_model.return_value = mock_model
|
|
|
|
# Execute the graph
|
|
result = await source_chat_graph.ainvoke(sample_state, sample_config)
|
|
|
|
# Verify the result structure
|
|
assert "messages" in result
|
|
assert "source_id" in result
|
|
assert result["source_id"] == "source:test123"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Run the tests
|
|
pytest.main([__file__, "-v"]) |