150 lines
5.2 KiB
Python
150 lines
5.2 KiB
Python
import pytest
|
|
import random
|
|
from docetl.operations.reduce import ReduceOperation
|
|
from tests.conftest import runner
|
|
|
|
|
|
@pytest.fixture
|
|
def default_model():
|
|
return "gpt-4o-mini"
|
|
|
|
|
|
@pytest.fixture
|
|
def max_threads():
|
|
return 64
|
|
|
|
|
|
@pytest.fixture
|
|
def large_sample_data():
|
|
groups = ["A", "B", "C"]
|
|
topics = ["technology", "science", "politics", "economics", "culture"]
|
|
|
|
def generate_text():
|
|
return f"This is a sample text about {random.choice(topics)}."
|
|
|
|
data = []
|
|
for _ in range(1000): # Generate 1000 items
|
|
group = random.choice(groups)
|
|
text = generate_text()
|
|
importance = random.randint(1, 10)
|
|
data.append({"group": group, "text": text, "importance": importance})
|
|
|
|
return data
|
|
|
|
|
|
def test_random_sampling(runner, default_model, max_threads, large_sample_data):
|
|
config = {
|
|
"name": "reduce_value_sampling",
|
|
"type": "reduce",
|
|
"reduce_key": "group",
|
|
"value_sampling": {"enabled": True, "method": "random", "sample_size": 50},
|
|
"prompt": "Summarize the following texts: {{ inputs|map(attribute='text')|join(' | ') }}",
|
|
"output": {"schema": {"summary": "string"}},
|
|
}
|
|
|
|
operation = ReduceOperation(runner, config, default_model, max_threads)
|
|
results, cost = operation.execute(large_sample_data)
|
|
|
|
assert len(results) == 3, "Should have results for all three groups A, B, and C"
|
|
for result in results:
|
|
assert "summary" in result, "Each result should have a summary"
|
|
assert len(result["summary"]) > 0, "Summary should not be empty"
|
|
|
|
|
|
def test_first_n_sampling(runner, default_model, max_threads, large_sample_data):
|
|
config = {
|
|
"name": "reduce_value_sampling",
|
|
"type": "reduce",
|
|
"reduce_key": "group",
|
|
"value_sampling": {"enabled": True, "method": "first_n", "sample_size": 100},
|
|
"prompt": "Summarize the following texts: {{ inputs|map(attribute='text')|join(' | ') }}",
|
|
"output": {"schema": {"summary": "string"}},
|
|
}
|
|
|
|
operation = ReduceOperation(runner, config, default_model, max_threads)
|
|
results, cost = operation.execute(large_sample_data)
|
|
|
|
assert len(results) == 3, "Should have results for all three groups A, B, and C"
|
|
for result in results:
|
|
assert "summary" in result, "Each result should have a summary"
|
|
assert len(result["summary"]) > 0, "Summary should not be empty"
|
|
|
|
|
|
def test_cluster_sampling(runner, default_model, max_threads, large_sample_data):
|
|
config = {
|
|
"name": "reduce_value_sampling",
|
|
"type": "reduce",
|
|
"reduce_key": "group",
|
|
"value_sampling": {
|
|
"enabled": True,
|
|
"method": "cluster",
|
|
"sample_size": 50,
|
|
"embedding_model": "text-embedding-3-small",
|
|
"embedding_keys": ["text"],
|
|
},
|
|
"prompt": "Summarize the following texts: {{ inputs|map(attribute='text')|join(' | ') }}",
|
|
"output": {"schema": {"summary": "string"}},
|
|
}
|
|
|
|
operation = ReduceOperation(runner, config, default_model, max_threads)
|
|
results, cost = operation.execute(large_sample_data)
|
|
|
|
assert len(results) == 3, "Should have results for all three groups A, B, and C"
|
|
for result in results:
|
|
assert "summary" in result, "Each result should have a summary"
|
|
assert len(result["summary"]) > 0, "Summary should not be empty"
|
|
|
|
|
|
def test_semantic_similarity_sampling(
|
|
runner, default_model, max_threads, large_sample_data
|
|
):
|
|
config = {
|
|
"name": "reduce_value_sampling",
|
|
"type": "reduce",
|
|
"reduce_key": "group",
|
|
"value_sampling": {
|
|
"enabled": True,
|
|
"method": "sem_sim",
|
|
"sample_size": 20,
|
|
"embedding_model": "text-embedding-3-small",
|
|
"embedding_keys": ["text"],
|
|
"query_text": "technology",
|
|
},
|
|
"prompt": "Summarize the following texts: {{ inputs|map(attribute='text')|join(' | ') }}",
|
|
"output": {"schema": {"summary": "string"}},
|
|
}
|
|
|
|
operation = ReduceOperation(runner, config, default_model, max_threads)
|
|
results, cost = operation.execute(large_sample_data)
|
|
|
|
assert len(results) == 3, "Should have results for all three groups A, B, and C"
|
|
for result in results:
|
|
assert "summary" in result, "Each result should have a summary"
|
|
assert len(result["summary"]) > 0, "Summary should not be empty"
|
|
|
|
# make sure there's no mention of "science", "politics", "economics", "culture"
|
|
assert "science" not in result["summary"]
|
|
assert "politics" not in result["summary"]
|
|
assert "economics" not in result["summary"]
|
|
assert "culture" not in result["summary"]
|
|
|
|
|
|
def test_invalid_sampling_method(
|
|
runner, default_model, max_threads, large_sample_data
|
|
):
|
|
config = {
|
|
"name": "reduce_value_sampling",
|
|
"type": "reduce",
|
|
"reduce_key": "group",
|
|
"value_sampling": {
|
|
"enabled": True,
|
|
"method": "invalid_method",
|
|
"sample_size": 50,
|
|
},
|
|
"prompt": "Summarize the following texts: {{ inputs|map(attribute='text')|join(' | ') }}",
|
|
"output": {"schema": {"summary": "string"}},
|
|
}
|
|
|
|
with pytest.raises(ValueError, match="Invalid 'method'"):
|
|
ReduceOperation(runner, config, default_model, max_threads)
|