docetl/tests/test_split.py

257 lines
9.2 KiB
Python

import pytest
from docetl.operations.split import SplitOperation
from docetl.operations.map import MapOperation
from docetl.operations.gather import GatherOperation
from tests.conftest import runner
@pytest.fixture
def default_model():
return "gpt-4o-mini"
@pytest.fixture
def max_threads():
return 4
@pytest.fixture
def split_config():
return {
"name": "split_doc",
"type": "split",
"split_key": "content",
"method": "token_count",
"method_kwargs": {"num_tokens": 10},
"name": "split_doc",
}
@pytest.fixture
def split_config_delimiter():
return {
"name": "split_doc_delimiter",
"type": "split",
"split_key": "content",
"method": "delimiter",
"method_kwargs": {"delimiter": "\n", "num_splits_to_group": 2},
"name": "split_doc_delimiter",
}
@pytest.fixture
def map_config():
return {
"name": "summarize_doc",
"type": "map",
"prompt": "Summarize the following text:\n\n{{input.content_chunk}}\n\nSummary:",
"output": {"schema": {"summary": "string"}},
"model": "gpt-4o-mini",
}
@pytest.fixture
def gather_config():
return {
"name": "gather_doc",
"type": "gather",
"content_key": "content_chunk",
"doc_id_key": "split_doc_id",
"order_key": "split_doc_chunk_num",
"peripheral_chunks": {
"previous": {
"head": {"content_key": "summary", "count": 1},
"middle": {"content_key": "summary"},
"tail": {"content_key": "content_chunk", "count": 1},
},
"next": {
"head": {"content_key": "content_chunk", "count": 1},
},
},
}
@pytest.fixture
def input_data():
return [
{
"id": "1",
"content": """
Machine learning is a subset of artificial intelligence that focuses on the development of algorithms and statistical models that enable computer systems to improve their performance on a specific task through experience. Instead of explicitly programming rules, machine learning allows systems to learn patterns from data and make decisions with minimal human intervention.
There are several types of machine learning:
1. Supervised Learning: The algorithm learns from labeled training data, trying to minimize the error between its predictions and the actual labels.
2. Unsupervised Learning: The algorithm tries to find patterns in unlabeled data without predefined categories or labels.
3. Semi-supervised Learning: This approach uses both labeled and unlabeled data for training.
4. Reinforcement Learning: The algorithm learns by interacting with an environment, receiving feedback in the form of rewards or penalties.
Machine learning has numerous applications across various fields, including:
- Image and speech recognition
- Natural language processing
- Recommendation systems
- Fraud detection
- Autonomous vehicles
- Medical diagnosis
- Financial market analysis
As the field continues to evolve, new techniques and applications are constantly emerging, pushing the boundaries of what's possible with artificial intelligence.
""",
},
]
def test_split_map_gather_operations(
runner,
split_config,
map_config,
gather_config,
input_data,
default_model,
max_threads,
):
split_op = SplitOperation(runner, split_config, default_model, max_threads)
map_op = MapOperation(runner, map_config, default_model, max_threads)
gather_op = GatherOperation(runner, gather_config, default_model, max_threads)
# Execute split operation
split_results, split_cost = split_op.execute(input_data)
assert len(split_results) > len(
input_data
), "Split operation should produce more chunks than input items"
assert split_cost == 0, "Split operation cost should be zero"
for result in split_results:
assert (
"split_doc_chunk_num" in result
), "Each result should have a split_doc_chunk_num"
assert "content_chunk" in result, "Each result should have content_chunk"
assert "split_doc_id" in result, "Each result should have split_doc_id"
# Execute map operation for summarization
map_results, map_cost = map_op.execute(split_results)
assert len(map_results) == len(
split_results
), "Map operation should produce same number of results as split operation"
for result in map_results:
assert "summary" in result, "Each result should have a summary"
assert (
len(result["summary"]) > 10
), "Summary should be longer than 10 characters"
# Execute gather operation
gather_results, gather_cost = gather_op.execute(map_results)
assert len(gather_results) == len(
map_results
), "Gather operation should produce same number of results as map operation"
assert gather_cost == 0, "Gather operation cost should be zero"
for idx, result in enumerate(gather_results):
assert (
"content_chunk_rendered" in result
), "Each result should have content_chunk_rendered"
formatted_content = result["content_chunk_rendered"]
assert (
"--- Previous Context ---" in formatted_content
), "Formatted content should include previous context"
assert (
"--- Next Context ---" in formatted_content
), "Formatted content should include next context"
assert (
"--- Begin Main Chunk ---" in formatted_content
), "Formatted content should include main chunk delimiters"
assert (
"--- End Main Chunk ---" in formatted_content
), "Formatted content should include main chunk delimiters"
if idx < 23:
assert (
"characters skipped" in formatted_content
), "Formatted content should include skipped characters information"
# Check that the gather operation preserved all split and map operation fields
for split_result, map_result, gather_result in zip(
split_results, map_results, gather_results
):
for key in split_result:
assert (
key in gather_result
), f"Gather result should preserve split result key: {key}"
for key in map_result:
assert (
key in gather_result
), f"Gather result should preserve map result key: {key}"
def test_split_map_gather_empty_input(
runner, split_config, map_config, gather_config, default_model, max_threads
):
split_op = SplitOperation(runner, split_config, default_model, max_threads)
map_op = MapOperation(runner, map_config, default_model, max_threads)
gather_op = GatherOperation(runner, gather_config, default_model, max_threads)
split_results, split_cost = split_op.execute([])
assert len(split_results) == 0
assert split_cost == 0
map_results, map_cost = map_op.execute(split_results)
assert len(map_results) == 0
assert map_cost == 0
gather_results, gather_cost = gather_op.execute(map_results)
assert len(gather_results) == 0
assert gather_cost == 0
def test_split_delimiter_no_summarization(
runner, split_config_delimiter, default_model, max_threads
):
input_data = [
{"id": "1", "content": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5\nLine 6"},
{"id": "2", "content": "Paragraph 1\n\nParagraph 2\n\nParagraph 3"},
]
split_op = SplitOperation(
runner, split_config_delimiter, default_model, max_threads
)
results, cost = split_op.execute(input_data)
assert len(results) == 5 # 3 chunks for first item, 2 for second
assert cost == 0 # No LLM calls, so cost should be 0
# Check first item's chunks
assert results[0]["content_chunk"] == "Line 1\nLine 2"
assert results[1]["content_chunk"] == "Line 3\nLine 4"
assert results[2]["content_chunk"] == "Line 5\nLine 6"
# Check second item's chunks
assert results[3]["content_chunk"] == "Paragraph 1\nParagraph 2"
assert results[4]["content_chunk"] == "Paragraph 3"
# Check that all results have the necessary fields
for result in results:
assert "split_doc_delimiter_id" in result
assert "split_doc_delimiter_chunk_num" in result
assert "id" in result # Original field should be preserved
# Check that chunk numbers are correct
assert results[0]["split_doc_delimiter_chunk_num"] == 1
assert results[1]["split_doc_delimiter_chunk_num"] == 2
assert results[2]["split_doc_delimiter_chunk_num"] == 3
assert results[3]["split_doc_delimiter_chunk_num"] == 1
assert results[4]["split_doc_delimiter_chunk_num"] == 2
# Check that document IDs are consistent within each original item
assert (
results[0]["split_doc_delimiter_id"]
== results[1]["split_doc_delimiter_id"]
== results[2]["split_doc_delimiter_id"]
)
assert results[3]["split_doc_delimiter_id"] == results[4]["split_doc_delimiter_id"]
assert results[0]["split_doc_delimiter_id"] != results[3]["split_doc_delimiter_id"]