Compare commits

...

1 Commits

Author SHA1 Message Date
Shreya Shankar ae06d5876e optimizer: add directives for resolve operator 2025-12-27 23:15:30 -06:00
7 changed files with 1084 additions and 12 deletions

View File

@ -30,6 +30,8 @@ from .map_reduce_fusion import MapReduceFusionDirective
from .hierarchical_reduce import HierarchicalReduceDirective
from .cascade_filtering import CascadeFilteringDirective
from .arbitrary_rewrite import ArbitraryRewriteDirective
from .map_to_map_resolve_reduce import MapToMapResolveReduceDirective
from .map_resolve_to_map_with_categories import MapResolveToMapWithCategoriesDirective
# Registry of all available directives
ALL_DIRECTIVES = [
@ -53,6 +55,8 @@ ALL_DIRECTIVES = [
HierarchicalReduceDirective(),
CascadeFilteringDirective(),
ArbitraryRewriteDirective(),
MapToMapResolveReduceDirective(),
MapResolveToMapWithCategoriesDirective(),
]
ALL_COST_DIRECTIVES = [
@ -179,8 +183,10 @@ __all__ = [
"HierarchicalReduceDirective",
"CascadeFilteringDirective",
"ArbitraryRewriteDirective",
"MapToMapResolveReduceDirective",
"MapResolveToMapWithCategoriesDirective",
"ALL_DIRECTIVES",
"DIRECTIVE_REGISTRY",
"DIRECTIVE_REGISTRY",
"get_all_directive_strings",
"instantiate_directive"
]

View File

@ -123,7 +123,8 @@ class Directive(BaseModel, ABC):
try:
# 1. Execute the directive
actual_output, _ = self.instantiate(
# instantiate returns (new_ops_plan, message_history, call_cost)
instantiate_result = self.instantiate(
operators=(
[test_case.input_config]
if isinstance(test_case.input_config, dict)
@ -134,6 +135,11 @@ class Directive(BaseModel, ABC):
input_file_path=temp_file_path,
pipeline_code=fake_pipeline,
)
# Handle both 2-tuple and 3-tuple returns
if isinstance(instantiate_result, tuple):
actual_output = instantiate_result[0]
else:
actual_output = instantiate_result
# 2. Use LLM judge to evaluate
judge_result = self._llm_judge_test(

View File

@ -0,0 +1,366 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
MapResolveToMapWithCategoriesInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class MapResolveToMapWithCategoriesDirective(Directive):
name: str = Field(
default="map_resolve_to_map_with_categories",
description="The name of the directive",
)
formal_description: str = Field(default="Map -> Resolve => Map (with categories)")
nl_description: str = Field(
default="Replace a Map -> Resolve pattern with a single Map operation that has predefined categories. The agent analyzes the data and task to propose a set of canonical categories, and the new Map forces outputs into one of these categories (or 'none of the above'). This effectively performs entity resolution deterministically by standardizing outputs upfront, avoiding the need for pairwise comparisons."
)
when_to_use: str = Field(
default="When a Map operation produces outputs that are then resolved/deduplicated, and the set of valid output categories can be enumerated upfront. This is more efficient than Resolve when the category space is small and well-defined (e.g., standardizing company types, product categories, sentiment labels). The target must be a Map operation followed by a Resolve operation."
)
instantiate_schema_type: Type[BaseModel] = (
MapResolveToMapWithCategoriesInstantiateSchema
)
example: str = Field(
default=(
"Original Pipeline:\n"
"- name: extract_sentiment\n"
" type: map\n"
" prompt: |\n"
" What is the sentiment of this review?\n"
" {{ input.review }}\n"
" output:\n"
" schema:\n"
" sentiment: string\n"
"\n"
"- name: normalize_sentiment\n"
" type: resolve\n"
" comparison_prompt: |\n"
" Are these sentiments equivalent?\n"
" Sentiment 1: {{ input1.sentiment }}\n"
" Sentiment 2: {{ input2.sentiment }}\n"
" resolution_prompt: |\n"
" Normalize these sentiments:\n"
" {% for input in inputs %}\n"
" - {{ input.sentiment }}\n"
" {% endfor %}\n"
" output:\n"
" schema:\n"
" sentiment: string\n"
"\n"
"Example InstantiateSchema:\n"
"MapResolveToMapWithCategoriesInstantiateSchema(\n"
" categories=['Positive', 'Negative', 'Neutral', 'Mixed'],\n"
" category_key='sentiment',\n"
" new_prompt='''Analyze the sentiment of this review and classify it into one of the following categories:\n"
"\n"
"Categories:\n"
"- Positive: Clearly positive sentiment, satisfaction, praise\n"
"- Negative: Clearly negative sentiment, complaints, criticism\n"
"- Neutral: No strong sentiment, factual statements\n"
"- Mixed: Contains both positive and negative elements\n"
"- None of the above: If the review doesn't fit any category\n"
"\n"
"Review: {{ input.review }}\n"
"\n"
"Return exactly one of: Positive, Negative, Neutral, Mixed, or None of the above.''',\n"
" include_none_of_above=True,\n"
")"
),
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="sentiment_categorization",
description="Should replace map+resolve with categorized map for sentiment",
input_config=[
{
"name": "extract_sentiment",
"type": "map",
"prompt": "What is the sentiment? {{ input.text }}",
"output": {"schema": {"sentiment": "string"}},
},
{
"name": "normalize_sentiment",
"type": "resolve",
"comparison_prompt": "Same sentiment? {{ input1.sentiment }} vs {{ input2.sentiment }}",
"resolution_prompt": "Normalize: {% for input in inputs %}{{ input.sentiment }}{% endfor %}",
"output": {"schema": {"sentiment": "string"}},
},
],
target_ops=["extract_sentiment", "normalize_sentiment"],
expected_behavior="Should create a single map with predefined sentiment categories",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, MapResolveToMapWithCategoriesDirective)
def __hash__(self):
return hash("MapResolveToMapWithCategoriesDirective")
def to_string_for_instantiate(
self, map_op: Dict, resolve_op: Dict, sample_data: List[Dict] = None
) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
map_op (Dict): The map operation configuration.
resolve_op (Dict): The resolve operation configuration.
sample_data (List[Dict], optional): Sample data to help identify categories.
Returns:
str: The agent prompt for instantiating the directive.
"""
sample_str = ""
if sample_data:
sample_str = f"\n\nSample Input Data (first 5 items):\n{json.dumps(sample_data[:5], indent=2)}"
return (
f"You are an expert at optimizing data processing pipelines by replacing entity resolution with categorical constraints.\n\n"
f"Map Operation:\n"
f"{str(map_op)}\n\n"
f"Resolve Operation:\n"
f"{str(resolve_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to replace the Map -> Resolve pattern with a single Map that uses predefined categories.\n\n"
f"Key Requirements:\n"
f"1. Analyze the map's output field and the resolve operation to understand what values are being normalized:\n"
f" - Look at the comparison_prompt to understand what variations are being matched\n"
f" - Look at the resolution_prompt to understand the canonical form\n\n"
f"2. Propose a set of categories that cover all expected outputs:\n"
f" - Categories should be mutually exclusive\n"
f" - Categories should be exhaustive (cover all realistic cases)\n"
f" - Consider including 'None of the above' for edge cases\n\n"
f"3. Optionally provide descriptions for each category to help the LLM classify correctly\n\n"
f"4. Create a new_prompt that:\n"
f" - Lists all valid categories with their descriptions\n"
f" - Instructs the LLM to output exactly one category\n"
f" - References the input using {{{{ input.key }}}} syntax\n"
f" - Includes any context from the original map prompt\n\n"
f"5. Identify the category_key (the output field that will contain the category)\n\n"
f"Benefits of this approach:\n"
f"- Eliminates O(n^2) pairwise comparisons from Resolve\n"
f"- Produces consistent, standardized outputs\n"
f"- Reduces cost by removing the Resolve operation entirely\n"
f"{sample_str}\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please analyze the operations and propose appropriate categories. Output the MapResolveToMapWithCategoriesInstantiateSchema."
)
def llm_instantiate(
self,
map_op: Dict,
resolve_op: Dict,
agent_llm: str,
message_history: list = [],
sample_data: List[Dict] = None,
):
"""
Use LLM to instantiate this directive.
Args:
map_op (Dict): The map operation configuration.
resolve_op (Dict): The resolve operation configuration.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
sample_data (List[Dict], optional): Sample data to help identify categories.
Returns:
MapResolveToMapWithCategoriesInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for document processing pipelines specializing in categorical classification.",
},
{
"role": "user",
"content": self.to_string_for_instantiate(
map_op, resolve_op, sample_data
),
},
]
)
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=MapResolveToMapWithCategoriesInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = MapResolveToMapWithCategoriesInstantiateSchema(**parsed_res)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(
self,
global_default_model,
ops_list: List[Dict],
map_op_name: str,
resolve_op_name: str,
rewrite: MapResolveToMapWithCategoriesInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config.
"""
new_ops_list = deepcopy(ops_list)
# Find position of the map and resolve ops
map_pos = None
resolve_pos = None
map_op = None
for i, op in enumerate(new_ops_list):
if op["name"] == map_op_name:
map_pos = i
map_op = op
elif op["name"] == resolve_op_name:
resolve_pos = i
if map_pos is None or resolve_pos is None:
raise ValueError(
f"Could not find map '{map_op_name}' and resolve '{resolve_op_name}' operations"
)
# Determine the model to use
default_model = map_op.get("model", global_default_model)
# Build the list of valid values for validation
valid_values = list(rewrite.categories)
if rewrite.include_none_of_above:
valid_values.append("None of the above")
# Modify the map operation with the new prompt and add validation
new_ops_list[map_pos]["prompt"] = rewrite.new_prompt
new_ops_list[map_pos]["model"] = default_model
# Add validation to ensure output is one of the categories
if "validate" not in new_ops_list[map_pos]:
new_ops_list[map_pos]["validate"] = []
# Add validation rule for the category key
validation_rule = f"output['{rewrite.category_key}'] in {valid_values}"
new_ops_list[map_pos]["validate"].append(validation_rule)
# Update the output schema to reflect the category key
if "output" not in new_ops_list[map_pos]:
new_ops_list[map_pos]["output"] = {"schema": {}}
new_ops_list[map_pos]["output"]["schema"][rewrite.category_key] = "string"
# Remove the resolve operation
new_ops_list.pop(resolve_pos)
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
dataset: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there are exactly two target ops (map and resolve)
assert (
len(target_ops) == 2
), "There must be exactly two target ops (map and resolve) to instantiate this directive"
# Find the map and resolve operations
map_op = None
resolve_op = None
for op in operators:
if op["name"] == target_ops[0]:
if op.get("type") == "map":
map_op = op
elif op.get("type") == "resolve":
resolve_op = op
elif op["name"] == target_ops[1]:
if op.get("type") == "map":
map_op = op
elif op.get("type") == "resolve":
resolve_op = op
if map_op is None or resolve_op is None:
raise ValueError(
f"Could not find both a map and resolve operation in target_ops: {target_ops}"
)
# Verify the map comes before resolve
map_idx = next(
i for i, op in enumerate(operators) if op["name"] == map_op["name"]
)
resolve_idx = next(
i for i, op in enumerate(operators) if op["name"] == resolve_op["name"]
)
if map_idx >= resolve_idx:
raise ValueError(
f"Map operation '{map_op['name']}' must come before resolve operation '{resolve_op['name']}'"
)
# Load sample data if available
sample_data = None
if dataset:
try:
with open(dataset, "r") as f:
sample_data = json.load(f)
except Exception:
pass # Ignore if we can't load sample data
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
map_op,
resolve_op,
agent_llm,
message_history,
sample_data,
)
# Apply the rewrite to the operators
new_ops_plan = self.apply(
global_default_model, operators, map_op["name"], resolve_op["name"], rewrite
)
return new_ops_plan, message_history, call_cost

View File

@ -0,0 +1,340 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
MapToMapResolveReduceInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class MapToMapResolveReduceDirective(Directive):
name: str = Field(
default="map_to_map_resolve_reduce", description="The name of the directive"
)
formal_description: str = Field(default="Map -> Reduce => Map -> Resolve -> Reduce")
nl_description: str = Field(
default="Insert a Resolve operation between Map and Reduce to deduplicate or normalize entities before aggregation. The Resolve operation uses code-powered blocking conditions to efficiently identify which pairs to compare, avoiding O(n^2) comparisons. This is useful when the Map output contains duplicate or near-duplicate entities that should be merged before the Reduce step."
)
when_to_use: str = Field(
default="When a Map operation produces outputs that may contain duplicates, variations, or near-duplicates (e.g., different spellings of names, similar categories), and these should be normalized before the Reduce aggregation step. The target must be a Map operation followed by a Reduce operation."
)
instantiate_schema_type: Type[BaseModel] = MapToMapResolveReduceInstantiateSchema
example: str = Field(
default=(
"Original Pipeline:\n"
"- name: extract_companies\n"
" type: map\n"
" prompt: |\n"
" Extract company names from this news article:\n"
" {{ input.article }}\n"
" output:\n"
" schema:\n"
" company_name: string\n"
"\n"
"- name: aggregate_companies\n"
" type: reduce\n"
" reduce_key: sector\n"
" prompt: |\n"
" List all unique companies in this sector:\n"
" {% for input in inputs %}\n"
" - {{ input.company_name }}\n"
" {% endfor %}\n"
" output:\n"
" schema:\n"
" companies: list[str]\n"
"\n"
"Example InstantiateSchema:\n"
"MapToMapResolveReduceInstantiateSchema(\n"
" resolve_name='normalize_company_names',\n"
" comparison_prompt='''Are these two company names referring to the same company?\n"
"Company 1: {{ input1.company_name }}\n"
"Company 2: {{ input2.company_name }}\n"
"Consider variations like abbreviations (IBM vs International Business Machines), \n"
"different legal suffixes (Inc, Corp, LLC), and common misspellings.''',\n"
" resolution_prompt='''Given these variations of a company name:\n"
"{% for input in inputs %}\n"
"- {{ input.company_name }}\n"
"{% endfor %}\n"
"Return the canonical/official company name.''',\n"
" blocking_conditions=[\n"
" \"input1['company_name'][:3].lower() == input2['company_name'][:3].lower()\",\n"
" \"input1['company_name'].split()[0].lower() == input2['company_name'].split()[0].lower()\",\n"
" ],\n"
" blocking_keys=['company_name'],\n"
" limit_comparisons=1000,\n"
" output_schema={'company_name': 'string'},\n"
")"
),
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="company_name_normalization",
description="Should insert resolve between map and reduce for company names",
input_config=[
{
"name": "extract_companies",
"type": "map",
"prompt": "Extract company name from: {{ input.text }}",
"output": {"schema": {"company_name": "string"}},
},
{
"name": "aggregate_by_sector",
"type": "reduce",
"reduce_key": "sector",
"prompt": "List companies:\n{% for input in inputs %}\n- {{ input.company_name }}\n{% endfor %}",
"output": {"schema": {"companies": "list[str]"}},
},
],
target_ops=["extract_companies", "aggregate_by_sector"],
expected_behavior="Should create a resolve operation between the map and reduce with appropriate blocking conditions",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, MapToMapResolveReduceDirective)
def __hash__(self):
return hash("MapToMapResolveReduceDirective")
def to_string_for_instantiate(self, map_op: Dict, reduce_op: Dict) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
map_op (Dict): The map operation configuration.
reduce_op (Dict): The reduce operation configuration.
Returns:
str: The agent prompt for instantiating the directive.
"""
return (
f"You are an expert at optimizing data processing pipelines by inserting entity resolution steps.\n\n"
f"Map Operation:\n"
f"{str(map_op)}\n\n"
f"Reduce Operation:\n"
f"{str(reduce_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to insert a Resolve operation between the Map and Reduce to deduplicate/normalize entities.\n\n"
f"Key Requirements:\n"
f"1. Create a comparison_prompt that determines if two items from the Map output are duplicates/variations:\n"
f" - Must reference {{ input1.key }} and {{ input2.key }} for comparing fields\n"
f" - Should handle common variations (abbreviations, misspellings, formatting differences)\n\n"
f"2. Create a resolution_prompt that merges matched items into a canonical form:\n"
f" - Must use {{% for input in inputs %}} to iterate over matched items\n"
f" - Should produce the most authoritative/complete version\n\n"
f"3. Create blocking_conditions to avoid O(n^2) comparisons:\n"
f" - These are Python expressions with access to 'input1' and 'input2' dicts\n"
f" - They should filter pairs to only those likely to match\n"
f" - Examples:\n"
f" * \"input1['name'][:3].lower() == input2['name'][:3].lower()\" (first 3 chars match)\n"
f" * \"input1['name'].split()[0].lower() == input2['name'].split()[0].lower()\" (first word matches)\n"
f" * \"abs(len(input1['name']) - len(input2['name'])) < 10\" (similar length)\n"
f" - Multiple conditions are OR'd together\n\n"
f"4. Set limit_comparisons to cap the number of pairs (recommended: 500-2000)\n\n"
f"5. The output_schema should match what the Reduce operation expects from each input\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please output the MapToMapResolveReduceInstantiateSchema."
)
def llm_instantiate(
self, map_op: Dict, reduce_op: Dict, agent_llm: str, message_history: list = []
):
"""
Use LLM to instantiate this directive.
Args:
map_op (Dict): The map operation configuration.
reduce_op (Dict): The reduce operation configuration.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
Returns:
MapToMapResolveReduceInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for document processing pipelines specializing in entity resolution.",
},
{
"role": "user",
"content": self.to_string_for_instantiate(map_op, reduce_op),
},
]
)
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=MapToMapResolveReduceInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = MapToMapResolveReduceInstantiateSchema(**parsed_res)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(
self,
global_default_model,
ops_list: List[Dict],
map_op_name: str,
reduce_op_name: str,
rewrite: MapToMapResolveReduceInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config.
"""
new_ops_list = deepcopy(ops_list)
# Find position of the map and reduce ops
map_pos = None
reduce_pos = None
map_op = None
for i, op in enumerate(ops_list):
if op["name"] == map_op_name:
map_pos = i
map_op = op
elif op["name"] == reduce_op_name:
reduce_pos = i
if map_pos is None or reduce_pos is None:
raise ValueError(
f"Could not find map '{map_op_name}' and reduce '{reduce_op_name}' operations"
)
# Determine the model to use
default_model = map_op.get("model", global_default_model)
# Find the reduce operation to get the reduce_key
reduce_op = None
for op in ops_list:
if op["name"] == reduce_op_name:
reduce_op = op
break
# Derive output schema from reduce_key - that's what's being grouped/resolved
reduce_key = reduce_op.get("reduce_key", []) if reduce_op else []
if isinstance(reduce_key, str):
reduce_key = [reduce_key]
# Build output schema from reduce_key fields
output_schema = {key: "string" for key in reduce_key}
# Create the resolve operation
resolve_op = {
"name": rewrite.resolve_name,
"type": "resolve",
"comparison_prompt": rewrite.comparison_prompt,
"resolution_prompt": rewrite.resolution_prompt,
"blocking_conditions": rewrite.blocking_conditions,
"blocking_keys": rewrite.blocking_keys,
"limit_comparisons": rewrite.limit_comparisons,
"model": default_model,
"output": {"schema": output_schema},
}
# Insert resolve operation after the map operation
new_ops_list.insert(map_pos + 1, resolve_op)
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there are exactly two target ops (map and reduce)
assert (
len(target_ops) == 2
), "There must be exactly two target ops (map and reduce) to instantiate this directive"
# Find the map and reduce operations
map_op = None
reduce_op = None
for op in operators:
if op["name"] == target_ops[0]:
if op.get("type") == "map":
map_op = op
elif op.get("type") == "reduce":
reduce_op = op
elif op["name"] == target_ops[1]:
if op.get("type") == "map":
map_op = op
elif op.get("type") == "reduce":
reduce_op = op
if map_op is None or reduce_op is None:
raise ValueError(
f"Could not find both a map and reduce operation in target_ops: {target_ops}"
)
# Verify the map comes before reduce
map_idx = next(
i for i, op in enumerate(operators) if op["name"] == map_op["name"]
)
reduce_idx = next(
i for i, op in enumerate(operators) if op["name"] == reduce_op["name"]
)
if map_idx >= reduce_idx:
raise ValueError(
f"Map operation '{map_op['name']}' must come before reduce operation '{reduce_op['name']}'"
)
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
map_op,
reduce_op,
agent_llm,
message_history,
)
# Apply the rewrite to the operators
new_ops_plan = self.apply(
global_default_model, operators, map_op["name"], reduce_op["name"], rewrite
)
return new_ops_plan, message_history, call_cost

View File

@ -61,7 +61,7 @@ class MapOpConfig(BaseModel):
...,
description="The keys of the output of the Map operator, to be referenced in the downstream operator's prompt. Can be a single key or a list of keys. Can be new keys or existing keys from the map operator we are rewriting.",
)
@classmethod
def validate_prompt_contains_input_key(cls, value: str) -> str:
"""
@ -563,7 +563,6 @@ class ChunkHeaderSummaryInstantiateSchema(BaseModel):
return v
class SamplingConfig(BaseModel):
"""Configuration for optional sampling in document chunking."""
@ -634,7 +633,7 @@ class DocumentChunkingInstantiateSchema(BaseModel):
default=None,
description="Optional sampling configuration. If provided, inserts a Sample operation between Gather and Map. Use by default UNLESS task requires processing ALL chunks (like comprehensive extraction of all instances).",
)
def validate_split_key_exists_in_input(self, input_file_path: str) -> None:
"""
Validates that the split_key exists in the input JSON file items.
@ -812,7 +811,6 @@ class DocumentChunkingTopKInstantiateSchema(BaseModel):
description="Configuration for the topk operation to select relevant chunks",
)
def validate_split_key_exists_in_input(self, input_file_path: str) -> None:
"""
Validates that the split_key exists in the input JSON file items.
@ -1211,6 +1209,109 @@ class SearchReplaceEdit(BaseModel):
# )
class MapToMapResolveReduceInstantiateSchema(BaseModel):
"""
Schema for inserting a Resolve operation between Map and Reduce.
Transforms Map -> Reduce => Map -> Resolve -> Reduce pattern.
The Resolve operation deduplicates/normalizes entities before aggregation.
"""
resolve_name: str = Field(..., description="The name of the new Resolve operator")
comparison_prompt: str = Field(
...,
description="Jinja prompt template for comparing two items. Must use {{ input1.key }} and {{ input2.key }} to reference fields from both items being compared.",
)
resolution_prompt: str = Field(
...,
description="Jinja prompt template for resolving matched items into a canonical form. Must use {% for input in inputs %} to iterate over matched items.",
)
blocking_conditions: List[str] = Field(
...,
description="List of Python expressions that determine if two items should be compared. Each expression has access to 'input1' and 'input2' dicts. Example: \"input1['category'].lower() == input2['category'].lower()\"",
)
blocking_keys: List[str] = Field(
...,
description="Keys to use for blocking/grouping items before comparison. Must include at least the reduce_key from the downstream Reduce operation, plus any additional context helpful for resolution.",
)
limit_comparisons: int = Field(
default=15000,
description="Maximum number of pairs to compare. Code-based blocked pairs are prioritized. Defaults to 15000 to avoid O(n^2) comparisons.",
gt=0,
)
@field_validator("comparison_prompt")
@classmethod
def check_comparison_prompt(cls, v: str) -> str:
if "input1" not in v or "input2" not in v:
raise ValueError(
"comparison_prompt must reference both 'input1' and 'input2' variables"
)
return v
@field_validator("resolution_prompt")
@classmethod
def check_resolution_prompt(cls, v: str) -> str:
if "inputs" not in v:
raise ValueError(
"resolution_prompt must reference 'inputs' variable for iterating over matched items"
)
return v
@field_validator("blocking_conditions")
@classmethod
def check_blocking_conditions(cls, v: List[str]) -> List[str]:
if not v:
raise ValueError(
"At least one blocking condition must be provided to avoid O(n^2) comparisons"
)
for condition in v:
if "input1" not in condition or "input2" not in condition:
raise ValueError(
f"Blocking condition must reference both 'input1' and 'input2': {condition}"
)
return v
class MapResolveToMapWithCategoriesInstantiateSchema(BaseModel):
"""
Schema for replacing Map -> Resolve with a single Map that has predefined categories.
The agent proposes categories based on analysis of the data/task, and the new Map
forces outputs into one of these categories (or 'none of the above'), effectively
doing entity resolution deterministically.
"""
categories: List[str] = Field(
...,
description="List of valid category values that the map output should be constrained to. Should include all distinct canonical values discovered from analyzing the data/task.",
)
category_key: str = Field(
...,
description="The key in the output schema that will contain the category value",
)
new_prompt: str = Field(
...,
description="The new prompt for the Map operation that includes the category list and instructs the LLM to output one of the predefined categories. Must reference {{ input.key }} for input fields.",
)
include_none_of_above: bool = Field(
default=True,
description="Whether to include 'None of the above' as a valid category option for items that don't fit any category",
)
@field_validator("categories")
@classmethod
def check_categories(cls, v: List[str]) -> List[str]:
if len(v) < 2:
raise ValueError("At least 2 categories must be provided")
if len(v) != len(set(v)):
raise ValueError("Categories must be unique")
return v
@field_validator("new_prompt")
@classmethod
def check_new_prompt(cls, v: str) -> str:
return MapOpConfig.validate_prompt_contains_input_key(v)
class ArbitraryRewriteInstantiateSchema(BaseModel):
"""
Schema for arbitrary pipeline rewrites using search/replace edits.

