283 lines
9.1 KiB
Python
283 lines
9.1 KiB
Python
import pytest
|
|
import json
|
|
import tempfile
|
|
import os
|
|
|
|
from docetl.runner import DSLRunner
|
|
from docetl.operations.split import SplitOperation
|
|
from docetl.operations.map import MapOperation
|
|
from docetl.operations.gather import GatherOperation
|
|
from tests.conftest import runner
|
|
|
|
|
|
def generate_random_content(length):
|
|
import random
|
|
|
|
words = [
|
|
"apple",
|
|
"banana",
|
|
"cherry",
|
|
"date",
|
|
"elderberry",
|
|
"fig",
|
|
"grape",
|
|
"honeydew",
|
|
"kiwi",
|
|
"lemon",
|
|
"mango",
|
|
"nectarine",
|
|
"orange",
|
|
"papaya",
|
|
"quince",
|
|
"raspberry",
|
|
"strawberry",
|
|
"tangerine",
|
|
"ugli fruit",
|
|
"watermelon",
|
|
]
|
|
return " ".join(random.choices(words, k=length))
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_data():
|
|
documents = []
|
|
for i in range(5):
|
|
document = f"# Document {i+1}\n\n"
|
|
document += generate_random_content(100) + "\n\n"
|
|
for j in range(3):
|
|
document += f"## Section {j+1}\n\n"
|
|
document += generate_random_content(60) + "\n\n"
|
|
for k in range(2):
|
|
document += f"### Subsection {k+1}\n\n"
|
|
document += generate_random_content(40) + "\n\n"
|
|
|
|
documents.append({"id": i + 1, "content": document})
|
|
return documents
|
|
|
|
|
|
@pytest.fixture
|
|
def config_yaml(sample_data):
|
|
with tempfile.NamedTemporaryFile(
|
|
mode="w+", suffix=".yaml", delete=False
|
|
) as temp_file, tempfile.NamedTemporaryFile(
|
|
mode="w+", suffix=".json", delete=False
|
|
) as long_documents_file, tempfile.NamedTemporaryFile(
|
|
mode="w+", suffix=".json", delete=False
|
|
) as output_file:
|
|
config = {
|
|
"datasets": {
|
|
"long_documents": {"type": "file", "path": long_documents_file.name}
|
|
},
|
|
"default_model": "gpt-4o-mini",
|
|
"operations": [
|
|
{
|
|
"name": "count_words",
|
|
"type": "map",
|
|
"optimize": True,
|
|
"recursively_optimize": False,
|
|
"output": {"schema": {"count": "integer"}},
|
|
"prompt": "Count the number of words that start with the letter 'a' in the following text:\n\n{{ input.content }}\n\nReturn only the count as an integer.",
|
|
}
|
|
],
|
|
"pipeline": {
|
|
"steps": [
|
|
{
|
|
"name": "word_analysis",
|
|
"input": "long_documents",
|
|
"operations": ["count_words"],
|
|
}
|
|
],
|
|
"output": {"type": "file", "path": output_file.name},
|
|
},
|
|
}
|
|
json.dump(config, temp_file)
|
|
temp_file.flush()
|
|
|
|
# Create sample data file
|
|
json.dump(sample_data, long_documents_file)
|
|
long_documents_file.flush()
|
|
|
|
return temp_file.name, long_documents_file.name, output_file.name
|
|
|
|
|
|
@pytest.mark.flaky(reruns=2, reruns_delay=1)
|
|
def test_synth_gather(config_yaml):
|
|
config_path, long_documents_path, output_path = config_yaml
|
|
|
|
# Initialize the optimizer
|
|
runner = DSLRunner.from_yaml(config_path)
|
|
|
|
# Run the optimization
|
|
optimized_pipeline, cost = runner.optimize(return_pipeline=True)
|
|
|
|
# Check if a gather operation was synthesized
|
|
synthesized_gather_found = False
|
|
for step in optimized_pipeline.config["pipeline"]["steps"]:
|
|
for op in step["operations"]:
|
|
synthesized_op = [
|
|
operation
|
|
for operation in optimized_pipeline.config["operations"]
|
|
if operation["name"] == op
|
|
][0]
|
|
if synthesized_op.get("type") == "gather":
|
|
synthesized_gather_found = True
|
|
|
|
# Check if the synthesized operation has the correct properties
|
|
assert synthesized_op["type"] == "gather"
|
|
assert "content_key" in synthesized_op
|
|
assert "doc_id_key" in synthesized_op
|
|
assert "order_key" in synthesized_op
|
|
assert "peripheral_chunks" in synthesized_op
|
|
assert "doc_header_key" in synthesized_op
|
|
|
|
break
|
|
if synthesized_gather_found:
|
|
break
|
|
|
|
assert (
|
|
synthesized_gather_found
|
|
), "No synthesized gather operation found in the optimized config"
|
|
|
|
# Run the optimized pipeline
|
|
optimized_pipeline.load_run_save()
|
|
|
|
# Check if the output file was created
|
|
assert os.path.exists(output_path), "Output file was not created"
|
|
|
|
# Load and check the output
|
|
with open(output_path, "r") as f:
|
|
output = json.load(f)
|
|
|
|
with open(long_documents_path, "r") as f:
|
|
sample_data = json.load(f)
|
|
|
|
assert len(output) == len(
|
|
sample_data
|
|
), "Output should have the same number of items as input"
|
|
for item in output:
|
|
assert "count" in item, "Each output item should have a 'count' field"
|
|
assert isinstance(item["count"], int), "The 'count' field should be an integer"
|
|
|
|
# Clean up temporary files
|
|
os.remove(config_path)
|
|
os.remove(long_documents_path)
|
|
os.remove(output_path)
|
|
|
|
|
|
# # Run the test
|
|
# if __name__ == "__main__":
|
|
# sd = sample_data()
|
|
# config = config_yaml(sd)
|
|
# test_synth_gather(config)
|
|
|
|
|
|
def test_split_map_gather(sample_data, runner):
|
|
default_model = "gpt-4o-mini"
|
|
# Define split operation
|
|
split_config = {
|
|
"name": "split_doc",
|
|
"type": "split",
|
|
"split_key": "content",
|
|
"method": "token_count",
|
|
"method_kwargs": {"num_tokens": 100},
|
|
"name": "split_doc",
|
|
}
|
|
|
|
# Define map operation to extract headers
|
|
map_config = {
|
|
"name": "extract_headers",
|
|
"type": "map",
|
|
"optimize": True,
|
|
"prompt": """Analyze the following chunk of a document and extract any headers you see.
|
|
|
|
{{ input.content_chunk }}
|
|
|
|
Provide your analysis as a list of dictionaries, where each dictionary contains a 'header' (string) and 'level' (integer). For example:
|
|
|
|
[
|
|
{"header": "Document 1", "level": 1},
|
|
{"header": "Section 1", "level": 2}
|
|
]
|
|
|
|
Only include headers you find in the text, do not add any that are not present.""",
|
|
"output": {"schema": {"headers": "list[{header: string, level: integer}]"}},
|
|
"model": default_model,
|
|
}
|
|
|
|
# Define gather operation
|
|
gather_config = {
|
|
"name": "gather_doc",
|
|
"type": "gather",
|
|
"content_key": "content_chunk",
|
|
"doc_id_key": "split_doc_id",
|
|
"order_key": "split_doc_chunk_num",
|
|
"peripheral_chunks": {
|
|
"previous": {"tail": {"count": 1}},
|
|
"next": {"head": {"count": 1}},
|
|
},
|
|
"doc_header_key": "headers",
|
|
}
|
|
|
|
# Initialize operations
|
|
split_op = SplitOperation(runner, split_config, default_model, max_threads=64)
|
|
map_op = MapOperation(runner, map_config, default_model, max_threads=64)
|
|
gather_op = GatherOperation(
|
|
runner, gather_config, default_model, max_threads=64
|
|
)
|
|
|
|
# Execute operations
|
|
split_results, split_cost = split_op.execute(sample_data)
|
|
map_results, map_cost = map_op.execute(split_results)
|
|
gather_results, gather_cost = gather_op.execute(map_results)
|
|
|
|
# Assertions
|
|
assert len(gather_results) == len(
|
|
split_results
|
|
), "Number of gathered results should match split results"
|
|
|
|
for result in gather_results:
|
|
assert "headers" in result, "Each gathered result should have a 'headers' field"
|
|
assert isinstance(
|
|
result["headers"], list
|
|
), "The 'headers' field should be a list"
|
|
|
|
for header in result["headers"]:
|
|
assert "header" in header, "Each header should have a 'header' field"
|
|
assert "level" in header, "Each header should have a 'level' field"
|
|
assert isinstance(
|
|
header["header"], str
|
|
), "The 'header' field should be a string"
|
|
assert isinstance(
|
|
header["level"], int
|
|
), "The 'level' field should be an integer"
|
|
|
|
assert (
|
|
"content_chunk_rendered" in result
|
|
), "Each result should have content_chunk_rendered"
|
|
formatted_content = result["content_chunk_rendered"]
|
|
|
|
assert (
|
|
"--- Previous Context ---" in formatted_content
|
|
), "Formatted content should include previous context"
|
|
assert (
|
|
"--- Next Context ---" in formatted_content
|
|
), "Formatted content should include next context"
|
|
assert (
|
|
"--- Begin Main Chunk ---" in formatted_content
|
|
), "Formatted content should include main chunk delimiters"
|
|
assert (
|
|
"--- End Main Chunk ---" in formatted_content
|
|
), "Formatted content should include main chunk delimiters"
|
|
|
|
assert split_cost == 0, "Split operation cost should be zero"
|
|
assert map_cost > 0, "Map operation cost should be greater than zero"
|
|
assert gather_cost == 0, "Gather operation cost should be zero"
|
|
|
|
|
|
# Run the tests
|
|
# if __name__ == "__main__":
|
|
# sd = sample_data()
|
|
# config = config_yaml(sd)
|
|
# test_synth_gather(config)
|
|
# test_split_map_gather(sd)
|