Compare commits

..

4 Commits

Author SHA1 Message Date
Shreya Shankar 54e6d08c2e Testing retriever and updating docs with logging 2025-12-26 16:42:20 -06:00
Shreya Shankar 83193f265a Merge main into addretriever branch 2025-12-26 16:11:08 -06:00
Shreya Shankar ba46f4510e feat: allow indexes to be built on datasets created by docetl pipelines 2025-11-23 19:55:17 -06:00
Shreya Shankar 53b4a17d03 feat: adding retrievers 2025-11-22 02:23:02 -08:00
92 changed files with 985 additions and 6587 deletions

View File

@ -1,776 +0,0 @@
---
name: docetl
description: Build and run LLM-powered data processing pipelines with DocETL. Use when users say "docetl", want to analyze unstructured data, process documents, extract information, or run ETL tasks on text. Helps with data collection, pipeline creation, execution, and optimization.
---
# DocETL Pipeline Development
DocETL is a system for creating LLM-powered data processing pipelines. This skill helps you build end-to-end pipelines: from data preparation to execution and optimization.
## Workflow Overview: Iterative Data Analysis
Work like a data analyst: **write → run → inspect → iterate**. Never write all scripts at once and run them all at once. Each phase should be completed and validated before moving to the next.
### Phase 1: Data Collection
1. Write data collection script
2. **Run it immediately** (with user permission)
3. **Inspect the dataset** - show the user:
- Total document count
- Keys/fields in each document
- Sample documents (first 3-5)
- Length distribution (avg chars, min/max)
- Any other relevant statistics
4. Iterate if needed (e.g., collect more data, fix parsing issues)
### Phase 2: Pipeline Development
1. Read sample documents to understand format
2. Write pipeline YAML with `sample: 10-20` for testing
3. **Run the test pipeline**
4. **Inspect intermediate results** - show the user:
- Extraction quality on samples
- Domain/category distributions
- Any validation failures
5. Iterate on prompts/schema based on results
6. Remove `sample` parameter and run full pipeline
7. **Show final results** - distributions, trends, key insights
### Phase 3: Visualization & Presentation
1. Write visualization script based on actual output structure
2. **Run and show the report** to the user
3. Iterate on charts/tables if needed
**Visualization Aesthetics:**
- **Clean and minimalist** - no clutter, generous whitespace
- **Warm and elegant color theme** - 1-2 accent colors max
- **Subtle borders** - not too rounded (border-radius: 8-10px max)
- **Sans-serif fonts** - system fonts like -apple-system, Segoe UI, Roboto
- **"Created by DocETL"** - add subtitle after the main title
- **Mix of charts and tables** - charts for distributions, tables for detailed summaries
- **Light background** - off-white (#f5f5f5) with white cards for content
**Report Structure:**
1. Title + "Created by DocETL" subtitle
2. Key stats cards (document count, categories, etc.)
3. Distribution charts (bar charts, pie charts)
4. Summary table with detailed analysis
5. Minimal footer
**Interactive Tables:**
- **All truncated content must be expandable** - never use static "..." truncation
- Long text: Show first ~250 chars with "(show more)" toggle
- Long lists: Show first 4-6 items with "(+N more)" toggle
- Use JavaScript to toggle visibility, not page reloads
**Source Document Links:**
- **Link aggregated results to source documents** - users should be able to drill down
- Clickable links that open a modal/popup with source content
- Modal should show: extracted fields + original source text
- Original text can be collapsed by default with "Show original" toggle
- Embed source data as JSON in the page for JavaScript access
**Key principle:** The user should see results at every step. Don't proceed to the next phase until the current phase produces good results.
## Step 1: Data Preparation
DocETL datasets must be **JSON arrays** or **CSV files**.
### JSON Format
```json
[
{"id": 1, "text": "First document content...", "metadata": "value"},
{"id": 2, "text": "Second document content...", "metadata": "value"}
]
```
### CSV Format
```csv
id,text,metadata
1,"First document content...","value"
2,"Second document content...","value"
```
### Data Collection Scripts
If user needs to collect data, write a Python script:
```python
import json
# Collect/transform data
documents = []
for source in sources:
documents.append({
"id": source.id,
"text": source.content, # DO NOT truncate text
# Add relevant fields
})
# Save as DocETL dataset
with open("dataset.json", "w") as f:
json.dump(documents, f, indent=2)
```
**Important:** Never truncate document text in collection scripts. DocETL operations like `split` handle long documents properly. Truncation loses information.
### After Running Data Collection
**Always run the collection script and inspect results before proceeding.** Show the user:
```python
import json
data = json.load(open("dataset.json"))
print(f"Total documents: {len(data)}")
print(f"Keys: {list(data[0].keys())}")
print(f"Avg length: {sum(len(str(d)) for d in data) // len(data)} chars")
# Show sample
print("\nSample document:")
print(json.dumps(data[0], indent=2)[:500])
```
Only proceed to pipeline development once the data looks correct.
## Step 2: Read and Understand the Data
**CRITICAL**: Before writing any prompts, READ the actual input data to understand:
- The structure and format of documents
- The vocabulary and terminology used
- What information is present vs. absent
- Edge cases and variations
```python
import json
with open("dataset.json") as f:
data = json.load(f)
# Examine several examples
for doc in data[:5]:
print(doc)
```
This understanding is essential for writing specific, effective prompts.
## Step 3: Pipeline Structure
Create a YAML file with this structure:
```yaml
default_model: gpt-5-nano
system_prompt:
dataset_description: <describe the data based on what you observed>
persona: <role for the LLM to adopt>
datasets:
input_data:
type: file
path: "dataset.json" # or dataset.csv
operations:
- name: <operation_name>
type: <operation_type>
prompt: |
<Detailed, specific prompt based on the actual data>
output:
schema:
<field_name>: <type>
pipeline:
steps:
- name: process
input: input_data
operations:
- <operation_name>
output:
type: file
path: "output.json"
intermediate_dir: "intermediates" # ALWAYS set this for debugging
```
### Key Configuration
- **default_model**: Use `gpt-5-nano` or `gpt-5-mini` for extraction/map operations
- **intermediate_dir**: Always set to log intermediate results
- **system_prompt**: Describe the data based on what you actually observed
### Model Selection by Operation Type
| Operation Type | Recommended Model | Rationale |
|---------------|-------------------|-----------|
| Map (extraction) | `gpt-5-nano` or `gpt-5-mini` | High volume, simple per-doc tasks |
| Filter | `gpt-5-nano` | Simple yes/no decisions |
| Reduce (summarization) | `gpt-4.1` or `gpt-5.1` | Complex synthesis across many docs |
| Resolve (deduplication) | `gpt-5-nano` or `gpt-5-mini` | Simple pairwise comparisons |
Use cheaper models for high-volume extraction, and more capable models for synthesis/summarization where quality matters most.
## Step 4: Writing Effective Prompts
**Prompts must be specific to the data, not generic.** After reading the input data:
### Bad (Generic) Prompt
```yaml
prompt: |
Extract key information from this document.
{{ input.text }}
```
### Good (Specific) Prompt
```yaml
prompt: |
You are analyzing a medical transcript from a doctor-patient visit.
The transcript follows this format:
- Doctor statements are prefixed with "DR:"
- Patient statements are prefixed with "PT:"
- Timestamps appear in brackets like [00:05:23]
From the following transcript, extract:
1. All medications mentioned (brand names or generic)
2. Dosages if specified
3. Patient-reported side effects or concerns
Transcript:
{{ input.transcript }}
Be thorough - patients often mention medication names informally.
If a medication is unclear, include it with a note.
```
### Prompt Writing Guidelines
1. **Describe the data format** you observed
2. **Be specific about what to extract** - list exact fields
3. **Mention edge cases** you noticed in the data
4. **Provide examples** if the task is ambiguous
5. **Set expectations** for handling missing/unclear information
## Step 5: Choosing Operations
Many tasks only need a **single map operation**. Use good judgement:
| Task | Recommended Approach |
|------|---------------------|
| Extract info from each doc | Single `map` |
| Multiple extractions | Multiple `map` operations chained |
| Extract then summarize | `map``reduce` |
| Filter then process | `filter``map` |
| Split long docs | `split``map``reduce` |
| Deduplicate entities | `map``unnest``resolve` |
## Operation Reference
### Map Operation
Applies an LLM transformation to each document independently.
```yaml
- name: extract_info
type: map
prompt: |
Analyze this document:
{{ input.text }}
Extract the main topic and 3 key points.
output:
schema:
topic: string
key_points: list[string]
model: gpt-5-nano # optional, uses default_model if not set
skip_on_error: true # recommended for large-scale runs
validate: # optional
- len(output["key_points"]) == 3
num_retries_on_validate_failure: 2 # optional
```
**Key parameters:**
- `prompt`: Jinja2 template, use `{{ input.field }}` to reference fields
- `output.schema`: Define output structure
- `skip_on_error`: Set `true` to continue on LLM errors (recommended at scale)
- `validate`: Python expressions to validate output
- `sample`: Process only N documents (for testing)
- `limit`: Stop after producing N outputs
### Filter Operation
Keeps or removes documents based on LLM criteria. Output schema must have exactly one boolean field.
```yaml
- name: filter_relevant
type: filter
skip_on_error: true
prompt: |
Document: {{ input.text }}
Is this document relevant to climate change?
Respond true or false.
output:
schema:
is_relevant: boolean
```
### Reduce Operation
Aggregates documents by a key using an LLM.
**Always include `fold_prompt` and `fold_batch_size`** for reduce operations. This handles cases where the group is too large to fit in context.
```yaml
- name: summarize_by_category
type: reduce
reduce_key: category # use "_all" to aggregate everything
skip_on_error: true
prompt: |
Summarize these {{ inputs | length }} items for category "{{ inputs[0].category }}":
{% for item in inputs %}
- {{ item.title }}: {{ item.description }}
{% endfor %}
Provide a 2-3 sentence summary of the key themes.
fold_prompt: |
You have a summary based on previous items, and new items to incorporate.
Previous summary (based on {{ output.item_count }} items):
{{ output.summary }}
New items ({{ inputs | length }} more):
{% for item in inputs %}
- {{ item.title }}: {{ item.description }}
{% endfor %}
Write a NEW summary that covers ALL items (previous + new).
IMPORTANT: Output a clean, standalone summary as if describing the entire dataset.
Do NOT mention "updated", "added", "new items", or reference the incremental process.
fold_batch_size: 100
output:
schema:
summary: string
item_count: int
validate:
- len(output["summary"].strip()) > 0
num_retries_on_validate_failure: 2
```
**Critical: Writing Good Fold Prompts**
The `fold_prompt` is called repeatedly as batches are processed. Its output must:
1. **Reflect ALL data seen so far**, not just the latest batch
2. **Be a clean, standalone output** - no "updated X" or "added Y items" language
3. **Match the same schema** as the initial `prompt` output
Bad fold_prompt output: "Added 50 new projects. The updated summary now includes..."
Good fold_prompt output: "Developers are building privacy-focused tools and local-first apps..."
**Estimating `fold_batch_size`:**
- **Use 100+ for most cases** - larger batches = fewer LLM calls = lower cost
- For very long documents, reduce to 50-75
- For short documents (tweets, titles), can use 150-200
- Models like gpt-4o-mini have 128k context, so batch size is rarely the bottleneck
**Key parameters:**
- `reduce_key`: Field to group by (or list of fields, or `_all`)
- `fold_prompt`: Template for incrementally adding items to existing output (required)
- `fold_batch_size`: Number of items per fold iteration (required, use 100+)
- `associative`: Set to `false` if order matters
### Split Operation
Divides long text into smaller chunks. No LLM call.
```yaml
- name: split_document
type: split
split_key: content
method: token_count # or "delimiter"
method_kwargs:
num_tokens: 500
model: gpt-5-nano
```
**Output adds:**
- `{split_key}_chunk`: The chunk content
- `{op_name}_id`: Original document ID
- `{op_name}_chunk_num`: Chunk number
### Unnest Operation
Flattens list fields into separate rows. No LLM call.
```yaml
- name: unnest_items
type: unnest
unnest_key: items # field containing the list
keep_empty: false # optional
```
**Example:** If a document has `items: ["a", "b", "c"]`, unnest creates 3 documents, each with `items: "a"`, `items: "b"`, `items: "c"`.
### Resolve Operation
Deduplicates and canonicalizes entities. Uses pairwise comparison.
```yaml
- name: dedupe_names
type: resolve
optimize: true # let optimizer find blocking rules
skip_on_error: true
comparison_prompt: |
Are these the same person?
Person 1: {{ input1.name }} ({{ input1.email }})
Person 2: {{ input2.name }} ({{ input2.email }})
Respond true or false.
resolution_prompt: |
Standardize this person's name:
{% for entry in inputs %}
- {{ entry.name }}
{% endfor %}
Return the canonical name.
output:
schema:
name: string
```
**Important:** Set `optimize: true` and run `docetl build` to generate efficient blocking rules. Without blocking, this is O(n²).
### Code Operations
Deterministic Python transformations without LLM calls.
**code_map:**
```yaml
- name: compute_stats
type: code_map
code: |
def transform(doc) -> dict:
return {
"word_count": len(doc["text"].split()),
"char_count": len(doc["text"])
}
```
**code_reduce:**
```yaml
- name: aggregate
type: code_reduce
reduce_key: category
code: |
def transform(items) -> dict:
total = sum(item["value"] for item in items)
return {"total": total, "count": len(items)}
```
**code_filter:**
```yaml
- name: filter_long
type: code_filter
code: |
def transform(doc) -> bool:
return len(doc["text"]) > 100
```
### Retrievers (LanceDB)
Augment LLM operations with retrieved context from a LanceDB index. Useful for:
- Finding related documents to compare against
- Providing additional context for extraction/classification
- Cross-referencing facts across a dataset
**Define a retriever:**
```yaml
retrievers:
facts_index:
type: lancedb
dataset: extracted_facts # dataset to index
index_dir: workloads/wiki/lance_index
build_index: if_missing # if_missing | always | never
index_types: ["fts", "embedding"] # or "hybrid"
fts:
index_phrase: "{{ input.fact }}: {{ input.source }}"
query_phrase: "{{ input.fact }}"
embedding:
model: openai/text-embedding-3-small
index_phrase: "{{ input.fact }}"
query_phrase: "{{ input.fact }}"
query:
mode: hybrid
top_k: 5
```
**Use in operations:**
```yaml
- name: find_conflicts
type: map
retriever: facts_index
prompt: |
Check if this fact conflicts with any retrieved facts:
Current fact: {{ input.fact }} (from {{ input.source }})
Related facts from other articles:
{{ retrieval_context }}
Return whether there's a genuine conflict.
output:
schema:
has_conflict: boolean
```
**Key points:**
- `{{ retrieval_context }}` is injected into prompts automatically
- Index is built on first use (when `build_index: if_missing`)
- Supports full-text (`fts`), vector (`embedding`), or `hybrid` search
- Use `save_retriever_output: true` to debug what was retrieved
- **Can index intermediate outputs**: Retriever can index the output of a previous pipeline step, enabling patterns like "extract facts → index facts → retrieve similar facts for each"
## Documentation Reference
For detailed parameters, advanced features, and more examples, read the docs:
- **Operations**: `docs/operators/` folder (map.md, reduce.md, filter.md, etc.)
- **Concepts**: `docs/concepts/` folder (pipelines.md, operators.md, schemas.md)
- **Examples**: `docs/examples/` folder
- **Optimization**: `docs/optimization/` folder
## Step 6: Environment Setup
Before running, verify API keys exist:
```bash
# Check for .env file
cat .env
```
Required keys depend on the model:
- OpenAI: `OPENAI_API_KEY`
- Anthropic: `ANTHROPIC_API_KEY`
- Google: `GEMINI_API_KEY`
If missing, prompt user to create `.env`:
```
OPENAI_API_KEY=sk-...
```
## Step 7: Execution
**Always test on a sample first, then run full pipeline.**
### Test Run (Required)
Add `sample: 10-20` to your first operation, then run:
```bash
docetl run pipeline.yaml
```
**Inspect the test results before proceeding:**
```python
import json
from collections import Counter
# Load intermediate results
data = json.load(open("intermediates/step_name/operation_name.json"))
print(f"Processed: {len(data)} docs")
# Check distributions
if "domain" in data[0]:
print("Domain distribution:")
for k, v in Counter(d["domain"] for d in data).most_common():
print(f" {k}: {v}")
# Show sample outputs
print("\nSample output:")
print(json.dumps(data[0], indent=2))
```
### Full Run
Once test results look good:
1. Remove the `sample` parameter from the pipeline
2. Ask user for permission (estimate cost based on test run)
3. Run full pipeline
4. **Show final results** - distributions, key insights, trends
Options:
- `--max_threads N` - Control parallelism
Check intermediate results in the `intermediate_dir` folder to debug each step.
## Step 8: Optimization (Optional)
Use MOAR optimizer to find the Pareto frontier of **cost vs. accuracy** tradeoffs. MOAR experiments with different pipeline rewrites and models to find optimal configurations.
Add to pipeline YAML:
```yaml
optimizer_config:
type: moar
save_dir: ./optimization_results
available_models:
- gpt-5-nano
- gpt-4o-mini
- gpt-4o
evaluation_file: evaluate.py # User must provide
metric_key: score
max_iterations: 20
model: gpt-5-nano
```
Create evaluation file (`evaluate.py`):
```python
def evaluate(outputs: list[dict]) -> dict:
# Score the outputs (0-1 scale recommended)
correct = sum(1 for o in outputs if is_correct(o))
return {"score": correct / len(outputs)}
```
Run optimization:
```bash
docetl build pipeline.yaml --optimizer moar
```
MOAR will produce multiple pipeline variants on the Pareto frontier - user can choose based on their cost/accuracy preferences.
## Output Schemas
**Keep schemas minimal and simple** unless the user explicitly requests more fields. Default to 1-3 output fields per operation. Only add more fields if the user specifically asks for them.
**Nesting limit:** Maximum 2 levels deep (e.g., `list[{field: str}]` is allowed, but no deeper).
```yaml
# Good - minimal, focused on the core task
output:
schema:
summary: string
# Good - a few fields when task requires it
output:
schema:
topic: string
keywords: list[string]
# Acceptable - 2 levels of nesting (list of objects)
output:
schema:
items: "list[{name: str, value: int}]"
# Bad - too many fields (unless user explicitly requested all of these)
output:
schema:
conflicts_found: bool
num_conflicts: int
conflicts: "list[{claim_a: str, source_a: str, claim_b: str, source_b: str}]"
analysis_summary: str
# Bad - more than 2 levels of nesting (not supported)
output:
schema:
data: "list[{nested: {too: {deep: str}}}]"
```
**Guidelines:**
- Start with the minimum fields needed to answer the user's question
- Avoid complex nested objects unless explicitly requested
- If you need structured data, prefer multiple simple operations over one complex schema
- Complex schemas increase LLM failures and costs
Supported types: `string`, `int`, `float`, `bool`, `list[type]`, `enum`
## Validation
**Always add validation to LLM-powered operations** (map, reduce, filter, resolve). Validation catches malformed outputs and retries automatically.
```yaml
- name: extract_keywords
type: map
prompt: |
Extract 3-5 keywords from: {{ input.text }}
output:
schema:
keywords: list[string]
validate:
- len(output["keywords"]) >= 3
- len(output["keywords"]) <= 5
num_retries_on_validate_failure: 2
```
Common validation patterns:
```yaml
# List length constraints
- len(output["items"]) >= 1
- len(output["items"]) <= 10
# Enum/allowed values
- output["sentiment"] in ["positive", "negative", "neutral"]
# String not empty
- len(output["summary"].strip()) > 0
# Numeric ranges
- output["score"] >= 0
- output["score"] <= 100
```
## Jinja2 Templating
For map operations, use `input`:
```yaml
prompt: |
Document: {{ input.text }}
{% if input.metadata %}
Context: {{ input.metadata }}
{% endif %}
```
For reduce operations, use `inputs` (list):
```yaml
prompt: |
Summarize these {{ inputs | length }} items:
{% for item in inputs %}
- {{ item.summary }}
{% endfor %}
```
## Troubleshooting
### Pipeline won't run
- Check `.env` has correct API keys
- Verify dataset file exists and is valid JSON/CSV
- Check YAML syntax
### Bad outputs
- Read more input data examples to improve prompt specificity
- Add `validate` rules with retries
- Simplify output schema
- Add concrete examples to prompt
### High costs
- Use `gpt-5-nano` or `gpt-4o-mini`
- Add `sample: 10` to test on subset first
- Run MOAR optimizer to find cost-efficient rewrites
### Check intermediate results
Look in `intermediate_dir` folder to debug each step.
## Quick Reference
```bash
# Run pipeline
docetl run pipeline.yaml
# Run with more parallelism
docetl run pipeline.yaml --max_threads 16
# Optimize pipeline (cost/accuracy tradeoff)
docetl build pipeline.yaml --optimizer moar
# Clear LLM cache
docetl clear-cache
# Check version
docetl version
```

View File

@ -13,8 +13,7 @@ DocETL is a tool for creating and executing data processing pipelines, especiall
2. A Python package for running production pipelines from the command line or Python code
> 💡 **Need Help Writing Your Pipeline?**
> You can use **Claude Code** (recommended) to help you write your pipeline—see the quickstart: https://ucbepic.github.io/docetl/quickstart-claude-code/
> If youd rather use ChatGPT or the Claude app, see [docetl.org/llms.txt](https://docetl.org/llms.txt) for a big prompt you can copy/paste before describing your task.
> Want to use an LLM like ChatGPT or Claude to help you write your pipeline? See [docetl.org/llms.txt](https://docetl.org/llms.txt) for a big prompt you can copy paste into ChatGPT or Claude, before describing your task.
### 🌟 Community Projects

View File

@ -1,15 +1,11 @@
__version__ = "0.2.6"
__version__ = "0.2.5"
import warnings
import litellm
from docetl.runner import DSLRunner
from docetl.optimizer import Optimizer
from docetl.apis.pd_accessors import SemanticAccessor
# Drop unsupported params for models like gpt-5 that don't support temperature=0
litellm.drop_params = True
# TODO: Remove after https://github.com/BerriAI/litellm/issues/7560 is fixed
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._config")

View File

@ -244,104 +244,5 @@ def version():
typer.echo(f"DocETL version: {docetl.__version__}")
@app.command("install-skill")
def install_skill(
uninstall: bool = typer.Option(
False, "--uninstall", "-u", help="Remove the installed skill instead"
),
):
"""
Install the DocETL Claude Code skill to your personal skills directory.
This makes the DocETL skill available in Claude Code for any project.
The skill helps you build and run DocETL pipelines.
"""
import shutil
# Find the skill source - try multiple locations
# 1. Installed package location (via importlib.resources)
# 2. Development location (relative to this file)
skill_source = None
# Try to find via package resources first
try:
import importlib.resources as pkg_resources
# For Python 3.9+, use files()
try:
package_root = Path(pkg_resources.files("docetl")).parent
potential_source = package_root / ".claude" / "skills" / "docetl"
if potential_source.exists():
skill_source = potential_source
except (TypeError, AttributeError):
pass
except ImportError:
pass
# Fallback: try relative to this file (development mode)
if skill_source is None:
dev_source = Path(__file__).parent.parent / ".claude" / "skills" / "docetl"
if dev_source.exists():
skill_source = dev_source
if skill_source is None or not skill_source.exists():
console.print(
Panel(
"[bold red]Error:[/bold red] Could not find the DocETL skill files.\n\n"
"This may happen if the package was not installed correctly.\n"
"Try reinstalling: [bold]pip install --force-reinstall docetl[/bold]",
title="[bold red]Skill Not Found[/bold red]",
border_style="red",
)
)
raise typer.Exit(1)
# Target directory
skill_target = Path.home() / ".claude" / "skills" / "docetl"
if uninstall:
if skill_target.exists():
shutil.rmtree(skill_target)
console.print(
Panel(
f"[bold green]Success![/bold green] DocETL skill removed from:\n"
f"[dim]{skill_target}[/dim]",
title="[bold green]Skill Uninstalled[/bold green]",
border_style="green",
)
)
else:
console.print(
Panel(
"[yellow]The DocETL skill is not currently installed.[/yellow]",
title="[yellow]Nothing to Uninstall[/yellow]",
border_style="yellow",
)
)
return
# Create parent directories if needed
skill_target.parent.mkdir(parents=True, exist_ok=True)
# Copy the skill
if skill_target.exists():
shutil.rmtree(skill_target)
shutil.copytree(skill_source, skill_target)
console.print(
Panel(
f"[bold green]Success![/bold green] DocETL skill installed to:\n"
f"[dim]{skill_target}[/dim]\n\n"
"[bold]Next steps:[/bold]\n"
"1. Restart Claude Code if it's running\n"
"2. The skill will automatically activate when you work on DocETL tasks\n\n"
"[dim]To uninstall: docetl install-skill --uninstall[/dim]",
title="[bold green]Skill Installed[/bold green]",
border_style="green",
)
)
if __name__ == "__main__":
app()

View File

@ -1,5 +1,4 @@
import os
import re
import threading
import time
from io import StringIO
@ -13,63 +12,6 @@ from docetl.utils import StageType, get_stage_description
install(show_locals=False)
# ANSI escape sequences to strip (cursor movement, line clearing, cursor visibility)
ANSI_CURSOR_PATTERNS = re.compile(
r"\x1b\[\?25[lh]" # Hide/show cursor
r"|\x1b\[\d*[ABCD]" # Cursor up/down/forward/back
r"|\x1b\[\d*[JK]" # Clear screen/line (but we handle \x1b[2K specially)
r"|\x1b\[s|\x1b\[u" # Save/restore cursor position
)
def process_carriage_returns(text: str) -> str:
"""
Process terminal control sequences for buffer-captured output.
Rich's spinner uses ANSI sequences for:
- `\r` - carriage return (move to start of line)
- `\x1b[2K` - clear entire line
- `\x1b[?25l/h` - hide/show cursor
- `\x1b[1A` - move cursor up
When captured to a buffer (not a real terminal), we simulate this by:
1. Handling `\r` as "replace from start of line"
2. Stripping cursor movement/visibility sequences
3. Keeping only meaningful content
"""
# First, strip cursor visibility and movement sequences
text = ANSI_CURSOR_PATTERNS.sub("", text)
# Process line by line
lines = text.split("\n")
processed_lines = []
for line in lines:
# Handle carriage returns - keep only content after the last \r
if "\r" in line:
segments = line.split("\r")
# Keep the last non-empty segment (after stripping ANSI clear codes)
for segment in reversed(segments):
# Strip the clear line code if present
cleaned = segment.replace("\x1b[2K", "").strip()
if cleaned:
# Keep the segment but preserve any color codes
processed_lines.append(segment.replace("\x1b[2K", ""))
break
else:
# All segments were empty, skip this line entirely
pass
else:
# No carriage return, keep the line as-is
if line.strip() or line == "": # Keep empty lines between content
processed_lines.append(line)
# Remove trailing empty lines
while processed_lines and not processed_lines[-1].strip():
processed_lines.pop()
return "\n".join(processed_lines)
class ThreadSafeConsole(Console):
def __init__(self, *args, **kwargs):
@ -82,18 +24,11 @@ class ThreadSafeConsole(Console):
self.optimizer_rationale = None
def get_output(self):
"""
Get the output from the buffer, processing carriage returns.
Rich's spinner uses carriage returns to overwrite lines in place.
We process these to only keep the latest content, preventing
duplicate spinner frames from flooding the output.
"""
# return self.export_text(styles=True)
value = self.buffer.getvalue()
self.buffer.truncate(0)
self.buffer.seek(0)
# Process carriage returns to handle spinner overwrites
return process_carriage_returns(value)
return value
def status(
self,
@ -101,15 +36,10 @@ class ThreadSafeConsole(Console):
*,
spinner: str = "dots",
spinner_style: "StyleType" = "status.spinner",
speed: float = 1.0,
refresh_per_second: float = 4,
speed: float = 0.1, # Much slower speed
refresh_per_second: float = 0.5, # Much slower refresh rate (every 2 seconds)
) -> "Status":
"""
Return a Rich Status with animation.
The carriage returns from the spinner animation are processed
in get_output() to prevent duplicate lines.
"""
status_renderable = Status(
status,
console=self,

View File

@ -465,8 +465,7 @@ class OpContainer:
return cached_data, 0, curr_logs
# Try to load from checkpoint if available
# Skip if this operation has bypass_cache: true
if not is_build and not self.config.get("bypass_cache", False):
if not is_build:
attempted_input_data = self.runner._load_from_checkpoint_if_exists(
self.name.split("/")[0], self.name.split("/")[-1]
)

View File

@ -849,9 +849,13 @@ class MOARSearch:
response = litellm.completion(
model=self.model,
messages=messages,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ExpandResponseFormat,
)
call_cost = response._hidden_params.get("response_cost", 0)
call_cost = response._hidden_params["response_cost"]
with self.tree_lock:
self.console.log(
f"[green]💰 Adding LLM call cost:[/green] ${call_cost:.4f} (total before: ${self.total_search_cost:.4f})"

View File

@ -153,11 +153,10 @@ def run_moar_optimization(
"optimizer_config must contain 'max_iterations' (number of MOARSearch iterations) for MOAR optimizer"
)
# Use rewrite_agent_model for consistency with rest of codebase, fallback to model for backwards compatibility
model = optimizer_config.get("rewrite_agent_model") or optimizer_config.get("model")
model = optimizer_config.get("model")
if not model:
raise ValueError(
"optimizer_config must contain 'rewrite_agent_model' (LLM model name for directive instantiation) for MOAR optimizer"
"optimizer_config must contain 'model' (LLM model name for directive instantiation) for MOAR optimizer"
)
# Optional parameters

View File

@ -3,7 +3,7 @@ import os
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from pydantic import Field, field_validator
from pydantic import field_validator
from docetl.operations.base import BaseOperation
from docetl.operations.utils import RichLoopBar
@ -15,7 +15,6 @@ class CodeMapOperation(BaseOperation):
code: Any
concurrent_thread_count: int = os.cpu_count()
drop_keys: list[str] | None = None
limit: int | None = Field(None, gt=0)
@field_validator("code")
@classmethod
@ -45,10 +44,6 @@ class CodeMapOperation(BaseOperation):
raise ValueError(f"Invalid code configuration: {str(e)}")
def execute(self, input_data: list[dict]) -> tuple[list[dict], float]:
limit_value = self.config.get("limit")
if limit_value is not None:
input_data = input_data[:limit_value]
namespace = {}
exec(self.config["code"], namespace)
transform_fn = namespace["transform"]
@ -83,7 +78,6 @@ class CodeReduceOperation(BaseOperation):
type: str = "code_reduce"
code: Any
concurrent_thread_count: int = os.cpu_count()
limit: int | None = Field(None, gt=0)
@field_validator("code")
@classmethod
@ -137,12 +131,6 @@ class CodeReduceOperation(BaseOperation):
grouped_data = list(grouped_data.items())
limit_value = self.config.get("limit")
if limit_value is not None:
# Sort by group size (smallest first) and take the limit
grouped_data = sorted(grouped_data, key=lambda x: len(x[1]))
grouped_data = grouped_data[:limit_value]
results = []
with ThreadPoolExecutor(
max_workers=self.config.get("concurrent_thread_count", os.cpu_count())
@ -180,7 +168,6 @@ class CodeFilterOperation(BaseOperation):
type: str = "code_filter"
code: Any
concurrent_thread_count: int = os.cpu_count()
limit: int | None = Field(None, gt=0)
@field_validator("code")
@classmethod
@ -214,7 +201,6 @@ class CodeFilterOperation(BaseOperation):
exec(self.config["code"], namespace)
filter_fn = namespace["transform"]
limit_value = self.config.get("limit")
results = []
with ThreadPoolExecutor(
max_workers=self.config.get("concurrent_thread_count", os.cpu_count())
@ -229,6 +215,4 @@ class CodeFilterOperation(BaseOperation):
should_keep = futures[i].result()
if should_keep:
results.append(input_data[i])
if limit_value is not None and len(results) >= limit_value:
break
return results, 0.0

View File

@ -17,7 +17,6 @@ from rich.prompt import Confirm
from docetl.operations.base import BaseOperation
from docetl.operations.utils import strict_render
from docetl.operations.utils.blocking import RuntimeBlockingOptimizer
from docetl.operations.utils.progress import RichLoopBar
from docetl.utils import (
completion_cost,
@ -64,7 +63,6 @@ class EquijoinOperation(BaseOperation):
comparison_prompt: str
output: dict[str, Any] | None = None
blocking_threshold: float | None = None
blocking_target_recall: float | None = None
blocking_conditions: list[str] | None = None
limits: dict[str, int] | None = None
comparison_model: str | None = None
@ -252,58 +250,6 @@ class EquijoinOperation(BaseOperation):
if self.status:
self.status.stop()
# Track pre-computed embeddings from auto-optimization
precomputed_left_embeddings = None
precomputed_right_embeddings = None
# Auto-compute blocking threshold if no blocking configuration is provided
if not blocking_threshold and not blocking_conditions and not limit_comparisons:
# Get target recall from operation config (default 0.95)
target_recall = self.config.get("blocking_target_recall", 0.95)
self.console.log(
f"[yellow]No blocking configuration. Auto-computing threshold (target recall: {target_recall:.0%})...[/yellow]"
)
# Create comparison function for threshold optimization
def compare_fn_for_optimization(left_item, right_item):
return self.compare_pair(
self.config["comparison_prompt"],
self.config.get("comparison_model", self.default_model),
left_item,
right_item,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
)
# Run threshold optimization
optimizer = RuntimeBlockingOptimizer(
runner=self.runner,
config=self.config,
default_model=self.default_model,
max_threads=self.max_threads,
console=self.console,
target_recall=target_recall,
sample_size=min(100, len(left_data) * len(right_data) // 4),
)
(
blocking_threshold,
precomputed_left_embeddings,
precomputed_right_embeddings,
optimization_cost,
) = optimizer.optimize_equijoin(
left_data,
right_data,
compare_fn_for_optimization,
left_keys=left_keys,
right_keys=right_keys,
)
total_cost += optimization_cost
self.console.log(
f"[green]Using auto-computed blocking threshold: {blocking_threshold}[/green]"
)
# Initial blocking using multiprocessing
num_processes = min(cpu_count(), len(left_data))
@ -352,60 +298,45 @@ class EquijoinOperation(BaseOperation):
)
if blocking_threshold is not None:
# Use precomputed embeddings if available from auto-optimization
if (
precomputed_left_embeddings is not None
and precomputed_right_embeddings is not None
):
left_embeddings = precomputed_left_embeddings
right_embeddings = precomputed_right_embeddings
else:
embedding_model = self.config.get("embedding_model", self.default_model)
model_input_context_length = model_cost.get(embedding_model, {}).get(
"max_input_tokens", 8192
)
batch_size = 2000
embedding_model = self.config.get("embedding_model", self.default_model)
model_input_context_length = model_cost.get(embedding_model, {}).get(
"max_input_tokens", 8192
)
def get_embeddings(
input_data: list[dict[str, Any]], keys: list[str], name: str
) -> tuple[list[list[float]], float]:
texts = [
" ".join(str(item[key]) for key in keys if key in item)[
: model_input_context_length * 4
]
for item in input_data
def get_embeddings(
input_data: list[dict[str, Any]], keys: list[str], name: str
) -> tuple[list[list[float]], float]:
texts = [
" ".join(str(item[key]) for key in keys if key in item)[
: model_input_context_length * 4
]
embeddings = []
embedding_cost = 0
num_batches = (len(texts) + batch_size - 1) // batch_size
for item in input_data
]
for batch_idx, i in enumerate(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
if num_batches > 1:
self.console.log(
f"[dim]Creating {name} embeddings: batch {batch_idx + 1}/{num_batches} "
f"({min(i + batch_size, len(texts))}/{len(texts)} items)[/dim]"
)
response = self.runner.api.gen_embedding(
model=embedding_model,
input=batch,
)
embeddings.extend(
[data["embedding"] for data in response["data"]]
)
embedding_cost += completion_cost(response)
return embeddings, embedding_cost
embeddings = []
total_cost = 0
batch_size = 2000
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
self.console.log(
f"On iteration {i} for creating embeddings for {name} data"
)
response = self.runner.api.gen_embedding(
model=embedding_model,
input=batch,
)
embeddings.extend([data["embedding"] for data in response["data"]])
total_cost += completion_cost(response)
return embeddings, total_cost
self.console.log(
f"[cyan]Creating embeddings for {len(left_data)} left + {len(right_data)} right items...[/cyan]"
)
left_embeddings, left_cost = get_embeddings(
left_data, left_keys, "left"
)
right_embeddings, right_cost = get_embeddings(
right_data, right_keys, "right"
)
total_cost += left_cost + right_cost
left_embeddings, left_cost = get_embeddings(left_data, left_keys, "left")
right_embeddings, right_cost = get_embeddings(
right_data, right_keys, "right"
)
total_cost += left_cost + right_cost
self.console.log(
f"Created embeddings for datasets. Total embedding creation cost: {total_cost}"
)
# Compute all cosine similarities in one call
from sklearn.metrics.pairwise import cosine_similarity

View File

@ -26,7 +26,6 @@ class ExtractOperation(BaseOperation):
timeout: int | None = None
skip_on_error: bool = False
litellm_completion_kwargs: dict[str, Any] = Field(default_factory=dict)
limit: int | None = Field(None, gt=0)
@field_validator("prompt")
def validate_prompt(cls, v):
@ -410,10 +409,6 @@ Return only the JSON object with your patterns, no explanatory text.
Returns:
tuple[list[dict], float]: A tuple containing the processed data and the total cost of the operation.
"""
limit_value = self.config.get("limit")
if limit_value is not None:
input_data = input_data[:limit_value]
if not input_data:
return [], 0.0

View File

@ -33,30 +33,6 @@ class FilterOperation(MapOperation):
return self
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._filter_key = next(
iter(
[
k
for k in self.config["output"]["schema"].keys()
if k != "_short_explanation"
]
)
)
self._filter_is_build = False
def _limit_applies_to_inputs(self) -> bool:
return False
def _handle_result(self, result: dict[str, Any]) -> tuple[dict | None, bool]:
keep_record = bool(result.get(self._filter_key))
result.pop(self._filter_key, None)
if self._filter_is_build or keep_record:
return result, keep_record
return None, False
def execute(
self, input_data: list[dict], is_build: bool = False
) -> tuple[list[dict], float]:
@ -70,10 +46,56 @@ class FilterOperation(MapOperation):
Returns:
tuple[list[dict], float]: A tuple containing the filtered list of dictionaries
and the total cost of the operation.
This method performs the following steps:
1. Processes each input item using an LLM model
2. Validates the output
3. Filters the results based on the specified filter key
4. Calculates the total cost of the operation
The method uses multi-threading to process items in parallel, improving performance
for large datasets.
Usage:
```python
from docetl.operations import FilterOperation
config = {
"prompt": "Determine if the following item is important: {{input}}",
"output": {
"schema": {"is_important": "bool"}
},
"model": "gpt-3.5-turbo"
}
filter_op = FilterOperation(config)
input_data = [
{"id": 1, "text": "Critical update"},
{"id": 2, "text": "Regular maintenance"}
]
results, cost = filter_op.execute(input_data)
print(f"Filtered results: {results}")
print(f"Total cost: {cost}")
```
"""
previous_state = self._filter_is_build
self._filter_is_build = is_build
try:
return super().execute(input_data)
finally:
self._filter_is_build = previous_state
filter_key = next(
iter(
[
k
for k in self.config["output"]["schema"].keys()
if k != "_short_explanation"
]
)
)
# Inject retrieval into inner map execution (super)
results, total_cost = super().execute(input_data)
# Drop records with filter_key values that are False
if not is_build:
results = [result for result in results if result[filter_key]]
# Drop the filter_key from the results
for result in results:
result.pop(filter_key, None)
return results, total_cost

View File

@ -43,7 +43,6 @@ class MapOperation(BaseOperation):
litellm_completion_kwargs: dict[str, Any] = {}
pdf_url_key: str | None = None
flush_partial_result: bool = False
limit: int | None = Field(None, gt=0)
# Calibration parameters
calibrate: bool = False
num_calibration_docs: int = Field(10, gt=0)
@ -152,12 +151,6 @@ class MapOperation(BaseOperation):
# Mark that we need to append document statement
self.config["_append_document_to_batch_prompt"] = True
def _limit_applies_to_inputs(self) -> bool:
return True
def _handle_result(self, result: dict[str, Any]) -> tuple[dict | None, bool]:
return result, True
def _generate_calibration_context(self, input_data: list[dict]) -> str:
"""
Generate calibration context by running the operation on a sample of documents
@ -278,27 +271,17 @@ Reference anchors:"""
The method uses parallel processing to improve performance.
"""
limit_value = self.config.get("limit")
# Check if there's no prompt and only drop_keys
if "prompt" not in self.config and "drop_keys" in self.config:
data_to_process = input_data
if limit_value is not None and self._limit_applies_to_inputs():
data_to_process = input_data[:limit_value]
# If only drop_keys is specified, simply drop the keys and return
dropped_results = []
for item in data_to_process:
for item in input_data:
new_item = {
k: v for k, v in item.items() if k not in self.config["drop_keys"]
}
dropped_results.append(new_item)
if limit_value is not None and len(dropped_results) >= limit_value:
break
return dropped_results, 0.0 # Return the modified data with no cost
if limit_value is not None and self._limit_applies_to_inputs():
input_data = input_data[:limit_value]
# Generate calibration context if enabled
calibration_context = ""
if self.config.get("calibrate", False) and "prompt" in self.config:
@ -544,87 +527,40 @@ Reference anchors:"""
return all_results, total_cost
limit_counter = 0
batch_size = self.max_batch_size if self.max_batch_size is not None else 1
total_batches = (len(input_data) + batch_size - 1) // batch_size
if total_batches == 0:
if self.status:
self.status.start()
return [], 0.0
worker_limit = self.max_batch_size or self.max_threads or 1
window_size = (
total_batches
if limit_value is None
else max(1, (limit_value + batch_size - 1) // batch_size)
)
results: list[dict] = []
total_cost = 0.0
limit_reached = False
op_name = self.config["name"]
if limit_value is not None and not self._limit_applies_to_inputs():
self.console.log(
f"[yellow]Note: Operation will terminate early once {limit_value} items pass the filter condition.[/yellow]"
)
with ThreadPoolExecutor(max_workers=worker_limit) as executor:
with RichLoopBar(
total=total_batches,
desc=f"Processing {op_name} (map) on all documents",
with ThreadPoolExecutor(max_workers=self.max_batch_size) as executor:
batch_size = self.max_batch_size if self.max_batch_size is not None else 1
futures = []
for i in range(0, len(input_data), batch_size):
batch = input_data[i : i + batch_size]
futures.append(executor.submit(_process_map_batch, batch))
results = []
total_cost = 0
pbar = RichLoopBar(
range(len(futures)),
desc=f"Processing {self.config['name']} (map) on all documents",
console=self.console,
) as pbar:
chunk_start = 0
while chunk_start < total_batches and not limit_reached:
chunk_end = min(total_batches, chunk_start + window_size)
chunk_ordinals = list(range(chunk_start, chunk_end))
futures = []
for ordinal in chunk_ordinals:
start_idx = ordinal * batch_size
batch = input_data[start_idx : start_idx + batch_size]
futures.append(executor.submit(_process_map_batch, batch))
for relative_idx, future in enumerate(futures):
if limit_value is not None and limit_counter >= limit_value:
limit_reached = True
break
result_list, item_cost = future.result()
total_cost += item_cost
if result_list:
if "drop_keys" in self.config:
result_list = [
{
k: v
for k, v in result.items()
if k not in self.config["drop_keys"]
}
for result in result_list
]
if self.config.get("flush_partial_results", False):
self.runner._flush_partial_results(
op_name, chunk_ordinals[relative_idx], result_list
)
for result in result_list:
processed_result, counts_towards_limit = (
self._handle_result(result)
)
if processed_result is not None:
results.append(processed_result)
if limit_value is not None and counts_towards_limit:
limit_counter += 1
if limit_counter >= limit_value:
limit_reached = True
break
pbar.update()
chunk_start = chunk_end
)
for batch_index in pbar:
result_list, item_cost = futures[batch_index].result()
if result_list:
if "drop_keys" in self.config:
result_list = [
{
k: v
for k, v in result.items()
if k not in self.config["drop_keys"]
}
for result in result_list
]
results.extend(result_list)
# --- BEGIN: Flush partial checkpoint ---
if self.config.get("flush_partial_results", False):
op_name = self.config["name"]
self.runner._flush_partial_results(
op_name, batch_index, result_list
)
# --- END: Flush partial checkpoint ---
total_cost += item_cost
if self.status:
self.status.start()

View File

@ -67,7 +67,6 @@ class ReduceOperation(BaseOperation):
timeout: int | None = None
litellm_completion_kwargs: dict[str, Any] = Field(default_factory=dict)
enable_observability: bool = False
limit: int | None = Field(None, gt=0)
@field_validator("prompt")
def validate_prompt(cls, v):
@ -286,12 +285,6 @@ class ReduceOperation(BaseOperation):
# Convert the grouped data to a list of tuples
grouped_data = list(grouped_data.items())
limit_value = self.config.get("limit")
if limit_value is not None:
# Sort by group size (smallest first) and take the limit
grouped_data = sorted(grouped_data, key=lambda x: len(x[1]))
grouped_data = grouped_data[:limit_value]
def process_group(
key: tuple, group_elems: list[dict]
) -> tuple[dict | None, float]:
@ -426,9 +419,6 @@ class ReduceOperation(BaseOperation):
if output is not None:
results.append(output)
if limit_value is not None and len(results) > limit_value:
results = results[:limit_value]
if self.config.get("persist_intermediates", False):
for result in results:
key = tuple(result[k] for k in self.config["reduce_key"])

View File

@ -10,10 +10,10 @@ import jinja2
from jinja2 import Template
from litellm import model_cost
from pydantic import Field, ValidationInfo, field_validator, model_validator
from rich.prompt import Confirm
from docetl.operations.base import BaseOperation
from docetl.operations.utils import RichLoopBar, rich_as_completed, strict_render
from docetl.operations.utils.blocking import RuntimeBlockingOptimizer
from docetl.utils import (
completion_cost,
extract_jinja_variables,
@ -40,7 +40,6 @@ class ResolveOperation(BaseOperation):
comparison_model: str | None = None
blocking_keys: list[str] | None = None
blocking_threshold: float | None = Field(None, ge=0, le=1)
blocking_target_recall: float | None = Field(None, ge=0, le=1)
blocking_conditions: list[str] | None = None
input: dict[str, Any] | None = None
embedding_batch_size: int | None = Field(None, gt=0)
@ -267,75 +266,26 @@ class ResolveOperation(BaseOperation):
blocking_keys = self.config.get("blocking_keys", [])
blocking_threshold = self.config.get("blocking_threshold")
blocking_conditions = self.config.get("blocking_conditions", [])
limit_comparisons = self.config.get("limit_comparisons")
total_cost = 0
if self.status:
self.status.stop()
# Track pre-computed embeddings from auto-optimization
precomputed_embeddings = None
# Auto-compute blocking threshold if no blocking configuration is provided
if not blocking_threshold and not blocking_conditions and not limit_comparisons:
# Get target recall from operation config (default 0.95)
target_recall = self.config.get("blocking_target_recall", 0.95)
self.console.log(
f"[yellow]No blocking configuration. Auto-computing threshold (target recall: {target_recall:.0%})...[/yellow]"
)
# Determine blocking keys if not set
auto_blocking_keys = blocking_keys if blocking_keys else None
if not auto_blocking_keys:
prompt_template = self.config.get("comparison_prompt", "")
prompt_vars = extract_jinja_variables(prompt_template)
prompt_vars = [
var
for var in prompt_vars
if var not in ["input", "input1", "input2"]
]
auto_blocking_keys = list(
set([var.split(".")[-1] for var in prompt_vars])
)
if not auto_blocking_keys:
auto_blocking_keys = list(input_data[0].keys())
blocking_keys = auto_blocking_keys
# Create comparison function for threshold optimization
def compare_fn_for_optimization(item1, item2):
return self.compare_pair(
self.config["comparison_prompt"],
self.config.get("comparison_model", self.default_model),
item1,
item2,
blocking_keys=[], # Don't use key-based shortcut during optimization
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
)
# Run threshold optimization
optimizer = RuntimeBlockingOptimizer(
runner=self.runner,
config=self.config,
default_model=self.default_model,
max_threads=self.max_threads,
console=self.console,
target_recall=target_recall,
sample_size=min(100, len(input_data) * (len(input_data) - 1) // 4),
)
blocking_threshold, precomputed_embeddings, optimization_cost = (
optimizer.optimize_resolve(
input_data,
compare_fn_for_optimization,
blocking_keys=blocking_keys,
)
)
total_cost += optimization_cost
if not blocking_threshold and not blocking_conditions:
# Prompt the user for confirmation
if not Confirm.ask(
"[yellow]Warning: No blocking keys or conditions specified. "
"This may result in a large number of comparisons. "
"We recommend specifying at least one blocking key or condition, or using the optimizer to automatically come up with these. "
"Do you want to continue without blocking?[/yellow]",
console=self.runner.console,
):
raise ValueError("Operation cancelled by user.")
input_schema = self.config.get("input", {}).get("schema", {})
if not blocking_keys:
# Set them to all keys in the input data
blocking_keys = list(input_data[0].keys())
limit_comparisons = self.config.get("limit_comparisons")
total_cost = 0
def is_match(item1: dict[str, Any], item2: dict[str, Any]) -> bool:
return any(
@ -346,101 +296,120 @@ class ResolveOperation(BaseOperation):
# Calculate embeddings if blocking_threshold is set
embeddings = None
if blocking_threshold is not None:
# Use precomputed embeddings if available from auto-optimization
if precomputed_embeddings is not None:
embeddings = precomputed_embeddings
else:
self.console.log(
f"[cyan]Creating embeddings for {len(input_data)} items...[/cyan]"
)
def get_embeddings_batch(
items: list[dict[str, Any]]
) -> list[tuple[list[float], float]]:
embedding_model = self.config.get(
"embedding_model", "text-embedding-3-small"
)
model_input_context_length = model_cost.get(embedding_model, {}).get(
"max_input_tokens", 8192
)
batch_size = self.config.get("embedding_batch_size", 1000)
embeddings = []
embedding_cost = 0.0
num_batches = (len(input_data) + batch_size - 1) // batch_size
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(input_data))
batch = input_data[start_idx:end_idx]
if num_batches > 1:
self.console.log(
f"[dim]Creating embeddings: batch {batch_idx + 1}/{num_batches} "
f"({end_idx}/{len(input_data)} items)[/dim]"
)
texts = [
" ".join(
str(item[key]) for key in blocking_keys if key in item
)[: model_input_context_length * 3]
for item in batch
texts = [
" ".join(str(item[key]) for key in blocking_keys if key in item)[
: model_input_context_length * 3
]
response = self.runner.api.gen_embedding(
model=embedding_model, input=texts
)
embeddings.extend([data["embedding"] for data in response["data"]])
embedding_cost += completion_cost(response)
for item in items
]
total_cost += embedding_cost
response = self.runner.api.gen_embedding(
model=embedding_model, input=texts
)
return [
(data["embedding"], completion_cost(response))
for data in response["data"]
]
# Build a mapping of blocking key values to indices
# This is used later for cluster merging (when two items match, merge all items sharing their key values)
value_to_indices: dict[tuple[str, ...], list[int]] = {}
for i, item in enumerate(input_data):
key = tuple(str(item.get(k, "")) for k in blocking_keys)
if key not in value_to_indices:
value_to_indices[key] = []
value_to_indices[key].append(i)
embeddings = []
costs = []
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
for i in range(
0, len(input_data), self.config.get("embedding_batch_size", 1000)
):
batch = input_data[
i : i + self.config.get("embedding_batch_size", 1000)
]
batch_results = list(executor.map(get_embeddings_batch, [batch]))
# Total number of pairs to potentially compare
n = len(input_data)
total_pairs = n * (n - 1) // 2
for result in batch_results:
embeddings.extend([r[0] for r in result])
costs.extend([r[1] for r in result])
# Apply code-based blocking conditions (check all pairs)
code_blocked_pairs = []
if blocking_conditions:
for i in range(n):
for j in range(i + 1, n):
if is_match(input_data[i], input_data[j]):
code_blocked_pairs.append((i, j))
total_cost += sum(costs)
# Generate all pairs to compare, ensuring no duplicate comparisons
def get_unique_comparison_pairs() -> (
tuple[list[tuple[int, int]], dict[tuple[str, ...], list[int]]]
):
# Create a mapping of values to their indices
value_to_indices: dict[tuple[str, ...], list[int]] = {}
for i, item in enumerate(input_data):
# Create a hashable key from the blocking keys
key = tuple(str(item.get(k, "")) for k in blocking_keys)
if key not in value_to_indices:
value_to_indices[key] = []
value_to_indices[key].append(i)
# Generate pairs for comparison, comparing each unique value combination only once
comparison_pairs = []
keys = list(value_to_indices.keys())
# First, handle comparisons between different values
for i in range(len(keys)):
for j in range(i + 1, len(keys)):
# Only need one comparison between different values
idx1 = value_to_indices[keys[i]][0]
idx2 = value_to_indices[keys[j]][0]
if idx1 < idx2: # Maintain ordering to avoid duplicates
comparison_pairs.append((idx1, idx2))
return comparison_pairs, value_to_indices
comparison_pairs, value_to_indices = get_unique_comparison_pairs()
# Filter pairs based on blocking conditions
def meets_blocking_conditions(pair: tuple[int, int]) -> bool:
i, j = pair
return (
is_match(input_data[i], input_data[j]) if blocking_conditions else False
)
# Start with pairs that meet blocking conditions, or empty list if no conditions
code_blocked_pairs = (
list(filter(meets_blocking_conditions, comparison_pairs))
if blocking_conditions
else []
)
# Apply cosine similarity blocking if threshold is specified
embedding_blocked_pairs = []
if blocking_threshold is not None and embeddings is not None:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
similarity_matrix = cosine_similarity(embeddings)
# Add pairs that meet the cosine similarity threshold and aren't already blocked
code_blocked_set = set(code_blocked_pairs)
# Use numpy to efficiently find all pairs above threshold
i_indices, j_indices = np.triu_indices(n, k=1)
similarities = similarity_matrix[i_indices, j_indices]
above_threshold_mask = similarities >= blocking_threshold
for i, j in comparison_pairs:
if (i, j) not in code_blocked_set:
similarity = similarity_matrix[i, j]
if similarity >= blocking_threshold:
embedding_blocked_pairs.append((i, j))
# Get pairs above threshold
above_threshold_i = i_indices[above_threshold_mask]
above_threshold_j = j_indices[above_threshold_mask]
self.console.log(
f"Cosine similarity blocking: added {len(embedding_blocked_pairs)} pairs "
f"(threshold: {blocking_threshold})"
)
# Filter out pairs already in code_blocked_set
embedding_blocked_pairs = [
(int(i), int(j))
for i, j in zip(above_threshold_i, above_threshold_j)
if (i, j) not in code_blocked_set
]
# Combine pairs from both blocking methods
# Combine pairs with prioritization for sampling
all_blocked_pairs = code_blocked_pairs + embedding_blocked_pairs
# If no blocking was applied, compare all pairs
if not blocking_conditions and blocking_threshold is None:
all_blocked_pairs = [(i, j) for i in range(n) for j in range(i + 1, n)]
# If no pairs are blocked at all, fall back to all comparison pairs
if not all_blocked_pairs:
all_blocked_pairs = comparison_pairs
# Apply limit_comparisons with prioritization
if limit_comparisons is not None and len(all_blocked_pairs) > limit_comparisons:
# Prioritize code-based pairs, then sample from embedding pairs if needed
@ -507,6 +476,18 @@ class ResolveOperation(BaseOperation):
cluster_map[root_idx] = root1
clusters[root_idx] = set()
# Calculate and print statistics
total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2
comparisons_made = len(blocked_pairs)
comparisons_saved = total_possible_comparisons - comparisons_made
self.console.log(
f"[green]Comparisons saved by deduping and blocking: {comparisons_saved} "
f"({(comparisons_saved / total_possible_comparisons) * 100:.2f}%)[/green]"
)
self.console.log(
f"[blue]Number of pairs to compare: {len(blocked_pairs)}[/blue]"
)
# Compute an auto-batch size based on the number of comparisons
def auto_batch() -> int:
# Maximum batch size limit for 4o-mini model
@ -532,14 +513,7 @@ class ResolveOperation(BaseOperation):
# Compare pairs and update clusters in real-time
batch_size = self.config.get("compare_batch_size", auto_batch())
# Log blocking summary
total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2
self.console.log(
f"Comparing {len(blocked_pairs):,} pairs "
f"({len(blocked_pairs)/total_possible_comparisons*100:.1f}% of {total_possible_comparisons:,} total, "
f"batch size: {batch_size})"
)
self.console.log(f"Using compare batch size: {batch_size}")
pair_costs = 0
pbar = RichLoopBar(

View File

@ -1,5 +1,4 @@
from .api import APIWrapper
from .blocking import RuntimeBlockingOptimizer
from .cache import (
cache,
cache_key,
@ -16,7 +15,6 @@ from .validation import safe_eval, convert_val, convert_dict_schema_to_list_sche
__all__ = [
'APIWrapper',
'RuntimeBlockingOptimizer',
'cache',
'cache_key',
'clear_cache',

View File

@ -1,567 +0,0 @@
"""
Runtime blocking threshold optimization utilities.
This module provides functionality for automatically computing embedding-based
blocking thresholds at runtime when no blocking configuration is provided.
"""
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable
import numpy as np
from litellm import model_cost
from rich.console import Console
from docetl.utils import completion_cost, extract_jinja_variables
class RuntimeBlockingOptimizer:
"""
Computes optimal embedding-based blocking thresholds at runtime.
This class samples pairs from the dataset, performs LLM comparisons,
and finds the optimal cosine similarity threshold that achieves a
target recall rate.
"""
def __init__(
self,
runner,
config: dict[str, Any],
default_model: str,
max_threads: int,
console: Console,
target_recall: float = 0.95,
sample_size: int = 100,
sampling_weight: float = 20.0,
):
"""
Initialize the RuntimeBlockingOptimizer.
Args:
runner: The pipeline runner instance.
config: Operation configuration.
default_model: Default LLM model for comparisons.
max_threads: Maximum threads for parallel processing.
console: Rich console for logging.
target_recall: Target recall rate (default 0.95).
sample_size: Number of pairs to sample for threshold estimation.
sampling_weight: Weight for exponential sampling towards higher similarities.
"""
self.runner = runner
self.config = config
self.default_model = default_model
self.max_threads = max_threads
self.console = console
self.target_recall = target_recall
self.sample_size = sample_size
self.sampling_weight = sampling_weight
def compute_embeddings(
self,
input_data: list[dict[str, Any]],
keys: list[str],
embedding_model: str | None = None,
batch_size: int = 1000,
) -> tuple[list[list[float]], float]:
"""
Compute embeddings for the input data.
Args:
input_data: List of input documents.
keys: Keys to use for embedding text.
embedding_model: Model to use for embeddings.
batch_size: Batch size for embedding computation.
Returns:
Tuple of (embeddings list, total cost).
"""
embedding_model = embedding_model or self.config.get(
"embedding_model", "text-embedding-3-small"
)
model_input_context_length = model_cost.get(embedding_model, {}).get(
"max_input_tokens", 8192
)
texts = [
" ".join(str(item[key]) for key in keys if key in item)[
: model_input_context_length * 3
]
for item in input_data
]
self.console.log(f"[cyan]Creating embeddings for {len(texts)} items...[/cyan]")
embeddings = []
total_cost = 0.0
num_batches = (len(texts) + batch_size - 1) // batch_size
for batch_idx, i in enumerate(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
if num_batches > 1:
self.console.log(
f"[dim] Batch {batch_idx + 1}/{num_batches} "
f"({len(embeddings) + len(batch)}/{len(texts)} items)[/dim]"
)
response = self.runner.api.gen_embedding(
model=embedding_model,
input=batch,
)
embeddings.extend([data["embedding"] for data in response["data"]])
total_cost += completion_cost(response)
return embeddings, total_cost
def calculate_cosine_similarities_self(
self, embeddings: list[list[float]]
) -> list[tuple[int, int, float]]:
"""
Calculate pairwise cosine similarities for self-join.
Args:
embeddings: List of embedding vectors.
Returns:
List of (i, j, similarity) tuples for all pairs where i < j.
"""
embeddings_array = np.array(embeddings)
norms = np.linalg.norm(embeddings_array, axis=1)
# Avoid division by zero
norms = np.where(norms == 0, 1e-10, norms)
dot_products = np.dot(embeddings_array, embeddings_array.T)
similarities_matrix = dot_products / np.outer(norms, norms)
i, j = np.triu_indices(len(embeddings), k=1)
similarities = list(
zip(i.tolist(), j.tolist(), similarities_matrix[i, j].tolist())
)
return similarities
def calculate_cosine_similarities_cross(
self,
left_embeddings: list[list[float]],
right_embeddings: list[list[float]],
) -> list[tuple[int, int, float]]:
"""
Calculate cosine similarities between two sets of embeddings.
Args:
left_embeddings: Embeddings for left dataset.
right_embeddings: Embeddings for right dataset.
Returns:
List of (left_idx, right_idx, similarity) tuples.
"""
left_array = np.array(left_embeddings)
right_array = np.array(right_embeddings)
dot_product = np.dot(left_array, right_array.T)
norm_left = np.linalg.norm(left_array, axis=1)
norm_right = np.linalg.norm(right_array, axis=1)
# Avoid division by zero
norm_left = np.where(norm_left == 0, 1e-10, norm_left)
norm_right = np.where(norm_right == 0, 1e-10, norm_right)
similarities = dot_product / np.outer(norm_left, norm_right)
return [
(i, j, float(sim))
for i, row in enumerate(similarities)
for j, sim in enumerate(row)
]
def sample_pairs(
self,
similarities: list[tuple[int, int, float]],
num_bins: int = 10,
stratified_fraction: float = 0.5,
) -> list[tuple[int, int]]:
"""
Sample pairs using a hybrid of stratified and exponential-weighted sampling.
This ensures coverage across the similarity distribution while still
focusing on high-similarity pairs where matches are more likely.
Args:
similarities: List of (i, j, similarity) tuples.
num_bins: Number of bins for stratified sampling.
stratified_fraction: Fraction of samples to allocate to stratified sampling.
Returns:
List of sampled (i, j) pairs.
"""
if len(similarities) == 0:
return []
sample_count = min(self.sample_size, len(similarities))
stratified_count = int(sample_count * stratified_fraction)
exponential_count = sample_count - stratified_count
sampled_indices = set()
sim_values = np.array([sim[2] for sim in similarities])
# Part 1: Stratified sampling across bins
if stratified_count > 0:
bin_edges = np.linspace(
sim_values.min(), sim_values.max() + 1e-9, num_bins + 1
)
samples_per_bin = max(1, stratified_count // num_bins)
for bin_idx in range(num_bins):
bin_mask = (sim_values >= bin_edges[bin_idx]) & (
sim_values < bin_edges[bin_idx + 1]
)
bin_indices = np.where(bin_mask)[0]
if len(bin_indices) > 0:
# Within each bin, use exponential weighting
bin_sims = sim_values[bin_indices]
bin_weights = np.exp(self.sampling_weight * bin_sims)
bin_weights /= bin_weights.sum()
n_to_sample = min(samples_per_bin, len(bin_indices))
chosen = np.random.choice(
bin_indices,
size=n_to_sample,
replace=False,
p=bin_weights,
)
sampled_indices.update(chosen.tolist())
# Part 2: Exponential-weighted sampling for remaining slots
if exponential_count > 0:
remaining_indices = [
i for i in range(len(similarities)) if i not in sampled_indices
]
if remaining_indices:
remaining_sims = sim_values[remaining_indices]
weights = np.exp(self.sampling_weight * remaining_sims)
weights /= weights.sum()
n_to_sample = min(exponential_count, len(remaining_indices))
chosen = np.random.choice(
remaining_indices,
size=n_to_sample,
replace=False,
p=weights,
)
sampled_indices.update(chosen.tolist())
sampled_pairs = [
(similarities[i][0], similarities[i][1]) for i in sampled_indices
]
return sampled_pairs
def _print_similarity_histogram(
self,
similarities: list[tuple[int, int, float]],
comparison_results: list[tuple[int, int, bool]],
threshold: float | None = None,
):
"""
Print a histogram of embedding cosine similarity distribution.
Args:
similarities: List of (i, j, similarity) tuples.
comparison_results: List of (i, j, is_match) from LLM comparisons.
threshold: Optional threshold to highlight in the histogram.
"""
# Filter out self-similarities (similarity == 1)
flat_similarities = [sim[2] for sim in similarities if sim[2] != 1]
if not flat_similarities:
return
hist, bin_edges = np.histogram(flat_similarities, bins=20)
max_bar_width, max_count = 40, max(hist) if max(hist) > 0 else 1
normalized_hist = [int(count / max_count * max_bar_width) for count in hist]
# Create a dictionary to store true labels
true_labels = {(i, j): is_match for i, j, is_match in comparison_results}
# Count pairs above threshold
pairs_above_threshold = (
sum(1 for sim in flat_similarities if sim >= threshold) if threshold else 0
)
total_pairs = len(flat_similarities)
lines = []
for i, count in enumerate(normalized_hist):
bar = "" * count
bin_start, bin_end = bin_edges[i], bin_edges[i + 1]
label = f"{bin_start:.2f}-{bin_end:.2f}"
# Count true matches and not matches in this bin
true_matches = 0
not_matches = 0
labeled_count = 0
for sim in similarities:
if bin_start <= sim[2] < bin_end:
if (sim[0], sim[1]) in true_labels:
labeled_count += 1
if true_labels[(sim[0], sim[1])]:
true_matches += 1
else:
not_matches += 1
# Calculate percentages of labeled pairs
if labeled_count > 0:
true_match_percent = (true_matches / labeled_count) * 100
label_info = f"[green]{true_match_percent:5.1f}%[/green] match"
else:
label_info = "[dim]--[/dim]"
# Highlight the bin containing the threshold
if threshold is not None and bin_start <= threshold < bin_end:
lines.append(
f"[bold yellow]{label}[/bold yellow] {bar:<{max_bar_width}} "
f"[dim]n={hist[i]:>5}[/dim] {label_info} [bold yellow]◀ threshold[/bold yellow]"
)
else:
lines.append(
f"{label} {bar:<{max_bar_width}} "
f"[dim]n={hist[i]:>5}[/dim] {label_info}"
)
from rich.panel import Panel
histogram_content = "\n".join(lines)
title = f"Similarity Distribution ({pairs_above_threshold:,} of {total_pairs:,} pairs ≥ {threshold:.4f})"
self.console.log(Panel(histogram_content, title=title, border_style="cyan"))
def find_optimal_threshold(
self,
comparisons: list[tuple[int, int, bool]],
similarities: list[tuple[int, int, float]],
) -> tuple[float, float]:
"""
Find the optimal similarity threshold that achieves target recall.
Args:
comparisons: List of (i, j, is_match) from LLM comparisons.
similarities: List of (i, j, similarity) tuples.
Returns:
Tuple of (optimal_threshold, achieved_recall).
"""
if not comparisons or not any(comp[2] for comp in comparisons):
# No matches found, use a high threshold to be conservative
self.console.log(
"[yellow]No matches found in sample. Using 99th percentile "
"similarity as threshold.[/yellow]"
)
all_sims = [sim[2] for sim in similarities]
threshold = float(np.percentile(all_sims, 99)) if all_sims else 0.9
return threshold, 0.0
true_labels = np.array([comp[2] for comp in comparisons])
sim_dict = {(i, j): sim for i, j, sim in similarities}
sim_scores = np.array([sim_dict.get((i, j), 0.0) for i, j, _ in comparisons])
thresholds = np.linspace(0, 1, 100)
recalls = []
for threshold in thresholds:
predictions = sim_scores >= threshold
tp = np.sum(predictions & true_labels)
fn = np.sum(~predictions & true_labels)
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
recalls.append(recall)
# Find highest threshold that achieves target recall
valid_indices = [i for i, r in enumerate(recalls) if r >= self.target_recall]
if not valid_indices:
# If no threshold achieves target recall, use the one with highest recall
best_idx = int(np.argmax(recalls))
optimal_threshold = float(thresholds[best_idx])
achieved_recall = float(recalls[best_idx])
self.console.log(
f"[yellow]Warning: Could not achieve target recall {self.target_recall:.0%}. "
f"Using threshold {optimal_threshold:.4f} with recall {achieved_recall:.2%}.[/yellow]"
)
else:
best_idx = max(valid_indices)
optimal_threshold = float(thresholds[best_idx])
achieved_recall = float(recalls[best_idx])
return round(optimal_threshold, 4), achieved_recall
def optimize_resolve(
self,
input_data: list[dict[str, Any]],
compare_fn: Callable[[dict, dict], tuple[bool, float, str]],
blocking_keys: list[str] | None = None,
) -> tuple[float, list[list[float]], float]:
"""
Compute optimal blocking threshold for resolve operation.
Args:
input_data: Input dataset.
compare_fn: Function to compare two items, returns (is_match, cost, prompt).
blocking_keys: Keys to use for blocking. If None, extracted from prompt.
Returns:
Tuple of (threshold, embeddings, total_cost).
"""
from rich.panel import Panel
# Determine blocking keys
if not blocking_keys:
prompt_template = self.config.get("comparison_prompt", "")
prompt_vars = extract_jinja_variables(prompt_template)
prompt_vars = [
var for var in prompt_vars if var not in ["input", "input1", "input2"]
]
blocking_keys = list(set([var.split(".")[-1] for var in prompt_vars]))
if not blocking_keys:
blocking_keys = list(input_data[0].keys())
# Compute embeddings
embeddings, embedding_cost = self.compute_embeddings(input_data, blocking_keys)
# Calculate similarities
similarities = self.calculate_cosine_similarities_self(embeddings)
# Sample pairs
sampled_pairs = self.sample_pairs(similarities)
if not sampled_pairs:
self.console.log(
"[yellow]No pairs to sample. Using default threshold 0.8.[/yellow]"
)
return 0.8, embeddings, embedding_cost
# Perform comparisons
comparisons = []
comparison_cost = 0.0
matches_found = 0
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
futures = {
executor.submit(compare_fn, input_data[i], input_data[j]): (i, j)
for i, j in sampled_pairs
}
for future in as_completed(futures):
i, j = futures[future]
try:
is_match, cost, _ = future.result()
comparisons.append((i, j, is_match))
comparison_cost += cost
if is_match:
matches_found += 1
except Exception as e:
self.console.log(f"[red]Comparison error: {e}[/red]")
comparisons.append((i, j, False))
# Find optimal threshold
threshold, achieved_recall = self.find_optimal_threshold(
comparisons, similarities
)
total_cost = embedding_cost + comparison_cost
# Print histogram visualization
self._print_similarity_histogram(similarities, comparisons, threshold)
# Print summary
n = len(input_data)
total_pairs = n * (n - 1) // 2
pairs_above = sum(1 for s in similarities if s[2] >= threshold)
summary = (
f"[bold]Blocking keys:[/bold] {blocking_keys}\n"
f"[bold]Sampled:[/bold] {len(sampled_pairs)} pairs → {matches_found} matches ({matches_found/len(sampled_pairs)*100:.1f}%)\n"
f"[bold]Threshold:[/bold] {threshold:.4f}{achieved_recall:.1%} recall (target: {self.target_recall:.0%})\n"
f"[bold]Pairs to compare:[/bold] {pairs_above:,} of {total_pairs:,} ({pairs_above/total_pairs*100:.1f}%)\n"
f"[bold]Optimization cost:[/bold] ${total_cost:.4f}"
)
self.console.log(
Panel(
summary, title="Blocking Threshold Optimization", border_style="green"
)
)
return threshold, embeddings, total_cost
def optimize_equijoin(
self,
left_data: list[dict[str, Any]],
right_data: list[dict[str, Any]],
compare_fn: Callable[[dict, dict], tuple[bool, float]],
left_keys: list[str] | None = None,
right_keys: list[str] | None = None,
) -> tuple[float, list[list[float]], list[list[float]], float]:
"""
Compute optimal blocking threshold for equijoin operation.
Args:
left_data: Left dataset.
right_data: Right dataset.
compare_fn: Function to compare two items, returns (is_match, cost).
left_keys: Keys to use for left dataset embeddings.
right_keys: Keys to use for right dataset embeddings.
Returns:
Tuple of (threshold, left_embeddings, right_embeddings, total_cost).
"""
from rich.panel import Panel
# Determine keys
if not left_keys:
left_keys = list(left_data[0].keys()) if left_data else []
if not right_keys:
right_keys = list(right_data[0].keys()) if right_data else []
# Compute embeddings
left_embeddings, left_cost = self.compute_embeddings(left_data, left_keys)
right_embeddings, right_cost = self.compute_embeddings(right_data, right_keys)
embedding_cost = left_cost + right_cost
# Calculate cross similarities
similarities = self.calculate_cosine_similarities_cross(
left_embeddings, right_embeddings
)
# Sample pairs
sampled_pairs = self.sample_pairs(similarities)
if not sampled_pairs:
self.console.log(
"[yellow]No pairs to sample. Using default threshold 0.8.[/yellow]"
)
return 0.8, left_embeddings, right_embeddings, embedding_cost
# Perform comparisons
comparisons = []
comparison_cost = 0.0
matches_found = 0
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
futures = {
executor.submit(compare_fn, left_data[i], right_data[j]): (i, j)
for i, j in sampled_pairs
}
for future in as_completed(futures):
i, j = futures[future]
try:
is_match, cost = future.result()
comparisons.append((i, j, is_match))
comparison_cost += cost
if is_match:
matches_found += 1
except Exception as e:
self.console.log(f"[red]Comparison error: {e}[/red]")
comparisons.append((i, j, False))
# Find optimal threshold
threshold, achieved_recall = self.find_optimal_threshold(
comparisons, similarities
)
total_cost = embedding_cost + comparison_cost
# Print histogram visualization
self._print_similarity_histogram(similarities, comparisons, threshold)
# Print summary
total_pairs = len(left_data) * len(right_data)
pairs_above = sum(1 for s in similarities if s[2] >= threshold)
summary = (
f"[bold]Left keys:[/bold] {left_keys} [bold]Right keys:[/bold] {right_keys}\n"
f"[bold]Sampled:[/bold] {len(sampled_pairs)} pairs → {matches_found} matches ({matches_found/len(sampled_pairs)*100:.1f}%)\n"
f"[bold]Threshold:[/bold] {threshold:.4f}{achieved_recall:.1%} recall (target: {self.target_recall:.0%})\n"
f"[bold]Pairs to compare:[/bold] {pairs_above:,} of {total_pairs:,} ({pairs_above/total_pairs*100:.1f}%)\n"
f"[bold]Optimization cost:[/bold] ${total_cost:.4f}"
)
self.console.log(
Panel(
summary, title="Blocking Threshold Optimization", border_style="green"
)
)
return threshold, left_embeddings, right_embeddings, total_cost

View File

@ -55,7 +55,7 @@ class Optimizer:
def __init__(
self,
runner: "DSLRunner",
rewrite_agent_model: str = "gpt-5.1",
rewrite_agent_model: str = "gpt-4o",
judge_agent_model: str = "gpt-4o-mini",
litellm_kwargs: dict[str, Any] = {},
resume: bool = False,
@ -67,7 +67,7 @@ class Optimizer:
Args:
yaml_file (str): Path to the YAML configuration file.
model (str): The name of the language model to use. Defaults to "gpt-5.1".
model (str): The name of the language model to use. Defaults to "gpt-4o".
resume (bool): Whether to resume optimization from a previous run. Defaults to False.
timeout (int): Timeout in seconds for operations. Defaults to 60.

View File

@ -1,926 +0,0 @@
"""
Fast decomposition analyzer for map operations.
This module provides a fast way to decompose map operations using directives,
running candidates on samples, and selecting the best via pairwise LLM comparison.
"""
import json
import os
import tempfile
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy
from typing import Any
import litellm
from litellm import completion, model_cost
from rich.console import Console
from docetl.console import get_console
from docetl.reasoning_optimizer.directives import (
ChainingDirective,
ClarifyInstructionsDirective,
DeterministicDocCompressionDirective,
DocumentChunkingDirective,
GleaningDirective,
IsolatingSubtasksDirective,
)
from docetl.runner import DSLRunner
from docetl.utils import completion_cost, count_tokens
# Drop unsupported params for models like gpt-5 that don't support temperature=0
litellm.drop_params = True
# Base directives always applicable to map operations
BASE_MAP_DIRECTIVES = [
ChainingDirective(),
IsolatingSubtasksDirective(),
GleaningDirective(),
ClarifyInstructionsDirective(),
]
# Directives that depend on document size
CHUNKING_DIRECTIVE = DocumentChunkingDirective()
COMPRESSION_DIRECTIVE = DeterministicDocCompressionDirective()
# Threshold for enabling DeterministicDocCompression (in characters)
DOC_COMPRESSION_CHAR_THRESHOLD = 1000
# Threshold for enabling DocumentChunking (as fraction of context window)
DOC_CHUNKING_CONTEXT_THRESHOLD = 0.10 # 10%
class FastDecomposer:
"""
Fast decomposition of map operations using directives and pairwise comparison.
Instead of the full optimizer flow, this:
1. Tries multiple directives to generate candidate decompositions
2. Runs each candidate on sample documents
3. Uses LLM judge for pairwise comparison
4. Returns the winning decomposition
"""
def __init__(
self,
yaml_config_path: str,
optimizer_model: str = "gpt-5.1",
sample_size: int = 5,
litellm_kwargs: dict[str, Any] | None = None,
console: Console | None = None,
):
"""
Initialize the decomposer.
Args:
yaml_config_path: Path to the pipeline YAML config file
optimizer_model: LLM model to use for directive instantiation and judging
sample_size: Number of sample documents to run candidates on
litellm_kwargs: Additional kwargs to pass to litellm.completion
console: Rich console for output (uses default if not provided)
"""
self.yaml_config_path = yaml_config_path
self.optimizer_model = optimizer_model
self.sample_size = sample_size
self.litellm_kwargs = litellm_kwargs or {}
if "temperature" not in self.litellm_kwargs:
self.litellm_kwargs["temperature"] = 0.0
self.total_cost = 0.0
self.console = console or get_console()
# Load the config
import yaml
with open(yaml_config_path, "r") as f:
self.config = yaml.safe_load(f)
self.operators = self.config.get("operations", [])
self.default_model = self.config.get("default_model", "gpt-4o-mini")
self.intermediate_dir = (
self.config.get("pipeline", {}).get("output", {}).get("intermediate_dir")
)
def _log(self, message: str) -> None:
"""Log a message to the console."""
self.console.log(message)
def get_model_context_limit(self, model: str) -> int:
"""
Get the context window limit for a model.
Args:
model: The model name (e.g., 'gpt-4o', 'azure/gpt-4')
Returns:
Maximum number of input tokens the model can handle
"""
model_info = model_cost.get(model, {})
# Try without provider prefix if not found
if not model_info:
model_name = model.split("/")[-1]
model_info = model_cost.get(model_name, {})
return model_info.get("max_input_tokens", 128000) # Default to 128k
def get_avg_doc_size(
self, sample_data: list[dict[str, Any]], op_config: dict[str, Any]
) -> tuple[float, float]:
"""
Calculate the average document size in characters and tokens.
Extracts the document content from sample data based on the operation's
prompt template (looks for {{ input.field_name }} patterns).
Args:
sample_data: List of sample documents
op_config: The operation configuration
Returns:
Tuple of (avg_chars, avg_tokens) for the document content
"""
import re
if not sample_data:
return 0.0, 0.0
prompt = op_config.get("prompt", "")
model = op_config.get("model", self.default_model)
# Extract field names from prompt template ({{ input.field_name }})
field_pattern = r"\{\{\s*input\.(\w+)\s*\}\}"
fields = re.findall(field_pattern, prompt)
if not fields:
# Fallback: use all string values from the first document
if sample_data:
fields = [
k
for k, v in sample_data[0].items()
if isinstance(v, str) and len(v) > 100
]
total_chars = 0
total_tokens = 0
for doc in sample_data:
doc_content = ""
for field in fields:
if field in doc:
value = doc[field]
if isinstance(value, str):
doc_content += value
else:
doc_content += str(value)
total_chars += len(doc_content)
if doc_content:
total_tokens += count_tokens(doc_content, model)
n = len(sample_data)
return total_chars / n if n > 0 else 0.0, total_tokens / n if n > 0 else 0.0
def get_applicable_directives(
self,
sample_data: list[dict[str, Any]],
op_config: dict[str, Any],
) -> list:
"""
Get the list of directives applicable based on data characteristics.
- DocumentChunkingDirective: only if avg doc size > 10% of context window
- DeterministicDocCompressionDirective: only if avg doc size > 1000 chars
Args:
sample_data: List of sample documents
op_config: The operation configuration
Returns:
List of applicable directive instances
"""
directives = [] # Build list with priority ordering
model = op_config.get("model", self.default_model)
context_limit = self.get_model_context_limit(model)
avg_chars, avg_tokens = self.get_avg_doc_size(sample_data, op_config)
self._log(
f"Document analysis: avg_chars={avg_chars:.0f}, avg_tokens={avg_tokens:.0f}, "
f"context_limit={context_limit}"
)
# Add DeterministicDocCompression FIRST if doc size > 1000 chars (high priority)
if avg_chars > DOC_COMPRESSION_CHAR_THRESHOLD:
self._log(
f"[cyan]Enabling DeterministicDocCompression (priority)[/cyan] "
f"(avg {avg_chars:.0f} chars > {DOC_COMPRESSION_CHAR_THRESHOLD})"
)
directives.append(COMPRESSION_DIRECTIVE)
# Add base directives
directives.extend(BASE_MAP_DIRECTIVES)
# Add DocumentChunking if avg tokens > 10% of context window
token_threshold = context_limit * DOC_CHUNKING_CONTEXT_THRESHOLD
if avg_tokens > token_threshold:
self._log(
f"[cyan]Enabling DocumentChunking[/cyan] "
f"(avg {avg_tokens:.0f} tokens > {token_threshold:.0f} = 10% of {context_limit})"
)
directives.append(CHUNKING_DIRECTIVE)
else:
self._log(
f"[dim]Skipping DocumentChunking[/dim] "
f"(avg {avg_tokens:.0f} tokens <= {token_threshold:.0f} = 10% of {context_limit})"
)
return directives
def load_sample_data(self, step_name: str, op_name: str) -> list[dict[str, Any]]:
"""
Load sample input data for an operation.
For the first operation, loads from the dataset.
For subsequent operations, loads from the previous operation's intermediate output.
Args:
step_name: Name of the pipeline step
op_name: Name of the operation
Returns:
List of sample documents
"""
# Find the operation's position
op_names = [op.get("name") for op in self.operators]
try:
op_idx = op_names.index(op_name)
except ValueError:
raise ValueError(f"Operation '{op_name}' not found in config")
if op_idx == 0:
# First operation - load from dataset
datasets = self.config.get("datasets", {})
# Get the first dataset (or the one used by this step)
for dataset_name, dataset_config in datasets.items():
dataset_path = dataset_config.get("path")
if dataset_path and os.path.exists(dataset_path):
with open(dataset_path, "r") as f:
data = json.load(f)
return data[: self.sample_size]
raise FileNotFoundError("No dataset found in config")
else:
# Load from previous operation's intermediate output
prev_op_name = op_names[op_idx - 1]
output_path = os.path.join(
self.intermediate_dir, step_name, f"{prev_op_name}.json"
)
if not os.path.exists(output_path):
raise FileNotFoundError(
f"No intermediate output found at {output_path}. "
"Run the previous operation first."
)
with open(output_path, "r") as f:
data = json.load(f)
return data[: self.sample_size]
def get_input_file_path(self) -> str:
"""Get the input file path for the pipeline."""
datasets = self.config.get("datasets", {})
for dataset_name, dataset_config in datasets.items():
path = dataset_config.get("path")
if path:
return path
return ""
def generate_candidates(
self,
op_name: str,
sample_data: list[dict[str, Any]],
target_op: dict[str, Any],
) -> list[dict[str, Any]]:
"""
Generate candidate decompositions using directives.
Directives are selected based on data characteristics:
- DocumentChunkingDirective: only if avg doc size > 10% of context window
- DeterministicDocCompressionDirective: only if avg doc size > 1000 chars
Args:
op_name: Name of the operation to decompose
sample_data: Sample data for analyzing document characteristics
target_op: The target operation configuration
Returns:
List of candidate dictionaries with 'name', 'ops', 'cost' keys
"""
candidates = []
# Add original as baseline
self._log("Adding original operation as baseline candidate")
candidates.append(
{
"name": "original",
"ops": deepcopy(self.operators),
"cost": 0.0,
"error": None,
}
)
input_file_path = self.get_input_file_path()
# Get applicable directives based on data characteristics
applicable_directives = self.get_applicable_directives(sample_data, target_op)
self._log(
f"Generating candidates using {len(applicable_directives)} directives..."
)
for i, directive in enumerate(applicable_directives, 1):
with self.console.status(
f"[bold cyan]({i}/{len(applicable_directives)}) Trying directive: {directive.name}...[/bold cyan]",
spinner="dots",
):
try:
new_ops_list, _, cost = directive.instantiate(
operators=deepcopy(self.operators),
target_ops=[op_name],
agent_llm=self.optimizer_model,
message_history=[],
global_default_model=self.default_model,
input_file_path=input_file_path,
)
self.total_cost += cost
self._log(
f" [green]✓[/green] {directive.name} generated {len(new_ops_list)} operations (cost: ${cost:.4f})"
)
candidates.append(
{
"name": directive.name,
"ops": new_ops_list,
"cost": cost,
"error": None,
}
)
except Exception as e:
# Directive not applicable or failed - skip it
self._log(f" [red]✗[/red] {directive.name} failed: {str(e)}")
candidates.append(
{
"name": directive.name,
"ops": None,
"cost": 0.0,
"error": str(e),
}
)
return candidates
def extract_ops_to_run(
self, ops_list: list[dict], original_op_name: str
) -> list[dict]:
"""
Extract the operations that replaced the original operation.
Args:
ops_list: The full transformed operations list
original_op_name: Name of the original operation that was decomposed
Returns:
List of operations that should be run on samples
"""
# Find ops that are new (not in original) or modified
original_names = {op["name"] for op in self.operators}
# Find the position where the original op was
original_idx = None
for i, op in enumerate(self.operators):
if op["name"] == original_op_name:
original_idx = i
break
if original_idx is None:
return ops_list
# Find new ops (those not in original list)
new_ops = []
for op in ops_list:
if op["name"] not in original_names or op["name"] == original_op_name:
new_ops.append(op)
return new_ops if new_ops else [ops_list[original_idx]]
def run_candidate_on_samples(
self,
candidate: dict[str, Any],
sample_data: list[dict[str, Any]],
original_op_name: str,
) -> list[dict[str, Any]]:
"""
Run a candidate's operations on sample data.
Args:
candidate: Candidate dictionary with 'ops' key
sample_data: List of sample documents
original_op_name: Name of the original operation
Returns:
List of output documents
"""
if candidate["ops"] is None:
return []
# Extract ops to run
ops_to_run = self.extract_ops_to_run(candidate["ops"], original_op_name)
if not ops_to_run:
return []
# Create a minimal config for running these ops
# Write sample data to a temp file
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(sample_data, f)
temp_input_path = f.name
# Create a temp output file (required by DSLRunner validation)
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
temp_output_path = f.name
try:
# Create a minimal pipeline config
temp_config = {
"default_model": self.default_model,
"operations": ops_to_run,
"datasets": {
"sample_data": {
"type": "file",
"path": temp_input_path,
}
},
"pipeline": {
"steps": [
{
"name": "decompose_test",
"input": "sample_data",
"operations": [op["name"] for op in ops_to_run],
}
],
"output": {
"type": "file",
"path": temp_output_path,
},
},
}
# Create runner and execute
runner = DSLRunner(temp_config, max_threads=4)
# Run operations sequentially on the data
current_data = sample_data
for op_config in ops_to_run:
current_data = runner._run_operation(op_config, current_data)
self.total_cost += runner.total_cost
return current_data
finally:
# Clean up temp files
for temp_path in [temp_input_path, temp_output_path]:
try:
os.unlink(temp_path)
except Exception:
pass
def pairwise_compare(
self,
candidate_a: dict[str, Any],
candidate_b: dict[str, Any],
original_prompt: str,
output_schema: dict[str, Any],
) -> dict[str, Any]:
"""
Compare two candidates using LLM judge.
Args:
candidate_a: First candidate with 'name', 'outputs' keys
candidate_b: Second candidate with 'name', 'outputs' keys
original_prompt: The original operation's prompt
output_schema: The expected output schema
Returns:
The winning candidate
"""
# If one has no outputs, the other wins
if not candidate_a.get("outputs"):
return candidate_b
if not candidate_b.get("outputs"):
return candidate_a
system_prompt = """You are an expert judge comparing outputs from two data processing pipeline variants.
Your task is to determine which variant produces BETTER outputs based on:
1. **Completeness**: Does the output contain all required information?
2. **Accuracy**: Is the extracted/generated information correct?
3. **Consistency**: Are the outputs consistent across different samples?
4. **Quality**: Is the output well-structured and useful?
Be objective and focus on the actual output quality, not the approach used."""
user_prompt = f"""Compare outputs from two pipeline variants for this task:
## Original Task
**Prompt:**
```
{original_prompt[:2000]}{"..." if len(original_prompt) > 2000 else ""}
```
**Expected Output Schema:**
```json
{json.dumps(output_schema, indent=2)}
```
## Variant A: {candidate_a["name"]}
**Sample Outputs:**
```json
{json.dumps(candidate_a["outputs"][:3], indent=2, default=str)}
```
## Variant B: {candidate_b["name"]}
**Sample Outputs:**
```json
{json.dumps(candidate_b["outputs"][:3], indent=2, default=str)}
```
## Your Task
Which variant produces better outputs? Consider completeness, accuracy, consistency, and quality.
Respond with your analysis and final choice."""
response = completion(
model=self.optimizer_model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format={
"type": "json_schema",
"json_schema": {
"name": "comparison_result",
"strict": True,
"schema": {
"type": "object",
"properties": {
"winner": {
"type": "string",
"enum": ["A", "B"],
"description": "Which variant is better: A or B",
},
"rationale": {
"type": "string",
"description": "Explanation of why this variant is better",
},
"a_strengths": {
"type": "array",
"items": {"type": "string"},
"description": "Strengths of variant A",
},
"b_strengths": {
"type": "array",
"items": {"type": "string"},
"description": "Strengths of variant B",
},
},
"required": [
"winner",
"rationale",
"a_strengths",
"b_strengths",
],
"additionalProperties": False,
},
},
},
**self.litellm_kwargs,
)
cost = completion_cost(response)
self.total_cost += cost
result = json.loads(response.choices[0].message.content)
winner = candidate_a if result["winner"] == "A" else candidate_b
winner["comparison_rationale"] = result["rationale"]
return winner
def decompose(
self,
step_name: str,
op_name: str,
) -> tuple[list[dict[str, Any]], str, int, float]:
"""
Decompose an operation using fast directive-based approach.
This is the main entry point. It:
1. Generates candidate decompositions using directives
2. Runs each on sample data
3. Uses pairwise LLM comparison to select the best
4. Returns the winning decomposition
Args:
step_name: Name of the pipeline step
op_name: Name of the operation to decompose
Returns:
Dict with:
- decomposed_ops: List of operations that replace the original
- winning_directive: Name of the winning directive
- candidates_evaluated: Number of candidates that were compared
- original_outputs: Sample outputs from the original operation
- decomposed_outputs: Sample outputs from the winning decomposition
- comparison_rationale: LLM's explanation of why the winner was chosen
- cost: Total LLM API cost
"""
self.console.rule("[bold blue]Fast Decomposition[/bold blue]")
self._log(f"[bold]Target operation:[/bold] {op_name}")
# Find the target operation config
target_op = None
for op in self.operators:
if op["name"] == op_name:
target_op = op
break
if target_op is None:
raise ValueError(f"Operation '{op_name}' not found in config")
# Verify it's a map operation
if target_op.get("type") != "map":
raise ValueError(
f"Operation '{op_name}' is type '{target_op.get('type')}', "
"but fast decomposition only supports 'map' operations"
)
# Load sample data
self._log(f"Loading sample data from step '{step_name}'...")
sample_data = self.load_sample_data(step_name, op_name)
self._log(f"[green]✓[/green] Loaded {len(sample_data)} sample documents")
# Generate candidates
self.console.rule("[bold]Generating Candidates[/bold]")
candidates = self.generate_candidates(op_name, sample_data, target_op)
# Filter out failed candidates
valid_candidates = [c for c in candidates if c["ops"] is not None]
self._log(f"[bold]{len(valid_candidates)}[/bold] valid candidates generated")
if len(valid_candidates) < 2:
# Only original (or nothing) - return original
self._log(
"[yellow]No alternative decompositions generated. Keeping original.[/yellow]"
)
return {
"decomposed_ops": self.operators,
"winning_directive": "original",
"candidates_evaluated": len(valid_candidates),
"original_outputs": [],
"decomposed_outputs": [],
"comparison_rationale": "No alternative decompositions were generated.",
"cost": self.total_cost,
}
# Run each candidate on samples IN PARALLEL
self.console.rule("[bold]Running Candidates on Samples[/bold]")
self._log(f"Running {len(valid_candidates)} candidates in parallel...")
def run_single_candidate(candidate):
"""Run a single candidate and return results."""
name = candidate["name"]
try:
outputs = self.run_candidate_on_samples(candidate, sample_data, op_name)
return {"name": name, "outputs": outputs, "error": None}
except Exception as e:
error_tb = traceback.format_exc()
return {
"name": name,
"outputs": [],
"error": str(e),
"traceback": error_tb,
}
# Run all candidates in parallel
with self.console.status(
"[bold cyan]Running candidates on samples...[/bold cyan]", spinner="dots"
):
with ThreadPoolExecutor(max_workers=len(valid_candidates)) as executor:
future_to_candidate = {
executor.submit(run_single_candidate, c): c
for c in valid_candidates
}
for future in as_completed(future_to_candidate):
candidate = future_to_candidate[future]
result = future.result()
if result["error"]:
candidate["outputs"] = []
candidate["run_error"] = result["error"]
self._log(
f" [red]✗[/red] {result['name']} failed: {result['error']}"
)
else:
candidate["outputs"] = result["outputs"]
self._log(
f" [green]✓[/green] {result['name']}: {len(result['outputs'])} outputs"
)
# Filter to candidates with outputs
candidates_with_outputs = [c for c in valid_candidates if c.get("outputs")]
self._log(
f"[bold]{len(candidates_with_outputs)}[/bold] candidates produced outputs"
)
if not candidates_with_outputs:
# All failed - return original
self._log("[red]All decomposition candidates failed to execute.[/red]")
# Log the errors for debugging
for c in valid_candidates:
if c.get("run_error"):
self._log(f" [red]{c['name']}:[/red] {c['run_error']}")
return {
"decomposed_ops": self.operators,
"winning_directive": "original",
"candidates_evaluated": 0,
"original_outputs": [],
"decomposed_outputs": [],
"comparison_rationale": "All decomposition candidates failed to execute.",
"cost": self.total_cost,
}
# Find the original candidate's outputs
original_candidate = next(
(c for c in candidates_with_outputs if c["name"] == "original"), None
)
original_outputs = original_candidate["outputs"] if original_candidate else []
# Parallel pairwise comparison: compare all candidates against original
self.console.rule("[bold]Pairwise Comparison[/bold]")
original_prompt = target_op.get("prompt", "")
output_schema = target_op.get("output", {}).get("schema", {})
# If only original exists, it wins by default
if len(candidates_with_outputs) == 1:
winner = candidates_with_outputs[0]
elif original_candidate is None:
# No original - just pick the first candidate
winner = candidates_with_outputs[0]
else:
# Compare all non-original candidates against original IN PARALLEL
challengers = [
c for c in candidates_with_outputs if c["name"] != "original"
]
if not challengers:
winner = original_candidate
else:
self._log(
f"Comparing {len(challengers)} candidates against original in parallel..."
)
def compare_against_original(challenger):
"""Compare a single challenger against original."""
try:
result = self.pairwise_compare(
original_candidate,
challenger,
original_prompt,
output_schema,
)
won = result["name"] == challenger["name"]
return {
"challenger": challenger,
"won": won,
"rationale": result.get("comparison_rationale", ""),
}
except Exception as e:
return {
"challenger": challenger,
"won": False,
"error": str(e),
}
# Run all comparisons in parallel
comparison_results = []
with self.console.status(
f"[bold cyan]Running {len(challengers)} comparisons in parallel...[/bold cyan]",
spinner="dots",
):
with ThreadPoolExecutor(max_workers=len(challengers)) as executor:
future_to_challenger = {
executor.submit(compare_against_original, c): c
for c in challengers
}
for future in as_completed(future_to_challenger):
result = future.result()
comparison_results.append(result)
challenger_name = result["challenger"]["name"]
if result.get("error"):
self._log(
f" [red]✗[/red] {challenger_name} vs original: error - {result['error']}"
)
elif result["won"]:
self._log(
f" [green]✓[/green] {challenger_name} beat original"
)
else:
self._log(
f" [dim]○[/dim] {challenger_name} lost to original"
)
# Find winners (candidates that beat original)
winners = [r for r in comparison_results if r.get("won")]
if not winners:
# Original beats all challengers
winner = original_candidate
self._log("[bold]Original wins against all challengers[/bold]")
elif len(winners) == 1:
# Single winner
winner = winners[0]["challenger"]
winner["comparison_rationale"] = winners[0].get("rationale", "")
else:
# Multiple winners beat original - run tiebreaker comparisons in parallel
self._log(
f"[bold]{len(winners)} candidates beat original - running tiebreaker...[/bold]"
)
# Compare all winners against each other in parallel (round-robin)
winner_candidates = [w["challenger"] for w in winners]
win_counts = {c["name"]: 0 for c in winner_candidates}
# Generate all pairwise matchups
matchups = []
for i, a in enumerate(winner_candidates):
for b in winner_candidates[i + 1 :]:
matchups.append((a, b))
if matchups:
def run_matchup(matchup):
a, b = matchup
try:
result = self.pairwise_compare(
a, b, original_prompt, output_schema
)
return {
"winner": result["name"],
"a": a["name"],
"b": b["name"],
}
except Exception:
return {
"winner": a["name"],
"a": a["name"],
"b": b["name"],
} # Default to first
with self.console.status(
f"[bold cyan]Running {len(matchups)} tiebreaker comparisons...[/bold cyan]",
spinner="dots",
):
with ThreadPoolExecutor(
max_workers=len(matchups)
) as executor:
for result in executor.map(run_matchup, matchups):
win_counts[result["winner"]] += 1
self._log(
f" [dim]{result['a']} vs {result['b']}{result['winner']}[/dim]"
)
# Pick candidate with most wins
best_name = max(win_counts, key=win_counts.get)
winner = next(
c for c in winner_candidates if c["name"] == best_name
)
self._log(
f"[bold]Tiebreaker winner: {best_name} ({win_counts[best_name]} wins)[/bold]"
)
# Extract the decomposed operations
decomposed_ops = self.extract_ops_to_run(winner["ops"], op_name)
# Final summary
self.console.rule("[bold green]Decomposition Complete[/bold green]")
self._log(f"[bold]Winner:[/bold] [green]{winner['name']}[/green]")
self._log(f"[bold]Candidates evaluated:[/bold] {len(candidates_with_outputs)}")
self._log(f"[bold]New operations:[/bold] {len(decomposed_ops)}")
self._log(f"[bold]Total cost:[/bold] ${self.total_cost:.4f}")
return {
"decomposed_ops": decomposed_ops,
"winning_directive": winner["name"],
"candidates_evaluated": len(candidates_with_outputs),
"original_outputs": original_outputs,
"decomposed_outputs": winner.get("outputs", []),
"comparison_rationale": winner.get("comparison_rationale", ""),
"cost": self.total_cost,
}

View File

@ -1,337 +0,0 @@
"""
Fast should_optimize analyzer using a single LLM call.
This module provides a lightweight alternative to the full MapOptimizer/ReduceOptimizer/JoinOptimizer
flow for quickly determining if an operation should be decomposed.
"""
import json
import os
from typing import Any
import litellm
from litellm import completion, model_cost
from docetl.utils import completion_cost, count_tokens
# Drop unsupported params for models like gpt-5 that don't support temperature=0
litellm.drop_params = True
class FastShouldOptimizeAnalyzer:
"""
Analyzes whether an operation should be optimized using a single LLM call.
Instead of running the operation on sample data and using complex evaluation logic,
this reads cached outputs from intermediate files and makes one judgment call.
"""
def __init__(
self,
intermediate_dir: str,
optimizer_model: str = "gpt-5.1",
litellm_kwargs: dict[str, Any] | None = None,
):
"""
Initialize the analyzer.
Args:
intermediate_dir: Path to the directory containing intermediate outputs
optimizer_model: The LLM model to use for analysis (default: gpt-5.1)
litellm_kwargs: Additional kwargs to pass to litellm.completion
"""
self.intermediate_dir = intermediate_dir
self.optimizer_model = optimizer_model
self.litellm_kwargs = litellm_kwargs or {}
if "temperature" not in self.litellm_kwargs:
self.litellm_kwargs["temperature"] = 0.0
def load_operation_data(self, step_name: str, op_name: str) -> list[dict[str, Any]]:
"""
Load data from the intermediate file for an operation.
Args:
step_name: Name of the pipeline step
op_name: Name of the operation
Returns:
List of dictionaries
Raises:
FileNotFoundError: If the intermediate file doesn't exist
"""
output_path = os.path.join(self.intermediate_dir, step_name, f"{op_name}.json")
if not os.path.exists(output_path):
raise FileNotFoundError(
f"No output file found at {output_path}. "
"Run the operation first to generate outputs."
)
with open(output_path, "r") as f:
return json.load(f)
def find_previous_operation(
self, operations: list[dict[str, Any]], op_name: str
) -> str | None:
"""
Find the operation that comes before op_name in the pipeline.
Args:
operations: List of operation configs from the pipeline
op_name: Name of the current operation
Returns:
Name of the previous operation, or None if this is the first operation
"""
op_names = [op.get("name") for op in operations]
try:
idx = op_names.index(op_name)
if idx > 0:
return op_names[idx - 1]
except ValueError:
pass
return None
def get_max_context_tokens(self) -> int:
"""
Get the maximum input tokens for the optimizer model.
Returns:
Maximum number of input tokens the model can handle
"""
model_info = model_cost.get(self.optimizer_model, {})
# Try without provider prefix if not found
if not model_info:
model_name = self.optimizer_model.split("/")[-1]
model_info = model_cost.get(model_name, {})
return model_info.get("max_input_tokens", 128000) # Default to 128k
def calculate_samples_that_fit(
self,
op_config: dict[str, Any],
outputs: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""
Calculate how many output samples fit in the context window.
Reserves space for system prompt, operation config, and response buffer,
then fills remaining space with as many samples as possible.
Args:
op_config: The operation configuration dictionary
outputs: List of all output documents
Returns:
List of samples that fit in the context window
"""
max_tokens = self.get_max_context_tokens()
# Reserve tokens for fixed parts
system_prompt_tokens = 500
op_config_tokens = count_tokens(
json.dumps(op_config, default=str), self.optimizer_model
)
response_buffer = 2000
available_for_samples = (
max_tokens - system_prompt_tokens - op_config_tokens - response_buffer
)
# Collect samples that fit
samples_to_include = []
tokens_used = 0
for output in outputs:
sample_json = json.dumps(output, default=str)
sample_tokens = count_tokens(sample_json, self.optimizer_model)
if tokens_used + sample_tokens <= available_for_samples:
samples_to_include.append(output)
tokens_used += sample_tokens
else:
break
return samples_to_include
def build_analysis_prompt(
self,
op_config: dict[str, Any],
samples: list[dict[str, Any]],
) -> tuple[str, str]:
"""
Build the system prompt and user prompt for the analysis LLM call.
Args:
op_config: The operation configuration dictionary
samples: List of output samples to analyze
Returns:
Tuple of (system_prompt, user_prompt)
"""
system_prompt = """You are an expert at analyzing LLM-powered data processing operations.
Your task is to determine if an operation would benefit from being decomposed into multiple
smaller, focused operations (also known as "optimization" or "decomposition").
An operation SHOULD be decomposed when:
1. The prompt asks the LLM to do multiple distinct tasks that could be done separately
2. The task is complex enough that breaking it into sequential steps would improve accuracy
3. The outputs show inconsistency, incompleteness, or quality issues that iterative refinement could fix
4. Long documents need to be processed in chunks rather than all at once
5. The prompt asks for both extraction AND analysis/synthesis in one step
An operation should NOT be decomposed when:
1. It performs a single, focused task well
2. The outputs are consistently high quality and complete
3. The task is simple and atomic (e.g., simple classification, single field extraction)
4. The operation is already well-scoped and produces reliable results
Be conservative - only recommend decomposition if there's clear evidence it would help."""
output_schema = op_config.get("output", {}).get("schema", {})
prompt_template = op_config.get("prompt", "No prompt specified")
# Truncate very long prompts for display
if len(prompt_template) > 3000:
prompt_template = prompt_template[:3000] + "\n... [truncated]"
user_prompt = f"""Analyze this data processing operation and its outputs:
## Operation Configuration
**Name:** {op_config.get('name', 'unknown')}
**Type:** {op_config.get('type', 'unknown')}
**Prompt Template:**
```
{prompt_template}
```
**Output Schema:**
```json
{json.dumps(output_schema, indent=2)}
```
## Sample Outputs ({len(samples)} samples from the operation)
```json
{json.dumps(samples, indent=2, default=str)}
```
## Your Task
Based on the operation configuration and sample outputs, determine:
1. Should this operation be decomposed/optimized?
2. If yes, what specific improvements would help?
Consider the quality, completeness, and consistency of the outputs when making your assessment."""
return system_prompt, user_prompt
def analyze(
self,
op_config: dict[str, Any],
step_name: str,
op_name: str,
) -> tuple[str, list[dict[str, Any]], int, float]:
"""
Analyze whether an operation should be optimized.
This is the main entry point. It loads outputs, builds the prompt,
makes a single LLM call, and returns the assessment.
Args:
op_config: The operation configuration dictionary
step_name: Name of the pipeline step
op_name: Name of the operation
Returns:
Tuple of (rationale, output_samples, num_docs_analyzed, cost):
- rationale: Empty string if no optimization needed, explanation if it should be optimized
- output_samples: The samples that were analyzed
- num_docs_analyzed: Number of documents that fit in the LLM prompt
- cost: LLM API cost in USD
Raises:
ValueError: If the operation is not an LLM-powered map operation
"""
# Validate operation type - only LLM-powered map operations
op_type = op_config.get("type", "")
if op_type != "map":
raise ValueError(
f"should_optimize only supports 'map' operations, got '{op_type}'. "
"Only LLM-powered map operations can be analyzed for decomposition."
)
# Load outputs from intermediate file
outputs = self.load_operation_data(step_name, op_name)
if not outputs:
return "No output samples available for analysis.", [], 0, 0.0
# Calculate samples that fit in context
samples = self.calculate_samples_that_fit(op_config, outputs)
if not samples:
return "Could not fit any samples in context window.", outputs[:5], 0, 0.0
# Build prompt
system_prompt, user_prompt = self.build_analysis_prompt(op_config, samples)
# Make LLM call with structured output
response = completion(
model=self.optimizer_model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format={
"type": "json_schema",
"json_schema": {
"name": "optimization_analysis",
"strict": True,
"schema": {
"type": "object",
"properties": {
"should_optimize": {
"type": "boolean",
"description": "True if operation should be decomposed/optimized",
},
"rationale": {
"type": "string",
"description": "Explanation of why the operation should or should not be optimized",
},
"suggested_improvements": {
"type": "array",
"items": {"type": "string"},
"description": "Specific improvements if optimization is recommended (empty if not)",
},
},
"required": [
"should_optimize",
"rationale",
"suggested_improvements",
],
"additionalProperties": False,
},
},
},
**self.litellm_kwargs,
)
# Calculate cost
cost = completion_cost(response)
# Parse response
result = json.loads(response.choices[0].message.content)
num_docs_analyzed = len(samples)
if result["should_optimize"]:
# Build rationale string with improvements
rationale_parts = [result["rationale"]]
if result["suggested_improvements"]:
rationale_parts.append("\n\nSuggested improvements:")
for imp in result["suggested_improvements"]:
rationale_parts.append(f"- {imp}")
return "\n".join(rationale_parts), samples, num_docs_analyzed, cost
else:
# Return empty string to indicate no optimization needed
return "", samples, num_docs_analyzed, cost

View File

@ -30,7 +30,7 @@ class LLMClient:
Initialize the LLMClient.
Args:
model (str, optional): The name of the LLM model to use. Defaults to "gpt-5.1".
model (str, optional): The name of the LLM model to use. Defaults to "gpt-4o".
**litellm_kwargs: Additional keyword arguments for the LLM model.
"""
self.rewrite_agent_model = rewrite_agent_model

View File

@ -204,6 +204,10 @@ def get_openai_response(
response = litellm.completion(
model=model,
messages=messages,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ResponseFormat,
)
assistant_response = response.choices[0].message.content

View File

@ -30,8 +30,6 @@ from .map_reduce_fusion import MapReduceFusionDirective
from .hierarchical_reduce import HierarchicalReduceDirective
from .cascade_filtering import CascadeFilteringDirective
from .arbitrary_rewrite import ArbitraryRewriteDirective
from .map_to_map_resolve_reduce import MapToMapResolveReduceDirective
from .map_resolve_to_map_with_categories import MapResolveToMapWithCategoriesDirective
# Registry of all available directives
ALL_DIRECTIVES = [
@ -55,8 +53,6 @@ ALL_DIRECTIVES = [
HierarchicalReduceDirective(),
CascadeFilteringDirective(),
ArbitraryRewriteDirective(),
MapToMapResolveReduceDirective(),
MapResolveToMapWithCategoriesDirective(),
]
ALL_COST_DIRECTIVES = [
@ -183,10 +179,8 @@ __all__ = [
"HierarchicalReduceDirective",
"CascadeFilteringDirective",
"ArbitraryRewriteDirective",
"MapToMapResolveReduceDirective",
"MapResolveToMapWithCategoriesDirective",
"ALL_DIRECTIVES",
"DIRECTIVE_REGISTRY",
"DIRECTIVE_REGISTRY",
"get_all_directive_strings",
"instantiate_directive"
]

View File

@ -324,6 +324,10 @@ Focus on quality over quantity - a few diverse, informative examples are better
model=self.agent_llm,
messages=self.message_history,
response_format=AgentDecision,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
)
call_cost += response._hidden_params["response_cost"]
@ -411,6 +415,10 @@ Provide your response as a JSON object matching this schema: {response_schema.mo
model=self.agent_llm,
messages=self.message_history,
response_format=response_schema,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
)
call_cost += schema_response._hidden_params["response_cost"]

View File

@ -123,8 +123,7 @@ class Directive(BaseModel, ABC):
try:
# 1. Execute the directive
# instantiate returns (new_ops_plan, message_history, call_cost)
instantiate_result = self.instantiate(
actual_output, _ = self.instantiate(
operators=(
[test_case.input_config]
if isinstance(test_case.input_config, dict)
@ -135,11 +134,6 @@ class Directive(BaseModel, ABC):
input_file_path=temp_file_path,
pipeline_code=fake_pipeline,
)
# Handle both 2-tuple and 3-tuple returns
if isinstance(instantiate_result, tuple):
actual_output = instantiate_result[0]
else:
actual_output = instantiate_result
# 2. Use LLM judge to evaluate
judge_result = self._llm_judge_test(
@ -222,6 +216,7 @@ class Directive(BaseModel, ABC):
model=agent_llm,
messages=messages,
response_format=JudgeResponse,
azure=True,
)
# Parse the JSON response

View File

@ -1,4 +1,5 @@
import json
import os
import re
from copy import deepcopy
from typing import Dict, List, Type
@ -168,15 +169,19 @@ class ChainingDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
# api_key=os.environ["GEMINI_API_KEY"],
azure=True,
response_format=ChainingInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
@ -197,12 +202,11 @@ class ChainingDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(

View File

@ -1,4 +1,5 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -210,14 +211,18 @@ class ChangeModelDirective(Directive):
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
# api_key=os.environ["GEMINI_API_KEY"],
azure=True,
response_format=ChangeModelInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ChangeModelInstantiateSchema(**parsed_res)
@ -235,12 +240,11 @@ class ChangeModelDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(

View File

@ -1,4 +1,5 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -204,14 +205,17 @@ class ChangeModelAccDirective(Directive):
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ChangeModelInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ChangeModelInstantiateSchema(**parsed_res)
@ -229,12 +233,11 @@ class ChangeModelAccDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(

View File

@ -1,4 +1,5 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -281,14 +282,17 @@ class ChangeModelCostDirective(Directive):
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ChangeModelInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ChangeModelInstantiateSchema(**parsed_res)
@ -306,12 +310,11 @@ class ChangeModelCostDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(

View File

@ -1,4 +1,5 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -184,14 +185,17 @@ class ChunkHeaderSummaryDirective(Directive):
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ChunkHeaderSummaryInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ChunkHeaderSummaryInstantiateSchema(**parsed_res)
@ -200,12 +204,11 @@ class ChunkHeaderSummaryDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(
@ -248,7 +251,9 @@ class ChunkHeaderSummaryDirective(Directive):
)
if gather_idx - split_idx > 1:
raise ValueError("There should not be operators between split and gather")
raise ValueError(
"There should not be operators between split and gather"
)
# Get the split_key from the split operation
split_key = split_op.get("split_key")

View File

@ -1,4 +1,5 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -260,14 +261,17 @@ def transform(input_doc):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=DeterministicDocCompressionInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = DeterministicDocCompressionInstantiateSchema(**parsed_res)
@ -280,12 +284,11 @@ def transform(input_doc):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(

View File

@ -1,4 +1,5 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -267,15 +268,18 @@ class DocumentChunkingDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
call_cost = 0.0
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=DocumentChunkingInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = DocumentChunkingInstantiateSchema(**parsed_res)
@ -286,12 +290,11 @@ class DocumentChunkingDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(
@ -401,19 +404,16 @@ class DocumentChunkingDirective(Directive):
sample_op["stratify_key"] = stratify_keys
if rewrite.sampling_config.samples_per_group:
sample_op["samples_per_group"] = (
rewrite.sampling_config.samples_per_group
)
sample_op["samples_per_group"] = rewrite.sampling_config.samples_per_group
# Add optional fields if provided
if rewrite.sampling_config.random_state is not None:
sample_op["random_state"] = rewrite.sampling_config.random_state
if rewrite.sampling_config.method_kwargs is not None:
try:
sample_op["method_kwargs"] = json.loads(
rewrite.sampling_config.method_kwargs
)
sample_op["method_kwargs"] = json.loads(rewrite.sampling_config.method_kwargs)
except Exception as e:
raise ValueError(f"Invalid method_kwargs: {e}")

View File

@ -1,4 +1,5 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -430,14 +431,17 @@ class DocumentChunkingTopKDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=DocumentChunkingTopKInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = DocumentChunkingTopKInstantiateSchema(**parsed_res)
@ -447,12 +451,11 @@ class DocumentChunkingTopKDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(

View File

@ -1,4 +1,5 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -163,14 +164,19 @@ class DocCompressionDirective(Directive):
},
]
)
last_error = None
error_message = ""
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=DocCompressionInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
@ -181,12 +187,11 @@ class DocCompressionDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Error: {error_message}"
)
def apply(

View File

@ -1,4 +1,5 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -260,14 +261,17 @@ class DocSummarizationDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=DocSummarizationInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = DocSummarizationInstantiateSchema(**parsed_res)
@ -276,12 +280,11 @@ class DocSummarizationDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(

View File

@ -1,4 +1,5 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -145,31 +146,32 @@ class GleaningDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
# api_key=os.environ["GEMINI_API_KEY"],
azure=True,
response_format=GleaningInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = GleaningInstantiateSchema(**parsed_res)
GleaningInstantiateSchema.check_no_jinja_variables(
schema.validation_prompt
)
GleaningInstantiateSchema.check_no_jinja_variables(schema.validation_prompt)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(

View File

@ -1,4 +1,5 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -179,14 +180,19 @@ class HierarchicalReduceDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=HierarchicalReduceInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
@ -208,12 +214,11 @@ class HierarchicalReduceDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(
@ -254,7 +259,7 @@ class HierarchicalReduceDirective(Directive):
"name": rewrite.map_config.name,
"type": "map",
"prompt": rewrite.map_config.prompt,
"model": default_model,
"model": default_model,
"litellm_completion_kwargs": {"temperature": 0},
"output": {"schema": {rewrite.map_config.output_keys[0]: "string"}},
}
@ -319,4 +324,4 @@ class HierarchicalReduceDirective(Directive):
new_ops_plan = self.apply(
global_default_model, operators, target_ops[0], rewrite
)
return new_ops_plan, message_history, call_cost
return new_ops_plan, message_history, call_cost

View File

@ -1,4 +1,5 @@
import json
import os
import re
from copy import deepcopy
from typing import Dict, List, Type
@ -259,14 +260,17 @@ class IsolatingSubtasksDirective(Directive):
original_op.get("output", {}).get("schema", {}).keys()
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=IsolatingSubtasksInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = IsolatingSubtasksInstantiateSchema(**parsed_res)
@ -281,12 +285,11 @@ class IsolatingSubtasksDirective(Directive):
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease ensure all original output keys are covered by subtasks and try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(

View File

@ -1,4 +1,5 @@
import json
import os
import re
from copy import deepcopy
from typing import Dict, List, Type
@ -6,20 +7,14 @@ from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
MapReduceFusionInstantiateSchema,
)
from docetl.reasoning_optimizer.instantiate_schemas import MapReduceFusionInstantiateSchema
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class MapReduceFusionDirective(Directive):
name: str = Field(
default="map_reduce_fusion", description="The name of the directive"
)
formal_description: str = Field(
default="Map -> Reduce => Map (with new prompt) -> Reduce (with new propmt)"
)
name: str = Field(default="map_reduce_fusion", description="The name of the directive")
formal_description: str = Field(default="Map -> Reduce => Map (with new prompt) -> Reduce (with new propmt)")
nl_description: str = Field(
default="Transform a Map operation followed by a Reduce operation over long documents by updating the Map prompt to first extract or process only the relevant information from each document. Then, modify the Reduce prompt to operate on these processed outputs instead of the full document content."
)
@ -47,7 +42,7 @@ class MapReduceFusionDirective(Directive):
" type: reduce\\n"
" reduce_key: category\\n"
" prompt: |\\n"
' For each category \\"{{ reduce_key }}\\", extract all organization names from these documents:\\n'
" For each category \\\"{{ reduce_key }}\\\", extract all organization names from these documents:\\n"
" {% for input in inputs %}\\n"
" Document {{ loop.index }}: {{ input.content }}\\n"
" {% endfor %}\\n"
@ -84,7 +79,7 @@ class MapReduceFusionDirective(Directive):
"reduce_key": "category",
"prompt": "For each category '{{ reduce_key }}', extract organizations from: {% for input in inputs %}{{ input.content }}{% endfor %}",
"output": {"schema": {"organizations": "list[str]"}},
},
}
],
target_ops=["classify_document", "extract_organizations"],
expected_behavior="Should modify map to classify AND extract organizations per document, then reduce to aggregate pre-extracted organizations",
@ -106,7 +101,7 @@ class MapReduceFusionDirective(Directive):
"reduce_key": "doc_type",
"prompt": "For each document type '{{ reduce_key }}', find all people mentioned: {% for input in inputs %}Document: {{ input.text }}{% endfor %}",
"output": {"schema": {"people": "list[str]"}},
},
}
],
target_ops=["analyze_document", "find_people"],
expected_behavior="Should modify map to analyze type AND extract people per document, then reduce to combine pre-extracted people lists",
@ -164,7 +159,7 @@ class MapReduceFusionDirective(Directive):
reduce_op: Dict,
expected_document_key,
agent_llm: str,
message_history: list = [],
message_history: list = []
):
"""
Use LLM to instantiate this directive by transforming the map and reduce operations.
@ -178,85 +173,66 @@ class MapReduceFusionDirective(Directive):
Returns:
MapReduceFusionInstantiateSchema: The structured output from the LLM.
"""
message_history.extend([
{"role": "system", "content": "You are a helpful AI assistant for optimizing document processing pipelines."},
{"role": "user", "content": self.to_string_for_instantiate([map_op, reduce_op])},
])
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for optimizing document processing pipelines.",
},
{
"role": "user",
"content": self.to_string_for_instantiate([map_op, reduce_op]),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=MapReduceFusionInstantiateSchema,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=MapReduceFusionInstantiateSchema
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
if (
"new_map_name" not in parsed_res
or "new_map_prompt" not in parsed_res
or "new_key" not in parsed_res
or "new_reduce_prompt" not in parsed_res
):
if "new_map_name" not in parsed_res or "new_map_prompt" not in parsed_res or "new_key" not in parsed_res or "new_reduce_prompt" not in parsed_res:
raise ValueError(
"Response from LLM is missing required keys: 'new_map_name', 'new_map_prompt', 'new_key', or 'new_reduce_prompt'"
)
schema = MapReduceFusionInstantiateSchema(
new_map_name=parsed_res["new_map_name"],
new_map_prompt=parsed_res["new_map_prompt"],
new_key=parsed_res["new_key"],
new_reduce_prompt=parsed_res["new_reduce_prompt"],
new_reduce_prompt=parsed_res["new_reduce_prompt"]
)
# Validate the schema
MapReduceFusionInstantiateSchema.validate_reduce_prompt_references_new_key(
schema.new_reduce_prompt, schema.new_key, expected_document_key
)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model,
ops_list: List[Dict],
map_target: str,
reduce_target: str,
rewrite: MapReduceFusionInstantiateSchema,
) -> List[Dict]:
raise Exception(f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts.")
def apply(self, global_default_model, ops_list: List[Dict], map_target: str, reduce_target: str, rewrite: MapReduceFusionInstantiateSchema) -> List[Dict]:
"""
Apply the directive to the pipeline config.
"""
# Create a copy of the pipeline config
new_ops_list = deepcopy(ops_list)
# Find positions of the target ops
map_pos = None
reduce_pos = None
orig_map_op = None
orig_reduce_op = None
for i, op in enumerate(ops_list):
if op["name"] == map_target:
map_pos = i
@ -264,15 +240,13 @@ class MapReduceFusionDirective(Directive):
elif op["name"] == reduce_target:
reduce_pos = i
orig_reduce_op = op
if map_pos is None or reduce_pos is None:
raise ValueError(
f"Could not find target operations: {map_target}, {reduce_target}"
)
raise ValueError(f"Could not find target operations: {map_target}, {reduce_target}")
# Get default model
default_model = orig_map_op.get("model", global_default_model)
# Create the new map operation with fused functionality
new_map_op = {
"name": rewrite.new_map_name,
@ -283,11 +257,11 @@ class MapReduceFusionDirective(Directive):
"output": {
"schema": {
**orig_map_op.get("output", {}).get("schema", {}),
rewrite.new_key: "list[str]",
rewrite.new_key: "list[str]"
}
},
}
}
# Create the new reduce operation that works with pre-extracted data
new_reduce_op = {
"name": orig_reduce_op["name"],
@ -296,58 +270,54 @@ class MapReduceFusionDirective(Directive):
"prompt": rewrite.new_reduce_prompt,
"model": default_model,
"litellm_completion_kwargs": {"temperature": 0},
"output": orig_reduce_op.get("output", {}),
"output": orig_reduce_op.get("output", {})
}
# Replace the operations
new_ops_list[map_pos] = new_map_op
new_ops_list[reduce_pos] = new_reduce_op
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
**kwargs,
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
**kwargs
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there are exactly two target ops (map and reduce)
if len(target_ops) != 2:
raise ValueError(
"map_reduce_fusion directive requires exactly two target operators: a map and a reduce operation"
)
raise ValueError("map_reduce_fusion directive requires exactly two target operators: a map and a reduce operation")
# Find the map and reduce operations
first_op = None
second_op = None
for op in operators:
if op["name"] == target_ops[0]:
first_op = op
elif op["name"] == target_ops[1]:
second_op = op
if first_op is None or second_op is None:
raise ValueError(f"Could not find target operations: {target_ops}")
if first_op.get("type") == "map" and second_op.get("type") == "reduce":
map_op = first_op
reduce_op = second_op
map_target = target_ops[0]
reduce_target = target_ops[1]
else:
raise ValueError(
"Target operators must be one map operation followed by one reduce operation!"
)
# Extract expected document key from the reduce prompt template
raise ValueError("Target operators must be one map operation followed by one reduce operation!")
# Extract expected document key from the reduce prompt template
prompt_template = reduce_op["prompt"]
# Find all occurrences of {{ input.key }} in the prompt
input_key_pattern = r"\{\{\s*([^\}\s]+)\s*\}\}"
@ -365,9 +335,7 @@ class MapReduceFusionDirective(Directive):
]
if document_key_candidates:
expected_document_key = document_key_candidates[
0
] # Pick the first candidate
expected_document_key = document_key_candidates[0] # Pick the first candidate
elif input_keys:
expected_document_key = input_keys[0] # Fall back to the first input key
else:
@ -376,12 +344,8 @@ class MapReduceFusionDirective(Directive):
print(f"Detected document key: {expected_document_key}")
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
map_op, reduce_op, expected_document_key, agent_llm, message_history
)
rewrite, message_history, call_cost = self.llm_instantiate(map_op, reduce_op, expected_document_key, agent_llm, message_history)
# Apply the rewrite to the operators
new_ops_plan = self.apply(
global_default_model, operators, map_target, reduce_target, rewrite
)
return new_ops_plan, message_history, call_cost
new_ops_plan = self.apply(global_default_model, operators, map_target, reduce_target, rewrite)
return new_ops_plan, message_history, call_cost

View File

@ -1,361 +0,0 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
MapResolveToMapWithCategoriesInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class MapResolveToMapWithCategoriesDirective(Directive):
name: str = Field(
default="map_resolve_to_map_with_categories",
description="The name of the directive",
)
formal_description: str = Field(default="Map -> Resolve => Map (with categories)")
nl_description: str = Field(
default="Replace a Map -> Resolve pattern with a single Map operation that has predefined categories. The agent analyzes the data and task to propose a set of canonical categories, and the new Map forces outputs into one of these categories (or 'none of the above'). This effectively performs entity resolution deterministically by standardizing outputs upfront, avoiding the need for pairwise comparisons."
)
when_to_use: str = Field(
default="When a Map operation produces outputs that are then resolved/deduplicated, and the set of valid output categories can be enumerated upfront. This is more efficient than Resolve when the category space is small and well-defined (e.g., standardizing company types, product categories, sentiment labels). The target must be a Map operation followed by a Resolve operation."
)
instantiate_schema_type: Type[BaseModel] = (
MapResolveToMapWithCategoriesInstantiateSchema
)
example: str = Field(
default=(
"Original Pipeline:\n"
"- name: extract_sentiment\n"
" type: map\n"
" prompt: |\n"
" What is the sentiment of this review?\n"
" {{ input.review }}\n"
" output:\n"
" schema:\n"
" sentiment: string\n"
"\n"
"- name: normalize_sentiment\n"
" type: resolve\n"
" comparison_prompt: |\n"
" Are these sentiments equivalent?\n"
" Sentiment 1: {{ input1.sentiment }}\n"
" Sentiment 2: {{ input2.sentiment }}\n"
" resolution_prompt: |\n"
" Normalize these sentiments:\n"
" {% for input in inputs %}\n"
" - {{ input.sentiment }}\n"
" {% endfor %}\n"
" output:\n"
" schema:\n"
" sentiment: string\n"
"\n"
"Example InstantiateSchema:\n"
"MapResolveToMapWithCategoriesInstantiateSchema(\n"
" categories=['Positive', 'Negative', 'Neutral', 'Mixed'],\n"
" category_key='sentiment',\n"
" new_prompt='''Analyze the sentiment of this review and classify it into one of the following categories:\n"
"\n"
"Categories:\n"
"- Positive: Clearly positive sentiment, satisfaction, praise\n"
"- Negative: Clearly negative sentiment, complaints, criticism\n"
"- Neutral: No strong sentiment, factual statements\n"
"- Mixed: Contains both positive and negative elements\n"
"- None of the above: If the review doesn't fit any category\n"
"\n"
"Review: {{ input.review }}\n"
"\n"
"Return exactly one of: Positive, Negative, Neutral, Mixed, or None of the above.''',\n"
" include_none_of_above=True,\n"
")"
),
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="sentiment_categorization",
description="Should replace map+resolve with categorized map for sentiment",
input_config=[
{
"name": "extract_sentiment",
"type": "map",
"prompt": "What is the sentiment? {{ input.text }}",
"output": {"schema": {"sentiment": "string"}},
},
{
"name": "normalize_sentiment",
"type": "resolve",
"comparison_prompt": "Same sentiment? {{ input1.sentiment }} vs {{ input2.sentiment }}",
"resolution_prompt": "Normalize: {% for input in inputs %}{{ input.sentiment }}{% endfor %}",
"output": {"schema": {"sentiment": "string"}},
},
],
target_ops=["extract_sentiment", "normalize_sentiment"],
expected_behavior="Should create a single map with predefined sentiment categories",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, MapResolveToMapWithCategoriesDirective)
def __hash__(self):
return hash("MapResolveToMapWithCategoriesDirective")
def to_string_for_instantiate(
self, map_op: Dict, resolve_op: Dict, sample_data: List[Dict] = None
) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
map_op (Dict): The map operation configuration.
resolve_op (Dict): The resolve operation configuration.
sample_data (List[Dict], optional): Sample data to help identify categories.
Returns:
str: The agent prompt for instantiating the directive.
"""
sample_str = ""
if sample_data:
sample_str = f"\n\nSample Input Data (first 5 items):\n{json.dumps(sample_data[:5], indent=2)}"
return (
f"You are an expert at optimizing data processing pipelines by replacing entity resolution with categorical constraints.\n\n"
f"Map Operation:\n"
f"{str(map_op)}\n\n"
f"Resolve Operation:\n"
f"{str(resolve_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to replace the Map -> Resolve pattern with a single Map that uses predefined categories.\n\n"
f"Key Requirements:\n"
f"1. Analyze the map's output field and the resolve operation to understand what values are being normalized:\n"
f" - Look at the comparison_prompt to understand what variations are being matched\n"
f" - Look at the resolution_prompt to understand the canonical form\n\n"
f"2. Propose a set of categories that cover all expected outputs:\n"
f" - Categories should be mutually exclusive\n"
f" - Categories should be exhaustive (cover all realistic cases)\n"
f" - Consider including 'None of the above' for edge cases\n\n"
f"3. Optionally provide descriptions for each category to help the LLM classify correctly\n\n"
f"4. Create a new_prompt that:\n"
f" - Lists all valid categories with their descriptions\n"
f" - Instructs the LLM to output exactly one category\n"
f" - References the input using {{{{ input.key }}}} syntax\n"
f" - Includes any context from the original map prompt\n\n"
f"5. Identify the category_key (the output field that will contain the category)\n\n"
f"Benefits of this approach:\n"
f"- Eliminates O(n^2) pairwise comparisons from Resolve\n"
f"- Produces consistent, standardized outputs\n"
f"- Reduces cost by removing the Resolve operation entirely\n"
f"{sample_str}\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please analyze the operations and propose appropriate categories. Output the MapResolveToMapWithCategoriesInstantiateSchema."
)
def llm_instantiate(
self,
map_op: Dict,
resolve_op: Dict,
agent_llm: str,
message_history: list = [],
sample_data: List[Dict] = None,
):
"""
Use LLM to instantiate this directive.
Args:
map_op (Dict): The map operation configuration.
resolve_op (Dict): The resolve operation configuration.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
sample_data (List[Dict], optional): Sample data to help identify categories.
Returns:
MapResolveToMapWithCategoriesInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for document processing pipelines specializing in categorical classification.",
},
{
"role": "user",
"content": self.to_string_for_instantiate(
map_op, resolve_op, sample_data
),
},
]
)
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=MapResolveToMapWithCategoriesInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = MapResolveToMapWithCategoriesInstantiateSchema(**parsed_res)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(
self,
global_default_model,
ops_list: List[Dict],
map_op_name: str,
resolve_op_name: str,
rewrite: MapResolveToMapWithCategoriesInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config.
"""
new_ops_list = deepcopy(ops_list)
# Find position of the map and resolve ops
map_pos = None
resolve_pos = None
map_op = None
for i, op in enumerate(new_ops_list):
if op["name"] == map_op_name:
map_pos = i
map_op = op
elif op["name"] == resolve_op_name:
resolve_pos = i
if map_pos is None or resolve_pos is None:
raise ValueError(
f"Could not find map '{map_op_name}' and resolve '{resolve_op_name}' operations"
)
# Determine the model to use
default_model = map_op.get("model", global_default_model)
# Build the list of valid values for validation
valid_values = list(rewrite.categories)
if rewrite.include_none_of_above:
valid_values.append("None of the above")
# Modify the map operation with the new prompt and add validation
new_ops_list[map_pos]["prompt"] = rewrite.new_prompt
new_ops_list[map_pos]["model"] = default_model
# Add validation to ensure output is one of the categories
if "validate" not in new_ops_list[map_pos]:
new_ops_list[map_pos]["validate"] = []
# Add validation rule for the category key
validation_rule = f"output['{rewrite.category_key}'] in {valid_values}"
new_ops_list[map_pos]["validate"].append(validation_rule)
# Update the output schema to reflect the category key
if "output" not in new_ops_list[map_pos]:
new_ops_list[map_pos]["output"] = {"schema": {}}
new_ops_list[map_pos]["output"]["schema"][rewrite.category_key] = "string"
# Remove the resolve operation
new_ops_list.pop(resolve_pos)
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
dataset: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there are exactly two target ops (map and resolve)
assert (
len(target_ops) == 2
), "There must be exactly two target ops (map and resolve) to instantiate this directive"
# Find the map and resolve operations
map_op = None
resolve_op = None
for op in operators:
if op["name"] == target_ops[0]:
if op.get("type") == "map":
map_op = op
elif op.get("type") == "resolve":
resolve_op = op
elif op["name"] == target_ops[1]:
if op.get("type") == "map":
map_op = op
elif op.get("type") == "resolve":
resolve_op = op
if map_op is None or resolve_op is None:
raise ValueError(
f"Could not find both a map and resolve operation in target_ops: {target_ops}"
)
# Verify the map comes before resolve
map_idx = next(
i for i, op in enumerate(operators) if op["name"] == map_op["name"]
)
resolve_idx = next(
i for i, op in enumerate(operators) if op["name"] == resolve_op["name"]
)
if map_idx >= resolve_idx:
raise ValueError(
f"Map operation '{map_op['name']}' must come before resolve operation '{resolve_op['name']}'"
)
# Load sample data if available
sample_data = None
if dataset:
try:
with open(dataset, "r") as f:
sample_data = json.load(f)
except Exception:
pass # Ignore if we can't load sample data
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
map_op,
resolve_op,
agent_llm,
message_history,
sample_data,
)
# Apply the rewrite to the operators
new_ops_plan = self.apply(
global_default_model, operators, map_op["name"], resolve_op["name"], rewrite
)
return new_ops_plan, message_history, call_cost

View File

@ -1,335 +0,0 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
MapToMapResolveReduceInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class MapToMapResolveReduceDirective(Directive):
name: str = Field(
default="map_to_map_resolve_reduce", description="The name of the directive"
)
formal_description: str = Field(default="Map -> Reduce => Map -> Resolve -> Reduce")
nl_description: str = Field(
default="Insert a Resolve operation between Map and Reduce to deduplicate or normalize entities before aggregation. The Resolve operation uses code-powered blocking conditions to efficiently identify which pairs to compare, avoiding O(n^2) comparisons. This is useful when the Map output contains duplicate or near-duplicate entities that should be merged before the Reduce step."
)
when_to_use: str = Field(
default="When a Map operation produces outputs that may contain duplicates, variations, or near-duplicates (e.g., different spellings of names, similar categories), and these should be normalized before the Reduce aggregation step. The target must be a Map operation followed by a Reduce operation."
)
instantiate_schema_type: Type[BaseModel] = MapToMapResolveReduceInstantiateSchema
example: str = Field(
default=(
"Original Pipeline:\n"
"- name: extract_companies\n"
" type: map\n"
" prompt: |\n"
" Extract company names from this news article:\n"
" {{ input.article }}\n"
" output:\n"
" schema:\n"
" company_name: string\n"
"\n"
"- name: aggregate_companies\n"
" type: reduce\n"
" reduce_key: sector\n"
" prompt: |\n"
" List all unique companies in this sector:\n"
" {% for input in inputs %}\n"
" - {{ input.company_name }}\n"
" {% endfor %}\n"
" output:\n"
" schema:\n"
" companies: list[str]\n"
"\n"
"Example InstantiateSchema:\n"
"MapToMapResolveReduceInstantiateSchema(\n"
" resolve_name='normalize_company_names',\n"
" comparison_prompt='''Are these two company names referring to the same company?\n"
"Company 1: {{ input1.company_name }}\n"
"Company 2: {{ input2.company_name }}\n"
"Consider variations like abbreviations (IBM vs International Business Machines), \n"
"different legal suffixes (Inc, Corp, LLC), and common misspellings.''',\n"
" resolution_prompt='''Given these variations of a company name:\n"
"{% for input in inputs %}\n"
"- {{ input.company_name }}\n"
"{% endfor %}\n"
"Return the canonical/official company name.''',\n"
" blocking_conditions=[\n"
" \"input1['company_name'][:3].lower() == input2['company_name'][:3].lower()\",\n"
" \"input1['company_name'].split()[0].lower() == input2['company_name'].split()[0].lower()\",\n"
" ],\n"
" blocking_keys=['company_name'],\n"
" limit_comparisons=1000,\n"
" output_schema={'company_name': 'string'},\n"
")"
),
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="company_name_normalization",
description="Should insert resolve between map and reduce for company names",
input_config=[
{
"name": "extract_companies",
"type": "map",
"prompt": "Extract company name from: {{ input.text }}",
"output": {"schema": {"company_name": "string"}},
},
{
"name": "aggregate_by_sector",
"type": "reduce",
"reduce_key": "sector",
"prompt": "List companies:\n{% for input in inputs %}\n- {{ input.company_name }}\n{% endfor %}",
"output": {"schema": {"companies": "list[str]"}},
},
],
target_ops=["extract_companies", "aggregate_by_sector"],
expected_behavior="Should create a resolve operation between the map and reduce with appropriate blocking conditions",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, MapToMapResolveReduceDirective)
def __hash__(self):
return hash("MapToMapResolveReduceDirective")
def to_string_for_instantiate(self, map_op: Dict, reduce_op: Dict) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
map_op (Dict): The map operation configuration.
reduce_op (Dict): The reduce operation configuration.
Returns:
str: The agent prompt for instantiating the directive.
"""
return (
f"You are an expert at optimizing data processing pipelines by inserting entity resolution steps.\n\n"
f"Map Operation:\n"
f"{str(map_op)}\n\n"
f"Reduce Operation:\n"
f"{str(reduce_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to insert a Resolve operation between the Map and Reduce to deduplicate/normalize entities.\n\n"
f"Key Requirements:\n"
f"1. Create a comparison_prompt that determines if two items from the Map output are duplicates/variations:\n"
f" - Must reference {{ input1.key }} and {{ input2.key }} for comparing fields\n"
f" - Should handle common variations (abbreviations, misspellings, formatting differences)\n\n"
f"2. Create a resolution_prompt that merges matched items into a canonical form:\n"
f" - Must use {{% for input in inputs %}} to iterate over matched items\n"
f" - Should produce the most authoritative/complete version\n\n"
f"3. Create blocking_conditions to avoid O(n^2) comparisons:\n"
f" - These are Python expressions with access to 'input1' and 'input2' dicts\n"
f" - They should filter pairs to only those likely to match\n"
f" - Examples:\n"
f" * \"input1['name'][:3].lower() == input2['name'][:3].lower()\" (first 3 chars match)\n"
f" * \"input1['name'].split()[0].lower() == input2['name'].split()[0].lower()\" (first word matches)\n"
f" * \"abs(len(input1['name']) - len(input2['name'])) < 10\" (similar length)\n"
f" - Multiple conditions are OR'd together\n\n"
f"4. Set limit_comparisons to cap the number of pairs (recommended: 500-2000)\n\n"
f"5. The output_schema should match what the Reduce operation expects from each input\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please output the MapToMapResolveReduceInstantiateSchema."
)
def llm_instantiate(
self, map_op: Dict, reduce_op: Dict, agent_llm: str, message_history: list = []
):
"""
Use LLM to instantiate this directive.
Args:
map_op (Dict): The map operation configuration.
reduce_op (Dict): The reduce operation configuration.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
Returns:
MapToMapResolveReduceInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for document processing pipelines specializing in entity resolution.",
},
{
"role": "user",
"content": self.to_string_for_instantiate(map_op, reduce_op),
},
]
)
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=MapToMapResolveReduceInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = MapToMapResolveReduceInstantiateSchema(**parsed_res)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(
self,
global_default_model,
ops_list: List[Dict],
map_op_name: str,
reduce_op_name: str,
rewrite: MapToMapResolveReduceInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config.
"""
new_ops_list = deepcopy(ops_list)
# Find position of the map and reduce ops
map_pos = None
reduce_pos = None
map_op = None
for i, op in enumerate(ops_list):
if op["name"] == map_op_name:
map_pos = i
map_op = op
elif op["name"] == reduce_op_name:
reduce_pos = i
if map_pos is None or reduce_pos is None:
raise ValueError(
f"Could not find map '{map_op_name}' and reduce '{reduce_op_name}' operations"
)
# Determine the model to use
default_model = map_op.get("model", global_default_model)
# Find the reduce operation to get the reduce_key
reduce_op = None
for op in ops_list:
if op["name"] == reduce_op_name:
reduce_op = op
break
# Derive output schema from reduce_key - that's what's being grouped/resolved
reduce_key = reduce_op.get("reduce_key", []) if reduce_op else []
if isinstance(reduce_key, str):
reduce_key = [reduce_key]
# Build output schema from reduce_key fields
output_schema = {key: "string" for key in reduce_key}
# Create the resolve operation
resolve_op = {
"name": rewrite.resolve_name,
"type": "resolve",
"comparison_prompt": rewrite.comparison_prompt,
"resolution_prompt": rewrite.resolution_prompt,
"blocking_conditions": rewrite.blocking_conditions,
"blocking_keys": rewrite.blocking_keys,
"limit_comparisons": rewrite.limit_comparisons,
"model": default_model,
"output": {"schema": output_schema},
}
# Insert resolve operation after the map operation
new_ops_list.insert(map_pos + 1, resolve_op)
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there are exactly two target ops (map and reduce)
assert (
len(target_ops) == 2
), "There must be exactly two target ops (map and reduce) to instantiate this directive"
# Find the map and reduce operations
map_op = None
reduce_op = None
for op in operators:
if op["name"] == target_ops[0]:
if op.get("type") == "map":
map_op = op
elif op.get("type") == "reduce":
reduce_op = op
elif op["name"] == target_ops[1]:
if op.get("type") == "map":
map_op = op
elif op.get("type") == "reduce":
reduce_op = op
if map_op is None or reduce_op is None:
raise ValueError(
f"Could not find both a map and reduce operation in target_ops: {target_ops}"
)
# Verify the map comes before reduce
map_idx = next(
i for i, op in enumerate(operators) if op["name"] == map_op["name"]
)
reduce_idx = next(
i for i, op in enumerate(operators) if op["name"] == reduce_op["name"]
)
if map_idx >= reduce_idx:
raise ValueError(
f"Map operation '{map_op['name']}' must come before reduce operation '{reduce_op['name']}'"
)
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
map_op,
reduce_op,
agent_llm,
message_history,
)
# Apply the rewrite to the operators
new_ops_plan = self.apply(
global_default_model, operators, map_op["name"], reduce_op["name"], rewrite
)
return new_ops_plan, message_history, call_cost

View File

@ -1,4 +1,5 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -171,14 +172,17 @@ class OperatorFusionDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=OperatorFusionInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = OperatorFusionInstantiateSchema(**parsed_res)
@ -187,12 +191,11 @@ class OperatorFusionDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(
@ -212,6 +215,7 @@ class OperatorFusionDirective(Directive):
new_ops_list = deepcopy(ops_list)
op1_name, op2_name = target_ops[0], target_ops[1]
# Find the operations
op1_idx = next(i for i, op in enumerate(ops_list) if op["name"] == op1_name)
op2_idx = next(i for i, op in enumerate(ops_list) if op["name"] == op2_name)
@ -325,5 +329,5 @@ def transform(input_doc):
return (
self.apply(global_default_model, operators, target_ops, rewrite),
message_history,
call_cost,
call_cost
)

View File

@ -1,4 +1,5 @@
import json
import os
import re
from copy import deepcopy
from typing import Dict, List, Type
@ -126,7 +127,7 @@ class ReduceChainingDirective(Directive):
original_op: Dict,
expected_document_key: str,
agent_llm: str,
message_history: list = [],
message_history: list = []
):
"""
Use LLM to instantiate this directive by decomposing the reduce operation.
@ -154,15 +155,19 @@ class ReduceChainingDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
call_cost = 0
resp = completion(
model=agent_llm,
messages=message_history,
response_format=ReduceChainingInstantiateSchema,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ReduceChainingInstantiateSchema
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ReduceChainingInstantiateSchema(**parsed_res)
@ -177,12 +182,11 @@ class ReduceChainingDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(
@ -279,9 +283,7 @@ class ReduceChainingDirective(Directive):
]
if document_key_candidates:
expected_document_key = document_key_candidates[
0
] # Pick the first candidate
expected_document_key = document_key_candidates[0] # Pick the first candidate
elif input_keys:
expected_document_key = input_keys[0] # Fall back to the first input key
else:

View File

@ -1,4 +1,5 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -254,14 +255,17 @@ class ReduceGleaningDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=GleaningInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = GleaningInstantiateSchema(**parsed_res)
@ -270,12 +274,11 @@ class ReduceGleaningDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(
@ -338,6 +341,5 @@ class ReduceGleaningDirective(Directive):
# Apply the rewrite to the operators
return (
self.apply(global_default_model, operators, target_ops[0], rewrite),
message_history,
call_cost,
message_history, call_cost
)

View File

@ -1,4 +1,5 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -233,14 +234,18 @@ class TakeHeadTailDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
call_cost = 0
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=TakeHeadTailInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
call_cost = resp._hidden_params["response_cost"]
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = TakeHeadTailInstantiateSchema(**parsed_res)
@ -249,12 +254,11 @@ class TakeHeadTailDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(

View File

@ -61,7 +61,7 @@ class MapOpConfig(BaseModel):
...,
description="The keys of the output of the Map operator, to be referenced in the downstream operator's prompt. Can be a single key or a list of keys. Can be new keys or existing keys from the map operator we are rewriting.",
)
@classmethod
def validate_prompt_contains_input_key(cls, value: str) -> str:
"""
@ -445,12 +445,8 @@ class DeterministicDocCompressionInstantiateSchema(BaseModel):
raise ValueError(
"Code must define a function named 'transform' that takes input_doc as parameter"
)
if (
"return {" not in v
and "return dict(" not in v
and "output = dict(" not in v
):
raise ValueError(f"Code must return a dictionary. Instead the code is: {v}")
if "return {" not in v and "return dict(" not in v:
raise ValueError("Code must return a dictionary")
return v
def validate_code_returns_target_keys(self, target_ops_configs: List[Dict]) -> None:
@ -567,6 +563,7 @@ class ChunkHeaderSummaryInstantiateSchema(BaseModel):
return v
class SamplingConfig(BaseModel):
"""Configuration for optional sampling in document chunking."""
@ -637,7 +634,7 @@ class DocumentChunkingInstantiateSchema(BaseModel):
default=None,
description="Optional sampling configuration. If provided, inserts a Sample operation between Gather and Map. Use by default UNLESS task requires processing ALL chunks (like comprehensive extraction of all instances).",
)
def validate_split_key_exists_in_input(self, input_file_path: str) -> None:
"""
Validates that the split_key exists in the input JSON file items.
@ -815,6 +812,7 @@ class DocumentChunkingTopKInstantiateSchema(BaseModel):
description="Configuration for the topk operation to select relevant chunks",
)
def validate_split_key_exists_in_input(self, input_file_path: str) -> None:
"""
Validates that the split_key exists in the input JSON file items.
@ -1213,109 +1211,6 @@ class SearchReplaceEdit(BaseModel):
# )
class MapToMapResolveReduceInstantiateSchema(BaseModel):
"""
Schema for inserting a Resolve operation between Map and Reduce.
Transforms Map -> Reduce => Map -> Resolve -> Reduce pattern.
The Resolve operation deduplicates/normalizes entities before aggregation.
"""
resolve_name: str = Field(..., description="The name of the new Resolve operator")
comparison_prompt: str = Field(
...,
description="Jinja prompt template for comparing two items. Must use {{ input1.key }} and {{ input2.key }} to reference fields from both items being compared.",
)
resolution_prompt: str = Field(
...,
description="Jinja prompt template for resolving matched items into a canonical form. Must use {% for input in inputs %} to iterate over matched items.",
)
blocking_conditions: List[str] = Field(
...,
description="List of Python expressions that determine if two items should be compared. Each expression has access to 'input1' and 'input2' dicts. Example: \"input1['category'].lower() == input2['category'].lower()\"",
)
blocking_keys: List[str] = Field(
...,
description="Keys to use for blocking/grouping items before comparison. Must include at least the reduce_key from the downstream Reduce operation, plus any additional context helpful for resolution.",
)
limit_comparisons: int = Field(
default=15000,
description="Maximum number of pairs to compare. Code-based blocked pairs are prioritized. Defaults to 15000 to avoid O(n^2) comparisons.",
gt=0,
)
@field_validator("comparison_prompt")
@classmethod
def check_comparison_prompt(cls, v: str) -> str:
if "input1" not in v or "input2" not in v:
raise ValueError(
"comparison_prompt must reference both 'input1' and 'input2' variables"
)
return v
@field_validator("resolution_prompt")
@classmethod
def check_resolution_prompt(cls, v: str) -> str:
if "inputs" not in v:
raise ValueError(
"resolution_prompt must reference 'inputs' variable for iterating over matched items"
)
return v
@field_validator("blocking_conditions")
@classmethod
def check_blocking_conditions(cls, v: List[str]) -> List[str]:
if not v:
raise ValueError(
"At least one blocking condition must be provided to avoid O(n^2) comparisons"
)
for condition in v:
if "input1" not in condition or "input2" not in condition:
raise ValueError(
f"Blocking condition must reference both 'input1' and 'input2': {condition}"
)
return v
class MapResolveToMapWithCategoriesInstantiateSchema(BaseModel):
"""
Schema for replacing Map -> Resolve with a single Map that has predefined categories.
The agent proposes categories based on analysis of the data/task, and the new Map
forces outputs into one of these categories (or 'none of the above'), effectively
doing entity resolution deterministically.
"""
categories: List[str] = Field(
...,
description="List of valid category values that the map output should be constrained to. Should include all distinct canonical values discovered from analyzing the data/task.",
)
category_key: str = Field(
...,
description="The key in the output schema that will contain the category value",
)
new_prompt: str = Field(
...,
description="The new prompt for the Map operation that includes the category list and instructs the LLM to output one of the predefined categories. Must reference {{ input.key }} for input fields.",
)
include_none_of_above: bool = Field(
default=True,
description="Whether to include 'None of the above' as a valid category option for items that don't fit any category",
)
@field_validator("categories")
@classmethod
def check_categories(cls, v: List[str]) -> List[str]:
if len(v) < 2:
raise ValueError("At least 2 categories must be provided")
if len(v) != len(set(v)):
raise ValueError("Categories must be unique")
return v
@field_validator("new_prompt")
@classmethod
def check_new_prompt(cls, v: str) -> str:
return MapOpConfig.validate_prompt_contains_input_key(v)
class ArbitraryRewriteInstantiateSchema(BaseModel):
"""
Schema for arbitrary pipeline rewrites using search/replace edits.

View File

@ -699,7 +699,7 @@ class DSLRunner(ConfigWrapper):
"litellm_kwargs", {}
)
kwargs["rewrite_agent_model"] = self.config.get("optimizer_config", {}).get(
"rewrite_agent_model", "gpt-5.1"
"rewrite_agent_model", "gpt-4o"
)
kwargs["judge_agent_model"] = self.config.get("optimizer_config", {}).get(
"judge_agent_model", "gpt-4o-mini"
@ -727,7 +727,7 @@ class DSLRunner(ConfigWrapper):
"litellm_kwargs", {}
)
kwargs["rewrite_agent_model"] = self.config.get("optimizer_config", {}).get(
"rewrite_agent_model", "gpt-5.1"
"rewrite_agent_model", "gpt-4o"
)
kwargs["judge_agent_model"] = self.config.get("optimizer_config", {}).get(
"judge_agent_model", "gpt-4o-mini"

View File

@ -34,9 +34,6 @@ To get started with DocETL:
2. Define your pipeline in a YAML file. Want to use an LLM like ChatGPT or Claude to help you write your pipeline? See [docetl.org/llms.txt](https://docetl.org/llms.txt) for a big prompt you can copy paste into ChatGPT or Claude, before describing your task.
3. Run your pipeline using the DocETL command-line interface
!!! tip "Fastest Way: Claude Code"
Clone this repo and run `claude` to use the built-in DocETL skill. Just describe your data processing task and Claude will create and run the pipeline for you. See [Quick Start (Claude Code)](quickstart-claude-code.md) for details.
## 🏛️ Project Origin
DocETL was created by members of the EPIC Data Lab and Data Systems and Foundations group at UC Berkeley. The EPIC (Effective Programming, Interaction, and Computation with Data) Lab focuses on developing low-code and no-code interfaces for data work, powered by next-generation predictive programming techniques. DocETL is one of the projects that emerged from our research efforts to streamline complex document processing tasks.

View File

@ -55,6 +55,14 @@ If you want to use only the parsing extra:
uv sync --extra parsing
```
If you want to use the parsing tools, you need to install the `parsing` extra:
```bash
poetry install --extras "parsing"
```
This will create a virtual environment and install all the required dependencies.
4. Set up your OpenAI API key:
Create a .env file in the project root and add your OpenAI API key:

View File

@ -91,10 +91,3 @@ The transform function should return True for items to keep and False for items
| reduce_key | Key(s) to group by (code_reduce only) | "_all" |
| pass_through | Pass through unmodified keys from first item in group (code_reduce only) | false |
| concurrent_thread_count | The number of threads to start | the number of logical CPU cores (os.cpu_count()) |
| limit | Maximum number of outputs to produce before stopping | Processes all data |
The `limit` parameter behaves differently for each operation type:
- **code_map**: Limits the number of input documents to process
- **code_filter**: Limits the number of documents that pass the filter (outputs)
- **code_reduce**: Limits the number of groups to reduce, selecting the smallest groups first (by document count)

View File

@ -34,9 +34,9 @@ This Equijoin operation matches job candidates to job postings:
The prompt template uses Jinja2 syntax, allowing you to reference input fields directly (e.g., `left.skills`). You can reference the left and right documents using `left` and `right` respectively.
!!! info "Automatic Blocking"
!!! warning "Performance Consideration"
If you don't specify any blocking configuration (`blocking_threshold`, `blocking_conditions`, or `limit_comparisons`), the Equijoin operation will automatically compute an optimal embedding-based blocking threshold at runtime. It samples pairs from your data, runs LLM comparisons on the sample, and finds a threshold that achieves 95% recall by default. You can adjust this with the `blocking_target_recall` parameter.
For large datasets, running comparisons with an LLM can be time-consuming. It's recommended to optimize your pipeline using `docetl build pipeline.yaml` to generate efficient blocking rules for the operation.
## Blocking
@ -95,19 +95,10 @@ A full Equijoin step combining both ideas might look like:
Equijoin shares many parameters with the Resolve operation. For a detailed list of required and optional parameters, please see the [Parameters section in the Resolve operation documentation](resolve.md#required-parameters).
### Equijoin-Specific Parameters
| Parameter | Description | Default |
| ------------------------- | --------------------------------------------------------------------------------- | ----------------------------- |
| `limits` | Maximum matches for each left/right item: `{"left": n, "right": m}` | No limit |
| `blocking_keys` | Keys for embedding blocking: `{"left": [...], "right": [...]}` | All keys from each dataset |
| `blocking_threshold` | Embedding similarity threshold for considering pairs | Auto-computed if not set |
| `blocking_target_recall` | Target recall when auto-computing blocking threshold (0.0 to 1.0) | 0.95 |
Key differences from Resolve:
Key differences for Equijoin include:
- `resolution_prompt` is not used in Equijoin.
- `blocking_keys` uses a dict with `left` and `right` keys instead of a simple list.
- `limits` parameter is specific to Equijoin, allowing you to set maximum matches for each left and right item.
## Incorporating Into a Pipeline

View File

@ -140,12 +140,9 @@ This strategy asks the LLM to generate regex patterns matching the desired conte
| `timeout` | Timeout for LLM calls in seconds | 120 |
| `skip_on_error` | Continue processing if errors occur | false |
| `litellm_completion_kwargs` | Additional parameters for LiteLLM calls | {} |
| `limit` | Maximum number of documents to extract from before stopping | Processes all data |
| `retriever` | Name of a retriever to use for RAG. See [Retrievers](../retrievers.md). | None |
| `save_retriever_output` | If true, saves the retrieved context to `_<operation_name>_retrieved_context` in the output. | False |
When `limit` is set, Extract only reformats and submits the first _N_ documents. This is handy when the upstream dataset is large and you want to cap cost while previewing results.
## Best Practices
Create specific extraction prompts with clear criteria about what content to extract or exclude. Choose the appropriate extraction method based on your needs:

View File

@ -83,10 +83,6 @@ This example demonstrates how the Filter operation distinguishes between high-im
See [map optional parameters](./map.md#optional-parameters) for additional configuration options, including `batch_prompt` and `max_batch_size`.
### Limiting filtered outputs
`limit` behaves slightly differently for filter operations than for map operations. Because filter drops documents whose predicate evaluates to `false`, the limit counts only the documents that would be retained (i.e., the ones whose boolean output is `true`). DocETL will continue evaluating additional inputs until it has collected `limit` passing documents and then stop scheduling further LLM calls. This ensures you can request “the first N matches” without paying to score the entire dataset.
!!! info "Validation"
For more details on validation techniques and implementation, see [operators](../concepts/operators.md#validation).

View File

@ -140,7 +140,6 @@ This example demonstrates how the Map operation can transform long, unstructured
| `optimize` | Flag to enable operation optimization | `True` |
| `recursively_optimize` | Flag to enable recursive optimization of operators synthesized as part of rewrite rules | `false` |
| `sample` | Number of samples to use for the operation | Processes all data |
| `limit` | Maximum number of outputs to produce before stopping | Processes all data |
| `tools` | List of tool definitions for LLM use | None |
| `validate` | List of Python expressions to validate the output | None |
| `flush_partial_results` | Write results of individual batches of map operation to disk for faster inspection | False |
@ -161,10 +160,6 @@ This example demonstrates how the Map operation can transform long, unstructured
Note: If `drop_keys` is specified, `prompt` and `output` become optional parameters.
### Limiting execution
Set `limit` when you only need the first _N_ map results or want to cap LLM spend. The operation slices the processed dataset to the first `limit` entries and also stops scheduling new prompts once that many outputs have been produced, even if a prompt returns multiple records. Filter operations inherit this behavior but redefine the count so the limit only applies to records whose filter predicate evaluates to `true` (see [Filter](./filter.md#optional-parameters)).
!!! info "Validation and Gleaning"

View File

@ -52,7 +52,6 @@ This Reduce operation processes customer feedback grouped by department:
| Parameter | Description | Default |
| ------------------------- | ------------------------------------------------------------------------------------------------------ | --------------------------- |
| `sample` | Number of samples to use for the operation | None |
| `limit` | Maximum number of groups to process before stopping | All groups |
| `synthesize_resolve` | If false, won't synthesize a resolve operation between map and reduce | true |
| `model` | The language model to use | Falls back to default_model |
| `input` | Specifies the schema or keys to subselect from each item | All keys from input items |
@ -70,10 +69,6 @@ This Reduce operation processes customer feedback grouped by department:
| `retriever` | Name of a retriever to use for RAG. See [Retrievers](../retrievers.md). | None |
| `save_retriever_output` | If true, saves the retrieved context to `_<operation_name>_retrieved_context` in the output. | False |
### Limiting group processing
Set `limit` to short-circuit the reduce phase after _N_ groups have been aggregated. When `limit` is set, groups are sorted by size (smallest first) and only the _N_ smallest groups are processed. This is useful for previewing results or capping LLM usage while minimizing cost by processing groups with fewer documents. Groups beyond the limit are never scheduled, so you avoid extra fold/merge calls. If a grouped reduce returns more than one record per group, the final output list is truncated to `limit`.
## Advanced Features
### Incremental Folding

View File

@ -44,9 +44,9 @@ This Resolve operation processes patient names to identify and standardize dupli
Note: The prompt templates use Jinja2 syntax, allowing you to reference input fields directly (e.g., `input1.patient_name`).
!!! info "Automatic Blocking"
!!! warning "Performance Consideration"
If you don't specify any blocking configuration (`blocking_threshold`, `blocking_conditions`, or `limit_comparisons`), the Resolve operation will automatically compute an optimal embedding-based blocking threshold at runtime. It samples pairs from your data, runs LLM comparisons on the sample, and finds a threshold that achieves 95% recall by default. You can adjust this with the `blocking_target_recall` parameter.
You should not run this operation as-is unless your dataset is small! Running O(n^2) comparisons with an LLM can be extremely time-consuming for large datasets. Instead, optimize your pipeline first using `docetl build pipeline.yaml` and run the optimized version, which will generate efficient blocking rules for the operation. Make sure you've set `optimize: true` in your resolve operation config.
## Blocking
@ -132,8 +132,7 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un
| `resolution_model` | The language model to use for reducing matched entries | Falls back to `default_model` |
| `comparison_model` | The language model to use for comparing potential matches | Falls back to `default_model` |
| `blocking_keys` | List of keys to use for initial blocking | All keys in the input data |
| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | Auto-computed if not set |
| `blocking_target_recall` | Target recall when auto-computing blocking threshold (0.0 to 1.0) | 0.95 |
| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | None |
| `blocking_conditions` | List of conditions for initial blocking | [] |
| `input` | Specifies the schema or keys to subselect from each item to pass into the prompts | All keys from input items |
| `embedding_batch_size` | The number of entries to send to the embedding model at a time | 1000 |
@ -141,9 +140,9 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un
| `limit_comparisons` | Maximum number of comparisons to perform | None |
| `timeout` | Timeout for each LLM call in seconds | 120 |
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |
| `sample` | Number of samples to use for the operation | None |
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
| `bypass_cache` | If true, bypass the cache for this operation. | False |
| `sample` | Number of samples to use for the operation | None |
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
| `bypass_cache` | If true, bypass the cache for this operation. | False |
## Best Practices

View File

@ -14,7 +14,7 @@ All fields in `optimizer_config` are required (no defaults):
| `evaluation_file` | `str` | Path to Python file containing `@register_eval` decorated function |
| `metric_key` | `str` | Key in evaluation results dictionary to use as accuracy metric |
| `max_iterations` | `int` | Maximum number of MOARSearch iterations to run |
| `rewrite_agent_model` | `str` | LLM model to use for directive instantiation during search |
| `model` | `str` | LLM model to use for directive instantiation during search |
!!! warning "All Fields Required"
MOAR will error if any required field is missing. There are no defaults.
@ -67,19 +67,19 @@ The optimizer will use the sample dataset, but your final pipeline uses the full
### Available Models
!!! info "LiteLLM Model Names"
Use LiteLLM model names (e.g., `gpt-4o-mini`, `gpt-4o`, `gpt-5.1`). Make sure your API keys are set in your environment.
Use LiteLLM model names (e.g., `gpt-4o-mini`, `gpt-4o`, `gpt-4.1`). Make sure your API keys are set in your environment.
```yaml
available_models: # LiteLLM model names - ensure API keys are set
- gpt-5.1-nano # Cheapest, lower accuracy
- gpt-5.1-mini # Low cost, decent accuracy
- gpt-5.1 # Balanced
- gpt-4.1-nano # Cheapest, lower accuracy
- gpt-4.1-mini # Low cost, decent accuracy
- gpt-4.1 # Balanced
- gpt-4o # Higher cost, better accuracy
```
### Model for Directive Instantiation
The `rewrite_agent_model` field specifies which LLM to use for generating optimization directives during the search process. This doesn't affect the models tested in `available_models`.
The `model` field specifies which LLM to use for generating optimization directives during the search process. This doesn't affect the models tested in `available_models`.
!!! tip "Cost Consideration"
Use a cheaper model (like `gpt-4o-mini`) for directive instantiation to reduce search costs.
@ -104,12 +104,12 @@ optimizer_config:
available_models:
- gpt-4o-mini
- gpt-4o
- gpt-5.1-mini
- gpt-5.1
- gpt-4.1-mini
- gpt-4.1
evaluation_file: evaluate_medications.py
metric_key: medication_extraction_score
max_iterations: 40
rewrite_agent_model: gpt-5.1
model: gpt-4.1
dataset_path: data/sample.json # Optional
exploration_weight: 1.414 # Optional
```

View File

@ -25,15 +25,15 @@ optimizer_config:
dataset_path: workloads/medical/raw_sample.json # Use sample for faster optimization
save_dir: workloads/medical/moar_results
available_models: # LiteLLM model names - ensure API keys are set in your environment
- gpt-5.1-nano
- gpt-5.1-mini
- gpt-5.1
- gpt-4.1-nano
- gpt-4.1-mini
- gpt-4.1
- gpt-4o
- gpt-4o-mini
evaluation_file: workloads/medical/evaluate_medications.py
metric_key: medication_extraction_score
max_iterations: 40
rewrite_agent_model: gpt-5.1
model: gpt-4.1
system_prompt:
dataset_description: a collection of transcripts of doctor visits

View File

@ -109,12 +109,12 @@ optimizer_config:
available_models: # LiteLLM model names - ensure API keys are set in your environment
- gpt-4o-mini
- gpt-4o
- gpt-5.1-mini
- gpt-5.1
- gpt-4.1-mini
- gpt-4.1
evaluation_file: evaluate_medications.py
metric_key: medication_extraction_score # This must match a key in your evaluation function's return dictionary
max_iterations: 40
rewrite_agent_model: gpt-5.1
model: gpt-4.1
dataset_path: data/transcripts_sample.json # Optional: use sample/hold-out dataset
```

View File

@ -19,7 +19,7 @@ High-level summary of the optimization run:
{
"optimizer": "moar",
"input_pipeline": "pipeline.yaml",
"rewrite_agent_model": "gpt-5.1",
"model": "gpt-4.1",
"max_iterations": 40,
"save_dir": "results/moar_optimization",
"dataset": "transcripts",

View File

@ -80,9 +80,9 @@ Test your evaluation function independently and check MOAR logs for errors.
```yaml
available_models:
- gpt-5.1-nano # Cheapest, lower accuracy
- gpt-5.1-mini # Low cost, decent accuracy
- gpt-5.1 # Balanced
- gpt-4.1-nano # Cheapest, lower accuracy
- gpt-4.1-mini # Low cost, decent accuracy
- gpt-4.1 # Balanced
- gpt-4o # Higher cost, better accuracy
```

View File

@ -1,45 +0,0 @@
# Quick Start with Claude Code
The fastest way to build DocETL pipelines is with [Claude Code](https://claude.ai/download), Anthropic's agentic coding tool. DocETL includes a built-in Claude Code skill that helps you create, run, and debug pipelines interactively.
## Option 1: Clone the Repository (Recommended)
This gives you the full development environment with the skill already configured.
1. Follow the [Installation from Source](installation.md#installation-from-source) instructions
2. Run `claude` in the repository directory
The skill is located at `.claude/skills/docetl/SKILL.md`.
## Option 2: Install via pip
If you already have DocETL installed via pip, you can install the skill separately:
```bash
pip install docetl
docetl install-skill
```
This copies the skill to `~/.claude/skills/docetl/`. Then run `claude` in any directory.
To uninstall: `docetl install-skill --uninstall`
## Usage
Simply describe what you want to do with your data. The skill activates automatically when you mention "docetl" or describe unstructured data processing tasks:
```
> I have a folder of customer support tickets in JSON format.
> I want to extract the main issue, sentiment, and suggested resolution for each.
```
Claude will:
1. **Read your data** to understand its structure
2. **Write a tailored pipeline** with prompts specific to your documents
3. **Run the pipeline** and show you the results
4. **Debug issues** if any operations fail
## Alternative: Manual Pipeline Authoring
If you prefer not to use Claude Code, see the [Quick Start Tutorial](tutorial.md) for writing pipelines by hand.

View File

@ -1,33 +1,29 @@
## Retrievers (LanceDB OSS)
Retrievers let you augment LLM operations with retrieved context from a LanceDB index built over a DocETL dataset. You define retrievers once at the top-level, then attach them to any LLM-powered operation using `retriever: <name>`. At runtime, DocETL performs full-text, vector, or hybrid search and injects the results into your prompt as `{{ retrieval_context }}`.
Retrievers let you augment LLM operations with retrieved context from a LanceDB index built over one of your DocETL datasets. You define retrievers once at the top-level, then attach them to any LLM-powered operation using `retriever: <name>`. At runtime, DocETL performs fulltext, vector, or hybrid search and injects the results into your prompt as `{{ retrieval_context }}`.
LanceDB supports built-in full-text search, vector search, and hybrid with RRF reranking. See the official docs: [LanceDB Hybrid Search docs](https://lancedb.com/docs/search/hybrid-search/).
### Key points
- Always OSS LanceDB (local `index_dir`).
- A retriever references a dataset from the pipeline config, or the output of a previous pipeline step.
- A retriever references an existing dataset from the pipeline config.
- Operations do not override retriever settings. One source of truth = consistency.
- `{{ retrieval_context }}` is available to your prompt; if not used, DocETL prepends a short "extra context" section automatically.
- `{{ retrieval_context }}` is available to your prompt; if not used, DocETL prepends a short “extra context” section automatically.
## Configuration
## Configuration (clear separation of index vs query)
Add a top-level `retrievers` section. Each retriever has:
- `dataset`: dataset name to index (can be a dataset or output of a previous pipeline step)
- `dataset`: dataset name to index
- `index_dir`: LanceDB path
- `index_types`: which indexes to build: `fts`, `embedding`, or `hybrid` (both)
- `fts.index_phrase`: Jinja template for indexing each row for full-text search
- `fts.query_phrase`: Jinja template for building the FTS query at runtime
- `embedding.model`: embedding model for vector index and queries
- `embedding.index_phrase`: Jinja template for indexing each row for embeddings
- `embedding.query_phrase`: Jinja template for building the embedding query
- `index_types`: which indexes to build: `fts`, `embedding`, or `hybrid` (interpreted as both `fts` and `embedding`)
- `fts.index_phrase`: Jinja for how to index each dataset row for FTS (context: `input`)
- `fts.query_phrase`: Jinja for how to build the FTS query (context: operation context)
- `embedding.model`: embedding model used for the vector index and for query vectors
- `embedding.index_phrase`: Jinja for how to index each dataset row for embedding (context: `input`)
- `embedding.query_phrase`: Jinja for how to build the embedding query text (context: operation context)
- `query.mode`: `fts` | `embedding` | `hybrid` (defaults to `hybrid` when both indexes exist)
- `query.top_k`: number of results to retrieve
### Basic example
```yaml
datasets:
transcripts:
@ -41,356 +37,158 @@ retrievers:
type: lancedb
dataset: transcripts
index_dir: workloads/medical/lance_index
build_index: if_missing # if_missing | always | never
index_types: ["fts", "embedding"]
build_index: if_missing # if_missing | always | never
index_types: ["fts", "embedding"] # or "hybrid"
fts:
index_phrase: "{{ input.src }}"
query_phrase: "{{ input.src[:1000] }}"
# How to index each row (context: input == dataset row)
index_phrase: >
{{ input.src }}
# How to build the query (map/filter/extract context: input; reduce: reduce_key & inputs)
query_phrase: >
{{ input.get("src","")[:1000] if input else "" }}
embedding:
model: openai/text-embedding-3-small
index_phrase: "{{ input.src }}"
query_phrase: "{{ input.src[:1000] }}"
# How to index each row for embedding (context: input == dataset row)
index_phrase: >
{{ input.src }}
# How to build the query text to embed (op context)
query_phrase: >
{{ input.get("src","")[:1000] if input else "" }}
query:
mode: hybrid
top_k: 8
```
## Multi-step pipelines with retrieval
Most pipelines have a single step, but you can define multiple steps where **the output of one step becomes the input (and retriever source) for the next**. This is powerful for patterns like:
1. Extract structured data from documents
2. Build a retrieval index on that extracted data
3. Use retrieval to find related items and process them
### Example: Extract facts, then find conflicts
```yaml
datasets:
articles:
type: file
path: workloads/wiki/articles.json
default_model: gpt-4o-mini
# Retriever indexes output of step 1 (extract_facts_step)
retrievers:
facts_index:
type: lancedb
dataset: extract_facts_step # References output of a pipeline step!
index_dir: workloads/wiki/facts_lance_index
build_index: if_missing
index_types: ["fts", "embedding"]
fts:
index_phrase: "{{ input.fact }} from {{ input.title }}"
query_phrase: "{{ input.fact }}"
embedding:
model: openai/text-embedding-3-small
index_phrase: "{{ input.fact }}"
query_phrase: "{{ input.fact }}"
query:
mode: hybrid
top_k: 5
operations:
- name: extract_facts
type: map
prompt: |
Extract factual claims from this article.
Article: {{ input.title }}
Text: {{ input.text }}
output:
schema:
facts: list[string]
- name: unnest_facts
type: unnest
unnest_key: facts
- name: find_conflicts
type: map
retriever: facts_index # Uses the retriever
prompt: |
Check if this fact conflicts with similar facts from other articles.
Current fact: {{ input.facts }} (from {{ input.title }})
Similar facts from other articles:
{{ retrieval_context }}
Return true only if there's a genuine contradiction.
output:
schema:
has_conflict: boolean
pipeline:
steps:
# Step 1: Extract and unnest facts
- name: extract_facts_step
input: articles
operations:
- extract_facts
- unnest_facts
# Step 2: Use retrieval to find conflicts
- name: find_conflicts_step
input: extract_facts_step # Input is output of step 1
operations:
- find_conflicts
output:
type: file
path: workloads/wiki/conflicts.json
intermediate_dir: workloads/wiki/intermediates
```
In this example:
- **Step 1** (`extract_facts_step`) extracts facts from articles
- The **retriever** (`facts_index`) indexes the output of step 1
- **Step 2** (`find_conflicts_step`) processes each fact, using retrieval to find similar facts from other articles
Notes:
- Index build is automatic the first time a retriever is used (when `build_index: if_missing`).
- `fts.index_phrase` and `embedding.index_phrase` are evaluated with `input` for each dataset record (here `input` is the dataset row).
- `fts.query_phrase` and `embedding.query_phrase` are evaluated with the operation context.
## Configuration reference
### Minimal example
Top-level (retrievers.<name>):
Here's the simplest possible retriever config (FTS only):
| Parameter | Type | Required | Default | Description |
| --- | --- | --- | --- | --- |
| type | string | yes | - | Must be `lancedb`. |
| dataset | string | yes | - | Name of an existing dataset in the pipeline config. |
| index_dir | string | yes | - | Filesystem path for the LanceDB database. Created if missing. |
| build_index | enum | no | `if_missing` | `if_missing` \| `always` \| `never`. Controls when to build the index. |
| index_types | list[string] \| string | yes | - | Which indexes to build: `fts`, `embedding`, or `"hybrid"` (interpreted as both). |
```yaml
retrievers:
my_search: # name can be anything you want
type: lancedb
dataset: my_dataset # must match a dataset name or pipeline step
index_dir: ./my_lance_index
index_types: ["fts"]
fts:
index_phrase: "{{ input.text }}" # what to index from each row
query_phrase: "{{ input.query }}" # what to search for at runtime
```
FTS section (retrievers.<name>.fts):
### Full example with all options
| Parameter | Type | Required | Default | Description |
| --- | --- | --- | --- | --- |
| index_phrase | jinja string | required if `fts` in index_types | - | How to index each dataset row. Context: `row`. |
| query_phrase | jinja string | recommended for FTS/hybrid queries | - | How to construct the FTS query. Context: op context (see below). |
```yaml
retrievers:
my_search:
type: lancedb
dataset: my_dataset
index_dir: ./my_lance_index
build_index: if_missing # optional, default: if_missing
index_types: ["fts", "embedding"] # can be ["fts"], ["embedding"], or both
fts:
index_phrase: "{{ input.text }}"
query_phrase: "{{ input.query }}"
embedding:
model: openai/text-embedding-3-small
index_phrase: "{{ input.text }}" # optional, falls back to fts.index_phrase
query_phrase: "{{ input.query }}"
query: # optional section
mode: hybrid # optional, auto-selects based on index_types
top_k: 10 # optional, default: 5
```
Embedding section (retrievers.<name>.embedding):
---
| Parameter | Type | Required | Default | Description |
| --- | --- | --- | --- | --- |
| model | string | required if `embedding` in index_types | - | Embedding model used for both index vectors and query vectors. |
| index_phrase | jinja string | no | falls back to `fts.index_phrase` if present | How to index each dataset row for embedding. Context: `row`. |
| query_phrase | jinja string | recommended for embedding/hybrid queries | - | How to construct the text to embed at query time. Context: op context. |
### Required fields
Query section (retrievers.<name>.query):
| Field | Description |
| --- | --- |
| `type` | Must be `lancedb` |
| `dataset` | Name of a dataset or pipeline step to index |
| `index_dir` | Path where LanceDB stores the index (created if missing) |
| `index_types` | List of index types: `["fts"]`, `["embedding"]`, or `["fts", "embedding"]` |
| Parameter | Type | Required | Default | Description |
| --- | --- | --- | --- | --- |
| mode | enum | no | auto | `fts` \| `embedding` \| `hybrid`. If omitted: `hybrid` when both indexes exist, else whichever index exists. |
| top_k | int | no | 5 | Number of results to return. |
---
### Optional fields
| Field | Default | Description |
| --- | --- | --- |
| `build_index` | `if_missing` | When to build: `if_missing`, `always`, or `never` |
| `query.mode` | auto | `fts`, `embedding`, or `hybrid`. Auto-selects based on what indexes exist |
| `query.top_k` | 5 | Number of results to return |
---
### The `fts` section
Required if `"fts"` is in `index_types`. Configures full-text search.
| Field | Required | Description |
| --- | --- | --- |
| `index_phrase` | yes | Jinja template: what text to index from each dataset row |
| `query_phrase` | yes | Jinja template: what text to search for at query time |
**Jinja variables available:**
| Template | Variables | When it runs |
| --- | --- | --- |
| `index_phrase` | `input` = the dataset row | Once per row when building the index |
| `query_phrase` | `input` = current item (map/filter/extract) | At query time for each item processed |
| `query_phrase` | `reduce_key`, `inputs` (reduce operations) | At query time for each group |
**Example - Medical knowledge base:**
```yaml
datasets:
drugs:
type: file
path: drugs.json # [{"name": "Aspirin", "uses": "pain, fever"}, ...]
patient_notes:
type: file
path: notes.json # [{"symptoms": "headache and fever"}, ...]
retrievers:
drug_lookup:
type: lancedb
dataset: drugs # index the drugs dataset
index_dir: ./drug_index
index_types: ["fts"]
fts:
index_phrase: "{{ input.name }}: {{ input.uses }}" # index: "Aspirin: pain, fever"
query_phrase: "{{ input.symptoms }}" # search with patient symptoms
operations:
- name: find_treatment
type: map
retriever: drug_lookup # attach the retriever
prompt: |
Patient symptoms: {{ input.symptoms }}
Relevant drugs from knowledge base:
{{ retrieval_context }}
Recommend a treatment.
output:
schema:
recommendation: string
```
When processing `{"symptoms": "headache and fever"}`:
1. `query_phrase` renders to `"headache and fever"`
2. FTS searches the index and finds `"Aspirin: pain, fever"` as a match
3. `{{ retrieval_context }}` in your prompt contains the matched results
---
### The `embedding` section
Required if `"embedding"` is in `index_types`. Configures vector/semantic search.
| Field | Required | Description |
| --- | --- | --- |
| `model` | yes | Embedding model, e.g. `openai/text-embedding-3-small` |
| `index_phrase` | no | Jinja template for text to embed. Falls back to `fts.index_phrase` |
| `query_phrase` | yes | Jinja template for query text to embed |
**Jinja variables:** Same as FTS section.
**Example - Semantic search:**
```yaml
retrievers:
semantic_docs:
type: lancedb
dataset: documentation
index_dir: ./docs_index
index_types: ["embedding"]
embedding:
model: openai/text-embedding-3-small
index_phrase: "{{ input.content }}"
query_phrase: "{{ input.question }}"
```
---
### The `query` section (optional)
Controls search behavior. You can omit this entire section.
| Field | Default | Description |
| --- | --- | --- |
| `mode` | auto | `fts`, `embedding`, or `hybrid`. Auto-selects `hybrid` if both indexes exist |
| `top_k` | 5 | Number of results to retrieve |
**Example - Override defaults:**
```yaml
retrievers:
my_search:
# ... other config ...
query:
mode: fts # force FTS even if embedding index exists
top_k: 20 # return more results
```
---
Notes:
- Hybrid search uses LanceDBs built-in reranking (RRF) by default.
- Jinja contexts:
- Map / Filter / Extract: `{"input": <current item>}`
- Reduce: `{"reduce_key": {...}, "inputs": [items]}`
- Jinja for indexing uses `{"input": <dataset row>}`
- Keep query phrases concise; slice long fields, e.g. `{{ input.src[:1000] }}`.
- The injected `retrieval_context` is truncated conservatively (~1000 chars per doc).
## Using a retriever in operations
Attach a retriever to any LLM operation (map, filter, reduce, extract) with `retriever: <retriever_name>`. The retrieved results are available as `{{ retrieval_context }}` in your prompt.
Attach the retriever to any LLM-powered op with `retriever: <name>`. Include `{{ retrieval_context }}` in your prompt or let DocETL prepend it automatically.
### Operation Parameters
When using a retriever with an operation, the following additional parameters are available:
| Parameter | Type | Default | Description |
| --- | --- | --- | --- |
| retriever | string | - | Name of the retriever to use (must match a key in `retrievers`). |
| save_retriever_output | bool | false | If true, saves retrieved context to `_<operation_name>_retrieved_context` in output. |
### Map example
| retriever | string | - | Name of the retriever to use (must be defined in the `retrievers` section). |
| save_retriever_output | bool | false | If true, saves the retrieved context to `_<operation_name>_retrieved_context` in the output. Useful for debugging and verifying retrieval quality. |
### Map
```yaml
- name: tag_visit
type: map
retriever: medical_r
save_retriever_output: true
output:
schema:
tag: string
prompt: |
Classify this medical visit. Related context:
{{ retrieval_context }}
Transcript: {{ input.src }}
operations:
- name: tag_visit
type: map
retriever: medical_r
save_retriever_output: true # Optional: save retrieved context to output
output:
schema:
tag: string
confidence: float
prompt: |
Classify the medical visit. Use the extra context if helpful:
{{ retrieval_context }}
Transcript:
{{ input.src }}
```
### Filter example
When `save_retriever_output: true`, each output document will include a `_tag_visit_retrieved_context` field containing the exact context that was retrieved and used for that document.
### Extract
```yaml
- name: filter_relevant
type: filter
retriever: medical_r
prompt: |
Is this transcript relevant to medication counseling?
Context: {{ retrieval_context }}
Transcript: {{ input.src }}
output:
schema:
is_relevant: boolean
- name: extract_side_effects
type: extract
retriever: medical_r
document_keys: ["src"]
prompt: "Extract side effects mentioned in the text."
```
### Reduce example
When using reduce, the retrieval context is computed per group.
### Filter
```yaml
- name: summarize_by_medication
type: reduce
retriever: medical_r
reduce_key: medication
output:
schema:
summary: string
prompt: |
Summarize key points for medication '{{ reduce_key.medication }}'.
Related context: {{ retrieval_context }}
Inputs:
{% for item in inputs %}
- {{ item.src }}
{% endfor %}
- name: filter_relevant
type: filter
retriever: medical_r
prompt: "Is this transcript relevant to medication counseling? Return is_relevant: boolean."
output:
schema:
is_relevant: bool
_short_explanation: string
```
### Reduce
When using reduce, the retrieval context is computed per group. The Jinja context provides both `reduce_key` and `inputs`.
```yaml
- name: summarize_by_medication
type: reduce
retriever: medical_r
reduce_key: "medication"
output:
schema:
summary: string
prompt: |
Summarize key points for medication '{{ reduce_key.medication }}'.
Use the extra context if helpful:
{{ retrieval_context }}
Inputs:
{{ inputs }}
```
## Jinja template contexts
- Map / Filter / Extract: `{"input": current_item}`
- Reduce: `{"reduce_key": {...}, "inputs": [items]}`
## Token budget and truncation
- DocETL uses a conservative default to limit the size of `retrieval_context` by truncating each retrieved text to ~1000 characters.
## Troubleshooting
- No results: the retriever injects “No extra context available.” and continues.
- Index issues: set `build_index: always` to rebuild; ensure `index_dir` exists and is writable.
- Embeddings: DocETL uses its embedding router and caches results where possible.
- **No results**: the retriever injects "No extra context available." and continues.
- **Index issues**: set `build_index: always` to rebuild; ensure `index_dir` exists and is writable.
- **Token limits**: `retrieval_context` is truncated to ~1000 chars per retrieved doc.

View File

@ -267,6 +267,10 @@ class AgentCommunicator:
response = completion(
model=self.model,
messages=messages,
azure=True,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
response_format={"type": "json_object"}
)
@ -286,6 +290,10 @@ class AgentCommunicator:
response = completion(
model=self.model,
messages=messages + [{"role": "user", "content": request_msg}],
azure=True,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
response_format={"type": "json_object"}
)

View File

@ -14,7 +14,6 @@ nav:
- Getting Started:
- Overview: index.md
- Installation: installation.md
- Quick Start (Claude Code): quickstart-claude-code.md
- Quick Start Tutorial: tutorial.md
- Quick Start Tutorial (Python API): tutorial-pythonapi.md
- Best Practices: best-practices.md

View File

@ -1,6 +1,6 @@
[project]
name = "docetl"
version = "0.2.6"
version = "0.2.5"
description = "ETL with LLM operations."
readme = "README.md"
requires-python = ">=3.10"
@ -122,7 +122,7 @@ build-backend = "hatchling.build"
[tool.hatch.build]
packages = ["docetl"]
include = ["docetl/**", "server/**", "README.md", "LICENSE", ".claude/skills/**"]
include = ["docetl/**", "server/**", "README.md", "LICENSE"]
exclude = ["website/**/*"]
[tool.pytest.ini_options]

View File

@ -27,7 +27,6 @@ class OptimizeResult(BaseModel):
should_optimize: str | None = None
input_data: list[dict[str, Any]] | None = None
output_data: list[dict[str, Any]] | None = None
num_docs_analyzed: int | None = None
cost: float | None = None
error: str | None = None
created_at: datetime
@ -36,25 +35,4 @@ class OptimizeResult(BaseModel):
class OptimizeRequest(BaseModel):
yaml_config: str
step_name: str
op_name: str
class DecomposeRequest(BaseModel):
yaml_config: str
step_name: str
op_name: str
class DecomposeResult(BaseModel):
task_id: str
status: TaskStatus
decomposed_operations: list[dict[str, Any]] | None = None
winning_directive: str | None = None
candidates_evaluated: int | None = None
original_outputs: list[dict[str, Any]] | None = None
decomposed_outputs: list[dict[str, Any]] | None = None
comparison_rationale: str | None = None
cost: float | None = None
error: str | None = None
created_at: datetime
completed_at: datetime | None = None
op_name: str

View File

@ -2,23 +2,13 @@ from typing import Any
import uuid
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
from docetl.runner import DSLRunner
from docetl.optimizers.fast_should_optimize import FastShouldOptimizeAnalyzer
from docetl.optimizers.fast_decomposer import FastDecomposer
import asyncio
from asyncio import Task
from rich.logging import RichHandler
import logging
import yaml
from datetime import datetime, timedelta
from enum import Enum
from server.app.models import (
OptimizeResult,
TaskStatus,
OptimizeRequest,
PipelineRequest,
DecomposeRequest,
DecomposeResult,
)
from server.app.models import OptimizeResult, TaskStatus, OptimizeRequest, PipelineRequest
# Setup logging
FORMAT = "%(message)s"
@ -28,14 +18,10 @@ logging.basicConfig(
router = APIRouter()
# Task storage for optimize tasks
# Task storage
tasks: dict[str, OptimizeResult] = {}
asyncio_tasks: dict[str, Task] = {}
# Task storage for decompose tasks
decompose_tasks: dict[str, DecomposeResult] = {}
decompose_asyncio_tasks: dict[str, Task] = {}
# Configuration
COMPLETED_TASK_TTL = timedelta(hours=1)
@ -44,111 +30,50 @@ async def cleanup_old_tasks():
while True:
try:
current_time = datetime.now()
# Clean up optimize tasks
task_ids_to_remove = []
for task_id, task in tasks.items():
if (task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED] and
task.completed_at and
task.completed_at and
current_time - task.completed_at > COMPLETED_TASK_TTL):
task_ids_to_remove.append(task_id)
for task_id in task_ids_to_remove:
del tasks[task_id]
# Clean up decompose tasks
decompose_ids_to_remove = []
for task_id, task in decompose_tasks.items():
if (task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED] and
task.completed_at and
current_time - task.completed_at > COMPLETED_TASK_TTL):
decompose_ids_to_remove.append(task_id)
for task_id in decompose_ids_to_remove:
del decompose_tasks[task_id]
await asyncio.sleep(60)
except Exception as e:
logging.error(f"Error in cleanup task: {e}")
await asyncio.sleep(60)
async def run_optimization(task_id: str, yaml_config: str, step_name: str, op_name: str):
"""Execute the optimization task using fast single-LLM-call analysis."""
"""Execute the optimization task"""
try:
tasks[task_id].status = TaskStatus.PROCESSING
# yaml_config is a file path, not YAML content - read and parse the file
if yaml_config.endswith(".yaml") or yaml_config.endswith(".yml"):
with open(yaml_config, "r") as f:
config = yaml.safe_load(f)
else:
# Fallback: try parsing as YAML string
config = yaml.safe_load(yaml_config)
# Validate that we got a dict
if not isinstance(config, dict):
raise ValueError(
f"Invalid yaml_config: expected dict after parsing, got {type(config).__name__}. "
f"Input: {str(yaml_config)[:200]}"
)
# Get intermediate directory from config
intermediate_dir = (
config.get("pipeline", {})
.get("output", {})
.get("intermediate_dir")
)
if not intermediate_dir:
raise ValueError("No intermediate_dir configured in pipeline output")
# Find the operation config
op_config = None
for op in config.get("operations", []):
if op.get("name") == op_name:
op_config = op
break
if not op_config:
raise ValueError(f"Operation '{op_name}' not found in config")
# Get optimizer settings from config
optimizer_model = (
config.get("optimizer_config", {})
.get("rewrite_agent_model", "gpt-5.1")
)
litellm_kwargs = (
config.get("optimizer_config", {})
.get("litellm_kwargs", {})
)
# Create analyzer and run in thread pool
analyzer = FastShouldOptimizeAnalyzer(
intermediate_dir=intermediate_dir,
optimizer_model=optimizer_model,
litellm_kwargs=litellm_kwargs,
)
should_optimize, output_data, num_docs_analyzed, cost = await asyncio.to_thread(
analyzer.analyze,
op_config,
# Run the actual optimization in a separate thread to not block
runner = DSLRunner.from_yaml(yaml_config)
should_optimize, input_data, output_data, cost = await asyncio.to_thread(
runner.should_optimize,
step_name,
op_name,
)
# Update task result
tasks[task_id].status = TaskStatus.COMPLETED
tasks[task_id].should_optimize = should_optimize
tasks[task_id].input_data = None # We don't have input_data in fast mode
tasks[task_id].input_data = input_data
tasks[task_id].output_data = output_data
tasks[task_id].num_docs_analyzed = num_docs_analyzed
tasks[task_id].cost = cost
tasks[task_id].completed_at = datetime.now()
except asyncio.CancelledError:
runner.is_cancelled = True
tasks[task_id].status = TaskStatus.CANCELLED
tasks[task_id].completed_at = datetime.now()
raise
except Exception as e:
import traceback
error_traceback = traceback.format_exc()
@ -156,10 +81,11 @@ async def run_optimization(task_id: str, yaml_config: str, step_name: str, op_na
tasks[task_id].error = f"{str(e)}\n{error_traceback}"
tasks[task_id].completed_at = datetime.now()
raise
finally:
if task_id in asyncio_tasks:
del asyncio_tasks[task_id]
runner.reset_env()
@router.on_event("startup")
async def startup_event():
@ -227,140 +153,6 @@ async def cancel_optimize_task(task_id: str):
return {"message": "Task cancelled successfully"}
# ============================================================================
# Decompose endpoints
# ============================================================================
async def run_decomposition(task_id: str, yaml_config: str, step_name: str, op_name: str):
"""Execute the decomposition task using fast directive-based approach."""
try:
decompose_tasks[task_id].status = TaskStatus.PROCESSING
# yaml_config is a file path
if not (yaml_config.endswith(".yaml") or yaml_config.endswith(".yml")):
raise ValueError("yaml_config must be a path to a YAML file")
# Get optimizer settings from config
with open(yaml_config, "r") as f:
config = yaml.safe_load(f)
optimizer_model = (
config.get("optimizer_config", {})
.get("rewrite_agent_model", "gpt-5.1")
)
litellm_kwargs = (
config.get("optimizer_config", {})
.get("litellm_kwargs", {})
)
# Create decomposer and run in thread pool
decomposer = FastDecomposer(
yaml_config_path=yaml_config,
optimizer_model=optimizer_model,
sample_size=5,
litellm_kwargs=litellm_kwargs,
)
result = await asyncio.to_thread(
decomposer.decompose,
step_name,
op_name,
)
# Update task result
decompose_tasks[task_id].status = TaskStatus.COMPLETED
decompose_tasks[task_id].decomposed_operations = result["decomposed_ops"]
decompose_tasks[task_id].winning_directive = result["winning_directive"]
decompose_tasks[task_id].candidates_evaluated = result["candidates_evaluated"]
decompose_tasks[task_id].original_outputs = result["original_outputs"]
decompose_tasks[task_id].decomposed_outputs = result["decomposed_outputs"]
decompose_tasks[task_id].comparison_rationale = result["comparison_rationale"]
decompose_tasks[task_id].cost = result["cost"]
decompose_tasks[task_id].completed_at = datetime.now()
except asyncio.CancelledError:
decompose_tasks[task_id].status = TaskStatus.CANCELLED
decompose_tasks[task_id].completed_at = datetime.now()
raise
except Exception as e:
import traceback
error_traceback = traceback.format_exc()
decompose_tasks[task_id].status = TaskStatus.FAILED
decompose_tasks[task_id].error = f"{str(e)}\n{error_traceback}"
decompose_tasks[task_id].completed_at = datetime.now()
raise
finally:
if task_id in decompose_asyncio_tasks:
del decompose_asyncio_tasks[task_id]
@router.post("/decompose", status_code=202)
async def submit_decompose_task(request: DecomposeRequest):
"""Submit a new decomposition task"""
task_id = str(uuid.uuid4())
# Create task record
decompose_tasks[task_id] = DecomposeResult(
task_id=task_id,
status=TaskStatus.PENDING,
created_at=datetime.now()
)
# Create and store the asyncio task
task = asyncio.create_task(
run_decomposition(
task_id,
request.yaml_config,
request.step_name,
request.op_name
)
)
decompose_asyncio_tasks[task_id] = task
return {"task_id": task_id}
@router.get("/decompose/{task_id}")
async def get_decompose_status(task_id: str) -> DecomposeResult:
"""Get the current status of a decomposition task"""
if task_id not in decompose_tasks:
raise HTTPException(
status_code=404,
detail="Task not found or has been cleaned up"
)
return decompose_tasks[task_id]
@router.post("/decompose/{task_id}/cancel")
async def cancel_decompose_task(task_id: str):
"""Cancel a running decomposition task"""
if task_id not in decompose_tasks:
raise HTTPException(
status_code=404,
detail="Task not found or has been cleaned up"
)
if task_id not in decompose_asyncio_tasks:
raise HTTPException(
status_code=400,
detail="Task already finished or cannot be cancelled"
)
asyncio_task = decompose_asyncio_tasks[task_id]
asyncio_task.cancel()
try:
await asyncio_task
except asyncio.CancelledError:
pass
return {"message": "Decomposition task cancelled successfully"}
# Keep the original run_pipeline endpoint
@router.post("/run_pipeline")
def run_pipeline(request: PipelineRequest) -> dict[str, Any]:
@ -377,12 +169,6 @@ def run_pipeline(request: PipelineRequest) -> dict[str, Any]:
@router.websocket("/ws/run_pipeline/{client_id}")
async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
"""
WebSocket endpoint for running pipelines with real-time output streaming.
Note: The old 'optimize' flag for full pipeline optimization has been removed.
Use the /decompose endpoint for fast operation decomposition instead.
"""
await websocket.accept()
runner = None
try:
@ -392,8 +178,19 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
if config.get("clear_intermediate", False):
runner.clear_intermediate()
async def run_pipeline():
return await asyncio.to_thread(runner.load_run_save)
if config.get("optimize", False):
logging.info(f"Optimizing pipeline with model {config.get('optimizer_model', 'gpt-4o')}")
# Set the runner config to the optimizer config
runner.config["optimizer_config"]["rewrite_agent_model"] = config.get("optimizer_model", "gpt-4o")
runner.config["optimizer_config"]["judge_agent_model"] = config.get("optimizer_model", "gpt-4o-mini")
async def run_pipeline():
return await asyncio.to_thread(runner.optimize, return_pipeline=False)
else:
async def run_pipeline():
return await asyncio.to_thread(runner.load_run_save)
pipeline_task = asyncio.create_task(run_pipeline())
@ -401,6 +198,18 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
console_output = runner.console.file.getvalue()
await websocket.send_json({"type": "output", "data": console_output})
if config.get("optimize", False):
optimizer_progress = runner.console.get_optimizer_progress()
rationale = runner.console.optimizer_rationale
await websocket.send_json({
"type": "optimizer_progress",
"status": optimizer_progress[0],
"progress": optimizer_progress[1],
"rationale": rationale[1] if rationale is not None else "",
"should_optimize": rationale[0] if rationale is not None else False,
"validator_prompt": rationale[2] if rationale is not None else ""
})
# Check for incoming messages from the user
try:
user_message = await asyncio.wait_for(
@ -410,7 +219,8 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
if user_message == "kill":
runner.console.log("Stopping process...")
runner.is_cancelled = True
await websocket.send_json({
"type": "error",
"message": "Process stopped by user request"
@ -440,16 +250,41 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
# Sleep for a short duration to ensure all output is captured
await asyncio.sleep(3)
await websocket.send_json(
{
"type": "result",
"data": {
"message": "Pipeline executed successfully",
"cost": result,
"yaml_config": config["yaml_config"],
},
}
)
# If optimize is true, send back the optimized operations
if config.get("optimize", False):
optimized_config, cost = result
# Send the operations back in order
new_pipeline_steps = optimized_config["pipeline"]["steps"]
new_pipeline_op_name_to_op_map = {op["name"]: op for op in optimized_config["operations"]}
new_ops_in_order = []
for new_step in new_pipeline_steps:
for op in new_step.get("operations", []):
if op not in new_ops_in_order:
new_ops_in_order.append(new_pipeline_op_name_to_op_map[op])
await websocket.send_json(
{
"type": "result",
"data": {
"message": "Pipeline executed successfully",
"cost": cost,
"optimized_ops": new_ops_in_order,
"yaml_config": config["yaml_config"],
},
}
)
else:
await websocket.send_json(
{
"type": "result",
"data": {
"message": "Pipeline executed successfully",
"cost": result,
"yaml_config": config["yaml_config"],
},
}
)
except WebSocketDisconnect:
if runner is not None:
runner.reset_env()
@ -464,159 +299,3 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
if runner is not None:
runner.reset_env()
await websocket.close()
@router.websocket("/ws/decompose/{client_id}")
async def websocket_decompose(websocket: WebSocket, client_id: str):
"""
WebSocket endpoint for fast operation decomposition with real-time output streaming.
Expects JSON message with:
- yaml_config: Path to the pipeline YAML config file
- step_name: Name of the pipeline step
- op_name: Name of the operation to decompose
"""
await websocket.accept()
decomposer = None
try:
config = await websocket.receive_json()
yaml_config = config["yaml_config"]
step_name = config["step_name"]
op_name = config["op_name"]
# Validate yaml_config is a path
if not (yaml_config.endswith(".yaml") or yaml_config.endswith(".yml")):
raise ValueError("yaml_config must be a path to a YAML file")
# Get optimizer settings from config
with open(yaml_config, "r") as f:
pipeline_config = yaml.safe_load(f)
optimizer_model = (
pipeline_config.get("optimizer_config", {})
.get("rewrite_agent_model", "gpt-5.1")
)
litellm_kwargs = (
pipeline_config.get("optimizer_config", {})
.get("litellm_kwargs", {})
)
# Create a ThreadSafeConsole for streaming output
from docetl.console import ThreadSafeConsole
console = ThreadSafeConsole(
force_terminal=True,
soft_wrap=True,
highlight=False,
log_path=False,
color_system="truecolor",
width=120,
style="bright_white on black",
record=True,
)
# Create decomposer with the console
decomposer = FastDecomposer(
yaml_config_path=yaml_config,
optimizer_model=optimizer_model,
sample_size=5,
litellm_kwargs=litellm_kwargs,
console=console,
)
# Run decomposition in a thread
async def run_decomposition():
return await asyncio.to_thread(
decomposer.decompose,
step_name,
op_name,
)
decompose_task = asyncio.create_task(run_decomposition())
# Stream console output while decomposition runs
accumulated_output = ""
while not decompose_task.done():
# get_output() processes carriage returns and clears the buffer
new_output = console.get_output()
if new_output:
# Handle spinner updates: if new output doesn't start with newline
# and accumulated doesn't end with newline, replace the last line
if accumulated_output and not accumulated_output.endswith('\n') and not new_output.startswith('\n'):
# Find the last newline in accumulated output
last_newline = accumulated_output.rfind('\n')
if last_newline >= 0:
# Replace everything after the last newline
accumulated_output = accumulated_output[:last_newline + 1] + new_output
else:
# No newline in accumulated, replace everything
accumulated_output = new_output
else:
accumulated_output += new_output
await websocket.send_json({"type": "output", "data": accumulated_output})
# Check for kill message
try:
user_message = await asyncio.wait_for(
websocket.receive_json(), timeout=0.1
)
if user_message == "kill":
console.log("[red]Stopping decomposition...[/red]")
decompose_task.cancel()
await websocket.send_json({
"type": "error",
"message": "Decomposition stopped by user request"
})
raise Exception("Process stopped by user request")
except asyncio.TimeoutError:
pass
except asyncio.CancelledError:
await websocket.send_json({
"type": "error",
"message": "Decomposition stopped by user request"
})
raise
await asyncio.sleep(0.5)
# Get final result
result = await decompose_task
# Send any remaining console output
final_output = console.get_output()
if final_output:
accumulated_output += final_output
await websocket.send_json({"type": "output", "data": accumulated_output})
await asyncio.sleep(1)
# Send the result
await websocket.send_json({
"type": "result",
"data": {
"decomposed_operations": result["decomposed_ops"],
"winning_directive": result["winning_directive"],
"candidates_evaluated": result["candidates_evaluated"],
"original_outputs": result["original_outputs"],
"decomposed_outputs": result["decomposed_outputs"],
"comparison_rationale": result["comparison_rationale"],
"cost": result["cost"],
},
})
except WebSocketDisconnect:
print(f"Decompose client {client_id} disconnected")
except Exception as e:
import traceback
error_traceback = traceback.format_exc()
print(f"Decompose error:\n{error_traceback}")
try:
await websocket.send_json({
"type": "error",
"data": str(e),
"traceback": error_traceback
})
except Exception:
pass
finally:
await websocket.close()

View File

@ -1,5 +1,4 @@
import pytest
from docetl.operations.extract import ExtractOperation
from docetl.operations.filter import FilterOperation
from docetl.operations.unnest import UnnestOperation
from docetl.operations.equijoin import EquijoinOperation
@ -34,29 +33,6 @@ def filter_sample_data():
{"text": "Brief.", "word_count": 1},
]
@pytest.fixture
def extract_config_missing_key():
return {
"name": "missing_key_extract",
"type": "extract",
"prompt": "Identify portions of the document that contain numeric data: {{ input.text }}",
"document_keys": ["missing_key"],
"model": "gpt-4o-mini",
"skip_on_error": True,
}
@pytest.fixture
def extract_config_limit():
return {
"name": "limit_extract",
"type": "extract",
"prompt": "Extract the main topic from this text: {{ input.text }}",
"document_keys": ["text"],
"model": "gpt-4o-mini",
"limit": 1,
}
def test_filter_operation(
filter_config, default_model, max_threads, filter_sample_data, runner
@ -68,25 +44,6 @@ def test_filter_operation(
assert all(len(result["text"].split()) > 3 for result in results)
def test_filter_operation_limit_counts_true_outputs(
filter_config, default_model, max_threads, runner
):
filter_config["limit"] = 1
filter_config["bypass_cache"] = True
sample_data = [
{"text": "Tiny.", "word_count": 1},
{"text": "This example clearly exceeds three words.", "word_count": 7},
{"text": "Another sufficiently long sentence lives here.", "word_count": 6},
]
operation = FilterOperation(runner, filter_config, default_model, max_threads)
results, cost = operation.execute(sample_data)
assert len(results) == 1
assert results[0]["text"] == sample_data[1]["text"]
assert cost >= 0
def test_filter_operation_empty_input(
filter_config, default_model, max_threads, runner
):
@ -97,22 +54,6 @@ def test_filter_operation_empty_input(
assert cost == 0
def test_extract_operation_limit(
extract_config_limit, default_model, max_threads, runner
):
operation = ExtractOperation(
runner, extract_config_limit, default_model, max_threads
)
sample_data = [
{"id": 1, "text": "Document about contracts"},
{"id": 2, "text": "Document about finance"},
]
results, cost = operation.execute(sample_data)
assert len(results) == 1
assert results[0]["id"] == 1
# Unnest Operation Tests
@pytest.fixture
def unnest_config():

View File

@ -194,21 +194,6 @@ def test_map_operation_with_batching(
any(vs in result["sentiment"] for vs in valid_sentiments) for result in results
)
def test_map_operation_limit(
map_config,
default_model,
max_threads,
map_sample_data,
runner,
):
map_config["limit"] = 2
map_config["bypass_cache"] = True
operation = MapOperation(runner, map_config, default_model, max_threads)
results, cost = operation.execute(map_sample_data)
assert len(results) == 2
assert cost >= 0
def test_map_operation_with_empty_input(
map_config_with_batching, default_model, max_threads, runner

View File

@ -54,19 +54,6 @@ def test_reduce_operation(
assert cost > 0
def test_reduce_operation_limit(
reduce_config, default_model, max_threads, reduce_sample_data, runner
):
reduce_config["limit"] = 2
reduce_config["bypass_cache"] = True
operation = ReduceOperation(runner, reduce_config, default_model, max_threads)
results, cost = operation.execute(reduce_sample_data)
assert len(results) == 2
assert {result["group"] for result in results}.issubset({"A", "C"})
assert cost > 0
def test_reduce_operation_with_all_key(
reduce_config, default_model, max_threads, reduce_sample_data, runner
):

View File

@ -6,7 +6,7 @@ Simple apply tests for directive testing - just ensure apply() doesn't crash.
# Simple apply tests - no pytest needed
from docetl.reasoning_optimizer.directives import (
ChainingDirective,
ChainingDirective,
GleaningDirective,
ReduceGleaningDirective,
ReduceChainingDirective,
@ -14,6 +14,7 @@ from docetl.reasoning_optimizer.directives import (
OperatorFusionDirective,
DocSummarizationDirective,
IsolatingSubtasksDirective,
DocCompressionDirective,
DeterministicDocCompressionDirective,
DocumentChunkingDirective,
DocumentChunkingTopKDirective,
@ -21,12 +22,8 @@ from docetl.reasoning_optimizer.directives import (
TakeHeadTailDirective,
ClarifyInstructionsDirective,
SwapWithCodeDirective,
HierarchicalReduceDirective,
MapToMapResolveReduceDirective,
MapResolveToMapWithCategoriesDirective,
HierarchicalReduceDirective
)
# DocCompressionDirective is commented out in __init__.py, import directly
from docetl.reasoning_optimizer.directives.doc_compression import DocCompressionDirective
def test_chaining_apply():
@ -1015,237 +1012,6 @@ def test_cascade_filtering_apply():
assert "high-quality research paper" in result[5]["prompt"]
def test_map_to_map_resolve_reduce_apply():
"""Test that map_to_map_resolve_reduce apply doesn't crash"""
directive = MapToMapResolveReduceDirective()
# Pipeline with map followed by reduce
ops_list = [
{
"name": "extract_companies",
"type": "map",
"prompt": "Extract company name from: {{ input.article }}",
"model": "gpt-4o-mini",
"output": {"schema": {"company_name": "string"}},
},
{
"name": "aggregate_by_sector",
"type": "reduce",
"reduce_key": "sector",
"prompt": "List companies:\n{% for input in inputs %}\n- {{ input.company_name }}\n{% endfor %}",
"model": "gpt-4o-mini",
"output": {"schema": {"companies": "list[str]"}},
},
]
from docetl.reasoning_optimizer.instantiate_schemas import (
MapToMapResolveReduceInstantiateSchema,
)
rewrite = MapToMapResolveReduceInstantiateSchema(
resolve_name="normalize_company_names",
comparison_prompt="""Are these two company names referring to the same company?
Company 1: {{ input1.company_name }}
Company 2: {{ input2.company_name }}
Consider variations like abbreviations (IBM vs International Business Machines).""",
resolution_prompt="""Given these variations of a company name:
{% for input in inputs %}
- {{ input.company_name }}
{% endfor %}
Return the canonical/official company name.""",
blocking_conditions=[
"input1['company_name'][:3].lower() == input2['company_name'][:3].lower()",
"input1['company_name'].split()[0].lower() == input2['company_name'].split()[0].lower()",
],
blocking_keys=["sector", "company_name"], # Must include reduce_key (sector)
limit_comparisons=1000,
)
result = directive.apply(
"gpt-4o-mini", ops_list, "extract_companies", "aggregate_by_sector", rewrite
)
assert isinstance(result, list)
assert len(result) == 3 # map + resolve + reduce
# Check the resolve operation was inserted correctly
assert result[0]["name"] == "extract_companies"
assert result[0]["type"] == "map"
assert result[1]["name"] == "normalize_company_names"
assert result[1]["type"] == "resolve"
assert "comparison_prompt" in result[1]
assert "resolution_prompt" in result[1]
assert "blocking_conditions" in result[1]
assert "blocking_keys" in result[1]
assert result[1]["blocking_keys"] == ["sector", "company_name"]
assert result[1]["limit_comparisons"] == 1000
assert "input1" in result[1]["comparison_prompt"]
assert "input2" in result[1]["comparison_prompt"]
assert "inputs" in result[1]["resolution_prompt"]
# Output schema should be derived from reduce_key
assert result[1]["output"]["schema"] == {"sector": "string"}
assert result[2]["name"] == "aggregate_by_sector"
assert result[2]["type"] == "reduce"
def test_map_to_map_resolve_reduce_apply_with_multiple_reduce_keys():
"""Test map_to_map_resolve_reduce with multiple reduce_keys"""
directive = MapToMapResolveReduceDirective()
ops_list = [
{
"name": "extract_products",
"type": "map",
"prompt": "Extract product info: {{ input.description }}",
"model": "gpt-4o-mini",
"output": {"schema": {"product_name": "string", "category": "string"}},
},
{
"name": "aggregate_products",
"type": "reduce",
"reduce_key": ["brand", "region"], # Multiple reduce keys
"prompt": "List products:\n{% for input in inputs %}{{ input.product_name }}{% endfor %}",
"model": "gpt-4o-mini",
"output": {"schema": {"products": "list[str]"}},
},
]
from docetl.reasoning_optimizer.instantiate_schemas import (
MapToMapResolveReduceInstantiateSchema,
)
rewrite = MapToMapResolveReduceInstantiateSchema(
resolve_name="normalize_products",
comparison_prompt="Same product? {{ input1.product_name }} vs {{ input2.product_name }}",
resolution_prompt="Normalize: {% for input in inputs %}{{ input.product_name }}{% endfor %}",
blocking_conditions=[
"input1['category'] == input2['category']",
],
blocking_keys=["brand", "region", "product_name"], # Must include reduce_keys
limit_comparisons=500,
)
result = directive.apply(
"gpt-4o-mini", ops_list, "extract_products", "aggregate_products", rewrite
)
assert result[1]["type"] == "resolve"
assert "blocking_keys" in result[1]
assert result[1]["blocking_keys"] == ["brand", "region", "product_name"]
# Output schema should include all reduce_keys
assert result[1]["output"]["schema"] == {"brand": "string", "region": "string"}
def test_map_resolve_to_map_with_categories_apply():
"""Test that map_resolve_to_map_with_categories apply doesn't crash"""
directive = MapResolveToMapWithCategoriesDirective()
# Pipeline with map followed by resolve
ops_list = [
{
"name": "extract_sentiment",
"type": "map",
"prompt": "What is the sentiment? {{ input.review }}",
"model": "gpt-4o-mini",
"output": {"schema": {"sentiment": "string"}},
},
{
"name": "normalize_sentiment",
"type": "resolve",
"comparison_prompt": "Same sentiment? {{ input1.sentiment }} vs {{ input2.sentiment }}",
"resolution_prompt": "Normalize: {% for input in inputs %}{{ input.sentiment }}{% endfor %}",
"model": "gpt-4o-mini",
"output": {"schema": {"sentiment": "string"}},
},
]
from docetl.reasoning_optimizer.instantiate_schemas import (
MapResolveToMapWithCategoriesInstantiateSchema,
)
rewrite = MapResolveToMapWithCategoriesInstantiateSchema(
categories=["Positive", "Negative", "Neutral", "Mixed"],
category_key="sentiment",
new_prompt="""Classify the sentiment of this review into one of these categories:
- Positive: Clearly positive sentiment
- Negative: Clearly negative sentiment
- Neutral: No strong sentiment
- Mixed: Both positive and negative
- None of the above
Review: {{ input.review }}
Return exactly one category.""",
include_none_of_above=True,
)
result = directive.apply(
"gpt-4o-mini", ops_list, "extract_sentiment", "normalize_sentiment", rewrite
)
assert isinstance(result, list)
assert len(result) == 1 # map only, resolve removed
# Check the map operation was modified correctly
assert result[0]["name"] == "extract_sentiment"
assert result[0]["type"] == "map"
assert "Positive" in result[0]["prompt"]
assert "Negative" in result[0]["prompt"]
assert "None of the above" in result[0]["prompt"]
# Check validation was added
assert "validate" in result[0]
assert len(result[0]["validate"]) == 1
assert "Positive" in result[0]["validate"][0]
assert "None of the above" in result[0]["validate"][0]
def test_map_resolve_to_map_with_categories_no_none_of_above():
"""Test map_resolve_to_map_with_categories without 'None of the above' option"""
directive = MapResolveToMapWithCategoriesDirective()
ops_list = [
{
"name": "classify_type",
"type": "map",
"prompt": "What type is this? {{ input.text }}",
"model": "gpt-4o-mini",
"output": {"schema": {"item_type": "string"}},
},
{
"name": "normalize_type",
"type": "resolve",
"comparison_prompt": "Same type? {{ input1.item_type }} vs {{ input2.item_type }}",
"resolution_prompt": "Normalize: {% for input in inputs %}{{ input.item_type }}{% endfor %}",
"model": "gpt-4o-mini",
"output": {"schema": {"item_type": "string"}},
},
]
from docetl.reasoning_optimizer.instantiate_schemas import (
MapResolveToMapWithCategoriesInstantiateSchema,
)
rewrite = MapResolveToMapWithCategoriesInstantiateSchema(
categories=["TypeA", "TypeB", "TypeC"],
category_key="item_type",
new_prompt="""Classify into: TypeA, TypeB, or TypeC
Text: {{ input.text }}""",
include_none_of_above=False,
)
result = directive.apply(
"gpt-4o-mini", ops_list, "classify_type", "normalize_type", rewrite
)
assert len(result) == 1
# Check validation does NOT include 'None of the above'
assert "validate" in result[0]
assert "None of the above" not in result[0]["validate"][0]
assert "TypeA" in result[0]["validate"][0]
assert "TypeB" in result[0]["validate"][0]
assert "TypeC" in result[0]["validate"][0]
if __name__ == "__main__":
# Run all tests
test_chaining_apply()
@ -1312,16 +1078,4 @@ if __name__ == "__main__":
test_cascade_filtering_apply()
print("✅ Cascade filtering apply test passed")
test_map_to_map_resolve_reduce_apply()
print("✅ Map to map resolve reduce apply test passed")
test_map_to_map_resolve_reduce_apply_with_multiple_reduce_keys()
print("✅ Map to map resolve reduce with multiple reduce keys apply test passed")
test_map_resolve_to_map_with_categories_apply()
print("✅ Map resolve to map with categories apply test passed")
test_map_resolve_to_map_with_categories_no_none_of_above()
print("✅ Map resolve to map with categories (no none of above) apply test passed")
print("\n🎉 All directive apply tests passed!")

View File

@ -10,13 +10,14 @@ from typing import Dict, List
from datetime import datetime
from docetl.reasoning_optimizer.directives import (
ChainingDirective,
ChainingDirective,
GleaningDirective,
ReduceGleaningDirective,
ReduceChainingDirective,
ChangeModelDirective,
DocSummarizationDirective,
IsolatingSubtasksDirective,
DocCompressionDirective,
DeterministicDocCompressionDirective,
OperatorFusionDirective,
DocumentChunkingDirective,
@ -27,12 +28,8 @@ from docetl.reasoning_optimizer.directives import (
SwapWithCodeDirective,
HierarchicalReduceDirective,
CascadeFilteringDirective,
MapToMapResolveReduceDirective,
MapResolveToMapWithCategoriesDirective,
TestResult,
TestResult
)
# DocCompressionDirective is commented out in __init__.py, import directly
from docetl.reasoning_optimizer.directives.doc_compression import DocCompressionDirective
def run_all_directive_tests(agent_llm: str = "gpt-4.1") -> Dict[str, List[TestResult]]:
"""
@ -71,8 +68,6 @@ def run_all_directive_tests(agent_llm: str = "gpt-4.1") -> Dict[str, List[TestRe
SwapWithCodeDirective(),
HierarchicalReduceDirective(),
CascadeFilteringDirective(),
MapToMapResolveReduceDirective(),
MapResolveToMapWithCategoriesDirective(),
]
all_results = {}
@ -181,8 +176,6 @@ def run_specific_directive_test(directive_name: str, agent_llm: str = "gpt-4o-mi
"clarify_instructions": ClarifyInstructionsDirective(),
"swap_with_code": SwapWithCodeDirective(),
"cascade_filtering": CascadeFilteringDirective(),
"map_to_map_resolve_reduce": MapToMapResolveReduceDirective(),
"map_resolve_to_map_with_categories": MapResolveToMapWithCategoriesDirective(),
}
if directive_name.lower() not in directive_map:

View File

@ -698,7 +698,7 @@ wheels = [
[[package]]
name = "docetl"
version = "0.2.6"
version = "0.2.5"
source = { editable = "." }
dependencies = [
{ name = "asteval" },

View File

@ -5,9 +5,4 @@ export const API_ROUTES = {
CANCEL: (taskId: string) =>
`/api/shouldOptimize?taskId=${taskId}&cancel=true`,
},
DECOMPOSE: {
SUBMIT: "/api/decompose",
STATUS: (taskId: string) => `/api/decompose?taskId=${taskId}`,
CANCEL: (taskId: string) => `/api/decompose?taskId=${taskId}&cancel=true`,
},
} as const;

View File

@ -1,79 +0,0 @@
import { NextRequest, NextResponse } from "next/server";
const FASTAPI_URL = `${
process.env.NEXT_PUBLIC_BACKEND_HTTPS ? "https" : "http"
}://${process.env.NEXT_PUBLIC_BACKEND_HOST}:${
process.env.NEXT_PUBLIC_BACKEND_PORT
}`;
// Helper to handle errors consistently
const handleError = (error: unknown, status = 500) => {
const message =
error instanceof Error ? error.message : "Internal server error";
return NextResponse.json({ error: message }, { status });
};
// Helper to proxy requests to FastAPI
async function proxyRequest(path: string, init?: RequestInit) {
const response = await fetch(`${FASTAPI_URL}${path}`, {
...init,
headers: {
"Content-Type": "application/json",
...init?.headers,
},
});
if (!response.ok) {
const error = await response.text();
throw new Error(`FastAPI server error: ${error}`);
}
return response.json();
}
export async function POST(request: NextRequest): Promise<NextResponse> {
try {
// Extract task ID from the URL if it exists
const taskId = request.nextUrl.searchParams.get("taskId");
const isCancel = request.nextUrl.searchParams.get("cancel") === "true";
// Handle different POST scenarios
if (taskId) {
if (isCancel) {
// Cancel task
const data = await proxyRequest(`/decompose/${taskId}/cancel`, {
method: "POST",
});
return NextResponse.json(data);
}
// Invalid request with taskId but no cancel
return handleError(new Error("Invalid request"), 400);
}
// Submit new task
const body = await request.json();
const data = await proxyRequest("/decompose", {
method: "POST",
body: JSON.stringify(body),
});
return NextResponse.json(data, { status: 202 });
} catch (error) {
return handleError(error);
}
}
export async function GET(request: NextRequest): Promise<NextResponse> {
try {
// Extract task ID from the URL
const taskId = request.nextUrl.searchParams.get("taskId");
if (!taskId) {
return handleError(new Error("Task ID is required"), 400);
}
const data = await proxyRequest(`/decompose/${taskId}`);
return NextResponse.json(data);
} catch (error) {
return handleError(error);
}
}

View File

@ -33,77 +33,6 @@ export function getNamespaceDir(homeDir: string, namespace: string) {
return path.join(homeDir, ".docetl", namespace);
}
/**
* Safely parse a JSON value - returns the parsed value if it's a string,
* or the original value if it's already an object.
* @param value - The value to parse
* @param fieldName - Name of the field (for error messages)
* @returns The parsed value
*/
function safeParseJSON(
value: unknown,
fieldName: string
): Record<string, unknown> | unknown[] {
if (value === null || value === undefined) {
throw new Error(`Field '${fieldName}' is null or undefined`);
}
// If it's already an object or array, return as-is
if (typeof value === "object") {
return value as Record<string, unknown> | unknown[];
}
// If it's a string, try to parse it
if (typeof value === "string") {
try {
return JSON.parse(value);
} catch (e) {
throw new Error(
`Field '${fieldName}' contains invalid JSON: ${value.slice(0, 100)}${value.length > 100 ? "..." : ""}`
);
}
}
throw new Error(
`Field '${fieldName}' has unexpected type '${typeof value}': ${String(value).slice(0, 100)}`
);
}
/**
* Recursively sanitize an object for YAML serialization.
* Converts any nested [object Object] strings to proper objects.
*/
function sanitizeForYaml(obj: unknown, path: string = ""): unknown {
if (obj === null || obj === undefined) {
return obj;
}
if (typeof obj === "string") {
// Check for [object Object] which indicates a serialization issue
if (obj === "[object Object]") {
console.warn(
`Warning: Found '[object Object]' string at path '${path}'. This indicates a serialization issue.`
);
return {}; // Return empty object as fallback
}
return obj;
}
if (Array.isArray(obj)) {
return obj.map((item, idx) => sanitizeForYaml(item, `${path}[${idx}]`));
}
if (typeof obj === "object") {
const result: Record<string, unknown> = {};
for (const [key, value] of Object.entries(obj)) {
result[key] = sanitizeForYaml(value, path ? `${path}.${key}` : key);
}
return result;
}
return obj;
}
export function generatePipelineConfig(
namespace: string,
default_model: string,
@ -153,18 +82,9 @@ export function generatePipelineConfig(
const litellm_completion_kwargs =
op.otherKwargs?.litellm_completion_kwargs;
if (litellm_completion_kwargs) {
try {
op.otherKwargs.litellm_completion_kwargs = safeParseJSON(
litellm_completion_kwargs,
`${op.name}.litellm_completion_kwargs`
);
} catch (e) {
console.warn(
`Warning: Could not parse litellm_completion_kwargs for operation '${op.name}':`,
e
);
// Keep the original value if parsing fails
}
op.otherKwargs.litellm_completion_kwargs = JSON.parse(
litellm_completion_kwargs
);
}
const newOp: Record<string, unknown> = {
@ -255,14 +175,10 @@ export function generatePipelineConfig(
// If it's a sample operation with custom method, parse the samples as key-value pairs
if (op.type === "sample" && op.otherKwargs?.method === "custom") {
try {
newOp.samples = safeParseJSON(
op.otherKwargs.samples,
`${op.name}.samples`
);
newOp.samples = JSON.parse(op.otherKwargs.samples);
} catch (error) {
console.warn(
`Failed to parse custom samples for operation '${op.name}':`,
error
"Failed to parse custom samples as JSON, using raw value"
);
}
}
@ -301,12 +217,14 @@ export function generatePipelineConfig(
// Fix type errors by asserting the pipeline config type
let pipelineConfig: any = {
from_docwrangler: true,
optimizer_model: optimizerModel,
datasets,
default_model,
optimizer_config: {
rewrite_agent_model: optimizerModel,
...(enable_observability && { force_decompose: true }),
},
...(enable_observability && {
optimizer_config: {
force_decompose: true,
},
}),
operations: updatedOperations,
pipeline: {
steps: [
@ -390,9 +308,7 @@ export function generatePipelineConfig(
const outputOpName = operationsToRun[currentOpIndex].name;
outputPath = path.join(outputBase, "data_processing", outputOpName + ".json");
// Sanitize the config before serializing to YAML to catch any [object Object] issues
const sanitizedConfig = sanitizeForYaml(pipelineConfig, "pipelineConfig");
const yamlString = yaml.dump(sanitizedConfig);
const yamlString = yaml.dump(pipelineConfig);
console.log(yamlString);

View File

@ -104,7 +104,6 @@ export interface OptimizeResult {
should_optimize?: string;
input_data?: Array<Record<string, unknown>>;
output_data?: Array<Record<string, unknown>>;
num_docs_analyzed?: number;
cost?: number;
error?: string;
created_at: string;
@ -115,24 +114,3 @@ export interface APIKey {
name: string;
value: string;
}
export interface DecomposeRequest {
yaml_config: string;
step_name: string;
op_name: string;
}
export interface DecomposeResult {
task_id: string;
status: TaskStatus;
decomposed_operations?: Array<Record<string, unknown>>;
winning_directive?: string;
candidates_evaluated?: number;
original_outputs?: Array<Record<string, unknown>>;
decomposed_outputs?: Array<Record<string, unknown>>;
comparison_rationale?: string;
cost?: number;
error?: string;
created_at: string;
completed_at?: string;
}

View File

@ -24,14 +24,12 @@ interface AnsiRendererProps {
text: string;
readyState: number;
setTerminalOutput: (text: string) => void;
isDecomposing?: boolean;
}
const AnsiRenderer: React.FC<AnsiRendererProps> = ({
text,
readyState,
setTerminalOutput,
isDecomposing = false,
}) => {
const html = convert.toHtml(text);
const scrollRef = useRef<HTMLDivElement>(null);
@ -53,9 +51,7 @@ const AnsiRenderer: React.FC<AnsiRendererProps> = ({
}
};
// Consider connected if either WebSocket is open OR decomposing is in progress
const isConnected = readyState === WebSocket.OPEN || isDecomposing;
const isWebSocketClosed = readyState === WebSocket.CLOSED && !isDecomposing;
const isWebSocketClosed = readyState === WebSocket.CLOSED;
return (
<div
@ -96,11 +92,9 @@ const AnsiRenderer: React.FC<AnsiRendererProps> = ({
)}
</div>
<div className="flex justify-between items-center text-xs text-gray-500">
<div className={isWebSocketClosed ? "text-red-500" : isDecomposing ? "text-blue-400" : ""}>
<div className={isWebSocketClosed ? "text-red-500" : ""}>
Status:{" "}
{isDecomposing
? "Decomposing..."
: readyState === WebSocket.CONNECTING
{readyState === WebSocket.CONNECTING
? "Connecting"
: readyState === WebSocket.OPEN
? "Connected"
@ -112,7 +106,10 @@ const AnsiRenderer: React.FC<AnsiRendererProps> = ({
</div>
<button
onClick={() => setTerminalOutput("")}
className="hover:text-white transition-colors"
className={`hover:text-white transition-colors ${
isWebSocketClosed ? "cursor-not-allowed opacity-50" : ""
}`}
disabled={isWebSocketClosed}
>
Clear
</button>

View File

@ -1,271 +0,0 @@
import React from "react";
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogDescription,
} from "@/components/ui/dialog";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { Button } from "@/components/ui/button";
import { ChevronLeft, ChevronRight, Check, X } from "lucide-react";
import { DecomposeResult } from "@/app/types";
interface DecompositionComparisonDialogProps {
isOpen: boolean;
result: DecomposeResult | null;
operationName?: string;
onOpenChange: (open: boolean) => void;
onApply: () => void;
onCancel: () => void;
}
export const DecompositionComparisonDialog: React.FC<
DecompositionComparisonDialogProps
> = ({ isOpen, result, operationName, onOpenChange, onApply, onCancel }) => {
const [currentPage, setCurrentPage] = React.useState(1);
const originalOutputs = result?.original_outputs || [];
const decomposedOutputs = result?.decomposed_outputs || [];
const maxRows = Math.max(originalOutputs.length, decomposedOutputs.length);
const renderOutputTable = (
data: Array<Record<string, unknown>>,
title: string
) => {
if (!data.length) {
return (
<div className="text-center text-muted-foreground py-8">
No outputs available
</div>
);
}
const columns = Object.keys(data[0]).filter(
(column) => !column.startsWith("_observability")
);
const currentRow = data[currentPage - 1];
if (!currentRow) return null;
return (
<div className="space-y-2">
<h4 className="font-medium text-sm text-muted-foreground">{title}</h4>
<div className="border rounded-md">
<div className="max-h-[250px] overflow-auto">
<Table className="relative w-full border-collapse">
<TableHeader>
<TableRow className="sticky top-0 bg-background z-10 border-b">
{columns.map((column) => (
<TableHead
key={column}
className="h-8 px-3 text-left align-middle bg-background text-xs"
>
{column}
</TableHead>
))}
</TableRow>
</TableHeader>
<TableBody>
<TableRow>
{columns.map((column) => (
<TableCell key={column} className="p-2 align-top">
<pre className="whitespace-pre-wrap font-mono text-xs text-left max-w-[300px]">
{typeof currentRow[column] === "object"
? JSON.stringify(currentRow[column], null, 2)
: String(currentRow[column] ?? "")}
</pre>
</TableCell>
))}
</TableRow>
</TableBody>
</Table>
</div>
</div>
</div>
);
};
React.useEffect(() => {
if (isOpen) {
setCurrentPage(1);
}
}, [isOpen]);
if (!result) return null;
const isOriginalWinner = result.winning_directive === "original";
return (
<Dialog open={isOpen} onOpenChange={onOpenChange}>
<DialogContent className="max-w-7xl max-h-[90vh] flex flex-col">
<DialogHeader className="flex-shrink-0 border-b pb-4">
<DialogTitle className="text-xl">
Review Decomposition Results
</DialogTitle>
<DialogDescription>
Compare the original operation outputs with the decomposed version
before applying changes.
</DialogDescription>
</DialogHeader>
<div className="flex-1 overflow-y-auto py-4 space-y-6">
{/* Summary */}
<div className="p-4 bg-muted rounded-lg space-y-2">
<div className="flex items-center gap-4 flex-wrap">
{operationName && (
<div className="flex items-center">
<span className="font-medium text-sm mr-2">Operation:</span>
<span className="bg-primary/15 text-primary rounded-md px-2 py-1 text-sm">
{operationName}
</span>
</div>
)}
<div className="flex items-center">
<span className="font-medium text-sm mr-2">
Winning Strategy:
</span>
<span
className={`rounded-md px-2 py-1 text-sm ${
isOriginalWinner
? "bg-yellow-100 text-yellow-800"
: "bg-green-100 text-green-800"
}`}
>
{result.winning_directive}
</span>
</div>
<div className="flex items-center">
<span className="font-medium text-sm mr-2">
Candidates Evaluated:
</span>
<span className="text-sm">{result.candidates_evaluated}</span>
</div>
{result.cost && (
<div className="flex items-center">
<span className="font-medium text-sm mr-2">Cost:</span>
<span className="text-sm">${result.cost.toFixed(4)}</span>
</div>
)}
</div>
</div>
{/* Comparison Rationale */}
{result.comparison_rationale && (
<div className="space-y-2">
<h3 className="font-medium text-base">Why This Was Chosen</h3>
<div className="p-3 bg-blue-50 border border-blue-200 rounded-lg text-sm">
{result.comparison_rationale}
</div>
</div>
)}
{/* Side by side comparison */}
<div className="space-y-4">
<div className="flex items-center justify-between">
<h3 className="font-medium text-base">
Output Comparison (Sample {currentPage} of {maxRows || 1})
</h3>
<div className="flex items-center space-x-1">
<Button
variant="outline"
size="sm"
onClick={() =>
setCurrentPage((prev) => Math.max(prev - 1, 1))
}
disabled={currentPage === 1}
className="px-2 py-1"
>
<ChevronLeft className="h-4 w-4" />
Previous
</Button>
<Button
variant="outline"
size="sm"
onClick={() =>
setCurrentPage((prev) => Math.min(prev + 1, maxRows))
}
disabled={currentPage === maxRows || maxRows === 0}
className="px-2 py-1"
>
Next
<ChevronRight className="h-4 w-4" />
</Button>
</div>
</div>
<div className="grid grid-cols-2 gap-4">
{renderOutputTable(originalOutputs, "Original Output")}
{renderOutputTable(
decomposedOutputs,
`Decomposed Output (${result.winning_directive})`
)}
</div>
</div>
{/* Decomposed Operations Preview */}
{result.decomposed_operations &&
result.decomposed_operations.length > 0 &&
!isOriginalWinner && (
<div className="space-y-2">
<h3 className="font-medium text-base">
New Operations ({result.decomposed_operations.length})
</h3>
<div className="space-y-2">
{result.decomposed_operations.map((op, idx) => (
<div
key={idx}
className="p-3 border rounded-lg bg-muted/50"
>
<div className="flex items-center gap-2 mb-2">
<span className="font-medium text-sm">
{(op.name as string) || `Operation ${idx + 1}`}
</span>
<span className="text-xs bg-primary/10 text-primary px-2 py-0.5 rounded">
{op.type as string}
</span>
</div>
{op.prompt && (
<pre className="text-xs font-mono bg-background p-2 rounded border max-h-[100px] overflow-auto">
{String(op.prompt).slice(0, 500)}
{String(op.prompt).length > 500 ? "..." : ""}
</pre>
)}
</div>
))}
</div>
</div>
)}
</div>
<div className="flex justify-end items-center gap-3 pt-4 border-t mt-4">
<Button variant="outline" onClick={onCancel}>
<X className="h-4 w-4 mr-2" />
Keep Original
</Button>
<Button
onClick={onApply}
disabled={isOriginalWinner}
title={
isOriginalWinner
? "Original was determined to be the best option"
: undefined
}
>
<Check className="h-4 w-4 mr-2" />
Apply Decomposition
</Button>
</div>
</DialogContent>
</Dialog>
);
};
DecompositionComparisonDialog.displayName = "DecompositionComparisonDialog";

View File

@ -39,6 +39,7 @@ import { Skeleton } from "@/components/ui/skeleton";
import { debounce } from "lodash";
import { Guardrails, GleaningConfig } from "./operations/args";
import createOperationComponent from "./operations/components";
import { useWebSocket } from "@/contexts/WebSocketContext";
import { Badge } from "./ui/badge";
import {
Popover,
@ -160,7 +161,7 @@ const OperationHeader: React.FC<OperationHeaderProps> = React.memo(
}`}
/>
</HoverCardTrigger>
<HoverCardContent className="w-96 max-h-80" side="bottom" align="start">
<HoverCardContent className="w-72" side="bottom" align="start">
<div className="flex flex-col space-y-1">
<p className="text-sm font-medium">
{optimizeResult === undefined || optimizeResult === null
@ -169,15 +170,13 @@ const OperationHeader: React.FC<OperationHeaderProps> = React.memo(
? "Decomposition Status"
: "Decomposition Recommended"}
</p>
<div className="max-h-64 overflow-y-auto">
<p className="text-sm text-muted-foreground whitespace-pre-wrap">
{optimizeResult === undefined || optimizeResult === null
? "Analyzing operation complexity..."
: optimizeResult === ""
? "No decomposition needed for this operation"
: optimizeResult}
</p>
</div>
<p className="text-sm text-muted-foreground">
{optimizeResult === undefined || optimizeResult === null
? "Analyzing operation complexity..."
: optimizeResult === ""
? "No decomposition needed for this operation"
: "Recommended decomposition: " + optimizeResult}
</p>
</div>
</HoverCardContent>
</HoverCard>
@ -368,7 +367,7 @@ const OperationHeader: React.FC<OperationHeaderProps> = React.memo(
disabled={disabled}
>
<Zap className="mr-2 h-4 w-4" />
Decompose Operation
Optimize Operation
</Button>
)}
@ -721,6 +720,7 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
output: pipelineOutput,
setOutput,
isLoadingOutputs,
setIsLoadingOutputs,
numOpRun,
setNumOpRun,
currentFile,
@ -730,14 +730,18 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
sampleSize,
setCost,
defaultModel,
optimizerModel,
setTerminalOutput,
namespace,
apiKeys,
systemPrompt,
extraPipelineSettings,
onRequestDecompositionRef,
} = usePipelineContext();
const { toast } = useToast();
const operationRef = useRef(operation);
const { connect, sendMessage, lastMessage, readyState, disconnect } =
useWebSocket();
useEffect(() => {
operationRef.current = operation;
@ -804,19 +808,75 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
const handleOptimizeConfirm = useCallback(async () => {
if (!operation) return;
setShowOptimizeDialog(false);
try {
// Clear the output
setTerminalOutput("");
setIsLoadingOutputs(true);
// Use the decomposition flow from context if available
if (onRequestDecompositionRef.current) {
onRequestDecompositionRef.current(operation.id, operation.name);
} else {
// Write pipeline config
const response = await fetch("/api/writePipelineConfig", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
default_model: defaultModel,
data: { path: currentFile?.path || "" },
operations,
operation_id: operation.id,
name: pipelineName,
sample_size: sampleSize,
optimize: true,
clear_intermediate: false,
system_prompt: systemPrompt,
namespace: namespace,
apiKeys: apiKeys,
optimizerModel: optimizerModel,
extraPipelineSettings: extraPipelineSettings,
}),
});
if (!response.ok) {
throw new Error(await response.text());
}
const { filePath } = await response.json();
// Ensure WebSocket is connected
await connect();
// Send message to run the pipeline
sendMessage({
yaml_config: filePath,
optimize: true,
});
} catch (error) {
console.error("Error optimizing operation:", error);
toast({
title: "Error",
description: "Decomposition handler not available",
description: error.message,
variant: "destructive",
});
// Close the WebSocket connection
disconnect();
} finally {
setShowOptimizeDialog(false);
}
}, [operation, onRequestDecompositionRef, toast]);
}, [
operation,
defaultModel,
currentFile,
operations,
pipelineName,
sampleSize,
optimizerModel,
connect,
sendMessage,
systemPrompt,
namespace,
apiKeys,
extraPipelineSettings,
]);
const onShowOutput = useCallback(async () => {
if (!operation) return;
@ -1164,7 +1224,7 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle>Decompose Operation</AlertDialogTitle>
<AlertDialogTitle>Optimize Operation</AlertDialogTitle>
<AlertDialogDescription>
{!hasOpenAIKey && !isLocalMode ? (
<div className="space-y-2">
@ -1172,8 +1232,8 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
OpenAI API Key Required
</p>
<p>
To decompose operations, please add your OpenAI API key in
Edit {">"}
To use the optimizer, please add your OpenAI API key in Edit{" "}
{">"}
Edit API Keys.
</p>
<button
@ -1185,10 +1245,11 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
</div>
) : (
<p>
This will analyze the operation and try different
decomposition strategies (chaining, chunking, gleaning, etc.)
to improve accuracy. The best strategy will be selected via
LLM-as-a-judge comparison. Do you want to proceed?
This will analyze the operation and replace it with another
pipeline that has higher accuracy (as determined by an
LLM-as-a-judge), if it can be found. Do you want to proceed?
The process may take between 2 and 10 minutes, depending on
how complex your data is.
</p>
)}
</AlertDialogDescription>
@ -1199,7 +1260,7 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
onClick={handleOptimizeConfirm}
disabled={!hasOpenAIKey && !isLocalMode}
>
Decompose
Proceed
</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>

View File

@ -23,7 +23,6 @@ interface OptimizationDialogProps {
operationName?: string;
inputData?: Array<Record<string, unknown>>;
outputData?: Array<Record<string, unknown>>;
numDocsAnalyzed?: number;
onOpenChange: (open: boolean) => void;
onDecompose?: () => void;
}
@ -35,7 +34,6 @@ export const OptimizationDialog: React.FC<OptimizationDialogProps> = ({
inputData,
outputData,
operationName,
numDocsAnalyzed,
onOpenChange,
onDecompose,
}) => {
@ -142,9 +140,9 @@ export const OptimizationDialog: React.FC<OptimizationDialogProps> = ({
<DialogHeader className="flex-shrink-0 border-b pb-4">
<DialogTitle className="text-xl">Operation Too Complex</DialogTitle>
<p className="text-base mt-2">
After analyzing {numDocsAnalyzed ?? "several"} documents, this
operation might be too complex for the LLM to handle efficiently. We
recommend breaking it down into smaller, more manageable steps.
This operation might be too complex for the LLM to handle
efficiently. We recommend breaking it down into smaller, more
manageable steps.
</p>
</DialogHeader>

View File

@ -81,7 +81,6 @@ const useOutputContext = () => {
const {
output,
isLoadingOutputs,
isDecomposing,
terminalOutput,
setTerminalOutput,
optimizerProgress,
@ -92,7 +91,6 @@ const useOutputContext = () => {
return {
output,
isLoadingOutputs,
isDecomposing,
terminalOutput,
setTerminalOutput,
optimizerProgress,
@ -387,7 +385,7 @@ VisualizeContent.displayName = "VisualizeContent";
// Move ConsoleContent outside
export const ConsoleContent = memo(() => {
const { terminalOutput, setTerminalOutput, optimizerProgress, isDecomposing } =
const { terminalOutput, setTerminalOutput, optimizerProgress } =
useOutputContext();
const { readyState } = useWebSocket();
@ -471,7 +469,6 @@ export const ConsoleContent = memo(() => {
text={terminalOutput || ""}
readyState={readyState}
setTerminalOutput={setTerminalOutput}
isDecomposing={isDecomposing}
/>
</div>
</div>
@ -481,7 +478,7 @@ ConsoleContent.displayName = "ConsoleContent";
// Main Output component
export const Output = memo(() => {
const { output, isLoadingOutputs, isDecomposing, sampleSize, operations } =
const { output, isLoadingOutputs, sampleSize, operations } =
useOutputContext();
const operation = useOperation(output?.operationId);
@ -503,14 +500,10 @@ export const Output = memo(() => {
);
}, [operation]);
// Effect for tab changes - switch to console when loading or decomposing
// Effect for tab changes
useEffect(() => {
if (isLoadingOutputs || isDecomposing) {
setActiveTab("console");
} else {
setActiveTab("table");
}
}, [isLoadingOutputs, isDecomposing]);
setActiveTab(isLoadingOutputs ? "console" : "table");
}, [isLoadingOutputs]);
// Memoize columns
const columns = useMemo(() => {

View File

@ -40,12 +40,9 @@ import { Input } from "@/components/ui/input";
import { schemaDictToItemSet } from "./utils";
import { v4 as uuidv4 } from "uuid";
import { useOptimizeCheck } from "@/hooks/useOptimizeCheck";
import { useDecomposeWebSocket } from "@/hooks/useDecomposeWebSocket";
import { canBeOptimized } from "@/lib/utils";
import { Textarea } from "./ui/textarea";
import { OptimizationDialog } from "@/components/OptimizationDialog";
import { DecompositionComparisonDialog } from "@/components/DecompositionComparisonDialog";
import { DecomposeResult } from "@/app/types";
import {
Popover,
PopoverContent,
@ -243,9 +240,6 @@ const PipelineGUI: React.FC = () => {
apiKeys,
extraPipelineSettings,
setExtraPipelineSettings,
isDecomposing,
setIsDecomposing,
onRequestDecompositionRef,
} = usePipelineContext();
const [isSettingsOpen, setIsSettingsOpen] = useState(false);
const { toast } = useToast();
@ -259,14 +253,12 @@ const PipelineGUI: React.FC = () => {
outputData?: Array<Record<string, unknown>>;
operationName?: string;
operationId?: string;
numDocsAnalyzed?: number;
}>({
isOpen: false,
content: "",
prompt: undefined,
operationName: undefined,
operationId: undefined,
numDocsAnalyzed: undefined,
});
const [isEditingName, setIsEditingName] = useState(false);
const [editedPipelineName, setEditedPipelineName] = useState(pipelineName);
@ -285,13 +277,14 @@ const PipelineGUI: React.FC = () => {
setCost((prev) => prev + result.cost);
if (result.should_optimize) {
const lastOp = operations[operations.length - 1];
const { dismiss } = toast({
title: `Hey! Consider decomposing ${lastOp.name}`,
toast({
title: `Hey! Consider decomposing ${operations[operations.length - 1].name
}`,
description: (
<span
className="cursor-pointer text-blue-500 hover:text-blue-700"
onClick={() => {
const lastOp = operations[operations.length - 1];
setOptimizationDialog({
isOpen: true,
content: result.should_optimize,
@ -300,9 +293,7 @@ const PipelineGUI: React.FC = () => {
operationId: lastOp.id,
inputData: result.input_data,
outputData: result.output_data,
numDocsAnalyzed: result.num_docs_analyzed,
});
dismiss();
}}
>
Click here to see why.
@ -321,144 +312,6 @@ const PipelineGUI: React.FC = () => {
},
});
const [decomposeDialog, setDecomposeDialog] = useState<{
isOpen: boolean;
result: DecomposeResult | null;
targetOpId: string | null;
operationName: string | null;
}>({
isOpen: false,
result: null,
targetOpId: null,
operationName: null,
});
// Ref to store the current decomposition target (avoids stale closure issue)
const decompositionTargetRef = useRef<{
operationId: string;
operationName: string;
} | null>(null);
const { startDecomposition, cancel: cancelDecomposition } = useDecomposeWebSocket({
namespace,
onOutput: (output) => {
// Stream output to the terminal
setTerminalOutput(output);
},
onComplete: (result) => {
setIsDecomposing(false);
setCost((prev) => prev + (result.cost || 0));
// Read from ref to avoid stale closure
const target = decompositionTargetRef.current;
console.log("Decomposition complete:", {
result,
target,
decomposed_operations: result.decomposed_operations,
});
// Show the comparison dialog instead of auto-applying
setDecomposeDialog({
isOpen: true,
result,
targetOpId: target?.operationId || null,
operationName: target?.operationName || null,
});
},
onError: (error) => {
setIsDecomposing(false);
toast({
title: "Decomposition Failed",
description: error,
variant: "destructive",
});
},
});
const handleApplyDecomposition = () => {
const { result, targetOpId } = decomposeDialog;
console.log("handleApplyDecomposition called:", { result, targetOpId, decomposeDialog });
if (!result || !targetOpId) {
console.warn("handleApplyDecomposition: missing result or targetOpId", { result, targetOpId });
return;
}
if (
result.decomposed_operations &&
result.decomposed_operations.length > 0 &&
result.winning_directive !== "original"
) {
setOperations((prev) => {
// Find the index of the target operation
const targetIdx = prev.findIndex((op) => op.id === targetOpId);
if (targetIdx === -1) return prev;
// Convert decomposed operations to our Operation format
const newOps: Operation[] = result.decomposed_operations!.map(
(opConfig) => ({
id: uuidv4(),
llmType:
opConfig.type === "map" ||
opConfig.type === "reduce" ||
opConfig.type === "filter" ||
opConfig.type === "parallel_map"
? "LLM"
: "non-LLM",
type: opConfig.type as Operation["type"],
name: opConfig.name as string,
prompt: opConfig.prompt as string | undefined,
output: opConfig.output
? {
schema: schemaDictToItemSet(
(opConfig.output as { schema: Record<string, string> })
.schema
),
}
: undefined,
gleaning: opConfig.gleaning as Operation["gleaning"],
otherKwargs: Object.fromEntries(
Object.entries(opConfig).filter(
([key]) =>
!["name", "type", "prompt", "output", "gleaning"].includes(key)
)
),
visibility: true,
})
);
// Replace the target operation with the new operations
const newOperations = [...prev];
newOperations.splice(targetIdx, 1, ...newOps);
return newOperations;
});
toast({
title: "Decomposition Applied",
description: `Applied ${result.winning_directive} directive. Created ${result.decomposed_operations.length} operation(s).`,
});
}
setDecomposeDialog({
isOpen: false,
result: null,
targetOpId: null,
operationName: null,
});
};
const handleCancelDecomposition = () => {
toast({
title: "Decomposition Cancelled",
description: "Keeping the original operation.",
});
setDecomposeDialog({
isOpen: false,
result: null,
targetOpId: null,
operationName: null,
});
};
const { restoreFromYAML } = useRestorePipeline({
setOperations,
setPipelineName,
@ -475,18 +328,79 @@ const PipelineGUI: React.FC = () => {
if (lastMessage) {
if (lastMessage.type === "output") {
setTerminalOutput(lastMessage.data);
} else if (lastMessage.type === "optimizer_progress") {
setOptimizerProgress({
status: lastMessage.status,
progress: lastMessage.progress,
shouldOptimize: lastMessage.should_optimize,
rationale: lastMessage.rationale,
validatorPrompt: lastMessage.validator_prompt,
});
} else if (lastMessage.type === "result") {
const runCost = lastMessage.data.cost || 0;
setOptimizerProgress(null);
// Trigger should_optimize check for the last operation if enabled
if (autoOptimizeCheck) {
const lastOp = operations[operations.length - 1];
if (lastOp && canBeOptimized(lastOp.type)) {
submitTask({
yaml_config: lastMessage.data.yaml_config,
step_name: "data_processing",
op_name: lastOp.name,
});
// See if there was an optimized operation
const optimizedOps = lastMessage.data.optimized_ops;
if (optimizedOps) {
const newOperations = optimizedOps.map((optimizedOp) => {
const {
id,
type,
name,
prompt,
output,
validate,
gleaning,
sample,
...otherKwargs
} = optimizedOp;
// Find matching operation in previous operations list
const existingOp = operations.find((op) => op.name === name);
return {
id: id || uuidv4(),
llmType:
type === "map" ||
type === "reduce" ||
type === "resolve" ||
type === "filter" ||
type === "parallel_map" ||
type === "rank" ||
type === "extract"
? "LLM"
: "non-LLM",
type: type,
name: name || "Untitled Operation",
prompt: prompt,
output: output
? {
schema: schemaDictToItemSet(output.schema),
}
: undefined,
validate: validate,
gleaning: gleaning,
sample: sample,
otherKwargs: otherKwargs || {},
...(existingOp?.runIndex && { runIndex: existingOp.runIndex }),
visibility: true,
} as Operation;
});
setOperations(newOperations);
} else {
// No optimized operations, so we need to check if we should optimize the last operation
// Trigger should optimize for the last operation
if (autoOptimizeCheck) {
const lastOp = operations[operations.length - 1];
if (lastOp && canBeOptimized(lastOp.type)) {
submitTask({
yaml_config: lastMessage.data.yaml_config,
step_name: "data_processing", // TODO: Make this a constant
op_name: lastOp.name,
});
}
}
}
@ -778,35 +692,20 @@ const PipelineGUI: React.FC = () => {
};
const handleStop = () => {
// Stop pipeline run if running
if (isLoadingOutputs) {
sendMessage("kill");
if (readyState === WebSocket.CLOSED) {
setIsLoadingOutputs(false);
}
}
sendMessage("kill");
// Stop decomposition if running
if (isDecomposing) {
cancelDecomposition();
setIsDecomposing(false);
if (readyState === WebSocket.CLOSED && isLoadingOutputs) {
setIsLoadingOutputs(false);
}
};
const handleOptimizeFromDialog = async () => {
if (!optimizationDialog.operationId || !optimizationDialog.operationName) return;
// Store target in ref so onComplete callback can access it
decompositionTargetRef.current = {
operationId: optimizationDialog.operationId,
operationName: optimizationDialog.operationName,
};
if (!optimizationDialog.operationId) return;
try {
setIsDecomposing(true);
setTerminalOutput(""); // Clear terminal output
setTerminalOutput("");
setIsLoadingOutputs(true);
// Write the pipeline config to get a YAML file path
const response = await fetch("/api/writePipelineConfig", {
method: "POST",
headers: {
@ -833,115 +732,24 @@ const PipelineGUI: React.FC = () => {
const { filePath } = await response.json();
// Use the WebSocket decompose endpoint for streaming output
await startDecomposition(
filePath,
"data_processing", // Matches the step name in generatePipelineConfig
optimizationDialog.operationName
);
await connect();
toast({
title: "Decomposing Operation",
description: "Analyzing and generating candidate decompositions. Watch the console for progress.",
sendMessage({
yaml_config: filePath,
optimize: true,
});
} catch (error) {
console.error("Error starting decomposition:", error);
setIsDecomposing(false);
console.error("Error optimizing operation:", error);
toast({
title: "Error",
description: error instanceof Error ? error.message : "Failed to start decomposition",
description: error.message,
variant: "destructive",
});
disconnect();
setIsLoadingOutputs(false);
}
};
// Handler for decomposition requests from OperationCard dropdown
const handleDecompositionRequest = useCallback(
async (operationId: string, operationName: string) => {
// Store target in ref so onComplete callback can access it
decompositionTargetRef.current = {
operationId,
operationName,
};
try {
setIsDecomposing(true);
setTerminalOutput(""); // Clear terminal output
// Set the dialog state so we know which operation we're decomposing
setDecomposeDialog((prev) => ({
...prev,
targetOpId: operationId,
operationName: operationName,
}));
// Write the pipeline config to get a YAML file path
const response = await fetch("/api/writePipelineConfig", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
default_model: defaultModel,
data: { path: currentFile?.path || "" },
operations,
operation_id: operationId,
name: pipelineName,
sample_size: sampleSize,
optimize: true,
namespace: namespace,
apiKeys: apiKeys,
optimizerModel: optimizerModel,
extraPipelineSettings: extraPipelineSettings,
}),
});
if (!response.ok) {
throw new Error(await response.text());
}
const { filePath } = await response.json();
// Start decomposition via WebSocket
startDecomposition(filePath, "data_processing", operationName);
toast({
title: "Decomposing Operation",
description:
"Analyzing and generating candidate decompositions. Watch the console for progress.",
});
} catch (error) {
console.error("Error starting decomposition:", error);
setIsDecomposing(false);
toast({
title: "Error",
description:
error instanceof Error ? error.message : "Failed to start decomposition",
variant: "destructive",
});
}
},
[
defaultModel,
currentFile,
operations,
pipelineName,
sampleSize,
namespace,
apiKeys,
optimizerModel,
extraPipelineSettings,
startDecomposition,
setIsDecomposing,
setTerminalOutput,
toast,
]
);
// Register the decomposition handler in context so OperationCard can use it
// Using ref assignment instead of state to avoid infinite loops
onRequestDecompositionRef.current = handleDecompositionRequest;
return (
<div className="flex flex-col h-full">
<div
@ -1245,7 +1053,7 @@ const PipelineGUI: React.FC = () => {
variant="destructive"
className="rounded-sm whitespace-nowrap"
onClick={handleStop}
disabled={!isLoadingOutputs && !isDecomposing}
disabled={!isLoadingOutputs}
>
<StopCircle size={16} className="mr-2" />
Stop
@ -1336,22 +1144,11 @@ const PipelineGUI: React.FC = () => {
operationName={optimizationDialog.operationName}
inputData={optimizationDialog.inputData}
outputData={optimizationDialog.outputData}
numDocsAnalyzed={optimizationDialog.numDocsAnalyzed}
onOpenChange={(open) =>
setOptimizationDialog((prev) => ({ ...prev, isOpen: open }))
}
onDecompose={handleOptimizeFromDialog}
/>
<DecompositionComparisonDialog
isOpen={decomposeDialog.isOpen}
result={decomposeDialog.result}
operationName={decomposeDialog.operationName || undefined}
onOpenChange={(open) =>
setDecomposeDialog((prev) => ({ ...prev, isOpen: open }))
}
onApply={handleApplyDecomposition}
onCancel={handleCancelDecomposition}
/>
</div>
);
};

View File

@ -31,7 +31,6 @@ interface PromptInputProps {
onChange: (value: string) => void;
disableValidation?: boolean;
placeholder?: string;
operationType?: "map" | "reduce" | "filter" | "resolve" | "other";
}
export const PromptInput: React.FC<PromptInputProps> = React.memo(
@ -39,62 +38,31 @@ export const PromptInput: React.FC<PromptInputProps> = React.memo(
prompt,
onChange,
disableValidation = false,
placeholder = "Enter prompt (Jinja2 template optional)",
operationType = "map",
placeholder = "Enter prompt (must be a Jinja2 template)",
}) => {
const hasJinjaTemplate = (value: string) => {
const validateJinjaTemplate = (value: string) => {
if (disableValidation) return true;
const hasOpenBrace = value.includes("{{");
const hasCloseBrace = value.includes("}}");
return hasOpenBrace && hasCloseBrace;
};
const isReduce = operationType === "reduce";
const templateVar = isReduce ? "inputs" : "input";
const exampleKey = isReduce ? "inputs[0].content" : "input.content";
return (
<>
<Textarea
placeholder={placeholder}
className={`mb-1 rounded-sm text-sm font-mono ${
!hasJinjaTemplate(prompt) ? "border-amber-400" : ""
!validateJinjaTemplate(prompt) ? "border-red-500" : ""
}`}
rows={Math.max(3, Math.ceil(prompt.split("\n").length))}
value={prompt}
onChange={(e) => onChange(e.target.value)}
/>
{!hasJinjaTemplate(prompt) && (
<div className="text-amber-600 text-sm mb-1 space-y-1">
<div>
<strong>Warning:</strong> No Jinja2 template found.
</div>
<div className="text-xs text-muted-foreground">
{isReduce ? (
<>
{"{"}
{"{"} inputs {"}"}
{"}"} (the list of all documents) will be auto-appended. For
example, if documents have keys {'"'}id{'"'} and {'"'}content
{'"'}, all docs with all keys will be included. To reference
specific keys, use a for loop: {"{"}% for doc in inputs %{"}"}{" "}
{"{"}
{"{"} doc.content {"}"}
{"}"} {"{"}% endfor %{"}"}.
</>
) : (
<>
{"{"}
{"{"} input {"}"}
{"}"} (the full document) will be auto-appended. For example,
if documents have keys {'"'}id{'"'}, {'"'}title{'"'}, and {'"'}
content{'"'}, all three will be included. To use only specific
keys (e.g., just {'"'}content{'"'}), add {"{"}
{"{"} input.content {"}"}
{"}"} to your prompt.
</>
)}
</div>
{!validateJinjaTemplate(prompt) && (
<div className="text-red-500 text-sm mb-1">
Prompt must contain Jinja2 template syntax {"{"}
{"{"} and {"}"}
{"}"}
</div>
)}
</>
@ -109,7 +77,6 @@ PromptInput.propTypes = {
onChange: PropTypes.func.isRequired,
disableValidation: PropTypes.bool,
placeholder: PropTypes.string,
operationType: PropTypes.oneOf(["map", "reduce", "filter", "resolve", "other"]),
};
interface SchemaFormProps {

View File

@ -314,7 +314,6 @@ export const ReduceOperationComponent: React.FC<OperationComponentProps> = ({
<PromptInput
prompt={operation.prompt || ""}
onChange={handlePromptChange}
operationType="reduce"
/>
<OutputSchema
schema={schemaItems}

View File

@ -26,7 +26,7 @@ const ToastViewport = React.forwardRef<
ToastViewport.displayName = ToastPrimitives.Viewport.displayName;
const toastVariants = cva(
"group pointer-events-auto relative flex w-full items-start justify-between space-x-2 overflow-hidden max-h-[80vh] whitespace-pre-wrap break-words rounded-md border p-4 pr-6 shadow-lg transition-all data-[swipe=cancel]:translate-x-0 data-[swipe=end]:translate-x-[var(--radix-toast-swipe-end-x)] data-[swipe=move]:translate-x-[var(--radix-toast-swipe-move-x)] data-[swipe=move]:transition-none data-[state=open]:animate-in data-[state=closed]:animate-out data-[swipe=end]:animate-out data-[state=closed]:fade-out-80 data-[state=closed]:slide-out-to-right-full data-[state=open]:slide-in-from-top-full data-[state=open]:sm:slide-in-from-bottom-full",
"group pointer-events-auto relative flex w-full items-center justify-between space-x-2 overflow-y-auto max-h-[80vh] whitespace-pre-wrap break-words rounded-md border p-4 pr-6 shadow-lg transition-all data-[swipe=cancel]:translate-x-0 data-[swipe=end]:translate-x-[var(--radix-toast-swipe-end-x)] data-[swipe=move]:translate-x-[var(--radix-toast-swipe-move-x)] data-[swipe=move]:transition-none data-[state=open]:animate-in data-[state=closed]:animate-out data-[swipe=end]:animate-out data-[state=closed]:fade-out-80 data-[state=closed]:slide-out-to-right-full data-[state=open]:slide-in-from-top-full data-[state=open]:sm:slide-in-from-bottom-full",
{
variants: {
variant: {

View File

@ -22,33 +22,29 @@ export function Toaster() {
const descId = `toast-desc-${id}`;
return (
<Toast key={id} {...props}>
<div className="grid gap-1 flex-1 overflow-hidden">
<div className="flex items-center gap-2 shrink-0">
{title && <ToastTitle>{title}</ToastTitle>}
{description && (
<Button
variant="ghost"
size="sm"
className="h-5 w-5 p-0 opacity-70 hover:opacity-100"
onClick={() => {
const el = document.getElementById(descId);
const text = el?.textContent ?? "";
if (text) {
navigator.clipboard.writeText(text);
}
}}
title="Copy message"
>
<Copy size={12} />
</Button>
)}
</div>
<div className="grid gap-1">
{title && <ToastTitle>{title}</ToastTitle>}
{description && (
<div className="overflow-y-auto max-h-[60vh]">
<ToastDescription id={descId}>{description}</ToastDescription>
</div>
<ToastDescription id={descId}>{description}</ToastDescription>
)}
</div>
{description && (
<Button
variant="ghost"
size="sm"
className="h-6 px-2 text-xs"
onClick={() => {
const el = document.getElementById(descId);
const text = el?.textContent ?? "";
if (text) {
navigator.clipboard.writeText(text);
}
}}
title="Copy message"
>
<Copy size={12} />
</Button>
)}
{action}
<ToastClose />
</Toast>

View File

@ -29,7 +29,6 @@ interface PipelineState {
validatorPrompt: string;
} | null;
isLoadingOutputs: boolean;
isDecomposing: boolean;
numOpRun: number;
pipelineName: string;
sampleSize: number | null;
@ -60,7 +59,6 @@ interface PipelineContextType extends PipelineState {
} | null>
>;
setIsLoadingOutputs: React.Dispatch<React.SetStateAction<boolean>>;
setIsDecomposing: React.Dispatch<React.SetStateAction<boolean>>;
setNumOpRun: React.Dispatch<React.SetStateAction<number>>;
setPipelineName: React.Dispatch<React.SetStateAction<string>>;
setSampleSize: React.Dispatch<React.SetStateAction<number | null>>;
@ -85,10 +83,6 @@ interface PipelineContextType extends PipelineState {
setExtraPipelineSettings: React.Dispatch<
React.SetStateAction<Record<string, unknown> | null>
>;
// Ref for triggering decomposition from OperationCard (using ref to avoid infinite loops)
onRequestDecompositionRef: React.MutableRefObject<
((operationId: string, operationName: string) => void) | null
>;
}
const PipelineContext = createContext<PipelineContextType | undefined>(
@ -296,7 +290,6 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
localStorageKeys.IS_LOADING_OUTPUTS_KEY,
false
),
isDecomposing: false, // Transient state, not persisted
numOpRun: loadFromLocalStorage(localStorageKeys.NUM_OP_RUN_KEY, 0),
pipelineName: loadFromLocalStorage(
localStorageKeys.PIPELINE_NAME_KEY,
@ -340,11 +333,6 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
const stateRef = useRef(state);
const [isMounted, setIsMounted] = useState(false);
// Ref for triggering decomposition from OperationCard (using ref to avoid infinite loops)
const onRequestDecompositionRef = useRef<
((operationId: string, operationName: string) => void) | null
>(null);
useEffect(() => {
stateRef.current = state;
}, [state]);
@ -435,7 +423,6 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
output: null,
terminalOutput: "",
isLoadingOutputs: false,
isDecomposing: false,
numOpRun: 0,
pipelineName: mockPipelineName,
sampleSize: mockSampleSize,
@ -542,10 +529,6 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
(value) => setStateAndUpdate("isLoadingOutputs", value),
[setStateAndUpdate]
),
setIsDecomposing: useCallback(
(value) => setStateAndUpdate("isDecomposing", value),
[setStateAndUpdate]
),
setNumOpRun: useCallback(
(value) => setStateAndUpdate("numOpRun", value),
[setStateAndUpdate]
@ -606,7 +589,6 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
(value) => setStateAndUpdate("extraPipelineSettings", value),
[setStateAndUpdate]
),
onRequestDecompositionRef,
};
return (

View File

@ -1,104 +0,0 @@
import { useState, useEffect } from "react";
import axios from "axios";
import { DecomposeResult, DecomposeRequest } from "@/app/types";
import { API_ROUTES } from "@/app/api/constants";
interface UseDecomposeProps {
onComplete?: (result: DecomposeResult) => void;
onError?: (error: string) => void;
pollInterval?: number;
}
export function useDecompose({
onComplete,
onError,
pollInterval = 2000,
}: UseDecomposeProps = {}) {
const [taskId, setTaskId] = useState<string | null>(null);
const [result, setResult] = useState<DecomposeResult | null>(null);
const [error, setError] = useState<string | null>(null);
const [isLoading, setIsLoading] = useState(false);
const submitTask = async (request: DecomposeRequest) => {
try {
setIsLoading(true);
setError(null);
setResult(null);
const response = await axios.post<{ task_id: string }>(
API_ROUTES.DECOMPOSE.SUBMIT,
request
);
setTaskId(response.data.task_id);
} catch (err) {
const errorMessage =
err instanceof Error ? err.message : "Failed to submit decompose task";
setError(errorMessage);
onError?.(errorMessage);
setIsLoading(false);
}
};
const cancelTask = async () => {
if (!taskId) return;
try {
await axios.post(API_ROUTES.DECOMPOSE.CANCEL(taskId));
setTaskId(null);
setIsLoading(false);
} catch (err) {
const errorMessage =
err instanceof Error ? err.message : "Failed to cancel task";
setError(errorMessage);
onError?.(errorMessage);
}
};
useEffect(() => {
if (!taskId) return;
const pollTask = async () => {
try {
const response = await axios.get<DecomposeResult>(
API_ROUTES.DECOMPOSE.STATUS(taskId)
);
setResult(response.data);
if (
["completed", "failed", "cancelled"].includes(response.data.status)
) {
setTaskId(null);
setIsLoading(false);
if (response.data.status === "completed") {
onComplete?.(response.data);
} else if (response.data.status === "failed" && response.data.error) {
setError(response.data.error);
onError?.(response.data.error);
}
}
} catch (err) {
const errorMessage =
err instanceof Error ? err.message : "Failed to fetch task status";
setError(errorMessage);
onError?.(errorMessage);
setTaskId(null);
setIsLoading(false);
}
};
const interval = setInterval(pollTask, pollInterval);
return () => clearInterval(interval);
}, [taskId, onComplete, onError, pollInterval]);
return {
submitTask,
cancelTask,
result,
error,
isLoading,
isRunning: !!taskId,
};
}

View File

@ -1,145 +0,0 @@
import { useState, useRef, useCallback } from "react";
import { DecomposeResult } from "@/app/types";
interface DecomposeWebSocketMessage {
type: "output" | "result" | "error";
data?: any;
message?: string;
traceback?: string;
}
interface UseDecomposeWebSocketProps {
namespace: string;
onOutput?: (output: string) => void;
onComplete?: (result: DecomposeResult) => void;
onError?: (error: string) => void;
}
export function useDecomposeWebSocket({
namespace,
onOutput,
onComplete,
onError,
}: UseDecomposeWebSocketProps) {
const [isConnected, setIsConnected] = useState(false);
const [isRunning, setIsRunning] = useState(false);
const ws = useRef<WebSocket | null>(null);
const connect = useCallback((): Promise<void> => {
return new Promise((resolve, reject) => {
if (ws.current?.readyState === WebSocket.OPEN) {
resolve();
return;
}
if (!namespace) {
reject(new Error("Namespace is required for WebSocket connection"));
return;
}
const wsUrl = `${
process.env.NEXT_PUBLIC_BACKEND_HTTPS === "true" ? "wss://" : "ws://"
}${process.env.NEXT_PUBLIC_BACKEND_HOST}:${
process.env.NEXT_PUBLIC_BACKEND_PORT
}/ws/decompose/${namespace}`;
ws.current = new WebSocket(wsUrl);
ws.current.onopen = () => {
setIsConnected(true);
resolve();
};
ws.current.onclose = () => {
setIsConnected(false);
setIsRunning(false);
};
ws.current.onerror = (error: Event) => {
console.error("Decompose WebSocket error:", error);
onError?.("WebSocket connection failed");
reject(new Error("WebSocket connection failed"));
};
ws.current.onmessage = (event) => {
try {
const message: DecomposeWebSocketMessage = JSON.parse(event.data);
if (message.type === "output") {
onOutput?.(message.data);
} else if (message.type === "result") {
setIsRunning(false);
// Convert the result data to DecomposeResult format
const result: DecomposeResult = {
task_id: "", // Not used in WebSocket mode
status: "completed",
decomposed_operations: message.data.decomposed_operations,
winning_directive: message.data.winning_directive,
candidates_evaluated: message.data.candidates_evaluated,
original_outputs: message.data.original_outputs,
decomposed_outputs: message.data.decomposed_outputs,
comparison_rationale: message.data.comparison_rationale,
cost: message.data.cost,
created_at: new Date().toISOString(),
completed_at: new Date().toISOString(),
};
onComplete?.(result);
// Close the connection after receiving result
ws.current?.close();
} else if (message.type === "error") {
setIsRunning(false);
const errorMsg = message.data || message.message || "Unknown error";
onError?.(errorMsg);
ws.current?.close();
}
} catch (error) {
console.error("Error parsing WebSocket message:", error);
onError?.("Failed to parse WebSocket message");
}
};
});
}, [namespace, onOutput, onComplete, onError]);
const startDecomposition = useCallback(
async (yamlConfig: string, stepName: string, opName: string) => {
try {
await connect();
setIsRunning(true);
if (ws.current?.readyState === WebSocket.OPEN) {
ws.current.send(
JSON.stringify({
yaml_config: yamlConfig,
step_name: stepName,
op_name: opName,
})
);
}
} catch (error) {
setIsRunning(false);
throw error;
}
},
[connect]
);
const cancel = useCallback(() => {
if (ws.current?.readyState === WebSocket.OPEN) {
ws.current.send(JSON.stringify("kill"));
}
}, []);
const disconnect = useCallback(() => {
if (ws.current) {
ws.current.close();
}
}, []);
return {
isConnected,
isRunning,
startDecomposition,
cancel,
disconnect,
};
}

View File

@ -6,8 +6,7 @@ export function cn(...inputs: ClassValue[]) {
}
export function canBeOptimized(operationType: string) {
// Only map operations can be decomposed
return operationType === "map";
return ["resolve", "map", "reduce", "filter"].includes(operationType);
}
export const generateId = () => {