Compare commits

..

15 Commits

Author SHA1 Message Date
Shreya Shankar bcac6872f5
Embedding blocking threshold optimization (#473)
* feat: Add runtime blocking threshold optimization

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Checkpoint before follow-up message

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Refactor: Simplify target_recall retrieval in Equijoin and Resolve

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Refactor: Improve blocking documentation and add auto-blocking

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* allow resolve and equijoin to figure out blocking thresholds on the fly.

---------

Co-authored-by: Cursor Agent <cursoragent@cursor.com>
2025-12-29 19:47:21 -06:00
Shreya Shankar 57a284bcb1
Fast Decomposition for Map Operations in DocWrangler (#472)
* refactor: docwrangler to use a faster decomposition flow, only if the last operation in a pipeline is a map operation.

* refactor: docwrangler to use a faster decomposition flow, only if the last operation in a pipeline is a map operation.

* refactor: update MOAR documentation
2025-12-29 18:22:02 -06:00
Shreya Shankar cfbb64470a
optimizer: add directives for resolve operator (#470) 2025-12-29 13:57:51 -06:00
Shreya Shankar a2e2e86c22 update: claude code skill when doing hn demo 2025-12-28 15:10:52 -06:00
Shreya Shankar b4a5e98972
Update README with Claude Code integration (#471)
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
2025-12-28 12:49:14 -06:00
Shreya Shankar 9eee1a2553
feat: add claude-code skill (#469)
* feat: add claude-code skill

* feat: add claude-code skill

* get ready to bump up version
2025-12-27 22:23:30 -06:00
Shreya Shankar fa86e23cfb
add limit param to llm ops (#466) 2025-12-26 19:46:28 -08:00
Shreya Shankar 56a1b7e794
Feat: Add first-class LanceDB retrievers and prompt augmentation across LLM operators (#460)
* feat: adding retrievers

* feat: allow indexes to be built on datasets created by docetl pipelines

* Testing retriever and updating docs with logging
2025-12-26 16:48:48 -06:00
Shreya Shankar f78009f0b1
add reduce to pandas api (#468) 2025-12-19 11:26:32 +05:30
Lindsey Wei 81d110404d
Add MOAR optimizer to docetl (#464)
* feat: adding conditional gleaning (#375)

* fix: improve caching and don't raise error for bad gather configs

* fix: improve caching and don't raise error for bad gather configs

* feat: adding conditional gleaning

* chore: bump up fastapi and python multipart (#376)

* merge

* chore: bump up fastapi and python multipart

* chore: bump up fastapi and python multipart

* pipeline for chaining

* chaining + gleaning(map)

* Clean up and reorganize pytest tests (#377)

* Replace api_wrapper with runner in test fixtures and configurations

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Refactor test fixtures and reorganize configuration in test files

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

---------

Co-authored-by: Cursor Agent <cursoragent@cursor.com>

* Refactor api.py for structured output (#378)

* Refactor API wrapper with modular design for LLM calls and output handling

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Refactor APIWrapper: Simplify LLM call logic and improve modularity

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Refactor output mode handling in APIWrapper with flexible configuration

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Add comprehensive tests for DocETL output modes with synthetic data

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Refactor output modes tests with improved pytest structure and DSLRunner

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Fix runtime errors

* Add nested JSON parsing for string values in API response

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Handle nested JSON parsing by extracting matching key values

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Simplify JSON parsing logic in API utility functions

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Add to tests

* Add documentation for DocETL output modes and configuration options

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Add docs

---------

Co-authored-by: Cursor Agent <cursoragent@cursor.com>

* Add operators to pandas API (#379)

* baseline w/ 3 rewrites

* added sample

* 3 choices

* feat: add global bypass cache (#383)

* baseline experiments

* baseline experiments

* update directives

* MCTS random expand

* MCTS v1

* MCTS V1 w/ instantiation check on chaining

* MCTS V2

* remove key

* delete key

* update evaluation code

* >1 instantiation for chaining & true f1 score

* true f1 score as acc

* hypervolume as value algo

* Add directives folder from PR #1

* tests for directives

* updates on instantiate_schemas.py

* adding experiment folder from P1

* added init for mcts

* partially merged PR1

* fix instantiate schema

* added chunking directives

* passed directive tests

* fix bug in doc_chunking instantiate

* fix doc_chunking bug

* update prompt for chunk_header_summary

* update gitignore

* fix mcts bugs when rebasing

* add chunk sampling directive

* print AUC at end

* reward update
merged

* fixed bug in doc_compression and avoid using chunk_header_summary twice

* fixed chunk sampling

* adding blackvault dataset and refactoring hard coded cuad statistics

* scaled cost, progressive widening

* feat: add head tail directive

* feat: add head tail directive

* feat: add head tail directive

* merge chunk sampling + doc chunking into one directive. also modify sample operator to stratify by multiple keys

* add game reviews experiemnt and fix doc chunking

* trim down some fat in game reviews eval logic

* syntax check for stratify key

* added split_key check and fixed bug in take_head_tail

* added **kwargs

* remove reduce in op fusion

* adding medec workload

* document_key check for doc_compression & load schema data

* add sustainability expt

* add reduce chaining

* update expand retry & op fusion

* remove hard coding of blackvault

* remove hard coding of blackvault

* adding agent for instantiation (that is allowed to read docs), and a clarify instructions directive

* adding agent for instantiation (that is allowed to read docs), and a clarify instructions directive

* add swap with code directive

* optimizing cost in first 5 iterations

* added mcts log and remove plot

* add memory along tree path and action reward in log

* add biodex and have train and test splits

* add biodex and have train and test splits

* added map reduce fusion, small changes to mcts

* added reference point for calculating hypervolume in evaluation

* fix bug in calculating hypervolume

* integrate with modal

* add run_all.py script

* add run_all.py script

* fix bug in baseline applying ops

* update mcts chat history

* update mcts chat history

* control message length to fit context window

* fix agent baseline loop

* remove print statement

* update run baseline

* adding extremely simple agent baseline that just gives us entirely new pipelines

* adding extremely simple agent baseline that just gives us entirely new pipelines

* update readme with simple / naive agent baseline

* update

* change modal image

* edit gitignore to operators doc

* adding retrieval based rewrite directive

* debug simple baseline

* fix baseline agent and top k chunking directive

* fix baseline agent and top k chunking directive

* fix baseline agent and top k chunking directive

* fix chunking topk directive and add hierarchical reduce

* merge the original plan execution of three methods

* add plot and hypervolume calculation

* feat: add cascade filtering directive

* feat: add arbitrary rewrite directive

* remove unnecessary field

* fixed fall back method for baseline & mcts

* adding some validation for arbitrary rewrite

* update sample operation

* change model choice

* updates on plot

* fix gpt-5 temperature in query engine

* update model choice in base

* use rp@5 for biodex eval

* separate cost & acc optimize change model directive

* add abacus to plot

* added memo table for each node and provide it to the agent for directive selection

* add facility dataset (we dont need to use for paper) and lotus baselines

* edit prompts in change model directive

* change models

* added rules on compositions of directives. clean up mcts using utils

* add multiple instantiation for selected directives

* fix gemini errors

* adding PZ baselines (still in progress

* add all PZ baselines to main

* add all PZ baselines to main

* fix bugs about multiple instantiations

* add utility to plot the test set

* fix linter errors

* fix PZ script for medec dataset

* add original pipeline on test set to plot

* change color of mcts

* test all model on the input query before search

* add search cost calculation

* fixed bug in json file path

* add concurrency control; start searching from pareto model plans

* set reward to be vertical distance to the step frontier

* newest mcts version

* small changes during exp

* change to run lotus, add bootstrap

* Add needs_code_filter variable for operator fusion

* add validation step; randomly select when values have ties

* eval function change

* clean up run test frontier

* Delete bootstrap_evaluation.py

* Delete BioDEX_evaluate.py

* Delete CUAD_evaluate.py

* Delete CUAD_sample50.py

* Delete bootstrapping.py

* Delete evaluate_blackvault.py

* Delete exp_graph.py

* Delete exp_graph_max.py

* Delete re_evaluate_zero_scores.py

* Delete split_json_data.py

* Delete test_evaluation.py

* Delete test_gemini_models.py

* Delete test_operations_modal.py

* Delete test_validate_frontier.py

* Delete validate_pareto_frontier.py

* Delete docetl/reasoning_optimizer/Untitled-1.py

* Delete docetl/BioDEX_test.py

* Delete docetl/mcts/execute_res_HV directory

* Delete docetl/mcts/graph_baseline.py

* Delete docetl/mcts/graph.py

* remove acc comparator

* delete graph

* delete irrelevant files

* mcts clean up

* update naming

* clean up simple agent

* update readme

* add user specified eval function

* Delete experiments/reasoning/run_tests.py

* Delete experiments/reasoning/run_test_frontier.py

* Delete experiments/reasoning/run_baseline.py

* Delete experiments/reasoning/run_all.py

* Delete experiments/reasoning/plot_result.py

* Delete experiments/reasoning/plot_matrix.py

* Delete experiments/reasoning/combine_biodex_test_results.py

* Delete experiments/reasoning/create_biodex_test_summary.py

* Delete experiments/reasoning/generate_biodex_summary.py

* Delete experiments/reasoning/TEST_FRONTIER_README.md

* Delete experiments/reasoning/utils.py

* Delete experiments/reasoning/README.md

* Delete experiments/reasoning/othersystems directory

* Delete experiments/reasoning/outputs/blackvault_baseline

* Delete experiments/reasoning/outputs/blackvault_mcts

* Delete experiments/reasoning/outputs/cuad_lotus_evaluation.json

* Delete compute_words_per_document.py

* Delete docetl/moar/acc_comparator.py

* Delete docetl/moar/acc_comparator.py.backup

* Delete docetl/moar/instantiation_check.py

* Delete docetl/reasoning_optimizer/build_optimization.py

* Delete docetl/reasoning_optimizer/generate_rewrite_plan.py

* remove auc

* simplify util

* Delete experiments/reasoning/data directory

* requirements

* eval function cleaning

* update readme

* update readme

* update readme

* remove facility

* update readme

* Update api to be consistent with main

* remove unused files

* clean up relative imports and model choices

* change all printing to use the rich console

* add documentation for moar and CLI

* move evaluation code from experiments dir to moar dir

* fix documentation

---------

Co-authored-by: Shreya Shankar <ss.shankar505@gmail.com>
Co-authored-by: linxiwei <lindseywei@visitor-10-57-110-173.wifi.berkeley.edu>
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: linxiwei <lindseywei@visitor-10-57-109-39.wifi.berkeley.edu>
Co-authored-by: linxiwei <lindseywei@lindseys-mbp-7.lan>
Co-authored-by: linxiwei <lindseywei@visitor-10-57-111-186.wifi.berkeley.edu>
Co-authored-by: linxiwei <lindseywei@wifi-10-41-110-112.wifi.berkeley.edu>
Co-authored-by: linxiwei <lindseywei@wifi-10-44-111-72.wifi.berkeley.edu>
Co-authored-by: Lindsey Wei <152750390+LindseyyyW@users.noreply.github.com>
Co-authored-by: linxiwei <lindseywei@wifi-10-44-110-253.wifi.berkeley.edu>
Co-authored-by: linxiwei <lindseywei@wifi-10-41-110-92.wifi.berkeley.edu>
Co-authored-by: linxiwei <lindseywei@wifi-10-41-108-154.wifi.berkeley.edu>
Co-authored-by: linxiwei <lindseywei@wifi-10-41-109-198.wifi.berkeley.edu>
Co-authored-by: linxiwei <lindseywei@Lindseys-MacBook-Pro-7.local>
Co-authored-by: linxiwei <lindseywei@192.168.0.110>
2025-11-28 13:47:42 -06:00
Shreya Shankar 0110071cd5
fix: add code ops and extract to python api (#463) 2025-11-24 14:18:59 -06:00
Shreya Shankar a184a3c1e9
fix: add code ops and extract to python api (#462)
* fix: add code ops and extract to python api

* fix: add code ops and extract to python api
2025-11-24 14:10:07 -06:00
Shreya Shankar 7cca6f57b5
Graceful jinja template handling with user confirmation (#452)
* feat: Add user confirmation for non-Jinja prompts

This commit introduces a confirmation step for prompts that do not contain Jinja2 syntax. It also modifies strict_render to automatically append document context when Jinja syntax is absent.

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Refactor: Move DOCETL_CONSOLE import to function scope

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Refactor: Move has_jinja_syntax to docetl.utils

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

---------

Co-authored-by: Cursor Agent <cursoragent@cursor.com>
2025-11-22 14:02:19 -08:00
Shreya Shankar 9ebfc1bd58
feat: add ability to sort chronologically for epstein emails (#458) 2025-11-16 09:08:10 -08:00
Shreya Shankar 067e671650
feat: add ability to sort chronologically for epstein emails (#457)
* feat: add ability to sort chronologically for epstein emails

* feat: add ability to sort chronologically for epstein emails
2025-11-16 08:48:56 -08:00
146 changed files with 30990 additions and 707 deletions

View File

@ -0,0 +1,776 @@
---
name: docetl
description: Build and run LLM-powered data processing pipelines with DocETL. Use when users say "docetl", want to analyze unstructured data, process documents, extract information, or run ETL tasks on text. Helps with data collection, pipeline creation, execution, and optimization.
---
# DocETL Pipeline Development
DocETL is a system for creating LLM-powered data processing pipelines. This skill helps you build end-to-end pipelines: from data preparation to execution and optimization.
## Workflow Overview: Iterative Data Analysis
Work like a data analyst: **write → run → inspect → iterate**. Never write all scripts at once and run them all at once. Each phase should be completed and validated before moving to the next.
### Phase 1: Data Collection
1. Write data collection script
2. **Run it immediately** (with user permission)
3. **Inspect the dataset** - show the user:
- Total document count
- Keys/fields in each document
- Sample documents (first 3-5)
- Length distribution (avg chars, min/max)
- Any other relevant statistics
4. Iterate if needed (e.g., collect more data, fix parsing issues)
### Phase 2: Pipeline Development
1. Read sample documents to understand format
2. Write pipeline YAML with `sample: 10-20` for testing
3. **Run the test pipeline**
4. **Inspect intermediate results** - show the user:
- Extraction quality on samples
- Domain/category distributions
- Any validation failures
5. Iterate on prompts/schema based on results
6. Remove `sample` parameter and run full pipeline
7. **Show final results** - distributions, trends, key insights
### Phase 3: Visualization & Presentation
1. Write visualization script based on actual output structure
2. **Run and show the report** to the user
3. Iterate on charts/tables if needed
**Visualization Aesthetics:**
- **Clean and minimalist** - no clutter, generous whitespace
- **Warm and elegant color theme** - 1-2 accent colors max
- **Subtle borders** - not too rounded (border-radius: 8-10px max)
- **Sans-serif fonts** - system fonts like -apple-system, Segoe UI, Roboto
- **"Created by DocETL"** - add subtitle after the main title
- **Mix of charts and tables** - charts for distributions, tables for detailed summaries
- **Light background** - off-white (#f5f5f5) with white cards for content
**Report Structure:**
1. Title + "Created by DocETL" subtitle
2. Key stats cards (document count, categories, etc.)
3. Distribution charts (bar charts, pie charts)
4. Summary table with detailed analysis
5. Minimal footer
**Interactive Tables:**
- **All truncated content must be expandable** - never use static "..." truncation
- Long text: Show first ~250 chars with "(show more)" toggle
- Long lists: Show first 4-6 items with "(+N more)" toggle
- Use JavaScript to toggle visibility, not page reloads
**Source Document Links:**
- **Link aggregated results to source documents** - users should be able to drill down
- Clickable links that open a modal/popup with source content
- Modal should show: extracted fields + original source text
- Original text can be collapsed by default with "Show original" toggle
- Embed source data as JSON in the page for JavaScript access
**Key principle:** The user should see results at every step. Don't proceed to the next phase until the current phase produces good results.
## Step 1: Data Preparation
DocETL datasets must be **JSON arrays** or **CSV files**.
### JSON Format
```json
[
{"id": 1, "text": "First document content...", "metadata": "value"},
{"id": 2, "text": "Second document content...", "metadata": "value"}
]
```
### CSV Format
```csv
id,text,metadata
1,"First document content...","value"
2,"Second document content...","value"
```
### Data Collection Scripts
If user needs to collect data, write a Python script:
```python
import json
# Collect/transform data
documents = []
for source in sources:
documents.append({
"id": source.id,
"text": source.content, # DO NOT truncate text
# Add relevant fields
})
# Save as DocETL dataset
with open("dataset.json", "w") as f:
json.dump(documents, f, indent=2)
```
**Important:** Never truncate document text in collection scripts. DocETL operations like `split` handle long documents properly. Truncation loses information.
### After Running Data Collection
**Always run the collection script and inspect results before proceeding.** Show the user:
```python
import json
data = json.load(open("dataset.json"))
print(f"Total documents: {len(data)}")
print(f"Keys: {list(data[0].keys())}")
print(f"Avg length: {sum(len(str(d)) for d in data) // len(data)} chars")
# Show sample
print("\nSample document:")
print(json.dumps(data[0], indent=2)[:500])
```
Only proceed to pipeline development once the data looks correct.
## Step 2: Read and Understand the Data
**CRITICAL**: Before writing any prompts, READ the actual input data to understand:
- The structure and format of documents
- The vocabulary and terminology used
- What information is present vs. absent
- Edge cases and variations
```python
import json
with open("dataset.json") as f:
data = json.load(f)
# Examine several examples
for doc in data[:5]:
print(doc)
```
This understanding is essential for writing specific, effective prompts.
## Step 3: Pipeline Structure
Create a YAML file with this structure:
```yaml
default_model: gpt-5-nano
system_prompt:
dataset_description: <describe the data based on what you observed>
persona: <role for the LLM to adopt>
datasets:
input_data:
type: file
path: "dataset.json" # or dataset.csv
operations:
- name: <operation_name>
type: <operation_type>
prompt: |
<Detailed, specific prompt based on the actual data>
output:
schema:
<field_name>: <type>
pipeline:
steps:
- name: process
input: input_data
operations:
- <operation_name>
output:
type: file
path: "output.json"
intermediate_dir: "intermediates" # ALWAYS set this for debugging
```
### Key Configuration
- **default_model**: Use `gpt-5-nano` or `gpt-5-mini` for extraction/map operations
- **intermediate_dir**: Always set to log intermediate results
- **system_prompt**: Describe the data based on what you actually observed
### Model Selection by Operation Type
| Operation Type | Recommended Model | Rationale |
|---------------|-------------------|-----------|
| Map (extraction) | `gpt-5-nano` or `gpt-5-mini` | High volume, simple per-doc tasks |
| Filter | `gpt-5-nano` | Simple yes/no decisions |
| Reduce (summarization) | `gpt-4.1` or `gpt-5.1` | Complex synthesis across many docs |
| Resolve (deduplication) | `gpt-5-nano` or `gpt-5-mini` | Simple pairwise comparisons |
Use cheaper models for high-volume extraction, and more capable models for synthesis/summarization where quality matters most.
## Step 4: Writing Effective Prompts
**Prompts must be specific to the data, not generic.** After reading the input data:
### Bad (Generic) Prompt
```yaml
prompt: |
Extract key information from this document.
{{ input.text }}
```
### Good (Specific) Prompt
```yaml
prompt: |
You are analyzing a medical transcript from a doctor-patient visit.
The transcript follows this format:
- Doctor statements are prefixed with "DR:"
- Patient statements are prefixed with "PT:"
- Timestamps appear in brackets like [00:05:23]
From the following transcript, extract:
1. All medications mentioned (brand names or generic)
2. Dosages if specified
3. Patient-reported side effects or concerns
Transcript:
{{ input.transcript }}
Be thorough - patients often mention medication names informally.
If a medication is unclear, include it with a note.
```
### Prompt Writing Guidelines
1. **Describe the data format** you observed
2. **Be specific about what to extract** - list exact fields
3. **Mention edge cases** you noticed in the data
4. **Provide examples** if the task is ambiguous
5. **Set expectations** for handling missing/unclear information
## Step 5: Choosing Operations
Many tasks only need a **single map operation**. Use good judgement:
| Task | Recommended Approach |
|------|---------------------|
| Extract info from each doc | Single `map` |
| Multiple extractions | Multiple `map` operations chained |
| Extract then summarize | `map``reduce` |
| Filter then process | `filter``map` |
| Split long docs | `split``map``reduce` |
| Deduplicate entities | `map``unnest``resolve` |
## Operation Reference
### Map Operation
Applies an LLM transformation to each document independently.
```yaml
- name: extract_info
type: map
prompt: |
Analyze this document:
{{ input.text }}
Extract the main topic and 3 key points.
output:
schema:
topic: string
key_points: list[string]
model: gpt-5-nano # optional, uses default_model if not set
skip_on_error: true # recommended for large-scale runs
validate: # optional
- len(output["key_points"]) == 3
num_retries_on_validate_failure: 2 # optional
```
**Key parameters:**
- `prompt`: Jinja2 template, use `{{ input.field }}` to reference fields
- `output.schema`: Define output structure
- `skip_on_error`: Set `true` to continue on LLM errors (recommended at scale)
- `validate`: Python expressions to validate output
- `sample`: Process only N documents (for testing)
- `limit`: Stop after producing N outputs
### Filter Operation
Keeps or removes documents based on LLM criteria. Output schema must have exactly one boolean field.
```yaml
- name: filter_relevant
type: filter
skip_on_error: true
prompt: |
Document: {{ input.text }}
Is this document relevant to climate change?
Respond true or false.
output:
schema:
is_relevant: boolean
```
### Reduce Operation
Aggregates documents by a key using an LLM.
**Always include `fold_prompt` and `fold_batch_size`** for reduce operations. This handles cases where the group is too large to fit in context.
```yaml
- name: summarize_by_category
type: reduce
reduce_key: category # use "_all" to aggregate everything
skip_on_error: true
prompt: |
Summarize these {{ inputs | length }} items for category "{{ inputs[0].category }}":
{% for item in inputs %}
- {{ item.title }}: {{ item.description }}
{% endfor %}
Provide a 2-3 sentence summary of the key themes.
fold_prompt: |
You have a summary based on previous items, and new items to incorporate.
Previous summary (based on {{ output.item_count }} items):
{{ output.summary }}
New items ({{ inputs | length }} more):
{% for item in inputs %}
- {{ item.title }}: {{ item.description }}
{% endfor %}
Write a NEW summary that covers ALL items (previous + new).
IMPORTANT: Output a clean, standalone summary as if describing the entire dataset.
Do NOT mention "updated", "added", "new items", or reference the incremental process.
fold_batch_size: 100
output:
schema:
summary: string
item_count: int
validate:
- len(output["summary"].strip()) > 0
num_retries_on_validate_failure: 2
```
**Critical: Writing Good Fold Prompts**
The `fold_prompt` is called repeatedly as batches are processed. Its output must:
1. **Reflect ALL data seen so far**, not just the latest batch
2. **Be a clean, standalone output** - no "updated X" or "added Y items" language
3. **Match the same schema** as the initial `prompt` output
Bad fold_prompt output: "Added 50 new projects. The updated summary now includes..."
Good fold_prompt output: "Developers are building privacy-focused tools and local-first apps..."
**Estimating `fold_batch_size`:**
- **Use 100+ for most cases** - larger batches = fewer LLM calls = lower cost
- For very long documents, reduce to 50-75
- For short documents (tweets, titles), can use 150-200
- Models like gpt-4o-mini have 128k context, so batch size is rarely the bottleneck
**Key parameters:**
- `reduce_key`: Field to group by (or list of fields, or `_all`)
- `fold_prompt`: Template for incrementally adding items to existing output (required)
- `fold_batch_size`: Number of items per fold iteration (required, use 100+)
- `associative`: Set to `false` if order matters
### Split Operation
Divides long text into smaller chunks. No LLM call.
```yaml
- name: split_document
type: split
split_key: content
method: token_count # or "delimiter"
method_kwargs:
num_tokens: 500
model: gpt-5-nano
```
**Output adds:**
- `{split_key}_chunk`: The chunk content
- `{op_name}_id`: Original document ID
- `{op_name}_chunk_num`: Chunk number
### Unnest Operation
Flattens list fields into separate rows. No LLM call.
```yaml
- name: unnest_items
type: unnest
unnest_key: items # field containing the list
keep_empty: false # optional
```
**Example:** If a document has `items: ["a", "b", "c"]`, unnest creates 3 documents, each with `items: "a"`, `items: "b"`, `items: "c"`.
### Resolve Operation
Deduplicates and canonicalizes entities. Uses pairwise comparison.
```yaml
- name: dedupe_names
type: resolve
optimize: true # let optimizer find blocking rules
skip_on_error: true
comparison_prompt: |
Are these the same person?
Person 1: {{ input1.name }} ({{ input1.email }})
Person 2: {{ input2.name }} ({{ input2.email }})
Respond true or false.
resolution_prompt: |
Standardize this person's name:
{% for entry in inputs %}
- {{ entry.name }}
{% endfor %}
Return the canonical name.
output:
schema:
name: string
```
**Important:** Set `optimize: true` and run `docetl build` to generate efficient blocking rules. Without blocking, this is O(n²).
### Code Operations
Deterministic Python transformations without LLM calls.
**code_map:**
```yaml
- name: compute_stats
type: code_map
code: |
def transform(doc) -> dict:
return {
"word_count": len(doc["text"].split()),
"char_count": len(doc["text"])
}
```
**code_reduce:**
```yaml
- name: aggregate
type: code_reduce
reduce_key: category
code: |
def transform(items) -> dict:
total = sum(item["value"] for item in items)
return {"total": total, "count": len(items)}
```
**code_filter:**
```yaml
- name: filter_long
type: code_filter
code: |
def transform(doc) -> bool:
return len(doc["text"]) > 100
```
### Retrievers (LanceDB)
Augment LLM operations with retrieved context from a LanceDB index. Useful for:
- Finding related documents to compare against
- Providing additional context for extraction/classification
- Cross-referencing facts across a dataset
**Define a retriever:**
```yaml
retrievers:
facts_index:
type: lancedb
dataset: extracted_facts # dataset to index
index_dir: workloads/wiki/lance_index
build_index: if_missing # if_missing | always | never
index_types: ["fts", "embedding"] # or "hybrid"
fts:
index_phrase: "{{ input.fact }}: {{ input.source }}"
query_phrase: "{{ input.fact }}"
embedding:
model: openai/text-embedding-3-small
index_phrase: "{{ input.fact }}"
query_phrase: "{{ input.fact }}"
query:
mode: hybrid
top_k: 5
```
**Use in operations:**
```yaml
- name: find_conflicts
type: map
retriever: facts_index
prompt: |
Check if this fact conflicts with any retrieved facts:
Current fact: {{ input.fact }} (from {{ input.source }})
Related facts from other articles:
{{ retrieval_context }}
Return whether there's a genuine conflict.
output:
schema:
has_conflict: boolean
```
**Key points:**
- `{{ retrieval_context }}` is injected into prompts automatically
- Index is built on first use (when `build_index: if_missing`)
- Supports full-text (`fts`), vector (`embedding`), or `hybrid` search
- Use `save_retriever_output: true` to debug what was retrieved
- **Can index intermediate outputs**: Retriever can index the output of a previous pipeline step, enabling patterns like "extract facts → index facts → retrieve similar facts for each"
## Documentation Reference
For detailed parameters, advanced features, and more examples, read the docs:
- **Operations**: `docs/operators/` folder (map.md, reduce.md, filter.md, etc.)
- **Concepts**: `docs/concepts/` folder (pipelines.md, operators.md, schemas.md)
- **Examples**: `docs/examples/` folder
- **Optimization**: `docs/optimization/` folder
## Step 6: Environment Setup
Before running, verify API keys exist:
```bash
# Check for .env file
cat .env
```
Required keys depend on the model:
- OpenAI: `OPENAI_API_KEY`
- Anthropic: `ANTHROPIC_API_KEY`
- Google: `GEMINI_API_KEY`
If missing, prompt user to create `.env`:
```
OPENAI_API_KEY=sk-...
```
## Step 7: Execution
**Always test on a sample first, then run full pipeline.**
### Test Run (Required)
Add `sample: 10-20` to your first operation, then run:
```bash
docetl run pipeline.yaml
```
**Inspect the test results before proceeding:**
```python
import json
from collections import Counter
# Load intermediate results
data = json.load(open("intermediates/step_name/operation_name.json"))
print(f"Processed: {len(data)} docs")
# Check distributions
if "domain" in data[0]:
print("Domain distribution:")
for k, v in Counter(d["domain"] for d in data).most_common():
print(f" {k}: {v}")
# Show sample outputs
print("\nSample output:")
print(json.dumps(data[0], indent=2))
```
### Full Run
Once test results look good:
1. Remove the `sample` parameter from the pipeline
2. Ask user for permission (estimate cost based on test run)
3. Run full pipeline
4. **Show final results** - distributions, key insights, trends
Options:
- `--max_threads N` - Control parallelism
Check intermediate results in the `intermediate_dir` folder to debug each step.
## Step 8: Optimization (Optional)
Use MOAR optimizer to find the Pareto frontier of **cost vs. accuracy** tradeoffs. MOAR experiments with different pipeline rewrites and models to find optimal configurations.
Add to pipeline YAML:
```yaml
optimizer_config:
type: moar
save_dir: ./optimization_results
available_models:
- gpt-5-nano
- gpt-4o-mini
- gpt-4o
evaluation_file: evaluate.py # User must provide
metric_key: score
max_iterations: 20
model: gpt-5-nano
```
Create evaluation file (`evaluate.py`):
```python
def evaluate(outputs: list[dict]) -> dict:
# Score the outputs (0-1 scale recommended)
correct = sum(1 for o in outputs if is_correct(o))
return {"score": correct / len(outputs)}
```
Run optimization:
```bash
docetl build pipeline.yaml --optimizer moar
```
MOAR will produce multiple pipeline variants on the Pareto frontier - user can choose based on their cost/accuracy preferences.
## Output Schemas
**Keep schemas minimal and simple** unless the user explicitly requests more fields. Default to 1-3 output fields per operation. Only add more fields if the user specifically asks for them.
**Nesting limit:** Maximum 2 levels deep (e.g., `list[{field: str}]` is allowed, but no deeper).
```yaml
# Good - minimal, focused on the core task
output:
schema:
summary: string
# Good - a few fields when task requires it
output:
schema:
topic: string
keywords: list[string]
# Acceptable - 2 levels of nesting (list of objects)
output:
schema:
items: "list[{name: str, value: int}]"
# Bad - too many fields (unless user explicitly requested all of these)
output:
schema:
conflicts_found: bool
num_conflicts: int
conflicts: "list[{claim_a: str, source_a: str, claim_b: str, source_b: str}]"
analysis_summary: str
# Bad - more than 2 levels of nesting (not supported)
output:
schema:
data: "list[{nested: {too: {deep: str}}}]"
```
**Guidelines:**
- Start with the minimum fields needed to answer the user's question
- Avoid complex nested objects unless explicitly requested
- If you need structured data, prefer multiple simple operations over one complex schema
- Complex schemas increase LLM failures and costs
Supported types: `string`, `int`, `float`, `bool`, `list[type]`, `enum`
## Validation
**Always add validation to LLM-powered operations** (map, reduce, filter, resolve). Validation catches malformed outputs and retries automatically.
```yaml
- name: extract_keywords
type: map
prompt: |
Extract 3-5 keywords from: {{ input.text }}
output:
schema:
keywords: list[string]
validate:
- len(output["keywords"]) >= 3
- len(output["keywords"]) <= 5
num_retries_on_validate_failure: 2
```
Common validation patterns:
```yaml
# List length constraints
- len(output["items"]) >= 1
- len(output["items"]) <= 10
# Enum/allowed values
- output["sentiment"] in ["positive", "negative", "neutral"]
# String not empty
- len(output["summary"].strip()) > 0
# Numeric ranges
- output["score"] >= 0
- output["score"] <= 100
```
## Jinja2 Templating
For map operations, use `input`:
```yaml
prompt: |
Document: {{ input.text }}
{% if input.metadata %}
Context: {{ input.metadata }}
{% endif %}
```
For reduce operations, use `inputs` (list):
```yaml
prompt: |
Summarize these {{ inputs | length }} items:
{% for item in inputs %}
- {{ item.summary }}
{% endfor %}
```
## Troubleshooting
### Pipeline won't run
- Check `.env` has correct API keys
- Verify dataset file exists and is valid JSON/CSV
- Check YAML syntax
### Bad outputs
- Read more input data examples to improve prompt specificity
- Add `validate` rules with retries
- Simplify output schema
- Add concrete examples to prompt
### High costs
- Use `gpt-5-nano` or `gpt-4o-mini`
- Add `sample: 10` to test on subset first
- Run MOAR optimizer to find cost-efficient rewrites
### Check intermediate results
Look in `intermediate_dir` folder to debug each step.
## Quick Reference
```bash
# Run pipeline
docetl run pipeline.yaml
# Run with more parallelism
docetl run pipeline.yaml --max_threads 16
# Optimize pipeline (cost/accuracy tradeoff)
docetl build pipeline.yaml --optimizer moar
# Clear LLM cache
docetl clear-cache
# Check version
docetl version
```

21
.gitignore vendored
View File

@ -57,7 +57,26 @@ website/next-env.d.ts
# experiments
experiments/*.json
experiments/reasoning/data/*.json
metrics_vs_cost.png
tests/data/anthropic-red-team-attempts.jsonl
tests/data/get_freshstack.py
tests/data/get_freshstack.py
!experiments/reasoning/data/operators_documentation.txt
*.txt
*.yaml
*.yml
*.json
!*lotus_evaluation.json
!*pz_evaluation.json
!*pz_*_evaluation.json
*.png
graph.py
graph_baseline.py
kendalltau.py
slides/*
node_modules/*
reaction_index/
*experiments/reasoning/othersystems/biodex/.chroma-biodex/

View File

@ -13,7 +13,8 @@ DocETL is a tool for creating and executing data processing pipelines, especiall
2. A Python package for running production pipelines from the command line or Python code
> 💡 **Need Help Writing Your Pipeline?**
> Want to use an LLM like ChatGPT or Claude to help you write your pipeline? See [docetl.org/llms.txt](https://docetl.org/llms.txt) for a big prompt you can copy paste into ChatGPT or Claude, before describing your task.
> You can use **Claude Code** (recommended) to help you write your pipeline—see the quickstart: https://ucbepic.github.io/docetl/quickstart-claude-code/
> If youd rather use ChatGPT or the Claude app, see [docetl.org/llms.txt](https://docetl.org/llms.txt) for a big prompt you can copy/paste before describing your task.
### 🌟 Community Projects

View File

@ -1,11 +1,15 @@
__version__ = "0.2.5"
__version__ = "0.2.6"
import warnings
import litellm
from docetl.runner import DSLRunner
from docetl.optimizer import Optimizer
from docetl.apis.pd_accessors import SemanticAccessor
# Drop unsupported params for models like gpt-5 that don't support temperature=0
litellm.drop_params = True
# TODO: Remove after https://github.com/BerriAI/litellm/issues/7560 is fixed
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._config")

View File

@ -59,8 +59,12 @@ from rich import print
from docetl.runner import DSLRunner
from docetl.schemas import (
ClusterOp,
CodeFilterOp,
CodeMapOp,
CodeReduceOp,
Dataset,
EquijoinOp,
ExtractOp,
FilterOp,
GatherOp,
MapOp,
@ -340,6 +344,14 @@ class Pipeline:
self.operations.append(ClusterOp(**op, type=op_type))
elif op_type == "sample":
self.operations.append(SampleOp(**op, type=op_type))
elif op_type == "code_map":
self.operations.append(CodeMapOp(**op, type=op_type))
elif op_type == "code_reduce":
self.operations.append(CodeReduceOp(**op, type=op_type))
elif op_type == "code_filter":
self.operations.append(CodeFilterOp(**op, type=op_type))
elif op_type == "extract":
self.operations.append(ExtractOp(**op, type=op_type))
self.steps = [PipelineStep(**step) for step in config["pipeline"]["steps"]]
self.output = PipelineOutput(**config["pipeline"]["output"])
self.default_model = config.get("default_model")
@ -363,6 +375,10 @@ __all__ = [
"SplitOp",
"GatherOp",
"UnnestOp",
"CodeMapOp",
"CodeReduceOp",
"CodeFilterOp",
"ExtractOp",
"PipelineStep",
"PipelineOutput",
"ParsingTool",

View File

@ -673,6 +673,68 @@ Record 2: {record_template.replace('input0', 'input2')}"""
return self._record_operation(results, "reduce", reduce_config, reduce_cost)
def reduce(
self,
prompt: str,
output: dict[str, Any] | None = None,
*,
output_schema: dict[str, Any] | None = None,
reduce_keys: str | list[str] = ["_all"],
**kwargs,
) -> pd.DataFrame:
"""
Reduce/aggregate all rows using a language model.
This is a simplified wrapper around the agg() method for common reduce operations.
Documentation: https://ucbepic.github.io/docetl/operators/reduce/
Args:
prompt: Jinja template string for the reduction prompt. Use {% for item in inputs %}
to iterate over input rows.
output: Output configuration with keys:
- "schema": Dictionary defining the expected output structure and types
- "mode": Optional output mode. Either "tools" (default) or "structured_output"
output_schema: DEPRECATED. Use 'output' parameter instead.
reduce_keys: Keys to group by for reduction (default: ["_all"] for all rows)
**kwargs: Additional configuration options passed to agg():
- model: LLM model to use
- validate: List of validation expressions
- num_retries_on_validate_failure: Number of retries
- timeout: Timeout in seconds (default: 120)
- max_retries_per_timeout: Max retries per timeout (default: 2)
Returns:
pd.DataFrame: Aggregated DataFrame with columns matching output['schema']
Examples:
>>> # Summarize all texts into one summary
>>> df.semantic.reduce(
... prompt=\"\"\"Summarize the following items:
... {% for item in inputs %}
... - {{ item.text }}
... {% endfor %}\"\"\",
... output={"schema": {"summary": "str"}}
... )
>>> # Reduce by group
>>> df.semantic.reduce(
... prompt=\"\"\"Combine items for {{ reduce_key }}:
... {% for item in inputs %}
... - {{ item.description }}
... {% endfor %}\"\"\",
... output={"schema": {"combined": "str"}},
... reduce_keys=["category"]
... )
"""
return self.agg(
reduce_prompt=prompt,
output=output,
output_schema=output_schema,
reduce_keys=reduce_keys,
reduce_kwargs=kwargs,
)
def filter(
self,
prompt: str,

View File

@ -3,10 +3,15 @@ from pathlib import Path
import typer
from dotenv import load_dotenv
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from docetl.operations.utils import clear_cache as cc
from docetl.runner import DSLRunner
console = Console(stderr=True)
app = typer.Typer(pretty_exceptions_enable=False)
@ -15,6 +20,12 @@ def build(
yaml_file: Path = typer.Argument(
..., help="Path to the YAML file containing the pipeline configuration"
),
optimizer: str = typer.Option(
"moar",
"--optimizer",
"-o",
help="Optimizer to use: 'moar' (default) or 'v1' (deprecated)",
),
max_threads: int | None = typer.Option(
None, help="Maximum number of threads to use for running operations"
),
@ -31,8 +42,8 @@ def build(
Args:
yaml_file (Path): Path to the YAML file containing the pipeline configuration.
optimizer (str): Optimizer to use - 'moar' or 'v1' (required).
max_threads (int | None): Maximum number of threads to use for running operations.
model (str): Model to use for optimization. Defaults to "gpt-4o".
resume (bool): Whether to resume optimization from a previous run. Defaults to False.
save_path (Path): Path to save the optimized pipeline configuration.
"""
@ -44,13 +55,147 @@ def build(
if os.path.exists(env_file):
load_dotenv(env_file)
runner = DSLRunner.from_yaml(str(yaml_file), max_threads=max_threads)
runner.optimize(
save=True,
return_pipeline=False,
resume=resume,
save_path=save_path,
)
# Validate optimizer choice
if optimizer not in ["moar", "v1"]:
typer.echo(
f"Error: optimizer must be 'moar' or 'v1', got '{optimizer}'", err=True
)
raise typer.Exit(1)
# Load YAML to check for optimizer_config
import yaml as yaml_lib
with open(yaml_file, "r") as f:
config = yaml_lib.safe_load(f)
if optimizer == "moar":
optimizer_config = config.get("optimizer_config", {})
if not optimizer_config:
example_yaml = """optimizer_config:
type: moar
save_dir: ./moar_results
available_models:
- gpt-5
- gpt-4o
evaluation_file: workloads/medical/evaluate_medications.py
metric_key: medication_extraction_score
max_iterations: 40
model: gpt-5"""
error_panel = Panel(
f"[bold red]Error:[/bold red] optimizer_config section is required in YAML for MOAR optimizer.\n\n"
f"[bold]Example:[/bold]\n"
f"[dim]{example_yaml}[/dim]\n\n"
f"[yellow]Note:[/yellow] dataset_name is inferred from the 'datasets' section. "
f"dataset_path can optionally be specified in optimizer_config, otherwise it's inferred from the 'datasets' section.",
title="[bold red]Missing optimizer_config[/bold red]",
border_style="red",
)
console.print(error_panel)
raise typer.Exit(1)
if optimizer_config.get("type") != "moar":
error_panel = Panel(
f"[bold red]Error:[/bold red] optimizer_config.type must be 'moar', got '[yellow]{optimizer_config.get('type')}[/yellow]'",
title="[bold red]Invalid optimizer type[/bold red]",
border_style="red",
)
console.print(error_panel)
raise typer.Exit(1)
# Validate required fields in optimizer_config
required_fields = {
"save_dir": "Output directory for MOAR results",
"available_models": "List of model names to use",
"evaluation_file": "Path to evaluation function file",
"metric_key": "Key to extract from evaluation results",
"max_iterations": "Number of MOARSearch iterations",
"model": "LLM model name for directive instantiation",
}
missing_fields = [
field for field in required_fields if not optimizer_config.get(field)
]
if missing_fields:
# Create a table for required fields
fields_table = Table(
show_header=True, header_style="bold cyan", box=None, padding=(0, 2)
)
fields_table.add_column("Field", style="yellow")
fields_table.add_column("Description", style="dim")
for field, desc in required_fields.items():
style = "bold red" if field in missing_fields else "dim"
fields_table.add_row(f"[{style}]{field}[/{style}]", desc)
# Create example YAML
example_yaml = """optimizer_config:
type: moar
save_dir: ./moar_results
available_models:
- gpt-5
- gpt-4o
evaluation_file: workloads/medical/evaluate_medications.py
metric_key: medication_extraction_score
max_iterations: 40
model: gpt-5"""
missing_list = ", ".join(
[f"[bold red]{f}[/bold red]" for f in missing_fields]
)
# Build error content with table rendered separately
from rich.console import Group
error_group = Group(
f"[bold red]Missing required fields:[/bold red] {missing_list}\n",
"[bold]Required fields:[/bold]",
fields_table,
f"\n[bold]Example:[/bold]\n[dim]{example_yaml}[/dim]\n",
"[yellow]Note:[/yellow] dataset_name is inferred from the 'datasets' section. "
"dataset_path can optionally be specified in optimizer_config, otherwise it's inferred from the 'datasets' section.",
)
error_panel = Panel(
error_group,
title="[bold red]Missing Required Fields[/bold red]",
border_style="red",
)
console.print(error_panel)
raise typer.Exit(1)
# Run MOAR optimization
from docetl.moar.cli_helpers import run_moar_optimization
try:
results = run_moar_optimization(
yaml_path=str(yaml_file),
optimizer_config=optimizer_config,
)
typer.echo("\n✅ MOAR optimization completed successfully!")
typer.echo(f" Results saved to: {optimizer_config.get('save_dir')}")
if results.get("evaluation_file"):
typer.echo(f" Evaluation: {results['evaluation_file']}")
except Exception as e:
typer.echo(f"Error running MOAR optimization: {e}", err=True)
raise typer.Exit(1)
else: # v1 optimizer (deprecated)
console.print(
Panel(
"[bold yellow]Warning:[/bold yellow] The V1 optimizer is deprecated. "
"Please use MOAR optimizer instead: [bold]docetl build pipeline.yaml --optimizer moar[/bold]",
title="[bold yellow]Deprecated Optimizer[/bold yellow]",
border_style="yellow",
)
)
runner = DSLRunner.from_yaml(str(yaml_file), max_threads=max_threads)
runner.optimize(
save=True,
return_pipeline=False,
resume=resume,
save_path=save_path,
)
@app.command()
@ -99,5 +244,104 @@ def version():
typer.echo(f"DocETL version: {docetl.__version__}")
@app.command("install-skill")
def install_skill(
uninstall: bool = typer.Option(
False, "--uninstall", "-u", help="Remove the installed skill instead"
),
):
"""
Install the DocETL Claude Code skill to your personal skills directory.
This makes the DocETL skill available in Claude Code for any project.
The skill helps you build and run DocETL pipelines.
"""
import shutil
# Find the skill source - try multiple locations
# 1. Installed package location (via importlib.resources)
# 2. Development location (relative to this file)
skill_source = None
# Try to find via package resources first
try:
import importlib.resources as pkg_resources
# For Python 3.9+, use files()
try:
package_root = Path(pkg_resources.files("docetl")).parent
potential_source = package_root / ".claude" / "skills" / "docetl"
if potential_source.exists():
skill_source = potential_source
except (TypeError, AttributeError):
pass
except ImportError:
pass
# Fallback: try relative to this file (development mode)
if skill_source is None:
dev_source = Path(__file__).parent.parent / ".claude" / "skills" / "docetl"
if dev_source.exists():
skill_source = dev_source
if skill_source is None or not skill_source.exists():
console.print(
Panel(
"[bold red]Error:[/bold red] Could not find the DocETL skill files.\n\n"
"This may happen if the package was not installed correctly.\n"
"Try reinstalling: [bold]pip install --force-reinstall docetl[/bold]",
title="[bold red]Skill Not Found[/bold red]",
border_style="red",
)
)
raise typer.Exit(1)
# Target directory
skill_target = Path.home() / ".claude" / "skills" / "docetl"
if uninstall:
if skill_target.exists():
shutil.rmtree(skill_target)
console.print(
Panel(
f"[bold green]Success![/bold green] DocETL skill removed from:\n"
f"[dim]{skill_target}[/dim]",
title="[bold green]Skill Uninstalled[/bold green]",
border_style="green",
)
)
else:
console.print(
Panel(
"[yellow]The DocETL skill is not currently installed.[/yellow]",
title="[yellow]Nothing to Uninstall[/yellow]",
border_style="yellow",
)
)
return
# Create parent directories if needed
skill_target.parent.mkdir(parents=True, exist_ok=True)
# Copy the skill
if skill_target.exists():
shutil.rmtree(skill_target)
shutil.copytree(skill_source, skill_target)
console.print(
Panel(
f"[bold green]Success![/bold green] DocETL skill installed to:\n"
f"[dim]{skill_target}[/dim]\n\n"
"[bold]Next steps:[/bold]\n"
"1. Restart Claude Code if it's running\n"
"2. The skill will automatically activate when you work on DocETL tasks\n\n"
"[dim]To uninstall: docetl install-skill --uninstall[/dim]",
title="[bold green]Skill Installed[/bold green]",
border_style="green",
)
)
if __name__ == "__main__":
app()

View File

@ -1,4 +1,5 @@
import os
import re
import threading
import time
from io import StringIO
@ -12,6 +13,63 @@ from docetl.utils import StageType, get_stage_description
install(show_locals=False)
# ANSI escape sequences to strip (cursor movement, line clearing, cursor visibility)
ANSI_CURSOR_PATTERNS = re.compile(
r"\x1b\[\?25[lh]" # Hide/show cursor
r"|\x1b\[\d*[ABCD]" # Cursor up/down/forward/back
r"|\x1b\[\d*[JK]" # Clear screen/line (but we handle \x1b[2K specially)
r"|\x1b\[s|\x1b\[u" # Save/restore cursor position
)
def process_carriage_returns(text: str) -> str:
"""
Process terminal control sequences for buffer-captured output.
Rich's spinner uses ANSI sequences for:
- `\r` - carriage return (move to start of line)
- `\x1b[2K` - clear entire line
- `\x1b[?25l/h` - hide/show cursor
- `\x1b[1A` - move cursor up
When captured to a buffer (not a real terminal), we simulate this by:
1. Handling `\r` as "replace from start of line"
2. Stripping cursor movement/visibility sequences
3. Keeping only meaningful content
"""
# First, strip cursor visibility and movement sequences
text = ANSI_CURSOR_PATTERNS.sub("", text)
# Process line by line
lines = text.split("\n")
processed_lines = []
for line in lines:
# Handle carriage returns - keep only content after the last \r
if "\r" in line:
segments = line.split("\r")
# Keep the last non-empty segment (after stripping ANSI clear codes)
for segment in reversed(segments):
# Strip the clear line code if present
cleaned = segment.replace("\x1b[2K", "").strip()
if cleaned:
# Keep the segment but preserve any color codes
processed_lines.append(segment.replace("\x1b[2K", ""))
break
else:
# All segments were empty, skip this line entirely
pass
else:
# No carriage return, keep the line as-is
if line.strip() or line == "": # Keep empty lines between content
processed_lines.append(line)
# Remove trailing empty lines
while processed_lines and not processed_lines[-1].strip():
processed_lines.pop()
return "\n".join(processed_lines)
class ThreadSafeConsole(Console):
def __init__(self, *args, **kwargs):
@ -24,11 +82,18 @@ class ThreadSafeConsole(Console):
self.optimizer_rationale = None
def get_output(self):
# return self.export_text(styles=True)
"""
Get the output from the buffer, processing carriage returns.
Rich's spinner uses carriage returns to overwrite lines in place.
We process these to only keep the latest content, preventing
duplicate spinner frames from flooding the output.
"""
value = self.buffer.getvalue()
self.buffer.truncate(0)
self.buffer.seek(0)
return value
# Process carriage returns to handle spinner overwrites
return process_carriage_returns(value)
def status(
self,
@ -36,10 +101,15 @@ class ThreadSafeConsole(Console):
*,
spinner: str = "dots",
spinner_style: "StyleType" = "status.spinner",
speed: float = 0.1, # Much slower speed
refresh_per_second: float = 0.5, # Much slower refresh rate (every 2 seconds)
speed: float = 1.0,
refresh_per_second: float = 4,
) -> "Status":
"""
Return a Rich Status with animation.
The carriage returns from the spinner animation are processed
in get_output() to prevent duplicate lines.
"""
status_renderable = Status(
status,
console=self,

View File

@ -465,7 +465,8 @@ class OpContainer:
return cached_data, 0, curr_logs
# Try to load from checkpoint if available
if not is_build:
# Skip if this operation has bypass_cache: true
if not is_build and not self.config.get("bypass_cache", False):
attempted_input_data = self.runner._load_from_checkpoint_if_exists(
self.name.split("/")[0], self.name.split("/")[-1]
)

1363
docetl/moar/MOARSearch.py Normal file

File diff suppressed because it is too large Load Diff

735
docetl/moar/Node.py Normal file
View File

@ -0,0 +1,735 @@
from __future__ import annotations
import math
import os
import random
from datetime import datetime
from typing import Any, Dict, Optional
import yaml
from dotenv import load_dotenv
from docetl.reasoning_optimizer.directives import Directive
from docetl.runner import DSLRunner
from docetl.utils import extract_output_from_json
class Node:
"""
A Node class for Monte Carlo Tree Search that represents a state in the search tree.
Each node holds:
- YAML file path and parsed content
- Visit count and value for UCB calculation
- Parent and children relationships
- Methods for tree traversal and expansion
"""
# A class-level counter for unique IDs
_id_counter = 0
@classmethod
def get_next_id(cls) -> int:
"""
Get the next available ID from the counter without incrementing it.
Returns:
int: The next ID that would be assigned
"""
return cls._id_counter
@classmethod
def increment_id_counter(cls) -> int:
"""
Increment the ID counter and return the new ID.
Returns:
int: The newly assigned ID
"""
new_id = cls._id_counter
cls._id_counter += 1
return new_id
def __init__(
self,
yaml_file_path: str,
parent: Optional[Node] = None,
c: float = 1.414,
message_history=[],
id: Optional[int] = None,
is_multi_instance: bool = False,
console=None,
):
"""
Initialize a Node with YAML file information.
Args:
yaml_file_path: Path to the YAML configuration file
parent: Parent node in the search tree
c: Exploration constant for UCB calculation (default: sqrt(2))
is_multi_instance: Whether this node is a multi-instance candidate (default: False)
console: Console instance for logging (default: None, uses DOCETL_CONSOLE)
"""
from docetl.console import DOCETL_CONSOLE
self.console = console if console is not None else DOCETL_CONSOLE
self.yaml_file_path = yaml_file_path
self.parsed_yaml = self._load_yaml()
# Where the JSON results will be written (if defined in the YAML). This is
# useful later for evaluation without having to guess filenames.
try:
self.result_path: str | None = (
self.parsed_yaml.get("pipeline", {}).get("output", {}).get("path")
)
except Exception:
self.result_path = None
self.on_frontier = False
self.used_actions = {}
self.op_dict = {} # Dict: op_name -> op
for op in self.parsed_yaml["operations"]:
op_name = op["name"]
self.op_dict[op_name] = op
self.used_actions[op_name] = set()
self.visits = 0
self.value = 0
self.parent = parent
self.children = []
self.c = c # Exploration constant for UCB
self.cost = -1.0
self.scaled_cost = -1.0 # Scaled cost in [0,1] range for reward calculations
self.sample_result = []
self.latest_action = None # Latest action that led to this node
self.optimization_goal = (
None # The optimization goal used for this node (e.g., 'acc', 'cost')
)
# Message history from root to this node (accumulated LLM conversations)
self.message_history = message_history
# Memo list to track (directive, target_operator) pairs from root to this node
self.memo = []
# Flag to indicate if this is a multi-instance candidate
self.is_multi_instance = is_multi_instance
# Assign a unique ID to this node
if id:
self.id = id
else:
self.id = Node._id_counter
Node._id_counter += 1
def execute_plan(self, max_threads: Optional[int] = None) -> tuple[float, list]:
"""
This method execute the query plan by running the YAML file with docetl.
Args:
max_threads (Optional[int]): Maximum number of threads to use for running operations.
Returns:
tuple[float, list]: A tuple containing (total_cost, result_data)
Raises:
Exception: If the pipeline execution fails.
"""
self.console.log(f"[dim]EXECUTING PLAN:[/dim] {self.yaml_file_path}")
# Get the current working directory (where the user called the command)
cwd = os.getcwd()
# Load .env file from the current working directory if it exists
env_file = os.path.join(cwd, ".env")
if os.path.exists(env_file):
load_dotenv(env_file)
try:
runner = DSLRunner.from_yaml(self.yaml_file_path, max_threads=max_threads)
# Print the query plan
runner.print_query_plan()
# Load datasets and execute the pipeline
runner.load()
# Execute the pipeline and get the result data
if runner.last_op_container:
result_data, _, _ = runner.last_op_container.next()
runner.save(result_data)
else:
result_data = []
# Get the total cost
total_cost = runner.total_cost
# Reset the environment
runner.reset_env()
self.cost = total_cost
try:
self.sample_result = extract_output_from_json(self.yaml_file_path)[:1]
except Exception as e:
self.console.log(
f"[yellow]Error extracting output from JSON for {self.yaml_file_path}: {e}[/yellow]"
)
self.sample_result = []
return total_cost
except Exception as e:
self.cost = -1 # Indicate failure
self.value = -float("inf")
# Log -inf occurrence for debugging
self._log_inf_occurrence("execution_failure", str(e), self.yaml_file_path)
raise Exception(f"Failed to execute plan {self.yaml_file_path}: {str(e)}")
def _load_yaml(self) -> Dict[str, Any]:
"""
Load and parse the YAML file.
Returns:
Parsed YAML content as a dictionary
"""
try:
with open(self.yaml_file_path, "r", encoding="utf-8") as file:
return yaml.safe_load(file)
except Exception as e:
self.console.log(
f"[yellow]Error loading YAML file {self.yaml_file_path}: {e}[/yellow]"
)
return {}
def best_child(self) -> Node:
"""
Return the child with the highest UCB (Upper Confidence Bound) value.
If there are ties, randomly select among the tied children.
UCB formula: value/visits + c * sqrt(ln(parent_visits) / visits)
Returns:
Child node with highest UCB, or None if no children exist
"""
def ucb(child: Node) -> float:
if child.cost == -1 or child.visits == 0:
return float("-inf")
exploitation = child.value / child.visits
exploration = self.c * math.sqrt(math.log(self.visits) / child.visits)
return exploitation + exploration
# Print visits and value for each child
for child in self.children:
self.console.log(
f"[dim]Child {child.yaml_file_path}: visits = {child.visits}, value = {child.value}[/dim]"
)
# Calculate UCB values for all children
ucb_values = [(child, ucb(child)) for child in self.children]
# Find the maximum UCB value
max_ucb = max(ucb_values, key=lambda x: x[1])[1]
# Find all children with the maximum UCB value (ties)
tied_children = [child for child, ucb_val in ucb_values if ucb_val == max_ucb]
# Randomly select among tied children
return random.choice(tied_children)
def add_child(self, child: Node):
"""
Add a new child node during tree expansion.
Args:
yaml_file_path: Path to the YAML file for the new child node
Returns:
The newly created child node
"""
self.children.append(child)
child.parent = self
def is_leaf(self) -> bool:
"""
Check if this node is a leaf (has no children).
Returns:
True if the node has no children, False otherwise
"""
return len(self.children) == 0
def mark_action_used(self, op_name, action: Directive):
"""
Mark a rewrite action as used.
Args:
action: The action identifier to mark as used
"""
self.used_actions[op_name].add(action)
def is_root(self) -> bool:
"""
Check if this node is the root (has no parent).
Returns:
True if the node has no parent, False otherwise
"""
return self.parent is None
def update_value(self, value: float):
"""
Update the node's value (typically after a simulation).
Args:
value: The value to add to the current node value
"""
# Guard against NaN and -inf values to prevent corruption of node.value
# Don't backpropagate -inf (failed evaluations) or NaN to parent nodes
if (
value is None
or (isinstance(value, float) and (value != value))
or value == float("-inf")
): # NaN or -inf check
self.console.log(
f"[yellow]⚠️ Skipping backpropagation of -inf / NaN value to node {self.get_id()}[/yellow]"
)
# Log -inf occurrence for debugging
self._log_inf_occurrence(
"backpropagation_skipped",
f"Skipped backpropagation of value: {value}",
self.yaml_file_path,
)
return
self.value = self.value + value
def update_visit(self):
"""
Update the node's visit by 1 (typically after a simulation).
"""
self.visits += 1
def get_ucb(self) -> float:
"""
Calculate the UCB value for this node.
Returns:
UCB value for this node
"""
if self.visits == 0:
return float("inf")
if self.parent is None:
return self.value / self.visits
exploitation = self.value / self.visits
exploration = self.c * math.sqrt(math.log(self.parent.visits) / self.visits)
return exploitation + exploration
def get_id(self) -> int:
"""
Return the unique identifier for this node.
Returns:
int: The unique ID of the node
"""
return self.id
def set_id_to_counter(self):
"""
Change this node's ID to the next available counter ID.
This is used after selecting the best multi-instance candidate.
Also renames the associated files to match the new ID.
Returns:
int: The new ID assigned to this node
"""
old_id = self.id
new_id = self.increment_id_counter()
# Rename files to match the new ID
self._rename_files_for_new_id(old_id, new_id)
self.id = new_id
return self.id
def _log_inf_occurrence(
self, failure_type: str, error_message: str, yaml_path: str
):
"""
Log -inf occurrences to a dedicated log file for debugging.
Args:
failure_type: Type of failure (e.g., "execution_failure", "evaluation_failure")
error_message: The error message that caused the failure
yaml_path: Path to the YAML file that failed
"""
try:
# Create log directory if it doesn't exist
log_dir = os.path.join(os.path.dirname(yaml_path), "inf_logs")
os.makedirs(log_dir, exist_ok=True)
# Create log file path
log_file = os.path.join(log_dir, "inf_occurrences.txt")
# Get current timestamp
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Get node information
node_id = getattr(self, "id", "unknown")
latest_action = getattr(self, "latest_action", "unknown")
parent_id = getattr(self.parent, "id", "root") if self.parent else "root"
# Format log entry
log_entry = f"""
{'='*80}
Timestamp: {timestamp}
Node ID: {node_id}
Parent ID: {parent_id}
Latest Action: {latest_action}
Failure Type: {failure_type}
YAML Path: {yaml_path}
Error Message: {error_message}
{'='*80}
"""
# Append to log file
with open(log_file, "a", encoding="utf-8") as f:
f.write(log_entry)
except Exception as log_error:
self.console.log(
f"[yellow]Warning: Failed to log -inf occurrence: {log_error}[/yellow]"
)
def _rename_files_for_new_id(self, old_id, new_id):
"""
Rename the YAML and output files to match the new node ID.
Args:
old_id: The old node ID (e.g., "7-2")
new_id: The new node ID (e.g., 8)
"""
try:
# Rename YAML file
if os.path.exists(self.yaml_file_path):
old_yaml_path = self.yaml_file_path
new_yaml_path = old_yaml_path.replace(
f"_{old_id}.yaml", f"_{new_id}.yaml"
)
os.rename(old_yaml_path, new_yaml_path)
self.yaml_file_path = new_yaml_path
except Exception as e:
self.console.log(
f"[yellow]Warning: Could not rename YAML file from {old_id} to {new_id}: {e}[/yellow]"
)
try:
# Rename output JSON file
if self.result_path and os.path.exists(self.result_path):
old_result_path = self.result_path
new_result_path = old_result_path.replace(
f"_{old_id}.json", f"_{new_id}.json"
)
os.rename(old_result_path, new_result_path)
self.result_path = new_result_path
# Also update the YAML to point to the new output file
if hasattr(self, "parsed_yaml") and self.parsed_yaml:
self.parsed_yaml["pipeline"]["output"]["path"] = new_result_path
# Rewrite the YAML file with the updated output path
with open(self.yaml_file_path, "w") as f:
yaml.dump(
self.parsed_yaml,
f,
default_flow_style=False,
allow_unicode=True,
sort_keys=False,
)
except Exception as e:
self.console.log(
f"[yellow]Warning: Could not rename result file from {old_id} to {new_id}: {e}[/yellow]"
)
def add_memo_entry(self, directive_name: str, target_operator: str):
"""
Add a (directive, target_operator) pair to the memo list.
Args:
directive_name: Name of the directive that was applied
target_operator: Name of the target operator the directive was applied to
"""
self.memo.append((directive_name, target_operator))
def get_optimization_path(self) -> str:
"""
Get a formatted string showing the optimization path from root to this node.
Returns:
Formatted path string like "ROOT → chaining(extract_clause) → gleaning(extract_entity)"
"""
if not self.memo:
return "ROOT"
path_parts = ["ROOT"]
for directive, target_op in self.memo:
path_parts.append(f"{directive}({target_op})")
return "".join(path_parts)
def get_exploration_tree_summary(
self, root: Node, node_accuracies: Optional[Dict["Node", float]] = None
) -> str:
"""
Generate a comprehensive but concise summary of the entire exploration tree.
This gives the LLM agent complete context about what has been tried.
Args:
root: Root node of the tree
node_accuracies: Optional dictionary mapping nodes to their accuracy values
Returns:
Formatted tree summary optimized for LLM consumption
"""
# Collect all exploration paths and their outcomes
successful_paths = []
failed_paths = []
def traverse_tree(node, current_path="ROOT"):
# Add this node's path if it's not the root
if node != root:
if hasattr(node, "cost") and node.cost != -1:
# Include accuracy if available
if node_accuracies and node in node_accuracies:
accuracy = node_accuracies[node]
successful_paths.append(
f"{current_path} (cost: ${node.cost:.2f}, accuracy: {accuracy:.3f})"
)
else:
successful_paths.append(
f"{current_path} (cost: ${node.cost:.2f})"
)
else:
failed_paths.append(f"{current_path} (failed)")
# Traverse children
for child in node.children:
if child.memo:
# Get the most recent directive-operator pair for this child
latest_directive, latest_target = child.memo[-1]
child_path = f"{current_path}{latest_directive}({latest_target})"
else:
child_path = f"{current_path}{child.latest_action.name if child.latest_action else 'unknown'}"
traverse_tree(child, child_path)
traverse_tree(root)
# Group paths by directive patterns for better insights
directive_patterns = {}
for path in successful_paths + failed_paths:
# Extract directive sequence
directives = []
parts = path.split("")
for part in parts[1:]: # Skip ROOT
if "(" in part:
directive = part.split("(")[0]
directives.append(directive)
if directives:
pattern_key = "".join(directives)
if pattern_key not in directive_patterns:
directive_patterns[pattern_key] = {"successful": [], "failed": []}
if "(failed)" in path:
directive_patterns[pattern_key]["failed"].append(path)
else:
directive_patterns[pattern_key]["successful"].append(path)
# Build summary
summary_parts = []
# Current position
current_path = self.get_optimization_path()
summary_parts.append(f"CURRENT POSITION: {current_path}")
# Successful explorations
if successful_paths:
summary_parts.append(
f"\nSUCCESSFUL EXPLORATIONS ({len(successful_paths)} total):"
)
# Show best performers first - sort by accuracy (highest first), then by cost (lowest first)
def extract_sort_key(path):
if "cost: $" not in path:
return (
0,
float("inf"),
) # lowest accuracy, highest cost for failed cases
try:
cost_part = path.split("cost: $")[1]
if ", accuracy:" in cost_part:
cost_str = cost_part.split(", accuracy:")[0]
accuracy_str = cost_part.split(", accuracy:")[1].split(")")[0]
return (
-float(accuracy_str),
float(cost_str),
) # negative accuracy for descending order
else:
cost_str = cost_part.split(")")[0]
return (
0,
float(cost_str),
) # no accuracy info, sort by cost only
except (ValueError, IndexError):
return (0, float("inf"))
sorted_successful = sorted(successful_paths, key=extract_sort_key)
for i, path in enumerate(sorted_successful):
summary_parts.append(f" {i+1}. {path}")
return "\n".join(summary_parts)
def get_memo_for_llm(
self, root_node: Node, node_accuracies: Optional[Dict["Node", float]] = None
) -> str:
"""
Get a comprehensive exploration summary formatted for LLM prompts.
Args:
root_node: Root node of the tree
node_accuracies: Optional dictionary mapping nodes to their accuracy values
Returns:
Complete exploration context to guide decision making
"""
return self.get_exploration_tree_summary(root_node, node_accuracies)
def delete(self, selected_node_final_id=None):
"""
Delete this node and clean up its resources.
For multi-instance candidates, moves files to backup_plans folder instead of deleting.
Args:
selected_node_final_id: The final ID of the selected node (for backup naming)
This method:
1. Removes the node from its parent's children list
2. Moves multi-instance files to backup or deletes regular files
3. Clears references to prevent memory leaks
The global node ID counter is maintained (not decremented) to ensure
unique IDs across the lifetime of the program.
"""
# Remove from parent's children list
if self.parent and self in self.parent.children:
self.parent.children.remove(self)
# Check if this is a multi-instance candidate using the explicit flag
if self.is_multi_instance and selected_node_final_id is not None:
# Move files to backup_plans folder with renamed IDs
self._backup_multi_instance_files(selected_node_final_id)
else:
# Regular deletion for non-multi-instance nodes
self._delete_files_permanently()
# Clear references to help with garbage collection
self.parent = None
self.children = []
self.parsed_yaml = {}
self.message_history = []
self.memo = []
self.sample_result = []
def _backup_multi_instance_files(self, selected_node_final_id):
"""
Move multi-instance files to backup_plans folder with new naming scheme.
Args:
selected_node_final_id: The final ID of the selected node
"""
try:
# Extract the instantiation number from current ID (e.g., "7-2" -> "2")
current_id_str = str(self.id)
if "-" in current_id_str:
instantiation_num = current_id_str.split("-")[1]
new_backup_id = f"{selected_node_final_id}-{instantiation_num}"
else:
new_backup_id = f"{selected_node_final_id}-backup"
# Create backup_plans directory
if os.path.exists(self.yaml_file_path):
yaml_dir = os.path.dirname(self.yaml_file_path)
backup_dir = os.path.join(yaml_dir, "backup_plans")
os.makedirs(backup_dir, exist_ok=True)
# Move YAML file
yaml_filename = os.path.basename(self.yaml_file_path)
# Replace old ID with new backup ID in filename
new_yaml_filename = yaml_filename.replace(
f"_{current_id_str}.yaml", f"_{new_backup_id}.yaml"
)
backup_yaml_path = os.path.join(backup_dir, new_yaml_filename)
if os.path.exists(self.yaml_file_path):
os.rename(self.yaml_file_path, backup_yaml_path)
self.console.log(
f"[dim]Moved YAML to backup:[/dim] {self.yaml_file_path}{backup_yaml_path}"
)
# Move result JSON file
if self.result_path and os.path.exists(self.result_path):
result_filename = os.path.basename(self.result_path)
new_result_filename = result_filename.replace(
f"_{current_id_str}.json", f"_{new_backup_id}.json"
)
backup_result_path = os.path.join(backup_dir, new_result_filename)
os.rename(self.result_path, backup_result_path)
self.console.log(
f"[dim]Moved result to backup:[/dim] {self.result_path}{backup_result_path}"
)
except Exception as e:
self.console.log(
f"[yellow]Warning: Could not backup files for multi-instance node {self.id}: {e}[/yellow]"
)
# Fall back to regular deletion if backup fails
self._delete_files_permanently()
def _delete_files_permanently(self):
"""
Permanently delete files (for non-multi-instance nodes).
"""
try:
if os.path.exists(self.yaml_file_path) and str(
self.yaml_file_path
).endswith((".yaml", ".yml")):
# Only delete if it looks like a generated file (contains numbers)
if any(
char.isdigit()
for char in os.path.basename(str(self.yaml_file_path))
):
os.remove(self.yaml_file_path)
self.console.log(
f"[dim]Deleted generated YAML file:[/dim] {self.yaml_file_path}"
)
except Exception as e:
self.console.log(
f"[yellow]Warning: Could not delete YAML file {self.yaml_file_path}: {e}[/yellow]"
)
try:
if self.result_path and os.path.exists(self.result_path):
# Only delete if it looks like a generated file (contains numbers)
if any(char.isdigit() for char in os.path.basename(self.result_path)):
os.remove(self.result_path)
self.console.log(
f"[dim]Deleted generated result file:[/dim] {self.result_path}"
)
except Exception as e:
self.console.log(
f"[yellow]Warning: Could not delete result file {self.result_path}: {e}[/yellow]"
)

View File

@ -0,0 +1,425 @@
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import matplotlib
matplotlib.use("Agg") # Use non-interactive backend
import matplotlib.pyplot as plt
from .Node import Node
class ParetoFrontier:
"""
Pareto Frontier class for managing cost-accuracy optimization.
This class maintains a collection of plans, estimates their accuracy through
pairwise comparisons, constructs and updates the Pareto frontier, and provides
value calculations for MCTS integration.
"""
def __init__(
self,
action_rewards: Dict[str, float],
action_cost_changes: Dict[str, float],
action_accuracy_changes: Dict[str, float],
dataset_name: str,
evaluate_func: Callable[[str], Dict[str, Any]],
console=None,
):
"""
Initialize the Pareto Frontier.
Args:
action_rewards: Reference to MCTS action_rewards dictionary
action_cost_changes: Reference to MCTS action_cost_changes dictionary
action_accuracy_changes: Reference to MCTS action_accuracy_changes dictionary
dataset_name: Name of the dataset being optimized (for evaluation and metric selection)
evaluate_func: Evaluation function (results_file_path: str) -> dict
console: Console instance for logging (default: None, uses DOCETL_CONSOLE)
"""
from docetl.console import DOCETL_CONSOLE
self.console = console if console is not None else DOCETL_CONSOLE
self.dataset_name = dataset_name
self.evaluate_func = evaluate_func
# Dataset-to-primary-metric mapping
self.dataset_metrics = {
"cuad": "avg_f1",
"blackvault": "avg_distinct_locations",
"game_reviews": "weighted_score",
"medec": "combined_score",
"sustainability": "combined_score",
"biodex": "avg_rp_at_5", # Optimize for RP@5 as specified
"facility": "combined_score",
}
# Internal state
self.plans: List[Node] = []
self.plans_accuracy: Dict[Node, float] = {}
self.plans_cost: Dict[Node, float] = {} # Real costs
self.frontier_plans: List[Node] = [] # List of nodes on frontier
self.frontier_data: List[List[int]] = (
[]
) # List of [acc, real_cost] of nodes on frontier
self.action_rewards = action_rewards
self.action_cost_changes = action_cost_changes
self.action_accuracy_changes = action_accuracy_changes
# Distance to current Pareto frontier: positive for on-frontier, negative for off-frontier
self.node_distances: Dict[Node, float] = {}
# Root plan reference point
self.root_accuracy: Optional[float] = None
self.root_cost: Optional[float] = None
def add_plan(self, node: Node) -> Dict[Node, int]:
"""
Add a new plan (Node) to the frontier and estimate its accuracy.
Args:
node: Node object representing the plan
Returns:
Dict containing estimated accuracy, pareto_value, and other metrics
"""
if node.cost == -1: # Handle error case
return {}
# Store plan information
self.plans.append(node)
self.plans_cost[node] = node.cost
# Estimate accuracy through pairwise comparisons
if len(self.plans_accuracy) == 0:
# First plan gets baseline accuracy
estimated_accuracy = 0.5
else:
estimated_accuracy = self.estimate_accuracy_via_comparisons(node)
self.plans_accuracy[node] = estimated_accuracy
# Update Pareto frontier
affected_nodes = self.update_pareto_frontier()
if node not in affected_nodes:
affected_nodes[node] = 0
return affected_nodes
def add_plan_f1(self, node: Node, accuracy: float) -> Tuple[Dict[Node, int], bool]:
"""
Add a new plan (Node) to the frontier with pre-evaluated accuracy.
Args:
node: Node object representing the plan
accuracy: Pre-evaluated accuracy score for the node
Returns:
Dict containing affected nodes, bool indicating wether the frontier is updated
"""
if node.cost == -1: # Handle error case
self.plans_accuracy[node] = float("-inf")
return {}, False
# Store plan information
self.plans.append(node)
self.plans_cost[node] = node.cost # Store real cost
# Scaled cost will be calculated in update_pareto_frontier_HV
# Store the pre-evaluated accuracy
self.plans_accuracy[node] = accuracy
# Update Pareto frontier
affected_nodes, is_frontier_updated = self.update_pareto_frontier_HV(node)
return affected_nodes, is_frontier_updated
def get_all_plans_summary(self) -> List[Dict[str, Any]]:
"""
Get summary of all plans with their metrics.
Returns:
List of dictionaries containing plan information and metrics
"""
summaries = []
for node in self.plans:
summary = {
"node": node.get_id(),
"path": node.yaml_file_path,
"cost": node.cost,
"accuracy": self.plans_accuracy[node],
"value": node.value,
"is_frontier": node in self.frontier_plans,
}
summaries.append(summary)
return summaries
# Helper function to project point onto step function frontier
def project_to_frontier(self, node_acc, node_cost, frontier_data):
"""
Project point onto the step function formed by frontier.
For a step function interpretation, the reward is simply the vertical distance
to the step function (accuracy distance only).
"""
if not frontier_data:
return node_acc
# Sort frontier by cost (ascending)
frontier_sorted = sorted(frontier_data, key=lambda x: x[1]) # Sort by cost
# Find the step function accuracy for this cost
step_function_accuracy = (
0.0 # Default if cost is lower than all frontier points
)
for fp_acc, fp_cost in frontier_sorted:
if node_cost >= fp_cost:
# Cost is >= this frontier point's cost, so step function is at this accuracy
step_function_accuracy = fp_acc
else:
# Cost is < this frontier point's cost, so we use the previous step
break
# Return the vertical (accuracy) distance to the step function
vertical_distance = abs(node_acc - step_function_accuracy)
return vertical_distance
def _update_action_rewards(self, node: Node, reward: float) -> None:
"""
Update action rewards and track cost/accuracy changes based on the reward received by a node.
Updates the cumulative sum for the latest action that led to this node.
Args:
node: The node that received the reward
reward: The reward value to incorporate
"""
if not node.latest_action or not self.action_rewards:
return
action = node.latest_action
if action in self.action_rewards:
# Update cumulative reward sum
self.action_rewards[action] += reward
# Track cost and accuracy changes
if (
node.parent
and node.parent in self.plans_cost
and node in self.plans_cost
):
cost_change = self.plans_cost[node] - self.plans_cost[node.parent]
self.action_cost_changes[action] += cost_change
if (
node.parent
and node.parent in self.plans_accuracy
and node in self.plans_accuracy
):
accuracy_change = (
self.plans_accuracy[node] - self.plans_accuracy[node.parent]
)
self.action_accuracy_changes[action] += accuracy_change
def update_pareto_frontier_HV(self, new_node) -> Tuple[Dict[Node, int], bool]:
"""
Update the Pareto frontier based on current plans and calculate hyper-volume indicator.
"""
valid_nodes = [node for node in self.plans if node.cost != -1]
affected_nodes = {}
if not valid_nodes:
self.frontier_plans = []
self.frontier_data = []
return affected_nodes, False
# Save old frontier nodes before updating
old_frontier_nodes = self.frontier_plans
# Sort by real cost for frontier calculation
valid_nodes.sort(key=lambda node: self.plans_cost[node])
# Reconstruct old frontier data using real costs
archive_frontier_data = []
for node in old_frontier_nodes:
if node in valid_nodes: # Only include valid nodes
acc = self.plans_accuracy.get(node, float("-inf"))
real_cost = self.plans_cost[node]
archive_frontier_data.append([acc, real_cost])
else:
self.console.log(
f"[yellow]INVALID NODE:[/yellow] {node.id}, [dim]cost:[/dim] {node.cost}, [dim]in_valid_nodes:[/dim] {node in valid_nodes}"
)
frontier = []
max_accuracy_so_far = float("-inf")
for node in valid_nodes:
accuracy = self.plans_accuracy.get(node, 0.0)
# Plan is on frontier if it has higher accuracy than all lower-cost plans
if accuracy > max_accuracy_so_far:
frontier.append(node)
max_accuracy_so_far = accuracy
new_frontier_data = []
for node in frontier:
acc = self.plans_accuracy.get(node)
real_cost = self.plans_cost[node] # Use real cost
new_frontier_data.append([acc, real_cost])
# Check if frontier actually changed
old_frontier_set = set(old_frontier_nodes)
new_frontier_set = set(frontier)
frontier_updated = old_frontier_set != new_frontier_set
# Update affected nodes based on frontier changes
for node in valid_nodes:
node_real_cost = self.plans_cost[node]
node_acc = self.plans_accuracy[node]
if node in new_frontier_set and node not in old_frontier_set:
# Newly on frontier - reward based on vertical distance to OLD frontier step function
node.on_frontier = True
vertical_distance_to_old = self.project_to_frontier(
node_acc, node_real_cost, archive_frontier_data
)
affected_nodes[node] = vertical_distance_to_old
# Update node distances - positive for on frontier
self.node_distances[node] = vertical_distance_to_old
# Update action rewards
self._update_action_rewards(node, vertical_distance_to_old)
elif (node not in new_frontier_set and node in old_frontier_set) or (
node.id == new_node.id
):
# Newly off frontier - give negative reward based on vertical distance to NEW frontier step function
node.on_frontier = False
vertical_distance = self.project_to_frontier(
node_acc, node_real_cost, new_frontier_data
)
affected_nodes[node] = -vertical_distance
# Update node distances - negative for off frontier
self.node_distances[node] = -vertical_distance
# Update action rewards
if node.id == new_node.id:
self._update_action_rewards(node, -vertical_distance)
elif node not in new_frontier_set:
# stay off frontier nodes - update the reward to be negative vertical distance to the NEW frontier step function
node.on_frontier = False
vertical_distance = self.project_to_frontier(
node_acc, node_real_cost, new_frontier_data
)
old_distance = self.node_distances.get(node, 0)
distance_diff = -vertical_distance - old_distance
affected_nodes[node] = distance_diff
# Update node distances - negative for off frontier
self.node_distances[node] = -vertical_distance
self.frontier_plans = frontier
self.frontier_data = new_frontier_data
if new_node.id > 0:
graph_dir = str(new_node.yaml_file_path).rsplit("/", 1)[0] + "/graph/"
os.makedirs(graph_dir, exist_ok=True)
save_path = graph_dir + f"plan_{new_node.id}.png"
self.plot_plans(save_path, new_node.id, str(new_node.yaml_file_path))
return affected_nodes, frontier_updated
def plot_plans(self, save_path=None, plan_num=None, yaml_file=None):
"""
Plot all current plans as dots on a cost vs. accuracy graph, annotating each with its id.
Frontier plans are blue, non-frontier plans are grey.
Args:
save_path: If provided, save the plot to this path instead of showing it
iteration_num: If provided, include iteration number in the title
"""
if plt is None:
raise ImportError(
"matplotlib is required for plotting. Please install it with 'pip install matplotlib'."
)
plt.figure(figsize=(10, 8))
# Separate frontier and non-frontier plans
frontier_nodes = [node for node in self.plans if node in self.frontier_plans]
non_frontier_nodes = [
node for node in self.plans if node not in self.frontier_plans
]
# Plot non-frontier plans (grey)
if non_frontier_nodes:
costs = [self.plans_cost[node] for node in non_frontier_nodes]
accuracies = [self.plans_accuracy[node] for node in non_frontier_nodes]
ids = [node.get_id() for node in non_frontier_nodes]
plt.scatter(costs, accuracies, color="grey", label="Off Frontier")
for x, y, label in zip(costs, accuracies, ids):
plt.annotate(
str(label),
(x, y),
textcoords="offset points",
xytext=(5, 5),
ha="left",
fontsize=9,
color="grey",
)
# Plot frontier plans (blue)
if frontier_nodes:
costs = [self.plans_cost[node] for node in frontier_nodes]
accuracies = [self.plans_accuracy[node] for node in frontier_nodes]
ids = [node.get_id() for node in frontier_nodes]
plt.scatter(costs, accuracies, color="blue", label="Frontier")
for x, y, label in zip(costs, accuracies, ids):
plt.annotate(
str(label),
(x, y),
textcoords="offset points",
xytext=(5, 5),
ha="left",
fontsize=9,
color="blue",
)
plt.xlabel("Cost")
plt.ylabel("Accuracy")
if plan_num is not None:
plt.title(f"Plan {plan_num}: {yaml_file}")
else:
plt.title("Plans: Cost vs. Accuracy")
plt.grid(True, linestyle="--", alpha=0.5)
plt.legend()
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close()
def __len__(self) -> int:
"""Return number of plans in the frontier."""
return len(self.plans)
def __contains__(self, node: Node) -> bool:
"""Check if plan is managed by this frontier."""
return node in self.plans
def delete_plan(self, node: Node) -> None:
"""
Completely delete a node from all ParetoFrontier data structures.
"""
if node in self.plans:
self.plans.remove(node)
accuracy = self.plans_accuracy.pop(node, None)
cost = self.plans_cost.pop(node, None)
if node in self.frontier_plans:
self.frontier_plans.remove(node)
if accuracy is not None and cost is not None:
try:
self.frontier_data.remove([accuracy, cost])
except ValueError:
pass
self.node_distances.pop(node, None)

33
docetl/moar/__init__.py Normal file
View File

@ -0,0 +1,33 @@
"""
MCTS (Monte Carlo Tree Search) module for DocETL optimization.
This module provides Monte Carlo Tree Search optimization for DocETL pipelines
using Pareto frontier analysis and multi-objective optimization.
"""
# Default list of available models for MOAR optimization
# Defined here before imports to avoid circular import issues
AVAILABLE_MODELS = [
# "gpt-5",
# "gpt-5-mini",
# "gpt-5-nano",
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
"gpt-4o",
"gpt-4o-mini",
# "gemini-2.5-pro",
# "gemini-2.5-flash",
# "gemini-2.5-flash-lite"
]
from .MOARSearch import MOARSearch # noqa: E402
from .Node import Node # noqa: E402
from .ParetoFrontier import ParetoFrontier # noqa: E402
__all__ = [
"MOARSearch",
"Node",
"ParetoFrontier",
"AVAILABLE_MODELS"
]

305
docetl/moar/cli_helpers.py Normal file
View File

@ -0,0 +1,305 @@
"""
Helper functions for running MOAR optimizer from CLI.
"""
import json
from datetime import datetime
from pathlib import Path
from typing import Any, Dict
import yaml
from docetl.console import DOCETL_CONSOLE
from docetl.moar import MOARSearch
from docetl.reasoning_optimizer.directives import ALL_DIRECTIVES
from docetl.utils_dataset import get_dataset_stats
from docetl.utils_evaluation import load_custom_evaluate_func
def infer_dataset_info(yaml_path: str, config: dict) -> tuple[str, str]:
"""
Infer dataset path and name from YAML config.
Args:
yaml_path: Path to YAML file
config: Full YAML config dictionary
Returns:
tuple: (dataset_path, dataset_name)
Raises:
ValueError: If datasets section is missing or empty
"""
datasets = config.get("datasets", {})
if not datasets:
raise ValueError("YAML config must contain a 'datasets' section")
# Get the first dataset (assuming single dataset per config)
dataset_name, dataset_config = next(iter(datasets.items()))
dataset_path = dataset_config.get("path")
if not dataset_path:
raise ValueError(f"Dataset '{dataset_name}' in config must have a 'path' field")
# Resolve relative paths - try as-is first, then relative to YAML file location
if Path(dataset_path).is_absolute():
# Already absolute, use as-is
pass
elif Path(dataset_path).exists():
# Path exists as-is (relative to current working directory)
dataset_path = str(Path(dataset_path).resolve())
else:
# Try resolving relative to YAML file location
yaml_dir = Path(yaml_path).parent
resolved_path = yaml_dir / dataset_path
if resolved_path.exists():
dataset_path = str(resolved_path.resolve())
else:
# Use the resolved path anyway (might be created later)
dataset_path = str(resolved_path.resolve())
return dataset_path, dataset_name
def load_evaluation_function(config: dict, dataset_file_path: str):
"""
Load evaluation function from optimizer_config.
Args:
config: optimizer_config dictionary from YAML
dataset_file_path: Path to the dataset file
Returns:
callable: Evaluation function
Raises:
ValueError: If required parameters are missing
"""
evaluation_file = config.get("evaluation_file")
if not evaluation_file:
raise ValueError(
"optimizer_config must contain 'evaluation_file' (path to Python file with @docetl.register_eval decorated function)"
)
DOCETL_CONSOLE.log(
f"[bold blue]📊 Loading evaluation function from: {evaluation_file}[/bold blue]"
)
evaluate_func = load_custom_evaluate_func(evaluation_file, dataset_file_path)
DOCETL_CONSOLE.log("[green]✅ Evaluation function loaded[/green]")
return evaluate_func
def run_moar_optimization(
yaml_path: str,
optimizer_config: dict,
) -> Dict[str, Any]:
"""
Run MOAR optimization from CLI.
Args:
yaml_path: Path to the YAML pipeline file
optimizer_config: optimizer_config dictionary from YAML (must contain save_dir)
Returns:
dict: Experiment summary
"""
# Load full config to infer dataset info
with open(yaml_path, "r") as f:
full_config = yaml.safe_load(f)
# Use dataset_path from optimizer_config if provided, otherwise infer from datasets section
if optimizer_config.get("dataset_path"):
dataset_path = optimizer_config.get("dataset_path")
# Resolve relative paths
if Path(dataset_path).is_absolute():
dataset_path = str(Path(dataset_path).resolve())
elif Path(dataset_path).exists():
dataset_path = str(Path(dataset_path).resolve())
else:
yaml_dir = Path(yaml_path).parent
dataset_path = str((yaml_dir / dataset_path).resolve())
# Infer dataset name from datasets section
_, dataset_name = infer_dataset_info(yaml_path, full_config)
else:
# Infer both dataset path and name from config
dataset_path, dataset_name = infer_dataset_info(yaml_path, full_config)
# Extract MOAR parameters from optimizer_config (all required, no defaults)
save_dir = optimizer_config.get("save_dir")
if not save_dir:
raise ValueError("optimizer_config must contain 'save_dir' for MOAR optimizer")
available_models = optimizer_config.get("available_models")
if not available_models:
raise ValueError(
"optimizer_config must contain 'available_models' (list of model names) for MOAR optimizer"
)
evaluation_file = optimizer_config.get("evaluation_file")
if not evaluation_file:
raise ValueError(
"optimizer_config must contain 'evaluation_file' (path to Python file with @docetl.register_eval decorated function) for MOAR optimizer"
)
metric_key = optimizer_config.get("metric_key")
if not metric_key:
raise ValueError(
"optimizer_config must contain 'metric_key' (key to extract from evaluation results) for MOAR optimizer"
)
max_iterations = optimizer_config.get("max_iterations")
if max_iterations is None:
raise ValueError(
"optimizer_config must contain 'max_iterations' (number of MOARSearch iterations) for MOAR optimizer"
)
# Use rewrite_agent_model for consistency with rest of codebase, fallback to model for backwards compatibility
model = optimizer_config.get("rewrite_agent_model") or optimizer_config.get("model")
if not model:
raise ValueError(
"optimizer_config must contain 'rewrite_agent_model' (LLM model name for directive instantiation) for MOAR optimizer"
)
# Optional parameters
exploration_weight = optimizer_config.get("exploration_weight", 1.414)
build_first_layer = optimizer_config.get("build_first_layer", False)
# Resolve save directory (handle relative paths)
save_dir = Path(save_dir)
if not save_dir.is_absolute():
# Resolve relative to current working directory
save_dir = Path.cwd() / save_dir
save_dir.mkdir(parents=True, exist_ok=True)
DOCETL_CONSOLE.log("[bold blue]🌳 Running MOARSearch[/bold blue]")
DOCETL_CONSOLE.log(f"[dim]Input Pipeline:[/dim] {yaml_path}")
DOCETL_CONSOLE.log(f"[dim]Save Directory:[/dim] {save_dir}")
DOCETL_CONSOLE.log(f"[dim]Max Iterations:[/dim] {max_iterations}")
DOCETL_CONSOLE.log(f"[dim]Exploration Weight (c):[/dim] {exploration_weight}")
DOCETL_CONSOLE.log(f"[dim]Model:[/dim] {model}")
DOCETL_CONSOLE.log(f"[dim]Dataset:[/dim] {dataset_name}")
DOCETL_CONSOLE.log()
# Load sample input data
DOCETL_CONSOLE.log("[bold blue]🚀 Initializing MOARSearch...[/bold blue]")
with open(dataset_path, "r") as f:
dataset_data = json.load(f)
# Take only the first 5 documents for sample input
if isinstance(dataset_data, list):
sample_input_data = dataset_data[:5]
else:
sample_input_data = dataset_data
# Use all registered rewrite directives
available_actions = set(ALL_DIRECTIVES)
# Get dataset statistics
dataset_stats = get_dataset_stats(yaml_path, dataset_name)
# Load evaluation function (pass dataset_path so it can be provided to eval function)
evaluate_func = load_evaluation_function(optimizer_config, dataset_path)
# Initialize MOARSearch
moar = MOARSearch(
root_yaml_path=yaml_path,
available_actions=available_actions,
sample_input=sample_input_data,
dataset_stats=dataset_stats,
dataset_name=dataset_name,
available_models=available_models,
evaluate_func=evaluate_func,
exploration_constant=exploration_weight,
max_iterations=max_iterations,
model=model,
output_dir=str(save_dir),
build_first_layer=build_first_layer,
custom_metric_key=metric_key,
sample_dataset_path=dataset_path, # Use the dataset_path (which may be from optimizer_config)
)
DOCETL_CONSOLE.log(
f"[green]✅ MOARSearch initialized with root node: {yaml_path}[/green]"
)
# Run MOARSearch optimization
DOCETL_CONSOLE.log(
f"[bold blue]\n🔍 Running MOARSearch optimization for {max_iterations} iterations...[/bold blue]"
)
start_time = datetime.now()
best_nodes = moar.search()
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
DOCETL_CONSOLE.log(
f"[green]✅ MOARSearch optimization completed in {duration:.2f} seconds[/green]"
)
# Run evaluation
DOCETL_CONSOLE.log("[bold blue]📊 Running evaluation...[/bold blue]")
# Prepare nodes for evaluation
nodes_for_evaluation = []
for n in moar.pareto_frontier.plans:
n.moar_accuracy = moar.pareto_frontier.plans_accuracy.get(n)
n.on_frontier = n in moar.pareto_frontier.frontier_plans
nodes_for_evaluation.append(n)
from docetl.utils_evaluation import run_evaluation
eval_results = run_evaluation(
nodes_or_files=nodes_for_evaluation,
evaluate_func=evaluate_func,
metric_key=metric_key,
output_path=save_dir,
dataset_name=dataset_name,
)
# Save experiment summary
results = {
"optimizer": "moar",
"input_pipeline": yaml_path,
"model": model,
"max_iterations": max_iterations,
"exploration_weight": exploration_weight,
"save_dir": str(save_dir),
"dataset": dataset_name,
"start_time": start_time.isoformat(),
"end_time": end_time.isoformat(),
"duration_seconds": duration,
"num_best_nodes": len(best_nodes) if best_nodes else 0,
"total_nodes_explored": (
len(moar.all_nodes) if hasattr(moar, "all_nodes") else 0
),
"total_search_cost": (
moar.total_search_cost if hasattr(moar, "total_search_cost") else 0
),
}
if eval_results:
results["evaluation_file"] = str(save_dir / "evaluation_metrics.json")
# Save Pareto frontier if available
if hasattr(moar, "pareto_frontier") and moar.pareto_frontier.frontier_plans:
pareto_file = save_dir / "pareto_frontier.json"
pareto_data = []
for node in moar.pareto_frontier.frontier_plans:
pareto_data.append(
{
"node_id": node.get_id(),
"yaml_path": node.yaml_file_path,
"cost": node.cost,
"accuracy": moar.pareto_frontier.plans_accuracy.get(node),
}
)
with open(pareto_file, "w") as f:
json.dump(pareto_data, f, indent=2)
results["pareto_frontier_file"] = str(pareto_file)
# Save experiment summary
summary_file = save_dir / "experiment_summary.json"
with open(summary_file, "w") as f:
json.dump(results, f, indent=2)
DOCETL_CONSOLE.log(f"[green]✅ Experiment summary saved to: {summary_file}[/green]")
return results

485
docetl/moar/search_utils.py Normal file
View File

@ -0,0 +1,485 @@
"""
Utility functions for MCTS implementation.
This module contains helper functions and utilities used by the MCTS algorithm
but not core to the MCTS structure itself.
"""
import json
import math
import os
import time
from typing import List, Optional
from pydantic import BaseModel
from docetl.reasoning_optimizer.directives import (
DIRECTIVE_GROUPS,
Directive,
get_all_cost_directive_strings,
get_all_directive_strings,
)
from docetl.reasoning_optimizer.load_data import load_input_doc
from docetl.reasoning_optimizer.op_descriptions import * # noqa: F403, F405
# Maximum number of tokens we will allow in the prompt we send to the model.
# The Azure GPT-5 family allows 272,000 tokens.
MAX_CONTEXT_TOKENS = 270_000
class ExpandResponseFormat(BaseModel):
directive: str
operators: List[str]
def count_tokens(messages):
"""Count estimated tokens in messages list."""
# messages should be a list of dicts, each with a "content" key
total_chars = sum(
len(m.get("content", "")) for m in messages if isinstance(m, dict)
)
return max(1, total_chars // 4)
def trim_history(history: list, keep_system_first: bool = True) -> list:
"""Trim the conversation history in-place so its estimated token count
(via ``count_tokens``) does not exceed ``MAX_CONTEXT_TOKENS``.
We always keep the very first system message and the first user message so the
assistant retains the global instructions and the initial query context. After
that we drop the oldest messages until the budget is satisfied. Returns the
trimmed history list.
"""
# Determine starting index to preserve the initial system message and first user message
start_idx = 0
if keep_system_first and history:
if history[0].get("role") == "system":
start_idx = 1
# Find the first user message after the system message
for i in range(1, len(history)):
if history[i].get("role") == "user":
start_idx = i + 1
break
elif history[0].get("role") == "user":
# If first message is user, keep it and find the next user message
start_idx = 1
for i in range(1, len(history)):
if history[i].get("role") == "user":
start_idx = i + 1
break
# Drop oldest messages (just after the preserved block) until within limit
while len(history) > start_idx + 1 and count_tokens(history) > MAX_CONTEXT_TOKENS:
history.pop(start_idx)
return history
def get_directive_group(directive_name: str) -> str:
"""
Get the group name for a directive.
Args:
directive_name: Name of the directive
Returns:
Group name if found, None otherwise
"""
for group_name, directives in DIRECTIVE_GROUPS.items():
for directive in directives:
if directive.name == directive_name:
return group_name
return None
def get_excluded_directives_for_operation(node, op_name: str) -> set:
"""Get compression directives to exclude for code_map and extract operations."""
op_type = node.op_dict[op_name].get("type")
compression_exclusions = set()
if op_type in ["code_map", "extract"]:
compression_exclusions = set(DIRECTIVE_GROUPS.get("compression", []))
return compression_exclusions
def is_action_applicable(node, action: Directive) -> bool:
"""Check if an action is applicable to a node."""
return True
def update_pipeline(orig_config, new_ops_list, target_ops):
"""
Update the pipeline configuration with new operations.
Args:
orig_config (dict): The original pipeline configuration
new_ops_list (list): The entire pipeline operations list (not a subset)
target_ops (list): List of target operation names to replace
Returns:
dict: Updated pipeline configuration
"""
if new_ops_list is not None:
op_names = [op.get("name") for op in new_ops_list if "name" in op]
# Update the pipeline steps to use the new operation names
if "pipeline" in orig_config and "steps" in orig_config["pipeline"]:
for step in orig_config["pipeline"]["steps"]:
if "operations" in step:
new_ops = []
for op in step["operations"]:
if op == target_ops[0]:
new_ops.extend(op_names)
step["operations"] = new_ops
return orig_config
def fix_models(parsed_yaml):
"""No-op: Model names should be specified correctly in the YAML."""
pass
def is_fully_explored(node, max_children_multiplier: float = 1.0) -> bool:
"""Check if a node has been fully explored based on visit count."""
if node.parent is None:
return True
allowed_children = max(
2, 1 + math.floor(math.sqrt(float(node.visits)) * max_children_multiplier)
)
# Not only check the number of children, but also ensure all children have been simulated
# A child is considered simulated if it has been visited at least once (visits > 0)
# This ensures that children created but not yet simulated won't cause the parent to be considered fully explored
if len(node.children) < allowed_children:
return False
# Check if all children have been simulated (visited at least once)
# This prevents selecting children that were just created but not yet simulated
for child in node.children:
if child.visits == 0:
# This child hasn't been selected/simulated yet, so parent is not fully explored
return False
return True
def should_continue_search(
iteration: int,
max_iterations: int,
start_time: float,
max_time: Optional[float] = None,
) -> bool:
"""Determine if search should continue based on iteration count and time."""
if iteration >= max_iterations:
return False
if max_time is not None:
elapsed_time = time.time() - start_time
if elapsed_time >= max_time:
return False
return True
def calculate_ucb1(
node, parent_visits: int, exploration_constant: float = math.sqrt(2)
) -> float:
"""Calculate UCB1 value for node selection."""
if node.visits == 0:
return float("inf")
exploitation = node.value / node.visits
exploration = exploration_constant * math.sqrt(
math.log(parent_visits) / node.visits
)
return exploitation + exploration
def print_tree_visits_and_values(node=None, depth=0, file_handle=None, console=None):
"""Print tree structure with visit counts and values."""
if node is None:
return
indent = " " * depth
node_info = f"{indent}Node ID: {node.get_id()}, Visits: {node.visits}, Value: {node.value:.4f}"
if file_handle:
file_handle.write(node_info + "\n")
else:
if console is None:
from docetl.console import DOCETL_CONSOLE
console = DOCETL_CONSOLE
console.log(node_info)
for child in node.children:
print_tree_visits_and_values(child, depth + 1, file_handle, console)
def log_tree_to_file(root_node, iteration_num, output_dir="./outputs", console=None):
"""Log the tree structure to a file."""
log_file_path = os.path.join(output_dir, f"moar_tree_iteration_{iteration_num}.txt")
with open(log_file_path, "w") as f:
f.write(f"MOAR Tree Structure - Iteration {iteration_num}\n")
f.write("=" * 50 + "\n")
print_tree_visits_and_values(root_node, file_handle=f, console=console)
def create_expansion_prompt_acc(
node,
action_options,
input_query,
available_actions,
action_cost_changes,
action_accuracy_changes,
action_counts,
sample_input,
root_node,
yaml_file_path,
dataset=None,
node_accuracies=None,
model_stats=None,
available_models=None,
) -> tuple[str, str]:
"""Create expansion prompt for accuracy optimization."""
# Use provided available_models list, or extract from model_stats as fallback
if available_models:
available_models_list = available_models
elif model_stats:
available_models_list = list(model_stats.keys())
else:
available_models_list = []
availabel_actions_str = ""
for item in action_options:
op_name = item[0]
action_name = item[1]
action_str = f"Operator: {op_name}, Rewrite directive: {action_name}\n"
availabel_actions_str += action_str
action_stats = []
for action in available_actions:
cost_change = action_cost_changes.get(action, 0)
accuracy_change = action_accuracy_changes.get(action, 0)
count = action_counts.get(action, 0)
if count > 0:
avg_cost_change = cost_change / count
avg_accuracy_change = accuracy_change / count
action_stats.append(
f"- {action.name}: {count} uses, avg change in cost: {avg_cost_change:+.2f}, avg change in accuracy: {avg_accuracy_change:+.4f}"
)
else:
action_stats.append(
f"- {action.name}: {count} uses, avg change in cost: Unknown (never tried), avg change in accuracy: Unknown (never tried)"
)
action_stats_str = "\n".join(action_stats)
input_schema = load_input_doc(yaml_file_path)
user_message = f"""
I have a set of operations used to process long documents, along with a list of possible rewrite directives aimed at improving the quality of the query result.
Given a query pipeline made up of these operations, recommend one specific rewrite directive (specify by its name) that would improve accuracy and specify which operators (specify by their names) in the pipeline the directive should be applied to.
Make sure that your chosen directive is in the provided list of rewrite directives.
Pipeline:
Pipelines in DocETL are the core structures that define the flow of data processing. A pipeline consists of five main components: \n
- Default Model: The language model to use for the pipeline. Limit your choice of model to {available_models_list} \n
- System Prompts: A description of your dataset and the "persona" you'd like the LLM to adopt when analyzing your data. \n
- Datasets: The input data sources for your pipeline. \n
- Operators: The processing steps that transform your data. \n
- Pipeline Specification: The sequence of steps and the output configuration. \n
Operators:
Operators form the building blocks of data processing pipelines. Below is the list of operators:
{op_map.to_string()}\n
{op_extract.to_string()}\n
{op_parallel_map.to_string()}\n
{op_filter.to_string()}\n
{op_reduce.to_string()}\n
{op_split.to_string()}\n
{op_gather.to_string()}\n
{op_unnest.to_string()}\n
{op_sample.to_string()}\n
{op_resolve.to_string()}\n
Rewrite directives:
{get_all_directive_strings()}\n
Your valid choice of operation and rewrite directive combination. Only choose one of these:\n
{availabel_actions_str}
Action Performance History:
Based on previous executions across DIFFERENT query pipelines, here's how each action has performed:\n
{action_stats_str}
Note: These statistics come from applying actions to various other query pipelines, not the current one. Use this as general guidance about action effectiveness, but consider that performance may vary significantly for your specific pipeline structure and data.
Model Performance Reference:
If you are considering a model change directive, here are model statistics on this specific dataset: \n {str(model_stats if model_stats else {})}
These show the accuracy (acc) and cost for each model. Only reference this when evaluating model change options.
Selection Strategy:
Consider the current query pipeline, which directive can best improve the accuracy.
Prioritize exploration of untested actions while balancing with exploitation of proven performers:
- Actions with 0 uses have unknown potential, so you should explore them if applicable.
- fusion directives are helpful
- Consider both immediate improvement and learning about the action space
{node.get_memo_for_llm(root_node, node_accuracies)}
Make sure you read every rewrite directive carefully.
Make sure you only choose from the valid choices above and avoid already used combinations or approaches too similar to what has already been tried in the current optimization path.
Input document schema with token statistics: {input_schema} \n
Input data sample: {json.dumps(sample_input, indent=2)[:5000]} \n
The original query in YAML format using our operations: {input_query} \n
The original query result: {json.dumps(node.sample_result, indent=2)[:3000]} \n
"""
# Create a condensed version for message history (without full operator/directive descriptions)
condensed_user_message = f"""
Recommend one specific rewrite directive for accuracy optimization.
Valid choices:
{availabel_actions_str}
Action Performance History:
{action_stats_str}
Current pipeline: {input_query}
"""
return user_message, condensed_user_message
def create_expansion_prompt_cost(
node,
action_options,
input_query,
available_actions,
action_cost_changes,
action_accuracy_changes,
action_counts,
sample_input,
root_node,
yaml_file_path,
dataset=None,
node_accuracies=None,
model_stats=None,
available_models=None,
) -> tuple[str, str]:
"""Create expansion prompt for cost optimization."""
# Use provided available_models list, or extract from model_stats as fallback
if available_models:
available_models_list = available_models
elif model_stats:
available_models_list = list(model_stats.keys())
else:
available_models_list = []
availabel_actions_str = ""
for item in action_options:
op_name = item[0]
action_name = item[1]
action_str = f"Operator: {op_name}, Rewrite directive: {action_name}\n"
availabel_actions_str += action_str
action_stats = []
for action in available_actions:
cost_change = action_cost_changes.get(action, 0)
accuracy_change = action_accuracy_changes.get(action, 0)
count = action_counts.get(action, 0)
if count > 0:
avg_cost_change = cost_change / count
avg_accuracy_change = accuracy_change / count
action_stats.append(
f"- {action.name}: {count} uses, avg change in cost: {avg_cost_change:+.2f}, avg change in accuracy: {avg_accuracy_change:+.4f}"
)
else:
action_stats.append(
f"- {action.name}: {count} uses, avg change in cost: Unknown (never tried), avg change in accuracy: Unknown (never tried)"
)
action_stats_str = "\n".join(action_stats)
input_schema = load_input_doc(yaml_file_path)
user_message = f"""
I have a set of operations used to process long documents, along with a list of possible rewrite directives designed to improve the cost effectiveness of the pipeline, while maintaining similar or better accuracy.
Given a query pipeline composed of these operations, recommend one specific rewrite directive (identified by its name from the provided list) that would improve cost effectiveness. Also, specify which operator(s) (by name) in the pipeline the directive should be applied to.
Make sure your recommended directive is selected from the provided list.
Pipeline:
Pipelines in DocETL are the core structures that define the flow of data processing. A pipeline consists of five main components: \n
- Default Model: The language model to use for the pipeline. Limit your choice of model to {available_models_list} \n
- System Prompts: A description of your dataset and the "persona" you'd like the LLM to adopt when analyzing your data. \n
- Datasets: The input data sources for your pipeline. \n
- Operators: The processing steps that transform your data. \n
- Pipeline Specification: The sequence of steps and the output configuration. \n
Operators:
Operators form the building blocks of data processing pipelines. Below is the list of operators:
{op_map.to_string()}\n
{op_extract.to_string()}\n
{op_parallel_map.to_string()}\n
{op_filter.to_string()}\n
{op_reduce.to_string()}\n
{op_split.to_string()}\n
{op_gather.to_string()}\n
{op_unnest.to_string()}\n
{op_sample.to_string()}\n
{op_resolve.to_string()}\n
Rewrite directives:
{get_all_cost_directive_strings()}\n
Your valid choice of operation and rewrite directive combination. Only choose one of these:\n
{availabel_actions_str}
Action Performance History:
Based on previous executions across DIFFERENT query pipelines, here's how each action has performed:\n
{action_stats_str}
Note: These statistics come from applying actions to various other query pipelines, not the current one. Use this as general guidance about action effectiveness, but consider that performance may vary significantly for your specific pipeline structure and data.
Model Performance Reference:
If you are considering a model change directive, here are model statistics on this specific dataset: \n {str(model_stats if model_stats else {})}
These show the accuracy (acc) and cost for each model. Only reference this when evaluating model change options.
Selection Strategy:
Consider the current query pipeline, which directive can best improve cost effectiveness.
Prioritize exploration of untested actions while balancing with exploitation of proven performers:
- Actions with 0 uses have unknown potential, so you should explore them if applicable.
- fusion directives are helpful
- Consider both immediate improvement and learning about the action space
{node.get_memo_for_llm(root_node, node_accuracies)}
Make sure you only choose from the valid choices above and avoid already used combinations or approaches too similar to what has already been tried in the current optimization path.
Input document schema with token statistics: {input_schema} \n
Input data sample: {json.dumps(sample_input, indent=2)[:5000]} \n
The original query in YAML format using our operations: {input_query} \n
The original query result: {json.dumps(node.sample_result, indent=2)[:3000]} \n
"""
condensed_user_message = f"""
Recommend one specific rewrite directive for cost optimization.
Valid choices:
{availabel_actions_str}
Action Performance History:
{action_stats_str}
Current pipeline: {input_query}
"""
return user_message, condensed_user_message

View File

@ -4,8 +4,8 @@ from docetl.operations.code_operations import CodeFilterOperation, CodeMapOperat
from docetl.operations.equijoin import EquijoinOperation
from docetl.operations.filter import FilterOperation
from docetl.operations.gather import GatherOperation
from docetl.operations.link_resolve import LinkResolveOperation
from docetl.operations.map import MapOperation, ParallelMapOperation
from docetl.operations.link_resolve import LinkResolveOperation
from docetl.operations.reduce import ReduceOperation
from docetl.operations.resolve import ResolveOperation
from docetl.operations.rank import RankOperation

View File

@ -2,6 +2,7 @@
The BaseOperation class is an abstract base class for all operations in the docetl framework. It provides a common structure and interface for various data processing operations.
"""
import traceback
from abc import ABC, ABCMeta, abstractmethod
from typing import Any
@ -87,6 +88,7 @@ class BaseOperation(ABC, metaclass=BaseOperationMeta):
type: str
skip_on_error: bool = False
gleaning: GleaningConfig | None = None
retriever: str | None = None
@abstractmethod
def execute(self, input_data: list[dict]) -> tuple[list[dict], float]:
@ -109,3 +111,26 @@ class BaseOperation(ABC, metaclass=BaseOperationMeta):
"""Perform syntax checks on the operation configuration."""
# Validate the configuration using Pydantic
self.schema.model_validate(self.config, context=context)
def _maybe_build_retrieval_context(self, context: dict[str, Any]) -> str:
"""Build retrieval context string if a retriever is configured."""
retriever_name = self.config.get("retriever")
if not retriever_name:
return ""
retrievers = getattr(self.runner, "retrievers", {})
if retriever_name not in retrievers:
raise ValueError(
f"Retriever '{retriever_name}' not found in configuration."
)
retriever = retrievers[retriever_name]
try:
result = retriever.retrieve(context)
return result.rendered_context or ""
except Exception as e:
# Soft-fail to avoid blocking the op
self.console.log(
f"[yellow]Warning: retrieval failed for '{retriever_name}': {e}[/yellow]"
)
# Print traceback to help debug
self.console.log(traceback.format_exc())
return "No extra context available."

View File

@ -7,6 +7,7 @@ from jinja2 import Template
from .base import BaseOperation
from .clustering_utils import get_embeddings_for_clustering
from .utils import RichLoopBar, strict_render
from docetl.utils import has_jinja_syntax, prompt_user_for_non_jinja_confirmation
class ClusterOperation(BaseOperation):
@ -19,6 +20,19 @@ class ClusterOperation(BaseOperation):
self.max_batch_size: int = self.config.get(
"max_batch_size", kwargs.get("max_batch_size", float("inf"))
)
# Check for non-Jinja prompts and prompt user for confirmation
if "summary_prompt" in self.config and not has_jinja_syntax(
self.config["summary_prompt"]
):
if not prompt_user_for_non_jinja_confirmation(
self.config["summary_prompt"], self.config["name"], "summary_prompt"
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your summary_prompt."
)
# Mark that we need to append document statement (cluster uses inputs)
self.config["_append_document_to_prompt"] = True
self.config["_is_reduce_operation"] = True
def syntax_check(self) -> None:
"""
@ -48,11 +62,16 @@ class ClusterOperation(BaseOperation):
if not isinstance(self.config["summary_prompt"], str):
raise TypeError("'prompt' must be a string")
# Check if the prompt is a valid Jinja2 template
try:
Template(self.config["summary_prompt"])
except Exception as e:
raise ValueError(f"Invalid Jinja2 template in 'prompt': {str(e)}")
# Check if the prompt has Jinja syntax
if not has_jinja_syntax(self.config["summary_prompt"]):
# This will be handled during initialization with user confirmation
pass
else:
# Check if the prompt is a valid Jinja2 template
try:
Template(self.config["summary_prompt"])
except Exception as e:
raise ValueError(f"Invalid Jinja2 template in 'prompt': {str(e)}")
# Check optional parameters
if "max_batch_size" in self.config:

View File

@ -4,6 +4,8 @@ This module contains utilities for clustering based on different methods.
We use these in map and reduce operations.
"""
import json
from docetl.operations.utils import APIWrapper
from docetl.utils import completion_cost
@ -26,10 +28,10 @@ def get_embeddings_for_clustering(
for i in range(0, len(items), batch_size):
batch = items[i : i + batch_size]
texts = [
" ".join(str(item[key]) for key in embedding_keys if key in item)[:10000]
" ".join(str(item[key]) for key in embedding_keys if key in item)[:1000]
for item in batch
]
response = api_wrapper.gen_embedding(embedding_model, texts)
response = api_wrapper.gen_embedding(embedding_model, json.dumps(texts))
embeddings.extend([data["embedding"] for data in response["data"]])
cost += completion_cost(response)

View File

@ -1,5 +1,9 @@
import inspect
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from pydantic import Field, field_validator
from docetl.operations.base import BaseOperation
from docetl.operations.utils import RichLoopBar
@ -8,9 +12,25 @@ from docetl.operations.utils import RichLoopBar
class CodeMapOperation(BaseOperation):
class schema(BaseOperation.schema):
type: str = "code_map"
code: str
code: Any
concurrent_thread_count: int = os.cpu_count()
drop_keys: list[str] | None = None
limit: int | None = Field(None, gt=0)
@field_validator("code")
@classmethod
def validate_code(cls, v: Any) -> str:
if isinstance(v, str):
return v
if callable(v):
try:
src = inspect.getsource(v)
except OSError as e:
raise ValueError(
"Unable to retrieve source for provided function. Please pass a normal def function."
) from e
return f"{src}\ntransform = {v.__name__}"
raise TypeError("code must be a string or a callable")
def syntax_check(self) -> None:
config = self.schema(**self.config)
@ -25,6 +45,10 @@ class CodeMapOperation(BaseOperation):
raise ValueError(f"Invalid code configuration: {str(e)}")
def execute(self, input_data: list[dict]) -> tuple[list[dict], float]:
limit_value = self.config.get("limit")
if limit_value is not None:
input_data = input_data[:limit_value]
namespace = {}
exec(self.config["code"], namespace)
transform_fn = namespace["transform"]
@ -57,8 +81,24 @@ class CodeMapOperation(BaseOperation):
class CodeReduceOperation(BaseOperation):
class schema(BaseOperation.schema):
type: str = "code_reduce"
code: str
code: Any
concurrent_thread_count: int = os.cpu_count()
limit: int | None = Field(None, gt=0)
@field_validator("code")
@classmethod
def validate_code(cls, v: Any) -> str:
if isinstance(v, str):
return v
if callable(v):
try:
src = inspect.getsource(v)
except OSError as e:
raise ValueError(
"Unable to retrieve source for provided function. Please pass a normal def function."
) from e
return f"{src}\ntransform = {v.__name__}"
raise TypeError("code must be a string or a callable")
def syntax_check(self) -> None:
config = self.schema(**self.config)
@ -97,6 +137,12 @@ class CodeReduceOperation(BaseOperation):
grouped_data = list(grouped_data.items())
limit_value = self.config.get("limit")
if limit_value is not None:
# Sort by group size (smallest first) and take the limit
grouped_data = sorted(grouped_data, key=lambda x: len(x[1]))
grouped_data = grouped_data[:limit_value]
results = []
with ThreadPoolExecutor(
max_workers=self.config.get("concurrent_thread_count", os.cpu_count())
@ -132,8 +178,24 @@ class CodeReduceOperation(BaseOperation):
class CodeFilterOperation(BaseOperation):
class schema(BaseOperation.schema):
type: str = "code_filter"
code: str
code: Any
concurrent_thread_count: int = os.cpu_count()
limit: int | None = Field(None, gt=0)
@field_validator("code")
@classmethod
def validate_code(cls, v: Any) -> str:
if isinstance(v, str):
return v
if callable(v):
try:
src = inspect.getsource(v)
except OSError as e:
raise ValueError(
"Unable to retrieve source for provided function. Please pass a normal def function."
) from e
return f"{src}\ntransform = {v.__name__}"
raise TypeError("code must be a string or a callable")
def syntax_check(self) -> None:
config = self.schema(**self.config)
@ -152,6 +214,7 @@ class CodeFilterOperation(BaseOperation):
exec(self.config["code"], namespace)
filter_fn = namespace["transform"]
limit_value = self.config.get("limit")
results = []
with ThreadPoolExecutor(
max_workers=self.config.get("concurrent_thread_count", os.cpu_count())
@ -166,4 +229,6 @@ class CodeFilterOperation(BaseOperation):
should_keep = futures[i].result()
if should_keep:
results.append(input_data[i])
if limit_value is not None and len(results) >= limit_value:
break
return results, 0.0

View File

@ -17,8 +17,13 @@ from rich.prompt import Confirm
from docetl.operations.base import BaseOperation
from docetl.operations.utils import strict_render
from docetl.operations.utils.blocking import RuntimeBlockingOptimizer
from docetl.operations.utils.progress import RichLoopBar
from docetl.utils import completion_cost
from docetl.utils import (
completion_cost,
has_jinja_syntax,
prompt_user_for_non_jinja_confirmation,
)
# Global variables to store shared data
_right_data = None
@ -59,6 +64,7 @@ class EquijoinOperation(BaseOperation):
comparison_prompt: str
output: dict[str, Any] | None = None
blocking_threshold: float | None = None
blocking_target_recall: float | None = None
blocking_conditions: list[str] | None = None
limits: dict[str, int] | None = None
comparison_model: str | None = None
@ -89,6 +95,41 @@ class EquijoinOperation(BaseOperation):
)
return v
@field_validator("comparison_prompt")
def validate_comparison_prompt(cls, v):
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
return v
# If it has Jinja syntax, validate it's a valid template
from jinja2 import Template
try:
Template(v)
except Exception as e:
raise ValueError(
f"Invalid Jinja2 template in 'comparison_prompt': {str(e)}"
)
return v
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Check for non-Jinja prompts and prompt user for confirmation
if "comparison_prompt" in self.config and not has_jinja_syntax(
self.config["comparison_prompt"]
):
if not prompt_user_for_non_jinja_confirmation(
self.config["comparison_prompt"],
self.config["name"],
"comparison_prompt",
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your comparison_prompt."
)
# Mark that we need to append document statement
# Note: equijoin uses left and right, so we'll handle it in strict_render
self.config["_append_document_to_comparison_prompt"] = True
def compare_pair(
self,
comparison_prompt: str,
@ -211,6 +252,58 @@ class EquijoinOperation(BaseOperation):
if self.status:
self.status.stop()
# Track pre-computed embeddings from auto-optimization
precomputed_left_embeddings = None
precomputed_right_embeddings = None
# Auto-compute blocking threshold if no blocking configuration is provided
if not blocking_threshold and not blocking_conditions and not limit_comparisons:
# Get target recall from operation config (default 0.95)
target_recall = self.config.get("blocking_target_recall", 0.95)
self.console.log(
f"[yellow]No blocking configuration. Auto-computing threshold (target recall: {target_recall:.0%})...[/yellow]"
)
# Create comparison function for threshold optimization
def compare_fn_for_optimization(left_item, right_item):
return self.compare_pair(
self.config["comparison_prompt"],
self.config.get("comparison_model", self.default_model),
left_item,
right_item,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
)
# Run threshold optimization
optimizer = RuntimeBlockingOptimizer(
runner=self.runner,
config=self.config,
default_model=self.default_model,
max_threads=self.max_threads,
console=self.console,
target_recall=target_recall,
sample_size=min(100, len(left_data) * len(right_data) // 4),
)
(
blocking_threshold,
precomputed_left_embeddings,
precomputed_right_embeddings,
optimization_cost,
) = optimizer.optimize_equijoin(
left_data,
right_data,
compare_fn_for_optimization,
left_keys=left_keys,
right_keys=right_keys,
)
total_cost += optimization_cost
self.console.log(
f"[green]Using auto-computed blocking threshold: {blocking_threshold}[/green]"
)
# Initial blocking using multiprocessing
num_processes = min(cpu_count(), len(left_data))
@ -259,45 +352,60 @@ class EquijoinOperation(BaseOperation):
)
if blocking_threshold is not None:
embedding_model = self.config.get("embedding_model", self.default_model)
model_input_context_length = model_cost.get(embedding_model, {}).get(
"max_input_tokens", 8192
)
def get_embeddings(
input_data: list[dict[str, Any]], keys: list[str], name: str
) -> tuple[list[list[float]], float]:
texts = [
" ".join(str(item[key]) for key in keys if key in item)[
: model_input_context_length * 4
]
for item in input_data
]
embeddings = []
total_cost = 0
# Use precomputed embeddings if available from auto-optimization
if (
precomputed_left_embeddings is not None
and precomputed_right_embeddings is not None
):
left_embeddings = precomputed_left_embeddings
right_embeddings = precomputed_right_embeddings
else:
embedding_model = self.config.get("embedding_model", self.default_model)
model_input_context_length = model_cost.get(embedding_model, {}).get(
"max_input_tokens", 8192
)
batch_size = 2000
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
self.console.log(
f"On iteration {i} for creating embeddings for {name} data"
)
response = self.runner.api.gen_embedding(
model=embedding_model,
input=batch,
)
embeddings.extend([data["embedding"] for data in response["data"]])
total_cost += completion_cost(response)
return embeddings, total_cost
left_embeddings, left_cost = get_embeddings(left_data, left_keys, "left")
right_embeddings, right_cost = get_embeddings(
right_data, right_keys, "right"
)
total_cost += left_cost + right_cost
self.console.log(
f"Created embeddings for datasets. Total embedding creation cost: {total_cost}"
)
def get_embeddings(
input_data: list[dict[str, Any]], keys: list[str], name: str
) -> tuple[list[list[float]], float]:
texts = [
" ".join(str(item[key]) for key in keys if key in item)[
: model_input_context_length * 4
]
for item in input_data
]
embeddings = []
embedding_cost = 0
num_batches = (len(texts) + batch_size - 1) // batch_size
for batch_idx, i in enumerate(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
if num_batches > 1:
self.console.log(
f"[dim]Creating {name} embeddings: batch {batch_idx + 1}/{num_batches} "
f"({min(i + batch_size, len(texts))}/{len(texts)} items)[/dim]"
)
response = self.runner.api.gen_embedding(
model=embedding_model,
input=batch,
)
embeddings.extend(
[data["embedding"] for data in response["data"]]
)
embedding_cost += completion_cost(response)
return embeddings, embedding_cost
self.console.log(
f"[cyan]Creating embeddings for {len(left_data)} left + {len(right_data)} right items...[/cyan]"
)
left_embeddings, left_cost = get_embeddings(
left_data, left_keys, "left"
)
right_embeddings, right_cost = get_embeddings(
right_data, right_keys, "right"
)
total_cost += left_cost + right_cost
# Compute all cosine similarities in one call
from sklearn.metrics.pairwise import cosine_similarity

View File

@ -11,6 +11,7 @@ from pydantic import Field, field_validator
from docetl.operations.base import BaseOperation
from docetl.operations.utils import RichLoopBar, strict_render
from docetl.utils import has_jinja_syntax, prompt_user_for_non_jinja_confirmation
class ExtractOperation(BaseOperation):
@ -25,9 +26,14 @@ class ExtractOperation(BaseOperation):
timeout: int | None = None
skip_on_error: bool = False
litellm_completion_kwargs: dict[str, Any] = Field(default_factory=dict)
limit: int | None = Field(None, gt=0)
@field_validator("prompt")
def validate_prompt(cls, v):
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
return v
try:
Template(v)
except Exception as e:
@ -47,6 +53,16 @@ class ExtractOperation(BaseOperation):
self.extraction_key_suffix = f"_extracted_{self.config['name']}"
else:
self.extraction_key_suffix = self.config["extraction_key_suffix"]
# Check for non-Jinja prompts and prompt user for confirmation
if "prompt" in self.config and not has_jinja_syntax(self.config["prompt"]):
if not prompt_user_for_non_jinja_confirmation(
self.config["prompt"], self.config["name"], "prompt"
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your prompt."
)
# Mark that we need to append document statement
self.config["_append_document_to_prompt"] = True
def _reformat_text_with_line_numbers(self, text: str, line_width: int = 80) -> str:
"""
@ -103,7 +119,7 @@ class ExtractOperation(BaseOperation):
def _execute_line_number_strategy(
self, item: dict, doc_key: str
) -> tuple[list[dict[str, Any]], float]:
) -> tuple[list[str], float, str]:
"""
Executes the line number extraction strategy for a single document key.
@ -132,10 +148,18 @@ class ExtractOperation(BaseOperation):
formatted_text = self._reformat_text_with_line_numbers(text_content)
# Render the prompt
extraction_instructions = strict_render(self.config["prompt"], {"input": item})
# Retrieval context
retrieval_context = self._maybe_build_retrieval_context({"input": item})
extraction_instructions = strict_render(
self.config["prompt"],
{"input": item, "retrieval_context": retrieval_context},
)
augmented_prompt_template = """
You are extracting specific content from text documents. Extract information according to these instructions: {{ extraction_instructions }}
Extra context (may be helpful):
{{ retrieval_context }}
The text is formatted with line numbers as follows:
{{ formatted_text }}
@ -162,6 +186,7 @@ Do not include explanatory text in your response, only the JSON object.
{
"extraction_instructions": extraction_instructions,
"formatted_text": formatted_text,
"retrieval_context": retrieval_context,
},
)
@ -202,7 +227,7 @@ Do not include explanatory text in your response, only the JSON object.
self.console.log(
f"[bold red]Error parsing LLM response: {llm_result.response}. Skipping.[/bold red]"
)
return [], llm_result.total_cost
return [], llm_result.total_cost, retrieval_context
for line_range in parsed_output.get("line_ranges", []):
start_line = line_range.get("start_line", 0)
@ -230,20 +255,20 @@ Do not include explanatory text in your response, only the JSON object.
extracted_texts.append("".join(extracted_content))
return extracted_texts, llm_result.total_cost
return extracted_texts, llm_result.total_cost, retrieval_context
except Exception as e:
if self.config.get("skip_on_error", True):
self.console.log(
f"[bold red]Error parsing LLM response: {str(e)}. Skipping.[/bold red]"
)
return [], llm_result.total_cost
return [], llm_result.total_cost, retrieval_context
else:
raise RuntimeError(f"Error parsing LLM response: {str(e)}") from e
def _execute_regex_strategy(
self, item: dict, doc_key: str
) -> tuple[list[str], float]:
) -> tuple[list[str], float, str]:
"""
Executes the regex extraction strategy for a single document key.
@ -252,7 +277,7 @@ Do not include explanatory text in your response, only the JSON object.
doc_key (str): The key of the document text to process.
Returns:
tuple[list[str], float]: A tuple containing the extraction results and the cost.
tuple[list[str], float, str]: A tuple containing the extraction results, cost, and retrieval context.
"""
import re
@ -262,7 +287,7 @@ Do not include explanatory text in your response, only the JSON object.
self.console.log(
f"[yellow]Warning: Key '{doc_key}' not found or not a string in document. Skipping.[/yellow]"
)
return [], 0.0
return [], 0.0, ""
else:
raise ValueError(
f"Key '{doc_key}' not found or not a string in document"
@ -271,13 +296,21 @@ Do not include explanatory text in your response, only the JSON object.
text_content = item[doc_key]
# Prepare the context for prompt rendering
context = {"input": item, "text_content": text_content}
retrieval_context = self._maybe_build_retrieval_context({"input": item})
context = {
"input": item,
"text_content": text_content,
"retrieval_context": retrieval_context,
}
# Render the prompt
extraction_instructions = strict_render(self.config["prompt"], context)
augmented_prompt_template = """
You are creating regex patterns to extract specific content from text. Extract information according to these instructions: {{ extraction_instructions }}
Extra context (may be helpful):
{{ retrieval_context }}
The text to analyze is:
{{ text_content }}
@ -356,14 +389,14 @@ Return only the JSON object with your patterns, no explanatory text.
else:
raise ValueError(f"Invalid regex pattern '{pattern}': {str(e)}")
return extracted_texts, llm_result.total_cost
return extracted_texts, llm_result.total_cost, retrieval_context
except Exception as e:
if self.config.get("skip_on_error", True):
self.console.log(
f"[bold red]Error parsing LLM response: {str(e)}. Skipping.[/bold red]"
)
return [], llm_result.total_cost
return [], llm_result.total_cost, retrieval_context
else:
raise RuntimeError(f"Error parsing LLM response: {str(e)}") from e
@ -377,6 +410,10 @@ Return only the JSON object with your patterns, no explanatory text.
Returns:
tuple[list[dict], float]: A tuple containing the processed data and the total cost of the operation.
"""
limit_value = self.config.get("limit")
if limit_value is not None:
input_data = input_data[:limit_value]
if not input_data:
return [], 0.0
@ -431,7 +468,7 @@ Return only the JSON object with your patterns, no explanatory text.
doc_key, future, output_item = futures[i]
try:
extracted_texts_duped, cost = future.result()
extracted_texts_duped, cost, retrieval_context = future.result()
# Remove duplicates and empty strings
extracted_texts_duped = [
@ -453,6 +490,12 @@ Return only the JSON object with your patterns, no explanatory text.
else:
output_item[output_key] = extracted_texts
# Save retrieved context if enabled
if self.config.get("save_retriever_output", False):
output_item[f"_{self.config['name']}_retrieved_context"] = (
retrieval_context if retrieval_context else ""
)
except Exception as e:
if self.config.get("skip_on_error", True):
self.console.log(

View File

@ -33,6 +33,30 @@ class FilterOperation(MapOperation):
return self
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._filter_key = next(
iter(
[
k
for k in self.config["output"]["schema"].keys()
if k != "_short_explanation"
]
)
)
self._filter_is_build = False
def _limit_applies_to_inputs(self) -> bool:
return False
def _handle_result(self, result: dict[str, Any]) -> tuple[dict | None, bool]:
keep_record = bool(result.get(self._filter_key))
result.pop(self._filter_key, None)
if self._filter_is_build or keep_record:
return result, keep_record
return None, False
def execute(
self, input_data: list[dict], is_build: bool = False
) -> tuple[list[dict], float]:
@ -46,55 +70,10 @@ class FilterOperation(MapOperation):
Returns:
tuple[list[dict], float]: A tuple containing the filtered list of dictionaries
and the total cost of the operation.
This method performs the following steps:
1. Processes each input item using an LLM model
2. Validates the output
3. Filters the results based on the specified filter key
4. Calculates the total cost of the operation
The method uses multi-threading to process items in parallel, improving performance
for large datasets.
Usage:
```python
from docetl.operations import FilterOperation
config = {
"prompt": "Determine if the following item is important: {{input}}",
"output": {
"schema": {"is_important": "bool"}
},
"model": "gpt-3.5-turbo"
}
filter_op = FilterOperation(config)
input_data = [
{"id": 1, "text": "Critical update"},
{"id": 2, "text": "Regular maintenance"}
]
results, cost = filter_op.execute(input_data)
print(f"Filtered results: {results}")
print(f"Total cost: {cost}")
```
"""
filter_key = next(
iter(
[
k
for k in self.config["output"]["schema"].keys()
if k != "_short_explanation"
]
)
)
results, total_cost = super().execute(input_data)
# Drop records with filter_key values that are False
if not is_build:
results = [result for result in results if result[filter_key]]
# Drop the filter_key from the results
for result in results:
result.pop(filter_key, None)
return results, total_cost
previous_state = self._filter_is_build
self._filter_is_build = is_build
try:
return super().execute(input_data)
finally:
self._filter_is_build = previous_state

View File

@ -7,11 +7,29 @@ from sklearn.metrics.pairwise import cosine_similarity
from docetl.operations.base import BaseOperation
from docetl.operations.utils import RichLoopBar, strict_render
from docetl.utils import has_jinja_syntax, prompt_user_for_non_jinja_confirmation
from .clustering_utils import get_embeddings_for_clustering
class LinkResolveOperation(BaseOperation):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Check for non-Jinja prompts and prompt user for confirmation
if "comparison_prompt" in self.config and not has_jinja_syntax(
self.config["comparison_prompt"]
):
if not prompt_user_for_non_jinja_confirmation(
self.config["comparison_prompt"],
self.config["name"],
"comparison_prompt",
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your comparison_prompt."
)
# Mark that we need to append document statement
# Note: link_resolve uses link_value, id_value, and item, so strict_render will handle it
self.config["_append_document_to_comparison_prompt"] = True
def execute(self, input_data: list[dict]) -> tuple[list[dict], float]:
"""
Executes the resolve links operation on the provided dataset.

View File

@ -17,6 +17,7 @@ from docetl.base_schemas import Tool, ToolFunction
from docetl.operations.base import BaseOperation
from docetl.operations.utils import RichLoopBar, strict_render, validate_output_types
from docetl.operations.utils.api import OutputMode
from docetl.utils import has_jinja_syntax, prompt_user_for_non_jinja_confirmation
class MapOperation(BaseOperation):
@ -42,6 +43,7 @@ class MapOperation(BaseOperation):
litellm_completion_kwargs: dict[str, Any] = {}
pdf_url_key: str | None = None
flush_partial_result: bool = False
limit: int | None = Field(None, gt=0)
# Calibration parameters
calibrate: bool = False
num_calibration_docs: int = Field(10, gt=0)
@ -49,6 +51,11 @@ class MapOperation(BaseOperation):
@field_validator("batch_prompt")
def validate_batch_prompt(cls, v):
if v is not None:
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
# We'll mark it for later processing
return v
try:
template = Template(v)
# Test render with a minimal inputs list to validate template
@ -62,6 +69,11 @@ class MapOperation(BaseOperation):
@field_validator("prompt")
def validate_prompt(cls, v):
if v is not None:
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
# We'll mark it for later processing
return v
try:
Template(v)
except Exception as e:
@ -118,6 +130,33 @@ class MapOperation(BaseOperation):
"max_batch_size", kwargs.get("max_batch_size", None)
)
self.clustering_method = "random"
# Check for non-Jinja prompts and prompt user for confirmation
if "prompt" in self.config and not has_jinja_syntax(self.config["prompt"]):
if not prompt_user_for_non_jinja_confirmation(
self.config["prompt"], self.config["name"], "prompt"
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your prompt."
)
# Mark that we need to append document statement
self.config["_append_document_to_prompt"] = True
if "batch_prompt" in self.config and not has_jinja_syntax(
self.config["batch_prompt"]
):
if not prompt_user_for_non_jinja_confirmation(
self.config["batch_prompt"], self.config["name"], "batch_prompt"
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your batch_prompt."
)
# Mark that we need to append document statement
self.config["_append_document_to_batch_prompt"] = True
def _limit_applies_to_inputs(self) -> bool:
return True
def _handle_result(self, result: dict[str, Any]) -> tuple[dict | None, bool]:
return result, True
def _generate_calibration_context(self, input_data: list[dict]) -> str:
"""
@ -239,17 +278,27 @@ Reference anchors:"""
The method uses parallel processing to improve performance.
"""
limit_value = self.config.get("limit")
# Check if there's no prompt and only drop_keys
if "prompt" not in self.config and "drop_keys" in self.config:
data_to_process = input_data
if limit_value is not None and self._limit_applies_to_inputs():
data_to_process = input_data[:limit_value]
# If only drop_keys is specified, simply drop the keys and return
dropped_results = []
for item in input_data:
for item in data_to_process:
new_item = {
k: v for k, v in item.items() if k not in self.config["drop_keys"]
}
dropped_results.append(new_item)
if limit_value is not None and len(dropped_results) >= limit_value:
break
return dropped_results, 0.0 # Return the modified data with no cost
if limit_value is not None and self._limit_applies_to_inputs():
input_data = input_data[:limit_value]
# Generate calibration context if enabled
calibration_context = ""
if self.config.get("calibrate", False) and "prompt" in self.config:
@ -276,7 +325,17 @@ Reference anchors:"""
item: dict, initial_result: dict | None = None
) -> tuple[dict | None, float]:
prompt = strict_render(self.config["prompt"], {"input": item})
# Build retrieval context (if configured)
retrieval_context = self._maybe_build_retrieval_context({"input": item})
ctx = {"input": item, "retrieval_context": retrieval_context}
rendered = strict_render(self.config["prompt"], ctx)
# If template didn't use retrieval_context, prepend a standard header
prompt = (
f"Here is some extra context:\n{retrieval_context}\n\n{rendered}"
if retrieval_context
and "retrieval_context" not in self.config["prompt"]
else rendered
)
messages = [{"role": "user", "content": prompt}]
if self.config.get("pdf_url_key", None):
# Append the pdf to the prompt
@ -384,6 +443,12 @@ Reference anchors:"""
output[f"_observability_{self.config['name']}"] = {
"prompt": prompt
}
# Add retrieved context if save_retriever_output is enabled
if self.config.get("save_retriever_output", False):
for output in outputs:
output[f"_{self.config['name']}_retrieved_context"] = (
retrieval_context if retrieval_context else ""
)
return outputs, llm_result.total_cost
return None, llm_result.total_cost
@ -479,40 +544,87 @@ Reference anchors:"""
return all_results, total_cost
with ThreadPoolExecutor(max_workers=self.max_batch_size) as executor:
batch_size = self.max_batch_size if self.max_batch_size is not None else 1
futures = []
for i in range(0, len(input_data), batch_size):
batch = input_data[i : i + batch_size]
futures.append(executor.submit(_process_map_batch, batch))
results = []
total_cost = 0
pbar = RichLoopBar(
range(len(futures)),
desc=f"Processing {self.config['name']} (map) on all documents",
console=self.console,
limit_counter = 0
batch_size = self.max_batch_size if self.max_batch_size is not None else 1
total_batches = (len(input_data) + batch_size - 1) // batch_size
if total_batches == 0:
if self.status:
self.status.start()
return [], 0.0
worker_limit = self.max_batch_size or self.max_threads or 1
window_size = (
total_batches
if limit_value is None
else max(1, (limit_value + batch_size - 1) // batch_size)
)
results: list[dict] = []
total_cost = 0.0
limit_reached = False
op_name = self.config["name"]
if limit_value is not None and not self._limit_applies_to_inputs():
self.console.log(
f"[yellow]Note: Operation will terminate early once {limit_value} items pass the filter condition.[/yellow]"
)
for batch_index in pbar:
result_list, item_cost = futures[batch_index].result()
if result_list:
if "drop_keys" in self.config:
result_list = [
{
k: v
for k, v in result.items()
if k not in self.config["drop_keys"]
}
for result in result_list
]
results.extend(result_list)
# --- BEGIN: Flush partial checkpoint ---
if self.config.get("flush_partial_results", False):
op_name = self.config["name"]
self.runner._flush_partial_results(
op_name, batch_index, result_list
)
# --- END: Flush partial checkpoint ---
total_cost += item_cost
with ThreadPoolExecutor(max_workers=worker_limit) as executor:
with RichLoopBar(
total=total_batches,
desc=f"Processing {op_name} (map) on all documents",
console=self.console,
) as pbar:
chunk_start = 0
while chunk_start < total_batches and not limit_reached:
chunk_end = min(total_batches, chunk_start + window_size)
chunk_ordinals = list(range(chunk_start, chunk_end))
futures = []
for ordinal in chunk_ordinals:
start_idx = ordinal * batch_size
batch = input_data[start_idx : start_idx + batch_size]
futures.append(executor.submit(_process_map_batch, batch))
for relative_idx, future in enumerate(futures):
if limit_value is not None and limit_counter >= limit_value:
limit_reached = True
break
result_list, item_cost = future.result()
total_cost += item_cost
if result_list:
if "drop_keys" in self.config:
result_list = [
{
k: v
for k, v in result.items()
if k not in self.config["drop_keys"]
}
for result in result_list
]
if self.config.get("flush_partial_results", False):
self.runner._flush_partial_results(
op_name, chunk_ordinals[relative_idx], result_list
)
for result in result_list:
processed_result, counts_towards_limit = (
self._handle_result(result)
)
if processed_result is not None:
results.append(processed_result)
if limit_value is not None and counts_towards_limit:
limit_counter += 1
if limit_counter >= limit_value:
limit_reached = True
break
pbar.update()
chunk_start = chunk_end
if self.status:
self.status.start()

View File

@ -32,7 +32,11 @@ from docetl.operations.utils import (
# Import OutputMode enum for structured output checks
from docetl.operations.utils.api import OutputMode
from docetl.utils import completion_cost
from docetl.utils import (
completion_cost,
has_jinja_syntax,
prompt_user_for_non_jinja_confirmation,
)
class ReduceOperation(BaseOperation):
@ -63,10 +67,15 @@ class ReduceOperation(BaseOperation):
timeout: int | None = None
litellm_completion_kwargs: dict[str, Any] = Field(default_factory=dict)
enable_observability: bool = False
limit: int | None = Field(None, gt=0)
@field_validator("prompt")
def validate_prompt(cls, v):
if v is not None:
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
return v
try:
template = Template(v)
template_vars = template.environment.parse(v).find_all(
@ -84,6 +93,10 @@ class ReduceOperation(BaseOperation):
@field_validator("fold_prompt")
def validate_fold_prompt(cls, v):
if v is not None:
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
return v
try:
fold_template = Template(v)
fold_template_vars = fold_template.environment.parse(v).find_all(
@ -104,6 +117,10 @@ class ReduceOperation(BaseOperation):
@field_validator("merge_prompt")
def validate_merge_prompt(cls, v):
if v is not None:
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
return v
try:
merge_template = Template(v)
merge_template_vars = merge_template.environment.parse(v).find_all(
@ -181,6 +198,39 @@ class ReduceOperation(BaseOperation):
)
self.intermediates = {}
self.lineage_keys = self.config.get("output", {}).get("lineage", [])
# Check for non-Jinja prompts and prompt user for confirmation
if "prompt" in self.config and not has_jinja_syntax(self.config["prompt"]):
if not prompt_user_for_non_jinja_confirmation(
self.config["prompt"], self.config["name"], "prompt"
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your prompt."
)
# Mark that we need to append document statement (for reduce, use inputs)
self.config["_append_document_to_prompt"] = True
self.config["_is_reduce_operation"] = True
if "fold_prompt" in self.config and not has_jinja_syntax(
self.config["fold_prompt"]
):
if not prompt_user_for_non_jinja_confirmation(
self.config["fold_prompt"], self.config["name"], "fold_prompt"
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your fold_prompt."
)
self.config["_append_document_to_fold_prompt"] = True
self.config["_is_reduce_operation"] = True
if "merge_prompt" in self.config and not has_jinja_syntax(
self.config["merge_prompt"]
):
if not prompt_user_for_non_jinja_confirmation(
self.config["merge_prompt"], self.config["name"], "merge_prompt"
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your merge_prompt."
)
self.config["_append_document_to_merge_prompt"] = True
self.config["_is_reduce_operation"] = True
def execute(self, input_data: list[dict]) -> tuple[list[dict], float]:
"""
@ -236,6 +286,12 @@ class ReduceOperation(BaseOperation):
# Convert the grouped data to a list of tuples
grouped_data = list(grouped_data.items())
limit_value = self.config.get("limit")
if limit_value is not None:
# Sort by group size (smallest first) and take the limit
grouped_data = sorted(grouped_data, key=lambda x: len(x[1]))
grouped_data = grouped_data[:limit_value]
def process_group(
key: tuple, group_elems: list[dict]
) -> tuple[dict | None, float]:
@ -248,6 +304,16 @@ class ReduceOperation(BaseOperation):
group_list = group_elems
total_cost = 0.0
# Build retrieval context once per group
try:
retrieval_context = self._maybe_build_retrieval_context(
{
"reduce_key": dict(zip(self.config["reduce_key"], key)),
"inputs": group_list,
}
)
except Exception:
retrieval_context = "No extra context available."
# Apply value sampling if enabled
value_sampling = self.config.get("value_sampling", {})
@ -277,18 +343,26 @@ class ReduceOperation(BaseOperation):
# Only execute merge-based plans if associative = True
if "merge_prompt" in self.config and self.config.get("associative", True):
result, prompts, cost = self._parallel_fold_and_merge(key, group_list)
result, prompts, cost = self._parallel_fold_and_merge(
key, group_list, retrieval_context
)
elif self.config.get("fold_batch_size", None) and self.config.get(
"fold_batch_size"
) >= len(group_list):
# If the fold batch size is greater than or equal to the number of items in the group,
# we can just run a single fold operation
result, prompt, cost = self._batch_reduce(key, group_list)
result, prompt, cost = self._batch_reduce(
key, group_list, None, retrieval_context
)
prompts = [prompt]
elif "fold_prompt" in self.config:
result, prompts, cost = self._incremental_reduce(key, group_list)
result, prompts, cost = self._incremental_reduce(
key, group_list, retrieval_context
)
else:
result, prompt, cost = self._batch_reduce(key, group_list)
result, prompt, cost = self._batch_reduce(
key, group_list, None, retrieval_context
)
prompts = [prompt]
total_cost += cost
@ -300,6 +374,16 @@ class ReduceOperation(BaseOperation):
# Add the _observability_{self.config['name']} key to the result
result[f"_observability_{self.config['name']}"] = {"prompts": prompts}
# Add retrieved context if save_retriever_output is enabled
if self.config.get("save_retriever_output", False):
ctx = (
retrieval_context
if retrieval_context
and retrieval_context != "No extra context available."
else ""
)
result[f"_{self.config['name']}_retrieved_context"] = ctx
# Apply pass-through at the group level
if (
result is not None
@ -342,6 +426,9 @@ class ReduceOperation(BaseOperation):
if output is not None:
results.append(output)
if limit_value is not None and len(results) > limit_value:
results = results[:limit_value]
if self.config.get("persist_intermediates", False):
for result in results:
key = tuple(result[k] for k in self.config["reduce_key"])
@ -418,7 +505,7 @@ class ReduceOperation(BaseOperation):
return [group_list[i] for i in top_k_indices], cost
def _parallel_fold_and_merge(
self, key: tuple, group_list: list[dict]
self, key: tuple, group_list: list[dict], retrieval_context: str
) -> tuple[dict | None, float]:
"""
Perform parallel folding and merging on a group of items.
@ -583,7 +670,7 @@ class ReduceOperation(BaseOperation):
)
def _incremental_reduce(
self, key: tuple, group_list: list[dict]
self, key: tuple, group_list: list[dict], retrieval_context: str
) -> tuple[dict | None, list[str], float]:
"""
Perform an incremental reduce operation on a group of items.
@ -679,6 +766,7 @@ class ReduceOperation(BaseOperation):
batch: list[dict],
current_output: dict | None,
scratchpad: str | None = None,
retrieval_context: str | None = None,
) -> tuple[dict | None, str, float]:
"""
Perform an incremental fold operation on a batch of items.
@ -695,7 +783,7 @@ class ReduceOperation(BaseOperation):
the prompt used, and the cost of the fold operation.
"""
if current_output is None:
return self._batch_reduce(key, batch, scratchpad)
return self._batch_reduce(key, batch, scratchpad, retrieval_context)
start_time = time.time()
fold_prompt = strict_render(
@ -704,8 +792,15 @@ class ReduceOperation(BaseOperation):
"inputs": batch,
"output": current_output,
"reduce_key": dict(zip(self.config["reduce_key"], key)),
"retrieval_context": retrieval_context or "",
},
)
if retrieval_context and "retrieval_context" not in self.config.get(
"fold_prompt", ""
):
fold_prompt = (
f"Here is some extra context:\n{retrieval_context}\n\n{fold_prompt}"
)
response = self.runner.api.call_llm(
self.config.get("model", self.default_model),
@ -730,6 +825,7 @@ class ReduceOperation(BaseOperation):
end_time = time.time()
self._update_fold_time(end_time - start_time)
fold_cost = response.total_cost
if response.validated:
structured_mode = (
@ -751,7 +847,7 @@ class ReduceOperation(BaseOperation):
return None, fold_prompt, fold_cost
def _merge_results(
self, key: tuple, outputs: list[dict]
self, key: tuple, outputs: list[dict], retrieval_context: str | None = None
) -> tuple[dict | None, str, float]:
"""
Merge multiple outputs into a single result.
@ -772,8 +868,15 @@ class ReduceOperation(BaseOperation):
{
"outputs": outputs,
"reduce_key": dict(zip(self.config["reduce_key"], key)),
"retrieval_context": retrieval_context or "",
},
)
if retrieval_context and "retrieval_context" not in self.config.get(
"merge_prompt", ""
):
merge_prompt = (
f"Here is some extra context:\n{retrieval_context}\n\n{merge_prompt}"
)
response = self.runner.api.call_llm(
self.config.get("model", self.default_model),
"merge",
@ -798,6 +901,7 @@ class ReduceOperation(BaseOperation):
end_time = time.time()
self._update_merge_time(end_time - start_time)
merge_cost = response.total_cost
if response.validated:
structured_mode = (
@ -867,7 +971,11 @@ class ReduceOperation(BaseOperation):
self.merge_times.append(time)
def _batch_reduce(
self, key: tuple, group_list: list[dict], scratchpad: str | None = None
self,
key: tuple,
group_list: list[dict],
scratchpad: str | None = None,
retrieval_context: str | None = None,
) -> tuple[dict | None, str, float]:
"""
Perform a batch reduce operation on a group of items.
@ -887,8 +995,13 @@ class ReduceOperation(BaseOperation):
{
"reduce_key": dict(zip(self.config["reduce_key"], key)),
"inputs": group_list,
"retrieval_context": retrieval_context or "",
},
)
if retrieval_context and "retrieval_context" not in self.config.get(
"prompt", ""
):
prompt = f"Here is some extra context:\n{retrieval_context}\n\n{prompt}"
item_cost = 0
response = self.runner.api.call_llm(

View File

@ -10,11 +10,16 @@ import jinja2
from jinja2 import Template
from litellm import model_cost
from pydantic import Field, ValidationInfo, field_validator, model_validator
from rich.prompt import Confirm
from docetl.operations.base import BaseOperation
from docetl.operations.utils import RichLoopBar, rich_as_completed, strict_render
from docetl.utils import completion_cost, extract_jinja_variables
from docetl.operations.utils.blocking import RuntimeBlockingOptimizer
from docetl.utils import (
completion_cost,
extract_jinja_variables,
has_jinja_syntax,
prompt_user_for_non_jinja_confirmation,
)
def find_cluster(item, cluster_map):
@ -35,6 +40,7 @@ class ResolveOperation(BaseOperation):
comparison_model: str | None = None
blocking_keys: list[str] | None = None
blocking_threshold: float | None = Field(None, ge=0, le=1)
blocking_target_recall: float | None = Field(None, ge=0, le=1)
blocking_conditions: list[str] | None = None
input: dict[str, Any] | None = None
embedding_batch_size: int | None = Field(None, gt=0)
@ -48,6 +54,10 @@ class ResolveOperation(BaseOperation):
@field_validator("comparison_prompt")
def validate_comparison_prompt(cls, v):
if v is not None:
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
return v
try:
comparison_template = Template(v)
comparison_vars = comparison_template.environment.parse(v).find_all(
@ -70,6 +80,10 @@ class ResolveOperation(BaseOperation):
@field_validator("resolution_prompt")
def validate_resolution_prompt(cls, v):
if v is not None:
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
return v
try:
reduction_template = Template(v)
reduction_vars = reduction_template.environment.parse(v).find_all(
@ -123,6 +137,38 @@ class ResolveOperation(BaseOperation):
return self
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Check for non-Jinja prompts and prompt user for confirmation
if "comparison_prompt" in self.config and not has_jinja_syntax(
self.config["comparison_prompt"]
):
if not prompt_user_for_non_jinja_confirmation(
self.config["comparison_prompt"],
self.config["name"],
"comparison_prompt",
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your comparison_prompt."
)
# Mark that we need to append document statement
# Note: comparison_prompt uses input1 and input2, so we'll handle it specially in strict_render
self.config["_append_document_to_comparison_prompt"] = True
if "resolution_prompt" in self.config and not has_jinja_syntax(
self.config["resolution_prompt"]
):
if not prompt_user_for_non_jinja_confirmation(
self.config["resolution_prompt"],
self.config["name"],
"resolution_prompt",
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your resolution_prompt."
)
# Mark that we need to append document statement (resolution uses inputs)
self.config["_append_document_to_resolution_prompt"] = True
self.config["_is_reduce_operation"] = True
def compare_pair(
self,
comparison_prompt: str,
@ -221,26 +267,75 @@ class ResolveOperation(BaseOperation):
blocking_keys = self.config.get("blocking_keys", [])
blocking_threshold = self.config.get("blocking_threshold")
blocking_conditions = self.config.get("blocking_conditions", [])
limit_comparisons = self.config.get("limit_comparisons")
total_cost = 0
if self.status:
self.status.stop()
if not blocking_threshold and not blocking_conditions:
# Prompt the user for confirmation
if not Confirm.ask(
"[yellow]Warning: No blocking keys or conditions specified. "
"This may result in a large number of comparisons. "
"We recommend specifying at least one blocking key or condition, or using the optimizer to automatically come up with these. "
"Do you want to continue without blocking?[/yellow]",
console=self.runner.console,
):
raise ValueError("Operation cancelled by user.")
# Track pre-computed embeddings from auto-optimization
precomputed_embeddings = None
# Auto-compute blocking threshold if no blocking configuration is provided
if not blocking_threshold and not blocking_conditions and not limit_comparisons:
# Get target recall from operation config (default 0.95)
target_recall = self.config.get("blocking_target_recall", 0.95)
self.console.log(
f"[yellow]No blocking configuration. Auto-computing threshold (target recall: {target_recall:.0%})...[/yellow]"
)
# Determine blocking keys if not set
auto_blocking_keys = blocking_keys if blocking_keys else None
if not auto_blocking_keys:
prompt_template = self.config.get("comparison_prompt", "")
prompt_vars = extract_jinja_variables(prompt_template)
prompt_vars = [
var
for var in prompt_vars
if var not in ["input", "input1", "input2"]
]
auto_blocking_keys = list(
set([var.split(".")[-1] for var in prompt_vars])
)
if not auto_blocking_keys:
auto_blocking_keys = list(input_data[0].keys())
blocking_keys = auto_blocking_keys
# Create comparison function for threshold optimization
def compare_fn_for_optimization(item1, item2):
return self.compare_pair(
self.config["comparison_prompt"],
self.config.get("comparison_model", self.default_model),
item1,
item2,
blocking_keys=[], # Don't use key-based shortcut during optimization
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
)
# Run threshold optimization
optimizer = RuntimeBlockingOptimizer(
runner=self.runner,
config=self.config,
default_model=self.default_model,
max_threads=self.max_threads,
console=self.console,
target_recall=target_recall,
sample_size=min(100, len(input_data) * (len(input_data) - 1) // 4),
)
blocking_threshold, precomputed_embeddings, optimization_cost = (
optimizer.optimize_resolve(
input_data,
compare_fn_for_optimization,
blocking_keys=blocking_keys,
)
)
total_cost += optimization_cost
input_schema = self.config.get("input", {}).get("schema", {})
if not blocking_keys:
# Set them to all keys in the input data
blocking_keys = list(input_data[0].keys())
limit_comparisons = self.config.get("limit_comparisons")
total_cost = 0
def is_match(item1: dict[str, Any], item2: dict[str, Any]) -> bool:
return any(
@ -251,120 +346,101 @@ class ResolveOperation(BaseOperation):
# Calculate embeddings if blocking_threshold is set
embeddings = None
if blocking_threshold is not None:
def get_embeddings_batch(
items: list[dict[str, Any]]
) -> list[tuple[list[float], float]]:
# Use precomputed embeddings if available from auto-optimization
if precomputed_embeddings is not None:
embeddings = precomputed_embeddings
else:
self.console.log(
f"[cyan]Creating embeddings for {len(input_data)} items...[/cyan]"
)
embedding_model = self.config.get(
"embedding_model", "text-embedding-3-small"
)
model_input_context_length = model_cost.get(embedding_model, {}).get(
"max_input_tokens", 8192
)
batch_size = self.config.get("embedding_batch_size", 1000)
embeddings = []
embedding_cost = 0.0
num_batches = (len(input_data) + batch_size - 1) // batch_size
texts = [
" ".join(str(item[key]) for key in blocking_keys if key in item)[
: model_input_context_length * 3
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(input_data))
batch = input_data[start_idx:end_idx]
if num_batches > 1:
self.console.log(
f"[dim]Creating embeddings: batch {batch_idx + 1}/{num_batches} "
f"({end_idx}/{len(input_data)} items)[/dim]"
)
texts = [
" ".join(
str(item[key]) for key in blocking_keys if key in item
)[: model_input_context_length * 3]
for item in batch
]
for item in items
]
response = self.runner.api.gen_embedding(
model=embedding_model, input=texts
)
embeddings.extend([data["embedding"] for data in response["data"]])
embedding_cost += completion_cost(response)
response = self.runner.api.gen_embedding(
model=embedding_model, input=texts
)
return [
(data["embedding"], completion_cost(response))
for data in response["data"]
]
total_cost += embedding_cost
embeddings = []
costs = []
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
for i in range(
0, len(input_data), self.config.get("embedding_batch_size", 1000)
):
batch = input_data[
i : i + self.config.get("embedding_batch_size", 1000)
]
batch_results = list(executor.map(get_embeddings_batch, [batch]))
# Build a mapping of blocking key values to indices
# This is used later for cluster merging (when two items match, merge all items sharing their key values)
value_to_indices: dict[tuple[str, ...], list[int]] = {}
for i, item in enumerate(input_data):
key = tuple(str(item.get(k, "")) for k in blocking_keys)
if key not in value_to_indices:
value_to_indices[key] = []
value_to_indices[key].append(i)
for result in batch_results:
embeddings.extend([r[0] for r in result])
costs.extend([r[1] for r in result])
# Total number of pairs to potentially compare
n = len(input_data)
total_pairs = n * (n - 1) // 2
total_cost += sum(costs)
# Generate all pairs to compare, ensuring no duplicate comparisons
def get_unique_comparison_pairs() -> (
tuple[list[tuple[int, int]], dict[tuple[str, ...], list[int]]]
):
# Create a mapping of values to their indices
value_to_indices: dict[tuple[str, ...], list[int]] = {}
for i, item in enumerate(input_data):
# Create a hashable key from the blocking keys
key = tuple(str(item.get(k, "")) for k in blocking_keys)
if key not in value_to_indices:
value_to_indices[key] = []
value_to_indices[key].append(i)
# Generate pairs for comparison, comparing each unique value combination only once
comparison_pairs = []
keys = list(value_to_indices.keys())
# First, handle comparisons between different values
for i in range(len(keys)):
for j in range(i + 1, len(keys)):
# Only need one comparison between different values
idx1 = value_to_indices[keys[i]][0]
idx2 = value_to_indices[keys[j]][0]
if idx1 < idx2: # Maintain ordering to avoid duplicates
comparison_pairs.append((idx1, idx2))
return comparison_pairs, value_to_indices
comparison_pairs, value_to_indices = get_unique_comparison_pairs()
# Filter pairs based on blocking conditions
def meets_blocking_conditions(pair: tuple[int, int]) -> bool:
i, j = pair
return (
is_match(input_data[i], input_data[j]) if blocking_conditions else False
)
# Start with pairs that meet blocking conditions, or empty list if no conditions
code_blocked_pairs = (
list(filter(meets_blocking_conditions, comparison_pairs))
if blocking_conditions
else []
)
# Apply code-based blocking conditions (check all pairs)
code_blocked_pairs = []
if blocking_conditions:
for i in range(n):
for j in range(i + 1, n):
if is_match(input_data[i], input_data[j]):
code_blocked_pairs.append((i, j))
# Apply cosine similarity blocking if threshold is specified
embedding_blocked_pairs = []
if blocking_threshold is not None and embeddings is not None:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
similarity_matrix = cosine_similarity(embeddings)
# Add pairs that meet the cosine similarity threshold and aren't already blocked
code_blocked_set = set(code_blocked_pairs)
for i, j in comparison_pairs:
if (i, j) not in code_blocked_set:
similarity = similarity_matrix[i, j]
if similarity >= blocking_threshold:
embedding_blocked_pairs.append((i, j))
# Use numpy to efficiently find all pairs above threshold
i_indices, j_indices = np.triu_indices(n, k=1)
similarities = similarity_matrix[i_indices, j_indices]
above_threshold_mask = similarities >= blocking_threshold
self.console.log(
f"Cosine similarity blocking: added {len(embedding_blocked_pairs)} pairs "
f"(threshold: {blocking_threshold})"
)
# Get pairs above threshold
above_threshold_i = i_indices[above_threshold_mask]
above_threshold_j = j_indices[above_threshold_mask]
# Combine pairs with prioritization for sampling
# Filter out pairs already in code_blocked_set
embedding_blocked_pairs = [
(int(i), int(j))
for i, j in zip(above_threshold_i, above_threshold_j)
if (i, j) not in code_blocked_set
]
# Combine pairs from both blocking methods
all_blocked_pairs = code_blocked_pairs + embedding_blocked_pairs
# If no pairs are blocked at all, fall back to all comparison pairs
if not all_blocked_pairs:
all_blocked_pairs = comparison_pairs
# If no blocking was applied, compare all pairs
if not blocking_conditions and blocking_threshold is None:
all_blocked_pairs = [(i, j) for i in range(n) for j in range(i + 1, n)]
# Apply limit_comparisons with prioritization
if limit_comparisons is not None and len(all_blocked_pairs) > limit_comparisons:
# Prioritize code-based pairs, then sample from embedding pairs if needed
@ -431,18 +507,6 @@ class ResolveOperation(BaseOperation):
cluster_map[root_idx] = root1
clusters[root_idx] = set()
# Calculate and print statistics
total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2
comparisons_made = len(blocked_pairs)
comparisons_saved = total_possible_comparisons - comparisons_made
self.console.log(
f"[green]Comparisons saved by deduping and blocking: {comparisons_saved} "
f"({(comparisons_saved / total_possible_comparisons) * 100:.2f}%)[/green]"
)
self.console.log(
f"[blue]Number of pairs to compare: {len(blocked_pairs)}[/blue]"
)
# Compute an auto-batch size based on the number of comparisons
def auto_batch() -> int:
# Maximum batch size limit for 4o-mini model
@ -468,7 +532,14 @@ class ResolveOperation(BaseOperation):
# Compare pairs and update clusters in real-time
batch_size = self.config.get("compare_batch_size", auto_batch())
self.console.log(f"Using compare batch size: {batch_size}")
# Log blocking summary
total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2
self.console.log(
f"Comparing {len(blocked_pairs):,} pairs "
f"({len(blocked_pairs)/total_possible_comparisons*100:.1f}% of {total_possible_comparisons:,} total, "
f"batch size: {batch_size})"
)
pair_costs = 0
pbar = RichLoopBar(

View File

@ -1,4 +1,5 @@
from .api import APIWrapper
from .blocking import RuntimeBlockingOptimizer
from .cache import (
cache,
cache_key,
@ -15,6 +16,7 @@ from .validation import safe_eval, convert_val, convert_dict_schema_to_list_sche
__all__ = [
'APIWrapper',
'RuntimeBlockingOptimizer',
'cache',
'cache_key',
'clear_cache',

View File

@ -86,23 +86,35 @@ class APIWrapper(object):
from litellm import Router
# Build model list: operation model first, then fallbacks
model_list = [{
"model_name": operation_model,
"litellm_params": {
"model": operation_model,
**({"api_base": self.default_lm_api_base} if self.default_lm_api_base else {})
model_list = [
{
"model_name": operation_model,
"litellm_params": {
"model": operation_model,
**(
{"api_base": self.default_lm_api_base}
if self.default_lm_api_base
else {}
),
},
}
}]
]
model_names = [operation_model]
# Add fallback models, skipping duplicates
seen = {operation_model}
for cfg in self.fallback_models_config:
name = cfg.get("model_name") if isinstance(cfg, dict) else (cfg if isinstance(cfg, str) else None)
name = (
cfg.get("model_name")
if isinstance(cfg, dict)
else (cfg if isinstance(cfg, str) else None)
)
if not name or name in seen:
continue
seen.add(name)
params = cfg.get("litellm_params", {}).copy() if isinstance(cfg, dict) else {}
params = (
cfg.get("litellm_params", {}).copy() if isinstance(cfg, dict) else {}
)
params["model"] = name
if self.default_lm_api_base and "api_base" not in params:
params["api_base"] = self.default_lm_api_base
@ -171,7 +183,11 @@ class APIWrapper(object):
extra_kwargs["api_base"] = self.default_embedding_api_base
# Use embedding router if available (for fallback models)
embedding_fn = self.embedding_router.embedding if self.embedding_router else embedding
embedding_fn = (
self.embedding_router.embedding
if self.embedding_router
else embedding
)
result = embedding_fn(model=model, input=input, **extra_kwargs)
# Cache the result
c.set(key, result)
@ -261,6 +277,10 @@ class APIWrapper(object):
):
model = "azure/" + model
# Pop off temperature if it's gpt-5 in the model name
if "gpt-5" in model:
litellm_completion_kwargs.pop("temperature", None)
total_cost = 0.0
validated = False
with cache as c:
@ -361,7 +381,9 @@ class APIWrapper(object):
# Use router if available (for fallback models), otherwise use direct completion
# When using router, ensure gleaning model is tried first, then fallback models
if self.router and self.fallback_models_config:
completion_fn = self._get_router_with_operation_model(gleaning_model)
completion_fn = self._get_router_with_operation_model(
gleaning_model
)
else:
completion_fn = completion
@ -848,6 +870,10 @@ Your main result must be sent via send_output. The updated_scratchpad is only fo
else:
completion_fn = completion
# Pop off temperature if it's gpt-5 in the model name
if "gpt-5" in model:
extra_litellm_kwargs.pop("temperature", None)
if use_structured_output:
try:
response = completion_fn(

View File

@ -0,0 +1,567 @@
"""
Runtime blocking threshold optimization utilities.
This module provides functionality for automatically computing embedding-based
blocking thresholds at runtime when no blocking configuration is provided.
"""
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable
import numpy as np
from litellm import model_cost
from rich.console import Console
from docetl.utils import completion_cost, extract_jinja_variables
class RuntimeBlockingOptimizer:
"""
Computes optimal embedding-based blocking thresholds at runtime.
This class samples pairs from the dataset, performs LLM comparisons,
and finds the optimal cosine similarity threshold that achieves a
target recall rate.
"""
def __init__(
self,
runner,
config: dict[str, Any],
default_model: str,
max_threads: int,
console: Console,
target_recall: float = 0.95,
sample_size: int = 100,
sampling_weight: float = 20.0,
):
"""
Initialize the RuntimeBlockingOptimizer.
Args:
runner: The pipeline runner instance.
config: Operation configuration.
default_model: Default LLM model for comparisons.
max_threads: Maximum threads for parallel processing.
console: Rich console for logging.
target_recall: Target recall rate (default 0.95).
sample_size: Number of pairs to sample for threshold estimation.
sampling_weight: Weight for exponential sampling towards higher similarities.
"""
self.runner = runner
self.config = config
self.default_model = default_model
self.max_threads = max_threads
self.console = console
self.target_recall = target_recall
self.sample_size = sample_size
self.sampling_weight = sampling_weight
def compute_embeddings(
self,
input_data: list[dict[str, Any]],
keys: list[str],
embedding_model: str | None = None,
batch_size: int = 1000,
) -> tuple[list[list[float]], float]:
"""
Compute embeddings for the input data.
Args:
input_data: List of input documents.
keys: Keys to use for embedding text.
embedding_model: Model to use for embeddings.
batch_size: Batch size for embedding computation.
Returns:
Tuple of (embeddings list, total cost).
"""
embedding_model = embedding_model or self.config.get(
"embedding_model", "text-embedding-3-small"
)
model_input_context_length = model_cost.get(embedding_model, {}).get(
"max_input_tokens", 8192
)
texts = [
" ".join(str(item[key]) for key in keys if key in item)[
: model_input_context_length * 3
]
for item in input_data
]
self.console.log(f"[cyan]Creating embeddings for {len(texts)} items...[/cyan]")
embeddings = []
total_cost = 0.0
num_batches = (len(texts) + batch_size - 1) // batch_size
for batch_idx, i in enumerate(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
if num_batches > 1:
self.console.log(
f"[dim] Batch {batch_idx + 1}/{num_batches} "
f"({len(embeddings) + len(batch)}/{len(texts)} items)[/dim]"
)
response = self.runner.api.gen_embedding(
model=embedding_model,
input=batch,
)
embeddings.extend([data["embedding"] for data in response["data"]])
total_cost += completion_cost(response)
return embeddings, total_cost
def calculate_cosine_similarities_self(
self, embeddings: list[list[float]]
) -> list[tuple[int, int, float]]:
"""
Calculate pairwise cosine similarities for self-join.
Args:
embeddings: List of embedding vectors.
Returns:
List of (i, j, similarity) tuples for all pairs where i < j.
"""
embeddings_array = np.array(embeddings)
norms = np.linalg.norm(embeddings_array, axis=1)
# Avoid division by zero
norms = np.where(norms == 0, 1e-10, norms)
dot_products = np.dot(embeddings_array, embeddings_array.T)
similarities_matrix = dot_products / np.outer(norms, norms)
i, j = np.triu_indices(len(embeddings), k=1)
similarities = list(
zip(i.tolist(), j.tolist(), similarities_matrix[i, j].tolist())
)
return similarities
def calculate_cosine_similarities_cross(
self,
left_embeddings: list[list[float]],
right_embeddings: list[list[float]],
) -> list[tuple[int, int, float]]:
"""
Calculate cosine similarities between two sets of embeddings.
Args:
left_embeddings: Embeddings for left dataset.
right_embeddings: Embeddings for right dataset.
Returns:
List of (left_idx, right_idx, similarity) tuples.
"""
left_array = np.array(left_embeddings)
right_array = np.array(right_embeddings)
dot_product = np.dot(left_array, right_array.T)
norm_left = np.linalg.norm(left_array, axis=1)
norm_right = np.linalg.norm(right_array, axis=1)
# Avoid division by zero
norm_left = np.where(norm_left == 0, 1e-10, norm_left)
norm_right = np.where(norm_right == 0, 1e-10, norm_right)
similarities = dot_product / np.outer(norm_left, norm_right)
return [
(i, j, float(sim))
for i, row in enumerate(similarities)
for j, sim in enumerate(row)
]
def sample_pairs(
self,
similarities: list[tuple[int, int, float]],
num_bins: int = 10,
stratified_fraction: float = 0.5,
) -> list[tuple[int, int]]:
"""
Sample pairs using a hybrid of stratified and exponential-weighted sampling.
This ensures coverage across the similarity distribution while still
focusing on high-similarity pairs where matches are more likely.
Args:
similarities: List of (i, j, similarity) tuples.
num_bins: Number of bins for stratified sampling.
stratified_fraction: Fraction of samples to allocate to stratified sampling.
Returns:
List of sampled (i, j) pairs.
"""
if len(similarities) == 0:
return []
sample_count = min(self.sample_size, len(similarities))
stratified_count = int(sample_count * stratified_fraction)
exponential_count = sample_count - stratified_count
sampled_indices = set()
sim_values = np.array([sim[2] for sim in similarities])
# Part 1: Stratified sampling across bins
if stratified_count > 0:
bin_edges = np.linspace(
sim_values.min(), sim_values.max() + 1e-9, num_bins + 1
)
samples_per_bin = max(1, stratified_count // num_bins)
for bin_idx in range(num_bins):
bin_mask = (sim_values >= bin_edges[bin_idx]) & (
sim_values < bin_edges[bin_idx + 1]
)
bin_indices = np.where(bin_mask)[0]
if len(bin_indices) > 0:
# Within each bin, use exponential weighting
bin_sims = sim_values[bin_indices]
bin_weights = np.exp(self.sampling_weight * bin_sims)
bin_weights /= bin_weights.sum()
n_to_sample = min(samples_per_bin, len(bin_indices))
chosen = np.random.choice(
bin_indices,
size=n_to_sample,
replace=False,
p=bin_weights,
)
sampled_indices.update(chosen.tolist())
# Part 2: Exponential-weighted sampling for remaining slots
if exponential_count > 0:
remaining_indices = [
i for i in range(len(similarities)) if i not in sampled_indices
]
if remaining_indices:
remaining_sims = sim_values[remaining_indices]
weights = np.exp(self.sampling_weight * remaining_sims)
weights /= weights.sum()
n_to_sample = min(exponential_count, len(remaining_indices))
chosen = np.random.choice(
remaining_indices,
size=n_to_sample,
replace=False,
p=weights,
)
sampled_indices.update(chosen.tolist())
sampled_pairs = [
(similarities[i][0], similarities[i][1]) for i in sampled_indices
]
return sampled_pairs
def _print_similarity_histogram(
self,
similarities: list[tuple[int, int, float]],
comparison_results: list[tuple[int, int, bool]],
threshold: float | None = None,
):
"""
Print a histogram of embedding cosine similarity distribution.
Args:
similarities: List of (i, j, similarity) tuples.
comparison_results: List of (i, j, is_match) from LLM comparisons.
threshold: Optional threshold to highlight in the histogram.
"""
# Filter out self-similarities (similarity == 1)
flat_similarities = [sim[2] for sim in similarities if sim[2] != 1]
if not flat_similarities:
return
hist, bin_edges = np.histogram(flat_similarities, bins=20)
max_bar_width, max_count = 40, max(hist) if max(hist) > 0 else 1
normalized_hist = [int(count / max_count * max_bar_width) for count in hist]
# Create a dictionary to store true labels
true_labels = {(i, j): is_match for i, j, is_match in comparison_results}
# Count pairs above threshold
pairs_above_threshold = (
sum(1 for sim in flat_similarities if sim >= threshold) if threshold else 0
)
total_pairs = len(flat_similarities)
lines = []
for i, count in enumerate(normalized_hist):
bar = "" * count
bin_start, bin_end = bin_edges[i], bin_edges[i + 1]
label = f"{bin_start:.2f}-{bin_end:.2f}"
# Count true matches and not matches in this bin
true_matches = 0
not_matches = 0
labeled_count = 0
for sim in similarities:
if bin_start <= sim[2] < bin_end:
if (sim[0], sim[1]) in true_labels:
labeled_count += 1
if true_labels[(sim[0], sim[1])]:
true_matches += 1
else:
not_matches += 1
# Calculate percentages of labeled pairs
if labeled_count > 0:
true_match_percent = (true_matches / labeled_count) * 100
label_info = f"[green]{true_match_percent:5.1f}%[/green] match"
else:
label_info = "[dim]--[/dim]"
# Highlight the bin containing the threshold
if threshold is not None and bin_start <= threshold < bin_end:
lines.append(
f"[bold yellow]{label}[/bold yellow] {bar:<{max_bar_width}} "
f"[dim]n={hist[i]:>5}[/dim] {label_info} [bold yellow]◀ threshold[/bold yellow]"
)
else:
lines.append(
f"{label} {bar:<{max_bar_width}} "
f"[dim]n={hist[i]:>5}[/dim] {label_info}"
)
from rich.panel import Panel
histogram_content = "\n".join(lines)
title = f"Similarity Distribution ({pairs_above_threshold:,} of {total_pairs:,} pairs ≥ {threshold:.4f})"
self.console.log(Panel(histogram_content, title=title, border_style="cyan"))
def find_optimal_threshold(
self,
comparisons: list[tuple[int, int, bool]],
similarities: list[tuple[int, int, float]],
) -> tuple[float, float]:
"""
Find the optimal similarity threshold that achieves target recall.
Args:
comparisons: List of (i, j, is_match) from LLM comparisons.
similarities: List of (i, j, similarity) tuples.
Returns:
Tuple of (optimal_threshold, achieved_recall).
"""
if not comparisons or not any(comp[2] for comp in comparisons):
# No matches found, use a high threshold to be conservative
self.console.log(
"[yellow]No matches found in sample. Using 99th percentile "
"similarity as threshold.[/yellow]"
)
all_sims = [sim[2] for sim in similarities]
threshold = float(np.percentile(all_sims, 99)) if all_sims else 0.9
return threshold, 0.0
true_labels = np.array([comp[2] for comp in comparisons])
sim_dict = {(i, j): sim for i, j, sim in similarities}
sim_scores = np.array([sim_dict.get((i, j), 0.0) for i, j, _ in comparisons])
thresholds = np.linspace(0, 1, 100)
recalls = []
for threshold in thresholds:
predictions = sim_scores >= threshold
tp = np.sum(predictions & true_labels)
fn = np.sum(~predictions & true_labels)
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
recalls.append(recall)
# Find highest threshold that achieves target recall
valid_indices = [i for i, r in enumerate(recalls) if r >= self.target_recall]
if not valid_indices:
# If no threshold achieves target recall, use the one with highest recall
best_idx = int(np.argmax(recalls))
optimal_threshold = float(thresholds[best_idx])
achieved_recall = float(recalls[best_idx])
self.console.log(
f"[yellow]Warning: Could not achieve target recall {self.target_recall:.0%}. "
f"Using threshold {optimal_threshold:.4f} with recall {achieved_recall:.2%}.[/yellow]"
)
else:
best_idx = max(valid_indices)
optimal_threshold = float(thresholds[best_idx])
achieved_recall = float(recalls[best_idx])
return round(optimal_threshold, 4), achieved_recall
def optimize_resolve(
self,
input_data: list[dict[str, Any]],
compare_fn: Callable[[dict, dict], tuple[bool, float, str]],
blocking_keys: list[str] | None = None,
) -> tuple[float, list[list[float]], float]:
"""
Compute optimal blocking threshold for resolve operation.
Args:
input_data: Input dataset.
compare_fn: Function to compare two items, returns (is_match, cost, prompt).
blocking_keys: Keys to use for blocking. If None, extracted from prompt.
Returns:
Tuple of (threshold, embeddings, total_cost).
"""
from rich.panel import Panel
# Determine blocking keys
if not blocking_keys:
prompt_template = self.config.get("comparison_prompt", "")
prompt_vars = extract_jinja_variables(prompt_template)
prompt_vars = [
var for var in prompt_vars if var not in ["input", "input1", "input2"]
]
blocking_keys = list(set([var.split(".")[-1] for var in prompt_vars]))
if not blocking_keys:
blocking_keys = list(input_data[0].keys())
# Compute embeddings
embeddings, embedding_cost = self.compute_embeddings(input_data, blocking_keys)
# Calculate similarities
similarities = self.calculate_cosine_similarities_self(embeddings)
# Sample pairs
sampled_pairs = self.sample_pairs(similarities)
if not sampled_pairs:
self.console.log(
"[yellow]No pairs to sample. Using default threshold 0.8.[/yellow]"
)
return 0.8, embeddings, embedding_cost
# Perform comparisons
comparisons = []
comparison_cost = 0.0
matches_found = 0
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
futures = {
executor.submit(compare_fn, input_data[i], input_data[j]): (i, j)
for i, j in sampled_pairs
}
for future in as_completed(futures):
i, j = futures[future]
try:
is_match, cost, _ = future.result()
comparisons.append((i, j, is_match))
comparison_cost += cost
if is_match:
matches_found += 1
except Exception as e:
self.console.log(f"[red]Comparison error: {e}[/red]")
comparisons.append((i, j, False))
# Find optimal threshold
threshold, achieved_recall = self.find_optimal_threshold(
comparisons, similarities
)
total_cost = embedding_cost + comparison_cost
# Print histogram visualization
self._print_similarity_histogram(similarities, comparisons, threshold)
# Print summary
n = len(input_data)
total_pairs = n * (n - 1) // 2
pairs_above = sum(1 for s in similarities if s[2] >= threshold)
summary = (
f"[bold]Blocking keys:[/bold] {blocking_keys}\n"
f"[bold]Sampled:[/bold] {len(sampled_pairs)} pairs → {matches_found} matches ({matches_found/len(sampled_pairs)*100:.1f}%)\n"
f"[bold]Threshold:[/bold] {threshold:.4f}{achieved_recall:.1%} recall (target: {self.target_recall:.0%})\n"
f"[bold]Pairs to compare:[/bold] {pairs_above:,} of {total_pairs:,} ({pairs_above/total_pairs*100:.1f}%)\n"
f"[bold]Optimization cost:[/bold] ${total_cost:.4f}"
)
self.console.log(
Panel(
summary, title="Blocking Threshold Optimization", border_style="green"
)
)
return threshold, embeddings, total_cost
def optimize_equijoin(
self,
left_data: list[dict[str, Any]],
right_data: list[dict[str, Any]],
compare_fn: Callable[[dict, dict], tuple[bool, float]],
left_keys: list[str] | None = None,
right_keys: list[str] | None = None,
) -> tuple[float, list[list[float]], list[list[float]], float]:
"""
Compute optimal blocking threshold for equijoin operation.
Args:
left_data: Left dataset.
right_data: Right dataset.
compare_fn: Function to compare two items, returns (is_match, cost).
left_keys: Keys to use for left dataset embeddings.
right_keys: Keys to use for right dataset embeddings.
Returns:
Tuple of (threshold, left_embeddings, right_embeddings, total_cost).
"""
from rich.panel import Panel
# Determine keys
if not left_keys:
left_keys = list(left_data[0].keys()) if left_data else []
if not right_keys:
right_keys = list(right_data[0].keys()) if right_data else []
# Compute embeddings
left_embeddings, left_cost = self.compute_embeddings(left_data, left_keys)
right_embeddings, right_cost = self.compute_embeddings(right_data, right_keys)
embedding_cost = left_cost + right_cost
# Calculate cross similarities
similarities = self.calculate_cosine_similarities_cross(
left_embeddings, right_embeddings
)
# Sample pairs
sampled_pairs = self.sample_pairs(similarities)
if not sampled_pairs:
self.console.log(
"[yellow]No pairs to sample. Using default threshold 0.8.[/yellow]"
)
return 0.8, left_embeddings, right_embeddings, embedding_cost
# Perform comparisons
comparisons = []
comparison_cost = 0.0
matches_found = 0
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
futures = {
executor.submit(compare_fn, left_data[i], right_data[j]): (i, j)
for i, j in sampled_pairs
}
for future in as_completed(futures):
i, j = futures[future]
try:
is_match, cost = future.result()
comparisons.append((i, j, is_match))
comparison_cost += cost
if is_match:
matches_found += 1
except Exception as e:
self.console.log(f"[red]Comparison error: {e}[/red]")
comparisons.append((i, j, False))
# Find optimal threshold
threshold, achieved_recall = self.find_optimal_threshold(
comparisons, similarities
)
total_cost = embedding_cost + comparison_cost
# Print histogram visualization
self._print_similarity_histogram(similarities, comparisons, threshold)
# Print summary
total_pairs = len(left_data) * len(right_data)
pairs_above = sum(1 for s in similarities if s[2] >= threshold)
summary = (
f"[bold]Left keys:[/bold] {left_keys} [bold]Right keys:[/bold] {right_keys}\n"
f"[bold]Sampled:[/bold] {len(sampled_pairs)} pairs → {matches_found} matches ({matches_found/len(sampled_pairs)*100:.1f}%)\n"
f"[bold]Threshold:[/bold] {threshold:.4f}{achieved_recall:.1%} recall (target: {self.target_recall:.0%})\n"
f"[bold]Pairs to compare:[/bold] {pairs_above:,} of {total_pairs:,} ({pairs_above/total_pairs*100:.1f}%)\n"
f"[bold]Optimization cost:[/bold] ${total_cost:.4f}"
)
self.console.log(
Panel(
summary, title="Blocking Threshold Optimization", border_style="green"
)
)
return threshold, left_embeddings, right_embeddings, total_cost

View File

@ -7,6 +7,8 @@ from jinja2.exceptions import UndefinedError
from rich import print as rprint
from rich.prompt import Prompt
from docetl.utils import has_jinja_syntax
aeval = Interpreter()
@ -28,17 +30,44 @@ def strict_render(template: Template | str, context: dict[str, Any]) -> str:
# Create strict environment
env = Environment(undefined=StrictUndefined)
# Convert string to Template if needed
# Only process string templates for non-Jinja syntax check
if isinstance(template, str):
template_string = template
# # If "inputs" in the context, make sure they are not accessing some attribute of inputs
# if "inputs" in context and "{{ inputs." in template:
# raise UndefinedError("The inputs variable is a list, so you cannot access attributes of inputs. Use inputs[index].key instead.")
# Check if template doesn't have Jinja syntax and append document statement
if not has_jinja_syntax(template_string):
# Determine the operation type based on context variables
if "left" in context and "right" in context:
# Equijoin operation - append both documents
template_string = (
f"{template_string}\n\nHere are the documents:\n"
f"Left document: {{{{ left }}}}\n"
f"Right document: {{{{ right }}}}"
)
elif "input1" in context and "input2" in context:
# Comparison operation (resolve) - append both documents
template_string = (
f"{template_string}\n\nHere are the documents:\n"
f"Document 1: {{{{ input1 }}}}\n"
f"Document 2: {{{{ input2 }}}}"
)
elif "inputs" in context:
# Reduce operation - append "Here are the documents: {{ inputs }}"
template_string = (
f"{template_string}\n\nHere are the documents: {{{{ inputs }}}}"
)
elif "input" in context:
# Regular operation - append "Here is the document: {{ input }}"
template_string = (
f"{template_string}\n\nHere is the document: {{{{ input }}}}"
)
# Convert string template to Template object
try:
template = env.from_string(template)
template = env.from_string(template_string)
except Exception as e:
raise ValueError(f"Invalid template: {str(e)}")
# If template is already a Template object, use it as-is
try:
return template.render(context)
@ -122,7 +151,7 @@ def _is_integer(value: Any) -> bool:
def _is_number(value: Any) -> bool:
"""Return True if value is a real number (int or float) but not a bool."""
return (isinstance(value, (int, float)) and not isinstance(value, bool))
return isinstance(value, (int, float)) and not isinstance(value, bool)
def _validate_scalar(value: Any, schema: dict[str, Any]) -> bool:
@ -176,7 +205,9 @@ def _validate_value_against_schema(value: Any, schema: dict[str, Any]) -> bool:
return False
# Validate known properties
for key, prop_schema in properties.items():
if key in value and not _validate_value_against_schema(value[key], prop_schema):
if key in value and not _validate_value_against_schema(
value[key], prop_schema
):
return False
# additionalProperties constraint
if not additional_props_allowed:

View File

@ -55,7 +55,7 @@ class Optimizer:
def __init__(
self,
runner: "DSLRunner",
rewrite_agent_model: str = "gpt-4o",
rewrite_agent_model: str = "gpt-5.1",
judge_agent_model: str = "gpt-4o-mini",
litellm_kwargs: dict[str, Any] = {},
resume: bool = False,
@ -67,7 +67,7 @@ class Optimizer:
Args:
yaml_file (str): Path to the YAML configuration file.
model (str): The name of the language model to use. Defaults to "gpt-4o".
model (str): The name of the language model to use. Defaults to "gpt-5.1".
resume (bool): Whether to resume optimization from a previous run. Defaults to False.
timeout (int): Timeout in seconds for operations. Defaults to 60.

View File

@ -0,0 +1,926 @@
"""
Fast decomposition analyzer for map operations.
This module provides a fast way to decompose map operations using directives,
running candidates on samples, and selecting the best via pairwise LLM comparison.
"""
import json
import os
import tempfile
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy
from typing import Any
import litellm
from litellm import completion, model_cost
from rich.console import Console
from docetl.console import get_console
from docetl.reasoning_optimizer.directives import (
ChainingDirective,
ClarifyInstructionsDirective,
DeterministicDocCompressionDirective,
DocumentChunkingDirective,
GleaningDirective,
IsolatingSubtasksDirective,
)
from docetl.runner import DSLRunner
from docetl.utils import completion_cost, count_tokens
# Drop unsupported params for models like gpt-5 that don't support temperature=0
litellm.drop_params = True
# Base directives always applicable to map operations
BASE_MAP_DIRECTIVES = [
ChainingDirective(),
IsolatingSubtasksDirective(),
GleaningDirective(),
ClarifyInstructionsDirective(),
]
# Directives that depend on document size
CHUNKING_DIRECTIVE = DocumentChunkingDirective()
COMPRESSION_DIRECTIVE = DeterministicDocCompressionDirective()
# Threshold for enabling DeterministicDocCompression (in characters)
DOC_COMPRESSION_CHAR_THRESHOLD = 1000
# Threshold for enabling DocumentChunking (as fraction of context window)
DOC_CHUNKING_CONTEXT_THRESHOLD = 0.10 # 10%
class FastDecomposer:
"""
Fast decomposition of map operations using directives and pairwise comparison.
Instead of the full optimizer flow, this:
1. Tries multiple directives to generate candidate decompositions
2. Runs each candidate on sample documents
3. Uses LLM judge for pairwise comparison
4. Returns the winning decomposition
"""
def __init__(
self,
yaml_config_path: str,
optimizer_model: str = "gpt-5.1",
sample_size: int = 5,
litellm_kwargs: dict[str, Any] | None = None,
console: Console | None = None,
):
"""
Initialize the decomposer.
Args:
yaml_config_path: Path to the pipeline YAML config file
optimizer_model: LLM model to use for directive instantiation and judging
sample_size: Number of sample documents to run candidates on
litellm_kwargs: Additional kwargs to pass to litellm.completion
console: Rich console for output (uses default if not provided)
"""
self.yaml_config_path = yaml_config_path
self.optimizer_model = optimizer_model
self.sample_size = sample_size
self.litellm_kwargs = litellm_kwargs or {}
if "temperature" not in self.litellm_kwargs:
self.litellm_kwargs["temperature"] = 0.0
self.total_cost = 0.0
self.console = console or get_console()
# Load the config
import yaml
with open(yaml_config_path, "r") as f:
self.config = yaml.safe_load(f)
self.operators = self.config.get("operations", [])
self.default_model = self.config.get("default_model", "gpt-4o-mini")
self.intermediate_dir = (
self.config.get("pipeline", {}).get("output", {}).get("intermediate_dir")
)
def _log(self, message: str) -> None:
"""Log a message to the console."""
self.console.log(message)
def get_model_context_limit(self, model: str) -> int:
"""
Get the context window limit for a model.
Args:
model: The model name (e.g., 'gpt-4o', 'azure/gpt-4')
Returns:
Maximum number of input tokens the model can handle
"""
model_info = model_cost.get(model, {})
# Try without provider prefix if not found
if not model_info:
model_name = model.split("/")[-1]
model_info = model_cost.get(model_name, {})
return model_info.get("max_input_tokens", 128000) # Default to 128k
def get_avg_doc_size(
self, sample_data: list[dict[str, Any]], op_config: dict[str, Any]
) -> tuple[float, float]:
"""
Calculate the average document size in characters and tokens.
Extracts the document content from sample data based on the operation's
prompt template (looks for {{ input.field_name }} patterns).
Args:
sample_data: List of sample documents
op_config: The operation configuration
Returns:
Tuple of (avg_chars, avg_tokens) for the document content
"""
import re
if not sample_data:
return 0.0, 0.0
prompt = op_config.get("prompt", "")
model = op_config.get("model", self.default_model)
# Extract field names from prompt template ({{ input.field_name }})
field_pattern = r"\{\{\s*input\.(\w+)\s*\}\}"
fields = re.findall(field_pattern, prompt)
if not fields:
# Fallback: use all string values from the first document
if sample_data:
fields = [
k
for k, v in sample_data[0].items()
if isinstance(v, str) and len(v) > 100
]
total_chars = 0
total_tokens = 0
for doc in sample_data:
doc_content = ""
for field in fields:
if field in doc:
value = doc[field]
if isinstance(value, str):
doc_content += value
else:
doc_content += str(value)
total_chars += len(doc_content)
if doc_content:
total_tokens += count_tokens(doc_content, model)
n = len(sample_data)
return total_chars / n if n > 0 else 0.0, total_tokens / n if n > 0 else 0.0
def get_applicable_directives(
self,
sample_data: list[dict[str, Any]],
op_config: dict[str, Any],
) -> list:
"""
Get the list of directives applicable based on data characteristics.
- DocumentChunkingDirective: only if avg doc size > 10% of context window
- DeterministicDocCompressionDirective: only if avg doc size > 1000 chars
Args:
sample_data: List of sample documents
op_config: The operation configuration
Returns:
List of applicable directive instances
"""
directives = [] # Build list with priority ordering
model = op_config.get("model", self.default_model)
context_limit = self.get_model_context_limit(model)
avg_chars, avg_tokens = self.get_avg_doc_size(sample_data, op_config)
self._log(
f"Document analysis: avg_chars={avg_chars:.0f}, avg_tokens={avg_tokens:.0f}, "
f"context_limit={context_limit}"
)
# Add DeterministicDocCompression FIRST if doc size > 1000 chars (high priority)
if avg_chars > DOC_COMPRESSION_CHAR_THRESHOLD:
self._log(
f"[cyan]Enabling DeterministicDocCompression (priority)[/cyan] "
f"(avg {avg_chars:.0f} chars > {DOC_COMPRESSION_CHAR_THRESHOLD})"
)
directives.append(COMPRESSION_DIRECTIVE)
# Add base directives
directives.extend(BASE_MAP_DIRECTIVES)
# Add DocumentChunking if avg tokens > 10% of context window
token_threshold = context_limit * DOC_CHUNKING_CONTEXT_THRESHOLD
if avg_tokens > token_threshold:
self._log(
f"[cyan]Enabling DocumentChunking[/cyan] "
f"(avg {avg_tokens:.0f} tokens > {token_threshold:.0f} = 10% of {context_limit})"
)
directives.append(CHUNKING_DIRECTIVE)
else:
self._log(
f"[dim]Skipping DocumentChunking[/dim] "
f"(avg {avg_tokens:.0f} tokens <= {token_threshold:.0f} = 10% of {context_limit})"
)
return directives
def load_sample_data(self, step_name: str, op_name: str) -> list[dict[str, Any]]:
"""
Load sample input data for an operation.
For the first operation, loads from the dataset.
For subsequent operations, loads from the previous operation's intermediate output.
Args:
step_name: Name of the pipeline step
op_name: Name of the operation
Returns:
List of sample documents
"""
# Find the operation's position
op_names = [op.get("name") for op in self.operators]
try:
op_idx = op_names.index(op_name)
except ValueError:
raise ValueError(f"Operation '{op_name}' not found in config")
if op_idx == 0:
# First operation - load from dataset
datasets = self.config.get("datasets", {})
# Get the first dataset (or the one used by this step)
for dataset_name, dataset_config in datasets.items():
dataset_path = dataset_config.get("path")
if dataset_path and os.path.exists(dataset_path):
with open(dataset_path, "r") as f:
data = json.load(f)
return data[: self.sample_size]
raise FileNotFoundError("No dataset found in config")
else:
# Load from previous operation's intermediate output
prev_op_name = op_names[op_idx - 1]
output_path = os.path.join(
self.intermediate_dir, step_name, f"{prev_op_name}.json"
)
if not os.path.exists(output_path):
raise FileNotFoundError(
f"No intermediate output found at {output_path}. "
"Run the previous operation first."
)
with open(output_path, "r") as f:
data = json.load(f)
return data[: self.sample_size]
def get_input_file_path(self) -> str:
"""Get the input file path for the pipeline."""
datasets = self.config.get("datasets", {})
for dataset_name, dataset_config in datasets.items():
path = dataset_config.get("path")
if path:
return path
return ""
def generate_candidates(
self,
op_name: str,
sample_data: list[dict[str, Any]],
target_op: dict[str, Any],
) -> list[dict[str, Any]]:
"""
Generate candidate decompositions using directives.
Directives are selected based on data characteristics:
- DocumentChunkingDirective: only if avg doc size > 10% of context window
- DeterministicDocCompressionDirective: only if avg doc size > 1000 chars
Args:
op_name: Name of the operation to decompose
sample_data: Sample data for analyzing document characteristics
target_op: The target operation configuration
Returns:
List of candidate dictionaries with 'name', 'ops', 'cost' keys
"""
candidates = []
# Add original as baseline
self._log("Adding original operation as baseline candidate")
candidates.append(
{
"name": "original",
"ops": deepcopy(self.operators),
"cost": 0.0,
"error": None,
}
)
input_file_path = self.get_input_file_path()
# Get applicable directives based on data characteristics
applicable_directives = self.get_applicable_directives(sample_data, target_op)
self._log(
f"Generating candidates using {len(applicable_directives)} directives..."
)
for i, directive in enumerate(applicable_directives, 1):
with self.console.status(
f"[bold cyan]({i}/{len(applicable_directives)}) Trying directive: {directive.name}...[/bold cyan]",
spinner="dots",
):
try:
new_ops_list, _, cost = directive.instantiate(
operators=deepcopy(self.operators),
target_ops=[op_name],
agent_llm=self.optimizer_model,
message_history=[],
global_default_model=self.default_model,
input_file_path=input_file_path,
)
self.total_cost += cost
self._log(
f" [green]✓[/green] {directive.name} generated {len(new_ops_list)} operations (cost: ${cost:.4f})"
)
candidates.append(
{
"name": directive.name,
"ops": new_ops_list,
"cost": cost,
"error": None,
}
)
except Exception as e:
# Directive not applicable or failed - skip it
self._log(f" [red]✗[/red] {directive.name} failed: {str(e)}")
candidates.append(
{
"name": directive.name,
"ops": None,
"cost": 0.0,
"error": str(e),
}
)
return candidates
def extract_ops_to_run(
self, ops_list: list[dict], original_op_name: str
) -> list[dict]:
"""
Extract the operations that replaced the original operation.
Args:
ops_list: The full transformed operations list
original_op_name: Name of the original operation that was decomposed
Returns:
List of operations that should be run on samples
"""
# Find ops that are new (not in original) or modified
original_names = {op["name"] for op in self.operators}
# Find the position where the original op was
original_idx = None
for i, op in enumerate(self.operators):
if op["name"] == original_op_name:
original_idx = i
break
if original_idx is None:
return ops_list
# Find new ops (those not in original list)
new_ops = []
for op in ops_list:
if op["name"] not in original_names or op["name"] == original_op_name:
new_ops.append(op)
return new_ops if new_ops else [ops_list[original_idx]]
def run_candidate_on_samples(
self,
candidate: dict[str, Any],
sample_data: list[dict[str, Any]],
original_op_name: str,
) -> list[dict[str, Any]]:
"""
Run a candidate's operations on sample data.
Args:
candidate: Candidate dictionary with 'ops' key
sample_data: List of sample documents
original_op_name: Name of the original operation
Returns:
List of output documents
"""
if candidate["ops"] is None:
return []
# Extract ops to run
ops_to_run = self.extract_ops_to_run(candidate["ops"], original_op_name)
if not ops_to_run:
return []
# Create a minimal config for running these ops
# Write sample data to a temp file
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(sample_data, f)
temp_input_path = f.name
# Create a temp output file (required by DSLRunner validation)
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
temp_output_path = f.name
try:
# Create a minimal pipeline config
temp_config = {
"default_model": self.default_model,
"operations": ops_to_run,
"datasets": {
"sample_data": {
"type": "file",
"path": temp_input_path,
}
},
"pipeline": {
"steps": [
{
"name": "decompose_test",
"input": "sample_data",
"operations": [op["name"] for op in ops_to_run],
}
],
"output": {
"type": "file",
"path": temp_output_path,
},
},
}
# Create runner and execute
runner = DSLRunner(temp_config, max_threads=4)
# Run operations sequentially on the data
current_data = sample_data
for op_config in ops_to_run:
current_data = runner._run_operation(op_config, current_data)
self.total_cost += runner.total_cost
return current_data
finally:
# Clean up temp files
for temp_path in [temp_input_path, temp_output_path]:
try:
os.unlink(temp_path)
except Exception:
pass
def pairwise_compare(
self,
candidate_a: dict[str, Any],
candidate_b: dict[str, Any],
original_prompt: str,
output_schema: dict[str, Any],
) -> dict[str, Any]:
"""
Compare two candidates using LLM judge.
Args:
candidate_a: First candidate with 'name', 'outputs' keys
candidate_b: Second candidate with 'name', 'outputs' keys
original_prompt: The original operation's prompt
output_schema: The expected output schema
Returns:
The winning candidate
"""
# If one has no outputs, the other wins
if not candidate_a.get("outputs"):
return candidate_b
if not candidate_b.get("outputs"):
return candidate_a
system_prompt = """You are an expert judge comparing outputs from two data processing pipeline variants.
Your task is to determine which variant produces BETTER outputs based on:
1. **Completeness**: Does the output contain all required information?
2. **Accuracy**: Is the extracted/generated information correct?
3. **Consistency**: Are the outputs consistent across different samples?
4. **Quality**: Is the output well-structured and useful?
Be objective and focus on the actual output quality, not the approach used."""
user_prompt = f"""Compare outputs from two pipeline variants for this task:
## Original Task
**Prompt:**
```
{original_prompt[:2000]}{"..." if len(original_prompt) > 2000 else ""}
```
**Expected Output Schema:**
```json
{json.dumps(output_schema, indent=2)}
```
## Variant A: {candidate_a["name"]}
**Sample Outputs:**
```json
{json.dumps(candidate_a["outputs"][:3], indent=2, default=str)}
```
## Variant B: {candidate_b["name"]}
**Sample Outputs:**
```json
{json.dumps(candidate_b["outputs"][:3], indent=2, default=str)}
```
## Your Task
Which variant produces better outputs? Consider completeness, accuracy, consistency, and quality.
Respond with your analysis and final choice."""
response = completion(
model=self.optimizer_model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format={
"type": "json_schema",
"json_schema": {
"name": "comparison_result",
"strict": True,
"schema": {
"type": "object",
"properties": {
"winner": {
"type": "string",
"enum": ["A", "B"],
"description": "Which variant is better: A or B",
},
"rationale": {
"type": "string",
"description": "Explanation of why this variant is better",
},
"a_strengths": {
"type": "array",
"items": {"type": "string"},
"description": "Strengths of variant A",
},
"b_strengths": {
"type": "array",
"items": {"type": "string"},
"description": "Strengths of variant B",
},
},
"required": [
"winner",
"rationale",
"a_strengths",
"b_strengths",
],
"additionalProperties": False,
},
},
},
**self.litellm_kwargs,
)
cost = completion_cost(response)
self.total_cost += cost
result = json.loads(response.choices[0].message.content)
winner = candidate_a if result["winner"] == "A" else candidate_b
winner["comparison_rationale"] = result["rationale"]
return winner
def decompose(
self,
step_name: str,
op_name: str,
) -> tuple[list[dict[str, Any]], str, int, float]:
"""
Decompose an operation using fast directive-based approach.
This is the main entry point. It:
1. Generates candidate decompositions using directives
2. Runs each on sample data
3. Uses pairwise LLM comparison to select the best
4. Returns the winning decomposition
Args:
step_name: Name of the pipeline step
op_name: Name of the operation to decompose
Returns:
Dict with:
- decomposed_ops: List of operations that replace the original
- winning_directive: Name of the winning directive
- candidates_evaluated: Number of candidates that were compared
- original_outputs: Sample outputs from the original operation
- decomposed_outputs: Sample outputs from the winning decomposition
- comparison_rationale: LLM's explanation of why the winner was chosen
- cost: Total LLM API cost
"""
self.console.rule("[bold blue]Fast Decomposition[/bold blue]")
self._log(f"[bold]Target operation:[/bold] {op_name}")
# Find the target operation config
target_op = None
for op in self.operators:
if op["name"] == op_name:
target_op = op
break
if target_op is None:
raise ValueError(f"Operation '{op_name}' not found in config")
# Verify it's a map operation
if target_op.get("type") != "map":
raise ValueError(
f"Operation '{op_name}' is type '{target_op.get('type')}', "
"but fast decomposition only supports 'map' operations"
)
# Load sample data
self._log(f"Loading sample data from step '{step_name}'...")
sample_data = self.load_sample_data(step_name, op_name)
self._log(f"[green]✓[/green] Loaded {len(sample_data)} sample documents")
# Generate candidates
self.console.rule("[bold]Generating Candidates[/bold]")
candidates = self.generate_candidates(op_name, sample_data, target_op)
# Filter out failed candidates
valid_candidates = [c for c in candidates if c["ops"] is not None]
self._log(f"[bold]{len(valid_candidates)}[/bold] valid candidates generated")
if len(valid_candidates) < 2:
# Only original (or nothing) - return original
self._log(
"[yellow]No alternative decompositions generated. Keeping original.[/yellow]"
)
return {
"decomposed_ops": self.operators,
"winning_directive": "original",
"candidates_evaluated": len(valid_candidates),
"original_outputs": [],
"decomposed_outputs": [],
"comparison_rationale": "No alternative decompositions were generated.",
"cost": self.total_cost,
}
# Run each candidate on samples IN PARALLEL
self.console.rule("[bold]Running Candidates on Samples[/bold]")
self._log(f"Running {len(valid_candidates)} candidates in parallel...")
def run_single_candidate(candidate):
"""Run a single candidate and return results."""
name = candidate["name"]
try:
outputs = self.run_candidate_on_samples(candidate, sample_data, op_name)
return {"name": name, "outputs": outputs, "error": None}
except Exception as e:
error_tb = traceback.format_exc()
return {
"name": name,
"outputs": [],
"error": str(e),
"traceback": error_tb,
}
# Run all candidates in parallel
with self.console.status(
"[bold cyan]Running candidates on samples...[/bold cyan]", spinner="dots"
):
with ThreadPoolExecutor(max_workers=len(valid_candidates)) as executor:
future_to_candidate = {
executor.submit(run_single_candidate, c): c
for c in valid_candidates
}
for future in as_completed(future_to_candidate):
candidate = future_to_candidate[future]
result = future.result()
if result["error"]:
candidate["outputs"] = []
candidate["run_error"] = result["error"]
self._log(
f" [red]✗[/red] {result['name']} failed: {result['error']}"
)
else:
candidate["outputs"] = result["outputs"]
self._log(
f" [green]✓[/green] {result['name']}: {len(result['outputs'])} outputs"
)
# Filter to candidates with outputs
candidates_with_outputs = [c for c in valid_candidates if c.get("outputs")]
self._log(
f"[bold]{len(candidates_with_outputs)}[/bold] candidates produced outputs"
)
if not candidates_with_outputs:
# All failed - return original
self._log("[red]All decomposition candidates failed to execute.[/red]")
# Log the errors for debugging
for c in valid_candidates:
if c.get("run_error"):
self._log(f" [red]{c['name']}:[/red] {c['run_error']}")
return {
"decomposed_ops": self.operators,
"winning_directive": "original",
"candidates_evaluated": 0,
"original_outputs": [],
"decomposed_outputs": [],
"comparison_rationale": "All decomposition candidates failed to execute.",
"cost": self.total_cost,
}
# Find the original candidate's outputs
original_candidate = next(
(c for c in candidates_with_outputs if c["name"] == "original"), None
)
original_outputs = original_candidate["outputs"] if original_candidate else []
# Parallel pairwise comparison: compare all candidates against original
self.console.rule("[bold]Pairwise Comparison[/bold]")
original_prompt = target_op.get("prompt", "")
output_schema = target_op.get("output", {}).get("schema", {})
# If only original exists, it wins by default
if len(candidates_with_outputs) == 1:
winner = candidates_with_outputs[0]
elif original_candidate is None:
# No original - just pick the first candidate
winner = candidates_with_outputs[0]
else:
# Compare all non-original candidates against original IN PARALLEL
challengers = [
c for c in candidates_with_outputs if c["name"] != "original"
]
if not challengers:
winner = original_candidate
else:
self._log(
f"Comparing {len(challengers)} candidates against original in parallel..."
)
def compare_against_original(challenger):
"""Compare a single challenger against original."""
try:
result = self.pairwise_compare(
original_candidate,
challenger,
original_prompt,
output_schema,
)
won = result["name"] == challenger["name"]
return {
"challenger": challenger,
"won": won,
"rationale": result.get("comparison_rationale", ""),
}
except Exception as e:
return {
"challenger": challenger,
"won": False,
"error": str(e),
}
# Run all comparisons in parallel
comparison_results = []
with self.console.status(
f"[bold cyan]Running {len(challengers)} comparisons in parallel...[/bold cyan]",
spinner="dots",
):
with ThreadPoolExecutor(max_workers=len(challengers)) as executor:
future_to_challenger = {
executor.submit(compare_against_original, c): c
for c in challengers
}
for future in as_completed(future_to_challenger):
result = future.result()
comparison_results.append(result)
challenger_name = result["challenger"]["name"]
if result.get("error"):
self._log(
f" [red]✗[/red] {challenger_name} vs original: error - {result['error']}"
)
elif result["won"]:
self._log(
f" [green]✓[/green] {challenger_name} beat original"
)
else:
self._log(
f" [dim]○[/dim] {challenger_name} lost to original"
)
# Find winners (candidates that beat original)
winners = [r for r in comparison_results if r.get("won")]
if not winners:
# Original beats all challengers
winner = original_candidate
self._log("[bold]Original wins against all challengers[/bold]")
elif len(winners) == 1:
# Single winner
winner = winners[0]["challenger"]
winner["comparison_rationale"] = winners[0].get("rationale", "")
else:
# Multiple winners beat original - run tiebreaker comparisons in parallel
self._log(
f"[bold]{len(winners)} candidates beat original - running tiebreaker...[/bold]"
)
# Compare all winners against each other in parallel (round-robin)
winner_candidates = [w["challenger"] for w in winners]
win_counts = {c["name"]: 0 for c in winner_candidates}
# Generate all pairwise matchups
matchups = []
for i, a in enumerate(winner_candidates):
for b in winner_candidates[i + 1 :]:
matchups.append((a, b))
if matchups:
def run_matchup(matchup):
a, b = matchup
try:
result = self.pairwise_compare(
a, b, original_prompt, output_schema
)
return {
"winner": result["name"],
"a": a["name"],
"b": b["name"],
}
except Exception:
return {
"winner": a["name"],
"a": a["name"],
"b": b["name"],
} # Default to first
with self.console.status(
f"[bold cyan]Running {len(matchups)} tiebreaker comparisons...[/bold cyan]",
spinner="dots",
):
with ThreadPoolExecutor(
max_workers=len(matchups)
) as executor:
for result in executor.map(run_matchup, matchups):
win_counts[result["winner"]] += 1
self._log(
f" [dim]{result['a']} vs {result['b']}{result['winner']}[/dim]"
)
# Pick candidate with most wins
best_name = max(win_counts, key=win_counts.get)
winner = next(
c for c in winner_candidates if c["name"] == best_name
)
self._log(
f"[bold]Tiebreaker winner: {best_name} ({win_counts[best_name]} wins)[/bold]"
)
# Extract the decomposed operations
decomposed_ops = self.extract_ops_to_run(winner["ops"], op_name)
# Final summary
self.console.rule("[bold green]Decomposition Complete[/bold green]")
self._log(f"[bold]Winner:[/bold] [green]{winner['name']}[/green]")
self._log(f"[bold]Candidates evaluated:[/bold] {len(candidates_with_outputs)}")
self._log(f"[bold]New operations:[/bold] {len(decomposed_ops)}")
self._log(f"[bold]Total cost:[/bold] ${self.total_cost:.4f}")
return {
"decomposed_ops": decomposed_ops,
"winning_directive": winner["name"],
"candidates_evaluated": len(candidates_with_outputs),
"original_outputs": original_outputs,
"decomposed_outputs": winner.get("outputs", []),
"comparison_rationale": winner.get("comparison_rationale", ""),
"cost": self.total_cost,
}

View File

@ -0,0 +1,337 @@
"""
Fast should_optimize analyzer using a single LLM call.
This module provides a lightweight alternative to the full MapOptimizer/ReduceOptimizer/JoinOptimizer
flow for quickly determining if an operation should be decomposed.
"""
import json
import os
from typing import Any
import litellm
from litellm import completion, model_cost
from docetl.utils import completion_cost, count_tokens
# Drop unsupported params for models like gpt-5 that don't support temperature=0
litellm.drop_params = True
class FastShouldOptimizeAnalyzer:
"""
Analyzes whether an operation should be optimized using a single LLM call.
Instead of running the operation on sample data and using complex evaluation logic,
this reads cached outputs from intermediate files and makes one judgment call.
"""
def __init__(
self,
intermediate_dir: str,
optimizer_model: str = "gpt-5.1",
litellm_kwargs: dict[str, Any] | None = None,
):
"""
Initialize the analyzer.
Args:
intermediate_dir: Path to the directory containing intermediate outputs
optimizer_model: The LLM model to use for analysis (default: gpt-5.1)
litellm_kwargs: Additional kwargs to pass to litellm.completion
"""
self.intermediate_dir = intermediate_dir
self.optimizer_model = optimizer_model
self.litellm_kwargs = litellm_kwargs or {}
if "temperature" not in self.litellm_kwargs:
self.litellm_kwargs["temperature"] = 0.0
def load_operation_data(self, step_name: str, op_name: str) -> list[dict[str, Any]]:
"""
Load data from the intermediate file for an operation.
Args:
step_name: Name of the pipeline step
op_name: Name of the operation
Returns:
List of dictionaries
Raises:
FileNotFoundError: If the intermediate file doesn't exist
"""
output_path = os.path.join(self.intermediate_dir, step_name, f"{op_name}.json")
if not os.path.exists(output_path):
raise FileNotFoundError(
f"No output file found at {output_path}. "
"Run the operation first to generate outputs."
)
with open(output_path, "r") as f:
return json.load(f)
def find_previous_operation(
self, operations: list[dict[str, Any]], op_name: str
) -> str | None:
"""
Find the operation that comes before op_name in the pipeline.
Args:
operations: List of operation configs from the pipeline
op_name: Name of the current operation
Returns:
Name of the previous operation, or None if this is the first operation
"""
op_names = [op.get("name") for op in operations]
try:
idx = op_names.index(op_name)
if idx > 0:
return op_names[idx - 1]
except ValueError:
pass
return None
def get_max_context_tokens(self) -> int:
"""
Get the maximum input tokens for the optimizer model.
Returns:
Maximum number of input tokens the model can handle
"""
model_info = model_cost.get(self.optimizer_model, {})
# Try without provider prefix if not found
if not model_info:
model_name = self.optimizer_model.split("/")[-1]
model_info = model_cost.get(model_name, {})
return model_info.get("max_input_tokens", 128000) # Default to 128k
def calculate_samples_that_fit(
self,
op_config: dict[str, Any],
outputs: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""
Calculate how many output samples fit in the context window.
Reserves space for system prompt, operation config, and response buffer,
then fills remaining space with as many samples as possible.
Args:
op_config: The operation configuration dictionary
outputs: List of all output documents
Returns:
List of samples that fit in the context window
"""
max_tokens = self.get_max_context_tokens()
# Reserve tokens for fixed parts
system_prompt_tokens = 500
op_config_tokens = count_tokens(
json.dumps(op_config, default=str), self.optimizer_model
)
response_buffer = 2000
available_for_samples = (
max_tokens - system_prompt_tokens - op_config_tokens - response_buffer
)
# Collect samples that fit
samples_to_include = []
tokens_used = 0
for output in outputs:
sample_json = json.dumps(output, default=str)
sample_tokens = count_tokens(sample_json, self.optimizer_model)
if tokens_used + sample_tokens <= available_for_samples:
samples_to_include.append(output)
tokens_used += sample_tokens
else:
break
return samples_to_include
def build_analysis_prompt(
self,
op_config: dict[str, Any],
samples: list[dict[str, Any]],
) -> tuple[str, str]:
"""
Build the system prompt and user prompt for the analysis LLM call.
Args:
op_config: The operation configuration dictionary
samples: List of output samples to analyze
Returns:
Tuple of (system_prompt, user_prompt)
"""
system_prompt = """You are an expert at analyzing LLM-powered data processing operations.
Your task is to determine if an operation would benefit from being decomposed into multiple
smaller, focused operations (also known as "optimization" or "decomposition").
An operation SHOULD be decomposed when:
1. The prompt asks the LLM to do multiple distinct tasks that could be done separately
2. The task is complex enough that breaking it into sequential steps would improve accuracy
3. The outputs show inconsistency, incompleteness, or quality issues that iterative refinement could fix
4. Long documents need to be processed in chunks rather than all at once
5. The prompt asks for both extraction AND analysis/synthesis in one step
An operation should NOT be decomposed when:
1. It performs a single, focused task well
2. The outputs are consistently high quality and complete
3. The task is simple and atomic (e.g., simple classification, single field extraction)
4. The operation is already well-scoped and produces reliable results
Be conservative - only recommend decomposition if there's clear evidence it would help."""
output_schema = op_config.get("output", {}).get("schema", {})
prompt_template = op_config.get("prompt", "No prompt specified")
# Truncate very long prompts for display
if len(prompt_template) > 3000:
prompt_template = prompt_template[:3000] + "\n... [truncated]"
user_prompt = f"""Analyze this data processing operation and its outputs:
## Operation Configuration
**Name:** {op_config.get('name', 'unknown')}
**Type:** {op_config.get('type', 'unknown')}
**Prompt Template:**
```
{prompt_template}
```
**Output Schema:**
```json
{json.dumps(output_schema, indent=2)}
```
## Sample Outputs ({len(samples)} samples from the operation)
```json
{json.dumps(samples, indent=2, default=str)}
```
## Your Task
Based on the operation configuration and sample outputs, determine:
1. Should this operation be decomposed/optimized?
2. If yes, what specific improvements would help?
Consider the quality, completeness, and consistency of the outputs when making your assessment."""
return system_prompt, user_prompt
def analyze(
self,
op_config: dict[str, Any],
step_name: str,
op_name: str,
) -> tuple[str, list[dict[str, Any]], int, float]:
"""
Analyze whether an operation should be optimized.
This is the main entry point. It loads outputs, builds the prompt,
makes a single LLM call, and returns the assessment.
Args:
op_config: The operation configuration dictionary
step_name: Name of the pipeline step
op_name: Name of the operation
Returns:
Tuple of (rationale, output_samples, num_docs_analyzed, cost):
- rationale: Empty string if no optimization needed, explanation if it should be optimized
- output_samples: The samples that were analyzed
- num_docs_analyzed: Number of documents that fit in the LLM prompt
- cost: LLM API cost in USD
Raises:
ValueError: If the operation is not an LLM-powered map operation
"""
# Validate operation type - only LLM-powered map operations
op_type = op_config.get("type", "")
if op_type != "map":
raise ValueError(
f"should_optimize only supports 'map' operations, got '{op_type}'. "
"Only LLM-powered map operations can be analyzed for decomposition."
)
# Load outputs from intermediate file
outputs = self.load_operation_data(step_name, op_name)
if not outputs:
return "No output samples available for analysis.", [], 0, 0.0
# Calculate samples that fit in context
samples = self.calculate_samples_that_fit(op_config, outputs)
if not samples:
return "Could not fit any samples in context window.", outputs[:5], 0, 0.0
# Build prompt
system_prompt, user_prompt = self.build_analysis_prompt(op_config, samples)
# Make LLM call with structured output
response = completion(
model=self.optimizer_model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format={
"type": "json_schema",
"json_schema": {
"name": "optimization_analysis",
"strict": True,
"schema": {
"type": "object",
"properties": {
"should_optimize": {
"type": "boolean",
"description": "True if operation should be decomposed/optimized",
},
"rationale": {
"type": "string",
"description": "Explanation of why the operation should or should not be optimized",
},
"suggested_improvements": {
"type": "array",
"items": {"type": "string"},
"description": "Specific improvements if optimization is recommended (empty if not)",
},
},
"required": [
"should_optimize",
"rationale",
"suggested_improvements",
],
"additionalProperties": False,
},
},
},
**self.litellm_kwargs,
)
# Calculate cost
cost = completion_cost(response)
# Parse response
result = json.loads(response.choices[0].message.content)
num_docs_analyzed = len(samples)
if result["should_optimize"]:
# Build rationale string with improvements
rationale_parts = [result["rationale"]]
if result["suggested_improvements"]:
rationale_parts.append("\n\nSuggested improvements:")
for imp in result["suggested_improvements"]:
rationale_parts.append(f"- {imp}")
return "\n".join(rationale_parts), samples, num_docs_analyzed, cost
else:
# Return empty string to indicate no optimization needed
return "", samples, num_docs_analyzed, cost

View File

@ -30,7 +30,7 @@ class LLMClient:
Initialize the LLMClient.
Args:
model (str, optional): The name of the LLM model to use. Defaults to "gpt-4o".
model (str, optional): The name of the LLM model to use. Defaults to "gpt-5.1".
**litellm_kwargs: Additional keyword arguments for the LLM model.
"""
self.rewrite_agent_model = rewrite_agent_model

View File

@ -0,0 +1,85 @@
# DocETL Reasoning Optimizer - Agent Guide
## Overview
The reasoning optimizer uses a graph-search algorithm with LLM-based directives to rewrite and optimize DocETL pipelines. This directory contains the directive system for pipeline transformations.
## Creating New Directives - Interactive Process
As an AI agent, I can guide you through creating new rewrite directives step-by-step. Here's the process:
### 1. Initial Consultation
- **What I'll ask**: Describe what transformation you want the directive to perform
- **What you provide**: High-level description of the directive's purpose and when to use it
- **What I'll do**: Analyze existing directives and suggest the best approach
### 2. Schema Design Phase
- **What I'll ask**: Confirm the configuration parameters needed for the directive
- **What I'll propose**: The instantiate schema structure (Pydantic models) that the LLM agent will output
- **What you confirm**: Whether the schema captures all necessary parameters
### 3. Directive Specification
- **What I'll propose**:
- `name`: Technical identifier for the directive
- `formal_description`: Brief transformation pattern (e.g., "Op => Code Map -> Op")
- `nl_description`: Natural language explanation
- `when_to_use`: Specific use cases
- **What you confirm**: Whether these descriptions accurately capture the directive's purpose
### 4. Example Creation
- **What I'll propose**: Example showing original operation and expected instantiate schema output
- **What you confirm**: Whether the example demonstrates the directive correctly
### 5. Test Case Design
- **What I'll propose**: Test cases with input operations and expected behaviors
- **What you confirm**: Whether test cases cover important scenarios
### 6. Implementation Review
- **What I'll show**: Complete directive implementation including:
- Schema classes in `instantiate_schemas.py`
- Directive class with all required methods
- Registration in `__init__.py`
- Apply tests in `tests/reasoning_optimizer/test_directive_apply.py`
- **What you confirm**: Final approval before implementation
## Existing Directive Patterns
### Single Operation Modification
- **Gleaning**: Adds validation loops (`validation_prompt`, `num_rounds`)
- **Change Model**: Switches LLM model (`model`)
- **Deterministic Doc Compression**: Adds regex-based preprocessing
### Operation Replacement
- **Chaining**: Replaces complex operation with sequential simpler ones
- **Isolating Subtasks**: Breaks operation into independent parallel tasks
### Pipeline Preprocessing
- **Doc Summarization**: Adds document summarization before main processing
## Key Files and Structure
- `directives/`: Individual directive implementations
- `instantiate_schemas.py`: Pydantic schemas for LLM outputs
- `agent.py`: Core MCTS agent that applies directives
- `op_descriptions.py`: Operation type descriptions for the agent
## Testing Commands
### Instantiation Tests
- Single directive: `python experiments/reasoning/run_tests.py --directive=directive_name`
- All directives: `python experiments/reasoning/run_tests.py`
### Apply Tests
- All directive apply methods: `python tests/reasoning_optimizer/test_directive_apply.py`
### Full MCTS
- Complete optimization: `python experiments/reasoning/run_mcts.py`
## Workflow for New Directive
1. Describe desired transformation → I analyze and propose approach
2. Confirm schema design → I implement Pydantic models
3. Confirm descriptions → I write directive metadata
4. Confirm examples → I create demonstration cases
5. Confirm test cases → I design validation scenarios
6. Review implementation → I write complete directive code
7. Test and iterate → We verify functionality together
**Ready to create a new directive? Describe what transformation you want to implement!**

View File

@ -0,0 +1,521 @@
import json
import os
import threading
import time
from typing import Dict, List
import litellm
import yaml
from pydantic import BaseModel
from docetl.reasoning_optimizer.directives import (
get_all_directive_strings,
instantiate_directive,
)
from docetl.reasoning_optimizer.load_data import load_input_doc
from docetl.utils import load_config
from .op_descriptions import * # noqa: F403, F405
# argparse removed - use experiments/reasoning/run_baseline.py for CLI
# Global dictionary of rate limiters per model
model_rate_limiters: Dict[str, "TokenRateLimiter"] = {}
# Use environment variable or default to current directory
data_dir = os.environ.get("EXPERIMENT_DATA_DIR", "./data/")
def get_rate_limiter(model: str, max_tpm: int) -> "TokenRateLimiter":
if model not in model_rate_limiters:
model_rate_limiters[model] = TokenRateLimiter(max_tpm)
return model_rate_limiters[model]
class TokenRateLimiter:
def __init__(self, max_tpm):
self.max_tpm = max_tpm
self.tokens_used = 0
self.lock = threading.Lock()
self.reset_time = time.time() + 60 # 60 seconds window
def allow(self, tokens):
with self.lock:
now = time.time()
if now >= self.reset_time:
self.tokens_used = 0
self.reset_time = now + 60
if self.tokens_used + tokens > self.max_tpm:
return False
self.tokens_used += tokens
return True
def wait_for_slot(self, tokens):
while not self.allow(tokens):
time_to_wait = max(0, self.reset_time - time.time())
time.sleep(time_to_wait)
def count_tokens(messages):
# messages should be a list of dicts, each with a "content" key
total_chars = sum(
len(m.get("content", "")) for m in messages if isinstance(m, dict)
)
return max(1, total_chars // 4)
# ------------------------------------------------------------------
# 🔒 Context-window safety helpers
# ------------------------------------------------------------------
# Maximum number of tokens we will allow in the prompt we send to the model.
# The Azure GPT-5 family allows 272,000 tokens.
MAX_CONTEXT_TOKENS = 270_000
def _trim_history(history: list, keep_system_first: bool = True) -> list:
"""Trim the conversation history in-place so its estimated token count
(via ``count_tokens``) does not exceed ``MAX_CONTEXT_TOKENS``.
We always keep the very first system message and the first user message so the
assistant retains the global instructions and the initial query context. After
that we drop the oldest messages until the budget is satisfied. Returns the
trimmed history list.
"""
# Determine starting index to preserve the initial system message and first user message
start_idx = 0
if keep_system_first and history:
if history[0].get("role") == "system":
start_idx = 1
# Find the first user message after the system message
for i in range(1, len(history)):
if history[i].get("role") == "user":
start_idx = i + 1
break
elif history[0].get("role") == "user":
# If first message is user, keep it and find the next user message
start_idx = 1
for i in range(1, len(history)):
if history[i].get("role") == "user":
start_idx = i + 1
break
# Drop oldest messages (just after the preserved block) until within limit
while len(history) > start_idx + 1 and count_tokens(history) > MAX_CONTEXT_TOKENS:
history.pop(start_idx)
return history
class ResponseFormat(BaseModel):
directive: str
operators: List[str]
def get_openai_response(
input_query,
input_schema,
input_data_sample,
model="o3",
max_tpm=5000000,
message_history=[],
curr_plan_output="",
prev_plan_cost: float = 0.0,
iteration=1,
):
"""
The first LLM call. Generates a rewrite plan given the rewrite directives.
"""
if iteration == 1:
user_message = f"""
I have a set of operations used to process long documents, along with a list of possible rewrite directives aimed at improving the quality of the query result.
Given a query pipeline made up of these operations, recommend one specific rewrite directive (specify by its name) that would improve accuracy and specify which operators (specify by the names) in the pipeline the directive should be applied to.
Make sure that your cosen directive is in the provided list of rewrite directives.
Pipeline:
Pipelines in DocETL are the core structures that define the flow of data processing. A pipeline consists of five main components: \n
- Default Model: The language model to use for the pipeline. Limit your choice of model to gpt-5-nano, gpt-4o-mini, gpt-5 \n
- System Prompts: A description of your dataset and the "persona" you'd like the LLM to adopt when analyzing your data. \n
- Datasets: The input data sources for your pipeline. \n
- Operators: The processing steps that transform your data. \n
- Pipeline Specification: The sequence of steps and the output configuration. \n
Operators:
Operators form the building blocks of data processing pipelines. Below is the list of operators:
{op_map.to_string()}\n
{op_extract.to_string()}\n
{op_parallel_map.to_string()}\n
{op_filter.to_string()}\n
{op_reduce.to_string()}\n
{op_split.to_string()}\n
{op_gather.to_string()}\n
{op_unnest.to_string()}\n
{op_sample.to_string()}\n
{op_resolve.to_string()}\n
Rewrite directives:
{get_all_directive_strings()}\n
Input document schema with token statistics: {input_schema} \n
Cost of previous plan execution: ${prev_plan_cost:.4f} \n
The original query in YAML format using our operations: {input_query} \n
Input data sample: {json.dumps(input_data_sample, indent=2)[:3000]} \n
Sample of the result from executing the original query: {json.dumps(curr_plan_output, indent=2)[:3000]} \n
"""
else:
user_message = f"""
Given the previously rewritten pipeline, recommend one specific rewrite directive (specify by its name) that would improve accuracy and specify which operator (specify by the name) in the pipeline the directive should be applied to.
Make sure that your cosen directive is in the provided list of rewrite directives.
Rewrite directives:
{get_all_directive_strings()}\n
Cost of previous plan execution: ${prev_plan_cost:.4f} \n
The original query in YAML format using our operations: {input_query} \n
Sample of the result from executing the original query: {json.dumps(curr_plan_output, indent=2)[:3000]} \n
"""
if len(message_history) == 0:
message_history.extend(
[
{
"role": "system",
"content": "You are an expert query optimization agent for document processing pipelines. Your role is to analyze user queries and apply rewrite directives to create more accurate execution plans. Your output must follow the structured output format.",
},
{"role": "user", "content": user_message},
]
)
else:
message_history.append({"role": "user", "content": user_message})
# Trim the history to prevent context window overflow before sending to the model
message_history = _trim_history(message_history)
messages = message_history
# Enforce rate limit for the specified model
if max_tpm > 0:
limiter = get_rate_limiter(model, max_tpm)
tokens = count_tokens(messages)
limiter.wait_for_slot(tokens)
# Count the number of tokens in the messages for debugging/monitoring
num_tokens = count_tokens(messages)
print(f"Token count for current messages: {num_tokens}")
# litellm._turn_on_debug()
response = litellm.completion(
model=model,
messages=messages,
response_format=ResponseFormat,
)
assistant_response = response.choices[0].message.content
# Add user and assistant messages to message_history as dicts
message_history.append({"role": "assistant", "content": assistant_response})
return assistant_response, message_history
def update_yaml_operations(input_file_path, output_file_path, new_operations):
"""
Load a YAML file, replace the operations section, and save to a new file.
Args:
input_file_path (str): Path to the original YAML file
output_file_path (str): Path where the modified YAML will be saved
new_operations (list): List of operation dictionaries to replace the original operations
"""
# Load the original YAML file
with open(input_file_path, "r") as file:
config = yaml.safe_load(file)
# Replace the operations section
config["operations"] = new_operations
# Write the modified config to a new YAML file
with open(output_file_path, "w") as file:
yaml.dump(
config, file, default_flow_style=False, allow_unicode=True, sort_keys=False
)
print(f"Modified YAML saved to: {output_file_path}")
def update_pipeline(orig_config, new_ops_list, target_ops):
"""
Update the pipeline configuration with new operations.
Args:
orig_config (dict): The original pipeline configuration
new_ops_list (list): The entire pipeline operations list (not a subset)
target_ops (list): List of target operation names to replace
Returns:
dict: Updated pipeline configuration
"""
if new_ops_list is not None:
op_names = [op.get("name") for op in new_ops_list if "name" in op]
# Update the pipeline steps to use the new operation names
if "pipeline" in orig_config and "steps" in orig_config["pipeline"]:
for step in orig_config["pipeline"]["steps"]:
if "operations" in step:
new_ops = []
for op in step["operations"]:
if op == target_ops[0]:
new_ops.extend(op_names)
step["operations"] = new_ops
return orig_config
def fix_models(parsed_yaml):
"""No-op: Model names should be specified correctly in the YAML."""
pass
def update_sample(new_ops_list, target_ops, orig_operators):
"""
Update sample settings in new operations based on original operators.
Args:
new_ops_list (list): List of new operations to update
target_ops (list): List of target operation names
orig_operators (list): List of original operators
Returns:
list: Updated new operations list with sample settings
"""
# Build a mapping from op name to op config in orig_operators
op_name_to_config = {op.get("name"): op for op in orig_operators if "name" in op}
# For each op in new_ops_list, if the corresponding op in orig_operators has 'sample', add it
sample_size = -1
for target_op_name in target_ops:
target_op = op_name_to_config[target_op_name]
if "sample" in target_op:
sample_size = target_op["sample"]
print("SAMPLE SIZE: ", sample_size)
for op in new_ops_list:
if sample_size != -1:
op["sample"] = sample_size
return new_ops_list
def save_message_history(message_history, filepath):
"""
Save message history to a JSON file.
Args:
message_history (list): List of message dictionaries
filepath (str): Path to save the message history
"""
with open(filepath, "w") as f:
json.dump(message_history, f, indent=2)
print(f"Message history saved to: {filepath}")
def load_message_history(filepath):
"""
Load message history from a JSON file.
Args:
filepath (str): Path to the message history file
Returns:
list: List of message dictionaries, or empty list if file doesn't exist
"""
if os.path.exists(filepath):
with open(filepath, "r") as f:
return json.load(f)
return []
def run_single_iteration(
yaml_path,
model,
max_tpm,
message_history,
iteration_num,
orig_output_sample,
prev_plan_cost: float,
output_dir=None,
dataset="cuad",
sample_data=None,
):
"""
Run a single iteration of the optimization process.
Args:
yaml_path (str): Path to the YAML file
model (str): Model name
max_tpm (int): Tokens per minute limit
message_history (list): Cumulative message history
iteration_num (int): Current iteration number
Returns:
tuple: (output_file_path, updated_message_history)
"""
print(f"\n=== Running Iteration {iteration_num} ===")
print(f"Input file: {yaml_path}")
# Parse input yaml file to get the list of operations
orig_config = load_config(yaml_path)
orig_operators = orig_config["operations"]
# Use provided sample data
random_sample = sample_data if sample_data is not None else []
with open(yaml_path, "r") as f:
input_query = f.read()
with open(yaml_path, "r") as file:
config = yaml.safe_load(file)
global_default_model = config.get("default_model")
datasets = config.get("datasets", {})
input_file_path = None
if isinstance(datasets, dict) and datasets:
first_dataset = next(iter(datasets.values()))
if isinstance(first_dataset, dict):
input_file_path = first_dataset.get("path")
input_schema = load_input_doc(yaml_path)
reply, message_history = get_openai_response(
input_query,
input_schema,
random_sample,
model=model,
max_tpm=max_tpm,
message_history=message_history,
curr_plan_output=orig_output_sample,
prev_plan_cost=prev_plan_cost,
iteration=iteration_num,
)
# Use output_dir if provided, otherwise fall back to data_dir
save_dir = output_dir if output_dir else data_dir
# Parse agent response
try:
parsed = json.loads(reply)
directive = parsed.get("directive")
target_ops = parsed.get("operators")
print(f"Directive: {directive}, Target ops: {target_ops}")
# Log directive and target ops to baseline_log.txt in the same directory as output YAML
log_file_path = os.path.join(save_dir, "baseline_log.txt")
log_message = f"Iteration {iteration_num}: Directive: {directive}, Target ops: {target_ops}\n"
with open(log_file_path, "a") as log_file:
log_file.write(log_message)
except Exception as e:
print(f"Failed to parse agent response: {e}")
return None, message_history
try:
new_ops_list, message_history, cost = instantiate_directive(
directive_name=directive,
operators=orig_operators,
target_ops=target_ops,
agent_llm=model,
message_history=message_history,
global_default_model=global_default_model,
input_file_path=input_file_path,
pipeline_code=orig_config,
dataset=dataset,
)
orig_config["operations"] = new_ops_list
# Update pipeline steps to reflect new operation names
orig_config = update_pipeline(orig_config, new_ops_list, target_ops)
# Apply special post-processing for chaining directive
if directive == "chaining":
new_ops_list = update_sample(new_ops_list, target_ops, orig_operators)
orig_config["operations"] = new_ops_list
# Ensure all model references start with 'azure/'
fix_models(orig_config)
except ValueError as e:
print(f"Failed to instantiate directive '{directive}': {e}")
return None, message_history
output_file_path = os.path.join(
save_dir,
f"iteration_{iteration_num}.yaml",
)
# Model names should be specified correctly in the YAML - no automatic prefixing
# Add bypass_cache: true at the top level
orig_config["bypass_cache"] = True
# Save the modified config
with open(output_file_path, "w") as file:
yaml.dump(
orig_config,
file,
default_flow_style=False,
allow_unicode=True,
sort_keys=False,
)
print(f"Modified YAML saved to: {output_file_path}")
# Execute the pipeline to get cost and sample outputs for next iteration
total_cost = 0.0
try:
from dotenv import load_dotenv
from docetl.runner import DSLRunner
# Update output path if output_dir is provided
if output_dir:
json_output_path = os.path.join(
output_dir, f"iteration_{iteration_num}_results.json"
)
orig_config["pipeline"]["output"]["path"] = json_output_path
# Save updated YAML with new output path
with open(output_file_path, "w") as file:
yaml.dump(
orig_config,
file,
default_flow_style=False,
allow_unicode=True,
sort_keys=False,
)
# Load environment
cwd = os.getcwd()
env_file = os.path.join(cwd, ".env")
if os.path.exists(env_file):
load_dotenv(env_file)
print("🔄 Executing pipeline to get cost and sample outputs...")
runner = DSLRunner.from_yaml(output_file_path)
runner.load()
if runner.last_op_container:
result_data, _, _ = runner.last_op_container.next()
runner.save(result_data)
total_cost = runner.total_cost
print(f"✅ Pipeline executed successfully, cost: ${total_cost:.4f}")
else:
print("⚠️ No results from pipeline execution")
raise Exception("No results from pipeline execution")
runner.reset_env()
except Exception as e:
print(f"❌ Pipeline execution failed: {e}")
raise e
return output_file_path, message_history, total_cost
# Use experiments/reasoning/run_baseline.py to run experiments

View File

@ -0,0 +1,458 @@
# Directives - How to Add a New Directive
This guide explains how to add a new directive to the reasoning optimizer. Directives are transformations that can be applied to pipeline operations to improve their effectiveness.
## What is a Directive?
A directive is a transformation rule that modifies pipeline operations. For example:
- **Chaining**: Breaks complex operations into sequential steps
- **Gleaning**: Adds validation loops to improve output quality
- **Change Model**: Switches the LLM model for better performance
- **Doc Summarization**: Adds preprocessing to summarize long documents
## Key Concepts
### What is an Instantiate Schema?
An **instantiate schema** is a Pydantic model that defines the structured output an agent (LLM) must produce to apply a directive. It acts as the "configuration blueprint" that tells the system exactly how to transform an operation.
For example, when applying the **gleaning** directive:
1. The agent looks at the original operation
2. The agent outputs a `GleaningInstantiateSchema` containing:
```python
{
"gleaning_config": {
"validation_prompt": "Check that the output has at least 2 insights...",
"num_rounds": 3,
"model": "gpt-4o-mini"
}
}
```
3. The directive uses this schema to add gleaning configuration to the operation
The instantiate schema ensures the agent provides all required parameters in the correct format for the directive to work.
### Workflow Overview
1. **Agent Analysis**: Agent examines the original operation and determines how to apply the directive
2. **Schema Generation**: Agent outputs structured configuration using the instantiate schema format
3. **Directive Application**: The directive's `apply()` method uses this configuration to transform the pipeline
4. **Validation**: Schema validators ensure the configuration is valid before application
## Quick Start - Adding a New Directive
### 1. Define Your Instantiate Schema
First, add your schema classes to `docetl/reasoning_optimizer/instantiate_schemas.py`. The instantiate schema defines what the agent must output:
```python
class MyDirectiveConfig(BaseModel):
"""Configuration parameters for your directive."""
param1: str = Field(..., description="Description of param1")
param2: int = Field(default=3, description="Description of param2")
model: str = Field(default="gpt-4o-mini", description="The LLM model to use")
class MyDirectiveInstantiateSchema(BaseModel):
"""
Schema that the agent must output to instantiate this directive.
This is what gets returned by the LLM when asked to apply the directive.
"""
my_directive_config: MyDirectiveConfig = Field(
..., description="The configuration to apply to the target operation"
)
# Add validators if needed
@field_validator("my_directive_config")
@classmethod
def validate_config(cls, v):
# Add validation logic here
return v
```
### 2. Create Your Directive Class
Create a new file in this directory (e.g., `my_directive.py`):
```python
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import MyDirectiveInstantiateSchema
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class MyDirective(Directive):
name: str = Field(default="my_directive", description="The name of the directive")
formal_description: str = Field(default="Op => Modified_Op")
nl_description: str = Field(
default="Natural language description of what this directive does"
)
when_to_use: str = Field(
default="When to apply this directive (specific use cases)"
)
# This tells the system what schema format the agent should output
instantiate_schema_type: Type[BaseModel] = Field(default=MyDirectiveInstantiateSchema)
example: str = Field(
default="""
Original Op (MapOpConfig):
- name: example_op
type: map
prompt: |
Example prompt: {{ input.document }}
output:
schema:
result: "string"
Example InstantiateSchema (what the agent should output):
MyDirectiveConfig(
param1="example_value",
param2=5,
model="gpt-4o-mini"
)
"""
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="basic_functionality",
description="Should apply directive transformation correctly",
input_config={
"name": "test_op",
"type": "map",
"prompt": "Test prompt: {{ input.text }}",
"output": {"schema": {"result": "string"}},
},
target_ops=["test_op"],
expected_behavior="Should modify the operation with directive configuration",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, MyDirective)
def __hash__(self):
return hash("MyDirective")
```
### 3. Implement Required Methods
```python
def to_string_for_instantiate(self, original_op: Dict) -> str:
"""
Generate a prompt that asks the agent to output the instantiate schema.
This prompt explains to the LLM what configuration it needs to generate.
"""
return (
f"You are an expert at [specific domain expertise for your directive].\n\n"
f"Original Operation:\n"
f"{str(original_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to instantiate this directive by generating a MyDirectiveConfig "
f"that specifies [specific instructions for what the directive should do].\n\n"
f"The agent must output the configuration in this exact format:\n"
f"- param1: [explanation of how to set this]\n"
f"- param2: [explanation of how to set this]\n"
f"- model: [which model to use]\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please output only the InstantiateSchema (MyDirectiveConfig object) "
f"that specifies how to apply this directive to the original operation."
)
def llm_instantiate(
self,
original_op: Dict,
agent_llm: str,
message_history: list = [],
) -> tuple:
"""
Call the LLM to generate the instantiate schema.
The LLM will output structured data matching MyDirectiveInstantiateSchema.
"""
message_history.extend([
{
"role": "user",
"content": self.to_string_for_instantiate(original_op),
},
])
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=MyDirectiveInstantiateSchema, # Forces structured output
)
try:
parsed_res = json.loads(resp.choices[0].message.content)
if "my_directive_config" not in parsed_res:
raise ValueError("Response missing required key 'my_directive_config'")
config = parsed_res["my_directive_config"]
schema = MyDirectiveInstantiateSchema(my_directive_config=config)
message_history.append({
"role": "assistant",
"content": resp.choices[0].message.content
})
return schema, message_history
except Exception as err:
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts.")
def apply(
self, ops_list: List[Dict], target_op: str, rewrite: MyDirectiveInstantiateSchema
) -> List[Dict]:
"""
Apply the directive using the instantiate schema configuration.
The 'rewrite' parameter contains the agent's generated configuration.
"""
new_ops_list = deepcopy(ops_list)
# Find the target operation
pos_to_replace = [i for i, op in enumerate(ops_list) if op["name"] == target_op][0]
target_operator = new_ops_list[pos_to_replace]
# Apply transformation using the agent's configuration
target_operator["my_directive_param"] = rewrite.my_directive_config.param1
target_operator["model"] = rewrite.my_directive_config.model
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
**kwargs,
) -> tuple:
"""
Main method that orchestrates directive instantiation:
1. Get agent to generate instantiate schema
2. Apply the transformation using that schema
"""
assert len(target_ops) == 1, "This directive requires exactly one target op"
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
# Step 1: Agent generates the instantiate schema
rewrite, message_history = self.llm_instantiate(
target_op_config, agent_llm, message_history
)
# Step 2: Apply transformation using the schema
return self.apply(operators, target_ops[0], rewrite), message_history
```
### 4. Register Your Directive
Add your directive to `__init__.py`:
```python
from .my_directive import MyDirective
ALL_DIRECTIVES = [
ChainingDirective(),
GleaningDirective(),
ChangeModelDirective(),
DocSummarizationDirective(),
MyDirective(), # Add your directive here
]
```
### 5. Update Test Runner
Add your directive to `tests/reasoning_optimizer/test_runner.py` in two places:
1. **Import section**:
```python
from docetl.reasoning_optimizer.directives import (
ChainingDirective,
GleaningDirective,
ChangeModelDirective,
DocSummarizationDirective,
IsolatingSubtasksDirective,
MyDirective, # Add your directive here
TestResult
)
```
2. **Both directive lists**:
```python
# In run_all_directive_tests function
directives = [
ChainingDirective(),
GleaningDirective(),
ChangeModelDirective(),
DocSummarizationDirective(),
IsolatingSubtasksDirective(),
MyDirective() # Add your directive here
]
# In run_specific_directive_test function
directive_map = {
"chaining": ChainingDirective(),
"gleaning": GleaningDirective(),
"change_model": ChangeModelDirective(),
"doc_summarization": DocSummarizationDirective(),
"isolating_subtasks": IsolatingSubtasksDirective(),
"my_directive": MyDirective() # Add your directive here
}
```
## Real Examples of Instantiate Schemas
### Gleaning Directive
```python
# What the agent outputs:
{
"gleaning_config": {
"validation_prompt": "Check that the output contains at least 2 insights and each has supporting actions",
"num_rounds": 3,
"model": "gpt-4o-mini"
}
}
# How it's applied: Adds gleaning configuration to the operation
target_operator["gleaning"] = {
"validation_prompt": rewrite.gleaning_config.validation_prompt,
"num_rounds": rewrite.gleaning_config.num_rounds,
"model": rewrite.gleaning_config.model,
}
```
### Chaining Directive
```python
# What the agent outputs:
{
"new_ops": [
{
"name": "extract_conditions",
"prompt": "Identify new medical conditions from: {{ input.summary }}",
"output_keys": ["conditions"],
"model": "gpt-4o-mini"
},
{
"name": "extract_treatments",
"prompt": "Extract treatments for conditions {{ input.conditions }} from {{ input.summary }}",
"output_keys": ["treatments"],
"model": "gpt-4o-mini"
}
]
}
# How it's applied: Replaces one operation with multiple chained operations
```
### Change Model Directive
```python
# What the agent outputs:
{
"change_model_config": {
"model": "gpt-4o"
}
}
# How it's applied: Changes the model field
target_operator["model"] = rewrite.change_model_config.model
```
## Testing Your Directive
### Individual Directive Testing
Test your directive by running its test cases:
```python
from docetl.reasoning_optimizer.directives import MyDirective
directive = MyDirective()
test_results = directive.run_tests(agent_llm="gpt-4o-mini")
for result in test_results:
print(f"{result.test_name}: {'PASS' if result.passed else 'FAIL'}")
print(f"Reason: {result.reason}")
```
### Command Line Testing
Run directive instantiation tests from the command line:
```bash
# Test a specific directive
python experiments/reasoning/run_tests.py --directive=isolating_subtasks
# Test all directive instantiation tests
python experiments/reasoning/run_tests.py
```
### Apply Method Testing
Test that directive `apply()` methods work correctly:
```bash
# Test all directive apply methods
python tests/reasoning_optimizer/test_directive_apply.py
```
This ensures the `apply()` method doesn't crash when given realistic pipeline configurations and rewrite schemas.
### Integration Testing
Full pipeline integration testing can be done via `experiments/reasoning/run_mcts.py`.
## Common Patterns
### 1. Single Operation Modification
Adds configuration to existing operation:
- **Gleaning**: Adds validation config
- **Change Model**: Modifies model parameter
### 2. Operation Replacement
Replaces one operation with multiple:
- **Chaining**: Creates sequence of simpler operations
### 3. Pipeline Preprocessing
Adds operations at pipeline start:
- **Doc Summarization**: Adds summarization step before main processing
## File Structure Summary
```
docetl/reasoning_optimizer/
├── directives/
│ ├── my_directive.py # Your directive implementation
│ ├── __init__.py # Register in ALL_DIRECTIVES
│ └── README.md # This file
└── instantiate_schemas.py # Define your schema classes
experiments/reasoning/ # Testing framework
└── run_mcts.py # Full pipeline testing
```
## Best Practices
1. **Schema First**: Design the instantiate schema before implementing the directive - it defines the interface
2. **Clear Agent Instructions**: The `to_string_for_instantiate()` method should clearly explain what the agent needs to output
3. **Validation**: Use Pydantic validators in your schema to catch invalid configurations
4. **Error Handling**: Handle LLM failures gracefully with retry logic
5. **Comprehensive Testing**: Test edge cases where the agent might output invalid configurations
6. **Documentation**: Clearly document what your instantiate schema fields mean and how they're used
The instantiate schema is the critical bridge between the agent's reasoning and your directive's implementation - design it carefully!

View File

@ -0,0 +1,192 @@
from typing import Dict, List
from .base import (
Directive,
DirectiveTestCase,
TestResult,
MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS,
DEFAULT_MODEL,
DEFAULT_MAX_TPM,
DEFAULT_OUTPUT_DIR
)
from .chaining import ChainingDirective
from .gleaning import GleaningDirective
from .change_model import ChangeModelDirective
from .change_model_acc import ChangeModelAccDirective # noqa: F401
from .change_model_cost import ChangeModelCostDirective
from .doc_summarization import DocSummarizationDirective
from .isolating_subtasks import IsolatingSubtasksDirective
# from .doc_compression import DocCompressionDirective
from .deterministic_doc_compression import DeterministicDocCompressionDirective
from .reduce_gleaning import ReduceGleaningDirective
from .reduce_chaining import ReduceChainingDirective
from .operator_fusion import OperatorFusionDirective
from .doc_chunking import DocumentChunkingDirective
from .doc_chunking_topk import DocumentChunkingTopKDirective
from .chunk_header_summary import ChunkHeaderSummaryDirective
from .take_head_tail import TakeHeadTailDirective
from .clarify_instructions import ClarifyInstructionsDirective
from .swap_with_code import SwapWithCodeDirective
from .map_reduce_fusion import MapReduceFusionDirective
from .hierarchical_reduce import HierarchicalReduceDirective
from .cascade_filtering import CascadeFilteringDirective
from .arbitrary_rewrite import ArbitraryRewriteDirective
from .map_to_map_resolve_reduce import MapToMapResolveReduceDirective
from .map_resolve_to_map_with_categories import MapResolveToMapWithCategoriesDirective
# Registry of all available directives
ALL_DIRECTIVES = [
ChainingDirective(),
GleaningDirective(),
ReduceGleaningDirective(),
ReduceChainingDirective(),
ChangeModelCostDirective(),
DocSummarizationDirective(),
IsolatingSubtasksDirective(),
# DocCompressionDirective(),
DeterministicDocCompressionDirective(),
OperatorFusionDirective(),
DocumentChunkingDirective(),
DocumentChunkingTopKDirective(),
ChunkHeaderSummaryDirective(),
TakeHeadTailDirective(),
ClarifyInstructionsDirective(),
SwapWithCodeDirective(),
MapReduceFusionDirective(),
HierarchicalReduceDirective(),
CascadeFilteringDirective(),
ArbitraryRewriteDirective(),
MapToMapResolveReduceDirective(),
MapResolveToMapWithCategoriesDirective(),
]
ALL_COST_DIRECTIVES = [
ReduceChainingDirective(),
DocSummarizationDirective(),
# DocCompressionDirective(),
DeterministicDocCompressionDirective(),
DocumentChunkingTopKDirective(),
OperatorFusionDirective(),
TakeHeadTailDirective(),
MapReduceFusionDirective(),
CascadeFilteringDirective(),
ArbitraryRewriteDirective(),
]
DIRECTIVE_GROUPS = {
"compression": [
# DocCompressionDirective(),
DocSummarizationDirective(),
DeterministicDocCompressionDirective(),
],
"chunking": [
DocumentChunkingDirective(),
DocumentChunkingTopKDirective(),
]
}
MULTI_INSTANCE_DIRECTIVES = [
DocumentChunkingDirective(),
DocumentChunkingTopKDirective(),
DeterministicDocCompressionDirective(),
TakeHeadTailDirective(),
CascadeFilteringDirective(),
ClarifyInstructionsDirective(),
]
# Create a mapping from directive names to directive instances
DIRECTIVE_REGISTRY = {directive.name: directive for directive in ALL_DIRECTIVES}
def get_all_directive_strings() -> str:
"""
Generate string descriptions for all available directives for use in prompts.
Returns:
str: Formatted string containing all directive descriptions
"""
return "\n".join([directive.to_string_for_plan() for directive in ALL_DIRECTIVES])
def get_all_cost_directive_strings() -> str:
return "\n".join([directive.to_string_for_plan() for directive in ALL_COST_DIRECTIVES])
def instantiate_directive(
directive_name: str,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list,
global_default_model: str = None,
dataset: str = None,
**kwargs
):
"""
Centralized method to instantiate any directive by name.
Args:
directive_name: Name of the directive to instantiate
operators: List of pipeline operators
target_ops: List of target operation names
agent_llm: LLM model to use
message_history: Conversation history
**kwargs: Additional arguments to pass to directive
Returns:
Tuple of (new_ops_list, updated_message_history)
Raises:
ValueError: If directive_name is not recognized
"""
if message_history is None:
message_history = []
if directive_name not in DIRECTIVE_REGISTRY:
available = list(DIRECTIVE_REGISTRY.keys())
raise ValueError(f"Unknown directive '{directive_name}'. Available: {available}")
directive = DIRECTIVE_REGISTRY[directive_name]
return directive.instantiate(
operators=operators,
target_ops=target_ops,
agent_llm=agent_llm,
message_history=message_history,
global_default_model = global_default_model,
dataset=dataset,
**kwargs
)
__all__ = [
"Directive",
"DirectiveTestCase",
"TestResult",
"MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS",
"DEFAULT_MODEL",
"DEFAULT_MAX_TPM",
"DEFAULT_OUTPUT_DIR",
"ChainingDirective",
"GleaningDirective",
"ReduceGleaningDirective",
"ReduceChainingDirective",
"ChangeModelDirective",
"DocSummarizationDirective",
"IsolatingSubtasksDirective",
# "DocCompressionDirective",
"DeterministicDocCompressionDirective",
"OperatorFusionDirective",
"DocumentChunkingDirective",
"DocumentChunkingTopKDirective",
"ChunkHeaderSummaryDirective",
"TakeHeadTailDirective",
"ClarifyInstructionsDirective",
"SwapWithCodeDirective",
"MapReduceFusionDirective",
"HierarchicalReduceDirective",
"CascadeFilteringDirective",
"ArbitraryRewriteDirective",
"MapToMapResolveReduceDirective",
"MapResolveToMapWithCategoriesDirective",
"ALL_DIRECTIVES",
"DIRECTIVE_REGISTRY",
"get_all_directive_strings",
"instantiate_directive"
]

View File

@ -0,0 +1,458 @@
import json
import os
from typing import Dict, List, Literal, Optional
from litellm import completion, model_cost
from pydantic import BaseModel, Field
from rich import print as rprint
from docetl.operations.utils.llm import count_tokens
class AgentDecision(BaseModel):
"""Schema for agent decision-making in agentic loops."""
action: Literal["read_next_docs", "read_operator_doc", "output_schema"] = Field(
..., description="The action the agent wants to take"
)
reasoning: str = Field(
...,
description="Explanation of why the agent chose this action and what they learned from current samples",
)
operator_name: Optional[str] = Field(
None,
description="For read_operator_doc action: the operator name to read documentation for (e.g., 'map', 'filter', 'reduce')",
)
class ReadNextDocTool:
"""Tool for iteratively reading documents from input data."""
def __init__(
self,
input_data: List[Dict],
context_window: int = 32000,
docs_per_iteration: int = 3,
):
self.input_data = input_data
self.current_index = 0
self.context_window = context_window
self.total_docs = len(input_data)
self.docs_per_iteration = docs_per_iteration
def read_next_doc(self) -> Optional[Dict]:
"""Read the next document from the input data."""
if self.current_index >= len(self.input_data):
return None
doc = self.input_data[self.current_index]
self.current_index += 1
return doc
def read_next_docs(self, count: int = None) -> List[Dict]:
"""Read the next N documents from the input data."""
if count is None:
count = self.docs_per_iteration
docs = []
for _ in range(count):
if self.current_index >= len(self.input_data):
break
docs.append(self.input_data[self.current_index])
self.current_index += 1
return docs
def has_more_docs(self) -> bool:
"""Check if there are more documents to read."""
return self.current_index < len(self.input_data)
def get_remaining_count(self) -> int:
"""Get the number of remaining documents."""
return len(self.input_data) - self.current_index
def reset(self) -> None:
"""Reset the iterator to the beginning."""
self.current_index = 0
def estimate_token_count(text: str, model: str = "gpt-4.1-mini") -> int:
"""Use proper token counting instead of rough estimation."""
return count_tokens(text, model)
def truncate_message_content(messages: List[Dict], max_tokens: int) -> List[Dict]:
"""
Truncate message content to fit within token limits.
Preserves system message and latest user message, truncates middle content.
"""
if not messages:
return messages
# Calculate total token count
total_tokens = sum(estimate_token_count(msg.get("content", "")) for msg in messages)
if total_tokens <= max_tokens:
return messages
# Keep system message and latest user message
truncated_messages = []
if messages[0].get("role") == "system":
truncated_messages.append(messages[0])
remaining_messages = messages[1:]
else:
remaining_messages = messages
# Always keep the latest message
if remaining_messages:
truncated_messages.append(remaining_messages[-1])
middle_messages = remaining_messages[:-1]
else:
middle_messages = []
# Calculate available tokens for middle messages
system_tokens = (
estimate_token_count(truncated_messages[0].get("content", ""))
if truncated_messages
else 0
)
latest_tokens = (
estimate_token_count(truncated_messages[-1].get("content", ""))
if len(truncated_messages) > 1
else 0
)
available_tokens = (
max_tokens - system_tokens - latest_tokens - 1000
) # Buffer for response
# Add middle messages until we hit the limit
current_tokens = 0
for msg in reversed(middle_messages): # Add most recent first
msg_tokens = estimate_token_count(msg.get("content", ""))
if current_tokens + msg_tokens <= available_tokens:
current_tokens += msg_tokens
truncated_messages.insert(-1, msg) # Insert before the latest message
else:
break
return truncated_messages
class AgenticDirectiveRunner:
"""
Utility class for running agentic directives that iteratively process documents.
Manages context windows, document iteration, and decision-making loops.
"""
def __init__(
self,
input_data: List[Dict],
agent_llm: str = "gpt-4.1-mini",
validation_func: Optional[callable] = None,
enable_operator_docs: bool = False,
):
self.input_data = input_data
self.agent_llm = agent_llm
self.context_window = self._get_model_context_window(agent_llm)
self.enable_operator_docs = enable_operator_docs
# Double the max iterations if operator docs are enabled to allow more exploration
self.docs_per_iteration = 3
self.max_iterations = 6 if enable_operator_docs else 3
self.doc_reader = ReadNextDocTool(
input_data, self.context_window, self.docs_per_iteration
)
self.message_history = []
self.validation_func = validation_func
self.docs_path = os.path.join(
os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
),
"docs",
"operators",
)
def _get_model_context_window(self, model: str) -> int:
"""Get the context window size for the given model."""
model_cost_info = model_cost.get(model, {})
if not model_cost_info:
# Try stripping the first part before the /
split_model = model.split("/")
if len(split_model) > 1:
model_cost_info = model_cost.get("/".join(split_model[1:]), {})
if not model_cost_info:
model_cost_info = model_cost.get(model.split("/")[-1], {})
return model_cost_info.get("max_input_tokens", 32768)
def _read_operator_doc(self, operator_name: str) -> Optional[str]:
"""
Read the documentation for a specific operator.
Args:
operator_name: Name of the operator (e.g., 'map', 'filter', 'reduce')
Returns:
The markdown documentation content or None if not found
"""
doc_file = os.path.join(self.docs_path, f"{operator_name}.md")
if os.path.exists(doc_file):
try:
with open(doc_file, "r") as f:
content = f.read()
return content
except Exception as e:
return f"Error reading documentation for {operator_name}: {str(e)}"
else:
# Try alternative names (e.g., 'parallel-map' for 'parallel_map')
alt_name = operator_name.replace("_", "-")
doc_file = os.path.join(self.docs_path, f"{alt_name}.md")
if os.path.exists(doc_file):
try:
with open(doc_file, "r") as f:
content = f.read()
return content
except Exception as e:
return f"Error reading documentation for {operator_name}: {str(e)}"
return f"Documentation not found for operator: {operator_name}"
def _truncate_doc_to_tokens(self, doc: Dict, max_tokens: int) -> str:
"""
Truncate document content to fit within the specified token limit.
Args:
doc: The document dictionary to truncate
max_tokens: Maximum number of tokens to allow
Returns:
Truncated document string
"""
doc_str = json.dumps(doc, indent=2)
doc_tokens = estimate_token_count(doc_str, self.agent_llm)
if doc_tokens <= max_tokens:
return doc_str
# If document is too long, truncate it
# Estimate characters per token (rough approximation)
chars_per_token = len(doc_str) / doc_tokens
target_chars = int(max_tokens * chars_per_token)
truncated = doc_str[:target_chars]
return truncated + "... [truncated]"
def run_agentic_loop(
self, system_prompt: str, initial_user_message: str, response_schema: BaseModel
):
"""
Run an agentic loop where the agent analyzes input data for directive instantiation.
Args:
system_prompt: System message for the agent
initial_user_message: Initial user message with task description
response_schema: Pydantic schema for the expected response
Returns:
Tuple of (parsed_response, message_history)
"""
call_cost = 0.0
# Initialize message history
self.message_history = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": initial_user_message},
]
max_iterations = min(
self.max_iterations, len(self.input_data)
) # Conservative limit for analysis
rprint(
f"[blue]🤖 Determining rewrite instantiation with {len(self.input_data)} documents available[/blue]"
)
for iteration in range(max_iterations):
# Calculate remaining context
current_tokens = sum(
estimate_token_count(msg.get("content", ""), self.agent_llm)
for msg in self.message_history
)
remaining_tokens = self.context_window - current_tokens - 2000 # Buffer
# Add context info for analysis
context_info = f"""
Analysis Progress:
- Remaining context window: ~{remaining_tokens} tokens
- Documents analyzed: {self.doc_reader.current_index}/{self.doc_reader.total_docs}
- Documents remaining: {self.doc_reader.get_remaining_count()}
Analyze the input samples to understand patterns, edge cases, and data characteristics that will help you complete your task effectively.
"""
# Create action guidance
if self.enable_operator_docs:
action_guidance = f"""
Choose your next action:
- read_next_docs: If you need more examples to understand data patterns, edge cases, or to gather more information for your analysis (reads ~{self.docs_per_iteration} documents at once)
- read_operator_doc: If you need to understand how a specific operator works, its parameters, or see examples (specify operator_name like 'map', 'filter', 'reduce')
- output_schema: If you have sufficient examples to complete your task based on the patterns and insights you've gathered from the data
Focus on quality over quantity - a few diverse, informative examples are better than many similar ones.
"""
else:
action_guidance = f"""
Choose your next action:
- read_next_docs: If you need more examples to understand data patterns, edge cases, or to gather more information for your analysis (reads ~{self.docs_per_iteration} documents at once)
- output_schema: If you have sufficient examples to complete your task based on the patterns and insights you've gathered from the data
Focus on quality over quantity - a few diverse, informative examples are better than many similar ones.
"""
# Update the latest user message with context info
if self.message_history[-1]["role"] == "user":
self.message_history[-1]["content"] += context_info + action_guidance
# Truncate messages if needed
self.message_history = truncate_message_content(
self.message_history, self.context_window - 2000
)
rprint(
f"[yellow]🧠 Iteration {iteration + 1}/{max_iterations}: Asking {self.agent_llm} agent to decide next action (tokens: {remaining_tokens} remaining)[/yellow]"
)
# Get structured agent decision
response = completion(
model=self.agent_llm,
messages=self.message_history,
response_format=AgentDecision,
)
call_cost += response._hidden_params["response_cost"]
try:
decision_json = json.loads(response.choices[0].message.content)
decision = AgentDecision(**decision_json)
except Exception as e:
raise Exception(f"Failed to parse agent decision: {str(e)}")
self.message_history.append(
{"role": "assistant", "content": response.choices[0].message.content}
)
# Handle agent's decision
if decision.action == "read_operator_doc":
# Agent wants to read operator documentation
if not self.enable_operator_docs:
user_message = "Operator documentation reading is not enabled for this directive."
self.message_history.append(
{"role": "user", "content": user_message}
)
elif not decision.operator_name:
user_message = "Please specify which operator documentation you want to read (e.g., 'map', 'filter', 'reduce')."
self.message_history.append(
{"role": "user", "content": user_message}
)
else:
rprint(
f"[blue]📖 Agent reading documentation for operator: {decision.operator_name}[/blue]"
)
doc_content = self._read_operator_doc(decision.operator_name)
user_message = f"Documentation for '{decision.operator_name}' operator:\n\n{doc_content}\n\nAnalyze this documentation to understand how to use this operator effectively."
self.message_history.append(
{"role": "user", "content": user_message}
)
elif decision.action == "read_next_docs":
# Agent wants to analyze more data
next_docs = self.doc_reader.read_next_docs()
if not next_docs:
# No more documents - force output
rprint(
"[red]📄 No more documents available. Proceeding with schema generation.[/red]"
)
user_message = "No more documents available. Based on the samples you've analyzed, please complete your task."
self.message_history.append(
{"role": "user", "content": user_message}
)
break
else:
rprint(
f"[green]📄 Agent reading {len(next_docs)} documents (up to {self.doc_reader.current_index}/{len(self.input_data)})[/green]"
)
docs_content = "\n".join(
[
f"Sample {self.doc_reader.current_index - len(next_docs) + i + 1}:\n{self._truncate_doc_to_tokens(doc, 1000)}"
for i, doc in enumerate(next_docs)
]
)
user_message = f"{docs_content}\n\nAnalyze these samples for patterns, edge cases, and characteristics that will help with your task."
self.message_history.append(
{"role": "user", "content": user_message}
)
elif decision.action == "output_schema":
# Agent is ready to create improved prompt
rprint(
f"[cyan]✨ Agent ready to generate final schema after analyzing {self.doc_reader.current_index} documents[/cyan]"
)
schema_prompt = f"""Based on your analysis of the input samples, complete your task using the patterns and insights you've gathered from the data.
Provide your response as a JSON object matching this schema: {response_schema.model_json_schema()}"""
self.message_history.append({"role": "user", "content": schema_prompt})
break
# Get the final schema response with validation and retries
rprint("[magenta]🔧 Generating final rewrite schema...[/magenta]")
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS
error_message = ""
for attempt in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
schema_response = completion(
model=self.agent_llm,
messages=self.message_history,
response_format=response_schema,
)
call_cost += schema_response._hidden_params["response_cost"]
try:
parsed_response = json.loads(schema_response.choices[0].message.content)
schema_instance = response_schema(**parsed_response)
# Add any additional validation if provided
if self.validation_func:
self.validation_func(schema_instance)
rprint(
f"[green]✅ Schema validation passed on attempt {attempt + 1}[/green]"
)
self.message_history.append(
{
"role": "assistant",
"content": schema_response.choices[0].message.content,
}
)
return schema_instance, self.message_history, call_cost
except Exception as err:
error_message = f"Validation error: {err}\nPlease try again with a corrected response."
rprint(
f"[red]❌ Schema validation failed on attempt {attempt + 1}: {str(err)}[/red]"
)
if attempt < MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS - 1:
rprint(
f"[yellow]🔄 Retrying schema generation (attempt {attempt + 2}/{MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS})[/yellow]"
)
self.message_history.append(
{
"role": "assistant",
"content": schema_response.choices[0].message.content,
}
)
self.message_history.append(
{"role": "user", "content": error_message}
)
raise Exception(
f"Failed to generate valid schema after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Error: {error_message}"
)

View File

@ -0,0 +1,361 @@
import json
from typing import Dict, List, Type
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
ArbitraryRewriteInstantiateSchema,
)
from docetl.reasoning_optimizer.op_descriptions import get_all_operator_descriptions
from .agent_utils import AgenticDirectiveRunner
from .base import Directive, DirectiveTestCase
class ArbitraryRewriteDirective(Directive):
name: str = Field(
default="arbitrary_rewrite", description="The name of the directive"
)
formal_description: str = Field(
default="Pipeline => Modified Pipeline (search/replace edits)"
)
nl_description: str = Field(
default="Allows the agent to make arbitrary edits to the pipeline using search/replace operations on the JSON representation. The agent can add, modify, remove, or replace operations to optimize for cost, accuracy, or both. This is a catch-all directive for optimizations that don't fit into other specific directive patterns."
)
when_to_use: str = Field(
default="When you identify obvious optimizations that can make the pipeline cheaper or more accurate, but they don't fit into existing directive patterns. Use this for complex multi-operation changes, reordering operations, consolidating redundant operations, or making systematic improvements across the pipeline."
)
instantiate_schema_type: Type[BaseModel] = Field(
default=ArbitraryRewriteInstantiateSchema
)
example: str = Field(
default="""
# Example 1: Consolidating Redundant Operations
Original Pipeline JSON (formatted for readability):
[
{
"name": "extract_names",
"type": "map",
"model": "gpt-4o",
"prompt": "Extract all person names from: {{ input.text }}",
"output": {"schema": {"names": "list[string]"}}
},
{
"name": "extract_locations",
"type": "map",
"model": "gpt-4o",
"prompt": "Extract all locations from: {{ input.text }}",
"output": {"schema": {"locations": "list[string]"}}
},
{
"name": "extract_dates",
"type": "map",
"model": "gpt-4o",
"prompt": "Extract all dates from: {{ input.text }}",
"output": {"schema": {"dates": "list[string]"}}
}
]
Example SearchReplaceEdits (agent recognizes redundant LLM calls):
search_replace_edits=[
SearchReplaceEdit(
search=' {\\n "name": "extract_names",\\n "type": "map",\\n "model": "gpt-4o",\\n "prompt": "Extract all person names from: {{ input.text }}",\\n "output": {"schema": {"names": "list[string]"}}\\n },\\n {\\n "name": "extract_locations",\\n "type": "map",\\n "model": "gpt-4o",\\n "prompt": "Extract all locations from: {{ input.text }}",\\n "output": {"schema": {"locations": "list[string]"}}\\n },\\n {\\n "name": "extract_dates",\\n "type": "map",\\n "model": "gpt-4o",\\n "prompt": "Extract all dates from: {{ input.text }}",\\n "output": {"schema": {"dates": "list[string]"}}\\n }',
replace=' {\\n "name": "extract_all_entities",\\n "type": "map",\\n "model": "gpt-4o-mini",\\n "prompt": "Extract all entities from the text:\\\\n{{ input.text }}\\\\n\\\\nReturn names, locations, and dates.",\\n "output": {\\n "schema": {\\n "names": "list[string]",\\n "locations": "list[string]",\\n "dates": "list[string]"\\n }\\n }\\n }',
reasoning="Consolidate three separate GPT-4o calls into one GPT-4o-mini call"
)
]
# Example 2: Reordering for Efficiency
Original Pipeline JSON:
[
{
"name": "summarize_documents",
"type": "map",
"model": "gpt-4o",
"prompt": "Summarize this document: {{ input.full_text }}",
"output": {"schema": {"summary": "string"}}
},
{
"name": "filter_relevant",
"type": "filter",
"model": "gpt-4o-mini",
"prompt": "Is this document about technology? {{ input.full_text }}",
"output": {"schema": {"is_relevant": "boolean"}}
}
]
Example SearchReplaceEdits (agent recognizes inefficient ordering):
search_replace_edits=[
SearchReplaceEdit(
search='{\\n "name": "summarize_documents",\\n "type": "map",\\n "model": "gpt-4o",\\n "prompt": "Summarize this document: {{ input.full_text }}",\\n "output": {"schema": {"summary": "string"}}\\n },\\n {\\n "name": "filter_relevant",',
replace='{\\n "name": "filter_relevant",',
reasoning="Remove summarize operation from before filter"
),
SearchReplaceEdit(
search='"output": {"schema": {"is_relevant": "boolean"}}\\n }',
replace='"output": {"schema": {"is_relevant": "boolean"}}\\n },\\n {\\n "name": "summarize_documents",\\n "type": "map",\\n "model": "gpt-4o",\\n "prompt": "Summarize this technology document: {{ input.full_text }}",\\n "output": {"schema": {"summary": "string"}}\\n }',
reasoning="Add summarization after filter to only process relevant documents"
)
]
Note: The agent must provide exact string matches including whitespace for the search strings.
"""
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="consolidate_operations",
description="Should identify and consolidate redundant operations",
input_config=[
{
"name": "extract_a",
"type": "map",
"model": "gpt-4o",
"prompt": "Extract A from: {{ input.text }}",
"output": {"schema": {"a": "string"}},
},
{
"name": "extract_b",
"type": "map",
"model": "gpt-4o",
"prompt": "Extract B from: {{ input.text }}",
"output": {"schema": {"b": "string"}},
},
],
target_ops=[], # Analyzes entire pipeline
expected_behavior="Should propose consolidating the two extraction operations into one",
should_pass=True,
)
]
)
def __eq__(self, other):
return isinstance(other, ArbitraryRewriteDirective)
def __hash__(self):
return hash("ArbitraryRewriteDirective")
def _order_ops_by_pipeline_steps(
self, ops_list: List[Dict], pipeline_code: Dict = None
) -> List[Dict]:
"""Order operations according to pipeline steps if pipeline_code is provided."""
if not (
pipeline_code
and "pipeline" in pipeline_code
and "steps" in pipeline_code["pipeline"]
):
return ops_list
# Get the order from pipeline steps
steps = pipeline_code["pipeline"]["steps"]
step_order = []
for step in steps:
if isinstance(step, dict) and "operations" in step:
step_order.extend(step["operations"])
elif isinstance(step, str):
step_order.append(step)
# Create a mapping of op names to ops
ops_by_name = {op["name"]: op for op in ops_list}
# Reorder ops according to pipeline steps
ordered_ops = []
for op_name in step_order:
if op_name in ops_by_name:
ordered_ops.append(ops_by_name[op_name])
# Add any ops that weren't in steps (shouldn't happen, but just in case)
for op in ops_list:
if op not in ordered_ops:
ordered_ops.append(op)
return ordered_ops
def to_string_for_instantiate(
self, pipeline_ops: List[Dict], pipeline_code: Dict = None
) -> str:
"""Generate prompt for agent to analyze pipeline and propose edits."""
# Convert ops to pretty JSON string for display (already ordered by instantiate)
pipeline_json = json.dumps(pipeline_ops, indent=4)
return (
f"You are an expert at optimizing data processing pipelines for cost and accuracy.\n\n"
f"Current Pipeline Operations (as JSON array):\n"
f"{pipeline_json}\n\n"
f"Your task is to analyze this pipeline and propose search/replace edits to optimize it.\n\n"
f"IMPORTANT: The above JSON is ONLY the operations array. When creating search/replace edits, "
f"work with this exact JSON structure. Do not include pipeline wrapper or other fields.\n\n"
f"You have access to:\n"
f"1. Sample input data through read_next_docs() - to understand data patterns and flow\n"
f"2. Operator documentation through read_operator_doc(operator_name) - to learn about available operators\n\n"
f"IMPORTANT: Read documentation for no more than 2 operators before analyzing sample data to get a sense for how best to rewrite the pipeline.\n\n"
f"Use these tools to:\n"
f"1. Understand the data flow through the pipeline\n"
f"2. Identify inefficiencies or redundancies\n"
f"3. Find opportunities for consolidation or reordering\n"
f"4. Determine if cheaper models could be used\n"
f"5. Discover new operators that might be more efficient\n\n"
f"IMPORTANT: You will provide search/replace edits that work on the JSON string representation.\n"
f"Each edit consists of:\n"
f"- search: An exact string to find in the JSON (including whitespace)\n"
f"- replace: The string to replace it with (can be empty to delete)\n"
f"- reasoning: Why this edit improves the pipeline\n\n"
f"The edits will be applied sequentially to the JSON string, then parsed back to operations.\n\n"
f"Guidelines for search/replace:\n"
f"- The search string must match EXACTLY including all whitespace, quotes, brackets, etc.\n"
f"- You can delete operations by replacing them with empty string (but be careful with commas)\n"
f"- You can add operations by replacing closing brackets with new operations\n"
f"- You can reorder by using multiple search/replace operations\n"
f"- Each edit operates on the result of the previous edit\n\n"
f"Types of optimizations to look for:\n"
f"- Redundant operations that could be consolidated\n"
f"- Operations in suboptimal order (e.g., expensive operations before filters)\n"
f"- Opportunities to use cheaper models\n"
f"- Complex operations that could be broken into simpler steps\n"
f"- Independent operations that could be parallelized\n\n"
f"Examples:\n"
f"{self.example}\n\n"
f"Analyze the pipeline and sample data strategically. When you have identified optimizations, "
f"output your proposed edits as an ArbitraryRewriteInstantiateSchema.\n\n"
f"Remember: Your search strings must match the JSON exactly as it appears above."
)
def llm_instantiate(
self,
pipeline_ops: List[Dict],
input_file_path: str,
agent_llm: str,
message_history: list = [],
pipeline_code: Dict = None,
):
"""Use agentic approach to analyze pipeline and generate edits."""
# Load sample input data
try:
with open(input_file_path, "r") as f:
input_data = json.load(f)
if not isinstance(input_data, list) or len(input_data) == 0:
raise ValueError(
"Input file must contain a non-empty list of sample data"
)
except Exception as e:
raise Exception(
f"Failed to load input data from {input_file_path}: {str(e)}"
)
# Set up agentic runner with operator doc reading enabled
runner = AgenticDirectiveRunner(
input_data=input_data,
agent_llm=agent_llm,
enable_operator_docs=True, # Enable reading operator documentation
)
# Create system prompt with operator descriptions
operator_descriptions = get_all_operator_descriptions()
system_prompt = (
"You are an expert at optimizing data processing pipelines. "
"You analyze pipeline structures and data flow to identify opportunities "
"for cost reduction and accuracy improvement through strategic search/replace edits on the JSON representation.\n\n"
f"{operator_descriptions}"
)
# Create initial user message
initial_message = self.to_string_for_instantiate(pipeline_ops, pipeline_code)
# Run the agentic loop
try:
schema, updated_message_history, call_cost = runner.run_agentic_loop(
system_prompt=system_prompt,
initial_user_message=initial_message,
response_schema=ArbitraryRewriteInstantiateSchema,
)
# Update message history
message_history.extend(updated_message_history)
return schema, message_history, call_cost
except Exception as e:
raise Exception(
f"Failed to instantiate arbitrary_rewrite directive: {str(e)}"
)
def apply(
self,
ops_list: List[Dict],
rewrite: ArbitraryRewriteInstantiateSchema,
) -> List[Dict]:
"""Apply the search/replace edits to the pipeline."""
# Convert operations list to JSON string with consistent formatting
pipeline_json = json.dumps(ops_list, indent=4)
# Apply each search/replace edit in sequence
for i, edit in enumerate(rewrite.search_replace_edits):
if edit.search in pipeline_json:
pipeline_json = pipeline_json.replace(
edit.search, edit.replace, 1
) # Replace first occurrence only
else:
# Log warning but continue with other edits
print(
f"Warning: Search string not found in edit {i+1}: {edit.search[:50]}..."
)
# Parse the modified JSON back to operations list
try:
new_ops_list = json.loads(pipeline_json)
if not isinstance(new_ops_list, list):
raise ValueError("Pipeline must be a list of operations")
# Get rid of any empty operations in new_ops_list
new_ops_list = [op for op in new_ops_list if op]
return new_ops_list
except json.JSONDecodeError as e:
raise ValueError(
f"Failed to parse pipeline after edits. The search/replace operations resulted in invalid JSON: {e}\n"
f"Modified JSON:\n{pipeline_json[:500]}..."
)
def instantiate(
self,
operators: List[Dict],
target_ops: List[str] = None,
agent_llm: str = "gpt-4o-mini",
message_history: list = [],
**kwargs,
):
"""
Main method that orchestrates directive instantiation.
For ArbitraryRewrite, we analyze the entire pipeline rather than specific target ops.
"""
input_file_path = kwargs.get("input_file_path", None)
pipeline_code = kwargs.get("pipeline_code", None)
if not input_file_path:
raise ValueError(
"input_file_path is required for ArbitraryRewrite directive"
)
# Order ops according to pipeline steps before everything else
ordered_ops = self._order_ops_by_pipeline_steps(operators, pipeline_code)
# Step 1: Agent analyzes pipeline and generates edits
rewrite, message_history, call_cost = self.llm_instantiate(
ordered_ops,
input_file_path,
agent_llm,
message_history,
pipeline_code,
)
# Step 2: Apply the edits
return (
self.apply(ordered_ops, rewrite),
message_history,
call_cost,
)

View File

@ -0,0 +1,231 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from litellm import completion
from pydantic import BaseModel, Field
# Configuration constants
MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS = 3
DEFAULT_MODEL = "gpt-4.1"
DEFAULT_MAX_TPM = 5000000
DEFAULT_OUTPUT_DIR = "./outputs"
class TestResult(BaseModel):
test_name: str
passed: bool
reason: str
actual_output: Any
execution_error: Optional[str] = None
class DirectiveTestCase(BaseModel):
name: str
description: str
input_config: Dict[str, Any] | List[Dict[str, Any]]
target_ops: List[str]
expected_behavior: str
should_pass: bool = True
class Directive(BaseModel, ABC):
name: str = Field(..., description="The name of the directive")
formal_description: str = Field(
..., description="A description of the directive; e.g., map => map -> map"
)
nl_description: str = Field(
..., description="An english description of the directive"
)
when_to_use: str = Field(..., description="When to use the directive")
instantiate_schema_type: BaseModel = Field(
...,
description="The schema the agent must conform to when instantiating the directive",
)
example: str = Field(..., description="An example of the directive being used")
test_cases: List[DirectiveTestCase] = Field(
default_factory=list, description="Test cases for this directive"
)
def to_string_for_plan(self) -> str:
"""Serialize directive for prompts."""
parts = [
f"### {self.name}",
f"**Format:** {self.formal_description}",
f"**Description:** {self.nl_description}",
f"**When to Use:** {self.when_to_use}",
]
return "\n\n".join(parts)
@abstractmethod
def to_string_for_instantiate(self, *args, **kwargs) -> str:
pass
@abstractmethod
def llm_instantiate(self, *args, **kwargs):
pass
@abstractmethod
def apply(self, *args, **kwargs) -> list:
pass
@abstractmethod
def instantiate(self, *args, **kwargs):
pass
def run_tests(self, agent_llm: str = "gpt-4o-mini") -> List[TestResult]:
"""Run all test cases for this directive using LLM judge"""
import json
import os
import tempfile
results = []
for test_case in self.test_cases:
try:
# Create fake sample data and pipeline for directives that need them
sample_data = [
{
"text": "Sample document 1",
"feedback": "Great product!",
"doc1": "Document A",
"doc2": "Document B",
},
{
"text": "Sample document 2",
"feedback": "Could be better",
"doc1": "Report 1",
"doc2": "Report 2",
},
{
"text": "Sample document 3",
"feedback": "Excellent service",
"doc1": "Policy A",
"doc2": "Policy B",
},
]
fake_pipeline = {
"operations": (
[test_case.input_config]
if isinstance(test_case.input_config, dict)
else test_case.input_config
),
"name": "test_pipeline",
}
# Create temporary input file
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
json.dump(sample_data, f, indent=2)
temp_file_path = f.name
try:
# 1. Execute the directive
# instantiate returns (new_ops_plan, message_history, call_cost)
instantiate_result = self.instantiate(
operators=(
[test_case.input_config]
if isinstance(test_case.input_config, dict)
else test_case.input_config
),
target_ops=test_case.target_ops,
agent_llm=agent_llm,
input_file_path=temp_file_path,
pipeline_code=fake_pipeline,
)
# Handle both 2-tuple and 3-tuple returns
if isinstance(instantiate_result, tuple):
actual_output = instantiate_result[0]
else:
actual_output = instantiate_result
# 2. Use LLM judge to evaluate
judge_result = self._llm_judge_test(
test_case=test_case,
actual_output=actual_output,
agent_llm=agent_llm,
)
results.append(
TestResult(
test_name=test_case.name,
passed=judge_result["passed"],
reason=judge_result["reason"],
actual_output=actual_output,
)
)
finally:
# Clean up temporary file
try:
os.unlink(temp_file_path)
except Exception:
pass
except Exception as e:
# Handle execution errors
expected_to_fail = not test_case.should_pass
results.append(
TestResult(
test_name=test_case.name,
passed=expected_to_fail,
reason=f"Execution {'failed as expected' if expected_to_fail else 'failed unexpectedly'}: {str(e)}",
actual_output=None,
execution_error=str(e),
)
)
return results
def _llm_judge_test(
self, test_case: DirectiveTestCase, actual_output: Any, agent_llm: str
) -> Dict[str, Any]:
"""Use LLM as judge to evaluate test results"""
user_message = f"""
Evaluate whether the directive execution meets the expected behavior.
**Test Case:** {test_case.name}
**Description:** {test_case.description}
**Expected Behavior:** {test_case.expected_behavior}
**Should Pass:** {test_case.should_pass}
**Original Input Configuration:**
{test_case.input_config}
**Actual Output from Directive:**
{actual_output}
**Evaluation Criteria:**
- If should_pass=True: Does the output demonstrate the expected behavior?
- If should_pass=False: Did the directive appropriately reject/not modify the input?
- Does the transformation make logical sense?
- Are all required elements preserved?
Provide your reasoning and a boolean pass/fail decision.
"""
messages = [
{
"role": "system",
"content": f"You are an expert judge evaluating whether a {self.name} directive execution meets the specified criteria. Focus on logical correctness and adherence to expected behavior.",
},
{"role": "user", "content": user_message},
]
class JudgeResponse(BaseModel):
passed: bool
reason: str
response = completion(
model=agent_llm,
messages=messages,
response_format=JudgeResponse,
)
# Parse the JSON response
parsed_content = json.loads(response.choices[0].message.content)
return {"passed": parsed_content["passed"], "reason": parsed_content["reason"]}

View File

@ -0,0 +1,443 @@
import json
import re
from copy import deepcopy
from typing import Dict, List, Type
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
CascadeFilteringInstantiateSchema,
)
from .agent_utils import AgenticDirectiveRunner
from .base import Directive, DirectiveTestCase
class CascadeFilteringDirective(Directive):
name: str = Field(
default="cascade_filtering", description="The name of the directive"
)
formal_description: str = Field(
default="Filter => (Code Filter* -> Filter(gpt-5-nano)* ->) Filter"
)
nl_description: str = Field(
default="Optimizes filtering costs by injecting a cascade of cheaper filters before the main filter. Starts with deterministic code filters (cheapest), then gpt-5-nano filters (ordered by prompt length), before the original expensive filter. Pre-filters prioritize high recall (rarely rejecting valid documents) and can have lower precision (okay to let through some invalid docs that the main filter will catch)."
)
when_to_use: str = Field(
default="When you have an expensive Filter operation (using costly models or complex prompts) and the data contains patterns that allow for cheaper pre-filtering. The pre-filters must have high recall (not dropping valid documents) but can have lower precision, as the final filter provides the actual precision."
)
instantiate_schema_type: Type[BaseModel] = Field(
default=CascadeFilteringInstantiateSchema
)
example: str = Field(
default="""
# Example 1: Disjunction - Legal Document Relevance Filter
Target Operation:
- name: filter_litigation_relevant_docs
type: filter
model: gpt-4o
prompt: |
Determine if this document is relevant to our patent litigation case.
The document is relevant if it contains ANY of:
1. Prior art references dated before 2015-03-15 describing similar technology
2. Internal emails discussing the patent claims or invalidity arguments
3. Technical specifications that contradict our patent claims
4. License agreements mentioning the patents-in-suit
5. Expert testimony or declarations about the technology
6. Financial damages calculations or royalty discussions
Patent numbers in suit: 8,123,456 and 9,234,567
Technology area: adaptive bitrate streaming for video delivery
Document: {{ input.document_text }}
Metadata: {{ input.document_metadata }}
output:
schema:
is_relevant: boolean
Example InstantiateSchema (agent recognizes the OR conditions can be split):
CascadeFilteringInstantiateSchema(
code_pre_filters=[
CodePreFilter(
name="has_patent_numbers",
code=\"\"\"def transform(input_doc):
text = (input_doc.get('document_text', '') + ' ' + str(input_doc.get('document_metadata', {}))).lower()
# If patent numbers are mentioned, likely relevant
return '8,123,456' in text or '9,234,567' in text or 'patent' in text\"\"\",
reasoning="Documents mentioning our patent numbers or patents in general might be relevant"
),
CodePreFilter(
name="has_tech_or_legal_terms",
code=\"\"\"def transform(input_doc):
text = input_doc.get('document_text', '').lower()
# Must have streaming tech or legal terms to possibly be relevant
terms = ['streaming', 'bitrate', 'video', 'codec', 'prior art', 'license',
'damages', 'royalty', 'testimony', 'expert', 'invalidity']
return any(term in text for term in terms)\"\"\",
reasoning="Documents without technical or legal terminology cannot be relevant"
)
],
llm_pre_filters=[
LLMPreFilter(
name="check_prior_art_date",
prompt="Does this mention technology from before March 2015? {{ input.document_metadata }} If yes; set 'keep' to true. If no; set 'keep' to false.",
reasoning="Prior art must predate the patent filing"
),
LLMPreFilter(
name="check_email_discussion",
prompt="Is this an email or discussion about patents? {{ input.document_text }} If yes; set 'keep' to true. If no; set 'keep' to false.",
reasoning="Internal emails about patents might be relevant"
),
LLMPreFilter(
name="check_technical_specs",
prompt="Does this contain technical specifications? {{ input.document_text }} If yes; set 'keep' to true. If no; set 'keep' to false.",
reasoning="Technical specs might contradict our claims"
)
],
analysis_summary="Filter has 6 OR criteria. Pre-filters check each criterion cheaply, eliminating 70% of documents before expensive analysis"
)
# Example 2: Complex Reasoning - Misinformation Detection Filter
Target Operation:
- name: filter_misinformation
type: filter
model: gpt-4o
prompt: |
Analyze this social media post to determine if it contains health misinformation.
Consider:
- Claims that contradict scientific consensus
- Misleading statistics or cherry-picked data
- Conspiracy theories about health organizations
- Dangerous medical advice without proper qualifications
- Manipulation of legitimate studies to support false conclusions
- Emotional manipulation to spread fear about vaccines/treatments
Requires nuanced understanding of:
- Context and speaker credibility
- Difference between opinion and factual claims
- Satirical vs serious content
- Cultural/religious beliefs vs dangerous misinformation
Post: {{ input.post_content }}
Author Profile: {{ input.author_info }}
Engagement Metrics: {{ input.engagement_stats }}
output:
schema:
is_misinformation: boolean
Example InstantiateSchema (agent deduces proxy filters for complex reasoning):
CascadeFilteringInstantiateSchema(
code_pre_filters=[
CodePreFilter(
name="has_health_content",
code=\"\"\"def transform(input_doc):
text = input_doc.get('post_content', '').lower()
# Must mention health/medical topics to potentially be health misinfo
health_terms = ['vaccine', 'covid', 'cancer', 'cure', 'treatment', 'doctor',
'medical', 'health', 'disease', 'virus', 'immune', 'pharmaceutical']
return any(term in text for term in health_terms)\"\"\",
reasoning="Posts without health-related content cannot be health misinformation"
),
CodePreFilter(
name="has_claims_language",
code=\"\"\"def transform(input_doc):
text = input_doc.get('post_content', '').lower()
# Look for language that makes claims rather than just sharing experience
claim_markers = ['proven', 'studies show', 'research', 'scientists', 'they don\\'t want',
'truth about', 'actually', 'fact', 'evidence', 'causes', 'prevents']
return any(marker in text for marker in claim_markers) or '!' in text\"\"\",
reasoning="Posts without claim-making language are less likely to be misinformation"
)
],
llm_pre_filters=[
LLMPreFilter(
name="check_medical_claims",
prompt="Does this post make medical or health claims? {{ input.post_content }} If yes; set 'keep' to true. If no; set 'keep' to false.",
reasoning="Posts making medical claims need detailed analysis"
),
LLMPreFilter(
name="check_high_engagement",
prompt="Is this a high-engagement post (>1000 shares or viral)? {{ input.engagement_stats }} If yes; set 'keep' to true. If no; set 'keep' to false.",
reasoning="High-engagement posts have more potential for harm if they contain misinformation"
)
],
analysis_summary="Complex task requiring reasoning about credibility, context, and scientific accuracy. Pre-filters identify posts that definitely need review (30% of total) by checking for health content, claim-making language, and viral spread potential"
)
"""
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="single_filter_cascade",
description="Should create cascade of filters before expensive filter",
input_config={
"name": "filter_quality",
"type": "filter",
"model": "gpt-4o",
"prompt": "Is this a high-quality research paper? Paper: {{ input.paper }}",
"output": {"schema": {"is_quality": "boolean"}},
},
target_ops=["filter_quality"],
expected_behavior="Should inject cheaper pre-filters (code and/or gpt-5-nano) before the expensive gpt-4o filter",
should_pass=True,
)
]
)
def __eq__(self, other):
return isinstance(other, CascadeFilteringDirective)
def __hash__(self):
return hash("CascadeFilteringDirective")
def _extract_input_keys(self, prompt: str) -> List[str]:
"""Extract input field names from a Jinja2 prompt template."""
# Match {{ input.fieldname }} patterns
pattern = r"\{\{\s*input\.(\w+)\s*\}\}"
matches = re.findall(pattern, prompt)
return list(set(matches)) # Return unique field names
def to_string_for_instantiate(
self, target_ops_configs: List[Dict], pipeline_code: Dict = None
) -> str:
"""Generate prompt for agent to analyze data and create cascade filters."""
assert (
len(target_ops_configs) == 1
), "CascadeFiltering directive only supports single target operation"
assert (
target_ops_configs[0]["type"] == "filter"
), "Target operation must be a filter"
op = target_ops_configs[0]
input_keys = self._extract_input_keys(op.get("prompt", ""))
pipeline_context = ""
if pipeline_code:
pipeline_context = f"""
Pipeline Context:
{json.dumps(pipeline_code, indent=2)}
The target filter '{op['name']}' fits into this broader pipeline. Consider what types of documents flow into this filter and what the downstream operations expect.
"""
return (
f"You are an expert at optimizing filter operations by creating efficient filter cascades.\n\n"
f"Target Filter Operation:\n"
f"{json.dumps(op, indent=2)}\n\n"
f"Input fields used in the original prompt: {input_keys}\n\n"
f"{pipeline_context}\n"
f"Your task is to analyze sample input data and create a cascade of cheaper pre-filters that can eliminate many documents before they reach the expensive main filter.\n\n"
f"You will be given access to sample input data through a read_next_docs() function. Use this to:\n"
f"1. Understand what documents should pass vs. fail the main filter\n"
f"2. Identify simple patterns that can predict filter outcomes\n"
f"3. Design code-based filters for deterministic patterns (cheapest)\n"
f"4. Design gpt-5-nano filters for simple semantic patterns (still cheap)\n\n"
f"Guidelines for code_pre_filters:\n"
f"- Must be Python functions with signature: def transform(input_doc): return bool\n"
f"- Should use regex, keyword matching, length checks, or simple logic\n"
f"- Must be deterministic and fast\n"
f"- Return True to keep the document, False to filter it out\n"
f"- Access document fields using input_doc.get('fieldname', default_value)\n\n"
f"Guidelines for llm_pre_filters:\n"
f"- Must be Jinja2 templates that reference input fields using {{{{ input.fieldname }}}} syntax\n"
f"- Must reference the same input fields as the original prompt\n"
f"- Should be simple, focused prompts that elicit a yes/no or true/false response\n"
f"- Keep prompts SHORT - they will be ordered by length automatically\n"
f"- Must use at least one of these input fields: {input_keys}\n\n"
f"General Guidelines:\n"
f"- Pre-filters MUST have HIGH RECALL (rarely reject documents that would pass the main filter)\n"
f"- Pre-filters can have LOWER PRECISION (okay to let through documents that fail - the main filter will catch them)\n"
f"- When in doubt, let the document through (return True) - false negatives are worse than false positives\n\n"
f"- All documents will have the same keys. So don't write code that checks if a particular key exists in a document or not.\n"
f"Example transformation:\n"
f"{self.example}\n\n"
f"Analyze samples strategically - look for patterns that distinguish documents that pass vs fail the filter.\n"
f"When you have identified enough patterns to create effective pre-filters, output your result.\n\n"
f"Remember: The goal is to reduce costs by filtering out documents early with cheaper methods while maintaining the same final accuracy."
)
def llm_instantiate(
self,
target_ops_configs: List[Dict],
input_file_path: str,
agent_llm: str,
message_history: list = [],
pipeline_code: Dict = None,
):
"""Use agentic approach to analyze data and generate cascade filters."""
# Load sample input data
try:
with open(input_file_path, "r") as f:
input_data = json.load(f)
if not isinstance(input_data, list) or len(input_data) == 0:
raise ValueError(
"Input file must contain a non-empty list of sample data"
)
except Exception as e:
raise Exception(
f"Failed to load input data from {input_file_path}: {str(e)}"
)
# Extract input keys from original prompt for validation
original_prompt = target_ops_configs[0].get("prompt", "")
expected_input_keys = self._extract_input_keys(original_prompt)
def validate_llm_prompts(schema_instance):
"""Validate that LLM prompts use correct input fields."""
for llm_filter in schema_instance.llm_pre_filters:
used_keys = set(
re.findall(r"\{\{\s*input\.(\w+)\s*\}\}", llm_filter.prompt)
)
invalid_keys = used_keys - set(expected_input_keys)
if invalid_keys:
raise ValueError(
f"LLM filter '{llm_filter.name}' uses invalid input fields: {invalid_keys}. "
f"Available fields from original prompt: {expected_input_keys}"
)
if not used_keys:
raise ValueError(
f"LLM filter '{llm_filter.name}' must reference at least one input field "
f"from the original prompt: {expected_input_keys}"
)
# Set up agentic runner with validation
runner = AgenticDirectiveRunner(
input_data=input_data,
agent_llm=agent_llm,
validation_func=validate_llm_prompts,
)
# Create system prompt
system_prompt = (
"You are an expert at optimizing data processing pipelines by creating efficient filter cascades. "
"You analyze data to identify patterns that allow for cheap pre-filtering before expensive operations. "
"Your goal is to reduce costs while maintaining accuracy by using progressively more expensive filters."
)
# Create initial user message
initial_message = self.to_string_for_instantiate(
target_ops_configs, pipeline_code
)
# Run the agentic loop
try:
schema, updated_message_history, call_cost = runner.run_agentic_loop(
system_prompt=system_prompt,
initial_user_message=initial_message,
response_schema=CascadeFilteringInstantiateSchema,
)
# Update message history
message_history.extend(updated_message_history)
return schema, message_history, call_cost
except Exception as e:
raise Exception(
f"Failed to instantiate cascade_filtering directive: {str(e)}"
)
def apply(
self,
global_default_model: str,
ops_list: List[Dict],
target_ops: List[str],
rewrite: CascadeFilteringInstantiateSchema,
) -> List[Dict]:
"""Apply the directive by injecting pre-filters before the target filter."""
new_ops_list = []
for op in ops_list:
if op["name"] in target_ops:
# First, add all code filters
for i, code_filter in enumerate(rewrite.code_pre_filters):
code_filter_op = {
"name": f"{code_filter.name}_{op['name']}",
"type": "code_filter",
"code": code_filter.code,
}
new_ops_list.append(code_filter_op)
# Sort LLM filters by prompt length (shortest first for cost efficiency)
sorted_llm_filters = sorted(
rewrite.llm_pre_filters, key=lambda f: len(f.prompt)
)
# Then, add all LLM filters (using gpt-5-nano)
for i, llm_filter in enumerate(sorted_llm_filters):
llm_filter_op = {
"name": f"{llm_filter.name}_{op['name']}",
"type": "filter",
"model": "gpt-5-nano",
"prompt": llm_filter.prompt,
"output": {"schema": {"keep": "boolean"}},
}
new_ops_list.append(llm_filter_op)
# Finally, add the original filter
new_ops_list.append(deepcopy(op))
else:
# Keep other operations as-is
new_ops_list.append(deepcopy(op))
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
global_default_model: str = None,
**kwargs,
):
"""
Main method that orchestrates directive instantiation:
1. Use agentic approach to analyze data and generate cascade filters
2. Apply the transformation by injecting pre-filters
"""
assert (
len(target_ops) == 1
), "CascadeFiltering directive requires exactly one target operation"
input_file_path = kwargs.get("input_file_path", None)
pipeline_code = kwargs.get("pipeline_code", None)
if not input_file_path:
raise ValueError(
"input_file_path is required for CascadeFiltering directive"
)
# Get configuration for target operation
target_ops_configs = [op for op in operators if op["name"] in target_ops]
if not target_ops_configs:
raise ValueError(f"Target operation {target_ops[0]} not found in operators")
if target_ops_configs[0]["type"] != "filter":
raise ValueError(
f"Target operation {target_ops[0]} must be a filter operation"
)
# Step 1: Agent analyzes data and generates cascade filters
rewrite, message_history, call_cost = self.llm_instantiate(
target_ops_configs,
input_file_path,
agent_llm,
message_history,
pipeline_code,
)
# Step 2: Apply transformation
return (
self.apply(global_default_model, operators, target_ops, rewrite),
message_history,
call_cost,
)

View File

@ -0,0 +1,313 @@
import json
import re
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import ChainingInstantiateSchema
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class ChainingDirective(Directive):
name: str = Field(default="chaining", description="The name of the directive")
formal_description: str = Field(default="Op => Map* -> Op")
nl_description: str = Field(
default="Decompose a complex operation into a sequence by inserting one or more Map steps that rewrite the input for the next operation. Each Map step outputs a 'result' string, and the downstream operation uses this result in its prompt."
)
when_to_use: str = Field(
default="When the original task is too complex for one step and should be split into a series (e.g., first extract key facts, then generate a summary based on those facts)."
)
instantiate_schema_type: Type[BaseModel] = ChainingInstantiateSchema
example: str = Field(
default=(
"Original Op (MapOpConfig):\n"
"- name: extract_newly_prescribed_treatments\n"
" type: map\n"
" prompt: |\n"
" For a hospital discharge summary, extract every treatment that was prescribed specifically for newly diagnosed conditions.\n"
' Discharge summary: "{{ input.summary }}"\n'
" output:\n"
" schema:\n"
" treatments: list[str]\n"
"\n"
"Example InstantiateSchema (must refer to the same input document keys as the original Op, and subsequent Map operators must refer to the output of the previous Map operator):\n"
"[\n"
" MapOpConfig(\n"
" name='identify_new_conditions',\n"
" prompt='''Review the following hospital discharge summary:\n"
"{{ input.summary }}\n"
"Identify all medical conditions that are explicitly marked as new diagnoses (e.g., 'new diagnosis of atrial fibrillation', 'recent onset heart failure').\n"
"Return a list of newly diagnosed conditions.''',\n"
" output_keys=['new_conditions'],\n"
" ),\n"
" MapOpConfig(\n"
" name='extract_treatments_for_new_conditions',\n"
" prompt='''For each newly diagnosed condition listed below, extract every treatment or medication prescribed for that specific condition from the discharge summary.\n"
"Discharge summary: {{ input.summary }}\n"
"Newly diagnosed conditions: {{ input.new_conditions }}\n"
"Return a list of prescribed treatments or medications for each condition.''',\n"
" output_keys=['treatments'],\n"
" ),\n"
"]"
),
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="complex_contract_extraction",
description="Should decompose complex contract extraction into separate ops for each term type and a final unification op",
input_config={
"name": "extract_contract_terms",
"type": "map",
"prompt": "Extract all payment terms, liability clauses, and termination conditions from: {{ input.contract }}",
"output": {"schema": {"terms": "list[str]"}},
},
target_ops=["extract_contract_terms"],
expected_behavior="Should create one op for each of: payment terms, liability clauses, and termination conditions, then a final op to unify all results into 'terms'",
should_pass=True,
),
DirectiveTestCase(
name="medical_treatment_analysis",
description="Should chain complex medical analysis into steps",
input_config={
"name": "analyze_patient_care",
"type": "map",
"prompt": "From this medical record, identify all diagnoses, treatments prescribed, and patient outcomes: {{ input.medical_record }}",
"output": {"schema": {"analysis": "string"}},
},
target_ops=["analyze_patient_care"],
expected_behavior="Should decompose into separate steps for diagnoses identification, treatment extraction, and outcome analysis",
should_pass=True,
),
DirectiveTestCase(
name="research_paper_analysis",
description="Should chain research paper analysis into structured steps",
input_config={
"name": "analyze_research_paper",
"type": "map",
"prompt": "From this research paper, extract the methodology, key findings, limitations, and future work directions: {{ input.paper }}",
"output": {"schema": {"analysis": "string"}},
},
target_ops=["analyze_research_paper"],
expected_behavior="Should decompose into separate steps for methodology extraction, findings identification, limitation analysis, and future work extraction",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, ChainingDirective)
def __hash__(self):
return hash("ChainingDirective")
def to_string_for_instantiate(self, original_op: Dict) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
original_op (str): The YAML or string representation of the original operation.
Returns:
str: The agent prompt for instantiating the directive.
"""
return (
f"You are an expert at decomposing complex data processing operations into modular steps.\n\n"
f"Original Operation:\n"
f"{str(original_op)}\n"
f"Directive: {self.name}\n"
f"Your task is to instantiate this directive by generating a list of new Map operators (as MapOpConfig objects) that decompose the original operation into a sequence of simpler steps. "
f"Each Map step should output a 'result' or other relevant keys, and downstream steps should use the outputs of previous steps as their input. "
f"Ensure that the chain of Map operators together accomplishes the intent of the original operation, but in a more modular and stepwise fashion.\n\n"
f"""Key Issues to ensure:\n
1. Ensure the new prompts don't reference categories, lists, or criteria without providing them or making them available from previous steps.\n
2. Every detail, category, instruction, requirement, and definition from the original must be present in the new configuration.\n
3. Confirm that each step can access all required information either from the original document or from outputs of preceding steps.\n"""
f"Example:\n"
f"{self.example}\n\n"
f"Please output only the InstantiateSchema (a list of MapOpConfig objects) for the new chain, referring to the same input document keys as the original operation and chaining outputs appropriately."
)
def llm_instantiate(
self,
original_op: Dict,
expected_input_keys: List[str],
expected_output_keys: List[str],
agent_llm: str,
message_history: list = [],
):
"""
Use LLM to instantiate this directive by decomposing the original operation.
Args:
original_op (Dict): The original operation.
expected_input_keys (List[str]): A list of input keys that the operation is expected to reference in its prompt. Each key should correspond to a field in the input document that must be used by the operator.
expected_output_keys (List[str]): A list of output keys that the last operation is expected to produce.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
Returns:
ChainingInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for document processing pipelines.",
},
{
"role": "user",
"content": self.to_string_for_instantiate(original_op),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=ChainingInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
if "new_ops" not in parsed_res:
raise ValueError(
"Response from LLM is missing required key 'new_ops'"
)
new_ops = parsed_res["new_ops"]
schema = ChainingInstantiateSchema(new_ops=new_ops)
# Validate the chain with required input/output keys
ChainingInstantiateSchema.validate_chain(
new_ops=schema.new_ops,
required_input_keys=expected_input_keys,
expected_output_keys=expected_output_keys,
)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model,
ops_list: List[Dict],
target_op: str,
rewrite: ChainingInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config.
"""
# Create a copy of the pipeline config
new_ops_list = deepcopy(ops_list)
# Find position of the target ops to replace
for i, op in enumerate(ops_list):
if op["name"] == target_op:
pos_to_replace = i
orig_op = op
break
# pos_to_replace = [i for i, op in enumerate(ops_list) if op["name"] == target_op][0]
# Create the new ops from the rewrite
new_ops = []
defualt_model = global_default_model
if "model" in orig_op:
defualt_model = orig_op["model"]
for i, op in enumerate(rewrite.new_ops):
if i < len(rewrite.new_ops) - 1:
new_ops.append(
{
"name": op.name,
"type": "map",
"prompt": op.prompt,
"model": defualt_model,
"litellm_completion_kwargs": {"temperature": 0},
"output": {"schema": {key: "string" for key in op.output_keys}},
}
)
else:
# Last op in the chain
new_ops.append(
{
"name": op.name,
"type": "map",
"prompt": op.prompt,
"model": defualt_model,
"litellm_completion_kwargs": {"temperature": 0},
"output": new_ops_list[pos_to_replace]["output"],
}
)
# Remove the target op and insert the new ops
new_ops_list.pop(pos_to_replace)
new_ops_list[pos_to_replace:pos_to_replace] = new_ops
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there is only one target op
assert (
len(target_ops) == 1
), "There must be exactly one target op to instantiate this chaining directive"
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
# Get the expected input/output keys
expected_output_keys = list(target_op_config["output"]["schema"].keys())
# Extract expected input keys from the target op's prompt template
prompt_template = target_op_config["prompt"]
# Find all occurrences of {{ input.key }} in the prompt
input_key_pattern = r"\{\{\s*input\.([^\}\s]+)\s*\}\}"
expected_input_keys = list(set(re.findall(input_key_pattern, prompt_template)))
print("input key: ", expected_input_keys)
print("output key: ", expected_output_keys)
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
target_op_config,
expected_input_keys,
expected_output_keys,
agent_llm,
message_history,
)
# Apply the rewrite to the operators
new_ops_plan = self.apply(
global_default_model, operators, target_ops[0], rewrite
)
return new_ops_plan, message_history, call_cost

View File

@ -0,0 +1,311 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import ChangeModelInstantiateSchema
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class ChangeModelDirective(Directive):
name: str = Field(default="change model", description="The name of the directive")
formal_description: str = Field(
default="Op => Op* (same operation with a different model choice)"
)
nl_description: str = Field(
default="Rewrites an operator to use a different LLM model based on task requirements. Generally, simpler tasks like extraction or classification may work well with cheaper models (gpt-4o-mini, gpt-5-nano), while complex reasoning tasks often benefit from more powerful models (gpt-5), though actual performance and constraints should guide the choice."
)
when_to_use: str = Field(
default="When the current model choice may not be optimal for the task requirements, considering factors like task complexity, performance needs, cost constraints, and quality requirements."
)
instantiate_schema_type: Type[BaseModel] = ChangeModelInstantiateSchema
example: str = Field(
default=(
"Original Op (MapOpConfig):\n"
"- name: extract_insights\n"
" type: map\n"
" prompt: |\n"
" From the user log below, list 2-3 concise insights (1-2 words each) and 1-2 supporting actions per insight.\n"
" Return as a list of dictionaries with 'insight' and 'supporting_actions'.\n"
" Log: {{ input.log }}\n"
" output:\n"
" schema:\n"
' insights_summary: "string"\n'
" model: gpt-5\n"
"\n"
"Example InstantiateSchema:\n"
"{\n"
' "model": "gpt-4o-mini"\n'
"}"
),
)
allowed_model_list: List[str] = Field(
default_factory=list,
description="The allowed list of models to choose from",
)
model_info: str = Field(
default=(
"""
OpenAI-MRCR evaluates a model's ability to locate and disambiguate multiple well-hidden "needles" within a large context.
Below are the actual performance scores for the 8-needle retrieval task at various context lengths. Use these results to compare the retrieval capabilities of each model.
The below results are the mean match ratio.
Input Tokens (1000s) | GPT-5 | GPT-5 nano | GPT-4o mini
---------------------|---------|--------------|-------------------|
8 | 99% | 69% | 32% |
16 | 100% | 64% | 30% |
32 | 96% | 55% | 27% |
64 | 98% | 45% | 25% |
128 | 97% | 44% | 25% |
256 | 92% | 40% | - |
The context window and pricing details for each model are shown below (token prices are per 1 million tokens):
Model | GPT-5-nano | GPT-4o-mini | GPT-5
-------------------|--------------|-------------|----------
Context Window | 400,000 | 128,000 | 400,000
Max Output Tokens | 128,000 | 16,384 | 128,000
Input Token Price | $0.05 | $0.15 | $1.25
Output Token Price | $0.40 | $0.60 | $10
"""
),
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="cost_optimization_simple_task",
description="Should suggest cheaper model for simple extraction task",
input_config={
"name": "extract_names",
"type": "map",
"prompt": "Extract person names from: {{ input.text }}",
"output": {"schema": {"names": "list[str]"}},
"model": "gpt-5",
},
target_ops=["extract_names"],
expected_behavior="Should recommend gpt-4o-mini or gpt-5-nano for cost savings on simple task",
should_pass=True,
),
DirectiveTestCase(
name="complex_analysis_needs_powerful_model",
description="Should recommend powerful model for complex reasoning task",
input_config={
"name": "analyze_legal_implications",
"type": "map",
"prompt": "Analyze the legal implications and potential risks in this complex contract: {{ input.contract }}",
"output": {"schema": {"analysis": "string"}},
"model": "gpt-4o-mini",
},
target_ops=["analyze_legal_implications"],
expected_behavior="Should recommend gpt-5 for complex legal analysis requiring strong reasoning",
should_pass=True,
),
DirectiveTestCase(
name="high_volume_processing_optimization",
description="Should optimize model for high-volume document processing",
input_config={
"name": "batch_document_summary",
"type": "map",
"prompt": "Summarize this document in 2-3 sentences: {{ input.document }}",
"output": {"schema": {"summary": "str"}},
"model": "gpt-5",
},
target_ops=["batch_document_summary"],
expected_behavior="Should recommend faster/cheaper model like gpt-4o-mini or gpt-5-nano for high-volume simple summarization",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, ChangeModelDirective)
def __hash__(self):
return hash("ChangeModelDirective")
def to_string_for_instantiate(self, original_op: Dict, optimize_goal) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
original_op (str): The YAML or string representation of the original operation.
Returns:
str: The agent prompt for instantiating the directive.
"""
if optimize_goal == "acc":
return (
f"You are an expert at choosing the most suitable model for a given task based on complexity and cost considerations.\n\n"
f"Original Operation:\n"
f"{str(original_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to instantiate this directive by suggesting a better model for executing the original operation.\n\n"
f"MODEL SELECTION CONSIDERATIONS:\n"
f"• Generally, simpler tasks (extraction, basic classification, straightforward summarization) may work well with cheaper models like gpt-4o-mini or gpt-5-nano\n"
f"• Complex reasoning tasks (analysis, interpretation, multi-step thinking, legal/medical analysis) often benefit from more powerful models like gpt-5\n"
f"• However, consider actual performance needs, quality requirements, and cost constraints when making the choice\n"
f"• Sometimes a powerful model may be needed for seemingly simple tasks if quality is critical, or a cheaper model may suffice for complex tasks if budget is constrained\n\n"
f"You have a list of allowed models to choose from: {str(self.allowed_model_list)}.\n\n"
f"Consider the information about the allowed models: \n {self.model_info}\n"
f"Your response should include the new model choice for the operation."
f"Ensure that your chosen model is in the list of allowed models."
f"Example:\n"
f"{self.example}\n\n"
f"Please output only the ChangeModelInstantiateSchema as JSON."
)
else:
return (
f"You are an expert at choosing the most suitable model for a given task to optimize cost.\n\n"
f"Original Operation:\n"
f"{str(original_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to instantiate this directive by generating a ChangeModelConfig that suggests a more cost-effective model for executing the original operation."
f"You have a list of allowed models to choose from: {str(self.allowed_model_list)}.\n\n"
f"Consider the information about the allowed models: \n {self.model_info}\n"
f"The ChangeModelConfig should include the new model choice for the operation that reduces costs while maintaining adequate performance."
f"Ensure that your chosen model is in the list of allowed models."
f"Example:\n"
f"{self.example}\n\n"
f"Please output only the InstantiateSchema (a ChangeModelConfig object)."
)
def llm_instantiate(
self,
global_default_model: str,
original_op: Dict,
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
):
"""
Use LLM to instantiate this directive.
Args:
original_op (Dict): The original operation.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
Returns:
ChangeModelInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for document processing pipelines.",
},
{
"role": "user",
"content": self.to_string_for_instantiate(
original_op, optimize_goal
),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=ChangeModelInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ChangeModelInstantiateSchema(**parsed_res)
orig_model = global_default_model
if "model" in original_op:
orig_model = original_op.get("model")
# Validate the model is in the allowed model list
ChangeModelInstantiateSchema.validate_diff_model_in_list(
orig_model=orig_model,
model=schema.model,
list_of_model=self.allowed_model_list,
)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model: str,
ops_list: List[Dict],
target_op: str,
rewrite: ChangeModelInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config by adding gleaning configuration to the target operator.
"""
# Create a copy of the pipeline config
new_ops_list = deepcopy(ops_list)
# Find position of the target op to modify
pos_to_replace = [
i for i, op in enumerate(ops_list) if op["name"] == target_op
][0]
# Add change model configuration to the target operator
target_operator = new_ops_list[pos_to_replace]
target_operator["model"] = rewrite.model
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
allowed_model_list: List[str] = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Update allowed_model_list if provided
if allowed_model_list is not None:
self.allowed_model_list = allowed_model_list
new_ops_list = deepcopy(operators)
inst_error = 0
for target_op in target_ops:
target_op_config = [op for op in operators if op["name"] == target_op][0]
# Instantiate the directive
try:
rewrite, message_history, call_cost = self.llm_instantiate(
global_default_model,
target_op_config,
agent_llm,
message_history,
optimize_goal=optimize_goal,
)
print(rewrite)
except Exception:
inst_error += 1
new_ops_list = self.apply(
global_default_model, new_ops_list, target_op, rewrite
)
if inst_error == len(target_ops):
print("CHANEG MODEL ERROR")
return None, message_history
return new_ops_list, message_history, call_cost

View File

@ -0,0 +1,303 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import ChangeModelInstantiateSchema
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class ChangeModelAccDirective(Directive):
name: str = Field(
default="change model acc", description="The name of the directive"
)
formal_description: str = Field(
default="Op => Op* (same operation with a more accurate model choice)"
)
nl_description: str = Field(
default="Rewrites an operator to use a more powerful LLM model to optimize accuracy. Prioritizes model performance and quality over cost considerations, typically suggesting more capable models like gpt-5 for complex reasoning tasks."
)
when_to_use: str = Field(
default="When accuracy and quality are the primary concerns, and cost is secondary. Suitable for complex reasoning tasks, critical analysis, or when maximum model performance is needed. Usually changing to a more expensive model will improve accuracy, so you should try this directive if it has not been used in the past iterations."
)
instantiate_schema_type: Type[BaseModel] = ChangeModelInstantiateSchema
example: str = Field(
default=(
"Original Op (MapOpConfig):\n"
"- name: analyze_complex_data\n"
" type: map\n"
" prompt: |\n"
" Analyze this complex financial data and provide detailed insights with risk assessment.\n"
" Data: {{ input.financial_data }}\n"
" output:\n"
" schema:\n"
' analysis: "string"\n'
" model: gpt-4o-mini\n"
"\n"
"Example InstantiateSchema:\n"
"{\n"
' "model": "gpt-5"\n'
"}"
),
)
allowed_model_list: List[str] = Field(
default_factory=list,
description="The allowed list of models to choose from",
)
model_info: str = Field(
default=(
"""
OpenAI-MRCR evaluates a model's ability to locate and disambiguate multiple well-hidden "needles" within a large context.
Below are the actual performance scores for the 2-needle retrieval task at various context lengths. Use these results to compare the retrieval capabilities of each model.
The below results are the mean match ratio.
Context Length | gpt-5 | gpt-5-nano | gpt-4o-mini | gemini-2.5-pro | gemini-2.5-flash | gpt-4.1 | gpt-4.1-mini | gpt-4.1-nano | gemini-2.5-flash-lite
---------------|-------|------------|-------------|----------------|------------------|---------|--------------|--------------|----------------------
128k | 97% | 44% | 25% | 83.6% | 86.2% | 61.3% | 47.1% | 36.7% | 39.9%
1M | - | - | - | 62.8% | 60.0% | 45.9% | 34.6% | 14.2% | 18.1%
The context window and pricing details for each model are shown below (token prices are per 1 million tokens):
| Family | Model | Input Price /1M | Output Price /1M | Context Window (API) |
|----------------|----------------------------|---------------------------------|-----------------------------------|-----------------------------|
| **GPT-5** | azure/gpt-5 | $1.25 | $10.00 | 400K (272K in + 128K out) |
| | azure/gpt-5-mini | $0.25 | $2.00 | 400K |
| | azure/gpt-5-nano | $0.05 | $0.40 | 400K |
| **GPT-4.1** | azure/gpt-4.1 | $2.00 | $8.00 | 1M |
| | azure/gpt-4.1-mini | $0.40 | $1.60 | 1M |
| | azure/gpt-4.1-nano | $0.10 | $0.40 | 1M |
| **GPT-4o** | azure/gpt-4o | $2.50 | $10.00 | 128K |
| | azure/gpt-4o-mini | $0.15 | $0.60 | 128K (16K output cap) |
| **Gemini 2.5** | gemini/gemini-2.5-pro | $1.25 (200K) / $2.50 (>200K) | $10.00 (200K) / $15.00 (>200K) | 1M (2M soon) |
| | gemini/gemini-2.5-flash | $0.30 | $2.50 | 1M |
| | gemini/gemini-2.5-flash-lite | $0.10 | $0.40 | 1M |
"""
),
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="complex_reasoning_needs_powerful_model",
description="Should recommend most powerful model for complex reasoning task",
input_config={
"name": "analyze_legal_implications",
"type": "map",
"prompt": "Analyze the legal implications and potential risks in this complex contract: {{ input.contract }}",
"output": {"schema": {"analysis": "string"}},
"model": "gpt-4o-mini",
},
target_ops=["analyze_legal_implications"],
expected_behavior="Should recommend gpt-5 for complex legal analysis requiring strong reasoning",
should_pass=True,
),
DirectiveTestCase(
name="accuracy_over_cost_scientific_analysis",
description="Should prioritize accuracy for scientific analysis task",
input_config={
"name": "analyze_research_data",
"type": "map",
"prompt": "Perform detailed statistical analysis and provide research insights: {{ input.data }}",
"output": {"schema": {"insights": "string"}},
"model": "gpt-5-nano",
},
target_ops=["analyze_research_data"],
expected_behavior="Should recommend gpt-5 for accurate scientific analysis despite higher cost",
should_pass=True,
),
DirectiveTestCase(
name="critical_medical_analysis_high_accuracy",
description="Should recommend most accurate model for critical medical analysis",
input_config={
"name": "medical_diagnosis_support",
"type": "map",
"prompt": "Analyze medical symptoms and provide diagnostic insights: {{ input.symptoms }}",
"output": {"schema": {"diagnosis": "string"}},
"model": "gpt-4o-mini",
},
target_ops=["medical_diagnosis_support"],
expected_behavior="Should recommend gpt-5 for critical medical analysis requiring highest accuracy",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, ChangeModelAccDirective)
def __hash__(self):
return hash("ChangeModelAccDirective")
def to_string_for_instantiate(self, original_op: Dict) -> str:
"""
Generate a prompt for an agent to instantiate this directive for accuracy optimization.
Args:
original_op (str): The YAML or string representation of the original operation.
Returns:
str: The agent prompt for instantiating the directive.
"""
return (
f"You are an expert at choosing the most accurate and powerful model for a given task, prioritizing quality over cost.\n\n"
f"Original Operation:\n"
f"{str(original_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to instantiate this directive by suggesting the most accurate model for executing the original operation.\n\n"
f"TASK COMPLEXITY ANALYSIS AND MODEL SELECTION:\n"
f"First, carefully analyze the original operation to understand:\n"
f"• What specific task is being performed (extraction, analysis, transformation, reasoning, etc.)\n"
f"• How much cognitive complexity and intelligence is required\n"
f"• Whether the task involves simple pattern matching or sophisticated reasoning\n"
f"• If the task requires domain expertise, multi-step thinking, or nuanced understanding\n"
f"• The criticality and precision requirements of the output\n"
f"• If the task requires processing very long context (1M+ tokens)\n\n"
f"Based on your analysis of task complexity, select the model that will provide the most accurate response:\n"
f"• For simple extraction or formatting tasks: Consider efficient models from the available options\n"
f"• For moderate complexity tasks requiring some reasoning: Use capable models from gpt-5 or gemini series\n"
f"• For complex reasoning, analysis, interpretation, legal/medical tasks, or critical decisions: Strongly prefer the most advanced models from gpt-5 or gemini series\n"
f"• For tasks requiring very long context (1M+ tokens): Consider models with extended context like gpt-4.1 series or gemini models\n"
f"• For highly specialized or extremely complex cognitive tasks: Use the most powerful model available\n\n"
f"Remember: The goal is maximum accuracy given the intelligence and context requirements of the specific task.\n"
f"You have a list of allowed models to choose from: {str(self.allowed_model_list)}.\n\n"
f"Consider the information about the allowed models: \n {self.model_info}\n"
f"Your response should include the new model choice for the operation that maximizes accuracy given the task complexity and context requirements."
f"Ensure that your chosen model is in the list of allowed models."
f"Example:\n"
f"{self.example}\n\n"
f"Please output only the ChangeModelInstantiateSchema as JSON."
)
def llm_instantiate(
self,
global_default_model: str,
original_op: Dict,
agent_llm: str,
message_history: list = [],
):
"""
Use LLM to instantiate this directive for accuracy optimization.
Args:
original_op (Dict): The original operation.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
Returns:
ChangeModelInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for document processing pipelines focused on accuracy optimization.",
},
{
"role": "user",
"content": self.to_string_for_instantiate(original_op),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=ChangeModelInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ChangeModelInstantiateSchema(**parsed_res)
orig_model = global_default_model
if "model" in original_op:
orig_model = original_op.get("model")
# Validate the model is in the allowed model list
ChangeModelInstantiateSchema.validate_diff_model_in_list(
orig_model=orig_model,
model=schema.model,
list_of_model=self.allowed_model_list,
)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model: str,
ops_list: List[Dict],
target_op: str,
rewrite: ChangeModelInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config by changing the model of the target operator.
"""
# Create a copy of the pipeline config
new_ops_list = deepcopy(ops_list)
# Find position of the target op to modify
pos_to_replace = [
i for i, op in enumerate(ops_list) if op["name"] == target_op
][0]
# Add change model configuration to the target operator
target_operator = new_ops_list[pos_to_replace]
target_operator["model"] = rewrite.model
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
global_default_model: str = None,
allowed_model_list: List[str] = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Update allowed_model_list if provided
if allowed_model_list is not None:
self.allowed_model_list = allowed_model_list
new_ops_list = deepcopy(operators)
inst_error = 0
for target_op in target_ops:
target_op_config = [op for op in operators if op["name"] == target_op][0]
# Instantiate the directive
try:
rewrite, message_history, call_cost = self.llm_instantiate(
global_default_model,
target_op_config,
agent_llm,
message_history,
)
print(rewrite)
except Exception:
inst_error += 1
new_ops_list = self.apply(
global_default_model, new_ops_list, target_op, rewrite
)
if inst_error == len(target_ops):
print("CHANGE MODEL ACC ERROR")
return None, message_history
return new_ops_list, message_history, call_cost

View File

@ -0,0 +1,395 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import ChangeModelInstantiateSchema
def get_cheaper_models(current_model: str, model_stats: Dict = None) -> List[str]:
"""
Get list of models that are cheaper than the current model.
Args:
current_model: The current model name (may include "azure/" or "gemini/" prefix)
model_stats: Dictionary of model statistics from MOARSearch
Returns:
List of model names that are cheaper than current_model, sorted by cost (cheapest first)
"""
if model_stats is None or not model_stats:
return []
current_model_stats = model_stats.get(current_model)
if current_model_stats is None:
return []
current_cost = current_model_stats.get("cost")
if current_cost is None:
return []
cheaper_models = []
for model, stats in model_stats.items():
if not isinstance(stats, dict):
continue
model_cost = stats.get("cost")
if model_cost is not None and model_cost < current_cost:
cheaper_models.append(model)
return cheaper_models
from .base import ( # noqa: E402
MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS,
Directive,
DirectiveTestCase,
)
def create_model_specific_directives(
current_model: str, allowed_model_list: List[str] = None
):
"""Create model-specific directives for the current model."""
directive = ChangeModelCostDirective(target_model=current_model)
directive.name = f"change to {current_model}"
directive.nl_description = f"Rewrites an operator to use the {current_model} model to optimize expenses while maintaining adequate performance."
if allowed_model_list is not None:
directive.allowed_model_list = allowed_model_list
else:
directive.allowed_model_list = [
current_model
] # Only allow this specific model if no list provided
return directive
class ChangeModelCostDirective(Directive):
name: str = Field(
default="change model cost", description="The name of the directive"
)
formal_description: str = Field(
default="Op => Op* (same operation with a more cost-effective model choice)"
)
nl_description: str = Field(
default="Rewrites an operator to use a more cost-effective LLM model to optimize expenses. Prioritizes cost savings while maintaining adequate performance, typically suggesting cheaper models like gpt-4o-mini or gpt-5-nano for simpler tasks."
)
when_to_use: str = Field(
default="When cost optimization is the primary concern and the task can be performed adequately by a less expensive model. Suitable for simple extraction, basic classification, or high-volume processing where budget constraints are important."
)
instantiate_schema_type: Type[BaseModel] = ChangeModelInstantiateSchema
example: str = Field(
default=(
"Original Op (MapOpConfig):\n"
"- name: extract_names\n"
" type: map\n"
" prompt: |\n"
" Extract person names from the following text.\n"
" Text: {{ input.text }}\n"
" output:\n"
" schema:\n"
' names: "list[str]"\n'
" model: gpt-5\n"
"\n"
"Example InstantiateSchema:\n"
"{\n"
' "model": "gpt-4o-mini"\n'
"}"
),
)
allowed_model_list: List[str] = Field(
default_factory=list,
description="The allowed list of models to choose from",
)
target_model: str = Field(
default="",
description="The specific target model for this directive instance",
)
model_info: str = Field(
default=(
"""
OpenAI-MRCR evaluates a model's ability to locate and disambiguate multiple well-hidden "needles" within a large context.
Below are the actual performance scores for the 2-needle retrieval task at various context lengths. Use these results to compare the retrieval capabilities of each model.
The below results are the mean match ratio.
Context Length | gpt-5 | gpt-5-nano | gpt-4o-mini | gemini-2.5-pro | gemini-2.5-flash | gpt-4.1 | gpt-4.1-mini | gpt-4.1-nano | gemini-2.5-flash-lite
---------------|-------|------------|-------------|----------------|------------------|---------|--------------|--------------|----------------------
128k | 97% | 44% | 25% | 83.6% | 86.2% | 61.3% | 47.1% | 36.7% | 39.9%
1M | - | - | - | 62.8% | 60.0% | 45.9% | 34.6% | 14.2% | 18.1%
The context window and pricing details for each model are shown below (token prices are per 1 million tokens):
| Family | Model | Input Price /1M | Output Price /1M | Context Window (API) |
|----------------|----------------------------|---------------------------------|-----------------------------------|-----------------------------|
| **GPT-5** | azure/gpt-5 | $1.25 | $10.00 | 400K (272K in + 128K out) |
| | azure/gpt-5-mini | $0.25 | $2.00 | 400K |
| | azure/gpt-5-nano | $0.05 | $0.40 | 400K |
| **GPT-4.1** | azure/gpt-4.1 | $2.00 | $8.00 | 1M |
| | azure/gpt-4.1-mini | $0.40 | $1.60 | 1M |
| | azure/gpt-4.1-nano | $0.10 | $0.40 | 1M |
| **GPT-4o** | azure/gpt-4o | $2.50 | $10.00 | 128K |
| | azure/gpt-4o-mini | $0.15 | $0.60 | 128K (16K output cap) |
| **Gemini 2.5** | gemini/gemini-2.5-pro | $1.25 (200K) / $2.50 (>200K) | $10.00 (200K) / $15.00 (>200K) | 1M (2M soon) |
| | gemini/gemini-2.5-flash | $0.30 | $2.50 | 1M |
| | gemini/gemini-2.5-flash-lite | $0.10 | $0.40 | 1M |
"""
),
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="cost_optimization_simple_extraction",
description="Should suggest cheaper model for simple extraction task",
input_config={
"name": "extract_names",
"type": "map",
"prompt": "Extract person names from: {{ input.text }}",
"output": {"schema": {"names": "list[str]"}},
"model": "gpt-5",
},
target_ops=["extract_names"],
expected_behavior="Should recommend gpt-4o-mini or gpt-5-nano for cost savings on simple task",
should_pass=True,
),
DirectiveTestCase(
name="high_volume_processing_cost_optimization",
description="Should optimize model for high-volume document processing to reduce costs",
input_config={
"name": "batch_document_summary",
"type": "map",
"prompt": "Summarize this document in 2-3 sentences: {{ input.document }}",
"output": {"schema": {"summary": "str"}},
"model": "gpt-5",
},
target_ops=["batch_document_summary"],
expected_behavior="Should recommend cheaper model like gpt-4o-mini or gpt-5-nano for high-volume simple summarization",
should_pass=True,
),
DirectiveTestCase(
name="basic_classification_cost_savings",
description="Should recommend cost-effective model for basic classification",
input_config={
"name": "classify_sentiment",
"type": "map",
"prompt": "Classify the sentiment of this text as positive, negative, or neutral: {{ input.text }}",
"output": {"schema": {"sentiment": "string"}},
"model": "gpt-5",
},
target_ops=["classify_sentiment"],
expected_behavior="Should recommend cheaper model for basic sentiment classification to reduce costs",
should_pass=True,
),
]
)
def __eq__(self, other):
return (
isinstance(other, ChangeModelCostDirective)
and self.target_model == other.target_model
)
def __hash__(self):
return hash(f"ChangeModelCostDirective_{self.target_model}")
def to_string_for_instantiate(
self, original_op: Dict, dataset: str, model_stats: Dict = None
) -> str:
"""
Generate a prompt for an agent to instantiate this directive for cost optimization.
Args:
original_op (str): The YAML or string representation of the original operation.
dataset: The dataset name
model_stats: Dictionary of model statistics from MOARSearch (optional)
Returns:
str: The agent prompt for instantiating the directive.
"""
model_stats_str = ""
if model_stats:
model_stats_str = f"You have a list of model statistics on the task with the original query pipeline: \n {str(model_stats)}\n"
return (
f"You are an expert at choosing the most cost-effective model for a given task while maintaining adequate performance.\n\n"
f"Original Operation:\n"
f"{str(original_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to instantiate this directive by suggesting the cheapest model that meets the requirements for executing the original operation.\n\n"
f"COST OPTIMIZATION STRATEGY:\n"
f"• Choose the cheapest model that meets the task requirements\n"
f"• For tasks requiring ultra-long context (1M+ context window), use gpt-4.1 series or gemini models\n"
f"• For tasks that can work with document samples or fit within 272k context window, use gpt-5-nano\n"
f"• For simple tasks (extraction, basic classification, straightforward summarization), use the cheapest available model like gpt-4o-mini or gpt-5-nano\n"
f"• For high-volume processing where cost accumulates quickly, prioritize the most economical option\n"
f"• Only use expensive models if the task absolutely requires capabilities not available in cheaper alternatives\n"
f"• Consider document length and context requirements when selecting models\n\n"
f"You have a list of allowed models to choose from: {str(self.allowed_model_list)}.\n\n"
f"Consider the information about the allowed models: \n {self.model_info}\n"
f"{model_stats_str}"
f"Your response should include the cheapest model choice that meets the operation requirements."
f"Ensure that your chosen model is in the list of allowed models."
f"Example:\n"
f"{self.example}\n\n"
f"Please output only the ChangeModelInstantiateSchema as JSON."
)
def llm_instantiate(
self,
global_default_model: str,
original_op: Dict,
agent_llm: str,
dataset: str,
message_history: list = [],
model_stats: Dict = None,
):
"""
Use LLM to instantiate this directive for cost optimization.
Args:
original_op (Dict): The original operation.
agent_llm (str): The LLM model to use.
dataset: The dataset name
message_history (List, optional): Conversation history for context.
model_stats: Dictionary of model statistics from MOARSearch (optional)
Returns:
ChangeModelInstantiateSchema: The structured output from the LLM.
"""
# If target_model is specified, use it directly without LLM call
if self.target_model:
schema = ChangeModelInstantiateSchema(model=self.target_model)
return schema, message_history, 0.0
# Otherwise, use LLM to choose the model
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for document processing pipelines focused on cost optimization.",
},
{
"role": "user",
"content": self.to_string_for_instantiate(
original_op, dataset, model_stats
),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=ChangeModelInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ChangeModelInstantiateSchema(**parsed_res)
orig_model = global_default_model
if "model" in original_op:
orig_model = original_op.get("model")
# Validate the model is in the allowed model list
ChangeModelInstantiateSchema.validate_diff_model_in_list(
orig_model=orig_model,
model=schema.model,
list_of_model=self.allowed_model_list,
)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model: str,
ops_list: List[Dict],
target_op: str,
rewrite: ChangeModelInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config by changing the model of the target operator.
"""
# Create a copy of the pipeline config
new_ops_list = deepcopy(ops_list)
# Find position of the target op to modify
pos_to_replace = [
i for i, op in enumerate(ops_list) if op["name"] == target_op
][0]
# Add change model configuration to the target operator
target_operator = new_ops_list[pos_to_replace]
target_operator["model"] = rewrite.model
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
global_default_model: str = None,
dataset: str = None,
model_stats: Dict = None,
allowed_model_list: List[str] = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
Args:
operators: List of operator configurations
target_ops: List of target operation names
agent_llm: LLM model to use for instantiation
message_history: Conversation history
global_default_model: Default model for the pipeline
dataset: Dataset name
model_stats: Dictionary of model statistics from MOARSearch (optional)
allowed_model_list: List of allowed models (optional)
**kwargs: Additional keyword arguments
"""
# Update allowed_model_list if provided
if allowed_model_list is not None:
self.allowed_model_list = allowed_model_list
new_ops_list = deepcopy(operators)
inst_error = 0
for target_op in target_ops:
target_op_config = [op for op in operators if op["name"] == target_op][0]
# Instantiate the directive
try:
rewrite, message_history, call_cost = self.llm_instantiate(
global_default_model,
target_op_config,
agent_llm,
dataset,
message_history,
model_stats,
)
print(rewrite)
except Exception:
inst_error += 1
new_ops_list = self.apply(
global_default_model, new_ops_list, target_op, rewrite
)
if inst_error == len(target_ops):
print("CHANGE MODEL COST ERROR")
return None, message_history
return new_ops_list, message_history, call_cost

View File

@ -0,0 +1,348 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
ChunkHeaderSummaryInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class ChunkHeaderSummaryDirective(Directive):
name: str = Field(
default="chunk_header_summary", description="The name of the directive"
)
formal_description: str = Field(default="Split -> Gather => Split -> Map -> Gather")
nl_description: str = Field(
default="Transforms an existing Split -> Gather pipeline by inserting a Map operation between them that extracts headers and creates summaries from each chunk. The Gather operation is then modified to use summaries for middle chunks and headers for document structure. This directive enhances chunking pipelines with header extraction and chunk summarization capabilities. Only use this if it is clear that chunk-level analysis is insufficient because the chunk requires headers and summaries from other chunks to be interpreted correctly."
)
when_to_use: str = Field(
default="Use only when you have an existing chunking pipeline (Split -> Gather) processing documents with clear hierarchical structure (legal contracts, technical manuals, research papers), and it is evident that chunk-level analysis is not accurate because the chunk needs headers and summaries from other chunks to make sense. This is beneficial when full chunk content in gather would be too verbose, and summarized or structured context is required for correct downstream processing. The target operators should be the split and gather. Make sure you specify these two operators when choosing this directive."
)
instantiate_schema_type: Type[BaseModel] = Field(
default=ChunkHeaderSummaryInstantiateSchema
)
example: str = Field(
default="""
Original Pipeline (Split -> Gather):
- name: split_contract_analysis
type: split
split_key: contract_text
method: token_count
method_kwargs:
num_tokens: 1000
- name: gather_contract_context
type: gather
content_key: contract_text_chunk
doc_id_key: split_contract_analysis_id
order_key: split_contract_analysis_chunk_num
peripheral_chunks:
previous:
tail:
count: 1
Example InstantiateSchema Options (what the agent should output):
# Basic header and summary extraction:
{
"header_extraction_prompt": "Extract any section headers or subsection titles from this contract chunk: {{ input.contract_text_chunk }}. Return the headers with their hierarchical levels.",
"summary_prompt": "Summarize the key legal concepts and clause types in this contract chunk: {{ input.contract_text_chunk }}. Focus on liability, indemnification, and related contractual obligations.",
}
""",
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="legal_document_with_structure",
description="Should transform split->gather pipeline to include header extraction and summarization",
input_config=[
{
"name": "split_legal_docs",
"type": "split",
"split_key": "agreement_text",
"method": "token_count",
"method_kwargs": {"num_tokens": 1000},
},
{
"name": "gather_legal_context",
"type": "gather",
"content_key": "agreement_text_chunk",
"doc_id_key": "split_legal_docs_id",
"order_key": "split_legal_docs_chunk_num",
"peripheral_chunks": {
"previous": {"tail": {"count": 1}},
"next": {"head": {"count": 1}},
},
},
],
target_ops=["split_legal_docs", "gather_legal_context"],
expected_behavior="Should insert parallel_map between split and gather for header extraction and summarization, with gather using doc_header_key",
should_pass=True,
),
DirectiveTestCase(
name="technical_manual_analysis",
description="Should transform split->gather pipeline for technical documentation with header and summary context",
input_config=[
{
"name": "split_manual",
"type": "split",
"split_key": "manual_text",
"method": "token_count",
"method_kwargs": {"num_tokens": 800},
},
{
"name": "gather_manual_context",
"type": "gather",
"content_key": "manual_text_chunk",
"doc_id_key": "split_manual_id",
"order_key": "split_manual_chunk_num",
"peripheral_chunks": {
"previous": {"tail": {"count": 2}},
"next": {"head": {"count": 1}},
},
},
],
target_ops=["split_manual", "gather_manual_context"],
expected_behavior="Should insert parallel_map between split and gather for header extraction and chunk summarization, enhancing technical documentation processing",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, ChunkHeaderSummaryDirective)
def __hash__(self):
return hash("ChunkHeaderSummaryDirective")
def to_string_for_instantiate(self, original_op: Dict) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
original_op (Dict): The original operation.
Returns:
str: The agent prompt for instantiating the directive.
"""
return (
f"You are an expert at enhancing document processing pipelines with header extraction and chunk summarization.\n\n"
f"Original Pipeline:\n"
f"{str(original_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to instantiate this directive by creating a configuration that enhances an existing Split -> Gather pipeline "
f"by inserting a Map operation between them for header extraction and summarization.\n\n"
f"Key requirements:\n"
f"1. header_extraction_prompt: Create a prompt to extract headers/section titles from each chunk:\n"
f" - Use '{{{{ input.<split_key>_chunk }}}}' to reference chunk content (you'll know the split_key from the existing split operation)\n"
f" - Focus on document structure (titles, headings, section numbers)\n"
f" - Should output 'headers' field with hierarchical level information\n"
f"2. summary_prompt: Create a prompt to summarize each chunk:\n"
f" - Use '{{{{ input.<split_key>_chunk }}}}' to reference chunk content\n"
f" - Focus on key concepts relevant to the downstream processing\n"
f" - Should output '<split_key>_summary' field\n"
f" - Keep summaries concise but informative for context\n"
f"The header extraction helps maintain document structure context.\n"
f"The summary provides condensed context from surrounding chunks.\n"
f"The gather operation combines both for comprehensive context.\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please output only the ChunkHeaderSummaryInstantiateSchema object as JSON."
)
def llm_instantiate(
self,
original_op: Dict,
agent_llm: str,
message_history: list = [],
):
"""
Use LLM to instantiate this directive by creating chunking configuration.
Args:
original_op (Dict): The original operation.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
Returns:
ChunkHeaderSummaryInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "user",
"content": self.to_string_for_instantiate(original_op),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=ChunkHeaderSummaryInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ChunkHeaderSummaryInstantiateSchema(**parsed_res)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model: str,
ops_list: List[Dict],
target_ops: List[str],
rewrite: ChunkHeaderSummaryInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config by inserting a parallel_map operation
between the existing split and gather operations.
"""
# Create a copy of the pipeline config
new_ops_list = deepcopy(ops_list)
# Find the split and gather operations
split_op = None
gather_op = None
split_idx = None
gather_idx = None
for i, op in enumerate(new_ops_list):
if op["name"] in target_ops:
if op["type"] == "split":
split_op = op
split_idx = i
elif op["type"] == "gather":
gather_op = op
gather_idx = i
if not split_op or not gather_op:
raise ValueError(
"Both split and gather operations must be provided as target operations"
)
if split_idx >= gather_idx:
raise ValueError(
"Split operation must come before gather operation in the pipeline"
)
if gather_idx - split_idx > 1:
raise ValueError("There should not be operators between split and gather")
# Get the split_key from the split operation
split_key = split_op.get("split_key")
if not split_key:
raise ValueError(
f"Split operation '{split_op['name']}' must have a 'split_key' field"
)
# Create the parallel map operation for header extraction and summarization
parallel_map_name = f"parallel_map_{split_op['name']}_header_summary"
parallel_map_op = {
"name": parallel_map_name,
"type": "parallel_map",
"prompts": [
{
"name": f"header_extraction_{split_op['name']}",
"prompt": rewrite.header_extraction_prompt,
"output_keys": ["headers"],
},
{
"name": f"summary_{split_op['name']}",
"prompt": rewrite.summary_prompt,
"output_keys": [f"{split_key}_summary"],
},
],
"model": global_default_model,
"output": {
"schema": {
"headers": "list[{header: string, level: integer}]",
f"{split_key}_summary": "string",
}
},
}
# Add doc_header_key to the gather operation to use extracted headers
new_ops_list[gather_idx]["doc_header_key"] = "headers"
# Insert the parallel map operation between split and gather
new_ops_list.insert(gather_idx, parallel_map_op)
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
global_default_model: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there are exactly two target ops
assert (
len(target_ops) == 2
), "There must be exactly two target ops (split and gather) to instantiate this chunk header summary directive"
# Find split and gather operations
split_op = None
gather_op = None
for op in operators:
if op["name"] in target_ops:
if op["type"] == "split":
split_op = op
elif op["type"] == "gather":
gather_op = op
if not split_op:
raise ValueError(
f"Chunk header summary directive requires a split operation among target operations, but none found in {target_ops}"
)
if not gather_op:
raise ValueError(
f"Chunk header summary directive requires a gather operation among target operations, but none found in {target_ops}"
)
# Create a combined context for instantiation
pipeline_context = {
"split_op": split_op,
"gather_op": gather_op,
"target_ops": target_ops,
}
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
pipeline_context, agent_llm, message_history
)
# Apply the rewrite to the operators
return (
self.apply(global_default_model, operators, target_ops, rewrite),
message_history,
call_cost,
)

View File

@ -0,0 +1,282 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
ClarifyInstructionsInstantiateSchema,
)
from .agent_utils import AgenticDirectiveRunner
from .base import Directive, DirectiveTestCase
class ClarifyInstructionsDirective(Directive):
name: str = Field(
default="clarify_instructions", description="The name of the directive"
)
formal_description: str = Field(default="Single-op => Op")
nl_description: str = Field(
default="Improves a single operation's prompt clarity and specificity by analyzing sample input data to identify patterns, resolve ambiguities, and create more precise instructions that reduce LLM confusion and improve output consistency"
)
when_to_use: str = Field(
default="When a single operation has a vague or ambiguous prompt that could benefit from more specific instructions based on actual data patterns. Particularly useful when you have multiple samples of input data and want to create a prompt for one specific operation that handles the patterns and edge cases present in your dataset"
)
instantiate_schema_type: Type[BaseModel] = Field(
default=ClarifyInstructionsInstantiateSchema
)
example: str = Field(
default="""
Target Operation:
- name: extract_key_findings
type: map
prompt: |
Extract the key findings from: {{ input.research_paper }}
output:
schema:
findings: "list[str]"
After analyzing sample research papers, the agent might discover papers contain
abstracts, conclusions, and results sections with different formats.
Example InstantiateSchema (what the agent should output):
ClarifyInstructionsInstantiateSchema(
clarified_prompt="Extract the key findings from the research paper: {{ input.research_paper }}. Focus on: 1) Main experimental results and statistical significance from Results sections, 2) Primary conclusions from Abstract and Conclusion sections, 3) Novel contributions explicitly stated by authors. Ignore methodological details, related work summaries, and future work suggestions. Return 3-7 concise findings as bullet points."
)
"""
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="single_target_clarification",
description="Should create clarified prompt for single target operation",
input_config={
"name": "analyze_feedback",
"type": "map",
"prompt": "Analyze the feedback: {{ input.feedback }}",
"output": {
"schema": {"sentiment": "string", "issues": "list[str]"}
},
},
target_ops=["analyze_feedback"],
expected_behavior="Should replace the original prompt with a more specific version based on analysis of sample feedback data. The clarified prompt should reference {{ input.feedback }} and provide specific guidance on what to analyze.",
should_pass=True,
),
DirectiveTestCase(
name="preserve_input_variables",
description="Should preserve all input variable references from original prompt",
input_config={
"name": "compare_documents",
"type": "map",
"prompt": "Compare {{ input.doc1 }} with {{ input.doc2 }} and identify differences",
"output": {"schema": {"differences": "list[str]"}},
},
target_ops=["compare_documents"],
expected_behavior="Should create clarified prompt that still references both {{ input.doc1 }} and {{ input.doc2 }} while providing more specific comparison instructions.",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, ClarifyInstructionsDirective)
def __hash__(self):
return hash("ClarifyInstructionsDirective")
def to_string_for_instantiate(
self, target_ops_configs: List[Dict], pipeline_code: Dict = None
) -> str:
"""
Generate a prompt that asks the agent to analyze sample data and create an improved prompt.
"""
assert (
len(target_ops_configs) == 1
), "ClarifyInstructions directive only supports single target operation"
op = target_ops_configs[0]
original_prompt = op.get("prompt", "")
# Build pipeline context
pipeline_context = ""
if pipeline_code:
pipeline_context = f"""
Pipeline Context:
{json.dumps(pipeline_code, indent=2)}
The target operation '{op['name']}' fits into this broader pipeline. Consider:
- What data flows into this operation from previous steps
- How this operation's output will be used by subsequent operations
- The overall goal of the pipeline when creating your improved prompt
"""
return (
f"You are an expert at analyzing data patterns and creating precise, effective prompts.\n\n"
f"Target Operation:\n"
f"{json.dumps(op, indent=2)}\n\n"
f"Original Prompt: {original_prompt}\n\n"
f"{pipeline_context}\n"
f"Your task is to analyze sample input data and create a significantly improved version of this prompt.\n\n"
f"You will be given access to sample input data through a read_next_doc() function. Use this to:\n"
f"1. Understand the actual structure and patterns in the input data\n"
f"2. Identify ambiguities in the original prompt that could be clarified\n"
f"3. Discover specific patterns, formats, or edge cases that should be addressed\n"
f"4. Consider how this operation fits into the broader pipeline context\n"
f"5. Create a more specific, actionable prompt based on these insights\n\n"
f"Guidelines for the improved prompt:\n"
f"- Must preserve ALL original input variable references (like {{{{ input.fieldname }}}})\n"
f"- Should be significantly more specific than the original\n"
f"- Include concrete instructions based on patterns observed in the data\n"
f"- Address potential ambiguities or edge cases discovered in samples\n"
f"- Consider the pipeline context and how this operation contributes to the overall goal\n"
f"- Maintain the same general task but with much clearer execution details\n\n"
f"Example transformation:\n"
f"{self.example}\n\n"
f"Analyze samples strategically - focus on diversity and understanding patterns rather than reading every document.\n"
f"When you have enough information to create a substantially improved prompt, output your result.\n\n"
f"Remember: Your goal is to make the prompt so clear and specific that it produces more consistent, higher-quality results."
)
def llm_instantiate(
self,
target_ops_configs: List[Dict],
input_file_path: str,
agent_llm: str,
message_history: list = [],
pipeline_code: Dict = None,
):
"""
Use agentic approach to analyze sample data and generate improved prompt.
"""
# Load sample input data
try:
with open(input_file_path, "r") as f:
input_data = json.load(f)
if not isinstance(input_data, list) or len(input_data) == 0:
raise ValueError(
"Input file must contain a non-empty list of sample data"
)
except Exception as e:
raise Exception(
f"Failed to load input data from {input_file_path}: {str(e)}"
)
# Create validation function for input variable preservation
original_prompt = target_ops_configs[0].get("prompt", "")
def validate_input_variables(schema_instance):
ClarifyInstructionsInstantiateSchema.validate_input_variables_preserved(
schema_instance.clarified_prompt, original_prompt
)
# Set up agentic runner with validation
runner = AgenticDirectiveRunner(
input_data=input_data,
agent_llm=agent_llm,
validation_func=validate_input_variables,
)
# Create system prompt for the agentic runner
system_prompt = (
"You are an expert prompt engineer who analyzes data to create better, more specific prompts. "
"Your goal is to examine input samples to understand patterns, identify ambiguities, and create "
"significantly improved prompts that are more specific and actionable than the original. "
"You also consider the broader pipeline context to ensure the improved prompt serves the overall data processing goal."
)
# Create initial user message
initial_message = self.to_string_for_instantiate(
target_ops_configs, pipeline_code
)
# Run the agentic loop (validation is handled internally)
try:
schema, updated_message_history, call_cost = runner.run_agentic_loop(
system_prompt=system_prompt,
initial_user_message=initial_message,
response_schema=ClarifyInstructionsInstantiateSchema,
)
# Update message history
message_history.extend(updated_message_history)
return schema, message_history, call_cost
except Exception as e:
raise Exception(
f"Failed to instantiate clarify_instructions directive: {str(e)}"
)
def apply(
self,
global_default_model: str,
ops_list: List[Dict],
target_ops: List[str],
rewrite: ClarifyInstructionsInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive by replacing the target operation's prompt with the clarified version.
"""
new_ops_list = deepcopy(ops_list)
# Find and update the target operation
for i, op in enumerate(new_ops_list):
if op["name"] in target_ops:
# Update the prompt with the clarified version
new_ops_list[i]["prompt"] = rewrite.clarified_prompt
break
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
global_default_model: str = None,
**kwargs,
):
"""
Main method that orchestrates directive instantiation:
1. Use agentic approach to analyze data and generate improved prompt
2. Apply the transformation using that improved prompt
"""
assert (
len(target_ops) == 1
), "ClarifyInstructions directive requires exactly one target operation"
input_file_path = kwargs.get("input_file_path", None)
pipeline_code = kwargs.get("pipeline_code", None)
if not input_file_path:
raise ValueError(
"input_file_path is required for ClarifyInstructions directive"
)
# Get configuration for target operation
target_ops_configs = [op for op in operators if op["name"] in target_ops]
if not target_ops_configs:
raise ValueError(f"Target operation {target_ops[0]} not found in operators")
# Step 1: Agent analyzes data and generates improved prompt
rewrite, message_history, call_cost = self.llm_instantiate(
target_ops_configs,
input_file_path,
agent_llm,
message_history,
pipeline_code,
)
# Step 2: Apply transformation using the improved prompt
return (
self.apply(global_default_model, operators, target_ops, rewrite),
message_history,
call_cost,
)

View File

@ -0,0 +1,350 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
DeterministicDocCompressionInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
# TODO: For the agent instantiating the rewrite directive,
# we might want to allow it to look at some example documents /
# enough documents until it feels like it has a good understanding of the data.
class DeterministicDocCompressionDirective(Directive):
name: str = Field(
default="deterministic_doc_compression", description="The name of the directive"
)
formal_description: str = Field(default="Op => Code Map -> Op")
nl_description: str = Field(
default="Reduces LLM processing costs by using deterministic logic (regex, patterns) to compress documents before expensive downstream operations, removing irrelevant content that could distract the LLM"
)
when_to_use: str = Field(
default="When documents contain identifiable patterns or keywords and you want to reduce token costs for downstream LLM operations while improving accuracy by eliminating distracting irrelevant content"
)
instantiate_schema_type: Type[BaseModel] = Field(
default=DeterministicDocCompressionInstantiateSchema
)
example: str = Field(
default="""
Target Operations:
- name: analyze_regulatory_compliance
type: map
prompt: |
Analyze regulatory compliance issues in this legal document: {{ input.legal_document }}
Focus on identifying violations, required actions, and compliance deadlines.
output:
schema:
violations: "list[str]"
required_actions: "list[str]"
deadlines: "list[str]"
Example InstantiateSchema (what the agent should output):
{
"name": "extract_compliance_sections",
"code": '''
def transform(input_doc):
import re
legal_document = input_doc.get('legal_document', '')
# Patterns to identify compliance-related content
compliance_patterns = [
r'(?i)(violat[e|ion]|breach|non-complian[ce|t])',
r'(?i)(deadline|due date|expir[e|ation]|within.*days?)',
r'(?i)(shall|must|required|mandatory|obligation)',
r'(?i)(section|article|clause)\\s+\\d+.*(complian[ce|t]|regulat[e|ory])'
]
relevant_spans = []
# Find all matches and extract context around them
for pattern in compliance_patterns:
for match in re.finditer(pattern, legal_document):
start_pos = match.start()
end_pos = match.end()
# Extract 300 chars before and 800 chars after the match
context_start = max(0, start_pos - 300)
context_end = min(len(legal_document), end_pos + 800)
# Extract the context around the match
context = legal_document[context_start:context_end]
relevant_spans.append((context_start, context_end, context))
# Merge overlapping spans and remove duplicates
if relevant_spans:
# Sort by start position
relevant_spans.sort(key=lambda x: x[0])
merged_spans = [relevant_spans[0]]
for current_start, current_end, current_text in relevant_spans[1:]:
last_start, last_end, last_text = merged_spans[-1]
if current_start <= last_end + 100: # Merge if close enough
# Extend the last span
new_end = max(last_end, current_end)
new_text = legal_document[last_start:new_end]
merged_spans[-1] = (last_start, new_end, new_text)
else:
merged_spans.append((current_start, current_end, current_text))
# Extract just the text portions
compressed_text = '\\n\\n--- SECTION BREAK ---\\n\\n'.join([span[2] for span in merged_spans])
else:
compressed_text = legal_document # Fallback if no matches
return {
'legal_document': compressed_text
}
'''
}
"""
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="detailed_merger_agreement_analysis",
description="Should compress merger agreement for comprehensive legal analysis",
input_config={
"name": "analyze_merger_agreement_terms",
"type": "map",
"prompt": """Perform a comprehensive legal analysis of this merger agreement: {{ input.merger_agreement }}
Analyze and extract the following:
1. Purchase price structure and payment mechanisms (cash, stock, earnouts, escrow arrangements)
2. Material adverse change (MAC) definitions and carve-outs that could affect deal completion
3. Representations and warranties with survival periods and liability caps
4. Closing conditions precedent, including regulatory approvals and third-party consents
5. Termination rights and associated breakup fees or reverse breakup fees
6. Indemnification provisions including baskets, caps, and survival periods
7. Employee retention arrangements and change-in-control provisions
8. Integration planning requirements and operational restrictions during pendency
9. Dispute resolution mechanisms and governing law provisions
10. Post-closing adjustments and working capital mechanisms
For each area, provide specific clause references, dollar amounts where applicable,
time periods, and risk assessment (High/Medium/Low) with justification.""",
"output": {
"schema": {
"purchase_price_analysis": "string",
"mac_provisions": "list[str]",
"representations_warranties": "list[str]",
"closing_conditions": "list[str]",
"termination_rights": "string",
"indemnification_terms": "string",
"employee_provisions": "list[str]",
"integration_restrictions": "list[str]",
"dispute_resolution": "string",
"post_closing_adjustments": "string",
"risk_assessment": "string",
}
},
},
target_ops=["analyze_merger_agreement_terms"],
expected_behavior="Should add Code Map operation that extracts merger agreement sections using regex patterns for legal terms, financial provisions, and risk-related clauses. The return dictionary of the transform function should be {'merger_agreement': ....} only.",
should_pass=True,
),
DirectiveTestCase(
name="multi_document_analysis_compression",
description="Should compress document for multiple analysis operations",
input_config=[
{
"name": "extract_financial_metrics",
"type": "map",
"prompt": "Extract revenue, profit, and expense figures from: {{ input.earnings_report }}",
"output": {
"schema": {
"revenue": "string",
"profit": "string",
"expenses": "list[str]",
}
},
},
{
"name": "assess_financial_risks",
"type": "map",
"prompt": "Identify financial risks and warning signs in: {{ input.earnings_report }}",
"output": {
"schema": {
"risks": "list[str]",
"warning_signs": "list[str]",
}
},
},
],
target_ops=["extract_financial_metrics", "assess_financial_risks"],
expected_behavior="Should add Code Map operation that extracts financial content needed for both operations. The return dictionary of the transform function should be {'earnings_report': ....} only.",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, DeterministicDocCompressionDirective)
def __hash__(self):
return hash("DeterministicDocCompressionDirective")
def to_string_for_instantiate(self, target_ops_configs: List[Dict]) -> str:
"""
Generate a prompt that asks the agent to output the instantiate schema.
This prompt explains to the LLM what configuration it needs to generate.
"""
ops_str = "\n".join(
[
f"Operation {i+1}:\n{str(op)}\n"
for i, op in enumerate(target_ops_configs)
]
)
return (
f"You are an expert at document processing and Python programming.\n\n"
f"Target Operations:\n"
f"{ops_str}\n\n"
f"Directive: {self.name}\n"
f"Your task is to instantiate this directive by specifying a Code Map operation "
f"that specifies how to compress the input document using deterministic logic.\n\n"
f"The directive will insert a Code Map operation that:\n"
f"1. Takes document field(s) from the input\n"
f"2. Uses deterministic logic (regex, keyword matching, pattern extraction) to compress them\n"
f"3. Returns a dictionary with the EXACT SAME document field key and compressed content\n"
f"4. Reduces token usage and improves focus for the downstream operations\n\n"
f"The agent must output the configuration specifying:\n"
f"- name: A descriptive name for the Code Map operation\n"
f"- code: Python code defining a 'transform' function that:\n"
f" * Takes input_doc as parameter\n"
f" * Imports 're' and other standard libraries WITHIN the function itself\n"
f" * Only uses standard Python libraries (re, string, json, etc.) - no external packages\n"
f" * Uses deterministic logic to extract relevant content patterns\n"
f" * For each pattern match, extracts context around it (e.g., ±500 chars, or -300 to +800 chars)\n"
f" * Use your judgment to determine appropriate character counts that capture enough context\n"
f" * Merges overlapping context spans to avoid duplication\n"
f" * Returns a dictionary with the EXACT document key and compressed value: {{document_key: compressed_content}}\n\n"
f"CRITICAL: The returned dictionary must use the exact same document field names as the original, "
f"not modified versions like 'document_key_compressed'. The downstream operations expect the exact same field names.\n\n"
f"IMPORTANT: Focus on extracting the minimal content necessary for ALL target operations "
f"using deterministic pattern matching. Analyze each operation's prompt to identify what types "
f"of content patterns to look for.\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please output only the DeterministicDocCompressionInstantiateSchema as JSON "
f"that specifies how to apply this directive to the target operations."
)
def llm_instantiate(
self,
target_ops_configs: List[Dict],
agent_llm: str,
message_history: list = [],
):
"""
Call the LLM to generate the instantiate schema.
The LLM will output structured data matching DeterministicDocCompressionInstantiateSchema.
"""
message_history.extend(
[
{
"role": "user",
"content": self.to_string_for_instantiate(target_ops_configs),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=DeterministicDocCompressionInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = DeterministicDocCompressionInstantiateSchema(**parsed_res)
# Validate against target operations
schema.validate_against_target_ops(target_ops_configs)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model: str,
ops_list: List[Dict],
target_ops: List[str],
rewrite: DeterministicDocCompressionInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive using the instantiate schema configuration.
Inserts a Code Map operation before the first target operation.
"""
new_ops_list = deepcopy(ops_list)
# Find the position of the first target operation
first_target_pos = min(
[i for i, op in enumerate(ops_list) if op["name"] in target_ops]
)
# Create the Code Map operation
code_map_op = {
"name": rewrite.name,
"type": "code_map",
"code": rewrite.code,
}
# Insert the Code Map operation before the first target operation
new_ops_list.insert(first_target_pos, code_map_op)
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
global_default_model: str = None,
**kwargs,
):
"""
Main method that orchestrates directive instantiation:
1. Get agent to generate instantiate schema for all target operations
2. Apply the transformation using that schema
"""
assert len(target_ops) >= 1, "This directive requires at least one target op"
# Get configurations for all target operations
target_ops_configs = [op for op in operators if op["name"] in target_ops]
# Step 1: Agent generates the instantiate schema considering all target ops
rewrite, message_history, call_cost = self.llm_instantiate(
target_ops_configs, agent_llm, message_history
)
# Step 2: Apply transformation using the schema
return (
self.apply(global_default_model, operators, target_ops, rewrite),
message_history,
call_cost,
)

View File

@ -0,0 +1,466 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
DocumentChunkingInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class DocumentChunkingDirective(Directive):
name: str = Field(default="doc_chunking", description="The name of the directive")
formal_description: str = Field(
default="Map => Split -> Gather -> [Sample] -> Map -> Reduce"
)
nl_description: str = Field(
default="Transforms a single Map operation into a chunking pipeline: splits long documents into chunks, gathers context around each chunk, optionally samples a subset of chunks for efficiency, processes chunks with a new Map operation, then reduces the results. By default, sampling is applied unless the task requires processing ALL chunks. This directive can only be applied to a top-level Map operation, not to a sub-map within a pipeline that already contains a split, gather, or reduce sequence."
)
when_to_use: str = Field(
default="Use when you need to process long documents to extract information, and the document is too long for a single Map operation. The agent will automatically decide whether to sample chunks (for tasks like categorization, theme extraction) or process all chunks (for comprehensive extraction of all instances). Do not apply if the target operation is already part of a split -> gather -> map -> reduce pipeline. Use different gather configs: 'previous.head' for documents with key metadata/definitions at the start, 'previous.tail' for maintaining references, and 'next.head' only for tables/clauses spanning chunks."
)
instantiate_schema_type: Type[BaseModel] = Field(
default=DocumentChunkingInstantiateSchema
)
example: str = Field(
default="""
Original Op (MapOpConfig):
- name: extract_contract_terms
type: map
prompt: |
Extract all payment terms, deadlines, and penalty clauses from the contract:
{{ input.contract_text }}
Return a comprehensive list of all terms found.
output:
schema:
contract_terms: list[str]
Example InstantiateSchema Options (what the agent should output):
# Basic context - good for most cases:
{
"chunk_size": 10000,
"split_key": "contract_text",
"sub_prompt": "You are analyzing a chunk of a larger document. Extract all payment terms, deadlines, and penalty clauses from this contract chunk: {{ input.contract_text_chunk_rendered }}. Return a comprehensive list of all terms found.",
"reduce_prompt": "Combine results from multiple document chunks: Extract all payment terms, deadlines, and penalty clauses by combining the results from each chunk: {% for input in inputs %}{{ input.contract_terms | join(', ') }}{% if not loop.last %}, {% endif %}{% endfor %}. Remove duplicates and return a comprehensive list of all terms found.",
"gather_config": {
"previous": {
"tail": {
"count": 0.5
}
}
},
}
# Rich context - for complex documents needing document-level metadata:
{
"gather_config": {
"previous": {
"head": {
"count": 1,
"content_key": "contract_text_chunk"
},
"tail": {
"count": 2,
"content_key": "contract_text_chunk"
}
}
}
}
# Forward context - for tables or clauses spanning chunks:
{
"gather_config": {
"previous": {
"tail": {
"count": 1,
"content_key": "contract_text_chunk"
}
},
"next": {
"head": {
"count": 1,
"content_key": "contract_text_chunk"
}
}
}
}
""",
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="comprehensive_legal_analysis",
description="Should transform complex legal document analysis into chunking pipeline",
input_config={
"name": "analyze_legal_document",
"type": "map",
"prompt": "From this legal document, extract all liability clauses with risk ratings (1-10), identify all parties and their obligations, find all monetary amounts with currencies, extract all dates and deadlines with legal consequences, and list all governing laws or jurisdictions mentioned. For each liability clause, assess the risk level considering industry standards and provide specific reasoning. Group findings by document section if clearly indicated. Return comprehensive analysis ensuring no critical legal elements are missed: {{ input.legal_document }}",
"output": {
"schema": {
"liability_analysis": "list[str]",
"parties_obligations": "list[str]",
"financial_terms": "list[str]",
"critical_dates": "list[str]",
"governing_laws": "list[str]",
"risk_assessment": "str",
}
},
},
target_ops=["analyze_legal_document"],
expected_behavior="Should create chunking pipeline where sub_prompt covers all extraction tasks (liability, parties, financial, dates, laws) with same risk assessment criteria, and reduce_prompt aggregates all findings maintaining the complete analytical framework and output schema from original prompt",
should_pass=True,
),
DirectiveTestCase(
name="clinical_trial_comprehensive_extraction",
description="Should transform detailed clinical research analysis into chunking pipeline",
input_config={
"name": "extract_clinical_data",
"type": "map",
"prompt": "Analyze this clinical trial document and extract: primary and secondary endpoints with measurement criteria, all adverse events categorized by severity (mild/moderate/severe), patient demographics including inclusion/exclusion criteria, statistical significance results with p-values and confidence intervals, drug dosages and administration protocols, and study methodology details. For each adverse event, determine if it's treatment-related based on temporal relationship and biological plausibility. Calculate overall safety profile score (1-10) considering frequency and severity of events. Ensure all regulatory compliance elements are captured: {{ input.trial_document }}",
"output": {
"schema": {
"endpoints": "list[str]",
"adverse_events": "list[str]",
"demographics": "str",
"statistical_results": "list[str]",
"protocols": "list[str]",
"safety_assessment": "str",
"compliance_status": "str",
}
},
},
target_ops=["extract_clinical_data"],
expected_behavior="Should create chunking pipeline where sub_prompt preserves all clinical analysis requirements (endpoints, adverse events, demographics, statistics, protocols) with same assessment criteria, and reduce_prompt combines results maintaining complete clinical framework and safety scoring methodology from original prompt",
should_pass=True,
),
DirectiveTestCase(
name="financial_comprehensive_analysis",
description="Should transform complex financial document analysis into chunking pipeline",
input_config={
"name": "analyze_financial_report",
"type": "map",
"prompt": "From this annual financial report, extract all revenue streams with growth rates and market segments, identify all risk factors with impact assessments (low/medium/high), find all forward-looking statements and their associated uncertainties, extract key financial ratios and calculate trend analysis over mentioned periods, identify all subsidiaries with their contribution to consolidated results, and analyze competitive positioning statements. For each risk factor, assess potential financial impact in dollar ranges and likelihood percentages. Ensure all material information affecting investor decisions is captured and categorized by urgency level: {{ input.financial_report }}",
"output": {
"schema": {
"revenue_analysis": "list[str]",
"risk_factors": "list[str]",
"forward_statements": "list[str]",
"financial_ratios": "str",
"subsidiaries": "list[str]",
"competitive_analysis": "str",
"material_disclosures": "list[str]",
}
},
},
target_ops=["analyze_financial_report"],
expected_behavior="Should create chunking pipeline where sub_prompt maintains all financial analysis requirements (revenue, risks, statements, ratios, subsidiaries, competitive analysis) with same impact assessment methodology, and reduce_prompt aggregates preserving complete financial analytical framework and materiality assessments from original prompt",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, DocumentChunkingDirective)
def __hash__(self):
return hash("DocumentChunkingDirective")
def to_string_for_instantiate(self, original_op: Dict) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
original_op (Dict): The original operation.
Returns:
str: The agent prompt for instantiating the directive.
"""
return (
f"You are an expert at transforming document processing operations into chunking pipelines.\n\n"
f"Original Operation:\n"
f"{str(original_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to instantiate this directive by creating a configuration that transforms the original Map operation "
f"into a Split -> Gather -> Map -> Reduce pipeline for processing long documents in chunks.\n\n"
f"Key requirements:\n"
f"1. chunk_size: Choose an appropriate token count (typically 10000-15000) based on the complexity of the task\n"
f"2. split_key: Identify the document field to split from the original operation's prompt. Make sure to use the same field as the original operation's prompt.\n"
f"3. sub_prompt: Take the original prompt exactly but:\n"
f" - Add instruction at start: 'You are analyzing a chunk of a larger document.'\n"
f" - Replace '{{{{ input.<split_key> }}}}' with '{{{{ input.<split_key>_chunk_rendered }}}}'\n"
f" - Keep everything else identical (same task, same output schema)\n"
f"4. reduce_prompt: Take original task instructions but adapt for aggregation:\n"
f" - Start with: 'Combine results from multiple document chunks:'\n"
f" - Include same task context/requirements as original prompt\n"
f" - Use '{{% for input in inputs %}}' to iterate over chunk results\n"
f" - Combine/deduplicate to match original output schema exactly\n"
f"5. sampling_config: IMPORTANT - Include sampling by default UNLESS the task requires ALL chunks:\n"
f" - ALWAYS use sampling for: categorization, theme identification, sentiment analysis, document type classification\n"
f" - NEVER use sampling for: comprehensive extraction ('extract ALL instances'), complete analysis requiring every chunk\n"
f" - Default sampling: method='uniform' with stratify_key, samples=5-10 chunks\n"
f" - For simple tasks (categorization): samples=1-3 chunks\n"
f" - For complex analysis: samples=5-15 chunks\n"
f" - For stratified sampling: specify method='uniform' and stratify_key (note: split document ID is automatically included)\n"
f" - Set sampling_config=null only if you need to process every single chunk\n"
f"6. gather_config: Configure context from surrounding chunks. Structure:\n"
f" gather_config:\n"
f" previous: # chunks before current chunk\n"
f" head: # first chunk(s) in document\n"
f" count: 1\n"
f" content_key: full_content_chunk # optional, defaults to main key chunk\n"
f" tail: # chunk(s) immediately before current\n"
f" count: 2\n"
f" content_key: full_content_chunk # optional\n"
f" next: # chunks after current chunk\n"
f" head: # chunk(s) immediately after current\n"
f" count: 1\n"
f" content_key: full_content_chunk # optional\n"
f" Usage guidelines:\n"
f" - Use 'previous.head' when document has important metadata/definitions at start\n"
f" - Use 'previous.tail' to maintain references and immediate context\n"
f" - Use 'next.head' only for tables/clauses that span chunk boundaries\n"
f" - Minimize context: only include what's needed for accurate operation\n"
f" - Count can be float (e.g., 0.5 for half chunk, 1.5 for chunk and a half)\n"
f" - More context increases token usage and cost - be judicious\n"
f" - Default to 0.5 previous tail if unsure about context needs\n"
f"The sub_prompt should focus on the main chunk content and extract the same type of information as the original.\n"
f"The reduce_prompt must produce the same output schema as the original operation.\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please output only the DocumentChunkingInstantiateSchema object as JSON."
)
def llm_instantiate(
self,
operators,
input_file_path,
original_op: Dict,
agent_llm: str,
message_history: list = [],
) -> tuple:
"""
Use LLM to instantiate this directive by creating chunking configuration.
Args:
original_op (Dict): The original operation.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
Returns:
DocumentChunkingInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "user",
"content": self.to_string_for_instantiate(original_op),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
call_cost = 0.0
resp = completion(
model=agent_llm,
messages=message_history,
response_format=DocumentChunkingInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = DocumentChunkingInstantiateSchema(**parsed_res)
schema.validate_stratify_key_in_pipeline(operators)
schema.validate_split_key_exists_in_input(input_file_path)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model: str,
ops_list: List[Dict],
target_op: str,
rewrite: DocumentChunkingInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config by replacing the target operation
with a split -> gather -> map -> reduce sequence.
"""
# Create a copy of the pipeline config
new_ops_list = deepcopy(ops_list)
target_op_config = [op for op in new_ops_list if op["name"] == target_op][0]
# Find position of the target op to replace
pos_to_replace = [
i for i, op in enumerate(ops_list) if op["name"] == target_op
][0]
original_op = ops_list[pos_to_replace]
original_model = target_op_config.get("model", global_default_model)
# Create operation names based on the original operation name
split_name = f"split_{target_op}"
gather_name = f"gather_{target_op}"
sample_name = f"sample_{target_op}_chunks"
map_name = f"map_{target_op}_chunks"
reduce_name = f"reduce_{target_op}"
# Create the split operation
split_op = {
"name": split_name,
"type": "split",
"split_key": rewrite.split_key,
"method": "token_count",
"method_kwargs": {
"num_tokens": rewrite.chunk_size,
"model": original_model,
},
}
# Create the gather operation with agent-configured context
# Convert Pydantic model to dict, excluding None values
gather_config_dict = (
rewrite.gather_config.model_dump(exclude_none=True)
if rewrite.gather_config
else {}
)
# Use default config if the gather_config is empty (all fields were None)
if not gather_config_dict:
gather_config_dict = {"previous": {"tail": {"count": 1}}}
gather_op = {
"name": gather_name,
"type": "gather",
"content_key": f"{rewrite.split_key}_chunk",
"doc_id_key": f"{split_name}_id",
"order_key": f"{split_name}_chunk_num",
"peripheral_chunks": gather_config_dict,
}
# Create the map operation for processing chunks
map_op = {
"name": map_name,
"type": "map",
"prompt": rewrite.sub_prompt,
"model": original_model,
"litellm_completion_kwargs": {"temperature": 0},
"output": deepcopy(original_op["output"]), # Same output schema as original
}
# Create the reduce operation
reduce_op = {
"name": reduce_name,
"type": "reduce",
"reduce_key": f"{split_name}_id",
"prompt": rewrite.reduce_prompt,
"model": original_model,
"litellm_completion_kwargs": {"temperature": 0},
"output": deepcopy(original_op["output"]), # Same output schema as original
"associative": False, # Order matters for chunks,
"pass_through": True,
}
# Construct operation sequence
ops_sequence = [split_op, gather_op]
# Add sample operation if sampling config is provided
if rewrite.sampling_config:
sample_op = {
"name": sample_name,
"type": "sample",
"method": rewrite.sampling_config.method,
"samples": rewrite.sampling_config.samples,
}
# Always stratify by split document ID and set samples_per_group
stratify_keys = [f"{split_name}_id"]
# Add agent-specified stratify key if provided
if rewrite.sampling_config.stratify_key:
stratify_keys.append(rewrite.sampling_config.stratify_key)
sample_op["stratify_key"] = stratify_keys
if rewrite.sampling_config.samples_per_group:
sample_op["samples_per_group"] = (
rewrite.sampling_config.samples_per_group
)
# Add optional fields if provided
if rewrite.sampling_config.random_state is not None:
sample_op["random_state"] = rewrite.sampling_config.random_state
if rewrite.sampling_config.method_kwargs is not None:
try:
sample_op["method_kwargs"] = json.loads(
rewrite.sampling_config.method_kwargs
)
except Exception as e:
raise ValueError(f"Invalid method_kwargs: {e}")
ops_sequence.append(sample_op)
# Add map and reduce operations
ops_sequence.extend([map_op, reduce_op])
# Replace the target operation with the new sequence
new_ops_list[pos_to_replace : pos_to_replace + 1] = ops_sequence
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
global_default_model: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there is only one target op
assert (
len(target_ops) == 1
), "There must be exactly one target op to instantiate this document chunking directive"
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
input_file_path = kwargs.get("input_file_path", None)
# Validate that the target operation is a map operation
if target_op_config.get("type") != "map":
raise ValueError(
f"Document chunking directive can only be applied to map operations, but target operation '{target_ops[0]}' is of type '{target_op_config.get('type')}'"
)
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
operators, input_file_path, target_op_config, agent_llm, message_history
)
# Apply the rewrite to the operators
return (
self.apply(global_default_model, operators, target_ops[0], rewrite),
message_history,
call_cost,
)

View File

@ -0,0 +1,619 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
DocumentChunkingTopKInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class DocumentChunkingTopKDirective(Directive):
name: str = Field(
default="doc_chunking_topk", description="The name of the directive"
)
formal_description: str = Field(
default="Map/Filter => Split -> TopK -> Reduce (-> Code Filter if original was Filter)"
)
nl_description: str = Field(
default="Cost optimization directive for documents where only certain portions are relevant to the task (when at least half the document is irrelevant). Works with both Map and Filter operations. Transforms into a retrieval-augmented pipeline: splits documents into chunks, uses topk to retrieve the most relevant chunks, processes them in a reduce operation. For Filter operations, adds a final code_filter step to return boolean results. Ideal when processing full documents would be wasteful due to irrelevant content."
)
when_to_use: str = Field(
default="Use when only certain portions of documents are relevant to the task and at least half of the document content is irrelevant. Perfect for complex filters (e.g., 'does this review mention competitor products more favorably?') or targeted extraction from documents with localized relevant sections. Works with both Map (extraction) and Filter (boolean decision) operations. The retrieval step (embedding or FTS) finds the relevant chunks, avoiding processing irrelevant content. For filters, the final code_filter converts the reduce output to True/False."
)
instantiate_schema_type: Type[BaseModel] = Field(
default=DocumentChunkingTopKInstantiateSchema
)
example: str = Field(
default="""
# Example 1: Complex Filter on Long Customer Reviews
Original Op:
- name: filter_competitor_mentions
type: filter
prompt: |
Analyze this customer review to determine if it mentions competitor products
more positively than our product.
Our Product: {{ input.our_product }}
Review: {{ input.review_text }}
Review ID: {{ input.review_id }}
Return true if the review speaks more favorably about competitor products than ours.
Consider: feature comparisons, performance mentions, value assessments, recommendations.
output:
schema:
mentions_competitors_more_positively: bool
InstantiateSchema (filter with embedding search):
{
"chunk_size": 10000,
"split_key": "review_text",
"reduce_prompt": "Analyze this customer review to determine if it mentions competitor products more positively than our product.\\n\\nOur Product: {{ inputs[0].our_product }}\\nReview: the top {{ inputs|length }} most relevant chunks from the document (ordered by relevance):\\n{% for input in inputs|sort(attribute='_topk_filter_competitor_mentions_chunks_rank') %}\\nChunk (Rank {{ input._topk_filter_competitor_mentions_chunks_rank }}, Score {{ input._topk_filter_competitor_mentions_chunks_score }}):\\n{{ input.review_text_chunk }}\\n{% endfor %}\\nReview ID: {{ inputs[0].review_id }}\\n\\nReturn true if the review speaks more favorably about competitor products than ours.\\nConsider: feature comparisons, performance mentions, value assessments, recommendations.",
"topk_config": {
"method": "embedding",
"k": 5,
"query": "competitor comparison versus alternative better than superior inferior worse features performance value recommendation prefer instead",
"keys": ["review_text_chunk"],
"embedding_model": "text-embedding-3-small"
}
}
# Example 2: Map Operation - Extract Specific Sections from Long Documents
Original Op:
- name: extract_methodology_from_paper
type: map
prompt: |
Extract detailed methodology from this research paper:
Paper: {{ input.paper_content }}
Title: {{ input.title }}
Extract: study design, sample size, data collection methods,
statistical analyses, and validation approaches.
output:
schema:
study_design: str
sample_size: dict
data_collection: list[str]
statistical_methods: list[str]
validation: str
InstantiateSchema (map with embedding search):
{
"chunk_size": 15000,
"split_key": "paper_content",
"reduce_prompt": "Extract detailed methodology from this research paper:\\n\\nPaper: the top {{ inputs|length }} most relevant chunks from the document (ordered by relevance):\\n{% for input in inputs|sort(attribute='_topk_extract_methodology_from_paper_chunks_rank') %}\\nChunk (Rank {{ input._topk_extract_methodology_from_paper_chunks_rank }}, Score {{ input._topk_extract_methodology_from_paper_chunks_score }}):\\n{{ input.paper_content_chunk }}\\n{% endfor %}\\nTitle: {{ inputs[0].title }}\\n\\nExtract: study design, sample size, data collection methods, statistical analyses, and validation approaches.",
"topk_config": {
"method": "embedding",
"k": 8,
"query": "methodology methods study design sample size participants data collection statistical analysis validation procedure protocol experimental",
"keys": ["paper_content_chunk"],
"embedding_model": "text-embedding-3-small"
}
}
# Example 3: Filter with FTS - Check Contract Compliance
Original Op:
- name: filter_contracts_with_liability_caps
type: filter
prompt: |
Determine if this contract contains liability cap provisions
that limit damages to less than $1 million.
Contract: {{ input.contract_text }}
Contract ID: {{ input.contract_id }}
Party: {{ input.counterparty }}
Return true if contract caps liability below $1M, false otherwise.
output:
schema:
has_low_liability_cap: bool
InstantiateSchema (filter with FTS for legal terms):
{
"chunk_size": 12000,
"split_key": "contract_text",
"reduce_prompt": "Determine if this contract contains liability cap provisions that limit damages to less than $1 million.\\n\\nContract: the top {{ inputs|length }} most relevant chunks from the document (ordered by relevance):\\n{% for input in inputs|sort(attribute='_topk_filter_contracts_with_liability_caps_chunks_rank') %}\\nSection (Rank {{ input._topk_filter_contracts_with_liability_caps_chunks_rank }}, Score {{ input._topk_filter_contracts_with_liability_caps_chunks_score }}):\\n{{ input.contract_text_chunk }}\\n{% endfor %}\\nContract ID: {{ inputs[0].contract_id }}\\nParty: {{ inputs[0].counterparty }}\\n\\nReturn true if contract caps liability below $1M, false otherwise.",
"topk_config": {
"method": "fts",
"k": 10,
"query": "liability limitation cap maximum damages indirect consequential million dollars aggregate total exposure indemnification",
"keys": ["contract_text_chunk"]
}
}
""",
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="clinical_trial_adverse_events_extraction",
description="Should transform clinical trial safety analysis into chunking pipeline with embedding-based topk for thematic content",
input_config={
"name": "extract_clinical_trial_safety",
"type": "map",
"prompt": """Analyze this clinical trial protocol and safety report to extract comprehensive safety information:
Protocol Number: {{ input.protocol_id }}
Study Phase: {{ input.study_phase }}
Document: {{ input.trial_document }}
Extract and analyze:
1. ALL adverse events (AEs) with:
- Event description and medical terminology (MedDRA preferred terms)
- Severity grade (1-5 per CTCAE v5.0)
- Relationship to study drug (definitely, probably, possibly, unlikely, not related)
- Onset timing relative to treatment start
- Resolution status and duration
- Actions taken (dose reduced, interrupted, discontinued)
2. Serious adverse events (SAEs) with additional details:
- Hospitalization requirements
- Life-threatening classification
- Death outcomes with causality assessment
- Expedited reporting timeline compliance
3. Laboratory abnormalities:
- Clinically significant lab value shifts
- Grade 3/4 laboratory toxicities
- Liver function test elevations (ALT, AST, bilirubin)
- Renal function changes (creatinine, eGFR)
- Hematologic abnormalities
4. Dose-limiting toxicities (DLTs) and maximum tolerated dose (MTD) determination
5. Safety run-in period results if applicable
6. Data safety monitoring board (DSMB) recommendations and protocol modifications
Ensure all safety data is captured with appropriate medical coding and regulatory compliance.""",
"output": {
"schema": {
"adverse_events": "list[dict]",
"serious_adverse_events": "list[dict]",
"lab_abnormalities": "list[dict]",
"dose_limiting_toxicities": "list[dict]",
"dsmb_recommendations": "list[str]",
"safety_summary": "dict",
}
},
},
target_ops=["extract_clinical_trial_safety"],
expected_behavior="Should create chunking pipeline with topk using embedding search to find sections discussing adverse events, safety data, laboratory results, and DSMB recommendations. Chunks should be 5-8k tokens with k=10-15 to capture all safety-related sections",
should_pass=True,
),
DirectiveTestCase(
name="sec_filing_risk_factors_extraction",
description="Should transform SEC filing analysis into chunking pipeline with FTS-based topk for specific regulatory terms",
input_config={
"name": "extract_sec_risk_disclosures",
"type": "map",
"prompt": """Extract and analyze all risk factor disclosures from this SEC 10-K filing:
Company: {{ input.company_ticker }}
Filing Period: {{ input.filing_period }}
Document: {{ input.form_10k }}
Industry: {{ input.industry_classification }}
Identify and categorize:
1. BUSINESS AND OPERATIONAL RISKS:
- Supply chain vulnerabilities and dependencies
- Key customer concentration (customers >10% revenue)
- Competition and market share threats
- Product obsolescence and innovation risks
- Manufacturing and quality control risks
- Intellectual property disputes and patent expirations
2. FINANCIAL AND MARKET RISKS:
- Liquidity and cash flow concerns
- Debt covenants and refinancing risks
- Foreign exchange exposure by currency
- Interest rate sensitivity analysis
- Credit risk and counterparty exposure
- Goodwill and intangible asset impairment risks
3. REGULATORY AND COMPLIANCE RISKS:
- SEC investigation disclosures
- FDA/regulatory approval dependencies
- Environmental liabilities and remediation costs
- Tax disputes and uncertain tax positions
- FCPA and anti-corruption compliance
- Data privacy (GDPR, CCPA) obligations
4. CYBERSECURITY AND TECHNOLOGY RISKS:
- Data breach history and potential impacts
- IT system dependencies and modernization needs
- Third-party technology provider risks
- Business continuity and disaster recovery
5. LITIGATION AND LEGAL RISKS:
- Material pending litigation with potential damages
- Class action lawsuit exposure
- Warranty and product liability claims
- Employment and labor disputes
6. ESG AND REPUTATIONAL RISKS:
- Climate change physical and transition risks
- Social license to operate concerns
- Executive succession planning
- Related party transaction risks
For each risk, extract:
- Risk description and specific company exposure
- Quantitative impact estimates if disclosed
- Mitigation strategies mentioned
- Changes from prior year disclosure
- Forward-looking statements and warnings""",
"output": {
"schema": {
"business_operational_risks": "list[dict]",
"financial_market_risks": "list[dict]",
"regulatory_compliance_risks": "list[dict]",
"cybersecurity_technology_risks": "list[dict]",
"litigation_legal_risks": "list[dict]",
"esg_reputational_risks": "list[dict]",
"risk_factor_changes": "list[dict]",
"material_risk_summary": "dict",
}
},
},
target_ops=["extract_sec_risk_disclosures"],
expected_behavior="Should create chunking pipeline with topk using FTS to search for specific risk-related keywords and sections (Item 1A, risk factors, legal proceedings, etc.). Chunks should be 6-10k tokens with k=15-20 to ensure comprehensive risk coverage",
should_pass=True,
),
DirectiveTestCase(
name="insurance_claim_analysis_with_dynamic_query",
description="Should transform insurance claim analysis with Jinja template query based on claim type",
input_config={
"name": "analyze_insurance_claim",
"type": "map",
"prompt": """Analyze this insurance claim file for coverage determination and fraud indicators:
Claim Number: {{ input.claim_id }}
Policy Number: {{ input.policy_number }}
Claim Type: {{ input.claim_type }}
Claimed Amount: {{ input.claimed_amount }}
Policy Documents: {{ input.policy_documents }}
Claim Documents: {{ input.claim_submission }}
Prior Claims History: {{ input.claims_history }}
Perform comprehensive analysis:
1. COVERAGE DETERMINATION:
- Verify incident date falls within policy period
- Check specific peril coverage for {{ input.claim_type }} claims
- Identify applicable policy limits and sublimits
- Calculate deductibles and co-insurance
- Review exclusions that may apply
- Assess pre-existing condition clauses (if medical)
- Verify additional living expense limits (if property)
2. CLAIM VALIDATION:
- Cross-reference damage description with photos/evidence
- Verify repair estimates against market rates
- Validate medical treatment necessity and coding
- Check for duplicate submissions or double-dipping
- Verify loss circumstances match policy terms
3. FRAUD INDICATORS ASSESSMENT:
- Pattern analysis against known fraud schemes
- Inconsistencies in statements or documentation
- Suspicious timing (policy inception, premium issues)
- Inflated valuations or treatment costs
- Missing or altered documentation
- Prior suspicious claims pattern
4. THIRD-PARTY LIABILITY:
- Subrogation opportunities
- Other insurance coverage available
- Responsible party identification
- Coordination of benefits requirements
5. REGULATORY COMPLIANCE:
- State-specific claim handling requirements
- Unfair claim settlement practices act compliance
- Required notices and timelines
- Bad faith claim indicators
6. SETTLEMENT RECOMMENDATION:
- Covered amount calculation
- Recommended settlement range
- Payment breakdown by category
- Reserve recommendations
- Special investigation unit (SIU) referral if warranted""",
"output": {
"schema": {
"coverage_analysis": "dict",
"claim_validation": "dict",
"fraud_indicators": "list[dict]",
"third_party_liability": "dict",
"regulatory_compliance": "dict",
"settlement_recommendation": "dict",
"siu_referral": "bool",
"reserve_amount": "float",
}
},
},
target_ops=["analyze_insurance_claim"],
expected_behavior="Should create chunking pipeline with topk using dynamic Jinja query that incorporates claim_type to search for relevant policy sections and prior claims. Query should adapt based on whether it's property, auto, medical, or liability claim. Chunks should be 5-7k tokens with k=12-18",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, DocumentChunkingTopKDirective)
def __hash__(self):
return hash("DocumentChunkingTopKDirective")
def to_string_for_instantiate(self, original_op: Dict) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
original_op (Dict): The original operation.
Returns:
str: The agent prompt for instantiating the directive.
"""
op_type = original_op.get("type", "map")
return (
f"You are an expert at transforming document processing operations into chunking pipelines with intelligent topk-based retrieval.\n\n"
f"Original Operation:\n"
f"{str(original_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to instantiate this directive by creating a configuration that transforms the original {op_type.capitalize()} operation "
f"into a Split -> TopK -> Reduce pipeline for processing very long documents with intelligent chunk selection.\n"
f"{'For Filter operations, a final code_filter step will be automatically added to return boolean results.' if op_type == 'filter' else ''}\n\n"
f"Key requirements:\n"
f"1. chunk_size: Choose an appropriate token count (typically 10000-15000) for cost-effective processing of long documents\n"
f"2. split_key: Identify the document field to split from the original operation's prompt (the longest text field)\n"
f"3. reduce_prompt: Use the EXACT SAME prompt as the original, with ONE change:\n"
f" - Where the original references '{{{{ input.<split_key> }}}}', replace it with:\n"
f" 'the top {{{{ inputs|length }}}} most relevant chunks from the document (ordered by relevance):\\n{{% for input in inputs|sort(attribute='_<topk_name>_rank') %}}\\nChunk (Rank {{{{ input._<topk_name>_rank }}}}, Score {{{{ input._<topk_name>_score }}}}):\\n{{{{ input.<split_key>_chunk }}}}\\n{{% endfor %}}'\n"
f" - Keep EVERYTHING else identical - same instructions, same output requirements\n"
f" - For other context fields (non-document fields), use {{{{ inputs[0].field_name }}}} instead of {{{{ input.field_name }}}}\n"
f"4. topk_config: REQUIRED - Configure intelligent chunk selection:\n"
f" - method: Choose 'embedding' for semantic similarity or 'fts' for keyword matching\n"
f" * Use 'embedding' when: looking for conceptual comparisons, themes, or abstract relationships\n"
f" * Use 'fts' when: searching for specific terms, legal clauses, technical codes, or exact phrases\n"
f" - k: Number of chunks to retrieve (typically 5-10 for comprehensive coverage)\n"
f" * For complex tasks needing most of the document as context: k=10\n"
f" * For targeted extraction from specific sections: k=5\n"
f" - query: Craft carefully to find chunks relevant to the {'filter decision' if op_type == 'filter' else 'extraction task'}\n"
f" * For embedding: use terms related to the comparison/decision criteria\n"
f" * For fts: use specific keywords that appear in relevant sections\n"
f" * Can use Jinja: '{{{{ input.competitor_name }}}} comparison versus {{{{ input.our_product }}}}'\n"
f" - keys: Always use the chunk key, typically ['<split_key>_chunk']\n"
f" - embedding_model: (optional, only for embedding method) defaults to 'text-embedding-3-small'\n"
f"The topk query should be carefully crafted to find the most relevant chunks.\n"
f"The reduce_prompt must process chunks directly and {'output a boolean decision' if op_type == 'filter' else 'preserve the original output schema'}.\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please output only the DocumentChunkingTopKInstantiateSchema object as JSON."
)
def llm_instantiate(
self,
operators,
input_file_path,
original_op: Dict,
agent_llm: str,
message_history: list = [],
):
"""
Use LLM to instantiate this directive by creating chunking configuration with topk.
Args:
original_op (Dict): The original operation.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
Returns:
DocumentChunkingTopKInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "user",
"content": self.to_string_for_instantiate(original_op),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=DocumentChunkingTopKInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = DocumentChunkingTopKInstantiateSchema(**parsed_res)
schema.validate_split_key_exists_in_input(input_file_path)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model: str,
ops_list: List[Dict],
target_op: str,
rewrite: DocumentChunkingTopKInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config by replacing the target operation
with a split -> topk -> reduce sequence.
"""
# Create a copy of the pipeline config
new_ops_list = deepcopy(ops_list)
target_op_config = [op for op in new_ops_list if op["name"] == target_op][0]
original_model = target_op_config.get("model", global_default_model)
# Find position of the target op to replace
pos_to_replace = [
i for i, op in enumerate(ops_list) if op["name"] == target_op
][0]
original_op = ops_list[pos_to_replace]
# Create operation names based on the original operation name
split_name = f"split_{target_op}"
topk_name = f"topk_{target_op}_chunks"
reduce_name = f"reduce_{target_op}"
# Create the split operation
split_op = {
"name": split_name,
"type": "split",
"split_key": rewrite.split_key,
"method": "token_count",
"method_kwargs": {
"num_tokens": rewrite.chunk_size,
"model": original_model,
},
}
# Create the topk operation
topk_op = {
"name": topk_name,
"type": "topk",
"method": rewrite.topk_config.method,
"k": rewrite.topk_config.k,
"keys": rewrite.topk_config.keys,
"query": rewrite.topk_config.query,
"stratify_key": [f"{split_name}_id"],
}
# Add stratify_key if specified
if rewrite.topk_config.stratify_key:
topk_op["stratify_key"] = topk_op["stratify_key"] + [
rewrite.topk_config.stratify_key
]
# Add embedding_model for embedding method
if (
rewrite.topk_config.method == "embedding"
and rewrite.topk_config.embedding_model
):
topk_op["embedding_model"] = rewrite.topk_config.embedding_model
# Check if original operation is a filter
is_filter = original_op.get("type") == "filter"
# Create the reduce operation that directly processes the chunks
if is_filter:
# For filter operations, reduce should output a boolean field
reduce_output = deepcopy(original_op["output"])
# Ensure we have a boolean field in the output schema
if "schema" in reduce_output:
# Get the first boolean field name from the schema
bool_field = None
for field_name, field_type in reduce_output["schema"].items():
if "bool" in field_type.lower():
bool_field = field_name
break
if not bool_field:
# If no boolean field found, create one
bool_field = "filter_result"
reduce_output["schema"] = {bool_field: "bool"}
else:
reduce_output = deepcopy(original_op["output"])
reduce_op = {
"name": reduce_name,
"type": "reduce",
"reduce_key": f"{split_name}_id",
"prompt": rewrite.reduce_prompt,
"model": original_model,
"litellm_completion_kwargs": {"temperature": 0},
"output": reduce_output,
"associative": False, # Order matters for chunks
"pass_through": True,
}
# Construct operation sequence
ops_sequence = [split_op, topk_op, reduce_op]
# If original was a filter, add a code_filter operation
if is_filter:
# Find the boolean field name in the output schema
bool_field = None
if "schema" in reduce_output:
for field_name, field_type in reduce_output["schema"].items():
if "bool" in field_type.lower():
bool_field = field_name
break
if bool_field:
code_filter_op = {
"name": f"code_filter_{target_op}",
"type": "code_filter",
"code": f"def transform(input_doc):\n return input_doc.get('{bool_field}', False)",
}
ops_sequence.append(code_filter_op)
# Replace the target operation with the new sequence
new_ops_list[pos_to_replace : pos_to_replace + 1] = ops_sequence
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
global_default_model: str = None,
**kwargs,
) -> tuple:
"""
Instantiate the directive for a list of operators.
"""
# Assert that there is only one target op
assert (
len(target_ops) == 1
), "There must be exactly one target op to instantiate this document chunking topk directive"
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
input_file_path = kwargs.get("input_file_path", None)
# Validate that the target operation is a map or filter operation
if target_op_config.get("type") not in ["map", "filter"]:
raise ValueError(
f"Document chunking topk directive can only be applied to map or filter operations, but target operation '{target_ops[0]}' is of type '{target_op_config.get('type')}'"
)
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
operators, input_file_path, target_op_config, agent_llm, message_history
)
# Apply the rewrite to the operators
return (
self.apply(global_default_model, operators, target_ops[0], rewrite),
message_history,
call_cost,
)

View File

@ -0,0 +1,254 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
DocCompressionInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class DocCompressionDirective(Directive):
name: str = Field(
default="doc_compression", description="The name of the directive"
)
formal_description: str = Field(default="Op => Extract -> Op")
nl_description: str = Field(
default="Reduces LLM processing costs by using an Extract operator to intelligently compress documents before expensive downstream operations, removing irrelevant content that could distract the LLM"
)
when_to_use: str = Field(
default="When documents contain irrelevant content and you want to reduce token costs for downstream LLM operations while improving accuracy by having the LLM focus on only the essential content"
)
instantiate_schema_type: Type[BaseModel] = Field(
default=DocCompressionInstantiateSchema
)
example: str = Field(
default="""
Target Operations:
- name: analyze_regulatory_impact
type: map
prompt: |
Analyze the potential regulatory impact described in: {{ input.legal_document }}
Consider stakeholder groups, compliance burdens, and implementation feasibility.
output:
schema:
stakeholder_impacts: "list[str]"
compliance_changes: "string"
- name: extract_key_dates
type: map
prompt: |
Extract important deadlines and dates from: {{ input.legal_document }}
output:
schema:
deadlines: "list[str]"
Example InstantiateSchema (what the agent should output):
DocCompressionConfig(
name="extract_regulatory_content",
document_key="legal_document",
prompt="Extract the minimal content necessary spanning: sections defining new regulatory requirements, stakeholder obligations, compliance deadlines, implementation timelines, and enforcement mechanisms. Focus only on substantive regulatory changes and specific dates, not background or procedural text.",
model="gpt-4o-mini"
)
"""
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="single_target_compression",
description="Should insert Extract operation before single target operation",
input_config={
"name": "analyze_document",
"type": "map",
"prompt": "Analyze this document: {{ input.document }}",
"output": {"schema": {"analysis": "string"}},
},
target_ops=["analyze_document"],
expected_behavior="Should add an Extract operation that compresses the document field before the analysis. The extract operation document_keys should be 'document' only.",
should_pass=True,
),
DirectiveTestCase(
name="multiple_target_compression",
description="Should insert Extract operation before first target operation and consider all targets",
input_config=[
{
"name": "extract_findings",
"type": "map",
"prompt": "Extract key findings from: {{ input.report }}",
"output": {"schema": {"findings": "list[str]"}},
},
{
"name": "analyze_impact",
"type": "map",
"prompt": "Analyze the business impact in: {{ input.report }}",
"output": {"schema": {"impact": "string"}},
},
],
target_ops=["extract_findings", "analyze_impact"],
expected_behavior="Should add Extract operation before extract_findings that considers content needed for both operations. The extract operation document_keys should be 'report' only.",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, DocCompressionDirective)
def __hash__(self):
return hash("DocCompressionDirective")
def to_string_for_instantiate(self, target_ops_configs: List[Dict]) -> str:
"""
Generate a prompt that asks the agent to output the instantiate schema.
This prompt explains to the LLM what configuration it needs to generate.
"""
ops_str = "\n".join(
[
f"Operation {i+1}:\n{str(op)}\n"
for i, op in enumerate(target_ops_configs)
]
)
return (
f"You are an expert at document analysis and optimization.\n\n"
f"Target Operations:\n"
f"{ops_str}\n\n"
f"Directive: {self.name}\n"
f"Your task is to instantiate this directive by generating a DocCompressionConfig "
f"that specifies how to compress the input document before processing.\n\n"
f"The directive will insert an Extract operation that:\n"
f"1. Takes a long document field from the input\n"
f"2. Extracts only the MINIMAL content relevant to ALL the target operations\n"
f"3. Replaces the original document field with the compressed content\n"
f"4. Reduces token usage and improves focus for the downstream operations\n\n"
f"The agent must output the configuration specifying:\n"
f"- name: A descriptive name for the Extract operation\n"
f"- document_key: Which document field contains the long content to compress\n"
f"- prompt: Plain text instructions for what MINIMAL content to extract (NOT a Jinja template)\n"
f"- model: Which model to use for extraction (typically a cheaper model like gpt-4o-mini)\n\n"
f"IMPORTANT: The extraction prompt should focus on extracting the minimal content necessary "
f"for ALL target operations. Analyze each operation's prompt to identify the "
f"specific information types needed across all operations, then design an extraction prompt "
f"that gets just those essential pieces while removing all irrelevant material.\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please output only the InstantiateSchema (DocCompressionConfig object) "
f"that specifies how to apply this directive to the target operations."
)
def llm_instantiate(
self,
target_ops_configs: List[Dict],
input_file_path: str,
agent_llm: str,
message_history: list = [],
):
"""
Call the LLM to generate the instantiate schema.
The LLM will output structured data matching DocCompressionInstantiateSchema.
"""
message_history.extend(
[
{
"role": "user",
"content": self.to_string_for_instantiate(target_ops_configs),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=DocCompressionInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = DocCompressionInstantiateSchema(**parsed_res)
schema.validate_document_keys_exists_in_input(input_file_path)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model: str,
ops_list: List[Dict],
target_ops: List[str],
rewrite: DocCompressionInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive using the instantiate schema configuration.
Inserts an Extract operation before the first target operation.
"""
new_ops_list = deepcopy(ops_list)
# Find the position of the first target operation
first_target_pos = min(
[i for i, op in enumerate(ops_list) if op["name"] in target_ops]
)
extract_op = {
"name": rewrite.name,
"type": "extract",
"prompt": rewrite.prompt,
"document_keys": [rewrite.document_key],
"litellm_completion_kwargs": {"temperature": 0},
"model": rewrite.model,
}
# Insert the Extract operation before the first target operation
new_ops_list.insert(first_target_pos, extract_op)
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
global_default_model: str = None,
**kwargs,
):
"""
Main method that orchestrates directive instantiation:
1. Get agent to generate instantiate schema for all target operations
2. Apply the transformation using that schema
"""
assert len(target_ops) >= 1, "This directive requires at least one target op"
input_file_path = kwargs.get("input_file_path", None)
# Get configurations for all target operations
target_ops_configs = [op for op in operators if op["name"] in target_ops]
# Step 1: Agent generates the instantiate schema considering all target ops
rewrite, message_history, call_cost = self.llm_instantiate(
target_ops_configs, input_file_path, agent_llm, message_history
)
# Step 2: Apply transformation using the schema
return (
self.apply(global_default_model, operators, target_ops, rewrite),
message_history,
call_cost,
)

View File

@ -0,0 +1,345 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
DocSummarizationInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class DocSummarizationDirective(Directive):
name: str = Field(
default="doc_summarization", description="The name of the directive"
)
formal_description: str = Field(default="Op => Map -> Op")
nl_description: str = Field(
default=(
"Adds a Map summarization operator at the very beginning of the pipeline to shorten the document before any downstream operations. "
"This reduces the number of tokens processed in later steps, saving cost and improving efficiency. "
"The summary is constructed to include all information required by any downstream operator. "
"Target operations should be all operators that reference the document key being summarized "
"(e.g., all ops using {{ input.transcript }})."
)
)
when_to_use: str = Field(
default=(
"Use when documents are too long or detailed for the downstream pipeline. "
"Summarization should preserve essential information and make subsequent tasks more efficient. "
"Target ops should include all operators that use the document key being summarized, and the summary model should be cheap."
)
)
instantiate_schema_type: Type[BaseModel] = Field(
default=DocSummarizationInstantiateSchema
)
example: str = Field(
default="""
"Original Pipeline - Complex medical reasoning across multiple operators:\n"
"[\n"
" {\n"
" name: 'assess_drug_interactions',\n"
" type: 'map',\n"
" prompt: 'Analyze potential drug interactions from this consultation: {{ input.transcript }}. Consider contraindications.',\n"
" output: { schema: { interaction_risks: 'list[dict]' } }\n"
" },\n"
" {\n"
" name: 'predict_side_effects',\n"
" type: 'map', \n"
" prompt: 'Based on the transcript, predict likely side effects for this patient: {{ input.transcript }}. Consider patient demographics and medical history.',\n"
" output: { schema: { predicted_effects: 'list[dict]' } }\n"
" },\n"
" {\n"
" name: 'generate_monitoring_plan',\n"
" type: 'map',\n"
" prompt: 'Create a monitoring plan based on this consultation: {{ input.transcript }}. Focus on symptoms to watch for.',\n"
" output: { schema: { monitoring_plan: 'string' } }\n"
" }\n"
"]\n"
"\n"
"Problem: 50-page transcript with scheduling, insurance, small talk distracts from medical reasoning.\n"
"\n"
"Example InstantiateSchema:\n"
"[\n"
" DocSummarizationConfig(\n"
" name='extract_medical_essentials',\n"
" document_key='transcript',\n"
" prompt='Extract a summary of medical information from this consultation transcript for drug interaction analysis, side effect prediction, and monitoring plan creation. Include: all medications with dosages, patient complaints/symptoms, medical history, current conditions, patient demographics (age, weight), allergies, and any contraindications mentioned. Exclude scheduling, insurance, and casual conversation: {{ input.transcript }}',\n"
" model='gpt-4o-mini'\n"
" )\n"
"]\n"
"\n"
"Result: All three reasoning operations work on focused medical facts, dramatically improving accuracy."
""",
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="medical_pipeline_analysis",
description="Should add summarization for multi-step medical reasoning pipeline",
input_config=[
{
"name": "assess_drug_interactions",
"type": "map",
"prompt": "Analyze potential drug interactions from this consultation: {{ input.transcript }}. Consider contraindications and patient allergies.",
"output": {"schema": {"interaction_risks": "list[str]"}},
},
{
"name": "predict_side_effects",
"type": "map",
"prompt": "Predict likely side effects for this patient based on: {{ input.transcript }}. Consider age, weight, and medical history.",
"output": {"schema": {"predicted_effects": "list[str]"}},
},
{
"name": "create_monitoring_plan",
"type": "map",
"prompt": "Create patient monitoring plan from: {{ input.transcript }}. Focus on symptoms to watch and lab work needed.",
"output": {"schema": {"monitoring_plan": "string"}},
},
],
target_ops=[
"assess_drug_interactions",
"predict_side_effects",
"create_monitoring_plan",
],
expected_behavior="Should add summarization that extracts comprehensive medical information (medications, dosages, allergies, demographics, symptoms, history) needed for all three downstream reasoning tasks",
should_pass=True,
),
DirectiveTestCase(
name="legal_contract_pipeline",
description="Should add summarization for multi-step legal analysis pipeline",
input_config=[
{
"name": "identify_liability_risks",
"type": "map",
"prompt": "Identify liability and indemnification risks in: {{ input.contract_document }}. Focus on limitation of liability clauses.",
"output": {"schema": {"liability_risks": "list[str]"}},
},
{
"name": "analyze_termination_terms",
"type": "map",
"prompt": "Analyze termination and breach conditions in: {{ input.contract_document }}. Include notice requirements and penalties.",
"output": {"schema": {"termination_analysis": "string"}},
},
{
"name": "assess_compliance_requirements",
"type": "map",
"prompt": "Assess regulatory compliance obligations from: {{ input.contract_document }}. Include data protection and industry standards.",
"output": {"schema": {"compliance_requirements": "list[str]"}},
},
],
target_ops=[
"identify_liability_risks",
"analyze_termination_terms",
"assess_compliance_requirements",
],
expected_behavior="Should add summarization that extracts all legally relevant clauses, terms, obligations, penalties, and compliance requirements needed across the legal analysis pipeline",
should_pass=True,
),
DirectiveTestCase(
name="financial_analysis_pipeline",
description="Should add summarization for multi-step investment analysis pipeline",
input_config=[
{
"name": "evaluate_financial_health",
"type": "map",
"prompt": "Evaluate company financial health from: {{ input.annual_report }}. Calculate key ratios and assess profitability trends.",
"output": {
"schema": {
"financial_metrics": "str",
"health_score": "float",
}
},
},
{
"name": "assess_market_position",
"type": "map",
"prompt": "Assess competitive market position using: {{ input.annual_report }}. Analyze market share and competitive advantages.",
"output": {
"schema": {
"market_analysis": "string",
"competitive_strengths": "list[str]",
}
},
},
{
"name": "identify_growth_risks",
"type": "map",
"prompt": "Identify growth opportunities and risks from: {{ input.annual_report }}. Include regulatory and market risks.",
"output": {
"schema": {
"growth_opportunities": "list[str]",
"risk_factors": "list[str]",
}
},
},
],
target_ops=[
"evaluate_financial_health",
"assess_market_position",
"identify_growth_risks",
],
expected_behavior="Should add summarization that extracts financial data, market information, competitive landscape, and risk factors needed for comprehensive investment analysis",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, DocSummarizationDirective)
def __hash__(self):
return hash("DocSummarizationDirective")
def to_string_for_instantiate(
self, operators: List[Dict], target_ops: List[str]
) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
operators: List of all operators in the pipeline
target_ops: List of target operation names that need the summarized content
Returns:
str: The agent prompt for instantiating the directive.
"""
return (
f"You are an expert at adding document summarization to data processing pipelines.\n\n"
f"Full Pipeline Context:\n"
f"{str(operators)}\n\n"
f"Target Operations (that will use summarized content): {target_ops}\n\n"
f"Directive: {self.name}\n"
f"Your task is to instantiate this directive by generating a DocSummarizationConfig that adds a Map summarization operator "
f"at the start of the pipeline.\n\n"
f"Analysis steps:\n"
f"1. Identify which input field contains long documents that could benefit from summarization\n"
f"2. Analyze ALL target operations' prompts to understand what information each needs\n"
f"3. Create a comprehensive summarization prompt that preserves ALL information needed by ANY target operation\n"
f"4. Ensure the summary contains sufficient detail for all downstream reasoning tasks\n\n"
f"The document_key should be the field name containing the long content to summarize.\n"
f"The prompt should instruct the LLM to extract and preserve ALL information types needed across ALL target operations.\n"
f"The output will replace the original document field with the summarized version.\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please output only the InstantiateSchema (a DocSummarizationConfig object)."
)
def llm_instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
):
"""
Use LLM to instantiate this directive by creating a summarization operation.
Args:
operators (List[Dict]): All operators in the pipeline.
target_ops (List[str]): Target operation names that need summarized content.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
Returns:
DocSummarizationInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "user",
"content": self.to_string_for_instantiate(operators, target_ops),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=DocSummarizationInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = DocSummarizationInstantiateSchema(**parsed_res)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model: str,
ops_list: List[Dict],
target_op: str,
rewrite: DocSummarizationInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config by adding a summarization Map operator at the start.
"""
# Create a copy of the pipeline config
new_ops_list = deepcopy(ops_list)
# Create the summarization Map operator using the LLM-generated name
summarization_op = {
"name": rewrite.name,
"type": "map",
"prompt": rewrite.prompt,
"model": rewrite.model,
"litellm_completion_kwargs": {"temperature": 0},
"output": {"schema": {rewrite.document_key: "string"}},
}
# Insert the summarization operator at the beginning of the pipeline
new_ops_list.insert(0, summarization_op)
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
global_default_model: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Validate that target ops exist in the pipeline
operator_names = {op["name"] for op in operators}
for target_op in target_ops:
if target_op not in operator_names:
raise ValueError(
f"Target operation '{target_op}' not found in pipeline"
)
# Instantiate the directive using full pipeline context
rewrite, message_history, call_cost = self.llm_instantiate(
operators, target_ops, agent_llm, message_history
)
# Apply the rewrite to the operators (use first target op for compatibility with apply method)
return (
self.apply(global_default_model, operators, target_ops[0], rewrite),
message_history,
call_cost,
)

View File

@ -0,0 +1,231 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import GleaningInstantiateSchema
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class GleaningDirective(Directive):
name: str = Field(default="gleaning", description="The name of the directive")
formal_description: str = Field(default="Map => Map_m (with gleaning config)")
nl_description: str = Field(
default="""Adds a validation loop to Map: after each LLM generation, a separate "judge" LLM evaluates the output using a yes/no validation prompt. If the output fails, the original LLM refines its answer and repeats until the output passes or the max number of rounds is reached."""
)
when_to_use: str = Field(
default="When initial Map outputs may not meet quality criteria and must be checked or improved automatically (e.g., too short, missing required info)."
)
# Remove from Pydantic fields, make it a plain class variable
instantiate_schema_type: Type[BaseModel] = Field(default=GleaningInstantiateSchema)
example: str = Field(
default="""
Original Op (MapOpConfig):
- name: extract_insights
type: map
prompt: |
From the user log below, list 2-3 concise insights (1-2 words each) and 1-2 supporting actions per insight.
Return as a list of dictionaries with 'insight' and 'supporting_actions'.
Log: {{ input.log }}
output:
schema:
insights_summary: "string"
Example InstantiateSchema (what the agent should output):
{
"validation_prompt": "There should be at least 2 insights, and each insight should have at least 1 supporting action.",
"num_rounds": 2,
"model": "gpt-4o-mini"
}
""",
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="quality_validation_insights",
description="Should add validation for insight extraction quality",
input_config={
"name": "extract_insights",
"type": "map",
"prompt": "From the user log below, list 2-3 concise insights and supporting actions: {{ input.log }}",
"output": {"schema": {"insights_summary": "string"}},
},
target_ops=["extract_insights"],
expected_behavior="Should add gleaning config with validation prompt checking for minimum number of insights and supporting actions",
should_pass=True,
),
DirectiveTestCase(
name="contract_term_validation",
description="Should add validation for contract term extraction completeness",
input_config={
"name": "extract_contract_terms",
"type": "map",
"prompt": "Extract payment terms, deadlines, and penalties from: {{ input.contract }}",
"output": {"schema": {"terms": "list[str]"}},
},
target_ops=["extract_contract_terms"],
expected_behavior="Should add gleaning validation to ensure all required term types are extracted",
should_pass=True,
),
DirectiveTestCase(
name="financial_report_analysis",
description="Should add validation for financial analysis accuracy",
input_config={
"name": "analyze_financial_report",
"type": "map",
"prompt": "Extract revenue, expenses, profit margins, and key financial ratios from: {{ input.report }}",
"output": {"schema": {"financial_analysis": "string"}},
},
target_ops=["analyze_financial_report"],
expected_behavior="Should add gleaning validation to ensure all financial metrics are accurately extracted and calculated",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, GleaningDirective)
def __hash__(self):
return hash("GleaningDirective")
def to_string_for_instantiate(self, original_op: Dict) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
original_op (str): The YAML or string representation of the original operation.
Returns:
str: The agent prompt for instantiating the directive.
"""
return (
f"You are an expert at adding validation and refinement loops to data processing operations.\n\n"
f"Original Operation:\n"
f"{str(original_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to instantiate this directive by generating a configuration that adds validation loops to the original operation. "
f"The gleaning configuration should include a validation prompt that evaluates the output quality and provides feedback for improvement, "
f"along with the number of refinement rounds to attempt. In the prompt, you shuold not use any Jinja variables. \n"
f"Example:\n"
f"{self.example}\n\n"
f"Please output only the GleaningInstantiateSchema object that specifies how to validate and refine the output of the original operation."
)
def llm_instantiate(
self,
original_op: Dict,
agent_llm: str,
message_history: list = [],
):
"""
Use LLM to instantiate this directive by decomposing the original operation.
Args:
original_op (Dict): The original operation.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
Returns:
GleaningInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "user",
"content": self.to_string_for_instantiate(original_op),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=GleaningInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = GleaningInstantiateSchema(**parsed_res)
GleaningInstantiateSchema.check_no_jinja_variables(
schema.validation_prompt
)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
default_model: str,
ops_list: List[Dict],
target_op: str,
rewrite: GleaningInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config by adding gleaning configuration to the target operator.
"""
# Create a copy of the pipeline config
new_ops_list = deepcopy(ops_list)
# Find position of the target op to modify
pos_to_replace = [
i for i, op in enumerate(ops_list) if op["name"] == target_op
][0]
# Add gleaning configuration to the target operator
target_operator = new_ops_list[pos_to_replace]
target_operator["gleaning"] = {
"validation_prompt": rewrite.validation_prompt,
"num_rounds": rewrite.num_rounds,
"model": rewrite.model,
}
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
global_default_model: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there is only one target op
assert (
len(target_ops) == 1
), "There must be exactly one target op to instantiate this chaining directive"
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
target_op_config, agent_llm, message_history
)
# Apply the rewrite to the operators
return (
self.apply(global_default_model, operators, target_ops[0], rewrite),
message_history,
call_cost,
)

View File

@ -0,0 +1,322 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
HierarchicalReduceInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class HierarchicalReduceDirective(Directive):
name: str = Field(
default="hierarchical_reduce", description="The name of the directive"
)
formal_description: str = Field(
default="Reduce => (Map* ->) Reduce -> Reduce (optionally with Map before first Reduce)"
)
nl_description: str = Field(
default="Transform a reduce operation that aggregates large groups of documents by first aggregating at a finer granularity (reduce_key + additional_key), then rolling up to the desired level (reduce_key only). This hierarchical approach can capture nuances that might be lost in a single large-scale aggregation and allows for intermediate validation."
)
when_to_use: str = Field(
default="When a reduce operation processes many documents per group and it would be beneficial to first aggregate at a finer granularity before rolling up. Useful when there's a semantic hierarchy in the data (e.g., aggregate by state+city first, then by state only) or when you want to prevent information loss in large-scale aggregations. The target operator must be a reduce operator."
)
instantiate_schema_type: Type[BaseModel] = HierarchicalReduceInstantiateSchema
example: str = Field(
default=(
"Original Reduce Op:\n"
"- name: summarize_by_state\n"
" type: reduce\n"
" reduce_key: state\n"
" prompt: |\n"
" Summarize voting patterns from these social media posts:\n"
" {% for input in inputs %}\n"
" Post: {{ input.content }}\n"
" {% endfor %}\n"
" Return a summary of voting patterns.\n"
" output:\n"
" schema:\n"
" summary: string\n"
"\n"
"Example InstantiateSchema (with Map for synthetic key):\n"
"HierarchicalReduceInstantiateSchema(\n"
" map_config=MapOpConfig(\n"
" name='extract_city',\n"
" prompt='Extract the city mentioned in this post:\\n{{ input.content }} made in this state:\\n{{ input.state }}\\nReturn the city name or \"Unknown\" if not found.',\n"
" output_keys=['city'],\n"
" ),\n"
" additional_key='city',\n"
" reduce_1_name='summarize_by_state_city',\n"
" # First reduce: Process raw posts at city level\n"
" reduce_1_prompt='Goal: Summarize voting patterns from social media posts to understand public sentiment.\\n\\nFor this state and city, analyze these posts:\\n{% for input in inputs %}\\nPost: {{ input.content }}\\n{% endfor %}\\nReturn a summary of voting patterns and key themes.',\n"
" # Second reduce: Explicitly work with summaries from first reduce\n"
" reduce_2_prompt='Goal: Summarize voting patterns from social media posts to understand public sentiment.\\n\\nWe have already summarized voting patterns at the city level. Your task is to combine these city-level summaries into a comprehensive state-level analysis:\\n{% for input in inputs %}\\nCity: {{ input.city }}\\nCity-Level Summary: {{ input.summary }}\\n{% endfor %}\\nSynthesize these city summaries into a unified state-level summary of voting patterns.',\n"
")\n"
"\n"
"Example InstantiateSchema (using existing key):\n"
"HierarchicalReduceInstantiateSchema(\n"
" map_config=None, # No Map needed when using existing key\n"
" additional_key='county', # Assuming 'county' already exists in the data\n"
" reduce_1_name='summarize_by_state_county',\n"
" # First reduce: Process raw posts at county level\n"
" reduce_1_prompt='Goal: Summarize voting patterns from social media posts.\\n\\nAnalyze posts for this state and county:\\n{% for input in inputs %}\\nPost: {{ input.content }}\\n{% endfor %}\\nReturn voting pattern summary for this county.',\n"
" # Second reduce: Explicitly work with county summaries\n"
" reduce_2_prompt='Goal: Summarize voting patterns from social media posts.\\n\\nWe have already analyzed voting patterns at the county level. Your task is to synthesize these county-level summaries into a state-level overview:\\n{% for input in inputs %}\\nCounty: {{ input.county }}\\nCounty Analysis: {{ input.summary }}\\n{% endfor %}\\nCombine these county analyses into a comprehensive state voting pattern summary.',\n"
")"
),
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="voting_pattern_aggregation",
description="Should create hierarchical aggregation for voting patterns",
input_config={
"name": "analyze_voting",
"type": "reduce",
"reduce_key": "state",
"prompt": "Analyze voting sentiments from these posts:\n{% for input in inputs %}\nPost: {{ input.post }}\n{% endfor %}\nReturn voting sentiment analysis.",
"output": {"schema": {"analysis": "string"}},
},
target_ops=["analyze_voting"],
expected_behavior="Should create two reduce operations: first by state+additional_key, then by state only, optionally with a Map to create synthetic keys",
should_pass=True,
),
DirectiveTestCase(
name="sales_hierarchical_aggregation",
description="Should create hierarchical aggregation for sales data",
input_config={
"name": "aggregate_sales",
"type": "reduce",
"reduce_key": "region",
"prompt": "Aggregate sales data from these records:\n{% for input in inputs %}\nSales: {{ input.sales_data }}\n{% endfor %}\nReturn total sales metrics.",
"output": {"schema": {"metrics": "object"}},
},
target_ops=["aggregate_sales"],
expected_behavior="Should create hierarchical reduce pattern for sales aggregation",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, HierarchicalReduceDirective)
def __hash__(self):
return hash("HierarchicalReduceDirective")
def to_string_for_instantiate(self, original_op: Dict) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
original_op (Dict): The original reduce operation.
Returns:
str: The agent prompt for instantiating the directive.
"""
return (
f"You are an expert at optimizing data processing operations using hierarchical aggregation patterns.\n\n"
f"Original Operation:\n"
f"{str(original_op)}\n"
f"Directive: {self.name}\n"
f"Your task is to instantiate this directive by creating a hierarchical reduce pattern that:\n"
f"1. First aggregates data at a finer granularity (original reduce_key + additional_key)\n"
f"2. Then rolls up these finer-grained aggregations to the desired level (original reduce_key only)\n\n"
f"Key Requirements:\n"
f"1. Identify or create an appropriate additional key for finer granularity:\n"
f" - Check if there's an existing key in the data that provides meaningful sub-grouping\n"
f" - If no suitable key exists, create a Map operation to synthesize one (e.g., extract city from text)\n"
f"2. Adapt the original reduce prompt for both aggregation levels:\n"
f" - reduce_1_prompt: Should aggregate at the finer granularity (both keys) from raw data\n"
f" - reduce_2_prompt: Should combine the outputs from reduce_1 to the target granularity\n"
f"3. IMPORTANT: Both reduce prompts should understand the overall goal from the original operation:\n"
f" - Each reduce operation runs independently without access to the original prompt\n"
f" - Both prompts should share a common context/prefix that explains the overall goal\n"
f" - The second reduce MUST explicitly acknowledge it's working with summaries/aggregations from the first reduce\n"
f" - Example: 'We have already summarized X at the Y level. Your task is to combine these Y-level summaries...'\n"
f"4. Both reduce operations should produce the same output schema as the original\n"
f"5. The reduce_2_prompt must reference the outputs from reduce_1, not the original documents\n\n"
f"This hierarchical approach is especially useful when:\n"
f"- There are many documents per group in the original reduce\n"
f"- There's a natural hierarchy in the data (geographic, temporal, categorical)\n"
f"- You want to prevent information loss in large-scale aggregations\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please output the HierarchicalReduceInstantiateSchema with the hierarchical aggregation details."
)
def llm_instantiate(
self, original_op: Dict, agent_llm: str, message_history: list = []
):
"""
Use LLM to instantiate this directive by creating a hierarchical reduce pattern.
Args:
original_op (Dict): The original reduce operation.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
Returns:
HierarchicalReduceInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for document processing pipelines specializing in hierarchical aggregation patterns.",
},
{
"role": "user",
"content": self.to_string_for_instantiate(original_op),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=HierarchicalReduceInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = HierarchicalReduceInstantiateSchema(**parsed_res)
# Validate that if map_config is provided, additional_key should match one of the output keys
if schema.map_config:
if len(schema.map_config.output_keys) != 1:
raise ValueError(
"Map config must have exactly one output key for hierarchical reduce"
)
if schema.additional_key != schema.map_config.output_keys[0]:
raise ValueError(
f"When creating a synthetic key with Map, additional_key must match the map output key '{schema.map_config.output_keys[0]}'"
)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model,
ops_list: List[Dict],
target_op: str,
rewrite: HierarchicalReduceInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config.
"""
# Create a copy of the pipeline config
new_ops_list = deepcopy(ops_list)
# Find position of the target reduce op to modify
pos_to_modify = None
orig_op = None
for i, op in enumerate(ops_list):
if op["name"] == target_op:
pos_to_modify = i
orig_op = op
break
if pos_to_modify is None:
raise ValueError(
f"Target operation '{target_op}' not found in operations list"
)
# Determine the model to use
default_model = orig_op.get("model", global_default_model)
operations_to_insert = []
# Create the optional Map operation if specified
if rewrite.map_config:
new_map_op = {
"name": rewrite.map_config.name,
"type": "map",
"prompt": rewrite.map_config.prompt,
"model": default_model,
"litellm_completion_kwargs": {"temperature": 0},
"output": {"schema": {rewrite.map_config.output_keys[0]: "string"}},
}
operations_to_insert.append(new_map_op)
# Create the first reduce operation (finer granularity)
first_reduce_op = deepcopy(orig_op)
first_reduce_op["name"] = rewrite.reduce_1_name
first_reduce_op["reduce_key"] = (
[orig_op["reduce_key"], rewrite.additional_key]
if isinstance(orig_op["reduce_key"], str)
else orig_op["reduce_key"] + [rewrite.additional_key]
)
first_reduce_op["prompt"] = rewrite.reduce_1_prompt
first_reduce_op["model"] = default_model
operations_to_insert.append(first_reduce_op)
# Create the second reduce operation (target granularity)
second_reduce_op = deepcopy(orig_op)
second_reduce_op["prompt"] = rewrite.reduce_2_prompt
second_reduce_op["model"] = default_model
operations_to_insert.append(second_reduce_op)
# Replace the original operation with the new operations
new_ops_list[pos_to_modify : pos_to_modify + 1] = operations_to_insert
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there is only one target op
assert (
len(target_ops) == 1
), "There must be exactly one target op to instantiate this hierarchical reduce directive"
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
# Ensure it's a reduce operation
if target_op_config.get("type") != "reduce":
raise ValueError(
f"Target operation '{target_ops[0]}' must be a reduce operation"
)
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
target_op_config,
agent_llm,
message_history,
)
# Apply the rewrite to the operators
new_ops_plan = self.apply(
global_default_model, operators, target_ops[0], rewrite
)
return new_ops_plan, message_history, call_cost

View File

@ -0,0 +1,412 @@
import json
import re
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
IsolatingSubtasksInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class IsolatingSubtasksDirective(Directive):
name: str = Field(
default="isolating_subtasks", description="The name of the directive"
)
formal_description: str = Field(default="Map => Parallel Map -> Map")
nl_description: str = Field(
default="Rewrites a single Map into a Parallel Map that isolates subtasks and generates separate outputs for each, followed by a Map that aggregates or synthesizes the results."
)
when_to_use: str = Field(
default="When the original Map is overloaded—either the prompt asks for many different things OR the output schema has many fields—and subtasks are better handled independently (e.g., extract each attribute in parallel, then combine into a unified output)."
)
instantiate_schema_type: Type[BaseModel] = Field(
default=IsolatingSubtasksInstantiateSchema
)
example: str = Field(
default="""
Original Op (MapOpConfig):
- name: extract_contract_info
type: map
prompt: |
Extract the following from this contract: {{ input.document }}
- party names
- agreement date
- governing law
- termination clauses
output:
schema:
parties: "string"
agreement_date: "string"
governing_law: "string"
termination_clauses: "string"
Example InstantiateSchema (what the agent should output):
IsolatingSubtasksConfig(
subtasks=[
SubtaskConfig(
name="Extract Basic Contract Info",
prompt="Extract party names and agreement date from: {{ input.document }}",
output_keys=["parties", "agreement_date"]
),
SubtaskConfig(
name="Extract Legal Terms",
prompt="Extract governing law and termination clauses from: {{ input.document }}",
output_keys=["governing_law", "termination_clauses"]
)
],
aggregation_prompt="Combine the basic info {{ input.subtask_1_output }} with legal terms {{ input.subtask_2_output }} into the final contract summary."
)
"""
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="complex_prompt_simple_output",
description="Complex multi-task prompt but simple list output - should isolate by prompt complexity",
input_config={
"name": "analyze_document",
"type": "map",
"prompt": "Analyze this document for: 1) sentiment and emotional tone, 2) key topics and themes, 3) factual accuracy and bias, 4) writing quality and readability, 5) actionable insights and recommendations. Document: {{ input.text }}",
"output": {"schema": {"results": "list[string]"}},
},
target_ops=["analyze_document"],
expected_behavior="Should create >1 parallel map prompts covering all analysis aspects (sentiment, topics, bias, quality, insights) and aggregation prompt referencing all subtask outputs",
should_pass=True,
),
DirectiveTestCase(
name="contract_analysis_many_fields",
description="Legal contract extraction with many specific fields",
input_config={
"name": "extract_contract_terms",
"type": "map",
"prompt": "Extract contract information from: {{ input.contract_text }}",
"output": {
"schema": {
"parties": "string",
"agreement_date": "string",
"governing_law": "string",
"termination_clause": "string",
"payment_terms": "string",
"liability_cap": "string",
}
},
},
target_ops=["extract_contract_terms"],
expected_behavior="Should create >1 parallel map prompts covering all 6 fields (parties, agreement_date, governing_law, termination_clause, payment_terms, liability_cap) and aggregation prompt referencing all subtask outputs",
should_pass=True,
),
DirectiveTestCase(
name="medical_transcript_processing",
description="Medical data extraction - different subtasks for different medical info types",
input_config={
"name": "process_medical_record",
"type": "map",
"prompt": "Extract patient demographics, symptoms, diagnosis, and treatment plan from: {{ input.transcript }}",
"output": {
"schema": {
"patient_info": "string",
"symptoms": "string",
"diagnosis": "string",
"treatment": "string",
}
},
},
target_ops=["process_medical_record"],
expected_behavior="Should create >1 parallel map prompts covering all medical aspects (demographics, symptoms, diagnosis, treatment) and aggregation prompt referencing all subtask outputs",
should_pass=True,
),
DirectiveTestCase(
name="research_paper_summary",
description="Academic paper analysis with focus on different aspects",
input_config={
"name": "summarize_research",
"type": "map",
"prompt": "Analyze this research paper for methodology, key findings, limitations, and practical applications: {{ input.paper_text }}",
"output": {
"schema": {"summary": "string", "key_points": "list[string]"}
},
},
target_ops=["summarize_research"],
expected_behavior="Should create >1 parallel map prompts covering all research aspects (methodology, findings, limitations, applications) and aggregation prompt referencing all subtask outputs",
should_pass=True,
),
DirectiveTestCase(
name="customer_feedback_analysis",
description="Multi-aspect customer feedback analysis with simple output",
input_config={
"name": "analyze_feedback",
"type": "map",
"prompt": "Analyze customer feedback for: product quality issues, service experience problems, pricing concerns, feature requests, and overall satisfaction. Feedback: {{ input.feedback_text }}",
"output": {
"schema": {
"insights": "list[string]",
"priority_score": "string",
}
},
},
target_ops=["analyze_feedback"],
expected_behavior="Should create >1 parallel map prompts covering all feedback aspects (quality, service, pricing, features, satisfaction) and aggregation prompt referencing all subtask outputs",
should_pass=True,
),
DirectiveTestCase(
name="financial_report_extraction",
description="Financial document with many specific metrics to extract",
input_config={
"name": "extract_financial_data",
"type": "map",
"prompt": "Extract financial metrics from earnings report: {{ input.report }}",
"output": {
"schema": {
"revenue": "string",
"profit_margin": "string",
"cash_flow": "string",
"debt_ratio": "string",
"growth_rate": "string",
"market_share": "string",
}
},
},
target_ops=["extract_financial_data"],
expected_behavior="Should create >1 parallel map prompts covering all 6 financial metrics (revenue, profit_margin, cash_flow, debt_ratio, growth_rate, market_share) and aggregation prompt referencing all subtask outputs",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, IsolatingSubtasksDirective)
def __hash__(self):
return hash("IsolatingSubtasksDirective")
def to_string_for_instantiate(self, original_op: Dict) -> str:
"""
Generate a prompt that asks the agent to output the instantiate schema.
"""
# Extract original operation details
original_name = original_op.get("name", "unknown")
original_prompt = original_op.get("prompt", "")
original_output_schema = original_op.get("output", {}).get("schema", {})
original_output_keys = (
list(original_output_schema.keys()) if original_output_schema else []
)
# Find the input key from the original prompt (look for {{ input.XXX }} pattern)
input_matches = re.findall(r"\{\{\s*input\.([^}\s]+)\s*\}\}", original_prompt)
input_key = input_matches[0] if input_matches else "document"
return (
f"You are an expert at analyzing overloaded map operations and breaking them into focused subtasks.\n\n"
f"Original Operation:\n"
f"Name: {original_name}\n"
f"Prompt: {original_prompt}\n"
f"Output Schema: {original_output_schema}\n"
f"Output Keys: {original_output_keys}\n\n"
f"This map operation is overloaded - either the prompt asks for many different things "
f"OR it has {len(original_output_keys)} output fields to generate. "
f"Your task is to create an IsolatingSubtasksConfig with:\n\n"
f"1. **SUBTASKS**: Group the {len(original_output_keys)} output fields into 2-4 logical subtasks "
f"where each subtask handles related fields that can be processed independently:\n"
f" - Each subtask needs a descriptive 'name'\n"
f" - Each subtask needs a focused Jinja 'prompt' that uses {{{{ input.{input_key} }}}} (same as original)\n"
f" - Each subtask needs 'output_keys' listing which fields it extracts\n"
f" - Every original output key must appear in exactly one subtask's output_keys\n\n"
f"2. **AGGREGATION_PROMPT**: A Jinja template that combines all subtask results:\n"
f" - Must reference ALL subtask outputs as {{{{ input.subtask_1_output }}}}, {{{{ input.subtask_2_output }}}}, etc.\n"
f" - Should synthesize/combine the subtask results into the final output\n"
f" - Final result must match the original output schema exactly\n"
f" - Example: 'Combine basic info {{{{ input.subtask_1_output }}}} with details {{{{ input.subtask_2_output }}}} into complete result.'\n\n"
f"CRITICAL REQUIREMENTS:\n"
f"- All {len(original_output_keys)} original output keys must be covered by subtasks\n"
f"- Subtask prompts must use {{{{ input.{input_key} }}}} (same input as original)\n"
f"- Aggregation prompt must reference {{{{ input.subtask_N_output }}}} for each subtask\n"
f"- No information should be lost in the transformation\n\n"
f"Example logical grouping for contract extraction:\n"
f"- Subtask 1: Basic info (parties, dates) \n"
f"- Subtask 2: Legal terms (governing law, clauses)\n"
f"- Subtask 3: Commercial terms (pricing, commitments)\n\n"
f"Please output the IsolatingSubtasksConfig that transforms this overloaded operation."
)
def llm_instantiate(
self,
original_op: Dict,
agent_llm: str,
message_history: list = [],
):
"""
Call the LLM to generate the instantiate schema with validation.
"""
message_history.extend(
[
{
"role": "user",
"content": self.to_string_for_instantiate(original_op),
},
]
)
original_output_keys = list(
original_op.get("output", {}).get("schema", {}).keys()
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=IsolatingSubtasksInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = IsolatingSubtasksInstantiateSchema(**parsed_res)
# Use the schema's validation methods
schema.validate_subtasks_coverage(original_output_keys)
schema.validate_aggregation_references_all_subtasks()
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease ensure all original output keys are covered by subtasks and try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model,
ops_list: List[Dict],
target_op: str,
rewrite: IsolatingSubtasksInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive by replacing the target map with parallel map + aggregation map.
"""
new_ops_list = deepcopy(ops_list)
# Find the target operation
pos_to_replace = None
original_op = None
for i, op in enumerate(ops_list):
if op["name"] == target_op:
pos_to_replace = i
original_op = op
break
if pos_to_replace is None:
raise ValueError(f"Target operation '{target_op}' not found")
# Create the parallel map operation
parallel_map_op = {
"name": f"{target_op}_parallel",
"type": "parallel_map",
"litellm_completion_kwargs": {"temperature": 0},
"prompts": [],
}
assert original_op
# Copy over other fields from original operation (sample, random_sample, etc.)
for key, value in original_op.items():
if key not in ["name", "type", "prompt", "output"]:
parallel_map_op[key] = value
# Set up output schema for parallel map
parallel_output_schema = {}
# Add each subtask as a prompt in the parallel map
for i, subtask in enumerate(rewrite.subtasks, 1):
subtask_output_key = f"subtask_{i}_output"
prompt_config = {
"name": subtask.name,
"output_keys": [subtask_output_key],
"prompt": subtask.prompt,
}
parallel_map_op["prompts"].append(prompt_config)
parallel_output_schema[subtask_output_key] = "string"
# Set the output schema for parallel map
parallel_map_op["output"] = {"schema": parallel_output_schema}
# Use the same model as the original operation
default_model = original_op.get("model", global_default_model)
parallel_map_op["model"] = default_model
# Check if aggregation is needed by comparing subtask output keys with original keys
subtask_output_keys = set()
for subtask in rewrite.subtasks:
subtask_output_keys.update(subtask.output_keys)
# Check if aggregation is needed: aggregation_prompt is empty
if not rewrite.aggregation_prompt.strip():
# Just return the parallel map - it already produces the right output
parallel_map_op["output"] = original_op.get("output", {})
new_ops_list[pos_to_replace : pos_to_replace + 1] = [parallel_map_op]
else:
# Need aggregation step
aggregation_map_op = {
"name": f"{target_op}_aggregate",
"type": "map",
"prompt": rewrite.aggregation_prompt,
"litellm_completion_kwargs": {"temperature": 0},
"output": original_op.get(
"output", {}
), # Same output schema as original
}
# Use the same model as the original operation
aggregation_map_op["model"] = default_model
# Replace the original operation with both operations
new_ops_list[pos_to_replace : pos_to_replace + 1] = [
parallel_map_op,
aggregation_map_op,
]
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
global_default_model: str = None,
**kwargs,
):
"""
Main method that orchestrates directive instantiation.
"""
assert len(target_ops) == 1, "This directive requires exactly one target op"
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
# Step 1: Agent generates the instantiate schema
rewrite, message_history, call_cost = self.llm_instantiate(
target_op_config, agent_llm, message_history
)
# Step 2: Apply transformation using the schema
return (
self.apply(global_default_model, operators, target_ops[0], rewrite),
message_history,
call_cost,
)

View File

@ -0,0 +1,387 @@
import json
import re
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
MapReduceFusionInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class MapReduceFusionDirective(Directive):
name: str = Field(
default="map_reduce_fusion", description="The name of the directive"
)
formal_description: str = Field(
default="Map -> Reduce => Map (with new prompt) -> Reduce (with new propmt)"
)
nl_description: str = Field(
default="Transform a Map operation followed by a Reduce operation over long documents by updating the Map prompt to first extract or process only the relevant information from each document. Then, modify the Reduce prompt to operate on these processed outputs instead of the full document content."
)
when_to_use: str = Field(
default="There is a Map operation followed by a Reduce operation, and the Reduce step iterates over long documents to extract specific information (e.g., locations, entities, themes) that could be pre-extracted per document in the Map step. This makes the Reduce operation more efficient and focused. The target operators must be a Map operator followed by a Reduce operator. When selecting this directive, specify the Map and Reduce operators as the target operators."
)
instantiate_schema_type: Type[BaseModel] = MapReduceFusionInstantiateSchema
example: str = Field(
default=(
"Original Pipeline:\\n"
"Map Op (classify_document):\\n"
"- name: classify_document\\n"
" type: map\\n"
" prompt: |\\n"
" Classify the following document into a category:\\n"
" {{ input.content }}\\n"
" Choose from: news, research, legal, business\\n"
" output:\\n"
" schema:\\n"
" category: str\\n"
"\\n"
"Reduce Op (extract_organizations):\\n"
"- name: extract_organizations\\n"
" type: reduce\\n"
" reduce_key: category\\n"
" prompt: |\\n"
' For each category \\"{{ reduce_key }}\\", extract all organization names from these documents:\\n'
" {% for input in inputs %}\\n"
" Document {{ loop.index }}: {{ input.content }}\\n"
" {% endfor %}\\n"
" Return a list of unique organization names.\\n"
" output:\\n"
" schema:\\n"
" organizations: list[str]\\n"
"\\n"
"Example InstantiateSchema Output:\\n"
"MapReduceFusionInstantiateSchema(\\n"
" new_map_name='fused_classify_extract_organizations',\\n"
" new_map_prompt='Analyze the following document:\\n{{ input.content }}\\n\\n1. Classify the document into a category (news, research, legal, business)\\n\\n2. Extract all organization names mentioned in the document\\n\\nProvide both the category and list of organizations.',\\n"
" new_key='organizations',\\n"
" new_reduce_prompt='For each category \\\"{{ reduce_key }}\\\", combine all organization names from these pre-extracted lists:\\n{% for input in inputs %}\\nOrganizations from document {{ loop.index }}: {{ input.organizations }}\\n{% endfor %}\\nReturn a single deduplicated list of all unique organization names.'\\n"
")"
),
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="classification_organization_fusion",
description="Should fuse document classification with organization extraction",
input_config=[
{
"name": "classify_document",
"type": "map",
"prompt": "Classify this document: {{ input.content }}",
"output": {"schema": {"category": "str"}},
},
{
"name": "extract_organizations",
"type": "reduce",
"reduce_key": "category",
"prompt": "For each category '{{ reduce_key }}', extract organizations from: {% for input in inputs %}{{ input.content }}{% endfor %}",
"output": {"schema": {"organizations": "list[str]"}},
},
],
target_ops=["classify_document", "extract_organizations"],
expected_behavior="Should modify map to classify AND extract organizations per document, then reduce to aggregate pre-extracted organizations",
should_pass=True,
),
DirectiveTestCase(
name="analysis_entity_fusion",
description="Should fuse document analysis with entity extraction per category",
input_config=[
{
"name": "analyze_document",
"type": "map",
"prompt": "Analyze document type: {{ input.text }}",
"output": {"schema": {"doc_type": "str"}},
},
{
"name": "find_people",
"type": "reduce",
"reduce_key": "doc_type",
"prompt": "For each document type '{{ reduce_key }}', find all people mentioned: {% for input in inputs %}Document: {{ input.text }}{% endfor %}",
"output": {"schema": {"people": "list[str]"}},
},
],
target_ops=["analyze_document", "find_people"],
expected_behavior="Should modify map to analyze type AND extract people per document, then reduce to combine pre-extracted people lists",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, MapReduceFusionDirective)
def __hash__(self):
return hash("MapReduceFusionDirective")
def to_string_for_instantiate(self, original_ops: List[Dict]) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
original_ops (List[Dict]): List containing the map and reduce operations.
Returns:
str: The agent prompt for instantiating the directive.
"""
map_op, reduce_op = original_ops[0], original_ops[1]
return (
f"You are an expert at optimizing data processing pipelines for efficiency.\\n\\n"
f"Original Map Operation:\\n"
f"{str(map_op)}\\n\\n"
f"Original Reduce Operation:\\n"
f"{str(reduce_op)}\\n\\n"
f"Directive: {self.name}\\n"
f"Your task is to apply map-reduce fusion: modify the Map operation to pre-extract the information that the Reduce operation needs, "
f"then update the Reduce operation to work with these pre-extracted results instead of processing full documents.\\n\\n"
f"Key Requirements:\\n"
f"1. Analyze what specific information the Reduce operation extracts from documents\\n"
f"2. Create a new Map prompt that does BOTH the original Map task AND extracts the information needed by Reduce\\n"
f"3. Create a new key name for the pre-extracted information that the Reduce will reference\\n"
f"4. Create a new Reduce prompt that aggregates the pre-extracted information instead of processing full documents\\n"
f"5. The Reduce operation should reference the new key (e.g., {{ input.new_key }}) instead of full document content\\n\\n"
f"Output Format:\\n"
f"Return a MapReduceFusionInstantiateSchema with:\\n"
f"- new_map_name: Combined name for the fused map operation\\n"
f"- new_map_prompt: Prompt that does both original map task + extraction\\n"
f"- new_key: Key name for the extracted information\\n"
f"- new_reduce_prompt: Prompt that works with pre-extracted data\\n\\n"
f"Example:\\n"
f"{self.example}\\n\\n"
f"Please output only the MapReduceFusionInstantiateSchema with the four required fields."
)
def llm_instantiate(
self,
map_op: Dict,
reduce_op: Dict,
expected_document_key,
agent_llm: str,
message_history: list = [],
):
"""
Use LLM to instantiate this directive by transforming the map and reduce operations.
Args:
map_op (Dict): The original map operation.
reduce_op (Dict): The original reduce operation.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
Returns:
MapReduceFusionInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for optimizing document processing pipelines.",
},
{
"role": "user",
"content": self.to_string_for_instantiate([map_op, reduce_op]),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=MapReduceFusionInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
if (
"new_map_name" not in parsed_res
or "new_map_prompt" not in parsed_res
or "new_key" not in parsed_res
or "new_reduce_prompt" not in parsed_res
):
raise ValueError(
"Response from LLM is missing required keys: 'new_map_name', 'new_map_prompt', 'new_key', or 'new_reduce_prompt'"
)
schema = MapReduceFusionInstantiateSchema(
new_map_name=parsed_res["new_map_name"],
new_map_prompt=parsed_res["new_map_prompt"],
new_key=parsed_res["new_key"],
new_reduce_prompt=parsed_res["new_reduce_prompt"],
)
# Validate the schema
MapReduceFusionInstantiateSchema.validate_reduce_prompt_references_new_key(
schema.new_reduce_prompt, schema.new_key, expected_document_key
)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model,
ops_list: List[Dict],
map_target: str,
reduce_target: str,
rewrite: MapReduceFusionInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config.
"""
# Create a copy of the pipeline config
new_ops_list = deepcopy(ops_list)
# Find positions of the target ops
map_pos = None
reduce_pos = None
orig_map_op = None
orig_reduce_op = None
for i, op in enumerate(ops_list):
if op["name"] == map_target:
map_pos = i
orig_map_op = op
elif op["name"] == reduce_target:
reduce_pos = i
orig_reduce_op = op
if map_pos is None or reduce_pos is None:
raise ValueError(
f"Could not find target operations: {map_target}, {reduce_target}"
)
# Get default model
default_model = orig_map_op.get("model", global_default_model)
# Create the new map operation with fused functionality
new_map_op = {
"name": rewrite.new_map_name,
"type": "map",
"prompt": rewrite.new_map_prompt,
"model": default_model,
"litellm_completion_kwargs": {"temperature": 0},
"output": {
"schema": {
**orig_map_op.get("output", {}).get("schema", {}),
rewrite.new_key: "list[str]",
}
},
}
# Create the new reduce operation that works with pre-extracted data
new_reduce_op = {
"name": orig_reduce_op["name"],
"type": "reduce",
"reduce_key": orig_reduce_op.get("reduce_key"),
"prompt": rewrite.new_reduce_prompt,
"model": default_model,
"litellm_completion_kwargs": {"temperature": 0},
"output": orig_reduce_op.get("output", {}),
}
# Replace the operations
new_ops_list[map_pos] = new_map_op
new_ops_list[reduce_pos] = new_reduce_op
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there are exactly two target ops (map and reduce)
if len(target_ops) != 2:
raise ValueError(
"map_reduce_fusion directive requires exactly two target operators: a map and a reduce operation"
)
# Find the map and reduce operations
first_op = None
second_op = None
for op in operators:
if op["name"] == target_ops[0]:
first_op = op
elif op["name"] == target_ops[1]:
second_op = op
if first_op is None or second_op is None:
raise ValueError(f"Could not find target operations: {target_ops}")
if first_op.get("type") == "map" and second_op.get("type") == "reduce":
map_op = first_op
reduce_op = second_op
map_target = target_ops[0]
reduce_target = target_ops[1]
else:
raise ValueError(
"Target operators must be one map operation followed by one reduce operation!"
)
# Extract expected document key from the reduce prompt template
prompt_template = reduce_op["prompt"]
# Find all occurrences of {{ input.key }} in the prompt
input_key_pattern = r"\{\{\s*([^\}\s]+)\s*\}\}"
input_keys = list(set(re.findall(input_key_pattern, prompt_template)))
print("input_keys: ", input_keys)
# Heuristic: pick the key that's most likely to contain document content
# Look for common document field names
document_key_candidates = [
key
for key in input_keys
if any(
doc_word in key.lower()
for doc_word in ["document", "text", "content", "body", "description"]
)
]
if document_key_candidates:
expected_document_key = document_key_candidates[
0
] # Pick the first candidate
elif input_keys:
expected_document_key = input_keys[0] # Fall back to the first input key
else:
raise ValueError("No input keys found in the reduce operation prompt")
print(f"Detected document key: {expected_document_key}")
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
map_op, reduce_op, expected_document_key, agent_llm, message_history
)
# Apply the rewrite to the operators
new_ops_plan = self.apply(
global_default_model, operators, map_target, reduce_target, rewrite
)
return new_ops_plan, message_history, call_cost

View File

@ -0,0 +1,361 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
MapResolveToMapWithCategoriesInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class MapResolveToMapWithCategoriesDirective(Directive):
name: str = Field(
default="map_resolve_to_map_with_categories",
description="The name of the directive",
)
formal_description: str = Field(default="Map -> Resolve => Map (with categories)")
nl_description: str = Field(
default="Replace a Map -> Resolve pattern with a single Map operation that has predefined categories. The agent analyzes the data and task to propose a set of canonical categories, and the new Map forces outputs into one of these categories (or 'none of the above'). This effectively performs entity resolution deterministically by standardizing outputs upfront, avoiding the need for pairwise comparisons."
)
when_to_use: str = Field(
default="When a Map operation produces outputs that are then resolved/deduplicated, and the set of valid output categories can be enumerated upfront. This is more efficient than Resolve when the category space is small and well-defined (e.g., standardizing company types, product categories, sentiment labels). The target must be a Map operation followed by a Resolve operation."
)
instantiate_schema_type: Type[BaseModel] = (
MapResolveToMapWithCategoriesInstantiateSchema
)
example: str = Field(
default=(
"Original Pipeline:\n"
"- name: extract_sentiment\n"
" type: map\n"
" prompt: |\n"
" What is the sentiment of this review?\n"
" {{ input.review }}\n"
" output:\n"
" schema:\n"
" sentiment: string\n"
"\n"
"- name: normalize_sentiment\n"
" type: resolve\n"
" comparison_prompt: |\n"
" Are these sentiments equivalent?\n"
" Sentiment 1: {{ input1.sentiment }}\n"
" Sentiment 2: {{ input2.sentiment }}\n"
" resolution_prompt: |\n"
" Normalize these sentiments:\n"
" {% for input in inputs %}\n"
" - {{ input.sentiment }}\n"
" {% endfor %}\n"
" output:\n"
" schema:\n"
" sentiment: string\n"
"\n"
"Example InstantiateSchema:\n"
"MapResolveToMapWithCategoriesInstantiateSchema(\n"
" categories=['Positive', 'Negative', 'Neutral', 'Mixed'],\n"
" category_key='sentiment',\n"
" new_prompt='''Analyze the sentiment of this review and classify it into one of the following categories:\n"
"\n"
"Categories:\n"
"- Positive: Clearly positive sentiment, satisfaction, praise\n"
"- Negative: Clearly negative sentiment, complaints, criticism\n"
"- Neutral: No strong sentiment, factual statements\n"
"- Mixed: Contains both positive and negative elements\n"
"- None of the above: If the review doesn't fit any category\n"
"\n"
"Review: {{ input.review }}\n"
"\n"
"Return exactly one of: Positive, Negative, Neutral, Mixed, or None of the above.''',\n"
" include_none_of_above=True,\n"
")"
),
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="sentiment_categorization",
description="Should replace map+resolve with categorized map for sentiment",
input_config=[
{
"name": "extract_sentiment",
"type": "map",
"prompt": "What is the sentiment? {{ input.text }}",
"output": {"schema": {"sentiment": "string"}},
},
{
"name": "normalize_sentiment",
"type": "resolve",
"comparison_prompt": "Same sentiment? {{ input1.sentiment }} vs {{ input2.sentiment }}",
"resolution_prompt": "Normalize: {% for input in inputs %}{{ input.sentiment }}{% endfor %}",
"output": {"schema": {"sentiment": "string"}},
},
],
target_ops=["extract_sentiment", "normalize_sentiment"],
expected_behavior="Should create a single map with predefined sentiment categories",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, MapResolveToMapWithCategoriesDirective)
def __hash__(self):
return hash("MapResolveToMapWithCategoriesDirective")
def to_string_for_instantiate(
self, map_op: Dict, resolve_op: Dict, sample_data: List[Dict] = None
) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
map_op (Dict): The map operation configuration.
resolve_op (Dict): The resolve operation configuration.
sample_data (List[Dict], optional): Sample data to help identify categories.
Returns:
str: The agent prompt for instantiating the directive.
"""
sample_str = ""
if sample_data:
sample_str = f"\n\nSample Input Data (first 5 items):\n{json.dumps(sample_data[:5], indent=2)}"
return (
f"You are an expert at optimizing data processing pipelines by replacing entity resolution with categorical constraints.\n\n"
f"Map Operation:\n"
f"{str(map_op)}\n\n"
f"Resolve Operation:\n"
f"{str(resolve_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to replace the Map -> Resolve pattern with a single Map that uses predefined categories.\n\n"
f"Key Requirements:\n"
f"1. Analyze the map's output field and the resolve operation to understand what values are being normalized:\n"
f" - Look at the comparison_prompt to understand what variations are being matched\n"
f" - Look at the resolution_prompt to understand the canonical form\n\n"
f"2. Propose a set of categories that cover all expected outputs:\n"
f" - Categories should be mutually exclusive\n"
f" - Categories should be exhaustive (cover all realistic cases)\n"
f" - Consider including 'None of the above' for edge cases\n\n"
f"3. Optionally provide descriptions for each category to help the LLM classify correctly\n\n"
f"4. Create a new_prompt that:\n"
f" - Lists all valid categories with their descriptions\n"
f" - Instructs the LLM to output exactly one category\n"
f" - References the input using {{{{ input.key }}}} syntax\n"
f" - Includes any context from the original map prompt\n\n"
f"5. Identify the category_key (the output field that will contain the category)\n\n"
f"Benefits of this approach:\n"
f"- Eliminates O(n^2) pairwise comparisons from Resolve\n"
f"- Produces consistent, standardized outputs\n"
f"- Reduces cost by removing the Resolve operation entirely\n"
f"{sample_str}\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please analyze the operations and propose appropriate categories. Output the MapResolveToMapWithCategoriesInstantiateSchema."
)
def llm_instantiate(
self,
map_op: Dict,
resolve_op: Dict,
agent_llm: str,
message_history: list = [],
sample_data: List[Dict] = None,
):
"""
Use LLM to instantiate this directive.
Args:
map_op (Dict): The map operation configuration.
resolve_op (Dict): The resolve operation configuration.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
sample_data (List[Dict], optional): Sample data to help identify categories.
Returns:
MapResolveToMapWithCategoriesInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for document processing pipelines specializing in categorical classification.",
},
{
"role": "user",
"content": self.to_string_for_instantiate(
map_op, resolve_op, sample_data
),
},
]
)
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=MapResolveToMapWithCategoriesInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = MapResolveToMapWithCategoriesInstantiateSchema(**parsed_res)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(
self,
global_default_model,
ops_list: List[Dict],
map_op_name: str,
resolve_op_name: str,
rewrite: MapResolveToMapWithCategoriesInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config.
"""
new_ops_list = deepcopy(ops_list)
# Find position of the map and resolve ops
map_pos = None
resolve_pos = None
map_op = None
for i, op in enumerate(new_ops_list):
if op["name"] == map_op_name:
map_pos = i
map_op = op
elif op["name"] == resolve_op_name:
resolve_pos = i
if map_pos is None or resolve_pos is None:
raise ValueError(
f"Could not find map '{map_op_name}' and resolve '{resolve_op_name}' operations"
)
# Determine the model to use
default_model = map_op.get("model", global_default_model)
# Build the list of valid values for validation
valid_values = list(rewrite.categories)
if rewrite.include_none_of_above:
valid_values.append("None of the above")
# Modify the map operation with the new prompt and add validation
new_ops_list[map_pos]["prompt"] = rewrite.new_prompt
new_ops_list[map_pos]["model"] = default_model
# Add validation to ensure output is one of the categories
if "validate" not in new_ops_list[map_pos]:
new_ops_list[map_pos]["validate"] = []
# Add validation rule for the category key
validation_rule = f"output['{rewrite.category_key}'] in {valid_values}"
new_ops_list[map_pos]["validate"].append(validation_rule)
# Update the output schema to reflect the category key
if "output" not in new_ops_list[map_pos]:
new_ops_list[map_pos]["output"] = {"schema": {}}
new_ops_list[map_pos]["output"]["schema"][rewrite.category_key] = "string"
# Remove the resolve operation
new_ops_list.pop(resolve_pos)
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
dataset: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there are exactly two target ops (map and resolve)
assert (
len(target_ops) == 2
), "There must be exactly two target ops (map and resolve) to instantiate this directive"
# Find the map and resolve operations
map_op = None
resolve_op = None
for op in operators:
if op["name"] == target_ops[0]:
if op.get("type") == "map":
map_op = op
elif op.get("type") == "resolve":
resolve_op = op
elif op["name"] == target_ops[1]:
if op.get("type") == "map":
map_op = op
elif op.get("type") == "resolve":
resolve_op = op
if map_op is None or resolve_op is None:
raise ValueError(
f"Could not find both a map and resolve operation in target_ops: {target_ops}"
)
# Verify the map comes before resolve
map_idx = next(
i for i, op in enumerate(operators) if op["name"] == map_op["name"]
)
resolve_idx = next(
i for i, op in enumerate(operators) if op["name"] == resolve_op["name"]
)
if map_idx >= resolve_idx:
raise ValueError(
f"Map operation '{map_op['name']}' must come before resolve operation '{resolve_op['name']}'"
)
# Load sample data if available
sample_data = None
if dataset:
try:
with open(dataset, "r") as f:
sample_data = json.load(f)
except Exception:
pass # Ignore if we can't load sample data
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
map_op,
resolve_op,
agent_llm,
message_history,
sample_data,
)
# Apply the rewrite to the operators
new_ops_plan = self.apply(
global_default_model, operators, map_op["name"], resolve_op["name"], rewrite
)
return new_ops_plan, message_history, call_cost

View File

@ -0,0 +1,335 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
MapToMapResolveReduceInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class MapToMapResolveReduceDirective(Directive):
name: str = Field(
default="map_to_map_resolve_reduce", description="The name of the directive"
)
formal_description: str = Field(default="Map -> Reduce => Map -> Resolve -> Reduce")
nl_description: str = Field(
default="Insert a Resolve operation between Map and Reduce to deduplicate or normalize entities before aggregation. The Resolve operation uses code-powered blocking conditions to efficiently identify which pairs to compare, avoiding O(n^2) comparisons. This is useful when the Map output contains duplicate or near-duplicate entities that should be merged before the Reduce step."
)
when_to_use: str = Field(
default="When a Map operation produces outputs that may contain duplicates, variations, or near-duplicates (e.g., different spellings of names, similar categories), and these should be normalized before the Reduce aggregation step. The target must be a Map operation followed by a Reduce operation."
)
instantiate_schema_type: Type[BaseModel] = MapToMapResolveReduceInstantiateSchema
example: str = Field(
default=(
"Original Pipeline:\n"
"- name: extract_companies\n"
" type: map\n"
" prompt: |\n"
" Extract company names from this news article:\n"
" {{ input.article }}\n"
" output:\n"
" schema:\n"
" company_name: string\n"
"\n"
"- name: aggregate_companies\n"
" type: reduce\n"
" reduce_key: sector\n"
" prompt: |\n"
" List all unique companies in this sector:\n"
" {% for input in inputs %}\n"
" - {{ input.company_name }}\n"
" {% endfor %}\n"
" output:\n"
" schema:\n"
" companies: list[str]\n"
"\n"
"Example InstantiateSchema:\n"
"MapToMapResolveReduceInstantiateSchema(\n"
" resolve_name='normalize_company_names',\n"
" comparison_prompt='''Are these two company names referring to the same company?\n"
"Company 1: {{ input1.company_name }}\n"
"Company 2: {{ input2.company_name }}\n"
"Consider variations like abbreviations (IBM vs International Business Machines), \n"
"different legal suffixes (Inc, Corp, LLC), and common misspellings.''',\n"
" resolution_prompt='''Given these variations of a company name:\n"
"{% for input in inputs %}\n"
"- {{ input.company_name }}\n"
"{% endfor %}\n"
"Return the canonical/official company name.''',\n"
" blocking_conditions=[\n"
" \"input1['company_name'][:3].lower() == input2['company_name'][:3].lower()\",\n"
" \"input1['company_name'].split()[0].lower() == input2['company_name'].split()[0].lower()\",\n"
" ],\n"
" blocking_keys=['company_name'],\n"
" limit_comparisons=1000,\n"
" output_schema={'company_name': 'string'},\n"
")"
),
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="company_name_normalization",
description="Should insert resolve between map and reduce for company names",
input_config=[
{
"name": "extract_companies",
"type": "map",
"prompt": "Extract company name from: {{ input.text }}",
"output": {"schema": {"company_name": "string"}},
},
{
"name": "aggregate_by_sector",
"type": "reduce",
"reduce_key": "sector",
"prompt": "List companies:\n{% for input in inputs %}\n- {{ input.company_name }}\n{% endfor %}",
"output": {"schema": {"companies": "list[str]"}},
},
],
target_ops=["extract_companies", "aggregate_by_sector"],
expected_behavior="Should create a resolve operation between the map and reduce with appropriate blocking conditions",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, MapToMapResolveReduceDirective)
def __hash__(self):
return hash("MapToMapResolveReduceDirective")
def to_string_for_instantiate(self, map_op: Dict, reduce_op: Dict) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
map_op (Dict): The map operation configuration.
reduce_op (Dict): The reduce operation configuration.
Returns:
str: The agent prompt for instantiating the directive.
"""
return (
f"You are an expert at optimizing data processing pipelines by inserting entity resolution steps.\n\n"
f"Map Operation:\n"
f"{str(map_op)}\n\n"
f"Reduce Operation:\n"
f"{str(reduce_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to insert a Resolve operation between the Map and Reduce to deduplicate/normalize entities.\n\n"
f"Key Requirements:\n"
f"1. Create a comparison_prompt that determines if two items from the Map output are duplicates/variations:\n"
f" - Must reference {{ input1.key }} and {{ input2.key }} for comparing fields\n"
f" - Should handle common variations (abbreviations, misspellings, formatting differences)\n\n"
f"2. Create a resolution_prompt that merges matched items into a canonical form:\n"
f" - Must use {{% for input in inputs %}} to iterate over matched items\n"
f" - Should produce the most authoritative/complete version\n\n"
f"3. Create blocking_conditions to avoid O(n^2) comparisons:\n"
f" - These are Python expressions with access to 'input1' and 'input2' dicts\n"
f" - They should filter pairs to only those likely to match\n"
f" - Examples:\n"
f" * \"input1['name'][:3].lower() == input2['name'][:3].lower()\" (first 3 chars match)\n"
f" * \"input1['name'].split()[0].lower() == input2['name'].split()[0].lower()\" (first word matches)\n"
f" * \"abs(len(input1['name']) - len(input2['name'])) < 10\" (similar length)\n"
f" - Multiple conditions are OR'd together\n\n"
f"4. Set limit_comparisons to cap the number of pairs (recommended: 500-2000)\n\n"
f"5. The output_schema should match what the Reduce operation expects from each input\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please output the MapToMapResolveReduceInstantiateSchema."
)
def llm_instantiate(
self, map_op: Dict, reduce_op: Dict, agent_llm: str, message_history: list = []
):
"""
Use LLM to instantiate this directive.
Args:
map_op (Dict): The map operation configuration.
reduce_op (Dict): The reduce operation configuration.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
Returns:
MapToMapResolveReduceInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for document processing pipelines specializing in entity resolution.",
},
{
"role": "user",
"content": self.to_string_for_instantiate(map_op, reduce_op),
},
]
)
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=MapToMapResolveReduceInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = MapToMapResolveReduceInstantiateSchema(**parsed_res)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(
self,
global_default_model,
ops_list: List[Dict],
map_op_name: str,
reduce_op_name: str,
rewrite: MapToMapResolveReduceInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config.
"""
new_ops_list = deepcopy(ops_list)
# Find position of the map and reduce ops
map_pos = None
reduce_pos = None
map_op = None
for i, op in enumerate(ops_list):
if op["name"] == map_op_name:
map_pos = i
map_op = op
elif op["name"] == reduce_op_name:
reduce_pos = i
if map_pos is None or reduce_pos is None:
raise ValueError(
f"Could not find map '{map_op_name}' and reduce '{reduce_op_name}' operations"
)
# Determine the model to use
default_model = map_op.get("model", global_default_model)
# Find the reduce operation to get the reduce_key
reduce_op = None
for op in ops_list:
if op["name"] == reduce_op_name:
reduce_op = op
break
# Derive output schema from reduce_key - that's what's being grouped/resolved
reduce_key = reduce_op.get("reduce_key", []) if reduce_op else []
if isinstance(reduce_key, str):
reduce_key = [reduce_key]
# Build output schema from reduce_key fields
output_schema = {key: "string" for key in reduce_key}
# Create the resolve operation
resolve_op = {
"name": rewrite.resolve_name,
"type": "resolve",
"comparison_prompt": rewrite.comparison_prompt,
"resolution_prompt": rewrite.resolution_prompt,
"blocking_conditions": rewrite.blocking_conditions,
"blocking_keys": rewrite.blocking_keys,
"limit_comparisons": rewrite.limit_comparisons,
"model": default_model,
"output": {"schema": output_schema},
}
# Insert resolve operation after the map operation
new_ops_list.insert(map_pos + 1, resolve_op)
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there are exactly two target ops (map and reduce)
assert (
len(target_ops) == 2
), "There must be exactly two target ops (map and reduce) to instantiate this directive"
# Find the map and reduce operations
map_op = None
reduce_op = None
for op in operators:
if op["name"] == target_ops[0]:
if op.get("type") == "map":
map_op = op
elif op.get("type") == "reduce":
reduce_op = op
elif op["name"] == target_ops[1]:
if op.get("type") == "map":
map_op = op
elif op.get("type") == "reduce":
reduce_op = op
if map_op is None or reduce_op is None:
raise ValueError(
f"Could not find both a map and reduce operation in target_ops: {target_ops}"
)
# Verify the map comes before reduce
map_idx = next(
i for i, op in enumerate(operators) if op["name"] == map_op["name"]
)
reduce_idx = next(
i for i, op in enumerate(operators) if op["name"] == reduce_op["name"]
)
if map_idx >= reduce_idx:
raise ValueError(
f"Map operation '{map_op['name']}' must come before reduce operation '{reduce_op['name']}'"
)
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
map_op,
reduce_op,
agent_llm,
message_history,
)
# Apply the rewrite to the operators
new_ops_plan = self.apply(
global_default_model, operators, map_op["name"], reduce_op["name"], rewrite
)
return new_ops_plan, message_history, call_cost

View File

@ -0,0 +1,329 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
OperatorFusionInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class OperatorFusionDirective(Directive):
name: str = Field(
default="operator_fusion", description="The name of the directive"
)
formal_description: str = Field(default="Op1 -> Op2 => Op2")
nl_description: str = Field(
default="Combines two sequential operations into a single operation to reduce LLM processing costs by avoiding duplicate document reads and API calls"
)
when_to_use: str = Field(
default="When you have two sequential LLM operations processing the same documents keys and want to optimize cost by combining them into one operation that performs both tasks. The target operators should be two consecutive operations. Make sure you specify two operators when choosing this directive."
)
instantiate_schema_type: Type[BaseModel] = Field(
default=OperatorFusionInstantiateSchema
)
example: str = Field(
default="""
Example 1 - Map + Filter Fusion:
Original: extract_sentiment (map) filter_positive (filter)
Agent output: {"fused_prompt": "Extract sentiment from {{ input.review }} AND determine if it's positive (true/false)"}
Example 2 - Map + Map Fusion:
Original: extract_entities (map) classify_urgency (map)
Agent output: {"fused_prompt": "Extract entities from {{ input.text }} AND classify urgency level"}
Example 3 - Map + Reduce Fusion:
Original: extract_themes (map) summarize_themes (reduce)
Agent output: {"fused_prompt": "For each group of feedback, extract themes and summarize them: {% for item in inputs %}{{ item.feedback }}{% endfor %}"}
"""
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="sentiment_analysis_fusion",
description="Fuse sentiment extraction + quality filter",
input_config=[
{
"name": "extract_sentiment",
"type": "map",
"prompt": "What is the sentiment of this review? {{ input.review_text }}",
"output": {
"schema": {"sentiment": "string", "confidence": "float"}
},
"model": "gpt-4o-mini",
},
{
"name": "filter_confident",
"type": "filter",
"prompt": "Is this sentiment confident (>0.8)? Confidence: {{ input.confidence }}",
"output": {"schema": {"is_confident": "boolean"}},
"model": "gpt-4o-mini",
},
],
target_ops=["extract_sentiment", "filter_confident"],
expected_behavior="Should create map + code_filter for sentiment extraction with confidence filtering",
should_pass=True,
),
DirectiveTestCase(
name="document_processing_fusion",
description="Fuse document classification + summarization",
input_config=[
{
"name": "classify_doc",
"type": "map",
"prompt": "Classify this document: {{ input.content }}",
"output": {"schema": {"category": "string"}},
"model": "gpt-4o-mini",
},
{
"name": "summarize_doc",
"type": "map",
"prompt": "Summarize this document: {{ input.content }}",
"output": {"schema": {"summary": "string"}},
"model": "gpt-4o-mini",
},
],
target_ops=["classify_doc", "summarize_doc"],
expected_behavior="Should create single map combining classification and summarization",
should_pass=True,
),
DirectiveTestCase(
name="extract_and_aggregate_fusion",
description="Fuse entity extraction + aggregation",
input_config=[
{
"name": "extract_mentions",
"type": "map",
"prompt": "Extract company mentions from: {{ input.article }}",
"output": {"schema": {"companies": "list[str]"}},
"model": "gpt-4o-mini",
},
{
"name": "count_mentions",
"type": "reduce",
"reduce_key": "topic",
"prompt": "Count company mentions: {% for item in inputs %}{{ item.companies }}{% endfor %}",
"output": {"schema": {"mention_counts": "str"}},
"model": "gpt-4o-mini",
},
],
target_ops=["extract_mentions", "count_mentions"],
expected_behavior="Should create single reduce combining extraction and counting",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, OperatorFusionDirective)
def __hash__(self):
return hash("OperatorFusionDirective")
def to_string_for_instantiate(self, original_ops: List[Dict]) -> str:
"""
Generate a prompt that asks the agent to output the instantiate schema.
"""
op1, op2 = original_ops[0], original_ops[1]
return (
f"You are an expert at optimizing data processing pipelines for cost efficiency.\n\n"
f"Two Sequential Operations:\n"
f"Operation 1: {op1}\n"
f"Operation 2: {op2}\n\n"
f"Your task is to fuse these two operations into a single operation that performs both tasks, "
f"reducing LLM API calls and processing costs.\n\n"
f"Create a combined prompt that:\n"
f"1. Performs the logic of both operations in one LLM call\n"
f"2. Uses the same input references as the original operations\n"
f"3. If either operation is a filter, include boolean logic for filtering\n"
f"4. Maintains the same output schema requirements\n\n"
f"IMPORTANT: If either operation is a filter, your fused prompt MUST include logic that "
f"outputs a boolean field for filtering purposes. A code_filter will be automatically added.\n\n"
f"Example outputs:\n"
f"{self.example}\n\n"
f"Please output only the OperatorFusionInstantiateSchema with 'fused_prompt' field "
f"that specifies how to combine these operations efficiently."
)
def llm_instantiate(
self,
original_ops: List[Dict],
agent_llm: str,
message_history: list = [],
) -> tuple:
"""
Call the LLM to generate the instantiate schema.
"""
message_history.extend(
[
{
"role": "user",
"content": self.to_string_for_instantiate(original_ops),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=OperatorFusionInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = OperatorFusionInstantiateSchema(**parsed_res)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model: str,
ops_list: List[Dict],
target_ops: List[str],
rewrite: OperatorFusionInstantiateSchema,
) -> List[Dict]:
"""
Apply the fusion directive by combining two operations and optionally adding code_filter.
"""
assert (
len(target_ops) == 2
), "Operator fusion requires exactly two target operations"
new_ops_list = deepcopy(ops_list)
op1_name, op2_name = target_ops[0], target_ops[1]
# Find the operations
op1_idx = next(i for i, op in enumerate(ops_list) if op["name"] == op1_name)
op2_idx = next(i for i, op in enumerate(ops_list) if op["name"] == op2_name)
op1, op2 = ops_list[op1_idx], ops_list[op2_idx]
# Determine fused operation type and schema based on the combination
op1_type, op2_type = op1.get("type"), op2.get("type")
assert (
op1_type != "reduce" and op2_type != "reduce"
), "Cannot apply fusion on reduce"
default_model = op1.get("model", global_default_model)
# Create base fused operation
fused_op = {
"name": f"fused_{op1_name}_{op2_name}",
"prompt": rewrite.fused_prompt,
"model": default_model,
"litellm_completion_kwargs": {"temperature": 0},
}
needs_code_filter = False
# Determine type, schema, and code_filter need based on combination
if op1_type == "map" and op2_type == "map":
# map + map => fuse into one map
fused_op["type"] = "map"
fused_op["output"] = {
"schema": {**op1["output"]["schema"], **op2["output"]["schema"]}
}
needs_code_filter = False
elif (op1_type == "map" and op2_type == "filter") or (
op1_type == "filter" and op2_type == "map"
):
# map + filter OR filter + map => fuse into map (with union of schemas) + code filter
fused_op["type"] = "map"
fused_op["output"] = {
"schema": {**op1["output"]["schema"], **op2["output"]["schema"]}
}
needs_code_filter = True
elif op1_type == "filter" and op2_type == "filter":
# filter + filter => fuse into one filter with bool output
fused_op["type"] = "filter"
fused_op["output"] = {"schema": {"_bool": "bool"}}
needs_code_filter = False
# Replace the original operations
if op1_idx < op2_idx:
# Remove in reverse order to maintain indices
new_ops_list.pop(op2_idx)
new_ops_list.pop(op1_idx)
new_ops_list.insert(op1_idx, fused_op)
else:
new_ops_list.pop(op1_idx)
new_ops_list.pop(op2_idx)
new_ops_list.insert(op2_idx, fused_op)
# Add code_filter if needed
if needs_code_filter:
# Get the filter field name from the filter operation
filter_op = op1 if op1.get("type") == "filter" else op2
filter_field = list(filter_op["output"]["schema"].keys())[0]
code_filter_op = {
"name": f"filter_{fused_op['name']}",
"type": "code_filter",
"code": f"""
def transform(input_doc):
return input_doc.get('{filter_field}', False)
""",
}
# Insert code_filter after the fused operation
fused_idx = next(
i for i, op in enumerate(new_ops_list) if op["name"] == fused_op["name"]
)
new_ops_list.insert(fused_idx + 1, code_filter_op)
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
global_default_model: str = None,
**kwargs,
):
"""
Main method that orchestrates directive instantiation.
"""
assert (
len(target_ops) == 2
), "Operator fusion requires exactly two target operations"
# Get the two operations to fuse
target_op_configs = [op for op in operators if op["name"] in target_ops]
# Ensure they are in the correct order
target_op_configs.sort(key=lambda op: target_ops.index(op["name"]))
# Step 1: Agent generates the instantiate schema
rewrite, message_history, call_cost = self.llm_instantiate(
target_op_configs, agent_llm, message_history
)
# Step 2: Apply transformation using the schema
return (
self.apply(global_default_model, operators, target_ops, rewrite),
message_history,
call_cost,
)

View File

@ -0,0 +1,304 @@
import json
import re
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
ReduceChainingInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class ReduceChainingDirective(Directive):
name: str = Field(
default="reduce_chaining", description="The name of the directive"
)
formal_description: str = Field(default="Reduce => Map -> Reduce")
nl_description: str = Field(
default="Transform a reduce operation that processes long documents by inserting a Map step that extracts/processes relevant information from each document first, then modifying the reduce prompt to work with the processed results instead of full document content."
)
when_to_use: str = Field(
default="When a reduce operation iterates through long documents to extract specific information (e.g., locations, entities, themes) that could be pre-extracted per document to make the reduce step more efficient and focused. The target operator must be a reduce operator. You should specify a reduce operator as the target operator when choosing this directive."
)
instantiate_schema_type: Type[BaseModel] = ReduceChainingInstantiateSchema
example: str = Field(
default=(
"Original Reduce Op:\n"
"- name: extract_all_locations\n"
" type: reduce\n"
" reduce_key: document_collection\n"
" prompt: |\n"
" Extract all distinct locations mentioned across these documents:\n"
" {% for input in inputs %}\n"
" Document: {{ input.document }}\n"
" {% endfor %}\n"
" Return a list of unique location names.\n"
" output:\n"
" schema:\n"
" locations: list[str]\n"
"\n"
"Example InstantiateSchema:\n"
"ReduceChainingInstantiateSchema(\n"
" map_name='extract_document_locations',\n"
" map_prompt='Extract all location names mentioned in this document:\\n{{ input.document }}\\nReturn a list of locations.',\n"
" new_key='locations',\n"
" modified_reduce_prompt='Combine and deduplicate all locations from these documents:\\n{% for input in inputs %}\\nLocations from document: {{ input.locations }}\\n{% endfor %}\\nReturn a list of unique location names.',\n"
")"
),
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="extract_distinct_entities",
description="Should decompose entity extraction into map-reduce pattern",
input_config={
"name": "extract_all_people",
"type": "reduce",
"reduce_key": "doc_id",
"prompt": "Extract all distinct person names mentioned across these documents:\n{% for input in inputs %}\nDocument: {{ input.text }}\n{% endfor %}\nReturn a list of unique person names.",
"output": {"schema": {"people": "list[str]"}},
},
target_ops=["extract_all_people"],
expected_behavior="Should create a map op to extract people from each document, then modify reduce to work with extracted lists",
should_pass=True,
),
DirectiveTestCase(
name="theme_analysis",
description="Should decompose theme analysis into map-reduce pattern",
input_config={
"name": "analyze_themes",
"type": "reduce",
"reduce_key": "category",
"prompt": "Identify common themes across these research papers:\n{% for input in inputs %}\nPaper: {{ input.content }}\n{% endfor %}\nReturn the main themes.",
"output": {"schema": {"themes": "list[str]"}},
},
target_ops=["analyze_themes"],
expected_behavior="Should create a map op to extract themes from each paper, then reduce to identify common ones",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, ReduceChainingDirective)
def __hash__(self):
return hash("ReduceChainingDirective")
def to_string_for_instantiate(self, original_op: Dict) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
original_op (str): The YAML or string representation of the original operation.
Returns:
str: The agent prompt for instantiating the directive.
"""
return (
f"You are an expert at optimizing data processing operations by decomposing complex reduce operations.\n\n"
f"Original Operation:\n"
f"{str(original_op)}\n"
f"Directive: {self.name}\n"
f"Your task is to instantiate this directive by creating a Map operation that preprocesses individual documents, "
f"and then modifying the Reduce operation to work with the preprocessed results instead of raw document content.\n\n"
f"The goal is to make the reduce operation more efficient by having the map operation extract or process "
f"the specific information needed from each document, rather than having the reduce operation process the full document content.\n\n"
f"Key Requirements:\n"
f"1. Create a Map operation that processes individual documents and extracts the relevant information\n"
f"2. Choose an appropriate new key name for the Map operation's output\n"
f"3. Modify the original reduce prompt to work with the processed results instead of the original document content\n"
f"4. Ensure the final output schema and semantics remain the same\n"
f"5. The modified reduce prompt should reference the new key, not the original document key\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please output the ReduceChainingInstantiateSchema with the map operation details and modified reduce prompt."
)
def llm_instantiate(
self,
original_op: Dict,
expected_document_key: str,
agent_llm: str,
message_history: list = [],
):
"""
Use LLM to instantiate this directive by decomposing the reduce operation.
Args:
original_op (Dict): The original reduce operation.
expected_document_key (str): The key that contains the document content to be processed.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
Returns:
ReduceChainingInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for document processing pipelines.",
},
{
"role": "user",
"content": self.to_string_for_instantiate(original_op),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=ReduceChainingInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ReduceChainingInstantiateSchema(**parsed_res)
# Validate the schema
ReduceChainingInstantiateSchema.validate_reduce_prompt_references_new_key(
schema.modified_reduce_prompt, schema.new_key, expected_document_key
)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model,
ops_list: List[Dict],
target_op: str,
rewrite: ReduceChainingInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config.
"""
# Create a copy of the pipeline config
new_ops_list = deepcopy(ops_list)
# Find position of the target reduce op to modify
pos_to_modify = None
orig_op = None
for i, op in enumerate(ops_list):
if op["name"] == target_op:
pos_to_modify = i
orig_op = op
break
if pos_to_modify is None:
raise ValueError(
f"Target operation '{target_op}' not found in operations list"
)
# Determine the model to use
default_model = orig_op.get("model", global_default_model)
# Create the new map operation
new_map_op = {
"name": rewrite.map_name,
"type": "map",
"prompt": rewrite.map_prompt,
"model": default_model,
"litellm_completion_kwargs": {"temperature": 0},
"output": {"schema": {rewrite.new_key: "string"}},
}
# Modify the reduce operation
modified_reduce_op = deepcopy(orig_op)
modified_reduce_op["prompt"] = rewrite.modified_reduce_prompt
# Insert the map operation before the reduce operation
new_ops_list.insert(pos_to_modify, new_map_op)
# Update the reduce operation (now at pos_to_modify + 1)
new_ops_list[pos_to_modify + 1] = modified_reduce_op
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there is only one target op
assert (
len(target_ops) == 1
), "There must be exactly one target op to instantiate this reduce chaining directive"
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
# Ensure it's a reduce operation
if target_op_config.get("type") != "reduce":
raise ValueError(
f"Target operation '{target_ops[0]}' must be a reduce operation"
)
# Extract expected document key from the reduce prompt template
prompt_template = target_op_config["prompt"]
# Find all occurrences of {{ input.key }} in the prompt
input_key_pattern = r"\{\{\s*([^\}\s]+)\s*\}\}"
input_keys = list(set(re.findall(input_key_pattern, prompt_template)))
print("input_keys: ", input_keys)
# Heuristic: pick the key that's most likely to contain document content
# Look for common document field names
document_key_candidates = [
key
for key in input_keys
if any(
doc_word in key.lower()
for doc_word in ["document", "text", "content", "body", "description"]
)
]
if document_key_candidates:
expected_document_key = document_key_candidates[
0
] # Pick the first candidate
elif input_keys:
expected_document_key = input_keys[0] # Fall back to the first input key
else:
raise ValueError("No input keys found in the reduce operation prompt")
print(f"Detected document key: {expected_document_key}")
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
target_op_config,
expected_document_key,
agent_llm,
message_history,
)
# Apply the rewrite to the operators
new_ops_plan = self.apply(
global_default_model, operators, target_ops[0], rewrite
)
return new_ops_plan, message_history, call_cost

View File

@ -0,0 +1,343 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import GleaningInstantiateSchema
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class ReduceGleaningDirective(Directive):
name: str = Field(
default="reduce_gleaning", description="The name of the directive"
)
formal_description: str = Field(default="Reduce => Reduce_m (with gleaning config)")
nl_description: str = Field(
default="""Adds a validation loop to Reduce operations: after each LLM generation during the reduce process, a separate "judge" LLM evaluates the output using a yes/no validation prompt. If the output fails, the original LLM refines its answer and repeats until the output passes or the max number of rounds is reached."""
)
when_to_use: str = Field(
default="When reduce operations process complex documents that require comprehensive analysis and synthesis (e.g., research paper analysis, customer feedback consolidation, legal document review, literature synthesis) where outputs must be validated for completeness, accuracy, and proper coverage of all input materials."
)
instantiate_schema_type: Type[BaseModel] = Field(default=GleaningInstantiateSchema)
example: str = Field(
default="""
Original Op (ReduceOpConfig):
- name: synthesize_research_findings
type: reduce
reduce_key: research_domain
prompt: |
You are analyzing research papers in the {{ research_domain }} domain.
Synthesize the following research findings into a comprehensive domain overview:
{% for paper in inputs %}
**Paper {{ loop.index }}:**
- Title: {{ paper.title }}
- Key Findings: {{ paper.key_findings }}
- Methodology: {{ paper.methodology }}
- Limitations: {{ paper.limitations }}
- Future Work: {{ paper.future_work }}
{% endfor %}
Generate a synthesis with the following structure:
- **Domain Overview**: 2-3 sentences describing the field
- **Major Findings**: List of 4-6 key insights across all papers
- **Methodological Approaches**: Summary of research methods used
- **Research Gaps**: Identified limitations and areas needing investigation
- **Future Directions**: Consolidated recommendations for future research
output:
schema:
domain_overview: "string"
major_findings: "list[str]"
methodological_approaches: "string"
research_gaps: "list[str]"
future_directions: "list[str]"
Example InstantiateSchema (what the agent should output):
{
"validation_prompt": "Verify that the synthesis includes all required sections (domain overview, major findings, methodological approaches, research gaps, future directions). Each major finding should be supported by evidence from multiple papers. Research gaps should be specific and actionable. The domain overview should accurately represent the scope covered by the input papers.",
"num_rounds": 3,
"model": "gpt-4o-mini"
}
""",
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="customer_feedback_analysis",
description="Should add validation for comprehensive customer feedback analysis by product category",
input_config={
"name": "analyze_feedback_by_product",
"type": "reduce",
"reduce_key": "product_category",
"prompt": """Analyze customer feedback for {{ product_category }} products.
Create a comprehensive analysis from the following feedback:
{% for review in inputs %}
Review {{ loop.index }}:
- Rating: {{ review.rating }}/5
- Comment: {{ review.comment }}
- Customer Type: {{ review.customer_type }}
- Date: {{ review.date }}
{% endfor %}
Provide analysis with:
- Overall sentiment and satisfaction level
- Top 3 most praised features
- Top 3 most criticized issues
- Recommendations for product improvements
- Customer segment insights""",
"output": {
"schema": {
"overall_sentiment": "string",
"satisfaction_level": "string",
"top_praised_features": "list[str]",
"top_criticized_issues": "list[str]",
"improvement_recommendations": "list[str]",
"segment_insights": "string",
}
},
},
target_ops=["analyze_feedback_by_product"],
expected_behavior="Should add gleaning validation to ensure all feedback is considered, sentiment analysis is accurate, and recommendations are actionable",
should_pass=True,
),
DirectiveTestCase(
name="legal_contract_analysis",
description="Should add validation for thorough legal contract analysis by contract type",
input_config={
"name": "analyze_contracts_by_type",
"type": "reduce",
"reduce_key": "contract_type",
"prompt": """Analyze {{ contract_type }} contracts and extract key legal provisions.
Review the following contracts:
{% for contract in inputs %}
Contract {{ loop.index }}:
- Party Names: {{ contract.parties }}
- Key Terms: {{ contract.key_terms }}
- Obligations: {{ contract.obligations }}
- Termination Clauses: {{ contract.termination }}
- Risk Factors: {{ contract.risks }}
{% endfor %}
Provide comprehensive analysis including:
- Common contractual patterns across all contracts
- Standard terms and deviations
- Risk assessment summary
- Compliance requirements identified
- Recommendations for contract standardization""",
"output": {
"schema": {
"common_patterns": "list[str]",
"standard_terms": "list[str]",
"risk_assessment": "string",
"compliance_requirements": "list[str]",
"standardization_recommendations": "list[str]",
}
},
},
target_ops=["analyze_contracts_by_type"],
expected_behavior="Should add gleaning validation to ensure all contract provisions are captured, risk analysis is thorough, and recommendations are legally sound",
should_pass=True,
),
DirectiveTestCase(
name="research_literature_synthesis",
description="Should add validation for comprehensive literature synthesis by research topic",
input_config={
"name": "synthesize_literature_by_topic",
"type": "reduce",
"reduce_key": "research_topic",
"prompt": """Synthesize research literature on {{ research_topic }}.
Analyze the following academic papers:
{% for paper in inputs %}
Paper {{ loop.index }}:
- Title: {{ paper.title }}
- Abstract: {{ paper.abstract }}
- Methodology: {{ paper.methodology }}
- Key Results: {{ paper.results }}
- Conclusions: {{ paper.conclusions }}
- Limitations: {{ paper.limitations }}
{% endfor %}
Create a literature synthesis with:
- Theoretical frameworks identified across papers
- Consensus findings and contradictory results
- Methodological approaches comparison
- Research gaps and limitations summary
- Future research directions and recommendations""",
"output": {
"schema": {
"theoretical_frameworks": "list[str]",
"consensus_findings": "list[str]",
"contradictory_results": "list[str]",
"methodological_comparison": "string",
"research_gaps": "list[str]",
"future_directions": "list[str]",
}
},
},
target_ops=["synthesize_literature_by_topic"],
expected_behavior="Should add gleaning validation to ensure all papers are properly synthesized, contradictions are identified, and research gaps are accurately captured",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, ReduceGleaningDirective)
def __hash__(self):
return hash("ReduceGleaningDirective")
def to_string_for_instantiate(self, original_op: Dict) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
original_op (str): The YAML or string representation of the original operation.
Returns:
str: The agent prompt for instantiating the directive.
"""
return (
f"You are an expert at adding validation and refinement loops to reduce operations for document processing tasks.\n\n"
f"Original Operation:\n"
f"{str(original_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to instantiate this directive by generating a configuration that adds validation loops to the reduce operation. "
f"The gleaning configuration should include a validation prompt that evaluates the quality of the reduce output, "
f"focusing on document analysis criteria such as:\n"
f"- Completeness: Are all input documents/items properly considered and synthesized?\n"
f"- Accuracy: Are the extracted insights, patterns, and conclusions accurate?\n"
f"- Structure: Does the output follow the required format and include all requested fields?\n"
f"- Comprehensiveness: Are key themes, patterns, and insights captured across all inputs?\n"
f"- Consistency: Are the analysis and recommendations internally consistent?\n\n"
f"For reduce operations, the LLM processes groups of related documents and creates consolidated, synthesized outputs. "
f"The validation should ensure proper document analysis, synthesis quality, and adherence to output requirements.\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please output only the GleaningInstantiateSchema object that specifies how to validate and refine the output of the reduce operation."
)
def llm_instantiate(
self,
original_op: Dict,
agent_llm: str,
message_history: list = [],
):
"""
Use LLM to instantiate this directive by decomposing the original operation.
Args:
original_op (Dict): The original operation.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
Returns:
GleaningInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "user",
"content": self.to_string_for_instantiate(original_op),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=GleaningInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = GleaningInstantiateSchema(**parsed_res)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model: str,
ops_list: List[Dict],
target_op: str,
rewrite: GleaningInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config by adding gleaning configuration to the target reduce operator.
"""
# Create a copy of the pipeline config
new_ops_list = deepcopy(ops_list)
# Find position of the target op to modify
pos_to_replace = [
i for i, op in enumerate(ops_list) if op["name"] == target_op
][0]
# Add gleaning configuration to the target operator
target_operator = new_ops_list[pos_to_replace]
target_operator["gleaning"] = {
"validation_prompt": rewrite.validation_prompt,
"num_rounds": rewrite.num_rounds,
"model": rewrite.model,
}
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
global_default_model: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there is only one target op
assert (
len(target_ops) == 1
), "There must be exactly one target op to instantiate this reduce gleaning directive"
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
# Verify the target operation is a reduce operation
if target_op_config.get("type") != "reduce":
raise ValueError(
f"ReduceGleaningDirective can only be applied to reduce operations, got {target_op_config.get('type')}"
)
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
target_op_config, agent_llm, message_history
)
# Apply the rewrite to the operators
return (
self.apply(global_default_model, operators, target_ops[0], rewrite),
message_history,
call_cost,
)

View File

@ -0,0 +1,339 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import SwapWithCodeInstantiateSchema
from .agent_utils import AgenticDirectiveRunner
from .base import Directive, DirectiveTestCase
class SwapWithCodeDirective(Directive):
name: str = Field(default="swap_with_code", description="The name of the directive")
formal_description: str = Field(default="Reduce => Code Reduce + Map")
nl_description: str = Field(
default="Replaces a Reduce operation with a Code Reduce operation for deterministic logic plus an optional Map operation to format the output. The Code Reduce handles the core reduction logic (like counting, collecting, or aggregating) while the optional Map operation converts the result to match the expected schema format."
)
when_to_use: str = Field(
default="When a reduce operation performs logic that can be implemented more efficiently or deterministically with code rather than an LLM. Examples include: counting distinct values, finding most common elements, basic aggregations, set operations, or mathematical computations. Particularly useful when the reduction logic is straightforward but the output needs to be formatted in a specific way for downstream operations."
)
instantiate_schema_type: Type[BaseModel] = Field(
default=SwapWithCodeInstantiateSchema
)
example: str = Field(
default="""
Target Operation:
- name: summarize_locations
type: reduce
reduce_key: "_all"
prompt: |
Summarize all distinct locations from these documents:
{% for input in inputs %}{{ input.locations }}{% endfor %}
output:
schema:
summary: "str"
distinct_locations: "list[str]"
The agent might convert this to:
1. Code Reduce that collects distinct locations: {"locations": ["NYC", "SF", "LA"]}
2. Optional Map that formats this as: {"summary": "Found 3 distinct locations: NYC, SF, LA", "distinct_locations": ["NYC", "SF", "LA"]}
Example InstantiateSchema (what the agent should output):
SwapWithCodeInstantiateSchema(
code_reduce_name="collect_distinct_locations",
code="def transform(inputs):\n locations = set()\n for item in inputs:\n if 'locations' in item and isinstance(item['locations'], list):\n locations.update(item['locations'])\n return {'distinct_locations': sorted(list(locations))}",
map_prompt="Create a summary of the locations: {{ input.distinct_locations }}. Output format: summary (string describing the count and locations), distinct_locations (the original list)."
)
"""
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="basic_reduce_to_code_reduce",
description="Should convert reduce operation to code reduce + optional map",
input_config={
"name": "count_items",
"type": "reduce",
"reduce_key": "_all",
"prompt": "Count the total items: {% for input in inputs %}{{ input.count }}{% endfor %}",
"output": {"schema": {"total": "int"}},
},
target_ops=["count_items"],
expected_behavior="Should replace reduce with code reduce that performs counting logic and optional map to format output",
should_pass=True,
),
DirectiveTestCase(
name="reduce_with_grouping",
description="Should handle reduce operations with grouping keys",
input_config={
"name": "group_by_category",
"type": "reduce",
"reduce_key": "category",
"prompt": "List all items for this category: {% for input in inputs %}{{ input.name }}{% endfor %}",
"output": {"schema": {"category": "str", "items": "list[str]"}},
},
target_ops=["group_by_category"],
expected_behavior="Should preserve reduce_key grouping in the code reduce operation",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, SwapWithCodeDirective)
def __hash__(self):
return hash("SwapWithCodeDirective")
def to_string_for_instantiate(
self, target_ops_configs: List[Dict], pipeline_code: Dict = None
) -> str:
"""
Generate a prompt that asks the agent to analyze sample data and create a code reduce + optional map replacement.
"""
assert (
len(target_ops_configs) == 1
), "SwapWithCode directive only supports single target operation"
op = target_ops_configs[0]
original_prompt = op.get("prompt", "")
reduce_key = op.get("reduce_key", "_all")
output_schema = op.get("output", {}).get("schema", {})
# Build pipeline context
pipeline_context = ""
if pipeline_code:
pipeline_context = f"""
Pipeline Context:
{json.dumps(pipeline_code, indent=2)}
The target reduce operation '{op['name']}' fits into this broader pipeline. Consider:
- What data flows into this operation from previous steps
- How this operation's output will be used by subsequent operations
- The overall goal of the pipeline when designing your code reduce + map solution
"""
return (
f"You are an expert at analyzing reduce operations and implementing efficient code-based alternatives.\n\n"
f"Target Reduce Operation:\n"
f"{json.dumps(op, indent=2)}\n\n"
f"Original Prompt: {original_prompt}\n"
f"Reduce Key: {reduce_key}\n"
f"Expected Output Schema: {json.dumps(output_schema, indent=2)}\n\n"
f"{pipeline_context}\n"
f"Your task is to replace this reduce operation with:\n"
f"1. A Code Reduce operation that implements the core reduction logic deterministically\n"
f"2. An optional Map operation that formats the code reduce output to match the expected schema\n\n"
f"You will be given access to sample input data through a read_next_docs() function. Use this to:\n"
f"1. Understand the actual structure and patterns in the input data\n"
f"2. Identify what the reduce operation is trying to accomplish\n"
f"3. Design efficient Python code that performs the same reduction logic\n"
f"4. Determine if a follow-up map operation is needed to format the output correctly\n"
f"5. Consider edge cases and data variations in your implementation\n\n"
f"Guidelines for the replacement:\n"
f"- The Code Reduce must implement a 'transform' function that takes a list of inputs and returns a dictionary\n"
f"- Include all necessary imports within the transform function\n"
f"- The code should handle the same reduce_key grouping as the original operation\n"
f"- If the code reduce output doesn't match the expected schema, provide a map_prompt to format it\n"
f"- The map operation (if needed) should reference fields from the code reduce output using {{{{ input.field_name }}}}\n"
f"- Focus on correctness, efficiency, and handling edge cases found in the sample data\n\n"
f"Examples of good candidates for code reduce:\n"
f"- Counting, summing, or basic mathematical operations\n"
f"- Collecting distinct values or creating sets\n"
f"- Finding min/max values or sorting\n"
f"- Simple aggregations or list operations\n"
f"- Deterministic text processing\n\n"
f"Example transformation:\n"
f"{self.example}\n\n"
f"Analyze samples strategically to understand the data patterns and reduction requirements.\n"
f"When you have enough information to create an efficient code-based solution, output your result."
)
def llm_instantiate(
self,
target_ops_configs: List[Dict],
input_file_path: str,
agent_llm: str,
message_history: list = [],
pipeline_code: Dict = None,
):
"""
Use agentic approach to analyze sample data and generate code reduce + optional map replacement.
"""
# Load sample input data
try:
with open(input_file_path, "r") as f:
input_data = json.load(f)
if not isinstance(input_data, list) or len(input_data) == 0:
raise ValueError(
"Input file must contain a non-empty list of sample data"
)
except Exception as e:
raise Exception(
f"Failed to load input data from {input_file_path}: {str(e)}"
)
# Create validation function
def validate_code_reduce_schema(schema_instance):
# Basic validation is handled by Pydantic validators
# Could add additional validation here if needed
pass
# Set up agentic runner with validation
runner = AgenticDirectiveRunner(
input_data=input_data,
agent_llm=agent_llm,
validation_func=validate_code_reduce_schema,
)
# Create system prompt for the agentic runner
system_prompt = (
"You are an expert at analyzing reduce operations and designing efficient code-based alternatives. "
"Your goal is to examine input samples to understand the reduction logic, then implement it as "
"efficient Python code with optional formatting. You consider both performance and correctness "
"while ensuring the output matches the expected schema through code reduce + optional map pattern."
)
# Create initial user message
initial_message = self.to_string_for_instantiate(
target_ops_configs, pipeline_code
)
# Run the agentic loop
try:
schema, updated_message_history, call_cost = runner.run_agentic_loop(
system_prompt=system_prompt,
initial_user_message=initial_message,
response_schema=SwapWithCodeInstantiateSchema,
)
# Update message history
message_history.extend(updated_message_history)
return schema, message_history, call_cost
except Exception as e:
raise Exception(f"Failed to instantiate swap_with_code directive: {str(e)}")
def apply(
self,
global_default_model: str,
ops_list: List[Dict],
target_ops: List[str],
rewrite: SwapWithCodeInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive by replacing the reduce operation with code reduce + optional map.
"""
new_ops_list = deepcopy(ops_list)
# Find the target operation
target_pos = None
target_op = None
for i, op in enumerate(new_ops_list):
if op["name"] in target_ops:
target_pos = i
target_op = op
break
if target_pos is None:
raise ValueError(f"Target operation {target_ops[0]} not found")
# Get model from original reduce operation or use global default
default_model = target_op.get("model", global_default_model)
# Create the code reduce operation
code_reduce_op = {
"name": rewrite.code_reduce_name,
"type": "code_reduce",
"code": rewrite.code,
"reduce_key": target_op.get("reduce_key", "_all"),
}
# Start with just the code reduce operation
replacement_ops = [code_reduce_op]
# Add optional map operation if specified
if rewrite.map_prompt is not None and rewrite.map_prompt.strip():
map_op = {
"name": target_op[
"name"
], # Keep the original name for the final output
"type": "map",
"prompt": rewrite.map_prompt,
"model": default_model,
"output": target_op.get("output", {}),
}
replacement_ops.append(map_op)
else:
# If no map operation, rename the code reduce to match original name
code_reduce_op["name"] = target_op["name"]
# Add output schema if it exists in original
if "output" in target_op:
code_reduce_op["output"] = target_op["output"]
# Replace the target operation with the replacement operations
new_ops_list = (
new_ops_list[:target_pos] + replacement_ops + new_ops_list[target_pos + 1 :]
)
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
global_default_model: str = None,
**kwargs,
):
"""
Main method that orchestrates directive instantiation:
1. Use agentic approach to analyze data and generate code reduce + optional map
2. Apply the transformation using that configuration
"""
assert (
len(target_ops) == 1
), "SwapWithCode directive requires exactly one target operation"
input_file_path = kwargs.get("input_file_path", None)
pipeline_code = kwargs.get("pipeline_code", None)
if not input_file_path:
raise ValueError("input_file_path is required for SwapWithCode directive")
# Get configuration for target operation
target_ops_configs = [op for op in operators if op["name"] in target_ops]
if not target_ops_configs:
raise ValueError(f"Target operation {target_ops[0]} not found in operators")
# Validate that target operation is a reduce operation
target_op = target_ops_configs[0]
if target_op.get("type") != "reduce":
raise ValueError(
f"SwapWithCode directive can only be applied to reduce operations, but {target_ops[0]} is of type {target_op.get('type')}"
)
# Step 1: Agent analyzes data and generates code reduce + optional map solution
rewrite, message_history, call_cost = self.llm_instantiate(
target_ops_configs,
input_file_path,
agent_llm,
message_history,
pipeline_code,
)
# Step 2: Apply transformation using the generated configuration
return (
self.apply(global_default_model, operators, target_ops, rewrite),
message_history, call_cost
)

View File

@ -0,0 +1,321 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import TakeHeadTailInstantiateSchema
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class TakeHeadTailDirective(Directive):
name: str = Field(default="take_head_tail", description="The name of the directive")
formal_description: str = Field(
default="LLM_Op => Code Map -> LLM_Op",
description="Inserts a Code Map operation before any LLM-powered operation (Map, Filter, Reduce) to truncate document content to head and tail words",
)
nl_description: str = Field(
default="Reduces document length by keeping only the first k words and optionally the last l words of the longest document field. This improves cost efficiency and can enhance accuracy for tasks that only require document beginnings (like classification)."
)
when_to_use: str = Field(
default="When any LLM operation (Map, Filter, Reduce) only needs the beginning (and optionally end) of documents, such as classification tasks, filtering by document type, reducing document summaries, or when full document content causes accuracy issues due to too much context."
)
instantiate_schema_type: Type[BaseModel] = Field(
default=TakeHeadTailInstantiateSchema
)
example: str = Field(
default="""
Original Map Operation (Research Paper Classification):
- name: classify_research_domain
type: map
prompt: |
What research domain does this paper belong to? Classify as Computer Science, Biology, Physics, Chemistry, or Other based on: {{ input.paper_text }}
output:
schema:
domain: "string"
confidence: "float"
model: gpt-4o-mini
TakeHeadTailInstantiateSchema:
{
"name": "extract_paper_abstract",
"document_key": "paper_text",
"head_words": 150,
"tail_words": 0
}
Resulting Pipeline:
- name: extract_paper_abstract
type: code_map
code: |
def transform(input_doc):
paper_text_content = input_doc.get('paper_text', '')
words = paper_text_content.split()
if len(words) <= 150:
truncated = paper_text_content
else:
head = ' '.join(words[:150])
truncated = head
return {'paper_text': truncated}
- name: classify_research_domain
type: map
prompt: |
What research domain does this paper belong to? Classify as Computer Science, Biology, Physics, Chemistry, or Other based on: {{ input.paper_text }}
output:
schema:
domain: "string"
confidence: "float"
model: gpt-4o-mini
"""
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="research_paper_classification",
description="Classify research papers by domain using only abstract/introduction",
input_config={
"name": "classify_paper_domain",
"type": "map",
"prompt": "What research domain does this paper belong to (CS, Biology, Physics, etc.)? Base your classification on the content: {{ input.full_text }}",
"output": {"schema": {"domain": "string", "confidence": "float"}},
"model": "gpt-4o-mini",
},
target_ops=["classify_paper_domain"],
expected_behavior="Should truncate full_text to first ~150 words (abstract/intro) since paper classification only needs the beginning, not the full methodology/results sections",
should_pass=True,
),
DirectiveTestCase(
name="document_metadata_extraction",
description="Extract metadata from document headers/footers for indexing",
input_config={
"name": "extract_document_metadata",
"type": "map",
"prompt": "Extract the title, author, date, and document type from this document: {{ input.content }}",
"output": {
"schema": {
"title": "string",
"author": "string",
"date": "string",
"doc_type": "string",
}
},
"model": "gpt-4o-mini",
},
target_ops=["extract_document_metadata"],
expected_behavior="Should keep both head (~100 words for headers/title) and tail (~50 words for footers/signatures) since metadata appears at document beginning and end",
should_pass=True,
),
DirectiveTestCase(
name="email_priority_classification",
description="Classify email priority using subject and first paragraph",
input_config={
"name": "classify_email_priority",
"type": "map",
"prompt": "Classify this email as HIGH, MEDIUM, or LOW priority based on urgency indicators: {{ input.email_body }}",
"output": {"schema": {"priority": "string", "reasoning": "string"}},
"model": "gpt-4o-mini",
},
target_ops=["classify_email_priority"],
expected_behavior="Should truncate email_body to first ~75 words since email priority is determined by subject line and opening, not full conversation thread",
should_pass=True,
),
DirectiveTestCase(
name="legal_document_type_identification",
description="Identify legal document type from contract headers and signature blocks",
input_config={
"name": "identify_legal_doc_type",
"type": "map",
"prompt": "What type of legal document is this (contract, agreement, policy, etc.)? Analyze: {{ input.legal_text }}",
"output": {
"schema": {
"document_type": "string",
"parties_involved": "list[string]",
}
},
"model": "gpt-4o-mini",
},
target_ops=["identify_legal_doc_type"],
expected_behavior="Should keep head (~200 words for title/parties) and tail (~100 words for signature blocks) since legal doc type is indicated at beginning and parties sign at end",
should_pass=True,
),
DirectiveTestCase(
name="spam_email_filtering",
description="Filter out spam emails based on subject line and opening content",
input_config={
"name": "filter_spam_emails",
"type": "filter",
"prompt": "Is this email spam? Look for suspicious patterns in: {{ input.email_content }}",
"output": {"schema": {"_bool": "bool"}},
"model": "gpt-4o-mini",
},
target_ops=["filter_spam_emails"],
expected_behavior="Should truncate email_content to first ~100 words since spam detection relies on subject, sender, and opening content, not full email thread",
should_pass=True,
),
DirectiveTestCase(
name="research_findings_synthesis",
description="Reduce multiple research papers into a unified findings summary",
input_config={
"name": "synthesize_research_findings",
"type": "reduce",
"prompt": "Synthesize the key findings from these research abstracts and conclusions: {% for doc in inputs %}{{ doc.paper_content }}{% endfor %}",
"output": {
"schema": {"synthesis": "string", "key_themes": "list[string]"}
},
"model": "gpt-4o-mini",
},
target_ops=["synthesize_research_findings"],
expected_behavior="Should keep head (~200 words for abstracts) and tail (~150 words for conclusions) from each paper since synthesis needs both research goals and outcomes",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, TakeHeadTailDirective)
def __hash__(self):
return hash("TakeHeadTailDirective")
def to_string_for_instantiate(self, original_op: Dict) -> str:
op_type = original_op.get("type", "operation")
op_type_caps = op_type.capitalize()
return (
f"You are an expert at optimizing document processing pipelines for cost and accuracy.\n\n"
f"Original {op_type_caps} Operation:\n"
f"{str(original_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to instantiate this directive by generating a TakeHeadTailInstantiateSchema "
f"that specifies how to truncate document content to improve efficiency.\n\n"
f"The directive will insert a Code Map operation before the target {op_type_caps} that:\n"
f"1. Identifies the document key with the longest text content\n"
f"2. Keeps only the first 'head_words' words\n"
f"3. Optionally keeps the last 'tail_words' words (default 0)\n"
f"4. Returns the truncated content in the same key\n\n"
f"Guidelines:\n"
f"- Choose head_words based on how much context the task likely needs\n"
f"- Set tail_words > 0 only if the task benefits from document endings (e.g., conclusions, signatures)\n"
f"- For classification/filtering: typically 50-150 head_words, tail_words=0\n"
f"- For summarization/reduction: typically 100-300 head_words, tail_words=50-100\n"
f"- For metadata extraction: head=100-200, tail=50-100 (headers + footers)\n"
f"- Identify the document_key by looking at {{{{ input.KEY }}}} references in the prompt\n\n"
f"Operation Type Considerations:\n"
f"- Map: Focus on what information is needed for transformation\n"
f"- Filter: Focus on what information is needed for the boolean decision\n"
f"- Reduce: Focus on what information is needed for aggregation/synthesis\n\n"
f"Example configuration:\n"
f"{self.example}\n\n"
f"Please output only the TakeHeadTailInstantiateSchema object that specifies:\n"
f"- name: descriptive name for the truncation operation\n"
f"- document_key: the key containing text to truncate\n"
f"- head_words: number of words to keep from start\n"
f"- tail_words: number of words to keep from end (0 if not needed)"
)
def llm_instantiate(
self,
original_op: Dict,
agent_llm: str,
message_history: list = [],
):
message_history.extend(
[
{
"role": "user",
"content": self.to_string_for_instantiate(original_op),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=TakeHeadTailInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = TakeHeadTailInstantiateSchema(**parsed_res)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
ops_list: List[Dict],
target_op: str,
rewrite: TakeHeadTailInstantiateSchema,
) -> List[Dict]:
new_ops_list = deepcopy(ops_list)
pos_to_replace = [
i for i, op in enumerate(ops_list) if op["name"] == target_op
][0]
head_words = rewrite.head_words
tail_words = rewrite.tail_words
document_key = rewrite.document_key
code_map_function = f"""def transform(input_doc):
{document_key}_content = input_doc.get('{document_key}', '')
words = {document_key}_content.split()
if len(words) <= {head_words + tail_words}:
# Document is short enough, keep as is
truncated = {document_key}_content
else:
head = ' '.join(words[:{head_words}])
if {tail_words} > 0:
tail = ' '.join(words[-{tail_words}:])
truncated = head + ' ... ' + tail
else:
truncated = head
return {{'{document_key}': truncated}}"""
code_map_op = {
"name": rewrite.name,
"type": "code_map",
"code": code_map_function,
}
new_ops_list.insert(pos_to_replace, code_map_op)
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
**kwargs,
):
assert (
len(target_ops) == 1
), "TakeHeadTail directive requires exactly one target op"
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
rewrite, message_history, call_cost = self.llm_instantiate(
target_op_config, agent_llm, message_history
)
return self.apply(operators, target_ops[0], rewrite), message_history, call_cost

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,180 @@
import json
from typing import Any, Dict, List
import tiktoken
from dotenv import load_dotenv
from docetl.dataset import Dataset, create_parsing_tool_map
from docetl.utils import load_config
def extract_input_schema(
data: List[Dict[str, Any]], max_samples: int = 10
) -> Dict[str, Dict[str, Any]]:
"""
Extract the input schema from JSON data by analyzing the structure.
Args:
data: List of dictionaries containing the dataset
max_samples: Maximum number of records to analyze (default: 10)
Returns:
Dict[str, Dict[str, Any]]: Schema mapping field names to their type and token info
"""
if not data:
return {}
# Sample records for analysis (to avoid processing entire large datasets)
sample_size = min(max_samples, len(data))
sample_data = data[:sample_size]
schema = {}
def count_field_tokens(value: Any, model="gpt-4o") -> int:
"""Count tokens for a single field value."""
try:
encoding = tiktoken.encoding_for_model(model)
if value is None:
return 0
elif isinstance(value, (dict, list)):
value_str = json.dumps(value, ensure_ascii=False)
else:
value_str = str(value)
tokens = encoding.encode(value_str)
return len(tokens)
except Exception:
return 0
def infer_type(value: Any) -> str:
"""Infer the type of a value."""
if value is None:
return "string" # Default to string for null values
elif isinstance(value, bool):
return "boolean"
elif isinstance(value, int):
return "integer"
elif isinstance(value, float):
return "number"
elif isinstance(value, list):
if not value:
return "list[string]" # Default for empty lists
# Check first few elements to determine list type
sample_elements = value[:3]
element_types = [infer_type(elem) for elem in sample_elements]
# If all elements are the same type, use that type
if len(set(element_types)) == 1:
return f"list[{element_types[0]}]"
else:
return "list[string]" # Mixed types default to string
elif isinstance(value, dict):
return "dict"
else:
return "string"
# Analyze each field across all sample records
all_fields = set()
for record in sample_data:
all_fields.update(record.keys())
# For each field, determine the most common type and token statistics
for field in sorted(all_fields):
field_values = []
field_tokens = []
for record in sample_data:
if field in record:
value = record[field]
field_values.append(value)
token_count = count_field_tokens(value)
field_tokens.append(token_count)
if not field_values:
schema[field] = {"type": "string", "avg_tokens": 0}
continue
# Count type occurrences
type_counts = {}
for value in field_values:
value_type = infer_type(value)
type_counts[value_type] = type_counts.get(value_type, 0) + 1
# Use the most common type
most_common_type = max(type_counts.items(), key=lambda x: x[1])[0]
avg_tokens = sum(field_tokens) / len(field_tokens)
schema[field] = {"type": most_common_type, "avg_tokens": round(avg_tokens, 1)}
return schema
def count_tokens_in_data(data, model="gpt-4o"):
"""
Count the total number of tokens in the data using tiktoken.
Args:
data: List of dictionaries containing the dataset
model: The model to use for tokenization (default: gpt-4o)
Returns:
int: Total number of tokens
"""
try:
# Get the encoding for the specified model
encoding = tiktoken.encoding_for_model(model)
total_tokens = 0
for item in data:
# Convert the entire item to a JSON string for tokenization
item_str = json.dumps(item, ensure_ascii=False)
tokens = encoding.encode(item_str)
total_tokens += len(tokens)
return total_tokens
except Exception as e:
print(f" [WARNING] Could not count tokens: {e}")
return None
def load_input_doc(yaml_path):
doc_info = ""
load_dotenv()
try:
config = load_config(yaml_path)
except Exception as e:
print(f"[ERROR] Failed to load config: {e}")
return doc_info
parsing_tool_map = create_parsing_tool_map(config.get("parsing_tools", None))
datasets_config = config.get("datasets", {})
if not datasets_config:
print("[ERROR] No datasets found in config.")
return doc_info
for name, dataset_config in datasets_config.items():
doc_info += f"Dataset: {name}\n"
try:
ds = Dataset(
runner=None,
type=dataset_config["type"],
path_or_data=dataset_config["path"],
source=dataset_config.get("source", "local"),
parsing=dataset_config.get("parsing", []),
user_defined_parsing_tool_map=parsing_tool_map,
)
data = ds.load()
if data:
doc_info += f" Type: {ds.type}\n"
doc_info += f" Records loaded: {len(data)}\n"
schema = extract_input_schema(data)
doc_info += " Input schema:\n"
for field, field_info in schema.items():
doc_info += f" {field}: {field_info['type']} (avg: {field_info['avg_tokens']} tokens)\n"
token_count = count_tokens_in_data(data)
if token_count is not None:
doc_info += f" Total tokens: {token_count:,}\n"
except Exception as e:
doc_info += f" [ERROR] Failed to load dataset '{name}': {e}\n"
return doc_info

View File

@ -0,0 +1,638 @@
from typing import Optional
from pydantic import BaseModel, Field
class Operator(BaseModel):
"""Operator model for each DocETL operator."""
# Fields matching your spreadsheet columns
name: str = Field(..., description="Operation name")
type_llm_or_not: str = Field(
..., description="Type (LLM-powered or not LLM-powered)"
)
description: str = Field(..., description="Description")
when_to_use: str = Field(..., description="When to Use")
required_parameters: str = Field(..., description="Required Parameters")
optional_parameters: Optional[str] = Field(None, description="Optional Parameters")
returns: str = Field(..., description="Returns")
minimal_example_configuration: str = Field(
..., description="Minimal Example Configuration"
)
def to_string(self) -> str:
"""Serialize operator for prompts."""
parts = [
f"## {self.name} ({self.type_llm_or_not})",
f"**Description:** {self.description}",
f"**When to Use:** {self.when_to_use}",
f"**Required Parameters:**\n{self.required_parameters}",
]
if self.optional_parameters:
parts.append(f"**Optional Parameters:**\n{self.optional_parameters}")
parts.append(f"**Returns:** {self.returns}")
parts.append(
f"**Example Configuration:**\n{self.minimal_example_configuration}\n"
)
return "\n\n".join(parts)
op_map = Operator(
name="Map",
type_llm_or_not="LLM-powered",
description="Processes each document independently by making an LLM call with your prompt template. Creates one output for each input document, with the output conforming to your specified schema.",
when_to_use="Use when you need to process each document individually - like extracting, summarizing, classifying, or generating new fields based on document content.",
required_parameters="""
name: Unique name for the operation
type: Must be "map"
model: LLM to use to execute the prompt
prompt: Jinja2 template with {{ input.key }} for each document field you want to reference
output.schema: Dictionary defining the structure and types of the output""",
optional_parameters="""
gleaning: Iteratively refines outputs that don't meet quality criteria. The LLM reviews its initial output and improves it based on validation feedback. Config requires:
- if: Python expression for when to refine; e.g., "len(output[key] == 0)" (optional)
- num_rounds: Max refinement iterations
- validation_prompt: What to improve
- model: LLM to use to execute the validation prompt and provide feedback (defaults to same model as operation model)
(Default: gleaning is not enabled)
calibrate: (bool) Processes a sample of documents first to create reference examples, then uses those examples in all subsequent prompts to ensure consistent outputs across the dataset
(Default: calibrate is False)
num_calibration_docs: Number of docs to use for calibration (default: 10)""",
returns="Each original document, augmented with new keys specified in output_schema",
minimal_example_configuration="""
name: gen_insights
type: map
model: gpt-4o-mini
prompt: From the user log below, list 2-3 concise insights (1-2 words each) and 1-2 supporting actions per insight. Return as a list of dictionaries with 'insight' and 'supporting_actions'. Log: {{ input.log }}
output:
schema:
insights_summary: "string"
""",
)
op_extract = Operator(
name="Extract",
type_llm_or_not="LLM-powered",
description="""Pulls out specific portions of text exactly as they appear in the source document. The LLM identifies which parts to extract by providing line number ranges or regex patterns. Extracted text is saved to the original field name with "_extracted" suffix (e.g., report_text → report_text_extracted).""",
when_to_use="Use when you need exact text from documents - like pulling out direct quotes, specific contract clauses, key findings, or any content that must be preserved word-for-word without LLM paraphrasing.",
required_parameters="""
name: Unique name for the operation
type: Must be "extract"
prompt: Instructions for what to extract (this is NOT a Jinja template; we will run the prompt independently for each key in document_keys)
document_keys: List of document fields to extract from
model: LLM to use to execute the extraction
""",
optional_parameters="""
extraction_method: How the LLM specifies what to extract from each document:
- "line_number" (default): The LLM outputs the line numbers or ranges in the document to extract. Use when relevant information is best identified by its position within each document.
- "regex": The LLM generates a custom regex pattern for each document to match the target text. Use when the information varies in location but follows identifiable text patterns (e.g., emails, dates, or structured phrases).
""",
returns="""Each original document, augmented with {key}_extracted for each key in the specified document_keys""",
minimal_example_configuration="""
name: findings
type: extract
prompt: Extract all sections that discuss key findings, results, or conclusions from this research report. Focus on paragraphs that:
- Summarize experimental outcomes
- Present statistical results
- Describe discovered insights
- State conclusions drawn from the research
Only extract the most important and substantive findings.
document_keys: ["report_text"]
model: gpt-4o-mini
""",
)
op_parallel_map = Operator(
name="Parallel Map",
type_llm_or_not="LLM-powered",
description="Runs multiple independent map operations concurrently on each document. Each prompt generates specific fields, and all outputs are combined into a single result per input document. More efficient than sequential maps when transformations are independent.",
when_to_use="Use when you need multiple independent analyses of the same document - like extracting different types of information, running multiple classifications, or generating various summaries from the same input.",
required_parameters="""
name: Unique name for the operation
type: Must be "parallel_map"
prompts: List of prompt configs, each with:
- prompt: Jinja2 template (should reference document fields using {{ input.key }})
- output_keys: List of fields this prompt generates
output.schema: Combined schema for all outputs. This must be the union of the output_keys from all promptsevery key in the output schema should be generated by at least one prompt, and no keys should be missing.
""",
optional_parameters="""
model: Default LLM for all prompts
Per-prompt options:
- model: (optional) Override the default LLM for this specific prompt
- gleaning: (optional) Validation and refinement configuration for this prompt, akin to gleaning in map operation
""",
returns="Each original document augmented with new keys specified in output_schema",
minimal_example_configuration="""
name: analyze_resume
type: parallel_map
prompts:
- prompt: Extract skills from: {{ input.resume }}
output_keys: [skills]
model: gpt-4o-mini
- prompt: Calculate years of experience from: {{ input.resume }}
output_keys: [years_exp]
- prompt: Rate writing quality 1-10: {{ input.cover_letter }}
output_keys: [writing_score]
output:
schema:
skills: list[string]
years_exp: float
writing_score: integer
""",
)
op_filter = Operator(
name="Filter",
type_llm_or_not="LLM-powered",
description="Evaluates each document with an LLM prompt and only keeps documents where the output is true. Documents evaluating to false are removed from the dataset entirely.",
when_to_use="Use when you need to keep only documents meeting specific criteria - like filtering high-impact articles, relevant records, quality content, or documents matching complex conditions that require LLM judgment.",
required_parameters="""
name: Unique name for the operation
type: Must be "filter"
model: LLM to use to execute the prompt
prompt: Jinja2 template that guides the LLM to return true or false (should reference document fields using {{ input.key }})
output.schema: Must contain exactly one boolean field
""",
optional_parameters="""
gleaning: Iteratively refines outputs that don't meet quality criteria. The LLM reviews its initial output and improves it based on validation feedback. Config requires:
- if: Python expression for when to refine; e.g., "len(output[key] == 0)" (optional)
- num_rounds: Max refinement iterations
- validation_prompt: What to improve
- model: LLM to use to execute the validation prompt and provide feedback (defaults to same model as operation model)
(Default: gleaning is not enabled)
calibrate: (bool) Processes a sample of documents first to create reference examples, then uses those examples in all subsequent prompts to ensure consistent outputs across the dataset
(Default: calibrate is False)
num_calibration_docs: Number of docs to use for calibration (default: 10)
""",
returns="Subset of documents; each document has same keys as before",
minimal_example_configuration="""
name: filter_insightful_comments
type: filter
prompt: Is this comment insightful?
Comment: {{ input.comment }}
Consider whether the comment adds a new perspective, explains reasoning, or deepens the discussion. Return true if it is insightful, false otherwise.
output:
schema:
is_insightful: boolean
model: gpt-4o-mini
""",
)
op_reduce = Operator(
name="Reduce",
type_llm_or_not="LLM-powered",
description="Aggregates multiple documents with the same key value(s) into a single output. Groups documents by reduce_key, then applies an LLM prompt to each group to create one aggregated output per unique key combination.",
when_to_use="Use when you need to summarize, consolidate, or analyze groups of related documents - like combining all reviews for a product, summarizing feedback by department, or aggregating patient records by ID.",
required_parameters="""
name: Unique name for the operation
type: Must be "reduce"
model: LLM used to execute the prompts
reduce_key: Key or keys to group by (can be a string, a list of strings, or "_all" to aggregate over all documents)
prompt: Jinja2 template that references {{ inputs }} (the list of grouped documents, where each document is a dictionary) and {{ reduce_key }}
fold_prompt: Template for incrementally processing large groups by folding batches into the current state. The template should reference:
- {{ inputs }}: The current batch of documents to process
- {{ output }}: The current aggregated state (matches output.schema)
The LLM processes the data in batches, updating the state after each batch using the fold_prompt.
fold_batch_size: Number of documents to process in each fold batch
output.schema: Schema for the aggregated output
""",
optional_parameters="""
value_sampling: Run the reduce operation only on a sample, to reduce processing cost and time. Specify a method:
- random: Randomly select a subset of items from each group.
- first_n: Select the first N items in each group.
- cluster: Use clustering (e.g., K-means) to select a diverse, representative subset.
- semantic_similarity: Select items most relevant to a provided query, based on embeddings.
Optional parameters for value_sampling:
- enabled: true to activate value sampling (default: false)
- method: One of the above sampling methods
- sample_size: Number of items to sample from each group
- For cluster: no additional parameters required
- For semantic_similarity:
- embedding_model: Embedding model to use (e.g., text-embedding-3-small)
- embedding_keys: List of fields to embed (e.g., [review])
- query_text: Text to focus sampling (e.g., "battery life and performance")
""",
returns="""A set of fewer documents; each document has reduce_keys and all the keys specified in the output schema. Note that many keys in the original documents, if not specified in reduce_keys (i.e., the groupby), will get dropped.""",
minimal_example_configuration="""
name: summarize_feedback
type: reduce
reduce_key: department
model: gpt-4o-mini
prompt: Summarize the customer feedback for {{ reduce_key.department }}:
{% for item in inputs %}
Feedback {{ loop.index }}: {{ item.feedback }}
{% endfor %}
Provide main points and overall sentiment.
fold_prompt: Incrementally update the summary and sentiment based on a batch of new feedback for {{ reduce_key.department }}.
Current summary: {{ output.summary }}
Current sentiment: {{ output.sentiment }}
New feedback batch:
{% for item in inputs %}
Feedback {{ loop.index }}: {{ item.feedback }}
{% endfor %}
Return the updated summary and sentiment after incorporating the new feedback.
fold_batch_size: 25
output:
schema:
summary: string
sentiment: string
""",
)
op_split = Operator(
name="Split",
type_llm_or_not="Not LLM-powered",
description="Divides long text into smaller chunks based on token count or delimiters. Creates multiple output documents from each input, one per chunk. Adds chunk metadata including chunk ID and sequence number.",
when_to_use="""Use when documents exceed LLM token limits, when processing long transcripts/reports/contracts and we need to read every portion of the document for the task. E.g., "extract all mentions of X from this document." """,
required_parameters="""
name: Unique name for the operation
type: Must be "split"
split_key: Field containing the text to split
method: "token_count" or "delimiter" (defaults to "token_count")
method_kwargs:
- For "token_count": num_tokens (integer) Number of tokens per split
- For "delimiter": delimiter (string) Delimiter to use for splitting the text
""",
optional_parameters=None,
returns="""
The Split operation generates multiple output items for each input item. Each output includes:
- All original key-value pairs from the input item
- {split_key}_chunk: The content of the split chunk
- {op_name}_id: A unique identifier for the original document
- {op_name}_chunk_num: The sequential number of the chunk within its original document
""",
minimal_example_configuration="""
name: split_transcript
type: split
split_key: transcript
method: token_count
method_kwargs:
num_tokens: 500
""",
)
op_gather = Operator(
name="Gather",
type_llm_or_not="Not LLM-powered",
description="""Adds context from surrounding chunks to each chunk after splitting. Includes content from previous/next chunks and maintains document structure through header hierarchies. Creates a "rendered" version of each chunk with its context.""",
when_to_use="Use after Split when chunks need context for accurate processing - essential for legal documents, technical manuals, or any content where references span chunks. Helps maintain document structure and cross-references.",
required_parameters="""
name: Unique name for the operation
type: Must be "gather"
content_key: Field containing the chunk content to gather
doc_id_key: Field linking all chunks from the same original document
order_key: Field specifying the order/sequence of each chunk within the document
peripheral_chunks: Configuration for including context from surrounding (previous and/or next) chunks. This helps each chunk retain important context that may be necessary for accurate downstream processing.
- You can specify previous and/or next context, and control how much content to include before and after each chunk.
- For each (previous/next), you can define:
- head: The first chunk(s) in the section (e.g., the chunk immediately before/after the current one)
- middle: Chunks between head and tail (e.g., summarized versions of further-away chunks)
- tail: The last chunk(s) in the section (e.g., furthest before/after the current chunk you want to include)
- For each section, specify:
- count: Number of chunks to include (head and tail only)
- content_key: Which field to use for context (e.g., full text or summary)
Example:
peripheral_chunks:
previous:
head:
count: 1
content_key: full_content
middle:
content_key: summary_content
tail:
count: 2
content_key: full_content
next:
head:
count: 1
content_key: full_content
- This config means:
Include the full content of 1 chunk before, summaries of all in-between previous chunks, and the full content of the 2 furthest-back previous chunks.
Include the full content of the 1 chunk immediately after the current chunk.
- Use full content for immediate context (head/tail) and summaries for middle sections to balance completeness and token efficiency.
- Only include next chunks if future context is important; by default, focus on previous for most text documents.
""",
optional_parameters="""
doc_header_key: (optional) Field containing extracted headers for each chunk.
- This field provides the hierarchical structure of document sections, enabling the Gather operation to reconstruct header context for each chunk.
- To use this, you must first run a map operation that extracts headers from each chunk, using the following schema:
headers: list of {header: string, level: integer}
- Example map operation:
name: extract_headers
type: map
input:
- agreement_text_chunk
prompt: Extract any section headers from the following merger agreement chunk:
{{ input.agreement_text_chunk }}
Return the headers as a list, preserving their hierarchy.
output.schema:
headers: list[{header: string, level: integer}]
""",
returns="""
The Gather operation produces one output item for each input chunk. Each output includes:
- All original key-value pairs from the input document
- {content_key}_rendered: The content of the chunk, enriched with:
Reconstructed header hierarchy (if doc_header_key is provided)
Previous context (chunks before the current chunk, if configured)
The main chunk, clearly marked
Next context (chunks after the current chunk, if configured)
Indications of skipped content between included sections (e.g., "[... 500 characters skipped ...]")
For example, if your content_key is agreement_text_chunk, the Gather operation adds:
agreement_text_chunk_rendered
Note: No additional unique identifier or chunk number fields are created by Gather. (Those are typically added by Split.) Gather focuses on adding the rendered, context-enriched content field.
""",
minimal_example_configuration="""
name: add_context
type: gather
content_key: text_chunk
doc_id_key: split_doc_id
order_key: split_chunk_num
peripheral_chunks:
previous:
head:
count: 1
content_key: text_chunk
tail:
count: 2
content_key: text_chunk
""",
)
op_unnest = Operator(
name="Unnest",
type_llm_or_not="Not LLM-powered",
description="Expands arrays into multiple documents (one per element) or flattens dictionary fields into the parent document. For arrays, replaces the array with individual elements. For dicts, adds specified fields to parent while keeping original dict.",
when_to_use="Use when you need to process array elements individually, or when flattening nested data structures for easier analysis. Essentially, this operation is for normalizing nested data.",
required_parameters="""
name: Unique name for the operation
type: Must be "unnest"
unnest_key: Field containing the array or dictionary to expand
""",
optional_parameters="""
keep_empty: If true, empty arrays being exploded will be kept in the output (with value None). Default: false
recursive: If true, the unnest operation will be applied recursively to nested arrays. Default: false
depth: The maximum depth for recursive unnesting (only applicable if recursive is true)
""",
returns="Returns one output document for each element in the unnested array or dictionary. Each output preserves all original fields from the input document, and adds fields from the expanded element. For arrays, each item becomes its own document. For dictionaries, specified expand_fields are added as top-level fields in the output.",
minimal_example_configuration="""
name: expand_user
type: unnest
unnest_key: user_info
""",
)
op_sample = Operator(
name="Sample",
type_llm_or_not="Not LLM-powered",
description="Selects a subset of documents from the input according to the specified sampling method. Used to generate a representative sample for further analysis or processing.",
when_to_use="Use when you want to work with a smaller subset of your data for debugging, rapid prototyping, or to reduce compute cost. Also useful for sampling before downstream processing. Stratification can be applied to uniform, first, outliers, top_embedding, and top_fts methods. It ensures that the sample maintains the distribution of specified key(s) in the data or retrieves top items from each stratum.",
required_parameters="""
name: Unique name for the operation
type: Must be "sample"
method: The sampling method to use ("uniform", "outliers", "custom", "first", "top_embedding", or "top_fts")
samples: Either a list of key-value pairs representing document ids and values, an integer count of samples, or a float fraction of samples.
""",
optional_parameters="""
method: The sampling method to use. Options:
- uniform: Randomly select the specified number or fraction of documents. When combined with stratification, maintains the distribution of the stratified groups.
- first: Select the first N documents from the dataset. When combined with stratification, takes proportionally from each group.
- top_embedding: Select top documents based on embedding similarity to a query. Requires the following in method_kwargs: keys: A list of keys to use for creating embeddings, query: The query string to match against (supports Jinja templates), embedding_model: (Optional) The embedding model to use. Defaults to "text-embedding-3-small".
- top_fts: Retrieves the top N items using full-text search with BM25 algorithm. Requires the following in method_kwargs: keys: A list of keys to search within, query: The query string for keyword matching (supports Jinja templates).
- outliers: Select or remove documents considered outliers based on embedding distance. Requiresthe following in method_kwargs: embedding_keys (fields to embed), std (standard deviation cutoff) or samples (number/fraction of outlier samples), and keep (true to keep, false to remove outliers; default false). Optionally, method_kwargs.center can specify the center point.
- custom: Samples specific items by matching key-value pairs. Stratification is not supported with custom sampling.
samples: The number of samples to select (integer), fraction of documents to sample (float), or explicit list of document IDs (for custom).
random_state: Integer to seed the random generator for reproducible results (default: random each run).
stratify_key: Field or list of fields to stratify by (for uniform method stratified sampling)
samples_per_group: When stratifying, sample N items per group vs. dividing total (for uniform method)
method_kwargs: Additional parameters required by the chosen sampling method, such as:
- embedding_keys: List of fields to embed (for outliers)
- std: Number of standard deviations for outlier cutoff (for outliers)
- keep: true to keep or false to remove outliers (for outliers; default false)
- center: Dictionary specifying the center for distance calculations (for outliers)
""",
returns="A subset of input documents, with the same schema as the original input.",
minimal_example_configuration="""
name: stratified_sample
type: sample
method: uniform
samples: 0.2
stratify_key: category
""",
)
op_resolve = Operator(
name="Resolve",
type_llm_or_not="LLM-powered",
description="Identifies and canonicalizes duplicate or matching entities across your dataset using LLM-driven pairwise comparison and resolution prompts. Useful for data cleaning, deduplication, and standardizing variations created by LLMs in preceding map operations.",
when_to_use="Use when you need to standardize documents that may refer to the same real-world entity but have inconsistent or duplicated fields (e.g., names, product titles, organizations) due to extraction, human error, or LLM variation.",
required_parameters="""
name: Unique name for the operation
type: Must be "resolve"
comparison_prompt: Jinja2 template for comparing two candidate records (refer to as {{ input1 }}, {{ input2 }})
resolution_prompt: Jinja2 template for consolidating/mapping a set of matched records (refer to as {{ inputs }})
output.schema: Dictionary defining the structure and types of the resolved output
embedding_model: Model to use for creating embeddings for blocking (default: falls back to default_model)
comparison_model: LLM to use for comparisons (default: falls back to default_model)
resolution_model: LLM to use for final resolution (default: falls back to default_model)
""",
optional_parameters="""
blocking_keys: List of fields for blockingrecords must match on at least one key to be compared (default: all input keys)
blocking_threshold: Embedding similarity threshold for blocking (only compare above this value)
blocking_conditions: List of Python expressions for custom blocking logic (e.g., "left['ssn'][-4:] == right['ssn'][-4:]")
embedding_batch_size: Number of records sent to embedding model per batch (default: 1000)
compare_batch_size: Number of record pairs compared per batch (default: 500)
limit_comparisons: Maximum number of pairwise comparisons (default: no limit)
""",
returns="One output document per input document, preserving the original document structure, but with specified fields in the output schema updated to their resolved (standardized) values. All other fields remain unchanged.",
minimal_example_configuration="""
name: standardize_patient_names
type: resolve
comparison_model: gpt-4o-mini
resolution_model: gpt-4o
embedding_model: text-embedding-3-small
comparison_prompt: |
Compare the following two patient name entries:
Patient 1: {{ input1.patient_name }}
Date of Birth 1: {{ input1.date_of_birth }}
Patient 2: {{ input2.patient_name }}
Date of Birth 2: {{ input2.date_of_birth }}
Are these entries likely referring to the same patient? Respond "True" or "False".
resolution_prompt: |
Standardize these patient names into a single, consistent format:
{% for entry in inputs %}
Patient Name {{ loop.index }}: {{ entry.patient_name }}
{% endfor %}
Provide a single, standardized patient name.
output:
schema:
patient_name: string
blocking_keys:
- last_name
- date_of_birth
blocking_threshold: 0.8
blocking_conditions:
- "left['last_name'][:2].lower() == right['last_name'][:2].lower()"
- "left['date_of_birth'] == right['date_of_birth']"
""",
)
op_code_map = Operator(
name="Code Map",
type_llm_or_not="not LLM-powered",
description="Applies a Python function to each document independently using custom code. Returns a dictionary of key-value pairs to UPDATE the original document with. Useful for deterministic transformations, regex processing, calculations, or leveraging external Python libraries.",
when_to_use="Use when you need deterministic processing, complex calculations, regex/pattern matching, or want to leverage existing Python libraries. Ideal for structured data transformations that don't require LLM reasoning.",
required_parameters="""
name: Unique name for the operation
type: Must be "code_map"
code: Python code defining a function named 'transform' that takes an input document and returns a dictionary of updates. Must include all necessary imports within the function. Format: def transform(input_doc): ...""",
optional_parameters="""
drop_keys: List of keys to remove from output (default: None)
concurrent_thread_count: Number of threads to use (default: number of logical CPU cores)""",
returns="Each original document, updated with the key-value pairs returned by the transform function",
minimal_example_configuration="""
name: extract_keywords_deterministic
type: code_map
code: |
def transform(input_doc):
import re
text = input_doc.get('content', '')
# Extract words that are all caps (potential keywords)
keywords = re.findall(r'\\b[A-Z]{2,}\\b', text)
# Extract email addresses
emails = re.findall(r'\\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}\\b', text)
return {
'keywords': list(set(keywords)),
'email_count': len(emails),
'word_count': len(text.split())
}
""",
)
op_code_filter = Operator(
name="Code Filter",
type_llm_or_not="not LLM-powered",
description="Filters documents based on custom Python logic. Uses a Python function that returns True to keep documents and False to filter them out. Useful for deterministic filtering based on calculations, regex patterns, or structured data conditions.",
when_to_use="Use when you need deterministic filtering logic that doesn't require LLM reasoning - like filtering based on numeric thresholds, text patterns, data completeness, or complex boolean conditions.",
required_parameters="""
name: Unique name for the operation
type: Must be "code_filter"
code: Python code defining a function named 'transform' that takes an input document and returns a boolean (True to keep, False to filter out). Must include all necessary imports within the function. Format: def transform(input_doc): ...""",
optional_parameters="""
concurrent_thread_count: Number of threads to use (default: number of logical CPU cores)""",
returns="Subset of input documents where the transform function returned True. Documents retain all original fields.",
minimal_example_configuration="""
name: filter_valid_scores
type: code_filter
code: |
def transform(input_doc):
score = input_doc.get('confidence_score', 0)
text_length = len(input_doc.get('content', ''))
# Keep documents with high confidence and sufficient content
return score >= 0.8 and text_length >= 100
""",
)
op_topk = Operator(
name="TopK",
type_llm_or_not="LLM-powered or not LLM-powered",
description="Retrieves the most relevant items from your dataset using semantic similarity, full-text search, or LLM-based comparison. Provides a specialized interface for retrieval tasks where you need to find and rank the best matching documents based on specific criteria.",
when_to_use="Use when you need to find the most relevant documents for a query, filter large datasets to the most important items, implement retrieval-augmented generation (RAG) pipelines, or build recommendation systems. Choose this over general sampling when you specifically need the 'best' matches according to some criteria.",
required_parameters="""
name: Unique name for the operation
type: Must be "topk"
method: Retrieval method to use ("embedding" for semantic similarity, "fts" for full-text search, or "llm_compare" for LLM-based ranking)
k: Number of items to retrieve (integer) or percentage (float between 0 and 1)
keys: List of document fields to use for matching/comparison
query: Query or ranking criteria (Jinja templates supported for embedding and fts methods only)
""",
optional_parameters="""
embedding_model: Model for embeddings (default: "text-embedding-3-small"). Used for embedding and llm_compare methods.
model: LLM model for comparisons (required for llm_compare method)
batch_size: Batch size for LLM ranking (default: 10, only for llm_compare method)
stratify_key: Key(s) for stratified retrieval - ensures you retrieve top items from each group (string or list of strings). Not supported with llm_compare method.
Method-specific notes:
- embedding: Uses semantic similarity via embeddings. Supports Jinja templates in query.
- fts: Uses BM25 full-text search algorithm. No API costs. Supports Jinja templates in query.
- llm_compare: Uses LLM for complex ranking based on multiple criteria. Most expensive but most flexible. Does NOT support Jinja templates in query (ranking criteria must be consistent across all documents).
""",
returns="Top k documents based on the specified method and query, with the same schema as the original input",
minimal_example_configuration="""
name: find_relevant_tickets
type: topk
method: embedding
k: 5
keys:
- subject
- description
- customer_feedback
query: "payment processing errors with international transactions"
embedding_model: text-embedding-3-small
""",
)
# List of all operators
ALL_OPERATORS = [
op_map,
op_extract,
op_parallel_map,
op_filter,
op_reduce,
op_split,
op_gather,
op_unnest,
op_sample,
op_resolve,
op_code_map,
op_code_filter,
op_topk,
]
def get_all_operator_descriptions() -> str:
"""
Generate a comprehensive string containing all operator descriptions.
This is useful for providing context about available operators in prompts.
Returns:
str: Formatted string containing all operator descriptions
"""
descriptions = []
descriptions.append("# Available DocETL Operators\n")
descriptions.append(
"Below are all the operators available in the DocETL pipeline:\n"
)
for op in ALL_OPERATORS:
descriptions.append(op.to_string())
return "\n".join(descriptions)

View File

@ -0,0 +1,6 @@
from .base import Retriever
from .lancedb import LanceDBRetriever
__all__ = ["Retriever", "LanceDBRetriever"]

35
docetl/retrievers/base.py Normal file
View File

@ -0,0 +1,35 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
@dataclass
class RetrievalResult:
"""Container for retrieval outputs."""
docs: list[dict]
rendered_context: str
meta: dict[str, Any]
class Retriever(ABC):
"""Abstract base class for retrievers."""
name: str
def __init__(self, runner, name: str, config: dict[str, Any]):
self.runner = runner
self.name = name
self.config = config
@abstractmethod
def ensure_index(self) -> None:
"""Create or verify the underlying index/table."""
raise NotImplementedError
@abstractmethod
def retrieve(self, context: dict[str, Any]) -> RetrievalResult:
"""Execute retrieval based on the provided Jinja context."""
raise NotImplementedError

View File

@ -0,0 +1,358 @@
from __future__ import annotations
import json
import os
import threading
from functools import lru_cache
from typing import Any
from jinja2 import Template
from .base import RetrievalResult, Retriever
class LanceDBRetriever(Retriever):
"""OSS LanceDB-backed retriever supporting FTS, vector, and hybrid search."""
_index_lock = threading.Lock()
_ensured = False
def _connect(self):
# Lazy import to avoid hard dependency during linting/tests without install
try:
import lancedb # type: ignore
except Exception as e:
raise ImportError(
"lancedb is required for LanceDBRetriever. Please add it to dependencies."
) from e
index_dir = self.config["index_dir"]
os.makedirs(index_dir, exist_ok=True)
return lancedb.connect(index_dir)
def _table_name(self) -> str:
# Stable table name based on retriever name
return f"docetl_{self.name}"
def _iter_dataset_rows(self) -> list[dict]:
"""Load dataset referenced by this retriever."""
dataset_name = self.config["dataset"]
# Ensure datasets are loaded
if not getattr(self.runner, "datasets", None):
self.runner.load()
if dataset_name not in self.runner.datasets:
raise ValueError(
f"Retriever '{self.name}' references unknown dataset '{dataset_name}'."
)
return self.runner.datasets[dataset_name].load()
def _render_input_phrase(self, tmpl: str | None, input_obj: dict) -> str:
if not tmpl:
return ""
try:
return Template(tmpl).render(input=input_obj)
except Exception:
return ""
def _batch_embed(self, texts: list[str]) -> list[list[float]]:
model = self.config.get("embedding", {}).get("model")
if not model:
# If no embedding model, return empty vectors
return []
# Use DocETL embedding router + cache
resp = self.runner.api.gen_embedding(model, json.dumps(texts))
# litellm embedding response
vectors = [d["embedding"] for d in resp["data"]]
return vectors
def _index_types(self) -> set[str]:
idx = self.config.get("index_types", [])
if isinstance(idx, str):
idx = [idx]
idx = set(idx or [])
# Interpret hybrid as both
if "hybrid" in idx:
idx.update({"fts", "embedding"})
return idx
def ensure_index(self) -> None:
build_policy = self.config.get("build_index", "if_missing")
with self._index_lock:
# Only build once per process; 'always' means rebuild once at startup
if self._ensured:
return
db = self._connect()
table_name = self._table_name()
dataset_name = self.config.get("dataset", "<unknown>")
idx_types = self._index_types()
try:
self.runner.console.log(
f"[cyan]LanceDB[/cyan] ensure_index: table='{table_name}', dataset='{dataset_name}', "
f"index_types={sorted(list(idx_types))}, build_policy='{build_policy}'"
)
except Exception:
pass
table = None
try:
existing_names = db.table_names()
except Exception as exc:
# Surface the real error instead of sidestepping
raise RuntimeError(f"Failed to list LanceDB tables: {exc}") from exc
if build_policy == "never":
try:
self.runner.console.log(
f"[yellow]LanceDB[/yellow] skipping index build (build_index=never) for '{table_name}'"
)
except Exception:
pass
self._ensured = True
return
if table_name in existing_names:
if build_policy == "always":
try:
self.runner.console.log(
f"[cyan]LanceDB[/cyan] dropping existing table '{table_name}' (build_index=always)"
)
except Exception:
pass
db.drop_table(table_name)
else:
# if_missing and table exists -> nothing to do
try:
self.runner.console.log(
f"[green]LanceDB[/green] index exists for '{table_name}', skipping build (build_index=if_missing)"
)
except Exception:
pass
self._ensured = True
return
if table is None:
try:
rows = self._iter_dataset_rows()
except Exception:
# Dataset may not be available yet (e.g., produced by a prior step)
try:
self.runner.console.log(
f"[yellow]LanceDB[/yellow] dataset '{dataset_name}' not available; "
f"deferring index build for '{table_name}'"
)
except Exception:
pass
return
idx_types = self._index_types()
# Build per-row phrases for FTS and Embedding
fts_index_tmpl = self.config.get("fts", {}).get("index_phrase", None)
emb_index_tmpl = self.config.get("embedding", {}).get(
"index_phrase", None
)
fts_texts: list[str] = []
if "fts" in idx_types:
fts_texts = [
self._render_input_phrase(fts_index_tmpl, r) for r in rows
]
emb_texts: list[str] = []
if "embedding" in idx_types:
emb_texts = [
self._render_input_phrase(emb_index_tmpl, r) for r in rows
]
# If no template, fall back to fts texts
if not any(emb_texts) and fts_texts:
emb_texts = fts_texts
# Compute embeddings for embedding index
vectors: list[list[float]] = []
if "embedding" in idx_types:
# Batch embeddings
batch = 128
for i in range(0, len(emb_texts), batch):
chunk = emb_texts[i : i + batch]
if not chunk:
continue
vectors.extend(self._batch_embed(chunk))
# Construct records for LanceDB
data = []
for idx, r in enumerate(rows):
rec = {
"id": r.get("id", idx),
"text": (
fts_texts[idx]
if fts_texts
else (emb_texts[idx] if emb_texts else "")
),
}
if vectors and idx < len(vectors):
rec["vector"] = vectors[idx]
# keep some fields for rendering
data.append(rec)
# Create table
table = db.create_table(table_name, data=data, mode="overwrite")
try:
self.runner.console.log(
f"[green]LanceDB[/green] created table '{table_name}' with {len(data)} rows "
f"(fts={'fts' in idx_types}, embedding={'embedding' in idx_types})"
)
except Exception:
pass
# Create FTS index if fts enabled
if "fts" in idx_types:
try:
table.create_fts_index("text")
try:
self.runner.console.log(
f"[green]LanceDB[/green] created FTS index on column 'text' for '{table_name}'"
)
except Exception:
pass
except Exception:
# Index may already exist or backend unavailable; continue
try:
self.runner.console.log(
f"[yellow]LanceDB[/yellow] FTS index creation skipped/failed for '{table_name}'"
)
except Exception:
pass
self._ensured = True
try:
self.runner.console.log(
f"[green]LanceDB[/green] index ready for '{table_name}'"
)
except Exception:
pass
def _render_query_phrase(self, tmpl: str | None, context: dict[str, Any]) -> str:
if not tmpl:
return ""
try:
return Template(tmpl).render(**context)
except Exception:
return ""
def _select_mode(self) -> str:
qmode = self.config.get("query", {}).get("mode", None)
if qmode:
return qmode
# Auto: hybrid if both index types present
idx = self._index_types()
if "fts" in idx and "embedding" in idx:
return "hybrid"
if "fts" in idx:
return "fts"
if "embedding" in idx:
return "embedding"
return "fts"
def _reranker(self):
# Lazy import and instantiate RRFReranker when needed
try:
from lancedb.rerankers import RRFReranker # type: ignore
rr = RRFReranker()
return rr
except Exception:
return None
def _limit_and_format(self, rows: list[dict]) -> list[dict]:
top_k = int(self.config.get("query", {}).get("top_k", 5))
return rows[:top_k]
def _fetch(self, context: dict[str, Any]) -> list[dict]:
db = self._connect()
table = db.open_table(self._table_name())
mode = self._select_mode()
top_k = int(self.config.get("query", {}).get("top_k", 5))
# Build queries
text_query = self._render_query_phrase(
self.config.get("fts", {}).get("query_phrase", None), context
)
vector_query: list[float] | None = None
emb_qtext = self._render_query_phrase(
self.config.get("embedding", {}).get("query_phrase", None), context
)
if emb_qtext:
vecs = self._batch_embed([emb_qtext])
vector_query = vecs[0] if vecs else None
# Execute
try:
if mode == "hybrid" and text_query and vector_query is not None:
q = (
table.search(query_type="hybrid")
.vector(vector_query)
.text(text_query)
)
rr = self._reranker()
if rr:
q = q.rerank(rr)
df = q.limit(top_k).to_pandas()
elif mode == "fts" and text_query:
# FTS-only
q = table.search(text_query, query_type="fts", fts_columns="text")
df = q.limit(top_k).to_pandas()
elif mode == "embedding" and vector_query is not None:
q = table.search(vector_query)
df = q.limit(top_k).to_pandas()
else:
return []
except Exception:
# Fallbacks
try:
q = (
table.search(text_query, query_type="fts", fts_columns="text")
if text_query
else table.search(vector_query)
)
df = q.limit(top_k).to_pandas()
except Exception:
return []
# Convert to list of dicts
rows = df.to_dict(orient="records") if hasattr(df, "to_dict") else []
return self._limit_and_format(rows)
def _render_docs(self, docs: list[dict]) -> str:
# Minimal, opinionated rendering: bullets of text (truncated)
max_chars = 1000
lines: list[str] = []
for i, d in enumerate(docs, start=1):
snippet = str(d.get("text", ""))[:max_chars]
if snippet:
lines.append(f"- [{i}] {snippet}")
return "\n".join(lines)
@staticmethod
@lru_cache(maxsize=256)
def _cache_key(name: str, ctx_key: str) -> str:
return f"{name}:{ctx_key}"
def _ctx_fingerprint(self, context: dict[str, Any]) -> str:
# Make a stable string key from context, using only parts used by templates
if "input" in context:
base = context["input"]
elif "reduce_key" in context:
base = {
"reduce_key": context.get("reduce_key", {}),
"inputs_len": len(context.get("inputs", [])),
}
else:
base = context
try:
return json.dumps(base, sort_keys=True, ensure_ascii=False)[:2000]
except Exception:
return str(base)[:2000]
def retrieve(self, context: dict[str, Any]) -> RetrievalResult:
self.ensure_index()
# Build a fingerprint (reserved for future caching)
_ = self._ctx_fingerprint(context)
docs = self._fetch(context)
rendered = self._render_docs(docs) if docs else "No extra context available."
meta = {"retriever": self.name, "num_docs": len(docs)}
return RetrievalResult(docs=docs, rendered_context=rendered, meta=meta)

View File

@ -93,6 +93,7 @@ class DSLRunner(ConfigWrapper):
config: dict[str, Any] | None
parsing_tools: list[schemas.ParsingTool] | None
datasets: dict[str, schemas.Dataset]
retrievers: dict[str, Any] | None
operations: list[OpType]
pipeline: schemas.PipelineSpec
@ -119,6 +120,7 @@ class DSLRunner(ConfigWrapper):
self.total_cost = 0
self._initialize_state()
self._setup_parsing_tools()
self._setup_retrievers()
self._build_operation_graph(config)
self._compute_operation_hashes()
@ -140,6 +142,31 @@ class DSLRunner(ConfigWrapper):
self.config.get("parsing_tools", None)
)
def _setup_retrievers(self) -> None:
"""Instantiate retrievers from configuration (lazy index creation)."""
from docetl.retrievers.lancedb import LanceDBRetriever
self.retrievers: dict[str, Any] = {}
retrievers_cfg = self.config.get("retrievers", {}) or {}
for name, rconf in retrievers_cfg.items():
if not isinstance(rconf, dict):
raise ValueError(f"Invalid retriever '{name}' configuration")
if rconf.get("type") != "lancedb":
raise ValueError(
f"Unsupported retriever type '{rconf.get('type')}' for '{name}'. Only 'lancedb' is supported."
)
required = ["dataset", "index_dir", "index_types"]
for key in required:
if key not in rconf:
raise ValueError(
f"Retriever '{name}' missing required key '{key}'."
)
# Defaults
rconf.setdefault("query", {"top_k": 5})
rconf.setdefault("build_index", "if_missing")
self.retrievers[name] = LanceDBRetriever(self, name, rconf)
def _build_operation_graph(self, config: dict) -> None:
"""Build the DAG of operations from configuration"""
self.config = config
@ -672,7 +699,7 @@ class DSLRunner(ConfigWrapper):
"litellm_kwargs", {}
)
kwargs["rewrite_agent_model"] = self.config.get("optimizer_config", {}).get(
"rewrite_agent_model", "gpt-4o"
"rewrite_agent_model", "gpt-5.1"
)
kwargs["judge_agent_model"] = self.config.get("optimizer_config", {}).get(
"judge_agent_model", "gpt-4o-mini"
@ -700,7 +727,7 @@ class DSLRunner(ConfigWrapper):
"litellm_kwargs", {}
)
kwargs["rewrite_agent_model"] = self.config.get("optimizer_config", {}).get(
"rewrite_agent_model", "gpt-4o"
"rewrite_agent_model", "gpt-5.1"
)
kwargs["judge_agent_model"] = self.config.get("optimizer_config", {}).get(
"judge_agent_model", "gpt-4o-mini"

View File

@ -4,7 +4,9 @@ from .base_schemas import * # noqa: F403
# ruff: noqa: F403
from .operations import (
cluster,
code_operations,
equijoin,
extract,
filter,
gather,
map,
@ -26,6 +28,10 @@ GatherOp = gather.GatherOperation.schema
UnnestOp = unnest.UnnestOperation.schema
ClusterOp = cluster.ClusterOperation.schema
SampleOp = sample.SampleOperation.schema
CodeMapOp = code_operations.CodeMapOperation.schema
CodeReduceOp = code_operations.CodeReduceOperation.schema
CodeFilterOp = code_operations.CodeFilterOperation.schema
ExtractOp = extract.ExtractOperation.schema
OpType = (
MapOp
@ -37,6 +43,10 @@ OpType = (
| SplitOp
| GatherOp
| UnnestOp
| CodeMapOp
| CodeReduceOp
| CodeFilterOp
| ExtractOp
)
Dataset = dataset.Dataset.schema

View File

@ -10,6 +10,7 @@ from jinja2 import Environment, meta
from litellm import ModelResponse
from litellm import completion_cost as lcc
from lzstring import LZString
from rich.prompt import Confirm
class Decryptor:
@ -79,6 +80,72 @@ class CapturedOutput:
self.optimizer_output[self.step][stage_type] = output
def has_jinja_syntax(template_string: str) -> bool:
"""
Check if a string contains Jinja2 template syntax.
Args:
template_string (str): The string to check.
Returns:
bool: True if the string contains Jinja2 syntax ({{ }} or {% %}), False otherwise.
"""
# Check for Jinja2 expression syntax {{ }}
if re.search(r"\{\{.*?\}\}", template_string):
return True
# Check for Jinja2 statement syntax {% %}
if re.search(r"\{%.*?%\}", template_string):
return True
return False
def prompt_user_for_non_jinja_confirmation(
prompt_text: str, operation_name: str, prompt_field: str = "prompt"
) -> bool:
"""
Prompt the user for confirmation when a prompt doesn't contain Jinja syntax.
Args:
prompt_text (str): The prompt text that doesn't contain Jinja syntax.
operation_name (str): The name of the operation.
prompt_field (str): The name of the prompt field (e.g., "prompt", "batch_prompt").
Returns:
bool: True if user confirms, False otherwise.
"""
from docetl.console import DOCETL_CONSOLE
console = DOCETL_CONSOLE
console.print(
f"\n[bold yellow]⚠ Warning:[/bold yellow] The '{prompt_field}' in operation '{operation_name}' "
f"does not appear to be a Jinja2 template (no {{}} or {{% %}} syntax found)."
)
console.print(
f"[dim]Prompt:[/dim] {prompt_text[:100]}{'...' if len(prompt_text) > 100 else ''}"
)
console.print(
"\n[bold]We will automatically append the document(s) to your prompt during execution:[/bold]"
)
console.print(
" • For single-document operations: 'Here is the document: {{ input }}'"
)
console.print(" • For reduce operations: 'Here are the documents: {{ inputs }}'")
console.print()
try:
return Confirm.ask(
"Do you want to proceed with inserting all documents as-is?",
default=True,
console=console,
)
except Exception:
# If Confirm fails (e.g., in non-interactive mode), default to True
console.print(
"[dim]Non-interactive mode: proceeding with document insertion[/dim]"
)
return True
def extract_jinja_variables(template_string: str) -> list[str]:
"""
Extract variables from a Jinja2 template string.
@ -299,3 +366,88 @@ class classproperty:
def __get__(self, obj: Any | None, owner: type) -> Any:
return self.f(owner)
def extract_output_from_json(yaml_file_path, json_output_path=None):
"""
Extract output fields from JSON file based on the output schema defined in the YAML file.
If the last operation doesn't have an output schema, returns all keys from the output data.
Args:
yaml_file_path (str): Path to the YAML configuration file
json_output_path (str): Path to the JSON output file to extract from
Returns:
List[Dict]: Extracted data containing only the fields specified in the output schema,
or all fields if no output schema is defined
"""
# Load YAML configuration
with open(yaml_file_path, "r") as f:
config = yaml.safe_load(f)
if json_output_path is None:
json_output_path = config.get("pipeline", {}).get("output", {}).get("path")
if json_output_path is None:
raise ValueError("No output path found in YAML file")
# Load JSON output data
with open(json_output_path, "r") as f:
output_data = json.load(f)
# Find the last operation in the pipeline
pipeline = config.get("pipeline", {})
steps = pipeline.get("steps", [])
if not steps:
raise ValueError("No pipeline steps found in YAML file")
# Get the last step and its operations
last_step = steps[-1]
last_step_operations = last_step.get("operations", [])
if not last_step_operations:
raise ValueError("No operations found in the last pipeline step")
# Get the name of the last operation in the last step
last_operation_name = last_step_operations[-1]
# Find this operation in the operations list
operations = config.get("operations", [])
last_operation = None
for op in operations:
if op.get("name") == last_operation_name:
last_operation = op
break
if not last_operation:
raise ValueError(
f"Operation '{last_operation_name}' not found in operations list"
)
output_schema = last_operation.get("output", {}).get("schema", {})
if not output_schema:
# If no output schema, return all keys from the output data
if isinstance(output_data, list) and len(output_data) > 0:
# Get all unique keys from all items
all_keys = set()
for item in output_data:
if isinstance(item, dict):
all_keys.update(item.keys())
# Return all data with all keys
return output_data
else:
# If output_data is not a list or is empty, return as-is
return output_data if isinstance(output_data, list) else [output_data]
# Extract the field names from the schema
schema_fields = list(output_schema.keys())
# Extract only the specified fields from each item in the output data
extracted_data = []
for item in output_data:
extracted_item = {}
for field in schema_fields:
if field in item:
extracted_item[field] = item[field]
extracted_data.append(extracted_item)
return extracted_data

129
docetl/utils_dataset.py Normal file
View File

@ -0,0 +1,129 @@
"""
Dataset utility functions for DocETL.
"""
import json
from typing import Any, Dict, List
import yaml
def compute_dataset_stats(
data: List[Dict[str, Any]], dataset_name: str = "data"
) -> str:
"""
Compute statistics for a dataset by analyzing the actual data.
Args:
data: List of data records
dataset_name: Name of the dataset
Returns:
str: Formatted dataset statistics
"""
if not data:
return (
f"Dataset: {dataset_name}\nType: file\nRecords loaded: 0\nNo data available"
)
num_records = len(data)
total_tokens = 0
field_stats = {}
# Analyze each record
for record in data:
if isinstance(record, dict):
for key, value in record.items():
# Skip if key starts with "GT "
if key.startswith("GT "):
continue
if key not in field_stats:
field_stats[key] = {
"total_chars": 0,
"count": 0,
"type": type(value).__name__,
}
if isinstance(value, str):
char_count = len(value)
field_stats[key]["total_chars"] += char_count
field_stats[key]["count"] += 1
total_tokens += (
char_count / 4
) # 4 characters per token approximation
elif isinstance(value, (int, float)):
# Numbers are typically short, estimate as ~5 characters
field_stats[key]["total_chars"] += 5
field_stats[key]["count"] += 1
total_tokens += 1.25
elif isinstance(value, list):
# For lists, estimate based on string representation
str_repr = str(value)
char_count = len(str_repr)
field_stats[key]["total_chars"] += char_count
field_stats[key]["count"] += 1
total_tokens += char_count / 4
# Format the output
stats_lines = [
f"Dataset: {dataset_name}",
"Type: file",
f"Records loaded: {num_records}",
"Input schema:",
]
for field, stats in field_stats.items():
if stats["count"] > 0:
avg_tokens = (stats["total_chars"] / stats["count"]) / 4
field_type = "string" if stats["type"] in ["str"] else stats["type"]
stats_lines.append(
f" {field}: {field_type} (avg: {avg_tokens:.1f} tokens)"
)
stats_lines.append(f"Total tokens: {int(total_tokens):,}")
return "\n ".join(stats_lines)
def get_dataset_stats(yaml_path: str, dataset_name: str | None = None) -> str:
"""
Get dataset statistics by loading and analyzing the actual data from YAML config.
Args:
yaml_path: Path to the YAML configuration file
dataset_name: Optional name of the dataset (if not provided, uses first dataset in config)
Returns:
str: Formatted dataset statistics
"""
# Load the YAML config to get the data path
with open(yaml_path, "r") as f:
config = yaml.safe_load(f)
# Extract dataset info from config
datasets = config.get("datasets", {})
if not datasets:
return f"Dataset: {dataset_name or 'unknown'}\nType: file\nRecords loaded: 0\nNo datasets found in config"
# Get the first dataset (or specified dataset)
if dataset_name and dataset_name in datasets:
dataset_config = datasets[dataset_name]
actual_dataset_name = dataset_name
else:
actual_dataset_name, dataset_config = next(iter(datasets.items()))
data_path = dataset_config.get("path")
if not data_path:
return f"Dataset: {actual_dataset_name}\nType: file\nRecords loaded: 0\nNo data path found"
# Load the data
try:
with open(data_path, "r") as f:
data = json.load(f)
return compute_dataset_stats(data, actual_dataset_name)
except Exception as e:
return f"Dataset: {actual_dataset_name}\nType: file\nRecords loaded: 0\nError loading data: {e}"

441
docetl/utils_evaluation.py Normal file
View File

@ -0,0 +1,441 @@
"""
Evaluation utility functions for DocETL.
"""
import importlib.util
import inspect
import json
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from docetl.console import DOCETL_CONSOLE
def register_eval(
func: Callable[[str, str], Dict[str, Any]]
) -> Callable[[str, str], Dict[str, Any]]:
"""
Decorator to mark a function as a DocETL evaluation function.
The decorated function should take two arguments (dataset_file_path, results_file_path) and return
a dictionary of evaluation metrics.
Example:
@docetl.register_eval
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
# ... evaluation logic ...
return {"score": 0.95}
"""
func._docetl_eval = True
return func
def load_custom_evaluate_func(
evaluation_file_path: str, dataset_file_path: str
) -> Callable[[str], Dict[str, Any]]:
"""
Load a custom evaluation function from a Python file and wrap it to pass dataset_file_path.
The file should contain a function decorated with @docetl.register_eval.
If multiple functions are decorated, an error is raised.
Args:
evaluation_file_path: Path to a Python file containing a function decorated with @docetl.register_eval
dataset_file_path: Path to the dataset file to pass to the evaluation function
Returns:
callable: Wrapped evaluation function that takes (results_file_path: str) -> dict
Raises:
ValueError: If the file doesn't exist, doesn't contain a decorated function, or has multiple decorated functions
"""
func_path = Path(evaluation_file_path)
if not func_path.exists():
raise ValueError(f"Evaluation file not found: {evaluation_file_path}")
# Use a unique module name based on the file path to avoid conflicts
module_name = f"docetl_eval_{func_path.stem}_{hash(str(func_path))}"
spec = importlib.util.spec_from_file_location(module_name, func_path)
if spec is None or spec.loader is None:
raise ValueError(f"Could not load module from: {evaluation_file_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Find all functions decorated with @docetl.register_eval
eval_functions = []
for name, obj in inspect.getmembers(module, inspect.isfunction):
if hasattr(obj, "_docetl_eval") and obj._docetl_eval:
eval_functions.append((name, obj))
if len(eval_functions) == 0:
raise ValueError(
f"Module {evaluation_file_path} must contain a function decorated with @docetl.register_eval. "
f"Found functions: {[name for name, _ in inspect.getmembers(module, inspect.isfunction)]}"
)
if len(eval_functions) > 1:
function_names = [name for name, _ in eval_functions]
raise ValueError(
f"Module {evaluation_file_path} contains multiple functions decorated with @docetl.register_eval: {function_names}. "
f"Only one evaluation function is allowed per file."
)
# Wrap the function to pass dataset_file_path
original_func = eval_functions[0][1]
def wrapped_func(results_file_path: str) -> Dict[str, Any]:
return original_func(dataset_file_path, results_file_path)
return wrapped_func
def _extract_node_data(item: Any) -> tuple[Optional[str], Dict[str, Any]]:
"""Extract node data from either a node object or a dict/file path."""
if hasattr(item, "result_path"):
jf = item.result_path
node_data = {
"node_id": item.get_id(),
"cost": item.cost,
"visits": getattr(item, "visits", 0),
"value": getattr(item, "value", 0),
}
else:
jf = item.get("file_path") if isinstance(item, dict) else item
node_data = {
"node_id": (
item.get("node_id", "unknown") if isinstance(item, dict) else "unknown"
),
"cost": item.get("cost", 0.0) if isinstance(item, dict) else 0.0,
"visits": item.get("visits", 0) if isinstance(item, dict) else 0,
"value": item.get("value", 0) if isinstance(item, dict) else 0,
}
return jf, node_data
def _get_display_path(jf: str, output_path: Path) -> str:
"""Get display path for a result file."""
jp = Path(jf).resolve()
op_root = output_path.resolve()
if hasattr(jp, "is_relative_to") and jp.is_relative_to(op_root):
return str(jp.relative_to(op_root))
else:
return jp.name
def _add_frontier_info(result: Dict[str, Any], item: Any) -> Dict[str, Any]:
"""Add frontier information if available."""
if hasattr(item, "result_path"):
result.update(
{
"moar_accuracy": getattr(item, "moar_accuracy", None),
"on_frontier": getattr(item, "on_frontier", False),
}
)
return result
def identify_pareto_frontier(
eval_results: List[Dict[str, Any]], metric_key: str
) -> List[Dict[str, Any]]:
"""
Identify the Pareto frontier for evaluation results based on accuracy vs cost.
Args:
eval_results: List of evaluation results with cost and accuracy metrics
metric_key: Key to use for accuracy metric
Returns:
Updated eval_results with 'on_frontier' field set to True/False
Raises:
KeyError: If required metrics are missing from results
"""
if not eval_results:
return eval_results
# Filter out results that don't have the required metrics
valid_results = [r for r in eval_results if metric_key in r and "cost" in r]
if not valid_results:
DOCETL_CONSOLE.log(
f"[yellow]⚠️ No valid results with {metric_key} and cost metrics[/yellow]"
)
return eval_results
# Validate that all results have the required metrics
for r in valid_results:
if metric_key not in r:
raise KeyError(
f"Missing required accuracy metric '{metric_key}' in evaluation result. "
f"Available keys: {list(r.keys())}. "
f"This metric is required for Pareto frontier identification."
)
if "cost" not in r:
raise KeyError(
f"Missing required 'cost' metric in evaluation result. "
f"Available keys: {list(r.keys())}"
)
# Sort by cost (ascending) and accuracy (descending for maximization)
valid_results.sort(key=lambda x: (x["cost"], -x[metric_key]))
# Identify Pareto frontier: points that are not dominated by any other point
frontier = []
for i, candidate in enumerate(valid_results):
is_dominated = False
for j, other in enumerate(valid_results):
if i == j:
continue
# Check if other point dominates candidate
# Dominated if: other has lower cost AND higher/equal accuracy, OR same cost AND higher accuracy
if (
other["cost"] < candidate["cost"]
and other[metric_key] >= candidate[metric_key]
) or (
other["cost"] == candidate["cost"]
and other[metric_key] > candidate[metric_key]
):
is_dominated = True
break
if not is_dominated:
frontier.append(candidate)
# Mark all results with frontier status
frontier_set = set(id(f) for f in frontier)
for r in eval_results:
r["on_frontier"] = id(r) in frontier_set
return eval_results
def print_pareto_frontier_summary(
eval_results: List[Dict[str, Any]],
metric_key: str,
dataset_name: Optional[str] = None,
) -> None:
"""
Print a summary of the Pareto frontier points.
Args:
eval_results: List of evaluation results with 'on_frontier' field
metric_key: Key to use for accuracy metric
dataset_name: Optional dataset name for display purposes
Raises:
KeyError: If required metrics are missing from frontier points
"""
frontier_points = [r for r in eval_results if r.get("on_frontier", False)]
if not frontier_points:
dataset_str = f" for {dataset_name}" if dataset_name else ""
DOCETL_CONSOLE.log(
f"[yellow]📊 No Pareto frontier points found{dataset_str}[/yellow]"
)
return
# Sort frontier points by cost for better display
frontier_points.sort(key=lambda x: x["cost"])
dataset_str = f" for {dataset_name.upper()}" if dataset_name else ""
DOCETL_CONSOLE.log(f"\n🏆 Pareto Frontier Summary{dataset_str}:")
DOCETL_CONSOLE.log("=" * 60)
DOCETL_CONSOLE.log(
f"{'Rank':<4} {'Cost ($)':<10} {metric_key.upper():<15} {'File':<30}"
)
DOCETL_CONSOLE.log("=" * 60)
for i, point in enumerate(frontier_points, 1):
if "cost" not in point:
raise KeyError(
f"Missing required 'cost' metric in frontier point. "
f"Available keys: {list(point.keys())}"
)
if metric_key not in point:
raise KeyError(
f"Missing required accuracy metric '{metric_key}' in frontier point. "
f"Available keys: {list(point.keys())}. "
f"This metric is required for Pareto frontier summary."
)
cost = point["cost"]
accuracy = point[metric_key]
file_name = point.get("file", "unknown")
DOCETL_CONSOLE.log(f"{i:<4} ${cost:<9.4f} {accuracy:<15.4f} {file_name:<30}")
DOCETL_CONSOLE.log("=" * 60)
DOCETL_CONSOLE.log(f"Total frontier points: {len(frontier_points)}")
def save_pareto_frontier_results(
eval_results: List[Dict[str, Any]],
output_path: Path,
metric_key: str,
dataset_name: Optional[str] = None,
) -> None:
"""
Save Pareto frontier results to a separate JSON file for analysis.
Args:
eval_results: List of evaluation results with 'on_frontier' field
output_path: Output directory path
metric_key: Key to use for accuracy metric
dataset_name: Optional dataset name
Raises:
KeyError: If required metrics are missing from frontier points
"""
frontier_points = [r for r in eval_results if r.get("on_frontier", False)]
if not frontier_points:
return
# Sort frontier points by cost
frontier_points.sort(key=lambda x: x["cost"])
# Add rank and accuracy metric information
for i, point in enumerate(frontier_points):
point["rank"] = i + 1
point["accuracy_metric"] = metric_key
# Calculate cost-effectiveness ratios between consecutive points
cost_effectiveness_analysis = []
for i in range(len(frontier_points) - 1):
curr = frontier_points[i]
next_point = frontier_points[i + 1]
if "cost" not in curr or "cost" not in next_point:
raise KeyError(
f"Missing required 'cost' metric in frontier point. "
f"Available keys: {list(curr.keys() if 'cost' not in curr else next_point.keys())}"
)
if metric_key not in curr or metric_key not in next_point:
raise KeyError(
f"Missing required accuracy metric '{metric_key}' in frontier point. "
f"Available keys: {list(curr.keys() if metric_key not in curr else next_point.keys())}. "
f"This metric is required for cost-effectiveness analysis."
)
cost_diff = next_point["cost"] - curr["cost"]
accuracy_diff = next_point[metric_key] - curr[metric_key]
if cost_diff > 0 and accuracy_diff > 0:
cost_effectiveness = cost_diff / accuracy_diff
cost_effectiveness_analysis.append(
{
"from_file": curr["file"],
"to_file": next_point["file"],
"cost_increase": cost_diff,
"accuracy_increase": accuracy_diff,
"cost_per_unit_improvement": cost_effectiveness,
}
)
# Create frontier summary
frontier_summary = {
"accuracy_metric": metric_key,
"total_frontier_points": len(frontier_points),
"frontier_points": frontier_points,
"cost_effectiveness_analysis": cost_effectiveness_analysis,
}
if dataset_name:
frontier_summary["dataset"] = dataset_name
# Save to file
frontier_file = output_path / "pareto_frontier.json"
with open(frontier_file, "w") as f:
json.dump(frontier_summary, f, indent=2)
DOCETL_CONSOLE.log(f"📊 Pareto frontier results written to {frontier_file}")
def run_evaluation(
nodes_or_files: List[Any],
evaluate_func: Callable[[str], Dict[str, Any]],
metric_key: str,
output_path: Path,
dataset_name: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
Run evaluation on a set of nodes or files using a custom evaluation function.
This is a general-purpose evaluation function that does not depend on
experiment-specific datasets. It processes nodes/files, extracts metrics,
identifies the Pareto frontier, and saves results.
Args:
nodes_or_files: List of node objects (with result_path) or file paths
evaluate_func: Evaluation function (results_file_path: str) -> dict
metric_key: Key to extract from evaluation results for accuracy metric
output_path: Path to save evaluation results
dataset_name: Optional dataset name for display purposes
Returns:
List of evaluation results with 'on_frontier' field set
Raises:
ValueError: If metric_key is not provided
KeyError: If required metrics are missing from evaluation results
"""
if not metric_key:
raise ValueError("metric_key must be provided")
eval_results = []
# Process evaluation items
for item in nodes_or_files:
jf, node_data = _extract_node_data(item)
if jf is None or not Path(jf).exists():
continue
try:
metrics = evaluate_func(jf)
display_path = _get_display_path(jf, output_path)
# Extract the custom metric
accuracy_value = metrics.get(metric_key)
if accuracy_value is None:
DOCETL_CONSOLE.log(
f"[yellow]⚠️ Warning: Metric key '{metric_key}' not found in evaluation results for {jf}. "
f"Available keys: {list(metrics.keys())}[/yellow]"
)
# Try to find a numeric value as fallback
accuracy_value = next(
(v for v in metrics.values() if isinstance(v, (int, float))), None
)
if accuracy_value is None:
DOCETL_CONSOLE.log(
f"[red]❌ Skipping {jf}: No valid accuracy metric found[/red]"
)
continue
result = {
"file": display_path,
metric_key: accuracy_value,
**metrics, # Include all metrics from custom function
**node_data,
}
result = _add_frontier_info(result, item)
eval_results.append(result)
except Exception as e:
DOCETL_CONSOLE.log(f"[red] ⚠️ Evaluation failed for {jf}: {e}[/red]")
# Identify Pareto frontier
if eval_results:
DOCETL_CONSOLE.log("\n🔍 Identifying Pareto frontier...")
eval_results = identify_pareto_frontier(eval_results, metric_key)
# Print Pareto frontier summary
print_pareto_frontier_summary(eval_results, metric_key, dataset_name)
# Save Pareto frontier results to separate file
save_pareto_frontier_results(
eval_results, output_path, metric_key, dataset_name
)
# Save evaluation results
if eval_results:
eval_out_file = output_path / "evaluation_metrics.json"
with open(eval_out_file, "w") as f:
json.dump(eval_results, f, indent=2)
DOCETL_CONSOLE.log(f"📊 Evaluation results written to {eval_out_file}")
return eval_results

View File

@ -101,6 +101,63 @@
ignore_init_summary: false
trim_doctest_flags: true
::: docetl.schemas.CodeMapOp
options:
show_root_heading: true
heading_level: 3
show_if_no_docstring: false
docstring_options:
ignore_init_summary: false
trim_doctest_flags: true
::: docetl.schemas.CodeReduceOp
options:
show_root_heading: true
heading_level: 3
show_if_no_docstring: false
docstring_options:
ignore_init_summary: false
trim_doctest_flags: true
::: docetl.schemas.CodeFilterOp
options:
show_root_heading: true
heading_level: 3
show_if_no_docstring: false
docstring_options:
ignore_init_summary: false
trim_doctest_flags: true
::: docetl.schemas.ExtractOp
options:
show_root_heading: true
heading_level: 3
show_if_no_docstring: false
docstring_options:
ignore_init_summary: false
trim_doctest_flags: true
### Callable support for code ops
Code operations (`code_map`, `code_reduce`, `code_filter`) accept either a string containing Python code that defines a `transform` function, or a regular Python function. When you pass a function, it does not need to be named `transform`; DocETL binds it internally.
Example:
```python
from docetl.api import CodeMapOp
def my_map(doc: dict) -> dict:
return {"double": doc["x"] * 2}
code_map = CodeMapOp(name="double_x", type="code_map", code=my_map)
```
- Map: `fn(doc: dict) -> dict`
- Filter: `fn(doc: dict) -> bool`
- Reduce: `fn(group: list[dict]) -> dict`
See also: [Code Operators](../operators/code.md), [Extract Operator](../operators/extract.md)
## Dataset and Pipeline
::: docetl.schemas.Dataset

View File

@ -56,13 +56,31 @@ The DocETL optimizer operates using the following mechanism:
### Using the Optimizer
You can invoke the optimizer using the following command:
DocETL provides two optimizer options:
#### MOAR Optimizer (Recommended)
The MOAR optimizer uses Monte Carlo Tree Search to find Pareto-optimal solutions balancing accuracy and cost:
```bash
docetl build your_pipeline.yaml
docetl build your_pipeline.yaml --optimizer moar
```
This command will save the optimized pipeline to `your_pipeline_opt.yaml`. Note that the optimizer will only rewrite operators where you've set `optimize: true`. Leaving this field unset will skip optimization for that operator.
See the [MOAR Optimizer Guide](../optimization/moar.md) for detailed instructions.
#### V1 Optimizer (Deprecated)
!!! warning "Deprecated"
The V1 optimizer is deprecated and no longer recommended. Use MOAR instead.
The V1 optimizer uses a greedy approach with validation. It's still available for backward compatibility:
```bash
docetl build your_pipeline.yaml --optimizer v1
```
!!! warning "Not Recommended"
The V1 optimizer should not be used for new projects. Use MOAR instead.
<!-- ### Automatic Entity Resolution

View File

@ -34,6 +34,9 @@ To get started with DocETL:
2. Define your pipeline in a YAML file. Want to use an LLM like ChatGPT or Claude to help you write your pipeline? See [docetl.org/llms.txt](https://docetl.org/llms.txt) for a big prompt you can copy paste into ChatGPT or Claude, before describing your task.
3. Run your pipeline using the DocETL command-line interface
!!! tip "Fastest Way: Claude Code"
Clone this repo and run `claude` to use the built-in DocETL skill. Just describe your data processing task and Claude will create and run the pipeline for you. See [Quick Start (Claude Code)](quickstart-claude-code.md) for details.
## 🏛️ Project Origin
DocETL was created by members of the EPIC Data Lab and Data Systems and Foundations group at UC Berkeley. The EPIC (Effective Programming, Interaction, and Computation with Data) Lab focuses on developing low-code and no-code interfaces for data work, powered by next-generation predictive programming techniques. DocETL is one of the projects that emerged from our research efforts to streamline complex document processing tasks.

View File

@ -55,8 +55,6 @@ If you want to use only the parsing extra:
uv sync --extra parsing
```
This will create a virtual environment and install all the required dependencies.
4. Set up your OpenAI API key:
Create a .env file in the project root and add your OpenAI API key:

View File

@ -91,3 +91,10 @@ The transform function should return True for items to keep and False for items
| reduce_key | Key(s) to group by (code_reduce only) | "_all" |
| pass_through | Pass through unmodified keys from first item in group (code_reduce only) | false |
| concurrent_thread_count | The number of threads to start | the number of logical CPU cores (os.cpu_count()) |
| limit | Maximum number of outputs to produce before stopping | Processes all data |
The `limit` parameter behaves differently for each operation type:
- **code_map**: Limits the number of input documents to process
- **code_filter**: Limits the number of documents that pass the filter (outputs)
- **code_reduce**: Limits the number of groups to reduce, selecting the smallest groups first (by document count)

View File

@ -34,9 +34,9 @@ This Equijoin operation matches job candidates to job postings:
The prompt template uses Jinja2 syntax, allowing you to reference input fields directly (e.g., `left.skills`). You can reference the left and right documents using `left` and `right` respectively.
!!! warning "Performance Consideration"
!!! info "Automatic Blocking"
For large datasets, running comparisons with an LLM can be time-consuming. It's recommended to optimize your pipeline using `docetl build pipeline.yaml` to generate efficient blocking rules for the operation.
If you don't specify any blocking configuration (`blocking_threshold`, `blocking_conditions`, or `limit_comparisons`), the Equijoin operation will automatically compute an optimal embedding-based blocking threshold at runtime. It samples pairs from your data, runs LLM comparisons on the sample, and finds a threshold that achieves 95% recall by default. You can adjust this with the `blocking_target_recall` parameter.
## Blocking
@ -95,10 +95,19 @@ A full Equijoin step combining both ideas might look like:
Equijoin shares many parameters with the Resolve operation. For a detailed list of required and optional parameters, please see the [Parameters section in the Resolve operation documentation](resolve.md#required-parameters).
Key differences for Equijoin include:
### Equijoin-Specific Parameters
| Parameter | Description | Default |
| ------------------------- | --------------------------------------------------------------------------------- | ----------------------------- |
| `limits` | Maximum matches for each left/right item: `{"left": n, "right": m}` | No limit |
| `blocking_keys` | Keys for embedding blocking: `{"left": [...], "right": [...]}` | All keys from each dataset |
| `blocking_threshold` | Embedding similarity threshold for considering pairs | Auto-computed if not set |
| `blocking_target_recall` | Target recall when auto-computing blocking threshold (0.0 to 1.0) | 0.95 |
Key differences from Resolve:
- `resolution_prompt` is not used in Equijoin.
- `limits` parameter is specific to Equijoin, allowing you to set maximum matches for each left and right item.
- `blocking_keys` uses a dict with `left` and `right` keys instead of a simple list.
## Incorporating Into a Pipeline

View File

@ -140,6 +140,11 @@ This strategy asks the LLM to generate regex patterns matching the desired conte
| `timeout` | Timeout for LLM calls in seconds | 120 |
| `skip_on_error` | Continue processing if errors occur | false |
| `litellm_completion_kwargs` | Additional parameters for LiteLLM calls | {} |
| `limit` | Maximum number of documents to extract from before stopping | Processes all data |
| `retriever` | Name of a retriever to use for RAG. See [Retrievers](../retrievers.md). | None |
| `save_retriever_output` | If true, saves the retrieved context to `_<operation_name>_retrieved_context` in the output. | False |
When `limit` is set, Extract only reformats and submits the first _N_ documents. This is handy when the upstream dataset is large and you want to cap cost while previewing results.
## Best Practices

View File

@ -83,6 +83,10 @@ This example demonstrates how the Filter operation distinguishes between high-im
See [map optional parameters](./map.md#optional-parameters) for additional configuration options, including `batch_prompt` and `max_batch_size`.
### Limiting filtered outputs
`limit` behaves slightly differently for filter operations than for map operations. Because filter drops documents whose predicate evaluates to `false`, the limit counts only the documents that would be retained (i.e., the ones whose boolean output is `true`). DocETL will continue evaluating additional inputs until it has collected `limit` passing documents and then stop scheduling further LLM calls. This ensures you can request “the first N matches” without paying to score the entire dataset.
!!! info "Validation"
For more details on validation techniques and implementation, see [operators](../concepts/operators.md#validation).

View File

@ -140,6 +140,7 @@ This example demonstrates how the Map operation can transform long, unstructured
| `optimize` | Flag to enable operation optimization | `True` |
| `recursively_optimize` | Flag to enable recursive optimization of operators synthesized as part of rewrite rules | `false` |
| `sample` | Number of samples to use for the operation | Processes all data |
| `limit` | Maximum number of outputs to produce before stopping | Processes all data |
| `tools` | List of tool definitions for LLM use | None |
| `validate` | List of Python expressions to validate the output | None |
| `flush_partial_results` | Write results of individual batches of map operation to disk for faster inspection | False |
@ -155,9 +156,15 @@ This example demonstrates how the Map operation can transform long, unstructured
| `pdf_url_key` | If specified, the key in the input that contains the URL of the PDF to process. | None |
| `calibrate` | Improve consistency across documents by using sample data as reference anchors. | False |
| `num_calibration_docs` | Number of documents to use sample and generate outputs for, for calibration. | 10 |
| `retriever` | Name of a retriever to use for RAG. See [Retrievers](../retrievers.md). | None |
| `save_retriever_output` | If true, saves the retrieved context to `_<operation_name>_retrieved_context` in the output. | False |
Note: If `drop_keys` is specified, `prompt` and `output` become optional parameters.
### Limiting execution
Set `limit` when you only need the first _N_ map results or want to cap LLM spend. The operation slices the processed dataset to the first `limit` entries and also stops scheduling new prompts once that many outputs have been produced, even if a prompt returns multiple records. Filter operations inherit this behavior but redefine the count so the limit only applies to records whose filter predicate evaluates to `true` (see [Filter](./filter.md#optional-parameters)).
!!! info "Validation and Gleaning"

View File

@ -52,6 +52,7 @@ This Reduce operation processes customer feedback grouped by department:
| Parameter | Description | Default |
| ------------------------- | ------------------------------------------------------------------------------------------------------ | --------------------------- |
| `sample` | Number of samples to use for the operation | None |
| `limit` | Maximum number of groups to process before stopping | All groups |
| `synthesize_resolve` | If false, won't synthesize a resolve operation between map and reduce | true |
| `model` | The language model to use | Falls back to default_model |
| `input` | Specifies the schema or keys to subselect from each item | All keys from input items |
@ -66,6 +67,12 @@ This Reduce operation processes customer feedback grouped by department:
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
| `bypass_cache` | If true, bypass the cache for this operation. | False |
| `retriever` | Name of a retriever to use for RAG. See [Retrievers](../retrievers.md). | None |
| `save_retriever_output` | If true, saves the retrieved context to `_<operation_name>_retrieved_context` in the output. | False |
### Limiting group processing
Set `limit` to short-circuit the reduce phase after _N_ groups have been aggregated. When `limit` is set, groups are sorted by size (smallest first) and only the _N_ smallest groups are processed. This is useful for previewing results or capping LLM usage while minimizing cost by processing groups with fewer documents. Groups beyond the limit are never scheduled, so you avoid extra fold/merge calls. If a grouped reduce returns more than one record per group, the final output list is truncated to `limit`.
## Advanced Features

View File

@ -44,9 +44,9 @@ This Resolve operation processes patient names to identify and standardize dupli
Note: The prompt templates use Jinja2 syntax, allowing you to reference input fields directly (e.g., `input1.patient_name`).
!!! warning "Performance Consideration"
!!! info "Automatic Blocking"
You should not run this operation as-is unless your dataset is small! Running O(n^2) comparisons with an LLM can be extremely time-consuming for large datasets. Instead, optimize your pipeline first using `docetl build pipeline.yaml` and run the optimized version, which will generate efficient blocking rules for the operation. Make sure you've set `optimize: true` in your resolve operation config.
If you don't specify any blocking configuration (`blocking_threshold`, `blocking_conditions`, or `limit_comparisons`), the Resolve operation will automatically compute an optimal embedding-based blocking threshold at runtime. It samples pairs from your data, runs LLM comparisons on the sample, and finds a threshold that achieves 95% recall by default. You can adjust this with the `blocking_target_recall` parameter.
## Blocking
@ -132,7 +132,8 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un
| `resolution_model` | The language model to use for reducing matched entries | Falls back to `default_model` |
| `comparison_model` | The language model to use for comparing potential matches | Falls back to `default_model` |
| `blocking_keys` | List of keys to use for initial blocking | All keys in the input data |
| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | None |
| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | Auto-computed if not set |
| `blocking_target_recall` | Target recall when auto-computing blocking threshold (0.0 to 1.0) | 0.95 |
| `blocking_conditions` | List of conditions for initial blocking | [] |
| `input` | Specifies the schema or keys to subselect from each item to pass into the prompts | All keys from input items |
| `embedding_batch_size` | The number of entries to send to the embedding model at a time | 1000 |
@ -140,9 +141,9 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un
| `limit_comparisons` | Maximum number of comparisons to perform | None |
| `timeout` | Timeout for each LLM call in seconds | 120 |
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |
| `sample` | Number of samples to use for the operation | None |
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
| `bypass_cache` | If true, bypass the cache for this operation. | False |
| `sample` | Number of samples to use for the operation | None |
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
| `bypass_cache` | If true, bypass the cache for this operation. | False |
## Best Practices

34
docs/optimization/moar.md Normal file
View File

@ -0,0 +1,34 @@
# MOAR Optimizer
The MOAR (Multi-Objective Agentic Rewrites) optimizer explores different ways to optimize your pipeline, finding solutions that balance accuracy and cost.
## What is MOAR?
When optimizing pipelines, you trade off cost and accuracy. MOAR explores many different pipeline configurations (like changing models, adding validation steps, combining operations, etc.) and evaluates each one to find the best trade-offs. It returns a frontier of plans that balance cost and accuracy, giving you multiple optimized options to choose from based on your budget and accuracy requirements.
## Quick Navigation
- **[Getting Started](moar/getting-started.md)** - Step-by-step guide to run your first MOAR optimization
- **[Configuration](moar/configuration.md)** - Complete reference for all configuration options
- **[Evaluation Functions](moar/evaluation.md)** - How to write and use evaluation functions
- **[Understanding Results](moar/results.md)** - What MOAR outputs and how to interpret it
- **[Examples](moar/examples.md)** - Complete working examples
- **[Troubleshooting](moar/troubleshooting.md)** - Common issues and solutions
## When to Use MOAR
!!! success "Good for"
- Finding cost-accuracy trade-offs across different models
- When you want multiple optimization options to choose from
- Custom evaluation metrics specific to your use case
- Exploring different pipeline configurations automatically
## Basic Workflow
1. **Create your pipeline YAML** - Define your DocETL pipeline
2. **Write an evaluation function** - Create a Python function to measure accuracy
3. **Configure MOAR** - Set up `optimizer_config` in your YAML
4. **Run optimization** - Execute `docetl build pipeline.yaml --optimizer moar`
5. **Review results** - Choose from the cost-accuracy frontier
Ready to get started? Head to the [Getting Started guide](moar/getting-started.md).

View File

@ -0,0 +1,116 @@
# MOAR Configuration Reference
Complete reference for all MOAR configuration options.
## Required Fields
All fields in `optimizer_config` are required (no defaults):
| Field | Type | Description |
|-------|------|-------------|
| `type` | `str` | Must be `"moar"` |
| `save_dir` | `str` | Directory where MOAR results will be saved |
| `available_models` | `list[str]` | List of LiteLLM model names to explore (e.g., `["gpt-4o-mini", "gpt-4o"]`). Make sure your API keys are set in your environment for these models. |
| `evaluation_file` | `str` | Path to Python file containing `@register_eval` decorated function |
| `metric_key` | `str` | Key in evaluation results dictionary to use as accuracy metric |
| `max_iterations` | `int` | Maximum number of MOARSearch iterations to run |
| `rewrite_agent_model` | `str` | LLM model to use for directive instantiation during search |
!!! warning "All Fields Required"
MOAR will error if any required field is missing. There are no defaults.
## Optional Fields
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `dataset_path` | `str` | Inferred from `datasets` | Path to dataset file to use for optimization. Use a sample/hold-out dataset to avoid optimizing on your test set. |
| `exploration_weight` | `float` | `1.414` | UCB exploration constant (higher = more exploration) |
| `build_first_layer` | `bool` | `False` | Whether to build initial model-specific nodes |
| `ground_truth_path` | `str` | `None` | Path to ground truth file (for evaluation) |
## Dataset Path
### Automatic Inference
If `dataset_path` is not specified, MOAR will automatically infer it from the `datasets` section of your YAML:
```yaml
datasets:
transcripts:
path: data/full_dataset.json # This will be used if dataset_path not specified
type: file
optimizer_config:
# dataset_path not specified - will use data/full_dataset.json
# ... other config ...
```
### Using Sample/Hold-Out Datasets
!!! tip "Best Practice"
Use a sample or hold-out dataset for optimization to avoid optimizing on your test set.
```yaml
optimizer_config:
dataset_path: data/sample_dataset.json # Use sample/hold-out for optimization
# ... other config ...
datasets:
transcripts:
path: data/full_dataset.json # Full dataset for final pipeline
```
The optimizer will use the sample dataset, but your final pipeline uses the full dataset. This ensures you don't overfit to your test set during optimization.
## Model Configuration
### Available Models
!!! info "LiteLLM Model Names"
Use LiteLLM model names (e.g., `gpt-4o-mini`, `gpt-4o`, `gpt-5.1`). Make sure your API keys are set in your environment.
```yaml
available_models: # LiteLLM model names - ensure API keys are set
- gpt-5.1-nano # Cheapest, lower accuracy
- gpt-5.1-mini # Low cost, decent accuracy
- gpt-5.1 # Balanced
- gpt-4o # Higher cost, better accuracy
```
### Model for Directive Instantiation
The `rewrite_agent_model` field specifies which LLM to use for generating optimization directives during the search process. This doesn't affect the models tested in `available_models`.
!!! tip "Cost Consideration"
Use a cheaper model (like `gpt-4o-mini`) for directive instantiation to reduce search costs.
## Iteration Count
The `max_iterations` parameter controls how many pipeline configurations MOAR explores:
- **10-20 iterations**: Quick exploration, good for testing
- **40 iterations**: Recommended for most use cases
- **100+ iterations**: For complex pipelines or when you need the absolute best results
!!! note "Time vs Quality"
More iterations give better results but take longer and cost more.
## Complete Example
```yaml
optimizer_config:
type: moar
save_dir: results/moar_optimization
available_models:
- gpt-4o-mini
- gpt-4o
- gpt-5.1-mini
- gpt-5.1
evaluation_file: evaluate_medications.py
metric_key: medication_extraction_score
max_iterations: 40
rewrite_agent_model: gpt-5.1
dataset_path: data/sample.json # Optional
exploration_weight: 1.414 # Optional
```

View File

@ -0,0 +1,193 @@
# Evaluation Functions
How to write evaluation functions for MOAR optimization.
## How Evaluation Functions Work
Your evaluation function receives the pipeline output and computes metrics by comparing it to the original dataset. MOAR uses one specific metric from your returned dictionary (specified by `metric_key`) to optimize for accuracy.
!!! info "Function Signature"
Your function must have exactly this signature:
```python
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
```
### What You Receive
- **`results_file_path`**: Path to JSON file containing your pipeline's output
- **`dataset_file_path`**: Path to JSON file containing the original dataset
### What You Return
A dictionary with numeric metrics. The key specified in `optimizer_config.metric_key` will be used as the accuracy metric for optimization.
!!! tip "Using Original Input Data"
Pipeline output includes the original input data. For example, if your dataset has a `src` attribute, it will be available in the output. You can use this directly for comparison without loading the dataset file separately.
## Basic Example
```python
import json
from typing import Any, Dict
from docetl.utils_evaluation import register_eval
@register_eval
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
# Load pipeline output
with open(results_file_path, 'r') as f:
output = json.load(f)
total_correct = 0
for result in output:
# For example, if your dataset has a 'src' attribute, it's available in the output
original_text = result.get("src", "").lower()
# Replace "your_extraction_key" with the actual key from your pipeline output
extracted_items = result.get("your_extraction_key", [])
# Check if extracted items appear in original text
for item in extracted_items:
if str(item).lower() in original_text:
total_correct += 1
return {
"extraction_score": total_correct, # This key is used if metric_key="extraction_score"
"total_extracted": sum(len(r.get("your_extraction_key", [])) for r in output),
}
```
## Requirements
!!! warning "Critical Requirements"
- The function must be decorated with `@docetl.register_eval`
- It must take exactly two arguments: `dataset_file_path` and `results_file_path`
- It must return a dictionary with numeric metrics
- The `metric_key` in your `optimizer_config` must match one of the keys in this dictionary
- Only one function per file can be decorated with `@register_eval`
## Performance Considerations
!!! tip "Keep It Fast"
Your evaluation function will be called many times during optimization. Make sure it's efficient:
- Avoid expensive computations
- Cache results if possible
- Keep the function simple and fast
## Common Evaluation Patterns
### Pattern 1: Extraction Verification with Recall
Check if extracted items appear in the document text and compute recall:
```python
@register_eval
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
with open(results_file_path, 'r') as f:
output = json.load(f)
# For example, if your dataset has a 'src' attribute, it's available in the output
total_correct = 0
total_extracted = 0
total_expected = 0
for result in output:
# Replace "src" with the actual key from your dataset
original_text = result.get("src", "").lower()
extracted_items = result.get("your_extraction_key", []) # Replace with your key
# Count correct extractions (items that appear in text)
for item in extracted_items:
total_extracted += 1
if str(item).lower() in original_text:
total_correct += 1
# Count expected items (if you have ground truth)
# total_expected += len(expected_items)
precision = total_correct / total_extracted if total_extracted > 0 else 0.0
recall = total_correct / total_expected if total_expected > 0 else 0.0
return {
"extraction_score": total_correct, # Use this as metric_key
"precision": precision,
"recall": recall,
}
```
### Pattern 2: Comparing Against Ground Truth
Load ground truth from the dataset file and compare:
```python
@register_eval
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
with open(results_file_path, 'r') as f:
predictions = json.load(f)
with open(dataset_file_path, 'r') as f:
ground_truth = json.load(f)
# Compare predictions with ground truth
# Adjust keys based on your data structure
correct = 0
total = len(predictions)
for pred, truth in zip(predictions, ground_truth):
# Example: compare classification labels
if pred.get("predicted_label") == truth.get("true_label"):
correct += 1
return {
"accuracy": correct / total if total > 0 else 0.0,
"correct": correct,
"total": total,
}
```
### Pattern 3: External Evaluation (File or API)
Load additional data or call an API for evaluation:
```python
import requests
from pathlib import Path
@register_eval
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
with open(results_file_path, 'r') as f:
output = json.load(f)
# Option A: Load ground truth from a separate file
ground_truth_path = Path(dataset_file_path).parent / "ground_truth.json"
with open(ground_truth_path, 'r') as f:
ground_truth = json.load(f)
# Option B: Call an API for evaluation
# response = requests.post("https://api.example.com/evaluate", json=output)
# api_score = response.json()["score"]
# Evaluate using ground truth
scores = []
for result, truth in zip(output, ground_truth):
# Your evaluation logic here
score = compute_score(result, truth)
scores.append(score)
return {
"average_score": sum(scores) / len(scores) if scores else 0.0,
"scores": scores,
}
```
## Testing Your Function
!!! tip "Test Before Running"
Test your evaluation function independently before running MOAR:
```python
result = evaluate_results("dataset.json", "results.json")
print(result) # Check that your metric_key is present
```
This helps catch errors early and ensures your function works correctly.

View File

@ -0,0 +1,133 @@
# MOAR Examples
Complete working examples for MOAR optimization.
## Medication Extraction Example
This example extracts medications from medical transcripts and evaluates extraction accuracy.
!!! note "Metric Key"
The `metric_key` in the `optimizer_config` section specifies which key from your evaluation function's return dictionary will be used as the accuracy metric. In this example, `metric_key: medication_extraction_score` means MOAR will optimize using the `medication_extraction_score` value returned by the evaluation function.
### pipeline.yaml
```yaml
datasets:
transcripts:
path: workloads/medical/raw.json
type: file
default_model: gpt-4o-mini
bypass_cache: true
optimizer_config:
type: moar
dataset_path: workloads/medical/raw_sample.json # Use sample for faster optimization
save_dir: workloads/medical/moar_results
available_models: # LiteLLM model names - ensure API keys are set in your environment
- gpt-5.1-nano
- gpt-5.1-mini
- gpt-5.1
- gpt-4o
- gpt-4o-mini
evaluation_file: workloads/medical/evaluate_medications.py
metric_key: medication_extraction_score
max_iterations: 40
rewrite_agent_model: gpt-5.1
system_prompt:
dataset_description: a collection of transcripts of doctor visits
persona: a medical practitioner analyzing patient symptoms and reactions to medications
operations:
- name: extract_medications
type: map
output:
schema:
medication: list[str]
prompt: |
Analyze the following transcript of a conversation between a doctor and a patient:
{{ input.src }}
Extract and list all medications mentioned in the transcript.
If no medications are mentioned, return an empty list.
pipeline:
steps:
- name: medication_extraction
input: transcripts
operations:
- extract_medications
output:
type: file
path: workloads/medical/extracted_medications_results.json
```
### evaluate_medications.py
```python
import json
from typing import Any, Dict
from docetl.utils_evaluation import register_eval
@register_eval
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
"""
Evaluate medication extraction results.
Checks if each extracted medication appears verbatim in the original transcript.
In this example, the dataset has a 'src' attribute with the original input text.
"""
# Load pipeline output
with open(results_file_path, 'r') as f:
output = json.load(f)
total_correct_medications = 0
total_extracted_medications = 0
# Evaluate each result
for result in output:
# In this example, the dataset has a 'src' attribute with the original transcript
original_transcript = result.get("src", "").lower()
extracted_medications = result.get("medication", [])
# Check each extracted medication
for medication in extracted_medications:
total_extracted_medications += 1
medication_lower = str(medication).lower().strip()
# Check if medication appears in transcript
if medication_lower in original_transcript:
total_correct_medications += 1
# Calculate metrics
precision = total_correct_medications / total_extracted_medications if total_extracted_medications > 0 else 0.0
return {
"medication_extraction_score": total_correct_medications, # This is used as the accuracy metric
"total_correct_medications": total_correct_medications,
"total_extracted_medications": total_extracted_medications,
"precision": precision,
}
```
### Running the Optimization
```bash
docetl build workloads/medical/pipeline_medication_extraction.yaml --optimizer moar
```
!!! tip "Using Sample Datasets"
Notice that `dataset_path` points to `raw_sample.json` for optimization, while the main pipeline uses `raw.json`. This prevents optimizing on your test set.
## Key Points
!!! info "Evaluation Function"
- In this example, uses the `src` attribute from output items (no need to load dataset separately)
- Checks if extracted medications appear verbatim in the transcript
- Returns multiple metrics, with `medication_extraction_score` as the primary one
!!! info "Configuration"
- Uses a sample dataset for optimization (`dataset_path`)
- Includes multiple models in `available_models` to explore trade-offs
- Sets `max_iterations` to 40 for a good balance of exploration and time

View File

@ -0,0 +1,157 @@
# Getting Started with MOAR
This guide walks you through running your first MOAR optimization step by step.
## Step 1: Create Your Pipeline YAML
Start with a standard DocETL pipeline YAML file:
```yaml
datasets:
transcripts:
path: data/transcripts.json
type: file
default_model: gpt-4o-mini
operations:
- name: extract_medications
type: map
output:
schema:
medication: list[str]
prompt: |
Extract all medications mentioned in: {{ input.src }}
pipeline:
steps:
- name: medication_extraction
input: transcripts
operations:
- extract_medications
output:
type: file
path: results.json
```
!!! note "Standard Pipeline"
Your pipeline doesn't need any special configuration for MOAR. Just create a normal DocETL pipeline.
## Step 2: Create an Evaluation Function
Create a Python file with an evaluation function. This function will be called for each pipeline configuration that MOAR explores.
!!! info "How Evaluation Works"
- Your function receives the pipeline output and the original dataset
- You compute evaluation metrics by comparing the output to the dataset
- You return a dictionary of metrics
- MOAR uses one specific key from this dictionary (specified by `metric_key`) as the accuracy metric to optimize
```python
# evaluate_medications.py
import json
from typing import Any, Dict
from docetl.utils_evaluation import register_eval
@register_eval
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
"""
Evaluate pipeline output against the original dataset.
"""
# Load pipeline output
with open(results_file_path, 'r') as f:
output = json.load(f)
# Load original dataset for comparison
with open(dataset_file_path, 'r') as f:
dataset = json.load(f)
# Compute your evaluation metrics
correct_count = 0
total_count = len(output)
for idx, result in enumerate(output):
# Compare result with original data
# For example, if your dataset has a 'src' attribute, it's available in the output
original_text = result.get("src", "").lower()
extracted_items = result.get("medication", [])
# Check if extracted items appear in original text
for item in extracted_items:
if item.lower() in original_text:
correct_count += 1
# Return dictionary of metrics
return {
"medication_extraction_score": correct_count, # This key will be used if metric_key matches
"total_extracted": total_count,
"precision": correct_count / total_count if total_count > 0 else 0.0,
}
```
!!! warning "Important Requirements"
- The function must be decorated with `@docetl.register_eval`
- It must take exactly two arguments: `dataset_file_path` and `results_file_path`
- It must return a dictionary with numeric metrics
- The `metric_key` in your `optimizer_config` must match one of the keys in this dictionary
- Only one function per file can be decorated with `@register_eval`
For more details on evaluation functions, see the [Evaluation Functions guide](evaluation.md).
## Step 3: Configure the Optimizer
Add an `optimizer_config` section to your YAML. The `metric_key` specifies which key from your evaluation function's return dictionary will be used as the accuracy metric for optimization:
```yaml
optimizer_config:
type: moar
save_dir: results/moar_optimization
available_models: # LiteLLM model names - ensure API keys are set in your environment
- gpt-4o-mini
- gpt-4o
- gpt-5.1-mini
- gpt-5.1
evaluation_file: evaluate_medications.py
metric_key: medication_extraction_score # This must match a key in your evaluation function's return dictionary
max_iterations: 40
rewrite_agent_model: gpt-5.1
dataset_path: data/transcripts_sample.json # Optional: use sample/hold-out dataset
```
!!! tip "Using Sample Datasets"
Use `dataset_path` to specify a sample or hold-out dataset for optimization. This prevents optimizing on your test set. The main pipeline will still use the full dataset from the `datasets` section.
For complete configuration details, see the [Configuration Reference](configuration.md).
## Step 4: Run the Optimizer
Run MOAR optimization using the CLI:
```bash
docetl build pipeline.yaml --optimizer moar
```
!!! success "What Happens Next"
MOAR will:
1. Explore different pipeline configurations
2. Evaluate each configuration using your evaluation function
3. Build a cost-accuracy frontier of optimal solutions
4. Save results to your `save_dir`
## Step 5: Review Results
After optimization completes, check your `save_dir` for:
- **`experiment_summary.json`** - High-level summary of the run
- **`pareto_frontier.json`** - List of optimal solutions
- **`evaluation_metrics.json`** - Detailed evaluation results
- **`pipeline_*.yaml`** - Optimized pipeline configurations
For details on interpreting results, see [Understanding Results](results.md).
## Next Steps
- Learn about [configuration options](configuration.md)
- See [complete examples](examples.md)
- Read [troubleshooting tips](troubleshooting.md)

View File

@ -0,0 +1,93 @@
# Understanding MOAR Results
What MOAR outputs and how to interpret the results.
## Output Files
After running MOAR optimization, you'll find several files in your `save_dir`:
- **`experiment_summary.json`** - High-level summary
- **`pareto_frontier.json`** - Optimal solutions
- **`evaluation_metrics.json`** - Detailed evaluation results
- **`pipeline_*.yaml`** - Optimized pipeline configurations
## experiment_summary.json
High-level summary of the optimization run:
```json
{
"optimizer": "moar",
"input_pipeline": "pipeline.yaml",
"rewrite_agent_model": "gpt-5.1",
"max_iterations": 40,
"save_dir": "results/moar_optimization",
"dataset": "transcripts",
"start_time": "2024-01-15T10:30:00",
"end_time": "2024-01-15T11:15:00",
"duration_seconds": 2700,
"num_best_nodes": 5,
"total_nodes_explored": 120,
"total_search_cost": 15.50
}
```
!!! info "Key Metrics"
- `num_best_nodes`: Number of solutions on the Pareto frontier
- `total_nodes_explored`: Total configurations tested
- `total_search_cost`: Total cost of the optimization search
## pareto_frontier.json
List of Pareto-optimal solutions (the cost-accuracy frontier):
```json
[
{
"node_id": 5,
"yaml_path": "results/moar_optimization/pipeline_5.yaml",
"cost": 0.05,
"accuracy": 0.92
},
{
"node_id": 12,
"yaml_path": "results/moar_optimization/pipeline_12.yaml",
"cost": 0.08,
"accuracy": 0.95
}
]
```
!!! tip "Choosing a Solution"
Review the Pareto frontier to find solutions that match your priorities:
- **Low cost priority**: Choose solutions with lower cost
- **High accuracy priority**: Choose solutions with higher accuracy
- **Balanced**: Choose solutions in the middle
Each solution includes a `yaml_path` pointing to the optimized pipeline configuration.
## evaluation_metrics.json
Detailed evaluation results for all explored configurations. This file contains comprehensive metrics for every pipeline configuration tested during optimization.
## Pipeline Configurations
Each solution on the Pareto frontier has a corresponding YAML file (e.g., `pipeline_5.yaml`) containing the optimized pipeline configuration. You can:
1. Review the changes MOAR made
2. Test the pipeline on your full dataset
3. Use it in production
## Next Steps
After reviewing the results:
1. **Review the Pareto frontier** - See available options
2. **Choose a solution** - Based on your accuracy/cost priorities
3. **Test the chosen pipeline** - Run it on your full dataset
4. **Integrate into production** - Use the optimized configuration
!!! success "Success"
You now have multiple optimized pipeline options to choose from, each representing a different point on the cost-accuracy trade-off curve.

View File

@ -0,0 +1,113 @@
# Troubleshooting MOAR
Common issues and solutions when using MOAR optimization.
## Error: Missing required accuracy metric
!!! error "Error Message"
`KeyError: Missing required accuracy metric 'your_metric_key'`
**Solution:**
Check that:
1. Your evaluation function returns a dictionary with the `metric_key` you specified
2. The `metric_key` in `optimizer_config` matches the key in your evaluation results
3. Your evaluation function is working correctly (test it independently)
```python
# Test your function
result = evaluate_results("dataset.json", "results.json")
print(result) # Verify your metric_key is present
```
## Error: Evaluation function takes wrong number of arguments
!!! error "Error Message"
`TypeError: evaluate_results() takes 1 positional argument but 2 were given`
**Solution:**
Make sure your evaluation function has exactly this signature:
```python
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
```
And that it's decorated with `@docetl.register_eval`.
## All accuracies showing as 0.0
!!! warning "Symptom"
All solutions show 0.0 accuracy in the Pareto frontier.
**Possible causes:**
1. **Evaluation function failing silently** - Check the error logs
2. **Result files don't exist** - Make sure pipelines are executing successfully
3. **Metric key doesn't match** - Verify `metric_key` matches what your function returns
**Solution:**
Test your evaluation function independently and check MOAR logs for errors.
## Optimization taking too long
!!! tip "Speed Up Optimization"
If optimization is taking too long, try:
- Reduce `max_iterations` (e.g., from 40 to 20)
- Use a smaller sample dataset via `dataset_path`
- Reduce the number of models in `available_models`
- Use a faster model for directive instantiation (`model` parameter)
## Best Practices
### Using Sample/Hold-Out Datasets
!!! tip "Avoid Overfitting"
Always use a sample or hold-out dataset for optimization to avoid optimizing on your test set:
```yaml
optimizer_config:
dataset_path: data/sample_100.json # Use sample/hold-out for optimization
```
### Choosing Models
!!! tip "Model Selection"
Include a range of models in `available_models` to explore cost-accuracy trade-offs:
```yaml
available_models:
- gpt-5.1-nano # Cheapest, lower accuracy
- gpt-5.1-mini # Low cost, decent accuracy
- gpt-5.1 # Balanced
- gpt-4o # Higher cost, better accuracy
```
### Iteration Count
!!! tip "Iteration Guidelines"
- **10-20 iterations**: Quick exploration, good for testing
- **40 iterations**: Recommended for most use cases
- **100+ iterations**: For complex pipelines or when you need the absolute best results
### Evaluation Function Performance
!!! tip "Keep Functions Fast"
Your evaluation function will be called many times. Make sure it's efficient:
- Avoid expensive computations
- Cache results if possible
- Keep the function simple and fast
## Getting Help
If you're still experiencing issues:
1. Check the MOAR logs for detailed error messages
2. Verify your evaluation function works independently
3. Test with a smaller `max_iterations` to isolate issues
4. Review the [Configuration Reference](configuration.md) to ensure all required fields are set

View File

@ -1,6 +1,29 @@
# DocETL Optimizer
The DocETL optimizer finds a plan that improves the accuracy of your document processing pipelines. It works by analyzing and potentially rewriting operations marked for optimization, finding optimal plans for execution.
DocETL provides two optimizer options to improve your document processing pipelines:
## MOAR Optimizer (Recommended)
The **MOAR (Multi-Objective Agentic Rewrites)** optimizer uses Monte Carlo Tree Search to explore optimization space and find Pareto-optimal solutions that balance accuracy and cost. It's the recommended optimizer for most use cases.
**Key Features:**
- Multi-objective optimization (accuracy + cost)
- Returns multiple Pareto-optimal solutions
- Automatic model exploration
- Custom evaluation functions
- Intelligent search using MCTS
See the [MOAR Optimizer Guide](moar.md) for detailed documentation and examples.
## V1 Optimizer (Deprecated)
!!! warning "Deprecated"
The V1 optimizer is deprecated and no longer recommended. Use MOAR instead for all new optimizations.
The V1 optimizer uses a greedy approach with validation to find improved pipeline configurations. It's still available for backward compatibility but should not be used for new projects.
The rest of this page describes the general optimization concepts that apply to both optimizers.
## Key Features
@ -81,3 +104,17 @@ After applying the optimizer, your pipeline could be transformed into a more eff
3. **Reduce Operation**: For each contract, combine the extracted and tagged clauses from each chunk.
The goal of the DocETL optimizer is to try many ways of rewriting your pipeline and then select the best one. This may take some time (20-30 minutes for very complex tasks and large documents). But the optimizer's ability to break down complex tasks into more manageable sub-steps can lead to more accurate and reliable results.
## Choosing an Optimizer
**Use MOAR if:**
- You want to explore cost-accuracy trade-offs
- You need multiple solution options (Pareto frontier)
- You have custom evaluation metrics
- You want automatic model exploration
!!! warning "V1 Optimizer Deprecated"
The V1 optimizer is deprecated. Use MOAR instead. If you have existing V1-optimized pipelines, they will continue to work, but new optimizations should use MOAR.
For detailed MOAR usage, see the [MOAR Optimizer Guide](moar.md).

View File

@ -0,0 +1,45 @@
# Quick Start with Claude Code
The fastest way to build DocETL pipelines is with [Claude Code](https://claude.ai/download), Anthropic's agentic coding tool. DocETL includes a built-in Claude Code skill that helps you create, run, and debug pipelines interactively.
## Option 1: Clone the Repository (Recommended)
This gives you the full development environment with the skill already configured.
1. Follow the [Installation from Source](installation.md#installation-from-source) instructions
2. Run `claude` in the repository directory
The skill is located at `.claude/skills/docetl/SKILL.md`.
## Option 2: Install via pip
If you already have DocETL installed via pip, you can install the skill separately:
```bash
pip install docetl
docetl install-skill
```
This copies the skill to `~/.claude/skills/docetl/`. Then run `claude` in any directory.
To uninstall: `docetl install-skill --uninstall`
## Usage
Simply describe what you want to do with your data. The skill activates automatically when you mention "docetl" or describe unstructured data processing tasks:
```
> I have a folder of customer support tickets in JSON format.
> I want to extract the main issue, sentiment, and suggested resolution for each.
```
Claude will:
1. **Read your data** to understand its structure
2. **Write a tailored pipeline** with prompts specific to your documents
3. **Run the pipeline** and show you the results
4. **Debug issues** if any operations fail
## Alternative: Manual Pipeline Authoring
If you prefer not to use Claude Code, see the [Quick Start Tutorial](tutorial.md) for writing pipelines by hand.

396
docs/retrievers.md Normal file
View File

@ -0,0 +1,396 @@
## Retrievers (LanceDB OSS)
Retrievers let you augment LLM operations with retrieved context from a LanceDB index built over a DocETL dataset. You define retrievers once at the top-level, then attach them to any LLM-powered operation using `retriever: <name>`. At runtime, DocETL performs full-text, vector, or hybrid search and injects the results into your prompt as `{{ retrieval_context }}`.
LanceDB supports built-in full-text search, vector search, and hybrid with RRF reranking. See the official docs: [LanceDB Hybrid Search docs](https://lancedb.com/docs/search/hybrid-search/).
### Key points
- Always OSS LanceDB (local `index_dir`).
- A retriever references a dataset from the pipeline config, or the output of a previous pipeline step.
- Operations do not override retriever settings. One source of truth = consistency.
- `{{ retrieval_context }}` is available to your prompt; if not used, DocETL prepends a short "extra context" section automatically.
## Configuration
Add a top-level `retrievers` section. Each retriever has:
- `dataset`: dataset name to index (can be a dataset or output of a previous pipeline step)
- `index_dir`: LanceDB path
- `index_types`: which indexes to build: `fts`, `embedding`, or `hybrid` (both)
- `fts.index_phrase`: Jinja template for indexing each row for full-text search
- `fts.query_phrase`: Jinja template for building the FTS query at runtime
- `embedding.model`: embedding model for vector index and queries
- `embedding.index_phrase`: Jinja template for indexing each row for embeddings
- `embedding.query_phrase`: Jinja template for building the embedding query
- `query.mode`: `fts` | `embedding` | `hybrid` (defaults to `hybrid` when both indexes exist)
- `query.top_k`: number of results to retrieve
### Basic example
```yaml
datasets:
transcripts:
type: file
path: workloads/medical/raw.json
default_model: gpt-4o-mini
retrievers:
medical_r:
type: lancedb
dataset: transcripts
index_dir: workloads/medical/lance_index
build_index: if_missing # if_missing | always | never
index_types: ["fts", "embedding"]
fts:
index_phrase: "{{ input.src }}"
query_phrase: "{{ input.src[:1000] }}"
embedding:
model: openai/text-embedding-3-small
index_phrase: "{{ input.src }}"
query_phrase: "{{ input.src[:1000] }}"
query:
mode: hybrid
top_k: 8
```
## Multi-step pipelines with retrieval
Most pipelines have a single step, but you can define multiple steps where **the output of one step becomes the input (and retriever source) for the next**. This is powerful for patterns like:
1. Extract structured data from documents
2. Build a retrieval index on that extracted data
3. Use retrieval to find related items and process them
### Example: Extract facts, then find conflicts
```yaml
datasets:
articles:
type: file
path: workloads/wiki/articles.json
default_model: gpt-4o-mini
# Retriever indexes output of step 1 (extract_facts_step)
retrievers:
facts_index:
type: lancedb
dataset: extract_facts_step # References output of a pipeline step!
index_dir: workloads/wiki/facts_lance_index
build_index: if_missing
index_types: ["fts", "embedding"]
fts:
index_phrase: "{{ input.fact }} from {{ input.title }}"
query_phrase: "{{ input.fact }}"
embedding:
model: openai/text-embedding-3-small
index_phrase: "{{ input.fact }}"
query_phrase: "{{ input.fact }}"
query:
mode: hybrid
top_k: 5
operations:
- name: extract_facts
type: map
prompt: |
Extract factual claims from this article.
Article: {{ input.title }}
Text: {{ input.text }}
output:
schema:
facts: list[string]
- name: unnest_facts
type: unnest
unnest_key: facts
- name: find_conflicts
type: map
retriever: facts_index # Uses the retriever
prompt: |
Check if this fact conflicts with similar facts from other articles.
Current fact: {{ input.facts }} (from {{ input.title }})
Similar facts from other articles:
{{ retrieval_context }}
Return true only if there's a genuine contradiction.
output:
schema:
has_conflict: boolean
pipeline:
steps:
# Step 1: Extract and unnest facts
- name: extract_facts_step
input: articles
operations:
- extract_facts
- unnest_facts
# Step 2: Use retrieval to find conflicts
- name: find_conflicts_step
input: extract_facts_step # Input is output of step 1
operations:
- find_conflicts
output:
type: file
path: workloads/wiki/conflicts.json
intermediate_dir: workloads/wiki/intermediates
```
In this example:
- **Step 1** (`extract_facts_step`) extracts facts from articles
- The **retriever** (`facts_index`) indexes the output of step 1
- **Step 2** (`find_conflicts_step`) processes each fact, using retrieval to find similar facts from other articles
## Configuration reference
### Minimal example
Here's the simplest possible retriever config (FTS only):
```yaml
retrievers:
my_search: # name can be anything you want
type: lancedb
dataset: my_dataset # must match a dataset name or pipeline step
index_dir: ./my_lance_index
index_types: ["fts"]
fts:
index_phrase: "{{ input.text }}" # what to index from each row
query_phrase: "{{ input.query }}" # what to search for at runtime
```
### Full example with all options
```yaml
retrievers:
my_search:
type: lancedb
dataset: my_dataset
index_dir: ./my_lance_index
build_index: if_missing # optional, default: if_missing
index_types: ["fts", "embedding"] # can be ["fts"], ["embedding"], or both
fts:
index_phrase: "{{ input.text }}"
query_phrase: "{{ input.query }}"
embedding:
model: openai/text-embedding-3-small
index_phrase: "{{ input.text }}" # optional, falls back to fts.index_phrase
query_phrase: "{{ input.query }}"
query: # optional section
mode: hybrid # optional, auto-selects based on index_types
top_k: 10 # optional, default: 5
```
---
### Required fields
| Field | Description |
| --- | --- |
| `type` | Must be `lancedb` |
| `dataset` | Name of a dataset or pipeline step to index |
| `index_dir` | Path where LanceDB stores the index (created if missing) |
| `index_types` | List of index types: `["fts"]`, `["embedding"]`, or `["fts", "embedding"]` |
---
### Optional fields
| Field | Default | Description |
| --- | --- | --- |
| `build_index` | `if_missing` | When to build: `if_missing`, `always`, or `never` |
| `query.mode` | auto | `fts`, `embedding`, or `hybrid`. Auto-selects based on what indexes exist |
| `query.top_k` | 5 | Number of results to return |
---
### The `fts` section
Required if `"fts"` is in `index_types`. Configures full-text search.
| Field | Required | Description |
| --- | --- | --- |
| `index_phrase` | yes | Jinja template: what text to index from each dataset row |
| `query_phrase` | yes | Jinja template: what text to search for at query time |
**Jinja variables available:**
| Template | Variables | When it runs |
| --- | --- | --- |
| `index_phrase` | `input` = the dataset row | Once per row when building the index |
| `query_phrase` | `input` = current item (map/filter/extract) | At query time for each item processed |
| `query_phrase` | `reduce_key`, `inputs` (reduce operations) | At query time for each group |
**Example - Medical knowledge base:**
```yaml
datasets:
drugs:
type: file
path: drugs.json # [{"name": "Aspirin", "uses": "pain, fever"}, ...]
patient_notes:
type: file
path: notes.json # [{"symptoms": "headache and fever"}, ...]
retrievers:
drug_lookup:
type: lancedb
dataset: drugs # index the drugs dataset
index_dir: ./drug_index
index_types: ["fts"]
fts:
index_phrase: "{{ input.name }}: {{ input.uses }}" # index: "Aspirin: pain, fever"
query_phrase: "{{ input.symptoms }}" # search with patient symptoms
operations:
- name: find_treatment
type: map
retriever: drug_lookup # attach the retriever
prompt: |
Patient symptoms: {{ input.symptoms }}
Relevant drugs from knowledge base:
{{ retrieval_context }}
Recommend a treatment.
output:
schema:
recommendation: string
```
When processing `{"symptoms": "headache and fever"}`:
1. `query_phrase` renders to `"headache and fever"`
2. FTS searches the index and finds `"Aspirin: pain, fever"` as a match
3. `{{ retrieval_context }}` in your prompt contains the matched results
---
### The `embedding` section
Required if `"embedding"` is in `index_types`. Configures vector/semantic search.
| Field | Required | Description |
| --- | --- | --- |
| `model` | yes | Embedding model, e.g. `openai/text-embedding-3-small` |
| `index_phrase` | no | Jinja template for text to embed. Falls back to `fts.index_phrase` |
| `query_phrase` | yes | Jinja template for query text to embed |
**Jinja variables:** Same as FTS section.
**Example - Semantic search:**
```yaml
retrievers:
semantic_docs:
type: lancedb
dataset: documentation
index_dir: ./docs_index
index_types: ["embedding"]
embedding:
model: openai/text-embedding-3-small
index_phrase: "{{ input.content }}"
query_phrase: "{{ input.question }}"
```
---
### The `query` section (optional)
Controls search behavior. You can omit this entire section.
| Field | Default | Description |
| --- | --- | --- |
| `mode` | auto | `fts`, `embedding`, or `hybrid`. Auto-selects `hybrid` if both indexes exist |
| `top_k` | 5 | Number of results to retrieve |
**Example - Override defaults:**
```yaml
retrievers:
my_search:
# ... other config ...
query:
mode: fts # force FTS even if embedding index exists
top_k: 20 # return more results
```
---
## Using a retriever in operations
Attach a retriever to any LLM operation (map, filter, reduce, extract) with `retriever: <retriever_name>`. The retrieved results are available as `{{ retrieval_context }}` in your prompt.
| Parameter | Type | Default | Description |
| --- | --- | --- | --- |
| retriever | string | - | Name of the retriever to use (must match a key in `retrievers`). |
| save_retriever_output | bool | false | If true, saves retrieved context to `_<operation_name>_retrieved_context` in output. |
### Map example
```yaml
- name: tag_visit
type: map
retriever: medical_r
save_retriever_output: true
output:
schema:
tag: string
prompt: |
Classify this medical visit. Related context:
{{ retrieval_context }}
Transcript: {{ input.src }}
```
### Filter example
```yaml
- name: filter_relevant
type: filter
retriever: medical_r
prompt: |
Is this transcript relevant to medication counseling?
Context: {{ retrieval_context }}
Transcript: {{ input.src }}
output:
schema:
is_relevant: boolean
```
### Reduce example
When using reduce, the retrieval context is computed per group.
```yaml
- name: summarize_by_medication
type: reduce
retriever: medical_r
reduce_key: medication
output:
schema:
summary: string
prompt: |
Summarize key points for medication '{{ reduce_key.medication }}'.
Related context: {{ retrieval_context }}
Inputs:
{% for item in inputs %}
- {{ item.src }}
{% endfor %}
```
## Troubleshooting
- **No results**: the retriever injects "No extra context available." and continues.
- **Index issues**: set `build_index: always` to rebuild; ensure `index_dir` exists and is writable.
- **Token limits**: `retrieval_context` is truncated to ~1000 chars per retrieved doc.

View File

@ -246,4 +246,43 @@ print(f"Optimized pipeline execution completed. Total cost: ${cost:.2f}")
# Continue processing with other pandas operations
```
Learn more about the pandas integration in the [pandas documentation](pandas/index.md).
Learn more about the pandas integration in the [pandas documentation](pandas/index.md).
## Using a Code Operation to Normalize Medication Names (Python API)
You can insert a code operation in the Python API pipeline to perform per-document transformations without calling an LLM. Code ops accept a Python function (it does not need to be named `transform`).
```python
from docetl.api import CodeMapOp
# Normalize medication names (lowercase/trim)
def normalize_medication(doc: dict) -> dict:
med = doc.get("medication", "")
return {"medication_norm": med.lower().strip() if isinstance(med, str) else med}
# Add this op after unnesting and before resolve
operations = [
# ... existing extract_medications (MapOp), unnest_medications (UnnestOp)
CodeMapOp(
name="normalize_medication",
type="code_map",
code=normalize_medication,
),
# Optionally, update Resolve/Reduce to use "medication_norm" instead of "medication"
]
# And include it in the step order, e.g.
step = PipelineStep(
name="medical_info_extraction",
input="transcripts",
operations=[
"extract_medications",
"unnest_medications",
"normalize_medication", # new code op here
"resolve_medications",
"summarize_prescriptions",
],
)
```
This keeps your deterministic preprocessing in Python while still leveraging DocETL for the LLM-powered stages.

View File

@ -36,7 +36,7 @@ DocETL uses [LiteLLM](https://github.com/BerriAI/litellm) under the hood, which
## Preparing the Data
Organize your medical transcript data in a JSON file as a list of objects. Each object should have a "src" key containing the transcript text. You can download the example dataset [here](../assets/medical_transcripts.json).
Organize your medical transcript data in a JSON file as a list of objects. Each object should have a "src" key containing the transcript text. You can download the example dataset [here](assets/medical_transcripts.json).
!!! example "Sample Data Structure"

View File

@ -0,0 +1,246 @@
# Running MOARSearch and Simple Agent Experiments
This guide explains how to run the MOAR optimizer `run_moar.py` and the simple agent optimizer `run_simple_agent.py`.
## Table of Contents
- [Prerequisites](#prerequisites)
- [MOAR](#moarsearch-run_moarpy)
- [Simple Agent](#simple-agent-run_simple_agentpy)
- [Supported Datasets](#supported-datasets)
- [Custom Datasets with User-Authored Accuracy Functions](#custom-datasets-with-user-authored-accuracy-functions)
- [Available Models](#available-models)
- [Output Files](#output-files)
## Prerequisites
### Dataset and Pipeline Files
Download the 6 workload experiment datasets and initial pipelines from the [Google Drive Link](https://drive.google.com/drive/folders/1pNFqYCguAZL3iGYHd-jDoialrtf73fbR?usp=drive_link). Extract the `data/` folder and `pipelines/` folder to `experiments/reasoning/` directory.
### Python Dependencies
Install the required Python packages:
```bash
pip install -r experiments/reasoning/requirements.txt
```
### Environment Variables
#### Required Environment Variables
Set up your Azure API credentials (required for LLM calls):
```bash
export AZURE_API_KEY="your_key_here"
export AZURE_API_BASE="your_endpoint_here"
export AZURE_API_VERSION="your_version_here"
```
#### Optional Environment Variables
- **`EXPERIMENT_OUTPUT_DIR`**: Directory to save experiment outputs. If not set, defaults are used:
- `run_moar.py`: `./outputs`
- `run_simple_agent.py`: `./outputs/simple_agent`
- **`EXPERIMENT_DATA_DIR`**: Directory containing input data files. Defaults to `./data/` if not set. Can also be set via `--data_dir` parameter in `run_moar.py`.
## MOARSearch (`run_moar.py`)
MOARSearch is a multi-objective optimization algorithm that uses graph search to find Pareto-optimal pipelines balancing cost and accuracy.
### Local Execution
#### Basic Usage
```bash
python experiments/reasoning/run_moar.py \
--yaml_path experiments/reasoning/pipelines/cuad.yaml \
--dataset_path experiments/reasoning/data/train/cuad.json \
--experiment_name cuad_moar \
--dataset cuad
```
### Modal Execution (Recommended)
Modal execution is recommended for longer-running experiments as it provides better resource management and persistence.
```bash
modal run experiments/reasoning/run_moar.py \
--yaml-path=experiments/reasoning/pipelines/medec.yaml \
--dataset-path=experiments/reasoning/data/train/medec.json \
--experiment-name=medec_moar \
--dataset=medec \
--max-iterations=30
```
Outputs are written to `/mnt/docetl-ro-experiments/outputs/{experiment_name}` in the shared Modal volume.
### Parameters
| Parameter | Required | Default | Description |
|-----------|----------|---------|-------------|
| `--yaml_path` | ✅ Yes | - | Path to the user-authored input YAML pipeline file |
| `--dataset_path` | ✅ Yes | - | Path to the dataset file for sample input data |
| `--experiment_name` | ✅ Yes | - | Unique experiment identifier |
| `--dataset` | Conditional | `cuad` | Dataset name for evaluation. **Required** if `--accuracy_function` is not provided. Must be one of: cuad, blackvault, medec, biodex, sustainability, game_reviews. If using `--accuracy_function`, can be any string (used for naming). |
| `--max_iterations` | No | `100` | Maximum MOARSearch iterations |
| `--exploration_weight` | No | `1.414` | UCB exploration parameter c (controls exploration vs exploitation) |
| `--model` | No | `gpt-5` | LLM model to use for directive instantiation |
| `--available_models` | No | All 11 default models | Space-separated list of models for first layer testing. Example: `gpt-5 gpt-5-mini gpt-4o` |
| `--data_dir` | No | `EXPERIMENT_DATA_DIR` env var or `./data/` | Directory containing input data files |
| `--output_dir` | No | `EXPERIMENT_OUTPUT_DIR` env var or `./outputs` | Directory to save experiment outputs |
| `--ground_truth` | No | - | Path to ground-truth file |
| `--accuracy_function` | No | - | Path to Python file containing custom `evaluate_results` function (for user datasets) |
| `--accuracy_metric_key` | No | - | Key to extract from evaluation results dict for accuracy metric (required with `--accuracy_function`) |
## Simple Agent (`run_simple_agent.py`)
The Simple Agent is a baseline optimizer that uses basic tool calling to generate and test pipeline configurations iteratively.
### Local Execution
#### Basic Usage
```bash
python experiments/reasoning/run_simple_agent.py \
--dataset medec \
--model gpt-4o-mini
```
### Modal Execution (Recommended)
```bash
modal run experiments/reasoning/run_simple_agent.py \
--dataset=medec \
--model=gpt-4o-mini \
--experiment-name=medec_simple
```
### Parameters
| Parameter | Required | Default | Description |
|-----------|----------|---------|-------------|
| `--dataset` | Conditional | - | Dataset name. **Required** if `--accuracy_function` is not provided. Must be one of: cuad, blackvault, medec, biodex, sustainability, game_reviews. If using `--accuracy_function`, can be any string (used for naming). |
| `--model` | No | `gpt-5` | LLM model to use for optimization |
| `--experiment_name` | No | `simple_agent_{dataset}` | Unique experiment identifier |
| `--output_dir` | No | `outputs/simple_agent` | Output directory |
| `--ground_truth` | No | - | Path to ground truth file for evaluation |
| `--available_models` | No | All 11 default models | Space-separated list of available models. Example: `gpt-5 gpt-5-mini gpt-4o` |
| `--accuracy_function` | No | - | Path to Python file containing custom `evaluate_results` function (for user datasets) |
| `--accuracy_metric_key` | No | - | Key to extract from evaluation results dict for accuracy metric (required with `--accuracy_function`) |
## Supported Datasets
Both `run_moar.py` and `run_simple_agent.py` support the following datasets:
- `cuad` - Legal clause extraction
- `blackvault` - UFO sighting analysis
- `medec` - Medical entity classification
- `biodex` - Biochemical reaction prediction
- `sustainability` - Sustainability analysis
- `game_reviews` - Game review sentiment analysis
### Custom Datasets with User-Authored Accuracy Functions
For datasets not listed above, you can provide your own accuracy evaluation function using the `--accuracy_function` and `--accuracy_metric_key` parameters.
#### Creating a Custom Accuracy Function
Create a Python file (e.g., `my_evaluate.py`) with an `evaluate_results` function:
```python
# my_evaluate.py
import json
def evaluate_results(method_name, results_file_path):
"""
Evaluate pipeline results and return metrics.
Args:
method_name: Name of the method being evaluated
results_file_path: Path to the JSON file containing pipeline results
Returns:
dict: Dictionary containing evaluation metrics. Must include the metric
specified by --accuracy_metric_key.
"""
# Load results
with open(results_file_path, 'r') as f:
results = json.load(f)
# Your evaluation logic here
# Calculate metrics based on your dataset's requirements
metrics = {
"my_accuracy_metric": 0.95, # Primary accuracy metric (specify key with --accuracy_metric_key)
# Add any other metrics you want to track
}
return metrics
```
#### Using Custom Accuracy Functions
**MOARSearch:**
```bash
python experiments/reasoning/run_moar.py \
--yaml_path my_pipeline.yaml \
--dataset_path my_data.json \
--experiment_name my_custom_experiment \
--dataset my_custom_dataset \
--accuracy_function my_evaluate.py \
--accuracy_metric_key my_accuracy_metric
```
**Simple Agent:**
```bash
python experiments/reasoning/run_simple_agent.py \
--dataset my_custom_dataset \
--accuracy_function my_evaluate.py \
--accuracy_metric_key my_accuracy_metric
```
**Note:** When using custom accuracy functions, the `--dataset` parameter can be any string (it's used for naming/organization). The actual evaluation logic comes from your custom function.
## Available Models
Both scripts support the `--available_models` parameter to specify which models should be tested. If not provided, the following default model list is used:
- `gpt-5`
- `gpt-5-mini`
- `gpt-5-nano`
- `gpt-4.1`
- `gpt-4.1-mini`
- `gpt-4.1-nano`
- `gpt-4o`
- `gpt-4o-mini`
- `gemini-2.5-pro`
- `gemini-2.5-flash`
- `gemini-2.5-flash-lite`
## Output Files
### MOARSearch Output
Results are saved to `outputs/{experiment_name}/` containing:
- `experiment_summary.json` - Experiment metadata and results
- `pareto_frontier.json` - All Pareto-optimal solutions with accuracy and cost
- `evaluation_metrics.json` - Detailed evaluation results
- `moar_tree_log.txt` - Search tree structure and visit counts
- `cost_vs_{metric}.png` - Plot showing cost vs performance (dataset-specific)
- Pipeline YAML files for each explored configuration
### Simple Agent Output
Results are saved to `outputs/simple_agent/{experiment_name}/` containing:
- `final_pipeline.yaml` - Final optimized pipeline configuration
- `iteration_{n}_output.json` - Pipeline outputs for each iteration
- `evaluation_metrics.json` - Performance evaluation results
- `cost_vs_{metric}.png` - Plot showing cost vs performance across iterations

Some files were not shown because too many files have changed in this diff Show More