Compare commits
1 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
ae06d5876e |
|
|
@ -30,6 +30,8 @@ from .map_reduce_fusion import MapReduceFusionDirective
|
|||
from .hierarchical_reduce import HierarchicalReduceDirective
|
||||
from .cascade_filtering import CascadeFilteringDirective
|
||||
from .arbitrary_rewrite import ArbitraryRewriteDirective
|
||||
from .map_to_map_resolve_reduce import MapToMapResolveReduceDirective
|
||||
from .map_resolve_to_map_with_categories import MapResolveToMapWithCategoriesDirective
|
||||
|
||||
# Registry of all available directives
|
||||
ALL_DIRECTIVES = [
|
||||
|
|
@ -53,6 +55,8 @@ ALL_DIRECTIVES = [
|
|||
HierarchicalReduceDirective(),
|
||||
CascadeFilteringDirective(),
|
||||
ArbitraryRewriteDirective(),
|
||||
MapToMapResolveReduceDirective(),
|
||||
MapResolveToMapWithCategoriesDirective(),
|
||||
]
|
||||
|
||||
ALL_COST_DIRECTIVES = [
|
||||
|
|
@ -179,8 +183,10 @@ __all__ = [
|
|||
"HierarchicalReduceDirective",
|
||||
"CascadeFilteringDirective",
|
||||
"ArbitraryRewriteDirective",
|
||||
"MapToMapResolveReduceDirective",
|
||||
"MapResolveToMapWithCategoriesDirective",
|
||||
"ALL_DIRECTIVES",
|
||||
"DIRECTIVE_REGISTRY",
|
||||
"DIRECTIVE_REGISTRY",
|
||||
"get_all_directive_strings",
|
||||
"instantiate_directive"
|
||||
]
|
||||
|
|
@ -123,7 +123,8 @@ class Directive(BaseModel, ABC):
|
|||
|
||||
try:
|
||||
# 1. Execute the directive
|
||||
actual_output, _ = self.instantiate(
|
||||
# instantiate returns (new_ops_plan, message_history, call_cost)
|
||||
instantiate_result = self.instantiate(
|
||||
operators=(
|
||||
[test_case.input_config]
|
||||
if isinstance(test_case.input_config, dict)
|
||||
|
|
@ -134,6 +135,11 @@ class Directive(BaseModel, ABC):
|
|||
input_file_path=temp_file_path,
|
||||
pipeline_code=fake_pipeline,
|
||||
)
|
||||
# Handle both 2-tuple and 3-tuple returns
|
||||
if isinstance(instantiate_result, tuple):
|
||||
actual_output = instantiate_result[0]
|
||||
else:
|
||||
actual_output = instantiate_result
|
||||
|
||||
# 2. Use LLM judge to evaluate
|
||||
judge_result = self._llm_judge_test(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,366 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
MapResolveToMapWithCategoriesInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class MapResolveToMapWithCategoriesDirective(Directive):
|
||||
name: str = Field(
|
||||
default="map_resolve_to_map_with_categories",
|
||||
description="The name of the directive",
|
||||
)
|
||||
formal_description: str = Field(default="Map -> Resolve => Map (with categories)")
|
||||
nl_description: str = Field(
|
||||
default="Replace a Map -> Resolve pattern with a single Map operation that has predefined categories. The agent analyzes the data and task to propose a set of canonical categories, and the new Map forces outputs into one of these categories (or 'none of the above'). This effectively performs entity resolution deterministically by standardizing outputs upfront, avoiding the need for pairwise comparisons."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When a Map operation produces outputs that are then resolved/deduplicated, and the set of valid output categories can be enumerated upfront. This is more efficient than Resolve when the category space is small and well-defined (e.g., standardizing company types, product categories, sentiment labels). The target must be a Map operation followed by a Resolve operation."
|
||||
)
|
||||
instantiate_schema_type: Type[BaseModel] = (
|
||||
MapResolveToMapWithCategoriesInstantiateSchema
|
||||
)
|
||||
|
||||
example: str = Field(
|
||||
default=(
|
||||
"Original Pipeline:\n"
|
||||
"- name: extract_sentiment\n"
|
||||
" type: map\n"
|
||||
" prompt: |\n"
|
||||
" What is the sentiment of this review?\n"
|
||||
" {{ input.review }}\n"
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
" sentiment: string\n"
|
||||
"\n"
|
||||
"- name: normalize_sentiment\n"
|
||||
" type: resolve\n"
|
||||
" comparison_prompt: |\n"
|
||||
" Are these sentiments equivalent?\n"
|
||||
" Sentiment 1: {{ input1.sentiment }}\n"
|
||||
" Sentiment 2: {{ input2.sentiment }}\n"
|
||||
" resolution_prompt: |\n"
|
||||
" Normalize these sentiments:\n"
|
||||
" {% for input in inputs %}\n"
|
||||
" - {{ input.sentiment }}\n"
|
||||
" {% endfor %}\n"
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
" sentiment: string\n"
|
||||
"\n"
|
||||
"Example InstantiateSchema:\n"
|
||||
"MapResolveToMapWithCategoriesInstantiateSchema(\n"
|
||||
" categories=['Positive', 'Negative', 'Neutral', 'Mixed'],\n"
|
||||
" category_key='sentiment',\n"
|
||||
" new_prompt='''Analyze the sentiment of this review and classify it into one of the following categories:\n"
|
||||
"\n"
|
||||
"Categories:\n"
|
||||
"- Positive: Clearly positive sentiment, satisfaction, praise\n"
|
||||
"- Negative: Clearly negative sentiment, complaints, criticism\n"
|
||||
"- Neutral: No strong sentiment, factual statements\n"
|
||||
"- Mixed: Contains both positive and negative elements\n"
|
||||
"- None of the above: If the review doesn't fit any category\n"
|
||||
"\n"
|
||||
"Review: {{ input.review }}\n"
|
||||
"\n"
|
||||
"Return exactly one of: Positive, Negative, Neutral, Mixed, or None of the above.''',\n"
|
||||
" include_none_of_above=True,\n"
|
||||
")"
|
||||
),
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="sentiment_categorization",
|
||||
description="Should replace map+resolve with categorized map for sentiment",
|
||||
input_config=[
|
||||
{
|
||||
"name": "extract_sentiment",
|
||||
"type": "map",
|
||||
"prompt": "What is the sentiment? {{ input.text }}",
|
||||
"output": {"schema": {"sentiment": "string"}},
|
||||
},
|
||||
{
|
||||
"name": "normalize_sentiment",
|
||||
"type": "resolve",
|
||||
"comparison_prompt": "Same sentiment? {{ input1.sentiment }} vs {{ input2.sentiment }}",
|
||||
"resolution_prompt": "Normalize: {% for input in inputs %}{{ input.sentiment }}{% endfor %}",
|
||||
"output": {"schema": {"sentiment": "string"}},
|
||||
},
|
||||
],
|
||||
target_ops=["extract_sentiment", "normalize_sentiment"],
|
||||
expected_behavior="Should create a single map with predefined sentiment categories",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, MapResolveToMapWithCategoriesDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("MapResolveToMapWithCategoriesDirective")
|
||||
|
||||
def to_string_for_instantiate(
|
||||
self, map_op: Dict, resolve_op: Dict, sample_data: List[Dict] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive.
|
||||
|
||||
Args:
|
||||
map_op (Dict): The map operation configuration.
|
||||
resolve_op (Dict): The resolve operation configuration.
|
||||
sample_data (List[Dict], optional): Sample data to help identify categories.
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
sample_str = ""
|
||||
if sample_data:
|
||||
sample_str = f"\n\nSample Input Data (first 5 items):\n{json.dumps(sample_data[:5], indent=2)}"
|
||||
|
||||
return (
|
||||
f"You are an expert at optimizing data processing pipelines by replacing entity resolution with categorical constraints.\n\n"
|
||||
f"Map Operation:\n"
|
||||
f"{str(map_op)}\n\n"
|
||||
f"Resolve Operation:\n"
|
||||
f"{str(resolve_op)}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to replace the Map -> Resolve pattern with a single Map that uses predefined categories.\n\n"
|
||||
f"Key Requirements:\n"
|
||||
f"1. Analyze the map's output field and the resolve operation to understand what values are being normalized:\n"
|
||||
f" - Look at the comparison_prompt to understand what variations are being matched\n"
|
||||
f" - Look at the resolution_prompt to understand the canonical form\n\n"
|
||||
f"2. Propose a set of categories that cover all expected outputs:\n"
|
||||
f" - Categories should be mutually exclusive\n"
|
||||
f" - Categories should be exhaustive (cover all realistic cases)\n"
|
||||
f" - Consider including 'None of the above' for edge cases\n\n"
|
||||
f"3. Optionally provide descriptions for each category to help the LLM classify correctly\n\n"
|
||||
f"4. Create a new_prompt that:\n"
|
||||
f" - Lists all valid categories with their descriptions\n"
|
||||
f" - Instructs the LLM to output exactly one category\n"
|
||||
f" - References the input using {{{{ input.key }}}} syntax\n"
|
||||
f" - Includes any context from the original map prompt\n\n"
|
||||
f"5. Identify the category_key (the output field that will contain the category)\n\n"
|
||||
f"Benefits of this approach:\n"
|
||||
f"- Eliminates O(n^2) pairwise comparisons from Resolve\n"
|
||||
f"- Produces consistent, standardized outputs\n"
|
||||
f"- Reduces cost by removing the Resolve operation entirely\n"
|
||||
f"{sample_str}\n\n"
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please analyze the operations and propose appropriate categories. Output the MapResolveToMapWithCategoriesInstantiateSchema."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
map_op: Dict,
|
||||
resolve_op: Dict,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
sample_data: List[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive.
|
||||
|
||||
Args:
|
||||
map_op (Dict): The map operation configuration.
|
||||
resolve_op (Dict): The resolve operation configuration.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
sample_data (List[Dict], optional): Sample data to help identify categories.
|
||||
|
||||
Returns:
|
||||
MapResolveToMapWithCategoriesInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant for document processing pipelines specializing in categorical classification.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(
|
||||
map_op, resolve_op, sample_data
|
||||
),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=MapResolveToMapWithCategoriesInstantiateSchema,
|
||||
)
|
||||
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = MapResolveToMapWithCategoriesInstantiateSchema(**parsed_res)
|
||||
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model,
|
||||
ops_list: List[Dict],
|
||||
map_op_name: str,
|
||||
resolve_op_name: str,
|
||||
rewrite: MapResolveToMapWithCategoriesInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config.
|
||||
"""
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find position of the map and resolve ops
|
||||
map_pos = None
|
||||
resolve_pos = None
|
||||
map_op = None
|
||||
|
||||
for i, op in enumerate(new_ops_list):
|
||||
if op["name"] == map_op_name:
|
||||
map_pos = i
|
||||
map_op = op
|
||||
elif op["name"] == resolve_op_name:
|
||||
resolve_pos = i
|
||||
|
||||
if map_pos is None or resolve_pos is None:
|
||||
raise ValueError(
|
||||
f"Could not find map '{map_op_name}' and resolve '{resolve_op_name}' operations"
|
||||
)
|
||||
|
||||
# Determine the model to use
|
||||
default_model = map_op.get("model", global_default_model)
|
||||
|
||||
# Build the list of valid values for validation
|
||||
valid_values = list(rewrite.categories)
|
||||
if rewrite.include_none_of_above:
|
||||
valid_values.append("None of the above")
|
||||
|
||||
# Modify the map operation with the new prompt and add validation
|
||||
new_ops_list[map_pos]["prompt"] = rewrite.new_prompt
|
||||
new_ops_list[map_pos]["model"] = default_model
|
||||
|
||||
# Add validation to ensure output is one of the categories
|
||||
if "validate" not in new_ops_list[map_pos]:
|
||||
new_ops_list[map_pos]["validate"] = []
|
||||
|
||||
# Add validation rule for the category key
|
||||
validation_rule = f"output['{rewrite.category_key}'] in {valid_values}"
|
||||
new_ops_list[map_pos]["validate"].append(validation_rule)
|
||||
|
||||
# Update the output schema to reflect the category key
|
||||
if "output" not in new_ops_list[map_pos]:
|
||||
new_ops_list[map_pos]["output"] = {"schema": {}}
|
||||
new_ops_list[map_pos]["output"]["schema"][rewrite.category_key] = "string"
|
||||
|
||||
# Remove the resolve operation
|
||||
new_ops_list.pop(resolve_pos)
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
optimize_goal="acc",
|
||||
global_default_model: str = None,
|
||||
dataset: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Assert that there are exactly two target ops (map and resolve)
|
||||
assert (
|
||||
len(target_ops) == 2
|
||||
), "There must be exactly two target ops (map and resolve) to instantiate this directive"
|
||||
|
||||
# Find the map and resolve operations
|
||||
map_op = None
|
||||
resolve_op = None
|
||||
|
||||
for op in operators:
|
||||
if op["name"] == target_ops[0]:
|
||||
if op.get("type") == "map":
|
||||
map_op = op
|
||||
elif op.get("type") == "resolve":
|
||||
resolve_op = op
|
||||
elif op["name"] == target_ops[1]:
|
||||
if op.get("type") == "map":
|
||||
map_op = op
|
||||
elif op.get("type") == "resolve":
|
||||
resolve_op = op
|
||||
|
||||
if map_op is None or resolve_op is None:
|
||||
raise ValueError(
|
||||
f"Could not find both a map and resolve operation in target_ops: {target_ops}"
|
||||
)
|
||||
|
||||
# Verify the map comes before resolve
|
||||
map_idx = next(
|
||||
i for i, op in enumerate(operators) if op["name"] == map_op["name"]
|
||||
)
|
||||
resolve_idx = next(
|
||||
i for i, op in enumerate(operators) if op["name"] == resolve_op["name"]
|
||||
)
|
||||
|
||||
if map_idx >= resolve_idx:
|
||||
raise ValueError(
|
||||
f"Map operation '{map_op['name']}' must come before resolve operation '{resolve_op['name']}'"
|
||||
)
|
||||
|
||||
# Load sample data if available
|
||||
sample_data = None
|
||||
if dataset:
|
||||
try:
|
||||
with open(dataset, "r") as f:
|
||||
sample_data = json.load(f)
|
||||
except Exception:
|
||||
pass # Ignore if we can't load sample data
|
||||
|
||||
# Instantiate the directive
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
map_op,
|
||||
resolve_op,
|
||||
agent_llm,
|
||||
message_history,
|
||||
sample_data,
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators
|
||||
new_ops_plan = self.apply(
|
||||
global_default_model, operators, map_op["name"], resolve_op["name"], rewrite
|
||||
)
|
||||
return new_ops_plan, message_history, call_cost
|
||||
|
|
@ -0,0 +1,340 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
MapToMapResolveReduceInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class MapToMapResolveReduceDirective(Directive):
|
||||
name: str = Field(
|
||||
default="map_to_map_resolve_reduce", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(default="Map -> Reduce => Map -> Resolve -> Reduce")
|
||||
nl_description: str = Field(
|
||||
default="Insert a Resolve operation between Map and Reduce to deduplicate or normalize entities before aggregation. The Resolve operation uses code-powered blocking conditions to efficiently identify which pairs to compare, avoiding O(n^2) comparisons. This is useful when the Map output contains duplicate or near-duplicate entities that should be merged before the Reduce step."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When a Map operation produces outputs that may contain duplicates, variations, or near-duplicates (e.g., different spellings of names, similar categories), and these should be normalized before the Reduce aggregation step. The target must be a Map operation followed by a Reduce operation."
|
||||
)
|
||||
instantiate_schema_type: Type[BaseModel] = MapToMapResolveReduceInstantiateSchema
|
||||
|
||||
example: str = Field(
|
||||
default=(
|
||||
"Original Pipeline:\n"
|
||||
"- name: extract_companies\n"
|
||||
" type: map\n"
|
||||
" prompt: |\n"
|
||||
" Extract company names from this news article:\n"
|
||||
" {{ input.article }}\n"
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
" company_name: string\n"
|
||||
"\n"
|
||||
"- name: aggregate_companies\n"
|
||||
" type: reduce\n"
|
||||
" reduce_key: sector\n"
|
||||
" prompt: |\n"
|
||||
" List all unique companies in this sector:\n"
|
||||
" {% for input in inputs %}\n"
|
||||
" - {{ input.company_name }}\n"
|
||||
" {% endfor %}\n"
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
" companies: list[str]\n"
|
||||
"\n"
|
||||
"Example InstantiateSchema:\n"
|
||||
"MapToMapResolveReduceInstantiateSchema(\n"
|
||||
" resolve_name='normalize_company_names',\n"
|
||||
" comparison_prompt='''Are these two company names referring to the same company?\n"
|
||||
"Company 1: {{ input1.company_name }}\n"
|
||||
"Company 2: {{ input2.company_name }}\n"
|
||||
"Consider variations like abbreviations (IBM vs International Business Machines), \n"
|
||||
"different legal suffixes (Inc, Corp, LLC), and common misspellings.''',\n"
|
||||
" resolution_prompt='''Given these variations of a company name:\n"
|
||||
"{% for input in inputs %}\n"
|
||||
"- {{ input.company_name }}\n"
|
||||
"{% endfor %}\n"
|
||||
"Return the canonical/official company name.''',\n"
|
||||
" blocking_conditions=[\n"
|
||||
" \"input1['company_name'][:3].lower() == input2['company_name'][:3].lower()\",\n"
|
||||
" \"input1['company_name'].split()[0].lower() == input2['company_name'].split()[0].lower()\",\n"
|
||||
" ],\n"
|
||||
" blocking_keys=['company_name'],\n"
|
||||
" limit_comparisons=1000,\n"
|
||||
" output_schema={'company_name': 'string'},\n"
|
||||
")"
|
||||
),
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="company_name_normalization",
|
||||
description="Should insert resolve between map and reduce for company names",
|
||||
input_config=[
|
||||
{
|
||||
"name": "extract_companies",
|
||||
"type": "map",
|
||||
"prompt": "Extract company name from: {{ input.text }}",
|
||||
"output": {"schema": {"company_name": "string"}},
|
||||
},
|
||||
{
|
||||
"name": "aggregate_by_sector",
|
||||
"type": "reduce",
|
||||
"reduce_key": "sector",
|
||||
"prompt": "List companies:\n{% for input in inputs %}\n- {{ input.company_name }}\n{% endfor %}",
|
||||
"output": {"schema": {"companies": "list[str]"}},
|
||||
},
|
||||
],
|
||||
target_ops=["extract_companies", "aggregate_by_sector"],
|
||||
expected_behavior="Should create a resolve operation between the map and reduce with appropriate blocking conditions",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, MapToMapResolveReduceDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("MapToMapResolveReduceDirective")
|
||||
|
||||
def to_string_for_instantiate(self, map_op: Dict, reduce_op: Dict) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive.
|
||||
|
||||
Args:
|
||||
map_op (Dict): The map operation configuration.
|
||||
reduce_op (Dict): The reduce operation configuration.
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
return (
|
||||
f"You are an expert at optimizing data processing pipelines by inserting entity resolution steps.\n\n"
|
||||
f"Map Operation:\n"
|
||||
f"{str(map_op)}\n\n"
|
||||
f"Reduce Operation:\n"
|
||||
f"{str(reduce_op)}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to insert a Resolve operation between the Map and Reduce to deduplicate/normalize entities.\n\n"
|
||||
f"Key Requirements:\n"
|
||||
f"1. Create a comparison_prompt that determines if two items from the Map output are duplicates/variations:\n"
|
||||
f" - Must reference {{ input1.key }} and {{ input2.key }} for comparing fields\n"
|
||||
f" - Should handle common variations (abbreviations, misspellings, formatting differences)\n\n"
|
||||
f"2. Create a resolution_prompt that merges matched items into a canonical form:\n"
|
||||
f" - Must use {{% for input in inputs %}} to iterate over matched items\n"
|
||||
f" - Should produce the most authoritative/complete version\n\n"
|
||||
f"3. Create blocking_conditions to avoid O(n^2) comparisons:\n"
|
||||
f" - These are Python expressions with access to 'input1' and 'input2' dicts\n"
|
||||
f" - They should filter pairs to only those likely to match\n"
|
||||
f" - Examples:\n"
|
||||
f" * \"input1['name'][:3].lower() == input2['name'][:3].lower()\" (first 3 chars match)\n"
|
||||
f" * \"input1['name'].split()[0].lower() == input2['name'].split()[0].lower()\" (first word matches)\n"
|
||||
f" * \"abs(len(input1['name']) - len(input2['name'])) < 10\" (similar length)\n"
|
||||
f" - Multiple conditions are OR'd together\n\n"
|
||||
f"4. Set limit_comparisons to cap the number of pairs (recommended: 500-2000)\n\n"
|
||||
f"5. The output_schema should match what the Reduce operation expects from each input\n\n"
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output the MapToMapResolveReduceInstantiateSchema."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self, map_op: Dict, reduce_op: Dict, agent_llm: str, message_history: list = []
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive.
|
||||
|
||||
Args:
|
||||
map_op (Dict): The map operation configuration.
|
||||
reduce_op (Dict): The reduce operation configuration.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
|
||||
Returns:
|
||||
MapToMapResolveReduceInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant for document processing pipelines specializing in entity resolution.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(map_op, reduce_op),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=MapToMapResolveReduceInstantiateSchema,
|
||||
)
|
||||
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = MapToMapResolveReduceInstantiateSchema(**parsed_res)
|
||||
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model,
|
||||
ops_list: List[Dict],
|
||||
map_op_name: str,
|
||||
reduce_op_name: str,
|
||||
rewrite: MapToMapResolveReduceInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config.
|
||||
"""
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find position of the map and reduce ops
|
||||
map_pos = None
|
||||
reduce_pos = None
|
||||
map_op = None
|
||||
|
||||
for i, op in enumerate(ops_list):
|
||||
if op["name"] == map_op_name:
|
||||
map_pos = i
|
||||
map_op = op
|
||||
elif op["name"] == reduce_op_name:
|
||||
reduce_pos = i
|
||||
|
||||
if map_pos is None or reduce_pos is None:
|
||||
raise ValueError(
|
||||
f"Could not find map '{map_op_name}' and reduce '{reduce_op_name}' operations"
|
||||
)
|
||||
|
||||
# Determine the model to use
|
||||
default_model = map_op.get("model", global_default_model)
|
||||
|
||||
# Find the reduce operation to get the reduce_key
|
||||
reduce_op = None
|
||||
for op in ops_list:
|
||||
if op["name"] == reduce_op_name:
|
||||
reduce_op = op
|
||||
break
|
||||
|
||||
# Derive output schema from reduce_key - that's what's being grouped/resolved
|
||||
reduce_key = reduce_op.get("reduce_key", []) if reduce_op else []
|
||||
if isinstance(reduce_key, str):
|
||||
reduce_key = [reduce_key]
|
||||
|
||||
# Build output schema from reduce_key fields
|
||||
output_schema = {key: "string" for key in reduce_key}
|
||||
|
||||
# Create the resolve operation
|
||||
resolve_op = {
|
||||
"name": rewrite.resolve_name,
|
||||
"type": "resolve",
|
||||
"comparison_prompt": rewrite.comparison_prompt,
|
||||
"resolution_prompt": rewrite.resolution_prompt,
|
||||
"blocking_conditions": rewrite.blocking_conditions,
|
||||
"blocking_keys": rewrite.blocking_keys,
|
||||
"limit_comparisons": rewrite.limit_comparisons,
|
||||
"model": default_model,
|
||||
"output": {"schema": output_schema},
|
||||
}
|
||||
|
||||
# Insert resolve operation after the map operation
|
||||
new_ops_list.insert(map_pos + 1, resolve_op)
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
optimize_goal="acc",
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Assert that there are exactly two target ops (map and reduce)
|
||||
assert (
|
||||
len(target_ops) == 2
|
||||
), "There must be exactly two target ops (map and reduce) to instantiate this directive"
|
||||
|
||||
# Find the map and reduce operations
|
||||
map_op = None
|
||||
reduce_op = None
|
||||
|
||||
for op in operators:
|
||||
if op["name"] == target_ops[0]:
|
||||
if op.get("type") == "map":
|
||||
map_op = op
|
||||
elif op.get("type") == "reduce":
|
||||
reduce_op = op
|
||||
elif op["name"] == target_ops[1]:
|
||||
if op.get("type") == "map":
|
||||
map_op = op
|
||||
elif op.get("type") == "reduce":
|
||||
reduce_op = op
|
||||
|
||||
if map_op is None or reduce_op is None:
|
||||
raise ValueError(
|
||||
f"Could not find both a map and reduce operation in target_ops: {target_ops}"
|
||||
)
|
||||
|
||||
# Verify the map comes before reduce
|
||||
map_idx = next(
|
||||
i for i, op in enumerate(operators) if op["name"] == map_op["name"]
|
||||
)
|
||||
reduce_idx = next(
|
||||
i for i, op in enumerate(operators) if op["name"] == reduce_op["name"]
|
||||
)
|
||||
|
||||
if map_idx >= reduce_idx:
|
||||
raise ValueError(
|
||||
f"Map operation '{map_op['name']}' must come before reduce operation '{reduce_op['name']}'"
|
||||
)
|
||||
|
||||
# Instantiate the directive
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
map_op,
|
||||
reduce_op,
|
||||
agent_llm,
|
||||
message_history,
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators
|
||||
new_ops_plan = self.apply(
|
||||
global_default_model, operators, map_op["name"], reduce_op["name"], rewrite
|
||||
)
|
||||
return new_ops_plan, message_history, call_cost
|
||||
|
|
@ -61,7 +61,7 @@ class MapOpConfig(BaseModel):
|
|||
...,
|
||||
description="The keys of the output of the Map operator, to be referenced in the downstream operator's prompt. Can be a single key or a list of keys. Can be new keys or existing keys from the map operator we are rewriting.",
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def validate_prompt_contains_input_key(cls, value: str) -> str:
|
||||
"""
|
||||
|
|
@ -563,7 +563,6 @@ class ChunkHeaderSummaryInstantiateSchema(BaseModel):
|
|||
return v
|
||||
|
||||
|
||||
|
||||
class SamplingConfig(BaseModel):
|
||||
"""Configuration for optional sampling in document chunking."""
|
||||
|
||||
|
|
@ -634,7 +633,7 @@ class DocumentChunkingInstantiateSchema(BaseModel):
|
|||
default=None,
|
||||
description="Optional sampling configuration. If provided, inserts a Sample operation between Gather and Map. Use by default UNLESS task requires processing ALL chunks (like comprehensive extraction of all instances).",
|
||||
)
|
||||
|
||||
|
||||
def validate_split_key_exists_in_input(self, input_file_path: str) -> None:
|
||||
"""
|
||||
Validates that the split_key exists in the input JSON file items.
|
||||
|
|
@ -812,7 +811,6 @@ class DocumentChunkingTopKInstantiateSchema(BaseModel):
|
|||
description="Configuration for the topk operation to select relevant chunks",
|
||||
)
|
||||
|
||||
|
||||
def validate_split_key_exists_in_input(self, input_file_path: str) -> None:
|
||||
"""
|
||||
Validates that the split_key exists in the input JSON file items.
|
||||
|
|
@ -1211,6 +1209,109 @@ class SearchReplaceEdit(BaseModel):
|
|||
# )
|
||||
|
||||
|
||||
class MapToMapResolveReduceInstantiateSchema(BaseModel):
|
||||
"""
|
||||
Schema for inserting a Resolve operation between Map and Reduce.
|
||||
Transforms Map -> Reduce => Map -> Resolve -> Reduce pattern.
|
||||
The Resolve operation deduplicates/normalizes entities before aggregation.
|
||||
"""
|
||||
|
||||
resolve_name: str = Field(..., description="The name of the new Resolve operator")
|
||||
comparison_prompt: str = Field(
|
||||
...,
|
||||
description="Jinja prompt template for comparing two items. Must use {{ input1.key }} and {{ input2.key }} to reference fields from both items being compared.",
|
||||
)
|
||||
resolution_prompt: str = Field(
|
||||
...,
|
||||
description="Jinja prompt template for resolving matched items into a canonical form. Must use {% for input in inputs %} to iterate over matched items.",
|
||||
)
|
||||
blocking_conditions: List[str] = Field(
|
||||
...,
|
||||
description="List of Python expressions that determine if two items should be compared. Each expression has access to 'input1' and 'input2' dicts. Example: \"input1['category'].lower() == input2['category'].lower()\"",
|
||||
)
|
||||
blocking_keys: List[str] = Field(
|
||||
...,
|
||||
description="Keys to use for blocking/grouping items before comparison. Must include at least the reduce_key from the downstream Reduce operation, plus any additional context helpful for resolution.",
|
||||
)
|
||||
limit_comparisons: int = Field(
|
||||
default=15000,
|
||||
description="Maximum number of pairs to compare. Code-based blocked pairs are prioritized. Defaults to 15000 to avoid O(n^2) comparisons.",
|
||||
gt=0,
|
||||
)
|
||||
|
||||
@field_validator("comparison_prompt")
|
||||
@classmethod
|
||||
def check_comparison_prompt(cls, v: str) -> str:
|
||||
if "input1" not in v or "input2" not in v:
|
||||
raise ValueError(
|
||||
"comparison_prompt must reference both 'input1' and 'input2' variables"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("resolution_prompt")
|
||||
@classmethod
|
||||
def check_resolution_prompt(cls, v: str) -> str:
|
||||
if "inputs" not in v:
|
||||
raise ValueError(
|
||||
"resolution_prompt must reference 'inputs' variable for iterating over matched items"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("blocking_conditions")
|
||||
@classmethod
|
||||
def check_blocking_conditions(cls, v: List[str]) -> List[str]:
|
||||
if not v:
|
||||
raise ValueError(
|
||||
"At least one blocking condition must be provided to avoid O(n^2) comparisons"
|
||||
)
|
||||
for condition in v:
|
||||
if "input1" not in condition or "input2" not in condition:
|
||||
raise ValueError(
|
||||
f"Blocking condition must reference both 'input1' and 'input2': {condition}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class MapResolveToMapWithCategoriesInstantiateSchema(BaseModel):
|
||||
"""
|
||||
Schema for replacing Map -> Resolve with a single Map that has predefined categories.
|
||||
The agent proposes categories based on analysis of the data/task, and the new Map
|
||||
forces outputs into one of these categories (or 'none of the above'), effectively
|
||||
doing entity resolution deterministically.
|
||||
"""
|
||||
|
||||
categories: List[str] = Field(
|
||||
...,
|
||||
description="List of valid category values that the map output should be constrained to. Should include all distinct canonical values discovered from analyzing the data/task.",
|
||||
)
|
||||
category_key: str = Field(
|
||||
...,
|
||||
description="The key in the output schema that will contain the category value",
|
||||
)
|
||||
new_prompt: str = Field(
|
||||
...,
|
||||
description="The new prompt for the Map operation that includes the category list and instructs the LLM to output one of the predefined categories. Must reference {{ input.key }} for input fields.",
|
||||
)
|
||||
include_none_of_above: bool = Field(
|
||||
default=True,
|
||||
description="Whether to include 'None of the above' as a valid category option for items that don't fit any category",
|
||||
)
|
||||
|
||||
@field_validator("categories")
|
||||
@classmethod
|
||||
def check_categories(cls, v: List[str]) -> List[str]:
|
||||
if len(v) < 2:
|
||||
raise ValueError("At least 2 categories must be provided")
|
||||
if len(v) != len(set(v)):
|
||||
raise ValueError("Categories must be unique")
|
||||
return v
|
||||
|
||||
@field_validator("new_prompt")
|
||||
@classmethod
|
||||
def check_new_prompt(cls, v: str) -> str:
|
||||
return MapOpConfig.validate_prompt_contains_input_key(v)
|
||||
|
||||
|
||||
class ArbitraryRewriteInstantiateSchema(BaseModel):
|
||||
"""
|
||||
Schema for arbitrary pipeline rewrites using search/replace edits.
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ Simple apply tests for directive testing - just ensure apply() doesn't crash.
|
|||
# Simple apply tests - no pytest needed
|
||||
|
||||
from docetl.reasoning_optimizer.directives import (
|
||||
ChainingDirective,
|
||||
ChainingDirective,
|
||||
GleaningDirective,
|
||||
ReduceGleaningDirective,
|
||||
ReduceChainingDirective,
|
||||
|
|
@ -14,7 +14,6 @@ from docetl.reasoning_optimizer.directives import (
|
|||
OperatorFusionDirective,
|
||||
DocSummarizationDirective,
|
||||
IsolatingSubtasksDirective,
|
||||
DocCompressionDirective,
|
||||
DeterministicDocCompressionDirective,
|
||||
DocumentChunkingDirective,
|
||||
DocumentChunkingTopKDirective,
|
||||
|
|
@ -22,8 +21,12 @@ from docetl.reasoning_optimizer.directives import (
|
|||
TakeHeadTailDirective,
|
||||
ClarifyInstructionsDirective,
|
||||
SwapWithCodeDirective,
|
||||
HierarchicalReduceDirective
|
||||
HierarchicalReduceDirective,
|
||||
MapToMapResolveReduceDirective,
|
||||
MapResolveToMapWithCategoriesDirective,
|
||||
)
|
||||
# DocCompressionDirective is commented out in __init__.py, import directly
|
||||
from docetl.reasoning_optimizer.directives.doc_compression import DocCompressionDirective
|
||||
|
||||
|
||||
def test_chaining_apply():
|
||||
|
|
@ -1012,6 +1015,237 @@ def test_cascade_filtering_apply():
|
|||
assert "high-quality research paper" in result[5]["prompt"]
|
||||
|
||||
|
||||
def test_map_to_map_resolve_reduce_apply():
|
||||
"""Test that map_to_map_resolve_reduce apply doesn't crash"""
|
||||
directive = MapToMapResolveReduceDirective()
|
||||
|
||||
# Pipeline with map followed by reduce
|
||||
ops_list = [
|
||||
{
|
||||
"name": "extract_companies",
|
||||
"type": "map",
|
||||
"prompt": "Extract company name from: {{ input.article }}",
|
||||
"model": "gpt-4o-mini",
|
||||
"output": {"schema": {"company_name": "string"}},
|
||||
},
|
||||
{
|
||||
"name": "aggregate_by_sector",
|
||||
"type": "reduce",
|
||||
"reduce_key": "sector",
|
||||
"prompt": "List companies:\n{% for input in inputs %}\n- {{ input.company_name }}\n{% endfor %}",
|
||||
"model": "gpt-4o-mini",
|
||||
"output": {"schema": {"companies": "list[str]"}},
|
||||
},
|
||||
]
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
MapToMapResolveReduceInstantiateSchema,
|
||||
)
|
||||
|
||||
rewrite = MapToMapResolveReduceInstantiateSchema(
|
||||
resolve_name="normalize_company_names",
|
||||
comparison_prompt="""Are these two company names referring to the same company?
|
||||
Company 1: {{ input1.company_name }}
|
||||
Company 2: {{ input2.company_name }}
|
||||
Consider variations like abbreviations (IBM vs International Business Machines).""",
|
||||
resolution_prompt="""Given these variations of a company name:
|
||||
{% for input in inputs %}
|
||||
- {{ input.company_name }}
|
||||
{% endfor %}
|
||||
Return the canonical/official company name.""",
|
||||
blocking_conditions=[
|
||||
"input1['company_name'][:3].lower() == input2['company_name'][:3].lower()",
|
||||
"input1['company_name'].split()[0].lower() == input2['company_name'].split()[0].lower()",
|
||||
],
|
||||
blocking_keys=["sector", "company_name"], # Must include reduce_key (sector)
|
||||
limit_comparisons=1000,
|
||||
)
|
||||
|
||||
result = directive.apply(
|
||||
"gpt-4o-mini", ops_list, "extract_companies", "aggregate_by_sector", rewrite
|
||||
)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 3 # map + resolve + reduce
|
||||
|
||||
# Check the resolve operation was inserted correctly
|
||||
assert result[0]["name"] == "extract_companies"
|
||||
assert result[0]["type"] == "map"
|
||||
|
||||
assert result[1]["name"] == "normalize_company_names"
|
||||
assert result[1]["type"] == "resolve"
|
||||
assert "comparison_prompt" in result[1]
|
||||
assert "resolution_prompt" in result[1]
|
||||
assert "blocking_conditions" in result[1]
|
||||
assert "blocking_keys" in result[1]
|
||||
assert result[1]["blocking_keys"] == ["sector", "company_name"]
|
||||
assert result[1]["limit_comparisons"] == 1000
|
||||
assert "input1" in result[1]["comparison_prompt"]
|
||||
assert "input2" in result[1]["comparison_prompt"]
|
||||
assert "inputs" in result[1]["resolution_prompt"]
|
||||
# Output schema should be derived from reduce_key
|
||||
assert result[1]["output"]["schema"] == {"sector": "string"}
|
||||
|
||||
assert result[2]["name"] == "aggregate_by_sector"
|
||||
assert result[2]["type"] == "reduce"
|
||||
|
||||
|
||||
def test_map_to_map_resolve_reduce_apply_with_multiple_reduce_keys():
|
||||
"""Test map_to_map_resolve_reduce with multiple reduce_keys"""
|
||||
directive = MapToMapResolveReduceDirective()
|
||||
|
||||
ops_list = [
|
||||
{
|
||||
"name": "extract_products",
|
||||
"type": "map",
|
||||
"prompt": "Extract product info: {{ input.description }}",
|
||||
"model": "gpt-4o-mini",
|
||||
"output": {"schema": {"product_name": "string", "category": "string"}},
|
||||
},
|
||||
{
|
||||
"name": "aggregate_products",
|
||||
"type": "reduce",
|
||||
"reduce_key": ["brand", "region"], # Multiple reduce keys
|
||||
"prompt": "List products:\n{% for input in inputs %}{{ input.product_name }}{% endfor %}",
|
||||
"model": "gpt-4o-mini",
|
||||
"output": {"schema": {"products": "list[str]"}},
|
||||
},
|
||||
]
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
MapToMapResolveReduceInstantiateSchema,
|
||||
)
|
||||
|
||||
rewrite = MapToMapResolveReduceInstantiateSchema(
|
||||
resolve_name="normalize_products",
|
||||
comparison_prompt="Same product? {{ input1.product_name }} vs {{ input2.product_name }}",
|
||||
resolution_prompt="Normalize: {% for input in inputs %}{{ input.product_name }}{% endfor %}",
|
||||
blocking_conditions=[
|
||||
"input1['category'] == input2['category']",
|
||||
],
|
||||
blocking_keys=["brand", "region", "product_name"], # Must include reduce_keys
|
||||
limit_comparisons=500,
|
||||
)
|
||||
|
||||
result = directive.apply(
|
||||
"gpt-4o-mini", ops_list, "extract_products", "aggregate_products", rewrite
|
||||
)
|
||||
assert result[1]["type"] == "resolve"
|
||||
assert "blocking_keys" in result[1]
|
||||
assert result[1]["blocking_keys"] == ["brand", "region", "product_name"]
|
||||
# Output schema should include all reduce_keys
|
||||
assert result[1]["output"]["schema"] == {"brand": "string", "region": "string"}
|
||||
|
||||
|
||||
def test_map_resolve_to_map_with_categories_apply():
|
||||
"""Test that map_resolve_to_map_with_categories apply doesn't crash"""
|
||||
directive = MapResolveToMapWithCategoriesDirective()
|
||||
|
||||
# Pipeline with map followed by resolve
|
||||
ops_list = [
|
||||
{
|
||||
"name": "extract_sentiment",
|
||||
"type": "map",
|
||||
"prompt": "What is the sentiment? {{ input.review }}",
|
||||
"model": "gpt-4o-mini",
|
||||
"output": {"schema": {"sentiment": "string"}},
|
||||
},
|
||||
{
|
||||
"name": "normalize_sentiment",
|
||||
"type": "resolve",
|
||||
"comparison_prompt": "Same sentiment? {{ input1.sentiment }} vs {{ input2.sentiment }}",
|
||||
"resolution_prompt": "Normalize: {% for input in inputs %}{{ input.sentiment }}{% endfor %}",
|
||||
"model": "gpt-4o-mini",
|
||||
"output": {"schema": {"sentiment": "string"}},
|
||||
},
|
||||
]
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
MapResolveToMapWithCategoriesInstantiateSchema,
|
||||
)
|
||||
|
||||
rewrite = MapResolveToMapWithCategoriesInstantiateSchema(
|
||||
categories=["Positive", "Negative", "Neutral", "Mixed"],
|
||||
category_key="sentiment",
|
||||
new_prompt="""Classify the sentiment of this review into one of these categories:
|
||||
- Positive: Clearly positive sentiment
|
||||
- Negative: Clearly negative sentiment
|
||||
- Neutral: No strong sentiment
|
||||
- Mixed: Both positive and negative
|
||||
- None of the above
|
||||
|
||||
Review: {{ input.review }}
|
||||
|
||||
Return exactly one category.""",
|
||||
include_none_of_above=True,
|
||||
)
|
||||
|
||||
result = directive.apply(
|
||||
"gpt-4o-mini", ops_list, "extract_sentiment", "normalize_sentiment", rewrite
|
||||
)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1 # map only, resolve removed
|
||||
|
||||
# Check the map operation was modified correctly
|
||||
assert result[0]["name"] == "extract_sentiment"
|
||||
assert result[0]["type"] == "map"
|
||||
assert "Positive" in result[0]["prompt"]
|
||||
assert "Negative" in result[0]["prompt"]
|
||||
assert "None of the above" in result[0]["prompt"]
|
||||
|
||||
# Check validation was added
|
||||
assert "validate" in result[0]
|
||||
assert len(result[0]["validate"]) == 1
|
||||
assert "Positive" in result[0]["validate"][0]
|
||||
assert "None of the above" in result[0]["validate"][0]
|
||||
|
||||
|
||||
def test_map_resolve_to_map_with_categories_no_none_of_above():
|
||||
"""Test map_resolve_to_map_with_categories without 'None of the above' option"""
|
||||
directive = MapResolveToMapWithCategoriesDirective()
|
||||
|
||||
ops_list = [
|
||||
{
|
||||
"name": "classify_type",
|
||||
"type": "map",
|
||||
"prompt": "What type is this? {{ input.text }}",
|
||||
"model": "gpt-4o-mini",
|
||||
"output": {"schema": {"item_type": "string"}},
|
||||
},
|
||||
{
|
||||
"name": "normalize_type",
|
||||
"type": "resolve",
|
||||
"comparison_prompt": "Same type? {{ input1.item_type }} vs {{ input2.item_type }}",
|
||||
"resolution_prompt": "Normalize: {% for input in inputs %}{{ input.item_type }}{% endfor %}",
|
||||
"model": "gpt-4o-mini",
|
||||
"output": {"schema": {"item_type": "string"}},
|
||||
},
|
||||
]
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
MapResolveToMapWithCategoriesInstantiateSchema,
|
||||
)
|
||||
|
||||
rewrite = MapResolveToMapWithCategoriesInstantiateSchema(
|
||||
categories=["TypeA", "TypeB", "TypeC"],
|
||||
category_key="item_type",
|
||||
new_prompt="""Classify into: TypeA, TypeB, or TypeC
|
||||
Text: {{ input.text }}""",
|
||||
include_none_of_above=False,
|
||||
)
|
||||
|
||||
result = directive.apply(
|
||||
"gpt-4o-mini", ops_list, "classify_type", "normalize_type", rewrite
|
||||
)
|
||||
assert len(result) == 1
|
||||
|
||||
# Check validation does NOT include 'None of the above'
|
||||
assert "validate" in result[0]
|
||||
assert "None of the above" not in result[0]["validate"][0]
|
||||
assert "TypeA" in result[0]["validate"][0]
|
||||
assert "TypeB" in result[0]["validate"][0]
|
||||
assert "TypeC" in result[0]["validate"][0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run all tests
|
||||
test_chaining_apply()
|
||||
|
|
@ -1078,4 +1312,16 @@ if __name__ == "__main__":
|
|||
test_cascade_filtering_apply()
|
||||
print("✅ Cascade filtering apply test passed")
|
||||
|
||||
test_map_to_map_resolve_reduce_apply()
|
||||
print("✅ Map to map resolve reduce apply test passed")
|
||||
|
||||
test_map_to_map_resolve_reduce_apply_with_multiple_reduce_keys()
|
||||
print("✅ Map to map resolve reduce with multiple reduce keys apply test passed")
|
||||
|
||||
test_map_resolve_to_map_with_categories_apply()
|
||||
print("✅ Map resolve to map with categories apply test passed")
|
||||
|
||||
test_map_resolve_to_map_with_categories_no_none_of_above()
|
||||
print("✅ Map resolve to map with categories (no none of above) apply test passed")
|
||||
|
||||
print("\n🎉 All directive apply tests passed!")
|
||||
|
|
|
|||
|
|
@ -10,14 +10,13 @@ from typing import Dict, List
|
|||
from datetime import datetime
|
||||
|
||||
from docetl.reasoning_optimizer.directives import (
|
||||
ChainingDirective,
|
||||
ChainingDirective,
|
||||
GleaningDirective,
|
||||
ReduceGleaningDirective,
|
||||
ReduceChainingDirective,
|
||||
ChangeModelDirective,
|
||||
DocSummarizationDirective,
|
||||
IsolatingSubtasksDirective,
|
||||
DocCompressionDirective,
|
||||
DeterministicDocCompressionDirective,
|
||||
OperatorFusionDirective,
|
||||
DocumentChunkingDirective,
|
||||
|
|
@ -28,8 +27,12 @@ from docetl.reasoning_optimizer.directives import (
|
|||
SwapWithCodeDirective,
|
||||
HierarchicalReduceDirective,
|
||||
CascadeFilteringDirective,
|
||||
TestResult
|
||||
MapToMapResolveReduceDirective,
|
||||
MapResolveToMapWithCategoriesDirective,
|
||||
TestResult,
|
||||
)
|
||||
# DocCompressionDirective is commented out in __init__.py, import directly
|
||||
from docetl.reasoning_optimizer.directives.doc_compression import DocCompressionDirective
|
||||
|
||||
def run_all_directive_tests(agent_llm: str = "gpt-4.1") -> Dict[str, List[TestResult]]:
|
||||
"""
|
||||
|
|
@ -68,6 +71,8 @@ def run_all_directive_tests(agent_llm: str = "gpt-4.1") -> Dict[str, List[TestRe
|
|||
SwapWithCodeDirective(),
|
||||
HierarchicalReduceDirective(),
|
||||
CascadeFilteringDirective(),
|
||||
MapToMapResolveReduceDirective(),
|
||||
MapResolveToMapWithCategoriesDirective(),
|
||||
]
|
||||
|
||||
all_results = {}
|
||||
|
|
@ -176,6 +181,8 @@ def run_specific_directive_test(directive_name: str, agent_llm: str = "gpt-4o-mi
|
|||
"clarify_instructions": ClarifyInstructionsDirective(),
|
||||
"swap_with_code": SwapWithCodeDirective(),
|
||||
"cascade_filtering": CascadeFilteringDirective(),
|
||||
"map_to_map_resolve_reduce": MapToMapResolveReduceDirective(),
|
||||
"map_resolve_to_map_with_categories": MapResolveToMapWithCategoriesDirective(),
|
||||
}
|
||||
|
||||
if directive_name.lower() not in directive_map:
|
||||
|
|
|
|||
Loading…
Reference in New Issue