View File

@ -6,7 +6,7 @@ Simple apply tests for directive testing - just ensure apply() doesn't crash.
# Simple apply tests - no pytest needed
from docetl.reasoning_optimizer.directives import (
ChainingDirective,
ChainingDirective,
GleaningDirective,
ReduceGleaningDirective,
ReduceChainingDirective,
@ -14,7 +14,6 @@ from docetl.reasoning_optimizer.directives import (
OperatorFusionDirective,
DocSummarizationDirective,
IsolatingSubtasksDirective,
DocCompressionDirective,
DeterministicDocCompressionDirective,
DocumentChunkingDirective,
DocumentChunkingTopKDirective,
@ -22,8 +21,12 @@ from docetl.reasoning_optimizer.directives import (
TakeHeadTailDirective,
ClarifyInstructionsDirective,
SwapWithCodeDirective,
HierarchicalReduceDirective
HierarchicalReduceDirective,
MapToMapResolveReduceDirective,
MapResolveToMapWithCategoriesDirective,
)
# DocCompressionDirective is commented out in __init__.py, import directly
from docetl.reasoning_optimizer.directives.doc_compression import DocCompressionDirective
def test_chaining_apply():
@ -1012,6 +1015,237 @@ def test_cascade_filtering_apply():
assert "high-quality research paper" in result[5]["prompt"]
def test_map_to_map_resolve_reduce_apply():
"""Test that map_to_map_resolve_reduce apply doesn't crash"""
directive = MapToMapResolveReduceDirective()
# Pipeline with map followed by reduce
ops_list = [
{
"name": "extract_companies",
"type": "map",
"prompt": "Extract company name from: {{ input.article }}",
"model": "gpt-4o-mini",
"output": {"schema": {"company_name": "string"}},
},
{
"name": "aggregate_by_sector",
"type": "reduce",
"reduce_key": "sector",
"prompt": "List companies:\n{% for input in inputs %}\n- {{ input.company_name }}\n{% endfor %}",
"model": "gpt-4o-mini",
"output": {"schema": {"companies": "list[str]"}},
},
]
from docetl.reasoning_optimizer.instantiate_schemas import (
MapToMapResolveReduceInstantiateSchema,
)
rewrite = MapToMapResolveReduceInstantiateSchema(
resolve_name="normalize_company_names",
comparison_prompt="""Are these two company names referring to the same company?
Company 1: {{ input1.company_name }}
Company 2: {{ input2.company_name }}
Consider variations like abbreviations (IBM vs International Business Machines).""",
resolution_prompt="""Given these variations of a company name:
{% for input in inputs %}
- {{ input.company_name }}
{% endfor %}
Return the canonical/official company name.""",
blocking_conditions=[
"input1['company_name'][:3].lower() == input2['company_name'][:3].lower()",
"input1['company_name'].split()[0].lower() == input2['company_name'].split()[0].lower()",
],
blocking_keys=["sector", "company_name"], # Must include reduce_key (sector)
limit_comparisons=1000,
)
result = directive.apply(
"gpt-4o-mini", ops_list, "extract_companies", "aggregate_by_sector", rewrite
)
assert isinstance(result, list)
assert len(result) == 3 # map + resolve + reduce
# Check the resolve operation was inserted correctly
assert result[0]["name"] == "extract_companies"
assert result[0]["type"] == "map"
assert result[1]["name"] == "normalize_company_names"
assert result[1]["type"] == "resolve"
assert "comparison_prompt" in result[1]
assert "resolution_prompt" in result[1]
assert "blocking_conditions" in result[1]
assert "blocking_keys" in result[1]
assert result[1]["blocking_keys"] == ["sector", "company_name"]
assert result[1]["limit_comparisons"] == 1000
assert "input1" in result[1]["comparison_prompt"]
assert "input2" in result[1]["comparison_prompt"]
assert "inputs" in result[1]["resolution_prompt"]
# Output schema should be derived from reduce_key
assert result[1]["output"]["schema"] == {"sector": "string"}
assert result[2]["name"] == "aggregate_by_sector"
assert result[2]["type"] == "reduce"
def test_map_to_map_resolve_reduce_apply_with_multiple_reduce_keys():
"""Test map_to_map_resolve_reduce with multiple reduce_keys"""
directive = MapToMapResolveReduceDirective()
ops_list = [
{
"name": "extract_products",
"type": "map",
"prompt": "Extract product info: {{ input.description }}",
"model": "gpt-4o-mini",
"output": {"schema": {"product_name": "string", "category": "string"}},
},
{
"name": "aggregate_products",
"type": "reduce",
"reduce_key": ["brand", "region"], # Multiple reduce keys
"prompt": "List products:\n{% for input in inputs %}{{ input.product_name }}{% endfor %}",
"model": "gpt-4o-mini",
"output": {"schema": {"products": "list[str]"}},
},
]
from docetl.reasoning_optimizer.instantiate_schemas import (
MapToMapResolveReduceInstantiateSchema,
)
rewrite = MapToMapResolveReduceInstantiateSchema(
resolve_name="normalize_products",
comparison_prompt="Same product? {{ input1.product_name }} vs {{ input2.product_name }}",
resolution_prompt="Normalize: {% for input in inputs %}{{ input.product_name }}{% endfor %}",
blocking_conditions=[
"input1['category'] == input2['category']",
],
blocking_keys=["brand", "region", "product_name"], # Must include reduce_keys
limit_comparisons=500,
)
result = directive.apply(
"gpt-4o-mini", ops_list, "extract_products", "aggregate_products", rewrite
)
assert result[1]["type"] == "resolve"
assert "blocking_keys" in result[1]
assert result[1]["blocking_keys"] == ["brand", "region", "product_name"]
# Output schema should include all reduce_keys
assert result[1]["output"]["schema"] == {"brand": "string", "region": "string"}
def test_map_resolve_to_map_with_categories_apply():
"""Test that map_resolve_to_map_with_categories apply doesn't crash"""
directive = MapResolveToMapWithCategoriesDirective()
# Pipeline with map followed by resolve
ops_list = [
{
"name": "extract_sentiment",
"type": "map",
"prompt": "What is the sentiment? {{ input.review }}",
"model": "gpt-4o-mini",
"output": {"schema": {"sentiment": "string"}},
},
{
"name": "normalize_sentiment",
"type": "resolve",
"comparison_prompt": "Same sentiment? {{ input1.sentiment }} vs {{ input2.sentiment }}",
"resolution_prompt": "Normalize: {% for input in inputs %}{{ input.sentiment }}{% endfor %}",
"model": "gpt-4o-mini",
"output": {"schema": {"sentiment": "string"}},
},
]
from docetl.reasoning_optimizer.instantiate_schemas import (
MapResolveToMapWithCategoriesInstantiateSchema,
)
rewrite = MapResolveToMapWithCategoriesInstantiateSchema(
categories=["Positive", "Negative", "Neutral", "Mixed"],
category_key="sentiment",
new_prompt="""Classify the sentiment of this review into one of these categories:
- Positive: Clearly positive sentiment
- Negative: Clearly negative sentiment
- Neutral: No strong sentiment
- Mixed: Both positive and negative
- None of the above
Review: {{ input.review }}
Return exactly one category.""",
include_none_of_above=True,
)
result = directive.apply(
"gpt-4o-mini", ops_list, "extract_sentiment", "normalize_sentiment", rewrite
)
assert isinstance(result, list)
assert len(result) == 1 # map only, resolve removed
# Check the map operation was modified correctly
assert result[0]["name"] == "extract_sentiment"
assert result[0]["type"] == "map"
assert "Positive" in result[0]["prompt"]
assert "Negative" in result[0]["prompt"]
assert "None of the above" in result[0]["prompt"]
# Check validation was added
assert "validate" in result[0]
assert len(result[0]["validate"]) == 1
assert "Positive" in result[0]["validate"][0]
assert "None of the above" in result[0]["validate"][0]
def test_map_resolve_to_map_with_categories_no_none_of_above():
"""Test map_resolve_to_map_with_categories without 'None of the above' option"""
directive = MapResolveToMapWithCategoriesDirective()
ops_list = [
{
"name": "classify_type",
"type": "map",
"prompt": "What type is this? {{ input.text }}",
"model": "gpt-4o-mini",
"output": {"schema": {"item_type": "string"}},
},
{
"name": "normalize_type",
"type": "resolve",
"comparison_prompt": "Same type? {{ input1.item_type }} vs {{ input2.item_type }}",
"resolution_prompt": "Normalize: {% for input in inputs %}{{ input.item_type }}{% endfor %}",
"model": "gpt-4o-mini",
"output": {"schema": {"item_type": "string"}},
},
]
from docetl.reasoning_optimizer.instantiate_schemas import (
MapResolveToMapWithCategoriesInstantiateSchema,
)
rewrite = MapResolveToMapWithCategoriesInstantiateSchema(
categories=["TypeA", "TypeB", "TypeC"],
category_key="item_type",
new_prompt="""Classify into: TypeA, TypeB, or TypeC
Text: {{ input.text }}""",
include_none_of_above=False,
)
result = directive.apply(
"gpt-4o-mini", ops_list, "classify_type", "normalize_type", rewrite
)
assert len(result) == 1
# Check validation does NOT include 'None of the above'
assert "validate" in result[0]
assert "None of the above" not in result[0]["validate"][0]
assert "TypeA" in result[0]["validate"][0]
assert "TypeB" in result[0]["validate"][0]
assert "TypeC" in result[0]["validate"][0]
if __name__ == "__main__":
# Run all tests
test_chaining_apply()
@ -1078,4 +1312,16 @@ if __name__ == "__main__":
test_cascade_filtering_apply()
print("✅ Cascade filtering apply test passed")
test_map_to_map_resolve_reduce_apply()
print("✅ Map to map resolve reduce apply test passed")
test_map_to_map_resolve_reduce_apply_with_multiple_reduce_keys()
print("✅ Map to map resolve reduce with multiple reduce keys apply test passed")
test_map_resolve_to_map_with_categories_apply()
print("✅ Map resolve to map with categories apply test passed")
test_map_resolve_to_map_with_categories_no_none_of_above()
print("✅ Map resolve to map with categories (no none of above) apply test passed")
print("\n🎉 All directive apply tests passed!")

View File

@ -10,14 +10,13 @@ from typing import Dict, List
from datetime import datetime
from docetl.reasoning_optimizer.directives import (
ChainingDirective,
ChainingDirective,
GleaningDirective,
ReduceGleaningDirective,
ReduceChainingDirective,
ChangeModelDirective,
DocSummarizationDirective,
IsolatingSubtasksDirective,
DocCompressionDirective,
DeterministicDocCompressionDirective,
OperatorFusionDirective,
DocumentChunkingDirective,
@ -28,8 +27,12 @@ from docetl.reasoning_optimizer.directives import (
SwapWithCodeDirective,
HierarchicalReduceDirective,
CascadeFilteringDirective,
TestResult
MapToMapResolveReduceDirective,
MapResolveToMapWithCategoriesDirective,
TestResult,
)
# DocCompressionDirective is commented out in __init__.py, import directly
from docetl.reasoning_optimizer.directives.doc_compression import DocCompressionDirective
def run_all_directive_tests(agent_llm: str = "gpt-4.1") -> Dict[str, List[TestResult]]:
"""
@ -68,6 +71,8 @@ def run_all_directive_tests(agent_llm: str = "gpt-4.1") -> Dict[str, List[TestRe
SwapWithCodeDirective(),
HierarchicalReduceDirective(),
CascadeFilteringDirective(),
MapToMapResolveReduceDirective(),
MapResolveToMapWithCategoriesDirective(),
]
all_results = {}
@ -176,6 +181,8 @@ def run_specific_directive_test(directive_name: str, agent_llm: str = "gpt-4o-mi
"clarify_instructions": ClarifyInstructionsDirective(),
"swap_with_code": SwapWithCodeDirective(),
"cascade_filtering": CascadeFilteringDirective(),
"map_to_map_resolve_reduce": MapToMapResolveReduceDirective(),
"map_resolve_to_map_with_categories": MapResolveToMapWithCategoriesDirective(),
}
if directive_name.lower() not in directive_map: