Compare commits
2 Commits
main
...
epsteinsor
| Author | SHA1 | Date |
|---|---|---|
|
|
7847b26743 | |
|
|
8dbe96d87c |
|
|
@ -1,776 +0,0 @@
|
|||
---
|
||||
name: docetl
|
||||
description: Build and run LLM-powered data processing pipelines with DocETL. Use when users say "docetl", want to analyze unstructured data, process documents, extract information, or run ETL tasks on text. Helps with data collection, pipeline creation, execution, and optimization.
|
||||
---
|
||||
|
||||
# DocETL Pipeline Development
|
||||
|
||||
DocETL is a system for creating LLM-powered data processing pipelines. This skill helps you build end-to-end pipelines: from data preparation to execution and optimization.
|
||||
|
||||
## Workflow Overview: Iterative Data Analysis
|
||||
|
||||
Work like a data analyst: **write → run → inspect → iterate**. Never write all scripts at once and run them all at once. Each phase should be completed and validated before moving to the next.
|
||||
|
||||
### Phase 1: Data Collection
|
||||
1. Write data collection script
|
||||
2. **Run it immediately** (with user permission)
|
||||
3. **Inspect the dataset** - show the user:
|
||||
- Total document count
|
||||
- Keys/fields in each document
|
||||
- Sample documents (first 3-5)
|
||||
- Length distribution (avg chars, min/max)
|
||||
- Any other relevant statistics
|
||||
4. Iterate if needed (e.g., collect more data, fix parsing issues)
|
||||
|
||||
### Phase 2: Pipeline Development
|
||||
1. Read sample documents to understand format
|
||||
2. Write pipeline YAML with `sample: 10-20` for testing
|
||||
3. **Run the test pipeline**
|
||||
4. **Inspect intermediate results** - show the user:
|
||||
- Extraction quality on samples
|
||||
- Domain/category distributions
|
||||
- Any validation failures
|
||||
5. Iterate on prompts/schema based on results
|
||||
6. Remove `sample` parameter and run full pipeline
|
||||
7. **Show final results** - distributions, trends, key insights
|
||||
|
||||
### Phase 3: Visualization & Presentation
|
||||
1. Write visualization script based on actual output structure
|
||||
2. **Run and show the report** to the user
|
||||
3. Iterate on charts/tables if needed
|
||||
|
||||
**Visualization Aesthetics:**
|
||||
- **Clean and minimalist** - no clutter, generous whitespace
|
||||
- **Warm and elegant color theme** - 1-2 accent colors max
|
||||
- **Subtle borders** - not too rounded (border-radius: 8-10px max)
|
||||
- **Sans-serif fonts** - system fonts like -apple-system, Segoe UI, Roboto
|
||||
- **"Created by DocETL"** - add subtitle after the main title
|
||||
- **Mix of charts and tables** - charts for distributions, tables for detailed summaries
|
||||
- **Light background** - off-white (#f5f5f5) with white cards for content
|
||||
|
||||
**Report Structure:**
|
||||
1. Title + "Created by DocETL" subtitle
|
||||
2. Key stats cards (document count, categories, etc.)
|
||||
3. Distribution charts (bar charts, pie charts)
|
||||
4. Summary table with detailed analysis
|
||||
5. Minimal footer
|
||||
|
||||
**Interactive Tables:**
|
||||
- **All truncated content must be expandable** - never use static "..." truncation
|
||||
- Long text: Show first ~250 chars with "(show more)" toggle
|
||||
- Long lists: Show first 4-6 items with "(+N more)" toggle
|
||||
- Use JavaScript to toggle visibility, not page reloads
|
||||
|
||||
**Source Document Links:**
|
||||
- **Link aggregated results to source documents** - users should be able to drill down
|
||||
- Clickable links that open a modal/popup with source content
|
||||
- Modal should show: extracted fields + original source text
|
||||
- Original text can be collapsed by default with "Show original" toggle
|
||||
- Embed source data as JSON in the page for JavaScript access
|
||||
|
||||
**Key principle:** The user should see results at every step. Don't proceed to the next phase until the current phase produces good results.
|
||||
|
||||
## Step 1: Data Preparation
|
||||
|
||||
DocETL datasets must be **JSON arrays** or **CSV files**.
|
||||
|
||||
### JSON Format
|
||||
```json
|
||||
[
|
||||
{"id": 1, "text": "First document content...", "metadata": "value"},
|
||||
{"id": 2, "text": "Second document content...", "metadata": "value"}
|
||||
]
|
||||
```
|
||||
|
||||
### CSV Format
|
||||
```csv
|
||||
id,text,metadata
|
||||
1,"First document content...","value"
|
||||
2,"Second document content...","value"
|
||||
```
|
||||
|
||||
### Data Collection Scripts
|
||||
|
||||
If user needs to collect data, write a Python script:
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
# Collect/transform data
|
||||
documents = []
|
||||
for source in sources:
|
||||
documents.append({
|
||||
"id": source.id,
|
||||
"text": source.content, # DO NOT truncate text
|
||||
# Add relevant fields
|
||||
})
|
||||
|
||||
# Save as DocETL dataset
|
||||
with open("dataset.json", "w") as f:
|
||||
json.dump(documents, f, indent=2)
|
||||
```
|
||||
|
||||
**Important:** Never truncate document text in collection scripts. DocETL operations like `split` handle long documents properly. Truncation loses information.
|
||||
|
||||
### After Running Data Collection
|
||||
|
||||
**Always run the collection script and inspect results before proceeding.** Show the user:
|
||||
|
||||
```python
|
||||
import json
|
||||
data = json.load(open("dataset.json"))
|
||||
|
||||
print(f"Total documents: {len(data)}")
|
||||
print(f"Keys: {list(data[0].keys())}")
|
||||
print(f"Avg length: {sum(len(str(d)) for d in data) // len(data)} chars")
|
||||
|
||||
# Show sample
|
||||
print("\nSample document:")
|
||||
print(json.dumps(data[0], indent=2)[:500])
|
||||
```
|
||||
|
||||
Only proceed to pipeline development once the data looks correct.
|
||||
|
||||
## Step 2: Read and Understand the Data
|
||||
|
||||
**CRITICAL**: Before writing any prompts, READ the actual input data to understand:
|
||||
- The structure and format of documents
|
||||
- The vocabulary and terminology used
|
||||
- What information is present vs. absent
|
||||
- Edge cases and variations
|
||||
|
||||
```python
|
||||
import json
|
||||
with open("dataset.json") as f:
|
||||
data = json.load(f)
|
||||
# Examine several examples
|
||||
for doc in data[:5]:
|
||||
print(doc)
|
||||
```
|
||||
|
||||
This understanding is essential for writing specific, effective prompts.
|
||||
|
||||
## Step 3: Pipeline Structure
|
||||
|
||||
Create a YAML file with this structure:
|
||||
|
||||
```yaml
|
||||
default_model: gpt-5-nano
|
||||
|
||||
system_prompt:
|
||||
dataset_description: <describe the data based on what you observed>
|
||||
persona: <role for the LLM to adopt>
|
||||
|
||||
datasets:
|
||||
input_data:
|
||||
type: file
|
||||
path: "dataset.json" # or dataset.csv
|
||||
|
||||
operations:
|
||||
- name: <operation_name>
|
||||
type: <operation_type>
|
||||
prompt: |
|
||||
<Detailed, specific prompt based on the actual data>
|
||||
output:
|
||||
schema:
|
||||
<field_name>: <type>
|
||||
|
||||
pipeline:
|
||||
steps:
|
||||
- name: process
|
||||
input: input_data
|
||||
operations:
|
||||
- <operation_name>
|
||||
output:
|
||||
type: file
|
||||
path: "output.json"
|
||||
intermediate_dir: "intermediates" # ALWAYS set this for debugging
|
||||
```
|
||||
|
||||
### Key Configuration
|
||||
|
||||
- **default_model**: Use `gpt-5-nano` or `gpt-5-mini` for extraction/map operations
|
||||
- **intermediate_dir**: Always set to log intermediate results
|
||||
- **system_prompt**: Describe the data based on what you actually observed
|
||||
|
||||
### Model Selection by Operation Type
|
||||
|
||||
| Operation Type | Recommended Model | Rationale |
|
||||
|---------------|-------------------|-----------|
|
||||
| Map (extraction) | `gpt-5-nano` or `gpt-5-mini` | High volume, simple per-doc tasks |
|
||||
| Filter | `gpt-5-nano` | Simple yes/no decisions |
|
||||
| Reduce (summarization) | `gpt-4.1` or `gpt-5.1` | Complex synthesis across many docs |
|
||||
| Resolve (deduplication) | `gpt-5-nano` or `gpt-5-mini` | Simple pairwise comparisons |
|
||||
|
||||
Use cheaper models for high-volume extraction, and more capable models for synthesis/summarization where quality matters most.
|
||||
|
||||
## Step 4: Writing Effective Prompts
|
||||
|
||||
**Prompts must be specific to the data, not generic.** After reading the input data:
|
||||
|
||||
### Bad (Generic) Prompt
|
||||
```yaml
|
||||
prompt: |
|
||||
Extract key information from this document.
|
||||
{{ input.text }}
|
||||
```
|
||||
|
||||
### Good (Specific) Prompt
|
||||
```yaml
|
||||
prompt: |
|
||||
You are analyzing a medical transcript from a doctor-patient visit.
|
||||
|
||||
The transcript follows this format:
|
||||
- Doctor statements are prefixed with "DR:"
|
||||
- Patient statements are prefixed with "PT:"
|
||||
- Timestamps appear in brackets like [00:05:23]
|
||||
|
||||
From the following transcript, extract:
|
||||
1. All medications mentioned (brand names or generic)
|
||||
2. Dosages if specified
|
||||
3. Patient-reported side effects or concerns
|
||||
|
||||
Transcript:
|
||||
{{ input.transcript }}
|
||||
|
||||
Be thorough - patients often mention medication names informally.
|
||||
If a medication is unclear, include it with a note.
|
||||
```
|
||||
|
||||
### Prompt Writing Guidelines
|
||||
|
||||
1. **Describe the data format** you observed
|
||||
2. **Be specific about what to extract** - list exact fields
|
||||
3. **Mention edge cases** you noticed in the data
|
||||
4. **Provide examples** if the task is ambiguous
|
||||
5. **Set expectations** for handling missing/unclear information
|
||||
|
||||
## Step 5: Choosing Operations
|
||||
|
||||
Many tasks only need a **single map operation**. Use good judgement:
|
||||
|
||||
| Task | Recommended Approach |
|
||||
|------|---------------------|
|
||||
| Extract info from each doc | Single `map` |
|
||||
| Multiple extractions | Multiple `map` operations chained |
|
||||
| Extract then summarize | `map` → `reduce` |
|
||||
| Filter then process | `filter` → `map` |
|
||||
| Split long docs | `split` → `map` → `reduce` |
|
||||
| Deduplicate entities | `map` → `unnest` → `resolve` |
|
||||
|
||||
## Operation Reference
|
||||
|
||||
### Map Operation
|
||||
|
||||
Applies an LLM transformation to each document independently.
|
||||
|
||||
```yaml
|
||||
- name: extract_info
|
||||
type: map
|
||||
prompt: |
|
||||
Analyze this document:
|
||||
{{ input.text }}
|
||||
|
||||
Extract the main topic and 3 key points.
|
||||
output:
|
||||
schema:
|
||||
topic: string
|
||||
key_points: list[string]
|
||||
model: gpt-5-nano # optional, uses default_model if not set
|
||||
skip_on_error: true # recommended for large-scale runs
|
||||
validate: # optional
|
||||
- len(output["key_points"]) == 3
|
||||
num_retries_on_validate_failure: 2 # optional
|
||||
```
|
||||
|
||||
**Key parameters:**
|
||||
- `prompt`: Jinja2 template, use `{{ input.field }}` to reference fields
|
||||
- `output.schema`: Define output structure
|
||||
- `skip_on_error`: Set `true` to continue on LLM errors (recommended at scale)
|
||||
- `validate`: Python expressions to validate output
|
||||
- `sample`: Process only N documents (for testing)
|
||||
- `limit`: Stop after producing N outputs
|
||||
|
||||
### Filter Operation
|
||||
|
||||
Keeps or removes documents based on LLM criteria. Output schema must have exactly one boolean field.
|
||||
|
||||
```yaml
|
||||
- name: filter_relevant
|
||||
type: filter
|
||||
skip_on_error: true
|
||||
prompt: |
|
||||
Document: {{ input.text }}
|
||||
|
||||
Is this document relevant to climate change?
|
||||
Respond true or false.
|
||||
output:
|
||||
schema:
|
||||
is_relevant: boolean
|
||||
```
|
||||
|
||||
### Reduce Operation
|
||||
|
||||
Aggregates documents by a key using an LLM.
|
||||
|
||||
**Always include `fold_prompt` and `fold_batch_size`** for reduce operations. This handles cases where the group is too large to fit in context.
|
||||
|
||||
```yaml
|
||||
- name: summarize_by_category
|
||||
type: reduce
|
||||
reduce_key: category # use "_all" to aggregate everything
|
||||
skip_on_error: true
|
||||
prompt: |
|
||||
Summarize these {{ inputs | length }} items for category "{{ inputs[0].category }}":
|
||||
|
||||
{% for item in inputs %}
|
||||
- {{ item.title }}: {{ item.description }}
|
||||
{% endfor %}
|
||||
|
||||
Provide a 2-3 sentence summary of the key themes.
|
||||
fold_prompt: |
|
||||
You have a summary based on previous items, and new items to incorporate.
|
||||
|
||||
Previous summary (based on {{ output.item_count }} items):
|
||||
{{ output.summary }}
|
||||
|
||||
New items ({{ inputs | length }} more):
|
||||
{% for item in inputs %}
|
||||
- {{ item.title }}: {{ item.description }}
|
||||
{% endfor %}
|
||||
|
||||
Write a NEW summary that covers ALL items (previous + new).
|
||||
|
||||
IMPORTANT: Output a clean, standalone summary as if describing the entire dataset.
|
||||
Do NOT mention "updated", "added", "new items", or reference the incremental process.
|
||||
fold_batch_size: 100
|
||||
output:
|
||||
schema:
|
||||
summary: string
|
||||
item_count: int
|
||||
validate:
|
||||
- len(output["summary"].strip()) > 0
|
||||
num_retries_on_validate_failure: 2
|
||||
```
|
||||
|
||||
**Critical: Writing Good Fold Prompts**
|
||||
|
||||
The `fold_prompt` is called repeatedly as batches are processed. Its output must:
|
||||
1. **Reflect ALL data seen so far**, not just the latest batch
|
||||
2. **Be a clean, standalone output** - no "updated X" or "added Y items" language
|
||||
3. **Match the same schema** as the initial `prompt` output
|
||||
|
||||
Bad fold_prompt output: "Added 50 new projects. The updated summary now includes..."
|
||||
Good fold_prompt output: "Developers are building privacy-focused tools and local-first apps..."
|
||||
|
||||
**Estimating `fold_batch_size`:**
|
||||
- **Use 100+ for most cases** - larger batches = fewer LLM calls = lower cost
|
||||
- For very long documents, reduce to 50-75
|
||||
- For short documents (tweets, titles), can use 150-200
|
||||
- Models like gpt-4o-mini have 128k context, so batch size is rarely the bottleneck
|
||||
|
||||
**Key parameters:**
|
||||
- `reduce_key`: Field to group by (or list of fields, or `_all`)
|
||||
- `fold_prompt`: Template for incrementally adding items to existing output (required)
|
||||
- `fold_batch_size`: Number of items per fold iteration (required, use 100+)
|
||||
- `associative`: Set to `false` if order matters
|
||||
|
||||
### Split Operation
|
||||
|
||||
Divides long text into smaller chunks. No LLM call.
|
||||
|
||||
```yaml
|
||||
- name: split_document
|
||||
type: split
|
||||
split_key: content
|
||||
method: token_count # or "delimiter"
|
||||
method_kwargs:
|
||||
num_tokens: 500
|
||||
model: gpt-5-nano
|
||||
```
|
||||
|
||||
**Output adds:**
|
||||
- `{split_key}_chunk`: The chunk content
|
||||
- `{op_name}_id`: Original document ID
|
||||
- `{op_name}_chunk_num`: Chunk number
|
||||
|
||||
### Unnest Operation
|
||||
|
||||
Flattens list fields into separate rows. No LLM call.
|
||||
|
||||
```yaml
|
||||
- name: unnest_items
|
||||
type: unnest
|
||||
unnest_key: items # field containing the list
|
||||
keep_empty: false # optional
|
||||
```
|
||||
|
||||
**Example:** If a document has `items: ["a", "b", "c"]`, unnest creates 3 documents, each with `items: "a"`, `items: "b"`, `items: "c"`.
|
||||
|
||||
### Resolve Operation
|
||||
|
||||
Deduplicates and canonicalizes entities. Uses pairwise comparison.
|
||||
|
||||
```yaml
|
||||
- name: dedupe_names
|
||||
type: resolve
|
||||
optimize: true # let optimizer find blocking rules
|
||||
skip_on_error: true
|
||||
comparison_prompt: |
|
||||
Are these the same person?
|
||||
|
||||
Person 1: {{ input1.name }} ({{ input1.email }})
|
||||
Person 2: {{ input2.name }} ({{ input2.email }})
|
||||
|
||||
Respond true or false.
|
||||
resolution_prompt: |
|
||||
Standardize this person's name:
|
||||
|
||||
{% for entry in inputs %}
|
||||
- {{ entry.name }}
|
||||
{% endfor %}
|
||||
|
||||
Return the canonical name.
|
||||
output:
|
||||
schema:
|
||||
name: string
|
||||
```
|
||||
|
||||
**Important:** Set `optimize: true` and run `docetl build` to generate efficient blocking rules. Without blocking, this is O(n²).
|
||||
|
||||
### Code Operations
|
||||
|
||||
Deterministic Python transformations without LLM calls.
|
||||
|
||||
**code_map:**
|
||||
```yaml
|
||||
- name: compute_stats
|
||||
type: code_map
|
||||
code: |
|
||||
def transform(doc) -> dict:
|
||||
return {
|
||||
"word_count": len(doc["text"].split()),
|
||||
"char_count": len(doc["text"])
|
||||
}
|
||||
```
|
||||
|
||||
**code_reduce:**
|
||||
```yaml
|
||||
- name: aggregate
|
||||
type: code_reduce
|
||||
reduce_key: category
|
||||
code: |
|
||||
def transform(items) -> dict:
|
||||
total = sum(item["value"] for item in items)
|
||||
return {"total": total, "count": len(items)}
|
||||
```
|
||||
|
||||
**code_filter:**
|
||||
```yaml
|
||||
- name: filter_long
|
||||
type: code_filter
|
||||
code: |
|
||||
def transform(doc) -> bool:
|
||||
return len(doc["text"]) > 100
|
||||
|
||||
```
|
||||
|
||||
### Retrievers (LanceDB)
|
||||
|
||||
Augment LLM operations with retrieved context from a LanceDB index. Useful for:
|
||||
- Finding related documents to compare against
|
||||
- Providing additional context for extraction/classification
|
||||
- Cross-referencing facts across a dataset
|
||||
|
||||
**Define a retriever:**
|
||||
```yaml
|
||||
retrievers:
|
||||
facts_index:
|
||||
type: lancedb
|
||||
dataset: extracted_facts # dataset to index
|
||||
index_dir: workloads/wiki/lance_index
|
||||
build_index: if_missing # if_missing | always | never
|
||||
index_types: ["fts", "embedding"] # or "hybrid"
|
||||
fts:
|
||||
index_phrase: "{{ input.fact }}: {{ input.source }}"
|
||||
query_phrase: "{{ input.fact }}"
|
||||
embedding:
|
||||
model: openai/text-embedding-3-small
|
||||
index_phrase: "{{ input.fact }}"
|
||||
query_phrase: "{{ input.fact }}"
|
||||
query:
|
||||
mode: hybrid
|
||||
top_k: 5
|
||||
```
|
||||
|
||||
**Use in operations:**
|
||||
```yaml
|
||||
- name: find_conflicts
|
||||
type: map
|
||||
retriever: facts_index
|
||||
prompt: |
|
||||
Check if this fact conflicts with any retrieved facts:
|
||||
|
||||
Current fact: {{ input.fact }} (from {{ input.source }})
|
||||
|
||||
Related facts from other articles:
|
||||
{{ retrieval_context }}
|
||||
|
||||
Return whether there's a genuine conflict.
|
||||
output:
|
||||
schema:
|
||||
has_conflict: boolean
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
- `{{ retrieval_context }}` is injected into prompts automatically
|
||||
- Index is built on first use (when `build_index: if_missing`)
|
||||
- Supports full-text (`fts`), vector (`embedding`), or `hybrid` search
|
||||
- Use `save_retriever_output: true` to debug what was retrieved
|
||||
- **Can index intermediate outputs**: Retriever can index the output of a previous pipeline step, enabling patterns like "extract facts → index facts → retrieve similar facts for each"
|
||||
|
||||
## Documentation Reference
|
||||
|
||||
For detailed parameters, advanced features, and more examples, read the docs:
|
||||
- **Operations**: `docs/operators/` folder (map.md, reduce.md, filter.md, etc.)
|
||||
- **Concepts**: `docs/concepts/` folder (pipelines.md, operators.md, schemas.md)
|
||||
- **Examples**: `docs/examples/` folder
|
||||
- **Optimization**: `docs/optimization/` folder
|
||||
|
||||
## Step 6: Environment Setup
|
||||
|
||||
Before running, verify API keys exist:
|
||||
|
||||
```bash
|
||||
# Check for .env file
|
||||
cat .env
|
||||
```
|
||||
|
||||
Required keys depend on the model:
|
||||
- OpenAI: `OPENAI_API_KEY`
|
||||
- Anthropic: `ANTHROPIC_API_KEY`
|
||||
- Google: `GEMINI_API_KEY`
|
||||
|
||||
If missing, prompt user to create `.env`:
|
||||
```
|
||||
OPENAI_API_KEY=sk-...
|
||||
```
|
||||
|
||||
## Step 7: Execution
|
||||
|
||||
**Always test on a sample first, then run full pipeline.**
|
||||
|
||||
### Test Run (Required)
|
||||
Add `sample: 10-20` to your first operation, then run:
|
||||
```bash
|
||||
docetl run pipeline.yaml
|
||||
```
|
||||
|
||||
**Inspect the test results before proceeding:**
|
||||
```python
|
||||
import json
|
||||
from collections import Counter
|
||||
|
||||
# Load intermediate results
|
||||
data = json.load(open("intermediates/step_name/operation_name.json"))
|
||||
|
||||
print(f"Processed: {len(data)} docs")
|
||||
|
||||
# Check distributions
|
||||
if "domain" in data[0]:
|
||||
print("Domain distribution:")
|
||||
for k, v in Counter(d["domain"] for d in data).most_common():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
# Show sample outputs
|
||||
print("\nSample output:")
|
||||
print(json.dumps(data[0], indent=2))
|
||||
```
|
||||
|
||||
### Full Run
|
||||
Once test results look good:
|
||||
1. Remove the `sample` parameter from the pipeline
|
||||
2. Ask user for permission (estimate cost based on test run)
|
||||
3. Run full pipeline
|
||||
4. **Show final results** - distributions, key insights, trends
|
||||
|
||||
Options:
|
||||
- `--max_threads N` - Control parallelism
|
||||
|
||||
Check intermediate results in the `intermediate_dir` folder to debug each step.
|
||||
|
||||
## Step 8: Optimization (Optional)
|
||||
|
||||
Use MOAR optimizer to find the Pareto frontier of **cost vs. accuracy** tradeoffs. MOAR experiments with different pipeline rewrites and models to find optimal configurations.
|
||||
|
||||
Add to pipeline YAML:
|
||||
|
||||
```yaml
|
||||
optimizer_config:
|
||||
type: moar
|
||||
save_dir: ./optimization_results
|
||||
available_models:
|
||||
- gpt-5-nano
|
||||
- gpt-4o-mini
|
||||
- gpt-4o
|
||||
evaluation_file: evaluate.py # User must provide
|
||||
metric_key: score
|
||||
max_iterations: 20
|
||||
model: gpt-5-nano
|
||||
```
|
||||
|
||||
Create evaluation file (`evaluate.py`):
|
||||
```python
|
||||
def evaluate(outputs: list[dict]) -> dict:
|
||||
# Score the outputs (0-1 scale recommended)
|
||||
correct = sum(1 for o in outputs if is_correct(o))
|
||||
return {"score": correct / len(outputs)}
|
||||
```
|
||||
|
||||
Run optimization:
|
||||
```bash
|
||||
docetl build pipeline.yaml --optimizer moar
|
||||
```
|
||||
|
||||
MOAR will produce multiple pipeline variants on the Pareto frontier - user can choose based on their cost/accuracy preferences.
|
||||
|
||||
## Output Schemas
|
||||
|
||||
**Keep schemas minimal and simple** unless the user explicitly requests more fields. Default to 1-3 output fields per operation. Only add more fields if the user specifically asks for them.
|
||||
|
||||
**Nesting limit:** Maximum 2 levels deep (e.g., `list[{field: str}]` is allowed, but no deeper).
|
||||
|
||||
```yaml
|
||||
# Good - minimal, focused on the core task
|
||||
output:
|
||||
schema:
|
||||
summary: string
|
||||
|
||||
# Good - a few fields when task requires it
|
||||
output:
|
||||
schema:
|
||||
topic: string
|
||||
keywords: list[string]
|
||||
|
||||
# Acceptable - 2 levels of nesting (list of objects)
|
||||
output:
|
||||
schema:
|
||||
items: "list[{name: str, value: int}]"
|
||||
|
||||
# Bad - too many fields (unless user explicitly requested all of these)
|
||||
output:
|
||||
schema:
|
||||
conflicts_found: bool
|
||||
num_conflicts: int
|
||||
conflicts: "list[{claim_a: str, source_a: str, claim_b: str, source_b: str}]"
|
||||
analysis_summary: str
|
||||
|
||||
# Bad - more than 2 levels of nesting (not supported)
|
||||
output:
|
||||
schema:
|
||||
data: "list[{nested: {too: {deep: str}}}]"
|
||||
```
|
||||
|
||||
**Guidelines:**
|
||||
- Start with the minimum fields needed to answer the user's question
|
||||
- Avoid complex nested objects unless explicitly requested
|
||||
- If you need structured data, prefer multiple simple operations over one complex schema
|
||||
- Complex schemas increase LLM failures and costs
|
||||
|
||||
Supported types: `string`, `int`, `float`, `bool`, `list[type]`, `enum`
|
||||
|
||||
## Validation
|
||||
|
||||
**Always add validation to LLM-powered operations** (map, reduce, filter, resolve). Validation catches malformed outputs and retries automatically.
|
||||
|
||||
```yaml
|
||||
- name: extract_keywords
|
||||
type: map
|
||||
prompt: |
|
||||
Extract 3-5 keywords from: {{ input.text }}
|
||||
output:
|
||||
schema:
|
||||
keywords: list[string]
|
||||
validate:
|
||||
- len(output["keywords"]) >= 3
|
||||
- len(output["keywords"]) <= 5
|
||||
num_retries_on_validate_failure: 2
|
||||
```
|
||||
|
||||
Common validation patterns:
|
||||
```yaml
|
||||
# List length constraints
|
||||
- len(output["items"]) >= 1
|
||||
- len(output["items"]) <= 10
|
||||
|
||||
# Enum/allowed values
|
||||
- output["sentiment"] in ["positive", "negative", "neutral"]
|
||||
|
||||
# String not empty
|
||||
- len(output["summary"].strip()) > 0
|
||||
|
||||
# Numeric ranges
|
||||
- output["score"] >= 0
|
||||
- output["score"] <= 100
|
||||
```
|
||||
|
||||
## Jinja2 Templating
|
||||
|
||||
For map operations, use `input`:
|
||||
```yaml
|
||||
prompt: |
|
||||
Document: {{ input.text }}
|
||||
{% if input.metadata %}
|
||||
Context: {{ input.metadata }}
|
||||
{% endif %}
|
||||
```
|
||||
|
||||
For reduce operations, use `inputs` (list):
|
||||
```yaml
|
||||
prompt: |
|
||||
Summarize these {{ inputs | length }} items:
|
||||
{% for item in inputs %}
|
||||
- {{ item.summary }}
|
||||
{% endfor %}
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Pipeline won't run
|
||||
- Check `.env` has correct API keys
|
||||
- Verify dataset file exists and is valid JSON/CSV
|
||||
- Check YAML syntax
|
||||
|
||||
### Bad outputs
|
||||
- Read more input data examples to improve prompt specificity
|
||||
- Add `validate` rules with retries
|
||||
- Simplify output schema
|
||||
- Add concrete examples to prompt
|
||||
|
||||
### High costs
|
||||
- Use `gpt-5-nano` or `gpt-4o-mini`
|
||||
- Add `sample: 10` to test on subset first
|
||||
- Run MOAR optimizer to find cost-efficient rewrites
|
||||
|
||||
### Check intermediate results
|
||||
Look in `intermediate_dir` folder to debug each step.
|
||||
|
||||
## Quick Reference
|
||||
|
||||
```bash
|
||||
# Run pipeline
|
||||
docetl run pipeline.yaml
|
||||
|
||||
# Run with more parallelism
|
||||
docetl run pipeline.yaml --max_threads 16
|
||||
|
||||
# Optimize pipeline (cost/accuracy tradeoff)
|
||||
docetl build pipeline.yaml --optimizer moar
|
||||
|
||||
# Clear LLM cache
|
||||
docetl clear-cache
|
||||
|
||||
# Check version
|
||||
docetl version
|
||||
```
|
||||
|
||||
|
|
@ -57,26 +57,7 @@ website/next-env.d.ts
|
|||
|
||||
# experiments
|
||||
experiments/*.json
|
||||
experiments/reasoning/data/*.json
|
||||
|
||||
metrics_vs_cost.png
|
||||
tests/data/anthropic-red-team-attempts.jsonl
|
||||
tests/data/get_freshstack.py
|
||||
!experiments/reasoning/data/operators_documentation.txt
|
||||
*.txt
|
||||
*.yaml
|
||||
*.yml
|
||||
*.json
|
||||
!*lotus_evaluation.json
|
||||
!*pz_evaluation.json
|
||||
!*pz_*_evaluation.json
|
||||
*.png
|
||||
graph.py
|
||||
graph_baseline.py
|
||||
kendalltau.py
|
||||
|
||||
slides/*
|
||||
node_modules/*
|
||||
|
||||
reaction_index/
|
||||
*experiments/reasoning/othersystems/biodex/.chroma-biodex/
|
||||
tests/data/get_freshstack.py
|
||||
|
|
@ -13,8 +13,7 @@ DocETL is a tool for creating and executing data processing pipelines, especiall
|
|||
2. A Python package for running production pipelines from the command line or Python code
|
||||
|
||||
> 💡 **Need Help Writing Your Pipeline?**
|
||||
> You can use **Claude Code** (recommended) to help you write your pipeline—see the quickstart: https://ucbepic.github.io/docetl/quickstart-claude-code/
|
||||
> If you’d rather use ChatGPT or the Claude app, see [docetl.org/llms.txt](https://docetl.org/llms.txt) for a big prompt you can copy/paste before describing your task.
|
||||
> Want to use an LLM like ChatGPT or Claude to help you write your pipeline? See [docetl.org/llms.txt](https://docetl.org/llms.txt) for a big prompt you can copy paste into ChatGPT or Claude, before describing your task.
|
||||
|
||||
|
||||
### 🌟 Community Projects
|
||||
|
|
|
|||
|
|
@ -1,15 +1,11 @@
|
|||
__version__ = "0.2.6"
|
||||
__version__ = "0.2.5"
|
||||
|
||||
import warnings
|
||||
import litellm
|
||||
|
||||
from docetl.runner import DSLRunner
|
||||
from docetl.optimizer import Optimizer
|
||||
from docetl.apis.pd_accessors import SemanticAccessor
|
||||
|
||||
# Drop unsupported params for models like gpt-5 that don't support temperature=0
|
||||
litellm.drop_params = True
|
||||
|
||||
# TODO: Remove after https://github.com/BerriAI/litellm/issues/7560 is fixed
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._config")
|
||||
|
||||
|
|
|
|||
|
|
@ -59,12 +59,8 @@ from rich import print
|
|||
from docetl.runner import DSLRunner
|
||||
from docetl.schemas import (
|
||||
ClusterOp,
|
||||
CodeFilterOp,
|
||||
CodeMapOp,
|
||||
CodeReduceOp,
|
||||
Dataset,
|
||||
EquijoinOp,
|
||||
ExtractOp,
|
||||
FilterOp,
|
||||
GatherOp,
|
||||
MapOp,
|
||||
|
|
@ -344,14 +340,6 @@ class Pipeline:
|
|||
self.operations.append(ClusterOp(**op, type=op_type))
|
||||
elif op_type == "sample":
|
||||
self.operations.append(SampleOp(**op, type=op_type))
|
||||
elif op_type == "code_map":
|
||||
self.operations.append(CodeMapOp(**op, type=op_type))
|
||||
elif op_type == "code_reduce":
|
||||
self.operations.append(CodeReduceOp(**op, type=op_type))
|
||||
elif op_type == "code_filter":
|
||||
self.operations.append(CodeFilterOp(**op, type=op_type))
|
||||
elif op_type == "extract":
|
||||
self.operations.append(ExtractOp(**op, type=op_type))
|
||||
self.steps = [PipelineStep(**step) for step in config["pipeline"]["steps"]]
|
||||
self.output = PipelineOutput(**config["pipeline"]["output"])
|
||||
self.default_model = config.get("default_model")
|
||||
|
|
@ -375,10 +363,6 @@ __all__ = [
|
|||
"SplitOp",
|
||||
"GatherOp",
|
||||
"UnnestOp",
|
||||
"CodeMapOp",
|
||||
"CodeReduceOp",
|
||||
"CodeFilterOp",
|
||||
"ExtractOp",
|
||||
"PipelineStep",
|
||||
"PipelineOutput",
|
||||
"ParsingTool",
|
||||
|
|
|
|||
|
|
@ -673,68 +673,6 @@ Record 2: {record_template.replace('input0', 'input2')}"""
|
|||
|
||||
return self._record_operation(results, "reduce", reduce_config, reduce_cost)
|
||||
|
||||
def reduce(
|
||||
self,
|
||||
prompt: str,
|
||||
output: dict[str, Any] | None = None,
|
||||
*,
|
||||
output_schema: dict[str, Any] | None = None,
|
||||
reduce_keys: str | list[str] = ["_all"],
|
||||
**kwargs,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Reduce/aggregate all rows using a language model.
|
||||
|
||||
This is a simplified wrapper around the agg() method for common reduce operations.
|
||||
|
||||
Documentation: https://ucbepic.github.io/docetl/operators/reduce/
|
||||
|
||||
Args:
|
||||
prompt: Jinja template string for the reduction prompt. Use {% for item in inputs %}
|
||||
to iterate over input rows.
|
||||
output: Output configuration with keys:
|
||||
- "schema": Dictionary defining the expected output structure and types
|
||||
- "mode": Optional output mode. Either "tools" (default) or "structured_output"
|
||||
output_schema: DEPRECATED. Use 'output' parameter instead.
|
||||
reduce_keys: Keys to group by for reduction (default: ["_all"] for all rows)
|
||||
**kwargs: Additional configuration options passed to agg():
|
||||
- model: LLM model to use
|
||||
- validate: List of validation expressions
|
||||
- num_retries_on_validate_failure: Number of retries
|
||||
- timeout: Timeout in seconds (default: 120)
|
||||
- max_retries_per_timeout: Max retries per timeout (default: 2)
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Aggregated DataFrame with columns matching output['schema']
|
||||
|
||||
Examples:
|
||||
>>> # Summarize all texts into one summary
|
||||
>>> df.semantic.reduce(
|
||||
... prompt=\"\"\"Summarize the following items:
|
||||
... {% for item in inputs %}
|
||||
... - {{ item.text }}
|
||||
... {% endfor %}\"\"\",
|
||||
... output={"schema": {"summary": "str"}}
|
||||
... )
|
||||
|
||||
>>> # Reduce by group
|
||||
>>> df.semantic.reduce(
|
||||
... prompt=\"\"\"Combine items for {{ reduce_key }}:
|
||||
... {% for item in inputs %}
|
||||
... - {{ item.description }}
|
||||
... {% endfor %}\"\"\",
|
||||
... output={"schema": {"combined": "str"}},
|
||||
... reduce_keys=["category"]
|
||||
... )
|
||||
"""
|
||||
return self.agg(
|
||||
reduce_prompt=prompt,
|
||||
output=output,
|
||||
output_schema=output_schema,
|
||||
reduce_keys=reduce_keys,
|
||||
reduce_kwargs=kwargs,
|
||||
)
|
||||
|
||||
def filter(
|
||||
self,
|
||||
prompt: str,
|
||||
|
|
|
|||
260
docetl/cli.py
260
docetl/cli.py
|
|
@ -3,15 +3,10 @@ from pathlib import Path
|
|||
|
||||
import typer
|
||||
from dotenv import load_dotenv
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
from docetl.operations.utils import clear_cache as cc
|
||||
from docetl.runner import DSLRunner
|
||||
|
||||
console = Console(stderr=True)
|
||||
|
||||
app = typer.Typer(pretty_exceptions_enable=False)
|
||||
|
||||
|
||||
|
|
@ -20,12 +15,6 @@ def build(
|
|||
yaml_file: Path = typer.Argument(
|
||||
..., help="Path to the YAML file containing the pipeline configuration"
|
||||
),
|
||||
optimizer: str = typer.Option(
|
||||
"moar",
|
||||
"--optimizer",
|
||||
"-o",
|
||||
help="Optimizer to use: 'moar' (default) or 'v1' (deprecated)",
|
||||
),
|
||||
max_threads: int | None = typer.Option(
|
||||
None, help="Maximum number of threads to use for running operations"
|
||||
),
|
||||
|
|
@ -42,8 +31,8 @@ def build(
|
|||
|
||||
Args:
|
||||
yaml_file (Path): Path to the YAML file containing the pipeline configuration.
|
||||
optimizer (str): Optimizer to use - 'moar' or 'v1' (required).
|
||||
max_threads (int | None): Maximum number of threads to use for running operations.
|
||||
model (str): Model to use for optimization. Defaults to "gpt-4o".
|
||||
resume (bool): Whether to resume optimization from a previous run. Defaults to False.
|
||||
save_path (Path): Path to save the optimized pipeline configuration.
|
||||
"""
|
||||
|
|
@ -55,147 +44,13 @@ def build(
|
|||
if os.path.exists(env_file):
|
||||
load_dotenv(env_file)
|
||||
|
||||
# Validate optimizer choice
|
||||
if optimizer not in ["moar", "v1"]:
|
||||
typer.echo(
|
||||
f"Error: optimizer must be 'moar' or 'v1', got '{optimizer}'", err=True
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Load YAML to check for optimizer_config
|
||||
import yaml as yaml_lib
|
||||
|
||||
with open(yaml_file, "r") as f:
|
||||
config = yaml_lib.safe_load(f)
|
||||
|
||||
if optimizer == "moar":
|
||||
optimizer_config = config.get("optimizer_config", {})
|
||||
if not optimizer_config:
|
||||
example_yaml = """optimizer_config:
|
||||
type: moar
|
||||
save_dir: ./moar_results
|
||||
available_models:
|
||||
- gpt-5
|
||||
- gpt-4o
|
||||
evaluation_file: workloads/medical/evaluate_medications.py
|
||||
metric_key: medication_extraction_score
|
||||
max_iterations: 40
|
||||
model: gpt-5"""
|
||||
|
||||
error_panel = Panel(
|
||||
f"[bold red]Error:[/bold red] optimizer_config section is required in YAML for MOAR optimizer.\n\n"
|
||||
f"[bold]Example:[/bold]\n"
|
||||
f"[dim]{example_yaml}[/dim]\n\n"
|
||||
f"[yellow]Note:[/yellow] dataset_name is inferred from the 'datasets' section. "
|
||||
f"dataset_path can optionally be specified in optimizer_config, otherwise it's inferred from the 'datasets' section.",
|
||||
title="[bold red]Missing optimizer_config[/bold red]",
|
||||
border_style="red",
|
||||
)
|
||||
console.print(error_panel)
|
||||
raise typer.Exit(1)
|
||||
|
||||
if optimizer_config.get("type") != "moar":
|
||||
error_panel = Panel(
|
||||
f"[bold red]Error:[/bold red] optimizer_config.type must be 'moar', got '[yellow]{optimizer_config.get('type')}[/yellow]'",
|
||||
title="[bold red]Invalid optimizer type[/bold red]",
|
||||
border_style="red",
|
||||
)
|
||||
console.print(error_panel)
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Validate required fields in optimizer_config
|
||||
required_fields = {
|
||||
"save_dir": "Output directory for MOAR results",
|
||||
"available_models": "List of model names to use",
|
||||
"evaluation_file": "Path to evaluation function file",
|
||||
"metric_key": "Key to extract from evaluation results",
|
||||
"max_iterations": "Number of MOARSearch iterations",
|
||||
"model": "LLM model name for directive instantiation",
|
||||
}
|
||||
|
||||
missing_fields = [
|
||||
field for field in required_fields if not optimizer_config.get(field)
|
||||
]
|
||||
if missing_fields:
|
||||
# Create a table for required fields
|
||||
fields_table = Table(
|
||||
show_header=True, header_style="bold cyan", box=None, padding=(0, 2)
|
||||
)
|
||||
fields_table.add_column("Field", style="yellow")
|
||||
fields_table.add_column("Description", style="dim")
|
||||
|
||||
for field, desc in required_fields.items():
|
||||
style = "bold red" if field in missing_fields else "dim"
|
||||
fields_table.add_row(f"[{style}]{field}[/{style}]", desc)
|
||||
|
||||
# Create example YAML
|
||||
example_yaml = """optimizer_config:
|
||||
type: moar
|
||||
save_dir: ./moar_results
|
||||
available_models:
|
||||
- gpt-5
|
||||
- gpt-4o
|
||||
evaluation_file: workloads/medical/evaluate_medications.py
|
||||
metric_key: medication_extraction_score
|
||||
max_iterations: 40
|
||||
model: gpt-5"""
|
||||
|
||||
missing_list = ", ".join(
|
||||
[f"[bold red]{f}[/bold red]" for f in missing_fields]
|
||||
)
|
||||
|
||||
# Build error content with table rendered separately
|
||||
from rich.console import Group
|
||||
|
||||
error_group = Group(
|
||||
f"[bold red]Missing required fields:[/bold red] {missing_list}\n",
|
||||
"[bold]Required fields:[/bold]",
|
||||
fields_table,
|
||||
f"\n[bold]Example:[/bold]\n[dim]{example_yaml}[/dim]\n",
|
||||
"[yellow]Note:[/yellow] dataset_name is inferred from the 'datasets' section. "
|
||||
"dataset_path can optionally be specified in optimizer_config, otherwise it's inferred from the 'datasets' section.",
|
||||
)
|
||||
|
||||
error_panel = Panel(
|
||||
error_group,
|
||||
title="[bold red]Missing Required Fields[/bold red]",
|
||||
border_style="red",
|
||||
)
|
||||
console.print(error_panel)
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Run MOAR optimization
|
||||
from docetl.moar.cli_helpers import run_moar_optimization
|
||||
|
||||
try:
|
||||
results = run_moar_optimization(
|
||||
yaml_path=str(yaml_file),
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
typer.echo("\n✅ MOAR optimization completed successfully!")
|
||||
typer.echo(f" Results saved to: {optimizer_config.get('save_dir')}")
|
||||
if results.get("evaluation_file"):
|
||||
typer.echo(f" Evaluation: {results['evaluation_file']}")
|
||||
except Exception as e:
|
||||
typer.echo(f"Error running MOAR optimization: {e}", err=True)
|
||||
raise typer.Exit(1)
|
||||
|
||||
else: # v1 optimizer (deprecated)
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold yellow]Warning:[/bold yellow] The V1 optimizer is deprecated. "
|
||||
"Please use MOAR optimizer instead: [bold]docetl build pipeline.yaml --optimizer moar[/bold]",
|
||||
title="[bold yellow]Deprecated Optimizer[/bold yellow]",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
runner = DSLRunner.from_yaml(str(yaml_file), max_threads=max_threads)
|
||||
runner.optimize(
|
||||
save=True,
|
||||
return_pipeline=False,
|
||||
resume=resume,
|
||||
save_path=save_path,
|
||||
)
|
||||
runner = DSLRunner.from_yaml(str(yaml_file), max_threads=max_threads)
|
||||
runner.optimize(
|
||||
save=True,
|
||||
return_pipeline=False,
|
||||
resume=resume,
|
||||
save_path=save_path,
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
|
|
@ -244,104 +99,5 @@ def version():
|
|||
typer.echo(f"DocETL version: {docetl.__version__}")
|
||||
|
||||
|
||||
@app.command("install-skill")
|
||||
def install_skill(
|
||||
uninstall: bool = typer.Option(
|
||||
False, "--uninstall", "-u", help="Remove the installed skill instead"
|
||||
),
|
||||
):
|
||||
"""
|
||||
Install the DocETL Claude Code skill to your personal skills directory.
|
||||
|
||||
This makes the DocETL skill available in Claude Code for any project.
|
||||
The skill helps you build and run DocETL pipelines.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
# Find the skill source - try multiple locations
|
||||
# 1. Installed package location (via importlib.resources)
|
||||
# 2. Development location (relative to this file)
|
||||
skill_source = None
|
||||
|
||||
# Try to find via package resources first
|
||||
try:
|
||||
import importlib.resources as pkg_resources
|
||||
|
||||
# For Python 3.9+, use files()
|
||||
try:
|
||||
package_root = Path(pkg_resources.files("docetl")).parent
|
||||
potential_source = package_root / ".claude" / "skills" / "docetl"
|
||||
if potential_source.exists():
|
||||
skill_source = potential_source
|
||||
except (TypeError, AttributeError):
|
||||
pass
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Fallback: try relative to this file (development mode)
|
||||
if skill_source is None:
|
||||
dev_source = Path(__file__).parent.parent / ".claude" / "skills" / "docetl"
|
||||
if dev_source.exists():
|
||||
skill_source = dev_source
|
||||
|
||||
if skill_source is None or not skill_source.exists():
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold red]Error:[/bold red] Could not find the DocETL skill files.\n\n"
|
||||
"This may happen if the package was not installed correctly.\n"
|
||||
"Try reinstalling: [bold]pip install --force-reinstall docetl[/bold]",
|
||||
title="[bold red]Skill Not Found[/bold red]",
|
||||
border_style="red",
|
||||
)
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Target directory
|
||||
skill_target = Path.home() / ".claude" / "skills" / "docetl"
|
||||
|
||||
if uninstall:
|
||||
if skill_target.exists():
|
||||
shutil.rmtree(skill_target)
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold green]Success![/bold green] DocETL skill removed from:\n"
|
||||
f"[dim]{skill_target}[/dim]",
|
||||
title="[bold green]Skill Uninstalled[/bold green]",
|
||||
border_style="green",
|
||||
)
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
Panel(
|
||||
"[yellow]The DocETL skill is not currently installed.[/yellow]",
|
||||
title="[yellow]Nothing to Uninstall[/yellow]",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Create parent directories if needed
|
||||
skill_target.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy the skill
|
||||
if skill_target.exists():
|
||||
shutil.rmtree(skill_target)
|
||||
|
||||
shutil.copytree(skill_source, skill_target)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold green]Success![/bold green] DocETL skill installed to:\n"
|
||||
f"[dim]{skill_target}[/dim]\n\n"
|
||||
"[bold]Next steps:[/bold]\n"
|
||||
"1. Restart Claude Code if it's running\n"
|
||||
"2. The skill will automatically activate when you work on DocETL tasks\n\n"
|
||||
"[dim]To uninstall: docetl install-skill --uninstall[/dim]",
|
||||
title="[bold green]Skill Installed[/bold green]",
|
||||
border_style="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from io import StringIO
|
||||
|
|
@ -13,63 +12,6 @@ from docetl.utils import StageType, get_stage_description
|
|||
|
||||
install(show_locals=False)
|
||||
|
||||
# ANSI escape sequences to strip (cursor movement, line clearing, cursor visibility)
|
||||
ANSI_CURSOR_PATTERNS = re.compile(
|
||||
r"\x1b\[\?25[lh]" # Hide/show cursor
|
||||
r"|\x1b\[\d*[ABCD]" # Cursor up/down/forward/back
|
||||
r"|\x1b\[\d*[JK]" # Clear screen/line (but we handle \x1b[2K specially)
|
||||
r"|\x1b\[s|\x1b\[u" # Save/restore cursor position
|
||||
)
|
||||
|
||||
|
||||
def process_carriage_returns(text: str) -> str:
|
||||
"""
|
||||
Process terminal control sequences for buffer-captured output.
|
||||
|
||||
Rich's spinner uses ANSI sequences for:
|
||||
- `\r` - carriage return (move to start of line)
|
||||
- `\x1b[2K` - clear entire line
|
||||
- `\x1b[?25l/h` - hide/show cursor
|
||||
- `\x1b[1A` - move cursor up
|
||||
|
||||
When captured to a buffer (not a real terminal), we simulate this by:
|
||||
1. Handling `\r` as "replace from start of line"
|
||||
2. Stripping cursor movement/visibility sequences
|
||||
3. Keeping only meaningful content
|
||||
"""
|
||||
# First, strip cursor visibility and movement sequences
|
||||
text = ANSI_CURSOR_PATTERNS.sub("", text)
|
||||
|
||||
# Process line by line
|
||||
lines = text.split("\n")
|
||||
processed_lines = []
|
||||
|
||||
for line in lines:
|
||||
# Handle carriage returns - keep only content after the last \r
|
||||
if "\r" in line:
|
||||
segments = line.split("\r")
|
||||
# Keep the last non-empty segment (after stripping ANSI clear codes)
|
||||
for segment in reversed(segments):
|
||||
# Strip the clear line code if present
|
||||
cleaned = segment.replace("\x1b[2K", "").strip()
|
||||
if cleaned:
|
||||
# Keep the segment but preserve any color codes
|
||||
processed_lines.append(segment.replace("\x1b[2K", ""))
|
||||
break
|
||||
else:
|
||||
# All segments were empty, skip this line entirely
|
||||
pass
|
||||
else:
|
||||
# No carriage return, keep the line as-is
|
||||
if line.strip() or line == "": # Keep empty lines between content
|
||||
processed_lines.append(line)
|
||||
|
||||
# Remove trailing empty lines
|
||||
while processed_lines and not processed_lines[-1].strip():
|
||||
processed_lines.pop()
|
||||
|
||||
return "\n".join(processed_lines)
|
||||
|
||||
|
||||
class ThreadSafeConsole(Console):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
@ -82,18 +24,11 @@ class ThreadSafeConsole(Console):
|
|||
self.optimizer_rationale = None
|
||||
|
||||
def get_output(self):
|
||||
"""
|
||||
Get the output from the buffer, processing carriage returns.
|
||||
|
||||
Rich's spinner uses carriage returns to overwrite lines in place.
|
||||
We process these to only keep the latest content, preventing
|
||||
duplicate spinner frames from flooding the output.
|
||||
"""
|
||||
# return self.export_text(styles=True)
|
||||
value = self.buffer.getvalue()
|
||||
self.buffer.truncate(0)
|
||||
self.buffer.seek(0)
|
||||
# Process carriage returns to handle spinner overwrites
|
||||
return process_carriage_returns(value)
|
||||
return value
|
||||
|
||||
def status(
|
||||
self,
|
||||
|
|
@ -101,15 +36,10 @@ class ThreadSafeConsole(Console):
|
|||
*,
|
||||
spinner: str = "dots",
|
||||
spinner_style: "StyleType" = "status.spinner",
|
||||
speed: float = 1.0,
|
||||
refresh_per_second: float = 4,
|
||||
speed: float = 0.1, # Much slower speed
|
||||
refresh_per_second: float = 0.5, # Much slower refresh rate (every 2 seconds)
|
||||
) -> "Status":
|
||||
"""
|
||||
Return a Rich Status with animation.
|
||||
|
||||
The carriage returns from the spinner animation are processed
|
||||
in get_output() to prevent duplicate lines.
|
||||
"""
|
||||
status_renderable = Status(
|
||||
status,
|
||||
console=self,
|
||||
|
|
|
|||
|
|
@ -465,8 +465,7 @@ class OpContainer:
|
|||
return cached_data, 0, curr_logs
|
||||
|
||||
# Try to load from checkpoint if available
|
||||
# Skip if this operation has bypass_cache: true
|
||||
if not is_build and not self.config.get("bypass_cache", False):
|
||||
if not is_build:
|
||||
attempted_input_data = self.runner._load_from_checkpoint_if_exists(
|
||||
self.name.split("/")[0], self.name.split("/")[-1]
|
||||
)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,735 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from docetl.reasoning_optimizer.directives import Directive
|
||||
from docetl.runner import DSLRunner
|
||||
from docetl.utils import extract_output_from_json
|
||||
|
||||
|
||||
class Node:
|
||||
"""
|
||||
A Node class for Monte Carlo Tree Search that represents a state in the search tree.
|
||||
|
||||
Each node holds:
|
||||
- YAML file path and parsed content
|
||||
- Visit count and value for UCB calculation
|
||||
- Parent and children relationships
|
||||
- Methods for tree traversal and expansion
|
||||
"""
|
||||
|
||||
# A class-level counter for unique IDs
|
||||
_id_counter = 0
|
||||
|
||||
@classmethod
|
||||
def get_next_id(cls) -> int:
|
||||
"""
|
||||
Get the next available ID from the counter without incrementing it.
|
||||
|
||||
Returns:
|
||||
int: The next ID that would be assigned
|
||||
"""
|
||||
return cls._id_counter
|
||||
|
||||
@classmethod
|
||||
def increment_id_counter(cls) -> int:
|
||||
"""
|
||||
Increment the ID counter and return the new ID.
|
||||
|
||||
Returns:
|
||||
int: The newly assigned ID
|
||||
"""
|
||||
new_id = cls._id_counter
|
||||
cls._id_counter += 1
|
||||
return new_id
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
yaml_file_path: str,
|
||||
parent: Optional[Node] = None,
|
||||
c: float = 1.414,
|
||||
message_history=[],
|
||||
id: Optional[int] = None,
|
||||
is_multi_instance: bool = False,
|
||||
console=None,
|
||||
):
|
||||
"""
|
||||
Initialize a Node with YAML file information.
|
||||
|
||||
Args:
|
||||
yaml_file_path: Path to the YAML configuration file
|
||||
parent: Parent node in the search tree
|
||||
c: Exploration constant for UCB calculation (default: sqrt(2))
|
||||
is_multi_instance: Whether this node is a multi-instance candidate (default: False)
|
||||
console: Console instance for logging (default: None, uses DOCETL_CONSOLE)
|
||||
"""
|
||||
from docetl.console import DOCETL_CONSOLE
|
||||
|
||||
self.console = console if console is not None else DOCETL_CONSOLE
|
||||
self.yaml_file_path = yaml_file_path
|
||||
self.parsed_yaml = self._load_yaml()
|
||||
# Where the JSON results will be written (if defined in the YAML). This is
|
||||
# useful later for evaluation without having to guess filenames.
|
||||
try:
|
||||
self.result_path: str | None = (
|
||||
self.parsed_yaml.get("pipeline", {}).get("output", {}).get("path")
|
||||
)
|
||||
except Exception:
|
||||
self.result_path = None
|
||||
self.on_frontier = False
|
||||
self.used_actions = {}
|
||||
|
||||
self.op_dict = {} # Dict: op_name -> op
|
||||
for op in self.parsed_yaml["operations"]:
|
||||
op_name = op["name"]
|
||||
self.op_dict[op_name] = op
|
||||
self.used_actions[op_name] = set()
|
||||
self.visits = 0
|
||||
self.value = 0
|
||||
self.parent = parent
|
||||
self.children = []
|
||||
self.c = c # Exploration constant for UCB
|
||||
self.cost = -1.0
|
||||
self.scaled_cost = -1.0 # Scaled cost in [0,1] range for reward calculations
|
||||
self.sample_result = []
|
||||
self.latest_action = None # Latest action that led to this node
|
||||
self.optimization_goal = (
|
||||
None # The optimization goal used for this node (e.g., 'acc', 'cost')
|
||||
)
|
||||
|
||||
# Message history from root to this node (accumulated LLM conversations)
|
||||
self.message_history = message_history
|
||||
|
||||
# Memo list to track (directive, target_operator) pairs from root to this node
|
||||
self.memo = []
|
||||
|
||||
# Flag to indicate if this is a multi-instance candidate
|
||||
self.is_multi_instance = is_multi_instance
|
||||
|
||||
# Assign a unique ID to this node
|
||||
if id:
|
||||
self.id = id
|
||||
else:
|
||||
self.id = Node._id_counter
|
||||
Node._id_counter += 1
|
||||
|
||||
def execute_plan(self, max_threads: Optional[int] = None) -> tuple[float, list]:
|
||||
"""
|
||||
This method execute the query plan by running the YAML file with docetl.
|
||||
|
||||
Args:
|
||||
max_threads (Optional[int]): Maximum number of threads to use for running operations.
|
||||
|
||||
Returns:
|
||||
tuple[float, list]: A tuple containing (total_cost, result_data)
|
||||
|
||||
Raises:
|
||||
Exception: If the pipeline execution fails.
|
||||
"""
|
||||
|
||||
self.console.log(f"[dim]EXECUTING PLAN:[/dim] {self.yaml_file_path}")
|
||||
|
||||
# Get the current working directory (where the user called the command)
|
||||
cwd = os.getcwd()
|
||||
|
||||
# Load .env file from the current working directory if it exists
|
||||
env_file = os.path.join(cwd, ".env")
|
||||
if os.path.exists(env_file):
|
||||
load_dotenv(env_file)
|
||||
|
||||
try:
|
||||
runner = DSLRunner.from_yaml(self.yaml_file_path, max_threads=max_threads)
|
||||
|
||||
# Print the query plan
|
||||
runner.print_query_plan()
|
||||
|
||||
# Load datasets and execute the pipeline
|
||||
runner.load()
|
||||
|
||||
# Execute the pipeline and get the result data
|
||||
if runner.last_op_container:
|
||||
result_data, _, _ = runner.last_op_container.next()
|
||||
runner.save(result_data)
|
||||
else:
|
||||
result_data = []
|
||||
|
||||
# Get the total cost
|
||||
total_cost = runner.total_cost
|
||||
|
||||
# Reset the environment
|
||||
runner.reset_env()
|
||||
|
||||
self.cost = total_cost
|
||||
|
||||
try:
|
||||
self.sample_result = extract_output_from_json(self.yaml_file_path)[:1]
|
||||
except Exception as e:
|
||||
self.console.log(
|
||||
f"[yellow]Error extracting output from JSON for {self.yaml_file_path}: {e}[/yellow]"
|
||||
)
|
||||
self.sample_result = []
|
||||
|
||||
return total_cost
|
||||
|
||||
except Exception as e:
|
||||
self.cost = -1 # Indicate failure
|
||||
self.value = -float("inf")
|
||||
|
||||
# Log -inf occurrence for debugging
|
||||
self._log_inf_occurrence("execution_failure", str(e), self.yaml_file_path)
|
||||
|
||||
raise Exception(f"Failed to execute plan {self.yaml_file_path}: {str(e)}")
|
||||
|
||||
def _load_yaml(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Load and parse the YAML file.
|
||||
|
||||
Returns:
|
||||
Parsed YAML content as a dictionary
|
||||
"""
|
||||
|
||||
try:
|
||||
with open(self.yaml_file_path, "r", encoding="utf-8") as file:
|
||||
return yaml.safe_load(file)
|
||||
except Exception as e:
|
||||
self.console.log(
|
||||
f"[yellow]Error loading YAML file {self.yaml_file_path}: {e}[/yellow]"
|
||||
)
|
||||
return {}
|
||||
|
||||
def best_child(self) -> Node:
|
||||
"""
|
||||
Return the child with the highest UCB (Upper Confidence Bound) value.
|
||||
If there are ties, randomly select among the tied children.
|
||||
|
||||
UCB formula: value/visits + c * sqrt(ln(parent_visits) / visits)
|
||||
|
||||
Returns:
|
||||
Child node with highest UCB, or None if no children exist
|
||||
"""
|
||||
|
||||
def ucb(child: Node) -> float:
|
||||
if child.cost == -1 or child.visits == 0:
|
||||
return float("-inf")
|
||||
exploitation = child.value / child.visits
|
||||
exploration = self.c * math.sqrt(math.log(self.visits) / child.visits)
|
||||
return exploitation + exploration
|
||||
|
||||
# Print visits and value for each child
|
||||
for child in self.children:
|
||||
self.console.log(
|
||||
f"[dim]Child {child.yaml_file_path}: visits = {child.visits}, value = {child.value}[/dim]"
|
||||
)
|
||||
|
||||
# Calculate UCB values for all children
|
||||
ucb_values = [(child, ucb(child)) for child in self.children]
|
||||
|
||||
# Find the maximum UCB value
|
||||
max_ucb = max(ucb_values, key=lambda x: x[1])[1]
|
||||
|
||||
# Find all children with the maximum UCB value (ties)
|
||||
tied_children = [child for child, ucb_val in ucb_values if ucb_val == max_ucb]
|
||||
|
||||
# Randomly select among tied children
|
||||
return random.choice(tied_children)
|
||||
|
||||
def add_child(self, child: Node):
|
||||
"""
|
||||
Add a new child node during tree expansion.
|
||||
|
||||
Args:
|
||||
yaml_file_path: Path to the YAML file for the new child node
|
||||
|
||||
Returns:
|
||||
The newly created child node
|
||||
"""
|
||||
|
||||
self.children.append(child)
|
||||
child.parent = self
|
||||
|
||||
def is_leaf(self) -> bool:
|
||||
"""
|
||||
Check if this node is a leaf (has no children).
|
||||
|
||||
Returns:
|
||||
True if the node has no children, False otherwise
|
||||
"""
|
||||
return len(self.children) == 0
|
||||
|
||||
def mark_action_used(self, op_name, action: Directive):
|
||||
"""
|
||||
Mark a rewrite action as used.
|
||||
|
||||
Args:
|
||||
action: The action identifier to mark as used
|
||||
"""
|
||||
self.used_actions[op_name].add(action)
|
||||
|
||||
def is_root(self) -> bool:
|
||||
"""
|
||||
Check if this node is the root (has no parent).
|
||||
|
||||
Returns:
|
||||
True if the node has no parent, False otherwise
|
||||
"""
|
||||
return self.parent is None
|
||||
|
||||
def update_value(self, value: float):
|
||||
"""
|
||||
Update the node's value (typically after a simulation).
|
||||
|
||||
Args:
|
||||
value: The value to add to the current node value
|
||||
"""
|
||||
# Guard against NaN and -inf values to prevent corruption of node.value
|
||||
# Don't backpropagate -inf (failed evaluations) or NaN to parent nodes
|
||||
if (
|
||||
value is None
|
||||
or (isinstance(value, float) and (value != value))
|
||||
or value == float("-inf")
|
||||
): # NaN or -inf check
|
||||
self.console.log(
|
||||
f"[yellow]⚠️ Skipping backpropagation of -inf / NaN value to node {self.get_id()}[/yellow]"
|
||||
)
|
||||
# Log -inf occurrence for debugging
|
||||
self._log_inf_occurrence(
|
||||
"backpropagation_skipped",
|
||||
f"Skipped backpropagation of value: {value}",
|
||||
self.yaml_file_path,
|
||||
)
|
||||
return
|
||||
self.value = self.value + value
|
||||
|
||||
def update_visit(self):
|
||||
"""
|
||||
Update the node's visit by 1 (typically after a simulation).
|
||||
"""
|
||||
self.visits += 1
|
||||
|
||||
def get_ucb(self) -> float:
|
||||
"""
|
||||
Calculate the UCB value for this node.
|
||||
|
||||
Returns:
|
||||
UCB value for this node
|
||||
"""
|
||||
if self.visits == 0:
|
||||
return float("inf")
|
||||
if self.parent is None:
|
||||
return self.value / self.visits
|
||||
|
||||
exploitation = self.value / self.visits
|
||||
exploration = self.c * math.sqrt(math.log(self.parent.visits) / self.visits)
|
||||
return exploitation + exploration
|
||||
|
||||
def get_id(self) -> int:
|
||||
"""
|
||||
Return the unique identifier for this node.
|
||||
Returns:
|
||||
int: The unique ID of the node
|
||||
"""
|
||||
return self.id
|
||||
|
||||
def set_id_to_counter(self):
|
||||
"""
|
||||
Change this node's ID to the next available counter ID.
|
||||
This is used after selecting the best multi-instance candidate.
|
||||
Also renames the associated files to match the new ID.
|
||||
|
||||
Returns:
|
||||
int: The new ID assigned to this node
|
||||
"""
|
||||
old_id = self.id
|
||||
new_id = self.increment_id_counter()
|
||||
|
||||
# Rename files to match the new ID
|
||||
self._rename_files_for_new_id(old_id, new_id)
|
||||
|
||||
self.id = new_id
|
||||
return self.id
|
||||
|
||||
def _log_inf_occurrence(
|
||||
self, failure_type: str, error_message: str, yaml_path: str
|
||||
):
|
||||
"""
|
||||
Log -inf occurrences to a dedicated log file for debugging.
|
||||
|
||||
Args:
|
||||
failure_type: Type of failure (e.g., "execution_failure", "evaluation_failure")
|
||||
error_message: The error message that caused the failure
|
||||
yaml_path: Path to the YAML file that failed
|
||||
"""
|
||||
try:
|
||||
# Create log directory if it doesn't exist
|
||||
log_dir = os.path.join(os.path.dirname(yaml_path), "inf_logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# Create log file path
|
||||
log_file = os.path.join(log_dir, "inf_occurrences.txt")
|
||||
|
||||
# Get current timestamp
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# Get node information
|
||||
node_id = getattr(self, "id", "unknown")
|
||||
latest_action = getattr(self, "latest_action", "unknown")
|
||||
parent_id = getattr(self.parent, "id", "root") if self.parent else "root"
|
||||
|
||||
# Format log entry
|
||||
log_entry = f"""
|
||||
{'='*80}
|
||||
Timestamp: {timestamp}
|
||||
Node ID: {node_id}
|
||||
Parent ID: {parent_id}
|
||||
Latest Action: {latest_action}
|
||||
Failure Type: {failure_type}
|
||||
YAML Path: {yaml_path}
|
||||
Error Message: {error_message}
|
||||
{'='*80}
|
||||
"""
|
||||
|
||||
# Append to log file
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(log_entry)
|
||||
|
||||
except Exception as log_error:
|
||||
self.console.log(
|
||||
f"[yellow]Warning: Failed to log -inf occurrence: {log_error}[/yellow]"
|
||||
)
|
||||
|
||||
def _rename_files_for_new_id(self, old_id, new_id):
|
||||
"""
|
||||
Rename the YAML and output files to match the new node ID.
|
||||
|
||||
Args:
|
||||
old_id: The old node ID (e.g., "7-2")
|
||||
new_id: The new node ID (e.g., 8)
|
||||
"""
|
||||
try:
|
||||
# Rename YAML file
|
||||
if os.path.exists(self.yaml_file_path):
|
||||
old_yaml_path = self.yaml_file_path
|
||||
new_yaml_path = old_yaml_path.replace(
|
||||
f"_{old_id}.yaml", f"_{new_id}.yaml"
|
||||
)
|
||||
os.rename(old_yaml_path, new_yaml_path)
|
||||
self.yaml_file_path = new_yaml_path
|
||||
except Exception as e:
|
||||
self.console.log(
|
||||
f"[yellow]Warning: Could not rename YAML file from {old_id} to {new_id}: {e}[/yellow]"
|
||||
)
|
||||
|
||||
try:
|
||||
# Rename output JSON file
|
||||
if self.result_path and os.path.exists(self.result_path):
|
||||
old_result_path = self.result_path
|
||||
new_result_path = old_result_path.replace(
|
||||
f"_{old_id}.json", f"_{new_id}.json"
|
||||
)
|
||||
os.rename(old_result_path, new_result_path)
|
||||
self.result_path = new_result_path
|
||||
|
||||
# Also update the YAML to point to the new output file
|
||||
if hasattr(self, "parsed_yaml") and self.parsed_yaml:
|
||||
self.parsed_yaml["pipeline"]["output"]["path"] = new_result_path
|
||||
# Rewrite the YAML file with the updated output path
|
||||
with open(self.yaml_file_path, "w") as f:
|
||||
yaml.dump(
|
||||
self.parsed_yaml,
|
||||
f,
|
||||
default_flow_style=False,
|
||||
allow_unicode=True,
|
||||
sort_keys=False,
|
||||
)
|
||||
except Exception as e:
|
||||
self.console.log(
|
||||
f"[yellow]Warning: Could not rename result file from {old_id} to {new_id}: {e}[/yellow]"
|
||||
)
|
||||
|
||||
def add_memo_entry(self, directive_name: str, target_operator: str):
|
||||
"""
|
||||
Add a (directive, target_operator) pair to the memo list.
|
||||
|
||||
Args:
|
||||
directive_name: Name of the directive that was applied
|
||||
target_operator: Name of the target operator the directive was applied to
|
||||
"""
|
||||
self.memo.append((directive_name, target_operator))
|
||||
|
||||
def get_optimization_path(self) -> str:
|
||||
"""
|
||||
Get a formatted string showing the optimization path from root to this node.
|
||||
|
||||
Returns:
|
||||
Formatted path string like "ROOT → chaining(extract_clause) → gleaning(extract_entity)"
|
||||
"""
|
||||
if not self.memo:
|
||||
return "ROOT"
|
||||
|
||||
path_parts = ["ROOT"]
|
||||
for directive, target_op in self.memo:
|
||||
path_parts.append(f"{directive}({target_op})")
|
||||
|
||||
return " → ".join(path_parts)
|
||||
|
||||
def get_exploration_tree_summary(
|
||||
self, root: Node, node_accuracies: Optional[Dict["Node", float]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a comprehensive but concise summary of the entire exploration tree.
|
||||
This gives the LLM agent complete context about what has been tried.
|
||||
|
||||
Args:
|
||||
root: Root node of the tree
|
||||
node_accuracies: Optional dictionary mapping nodes to their accuracy values
|
||||
|
||||
Returns:
|
||||
Formatted tree summary optimized for LLM consumption
|
||||
"""
|
||||
|
||||
# Collect all exploration paths and their outcomes
|
||||
successful_paths = []
|
||||
failed_paths = []
|
||||
|
||||
def traverse_tree(node, current_path="ROOT"):
|
||||
# Add this node's path if it's not the root
|
||||
if node != root:
|
||||
if hasattr(node, "cost") and node.cost != -1:
|
||||
# Include accuracy if available
|
||||
if node_accuracies and node in node_accuracies:
|
||||
accuracy = node_accuracies[node]
|
||||
successful_paths.append(
|
||||
f"{current_path} (cost: ${node.cost:.2f}, accuracy: {accuracy:.3f})"
|
||||
)
|
||||
else:
|
||||
successful_paths.append(
|
||||
f"{current_path} (cost: ${node.cost:.2f})"
|
||||
)
|
||||
else:
|
||||
failed_paths.append(f"{current_path} (failed)")
|
||||
|
||||
# Traverse children
|
||||
for child in node.children:
|
||||
if child.memo:
|
||||
# Get the most recent directive-operator pair for this child
|
||||
latest_directive, latest_target = child.memo[-1]
|
||||
child_path = f"{current_path} → {latest_directive}({latest_target})"
|
||||
else:
|
||||
child_path = f"{current_path} → {child.latest_action.name if child.latest_action else 'unknown'}"
|
||||
traverse_tree(child, child_path)
|
||||
|
||||
traverse_tree(root)
|
||||
|
||||
# Group paths by directive patterns for better insights
|
||||
directive_patterns = {}
|
||||
for path in successful_paths + failed_paths:
|
||||
# Extract directive sequence
|
||||
directives = []
|
||||
parts = path.split(" → ")
|
||||
for part in parts[1:]: # Skip ROOT
|
||||
if "(" in part:
|
||||
directive = part.split("(")[0]
|
||||
directives.append(directive)
|
||||
|
||||
if directives:
|
||||
pattern_key = " → ".join(directives)
|
||||
if pattern_key not in directive_patterns:
|
||||
directive_patterns[pattern_key] = {"successful": [], "failed": []}
|
||||
|
||||
if "(failed)" in path:
|
||||
directive_patterns[pattern_key]["failed"].append(path)
|
||||
else:
|
||||
directive_patterns[pattern_key]["successful"].append(path)
|
||||
|
||||
# Build summary
|
||||
summary_parts = []
|
||||
|
||||
# Current position
|
||||
current_path = self.get_optimization_path()
|
||||
summary_parts.append(f"CURRENT POSITION: {current_path}")
|
||||
|
||||
# Successful explorations
|
||||
if successful_paths:
|
||||
summary_parts.append(
|
||||
f"\nSUCCESSFUL EXPLORATIONS ({len(successful_paths)} total):"
|
||||
)
|
||||
|
||||
# Show best performers first - sort by accuracy (highest first), then by cost (lowest first)
|
||||
def extract_sort_key(path):
|
||||
if "cost: $" not in path:
|
||||
return (
|
||||
0,
|
||||
float("inf"),
|
||||
) # lowest accuracy, highest cost for failed cases
|
||||
try:
|
||||
cost_part = path.split("cost: $")[1]
|
||||
if ", accuracy:" in cost_part:
|
||||
cost_str = cost_part.split(", accuracy:")[0]
|
||||
accuracy_str = cost_part.split(", accuracy:")[1].split(")")[0]
|
||||
return (
|
||||
-float(accuracy_str),
|
||||
float(cost_str),
|
||||
) # negative accuracy for descending order
|
||||
else:
|
||||
cost_str = cost_part.split(")")[0]
|
||||
return (
|
||||
0,
|
||||
float(cost_str),
|
||||
) # no accuracy info, sort by cost only
|
||||
except (ValueError, IndexError):
|
||||
return (0, float("inf"))
|
||||
|
||||
sorted_successful = sorted(successful_paths, key=extract_sort_key)
|
||||
for i, path in enumerate(sorted_successful):
|
||||
summary_parts.append(f" {i+1}. {path}")
|
||||
|
||||
return "\n".join(summary_parts)
|
||||
|
||||
def get_memo_for_llm(
|
||||
self, root_node: Node, node_accuracies: Optional[Dict["Node", float]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Get a comprehensive exploration summary formatted for LLM prompts.
|
||||
|
||||
Args:
|
||||
root_node: Root node of the tree
|
||||
node_accuracies: Optional dictionary mapping nodes to their accuracy values
|
||||
|
||||
Returns:
|
||||
Complete exploration context to guide decision making
|
||||
"""
|
||||
return self.get_exploration_tree_summary(root_node, node_accuracies)
|
||||
|
||||
def delete(self, selected_node_final_id=None):
|
||||
"""
|
||||
Delete this node and clean up its resources.
|
||||
For multi-instance candidates, moves files to backup_plans folder instead of deleting.
|
||||
|
||||
Args:
|
||||
selected_node_final_id: The final ID of the selected node (for backup naming)
|
||||
|
||||
This method:
|
||||
1. Removes the node from its parent's children list
|
||||
2. Moves multi-instance files to backup or deletes regular files
|
||||
3. Clears references to prevent memory leaks
|
||||
|
||||
The global node ID counter is maintained (not decremented) to ensure
|
||||
unique IDs across the lifetime of the program.
|
||||
"""
|
||||
# Remove from parent's children list
|
||||
if self.parent and self in self.parent.children:
|
||||
self.parent.children.remove(self)
|
||||
|
||||
# Check if this is a multi-instance candidate using the explicit flag
|
||||
if self.is_multi_instance and selected_node_final_id is not None:
|
||||
# Move files to backup_plans folder with renamed IDs
|
||||
self._backup_multi_instance_files(selected_node_final_id)
|
||||
else:
|
||||
# Regular deletion for non-multi-instance nodes
|
||||
self._delete_files_permanently()
|
||||
|
||||
# Clear references to help with garbage collection
|
||||
self.parent = None
|
||||
self.children = []
|
||||
self.parsed_yaml = {}
|
||||
self.message_history = []
|
||||
self.memo = []
|
||||
self.sample_result = []
|
||||
|
||||
def _backup_multi_instance_files(self, selected_node_final_id):
|
||||
"""
|
||||
Move multi-instance files to backup_plans folder with new naming scheme.
|
||||
|
||||
Args:
|
||||
selected_node_final_id: The final ID of the selected node
|
||||
"""
|
||||
try:
|
||||
# Extract the instantiation number from current ID (e.g., "7-2" -> "2")
|
||||
current_id_str = str(self.id)
|
||||
if "-" in current_id_str:
|
||||
instantiation_num = current_id_str.split("-")[1]
|
||||
new_backup_id = f"{selected_node_final_id}-{instantiation_num}"
|
||||
else:
|
||||
new_backup_id = f"{selected_node_final_id}-backup"
|
||||
|
||||
# Create backup_plans directory
|
||||
if os.path.exists(self.yaml_file_path):
|
||||
yaml_dir = os.path.dirname(self.yaml_file_path)
|
||||
backup_dir = os.path.join(yaml_dir, "backup_plans")
|
||||
os.makedirs(backup_dir, exist_ok=True)
|
||||
|
||||
# Move YAML file
|
||||
yaml_filename = os.path.basename(self.yaml_file_path)
|
||||
# Replace old ID with new backup ID in filename
|
||||
new_yaml_filename = yaml_filename.replace(
|
||||
f"_{current_id_str}.yaml", f"_{new_backup_id}.yaml"
|
||||
)
|
||||
backup_yaml_path = os.path.join(backup_dir, new_yaml_filename)
|
||||
|
||||
if os.path.exists(self.yaml_file_path):
|
||||
os.rename(self.yaml_file_path, backup_yaml_path)
|
||||
self.console.log(
|
||||
f"[dim]Moved YAML to backup:[/dim] {self.yaml_file_path} → {backup_yaml_path}"
|
||||
)
|
||||
|
||||
# Move result JSON file
|
||||
if self.result_path and os.path.exists(self.result_path):
|
||||
result_filename = os.path.basename(self.result_path)
|
||||
new_result_filename = result_filename.replace(
|
||||
f"_{current_id_str}.json", f"_{new_backup_id}.json"
|
||||
)
|
||||
backup_result_path = os.path.join(backup_dir, new_result_filename)
|
||||
|
||||
os.rename(self.result_path, backup_result_path)
|
||||
self.console.log(
|
||||
f"[dim]Moved result to backup:[/dim] {self.result_path} → {backup_result_path}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.console.log(
|
||||
f"[yellow]Warning: Could not backup files for multi-instance node {self.id}: {e}[/yellow]"
|
||||
)
|
||||
# Fall back to regular deletion if backup fails
|
||||
self._delete_files_permanently()
|
||||
|
||||
def _delete_files_permanently(self):
|
||||
"""
|
||||
Permanently delete files (for non-multi-instance nodes).
|
||||
"""
|
||||
try:
|
||||
if os.path.exists(self.yaml_file_path) and str(
|
||||
self.yaml_file_path
|
||||
).endswith((".yaml", ".yml")):
|
||||
# Only delete if it looks like a generated file (contains numbers)
|
||||
if any(
|
||||
char.isdigit()
|
||||
for char in os.path.basename(str(self.yaml_file_path))
|
||||
):
|
||||
os.remove(self.yaml_file_path)
|
||||
self.console.log(
|
||||
f"[dim]Deleted generated YAML file:[/dim] {self.yaml_file_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.console.log(
|
||||
f"[yellow]Warning: Could not delete YAML file {self.yaml_file_path}: {e}[/yellow]"
|
||||
)
|
||||
|
||||
try:
|
||||
if self.result_path and os.path.exists(self.result_path):
|
||||
# Only delete if it looks like a generated file (contains numbers)
|
||||
if any(char.isdigit() for char in os.path.basename(self.result_path)):
|
||||
os.remove(self.result_path)
|
||||
self.console.log(
|
||||
f"[dim]Deleted generated result file:[/dim] {self.result_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.console.log(
|
||||
f"[yellow]Warning: Could not delete result file {self.result_path}: {e}[/yellow]"
|
||||
)
|
||||
|
|
@ -1,425 +0,0 @@
|
|||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg") # Use non-interactive backend
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from .Node import Node
|
||||
|
||||
|
||||
class ParetoFrontier:
|
||||
"""
|
||||
Pareto Frontier class for managing cost-accuracy optimization.
|
||||
|
||||
This class maintains a collection of plans, estimates their accuracy through
|
||||
pairwise comparisons, constructs and updates the Pareto frontier, and provides
|
||||
value calculations for MCTS integration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_rewards: Dict[str, float],
|
||||
action_cost_changes: Dict[str, float],
|
||||
action_accuracy_changes: Dict[str, float],
|
||||
dataset_name: str,
|
||||
evaluate_func: Callable[[str], Dict[str, Any]],
|
||||
console=None,
|
||||
):
|
||||
"""
|
||||
Initialize the Pareto Frontier.
|
||||
|
||||
Args:
|
||||
action_rewards: Reference to MCTS action_rewards dictionary
|
||||
action_cost_changes: Reference to MCTS action_cost_changes dictionary
|
||||
action_accuracy_changes: Reference to MCTS action_accuracy_changes dictionary
|
||||
dataset_name: Name of the dataset being optimized (for evaluation and metric selection)
|
||||
evaluate_func: Evaluation function (results_file_path: str) -> dict
|
||||
console: Console instance for logging (default: None, uses DOCETL_CONSOLE)
|
||||
"""
|
||||
from docetl.console import DOCETL_CONSOLE
|
||||
|
||||
self.console = console if console is not None else DOCETL_CONSOLE
|
||||
self.dataset_name = dataset_name
|
||||
self.evaluate_func = evaluate_func
|
||||
|
||||
# Dataset-to-primary-metric mapping
|
||||
self.dataset_metrics = {
|
||||
"cuad": "avg_f1",
|
||||
"blackvault": "avg_distinct_locations",
|
||||
"game_reviews": "weighted_score",
|
||||
"medec": "combined_score",
|
||||
"sustainability": "combined_score",
|
||||
"biodex": "avg_rp_at_5", # Optimize for RP@5 as specified
|
||||
"facility": "combined_score",
|
||||
}
|
||||
|
||||
# Internal state
|
||||
self.plans: List[Node] = []
|
||||
self.plans_accuracy: Dict[Node, float] = {}
|
||||
self.plans_cost: Dict[Node, float] = {} # Real costs
|
||||
self.frontier_plans: List[Node] = [] # List of nodes on frontier
|
||||
self.frontier_data: List[List[int]] = (
|
||||
[]
|
||||
) # List of [acc, real_cost] of nodes on frontier
|
||||
self.action_rewards = action_rewards
|
||||
self.action_cost_changes = action_cost_changes
|
||||
self.action_accuracy_changes = action_accuracy_changes
|
||||
|
||||
# Distance to current Pareto frontier: positive for on-frontier, negative for off-frontier
|
||||
self.node_distances: Dict[Node, float] = {}
|
||||
|
||||
# Root plan reference point
|
||||
self.root_accuracy: Optional[float] = None
|
||||
self.root_cost: Optional[float] = None
|
||||
|
||||
def add_plan(self, node: Node) -> Dict[Node, int]:
|
||||
"""
|
||||
Add a new plan (Node) to the frontier and estimate its accuracy.
|
||||
|
||||
Args:
|
||||
node: Node object representing the plan
|
||||
|
||||
Returns:
|
||||
Dict containing estimated accuracy, pareto_value, and other metrics
|
||||
"""
|
||||
if node.cost == -1: # Handle error case
|
||||
return {}
|
||||
|
||||
# Store plan information
|
||||
self.plans.append(node)
|
||||
self.plans_cost[node] = node.cost
|
||||
|
||||
# Estimate accuracy through pairwise comparisons
|
||||
if len(self.plans_accuracy) == 0:
|
||||
# First plan gets baseline accuracy
|
||||
estimated_accuracy = 0.5
|
||||
else:
|
||||
estimated_accuracy = self.estimate_accuracy_via_comparisons(node)
|
||||
|
||||
self.plans_accuracy[node] = estimated_accuracy
|
||||
|
||||
# Update Pareto frontier
|
||||
affected_nodes = self.update_pareto_frontier()
|
||||
if node not in affected_nodes:
|
||||
affected_nodes[node] = 0
|
||||
return affected_nodes
|
||||
|
||||
def add_plan_f1(self, node: Node, accuracy: float) -> Tuple[Dict[Node, int], bool]:
|
||||
"""
|
||||
Add a new plan (Node) to the frontier with pre-evaluated accuracy.
|
||||
|
||||
Args:
|
||||
node: Node object representing the plan
|
||||
accuracy: Pre-evaluated accuracy score for the node
|
||||
|
||||
Returns:
|
||||
Dict containing affected nodes, bool indicating wether the frontier is updated
|
||||
"""
|
||||
if node.cost == -1: # Handle error case
|
||||
self.plans_accuracy[node] = float("-inf")
|
||||
return {}, False
|
||||
|
||||
# Store plan information
|
||||
self.plans.append(node)
|
||||
self.plans_cost[node] = node.cost # Store real cost
|
||||
# Scaled cost will be calculated in update_pareto_frontier_HV
|
||||
|
||||
# Store the pre-evaluated accuracy
|
||||
self.plans_accuracy[node] = accuracy
|
||||
|
||||
# Update Pareto frontier
|
||||
affected_nodes, is_frontier_updated = self.update_pareto_frontier_HV(node)
|
||||
return affected_nodes, is_frontier_updated
|
||||
|
||||
def get_all_plans_summary(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get summary of all plans with their metrics.
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing plan information and metrics
|
||||
"""
|
||||
summaries = []
|
||||
for node in self.plans:
|
||||
summary = {
|
||||
"node": node.get_id(),
|
||||
"path": node.yaml_file_path,
|
||||
"cost": node.cost,
|
||||
"accuracy": self.plans_accuracy[node],
|
||||
"value": node.value,
|
||||
"is_frontier": node in self.frontier_plans,
|
||||
}
|
||||
summaries.append(summary)
|
||||
|
||||
return summaries
|
||||
|
||||
# Helper function to project point onto step function frontier
|
||||
def project_to_frontier(self, node_acc, node_cost, frontier_data):
|
||||
"""
|
||||
Project point onto the step function formed by frontier.
|
||||
For a step function interpretation, the reward is simply the vertical distance
|
||||
to the step function (accuracy distance only).
|
||||
"""
|
||||
if not frontier_data:
|
||||
return node_acc
|
||||
|
||||
# Sort frontier by cost (ascending)
|
||||
frontier_sorted = sorted(frontier_data, key=lambda x: x[1]) # Sort by cost
|
||||
|
||||
# Find the step function accuracy for this cost
|
||||
step_function_accuracy = (
|
||||
0.0 # Default if cost is lower than all frontier points
|
||||
)
|
||||
|
||||
for fp_acc, fp_cost in frontier_sorted:
|
||||
if node_cost >= fp_cost:
|
||||
# Cost is >= this frontier point's cost, so step function is at this accuracy
|
||||
step_function_accuracy = fp_acc
|
||||
else:
|
||||
# Cost is < this frontier point's cost, so we use the previous step
|
||||
break
|
||||
|
||||
# Return the vertical (accuracy) distance to the step function
|
||||
vertical_distance = abs(node_acc - step_function_accuracy)
|
||||
return vertical_distance
|
||||
|
||||
def _update_action_rewards(self, node: Node, reward: float) -> None:
|
||||
"""
|
||||
Update action rewards and track cost/accuracy changes based on the reward received by a node.
|
||||
Updates the cumulative sum for the latest action that led to this node.
|
||||
|
||||
Args:
|
||||
node: The node that received the reward
|
||||
reward: The reward value to incorporate
|
||||
"""
|
||||
if not node.latest_action or not self.action_rewards:
|
||||
return
|
||||
action = node.latest_action
|
||||
if action in self.action_rewards:
|
||||
# Update cumulative reward sum
|
||||
self.action_rewards[action] += reward
|
||||
|
||||
# Track cost and accuracy changes
|
||||
if (
|
||||
node.parent
|
||||
and node.parent in self.plans_cost
|
||||
and node in self.plans_cost
|
||||
):
|
||||
cost_change = self.plans_cost[node] - self.plans_cost[node.parent]
|
||||
self.action_cost_changes[action] += cost_change
|
||||
|
||||
if (
|
||||
node.parent
|
||||
and node.parent in self.plans_accuracy
|
||||
and node in self.plans_accuracy
|
||||
):
|
||||
accuracy_change = (
|
||||
self.plans_accuracy[node] - self.plans_accuracy[node.parent]
|
||||
)
|
||||
self.action_accuracy_changes[action] += accuracy_change
|
||||
|
||||
def update_pareto_frontier_HV(self, new_node) -> Tuple[Dict[Node, int], bool]:
|
||||
"""
|
||||
Update the Pareto frontier based on current plans and calculate hyper-volume indicator.
|
||||
"""
|
||||
|
||||
valid_nodes = [node for node in self.plans if node.cost != -1]
|
||||
affected_nodes = {}
|
||||
|
||||
if not valid_nodes:
|
||||
self.frontier_plans = []
|
||||
self.frontier_data = []
|
||||
return affected_nodes, False
|
||||
|
||||
# Save old frontier nodes before updating
|
||||
old_frontier_nodes = self.frontier_plans
|
||||
|
||||
# Sort by real cost for frontier calculation
|
||||
valid_nodes.sort(key=lambda node: self.plans_cost[node])
|
||||
|
||||
# Reconstruct old frontier data using real costs
|
||||
archive_frontier_data = []
|
||||
for node in old_frontier_nodes:
|
||||
if node in valid_nodes: # Only include valid nodes
|
||||
acc = self.plans_accuracy.get(node, float("-inf"))
|
||||
real_cost = self.plans_cost[node]
|
||||
archive_frontier_data.append([acc, real_cost])
|
||||
else:
|
||||
self.console.log(
|
||||
f"[yellow]INVALID NODE:[/yellow] {node.id}, [dim]cost:[/dim] {node.cost}, [dim]in_valid_nodes:[/dim] {node in valid_nodes}"
|
||||
)
|
||||
|
||||
frontier = []
|
||||
max_accuracy_so_far = float("-inf")
|
||||
|
||||
for node in valid_nodes:
|
||||
accuracy = self.plans_accuracy.get(node, 0.0)
|
||||
|
||||
# Plan is on frontier if it has higher accuracy than all lower-cost plans
|
||||
if accuracy > max_accuracy_so_far:
|
||||
frontier.append(node)
|
||||
max_accuracy_so_far = accuracy
|
||||
|
||||
new_frontier_data = []
|
||||
for node in frontier:
|
||||
acc = self.plans_accuracy.get(node)
|
||||
real_cost = self.plans_cost[node] # Use real cost
|
||||
new_frontier_data.append([acc, real_cost])
|
||||
|
||||
# Check if frontier actually changed
|
||||
old_frontier_set = set(old_frontier_nodes)
|
||||
new_frontier_set = set(frontier)
|
||||
frontier_updated = old_frontier_set != new_frontier_set
|
||||
|
||||
# Update affected nodes based on frontier changes
|
||||
for node in valid_nodes:
|
||||
node_real_cost = self.plans_cost[node]
|
||||
node_acc = self.plans_accuracy[node]
|
||||
|
||||
if node in new_frontier_set and node not in old_frontier_set:
|
||||
# Newly on frontier - reward based on vertical distance to OLD frontier step function
|
||||
node.on_frontier = True
|
||||
vertical_distance_to_old = self.project_to_frontier(
|
||||
node_acc, node_real_cost, archive_frontier_data
|
||||
)
|
||||
affected_nodes[node] = vertical_distance_to_old
|
||||
# Update node distances - positive for on frontier
|
||||
self.node_distances[node] = vertical_distance_to_old
|
||||
# Update action rewards
|
||||
self._update_action_rewards(node, vertical_distance_to_old)
|
||||
|
||||
elif (node not in new_frontier_set and node in old_frontier_set) or (
|
||||
node.id == new_node.id
|
||||
):
|
||||
# Newly off frontier - give negative reward based on vertical distance to NEW frontier step function
|
||||
node.on_frontier = False
|
||||
vertical_distance = self.project_to_frontier(
|
||||
node_acc, node_real_cost, new_frontier_data
|
||||
)
|
||||
affected_nodes[node] = -vertical_distance
|
||||
# Update node distances - negative for off frontier
|
||||
self.node_distances[node] = -vertical_distance
|
||||
# Update action rewards
|
||||
if node.id == new_node.id:
|
||||
self._update_action_rewards(node, -vertical_distance)
|
||||
elif node not in new_frontier_set:
|
||||
# stay off frontier nodes - update the reward to be negative vertical distance to the NEW frontier step function
|
||||
node.on_frontier = False
|
||||
vertical_distance = self.project_to_frontier(
|
||||
node_acc, node_real_cost, new_frontier_data
|
||||
)
|
||||
old_distance = self.node_distances.get(node, 0)
|
||||
distance_diff = -vertical_distance - old_distance
|
||||
affected_nodes[node] = distance_diff
|
||||
# Update node distances - negative for off frontier
|
||||
self.node_distances[node] = -vertical_distance
|
||||
|
||||
self.frontier_plans = frontier
|
||||
self.frontier_data = new_frontier_data
|
||||
if new_node.id > 0:
|
||||
graph_dir = str(new_node.yaml_file_path).rsplit("/", 1)[0] + "/graph/"
|
||||
os.makedirs(graph_dir, exist_ok=True)
|
||||
save_path = graph_dir + f"plan_{new_node.id}.png"
|
||||
self.plot_plans(save_path, new_node.id, str(new_node.yaml_file_path))
|
||||
return affected_nodes, frontier_updated
|
||||
|
||||
def plot_plans(self, save_path=None, plan_num=None, yaml_file=None):
|
||||
"""
|
||||
Plot all current plans as dots on a cost vs. accuracy graph, annotating each with its id.
|
||||
Frontier plans are blue, non-frontier plans are grey.
|
||||
|
||||
Args:
|
||||
save_path: If provided, save the plot to this path instead of showing it
|
||||
iteration_num: If provided, include iteration number in the title
|
||||
"""
|
||||
if plt is None:
|
||||
raise ImportError(
|
||||
"matplotlib is required for plotting. Please install it with 'pip install matplotlib'."
|
||||
)
|
||||
|
||||
plt.figure(figsize=(10, 8))
|
||||
|
||||
# Separate frontier and non-frontier plans
|
||||
frontier_nodes = [node for node in self.plans if node in self.frontier_plans]
|
||||
non_frontier_nodes = [
|
||||
node for node in self.plans if node not in self.frontier_plans
|
||||
]
|
||||
|
||||
# Plot non-frontier plans (grey)
|
||||
if non_frontier_nodes:
|
||||
costs = [self.plans_cost[node] for node in non_frontier_nodes]
|
||||
accuracies = [self.plans_accuracy[node] for node in non_frontier_nodes]
|
||||
ids = [node.get_id() for node in non_frontier_nodes]
|
||||
plt.scatter(costs, accuracies, color="grey", label="Off Frontier")
|
||||
for x, y, label in zip(costs, accuracies, ids):
|
||||
plt.annotate(
|
||||
str(label),
|
||||
(x, y),
|
||||
textcoords="offset points",
|
||||
xytext=(5, 5),
|
||||
ha="left",
|
||||
fontsize=9,
|
||||
color="grey",
|
||||
)
|
||||
|
||||
# Plot frontier plans (blue)
|
||||
if frontier_nodes:
|
||||
costs = [self.plans_cost[node] for node in frontier_nodes]
|
||||
accuracies = [self.plans_accuracy[node] for node in frontier_nodes]
|
||||
ids = [node.get_id() for node in frontier_nodes]
|
||||
plt.scatter(costs, accuracies, color="blue", label="Frontier")
|
||||
for x, y, label in zip(costs, accuracies, ids):
|
||||
plt.annotate(
|
||||
str(label),
|
||||
(x, y),
|
||||
textcoords="offset points",
|
||||
xytext=(5, 5),
|
||||
ha="left",
|
||||
fontsize=9,
|
||||
color="blue",
|
||||
)
|
||||
|
||||
plt.xlabel("Cost")
|
||||
plt.ylabel("Accuracy")
|
||||
|
||||
if plan_num is not None:
|
||||
plt.title(f"Plan {plan_num}: {yaml_file}")
|
||||
else:
|
||||
plt.title("Plans: Cost vs. Accuracy")
|
||||
|
||||
plt.grid(True, linestyle="--", alpha=0.5)
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return number of plans in the frontier."""
|
||||
return len(self.plans)
|
||||
|
||||
def __contains__(self, node: Node) -> bool:
|
||||
"""Check if plan is managed by this frontier."""
|
||||
return node in self.plans
|
||||
|
||||
def delete_plan(self, node: Node) -> None:
|
||||
"""
|
||||
Completely delete a node from all ParetoFrontier data structures.
|
||||
"""
|
||||
if node in self.plans:
|
||||
self.plans.remove(node)
|
||||
|
||||
accuracy = self.plans_accuracy.pop(node, None)
|
||||
cost = self.plans_cost.pop(node, None)
|
||||
|
||||
if node in self.frontier_plans:
|
||||
self.frontier_plans.remove(node)
|
||||
if accuracy is not None and cost is not None:
|
||||
try:
|
||||
self.frontier_data.remove([accuracy, cost])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
self.node_distances.pop(node, None)
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
"""
|
||||
MCTS (Monte Carlo Tree Search) module for DocETL optimization.
|
||||
|
||||
This module provides Monte Carlo Tree Search optimization for DocETL pipelines
|
||||
using Pareto frontier analysis and multi-objective optimization.
|
||||
"""
|
||||
|
||||
# Default list of available models for MOAR optimization
|
||||
# Defined here before imports to avoid circular import issues
|
||||
AVAILABLE_MODELS = [
|
||||
# "gpt-5",
|
||||
# "gpt-5-mini",
|
||||
# "gpt-5-nano",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini",
|
||||
"gpt-4.1-nano",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
# "gemini-2.5-pro",
|
||||
# "gemini-2.5-flash",
|
||||
# "gemini-2.5-flash-lite"
|
||||
]
|
||||
|
||||
from .MOARSearch import MOARSearch # noqa: E402
|
||||
from .Node import Node # noqa: E402
|
||||
from .ParetoFrontier import ParetoFrontier # noqa: E402
|
||||
|
||||
__all__ = [
|
||||
"MOARSearch",
|
||||
"Node",
|
||||
"ParetoFrontier",
|
||||
"AVAILABLE_MODELS"
|
||||
]
|
||||
|
|
@ -1,305 +0,0 @@
|
|||
"""
|
||||
Helper functions for running MOAR optimizer from CLI.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import yaml
|
||||
|
||||
from docetl.console import DOCETL_CONSOLE
|
||||
from docetl.moar import MOARSearch
|
||||
from docetl.reasoning_optimizer.directives import ALL_DIRECTIVES
|
||||
from docetl.utils_dataset import get_dataset_stats
|
||||
from docetl.utils_evaluation import load_custom_evaluate_func
|
||||
|
||||
|
||||
def infer_dataset_info(yaml_path: str, config: dict) -> tuple[str, str]:
|
||||
"""
|
||||
Infer dataset path and name from YAML config.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to YAML file
|
||||
config: Full YAML config dictionary
|
||||
|
||||
Returns:
|
||||
tuple: (dataset_path, dataset_name)
|
||||
|
||||
Raises:
|
||||
ValueError: If datasets section is missing or empty
|
||||
"""
|
||||
datasets = config.get("datasets", {})
|
||||
if not datasets:
|
||||
raise ValueError("YAML config must contain a 'datasets' section")
|
||||
|
||||
# Get the first dataset (assuming single dataset per config)
|
||||
dataset_name, dataset_config = next(iter(datasets.items()))
|
||||
dataset_path = dataset_config.get("path")
|
||||
|
||||
if not dataset_path:
|
||||
raise ValueError(f"Dataset '{dataset_name}' in config must have a 'path' field")
|
||||
|
||||
# Resolve relative paths - try as-is first, then relative to YAML file location
|
||||
if Path(dataset_path).is_absolute():
|
||||
# Already absolute, use as-is
|
||||
pass
|
||||
elif Path(dataset_path).exists():
|
||||
# Path exists as-is (relative to current working directory)
|
||||
dataset_path = str(Path(dataset_path).resolve())
|
||||
else:
|
||||
# Try resolving relative to YAML file location
|
||||
yaml_dir = Path(yaml_path).parent
|
||||
resolved_path = yaml_dir / dataset_path
|
||||
if resolved_path.exists():
|
||||
dataset_path = str(resolved_path.resolve())
|
||||
else:
|
||||
# Use the resolved path anyway (might be created later)
|
||||
dataset_path = str(resolved_path.resolve())
|
||||
|
||||
return dataset_path, dataset_name
|
||||
|
||||
|
||||
def load_evaluation_function(config: dict, dataset_file_path: str):
|
||||
"""
|
||||
Load evaluation function from optimizer_config.
|
||||
|
||||
Args:
|
||||
config: optimizer_config dictionary from YAML
|
||||
dataset_file_path: Path to the dataset file
|
||||
|
||||
Returns:
|
||||
callable: Evaluation function
|
||||
|
||||
Raises:
|
||||
ValueError: If required parameters are missing
|
||||
"""
|
||||
evaluation_file = config.get("evaluation_file")
|
||||
if not evaluation_file:
|
||||
raise ValueError(
|
||||
"optimizer_config must contain 'evaluation_file' (path to Python file with @docetl.register_eval decorated function)"
|
||||
)
|
||||
|
||||
DOCETL_CONSOLE.log(
|
||||
f"[bold blue]📊 Loading evaluation function from: {evaluation_file}[/bold blue]"
|
||||
)
|
||||
evaluate_func = load_custom_evaluate_func(evaluation_file, dataset_file_path)
|
||||
DOCETL_CONSOLE.log("[green]✅ Evaluation function loaded[/green]")
|
||||
return evaluate_func
|
||||
|
||||
|
||||
def run_moar_optimization(
|
||||
yaml_path: str,
|
||||
optimizer_config: dict,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run MOAR optimization from CLI.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the YAML pipeline file
|
||||
optimizer_config: optimizer_config dictionary from YAML (must contain save_dir)
|
||||
|
||||
Returns:
|
||||
dict: Experiment summary
|
||||
"""
|
||||
# Load full config to infer dataset info
|
||||
with open(yaml_path, "r") as f:
|
||||
full_config = yaml.safe_load(f)
|
||||
|
||||
# Use dataset_path from optimizer_config if provided, otherwise infer from datasets section
|
||||
if optimizer_config.get("dataset_path"):
|
||||
dataset_path = optimizer_config.get("dataset_path")
|
||||
# Resolve relative paths
|
||||
if Path(dataset_path).is_absolute():
|
||||
dataset_path = str(Path(dataset_path).resolve())
|
||||
elif Path(dataset_path).exists():
|
||||
dataset_path = str(Path(dataset_path).resolve())
|
||||
else:
|
||||
yaml_dir = Path(yaml_path).parent
|
||||
dataset_path = str((yaml_dir / dataset_path).resolve())
|
||||
# Infer dataset name from datasets section
|
||||
_, dataset_name = infer_dataset_info(yaml_path, full_config)
|
||||
else:
|
||||
# Infer both dataset path and name from config
|
||||
dataset_path, dataset_name = infer_dataset_info(yaml_path, full_config)
|
||||
|
||||
# Extract MOAR parameters from optimizer_config (all required, no defaults)
|
||||
save_dir = optimizer_config.get("save_dir")
|
||||
if not save_dir:
|
||||
raise ValueError("optimizer_config must contain 'save_dir' for MOAR optimizer")
|
||||
|
||||
available_models = optimizer_config.get("available_models")
|
||||
if not available_models:
|
||||
raise ValueError(
|
||||
"optimizer_config must contain 'available_models' (list of model names) for MOAR optimizer"
|
||||
)
|
||||
|
||||
evaluation_file = optimizer_config.get("evaluation_file")
|
||||
if not evaluation_file:
|
||||
raise ValueError(
|
||||
"optimizer_config must contain 'evaluation_file' (path to Python file with @docetl.register_eval decorated function) for MOAR optimizer"
|
||||
)
|
||||
|
||||
metric_key = optimizer_config.get("metric_key")
|
||||
if not metric_key:
|
||||
raise ValueError(
|
||||
"optimizer_config must contain 'metric_key' (key to extract from evaluation results) for MOAR optimizer"
|
||||
)
|
||||
|
||||
max_iterations = optimizer_config.get("max_iterations")
|
||||
if max_iterations is None:
|
||||
raise ValueError(
|
||||
"optimizer_config must contain 'max_iterations' (number of MOARSearch iterations) for MOAR optimizer"
|
||||
)
|
||||
|
||||
# Use rewrite_agent_model for consistency with rest of codebase, fallback to model for backwards compatibility
|
||||
model = optimizer_config.get("rewrite_agent_model") or optimizer_config.get("model")
|
||||
if not model:
|
||||
raise ValueError(
|
||||
"optimizer_config must contain 'rewrite_agent_model' (LLM model name for directive instantiation) for MOAR optimizer"
|
||||
)
|
||||
|
||||
# Optional parameters
|
||||
exploration_weight = optimizer_config.get("exploration_weight", 1.414)
|
||||
build_first_layer = optimizer_config.get("build_first_layer", False)
|
||||
|
||||
# Resolve save directory (handle relative paths)
|
||||
save_dir = Path(save_dir)
|
||||
if not save_dir.is_absolute():
|
||||
# Resolve relative to current working directory
|
||||
save_dir = Path.cwd() / save_dir
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
DOCETL_CONSOLE.log("[bold blue]🌳 Running MOARSearch[/bold blue]")
|
||||
DOCETL_CONSOLE.log(f"[dim]Input Pipeline:[/dim] {yaml_path}")
|
||||
DOCETL_CONSOLE.log(f"[dim]Save Directory:[/dim] {save_dir}")
|
||||
DOCETL_CONSOLE.log(f"[dim]Max Iterations:[/dim] {max_iterations}")
|
||||
DOCETL_CONSOLE.log(f"[dim]Exploration Weight (c):[/dim] {exploration_weight}")
|
||||
DOCETL_CONSOLE.log(f"[dim]Model:[/dim] {model}")
|
||||
DOCETL_CONSOLE.log(f"[dim]Dataset:[/dim] {dataset_name}")
|
||||
DOCETL_CONSOLE.log()
|
||||
|
||||
# Load sample input data
|
||||
DOCETL_CONSOLE.log("[bold blue]🚀 Initializing MOARSearch...[/bold blue]")
|
||||
with open(dataset_path, "r") as f:
|
||||
dataset_data = json.load(f)
|
||||
|
||||
# Take only the first 5 documents for sample input
|
||||
if isinstance(dataset_data, list):
|
||||
sample_input_data = dataset_data[:5]
|
||||
else:
|
||||
sample_input_data = dataset_data
|
||||
|
||||
# Use all registered rewrite directives
|
||||
available_actions = set(ALL_DIRECTIVES)
|
||||
|
||||
# Get dataset statistics
|
||||
dataset_stats = get_dataset_stats(yaml_path, dataset_name)
|
||||
|
||||
# Load evaluation function (pass dataset_path so it can be provided to eval function)
|
||||
evaluate_func = load_evaluation_function(optimizer_config, dataset_path)
|
||||
|
||||
# Initialize MOARSearch
|
||||
moar = MOARSearch(
|
||||
root_yaml_path=yaml_path,
|
||||
available_actions=available_actions,
|
||||
sample_input=sample_input_data,
|
||||
dataset_stats=dataset_stats,
|
||||
dataset_name=dataset_name,
|
||||
available_models=available_models,
|
||||
evaluate_func=evaluate_func,
|
||||
exploration_constant=exploration_weight,
|
||||
max_iterations=max_iterations,
|
||||
model=model,
|
||||
output_dir=str(save_dir),
|
||||
build_first_layer=build_first_layer,
|
||||
custom_metric_key=metric_key,
|
||||
sample_dataset_path=dataset_path, # Use the dataset_path (which may be from optimizer_config)
|
||||
)
|
||||
|
||||
DOCETL_CONSOLE.log(
|
||||
f"[green]✅ MOARSearch initialized with root node: {yaml_path}[/green]"
|
||||
)
|
||||
|
||||
# Run MOARSearch optimization
|
||||
DOCETL_CONSOLE.log(
|
||||
f"[bold blue]\n🔍 Running MOARSearch optimization for {max_iterations} iterations...[/bold blue]"
|
||||
)
|
||||
start_time = datetime.now()
|
||||
best_nodes = moar.search()
|
||||
end_time = datetime.now()
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
DOCETL_CONSOLE.log(
|
||||
f"[green]✅ MOARSearch optimization completed in {duration:.2f} seconds[/green]"
|
||||
)
|
||||
|
||||
# Run evaluation
|
||||
DOCETL_CONSOLE.log("[bold blue]📊 Running evaluation...[/bold blue]")
|
||||
|
||||
# Prepare nodes for evaluation
|
||||
nodes_for_evaluation = []
|
||||
for n in moar.pareto_frontier.plans:
|
||||
n.moar_accuracy = moar.pareto_frontier.plans_accuracy.get(n)
|
||||
n.on_frontier = n in moar.pareto_frontier.frontier_plans
|
||||
nodes_for_evaluation.append(n)
|
||||
|
||||
from docetl.utils_evaluation import run_evaluation
|
||||
|
||||
eval_results = run_evaluation(
|
||||
nodes_or_files=nodes_for_evaluation,
|
||||
evaluate_func=evaluate_func,
|
||||
metric_key=metric_key,
|
||||
output_path=save_dir,
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
|
||||
# Save experiment summary
|
||||
results = {
|
||||
"optimizer": "moar",
|
||||
"input_pipeline": yaml_path,
|
||||
"model": model,
|
||||
"max_iterations": max_iterations,
|
||||
"exploration_weight": exploration_weight,
|
||||
"save_dir": str(save_dir),
|
||||
"dataset": dataset_name,
|
||||
"start_time": start_time.isoformat(),
|
||||
"end_time": end_time.isoformat(),
|
||||
"duration_seconds": duration,
|
||||
"num_best_nodes": len(best_nodes) if best_nodes else 0,
|
||||
"total_nodes_explored": (
|
||||
len(moar.all_nodes) if hasattr(moar, "all_nodes") else 0
|
||||
),
|
||||
"total_search_cost": (
|
||||
moar.total_search_cost if hasattr(moar, "total_search_cost") else 0
|
||||
),
|
||||
}
|
||||
|
||||
if eval_results:
|
||||
results["evaluation_file"] = str(save_dir / "evaluation_metrics.json")
|
||||
|
||||
# Save Pareto frontier if available
|
||||
if hasattr(moar, "pareto_frontier") and moar.pareto_frontier.frontier_plans:
|
||||
pareto_file = save_dir / "pareto_frontier.json"
|
||||
pareto_data = []
|
||||
for node in moar.pareto_frontier.frontier_plans:
|
||||
pareto_data.append(
|
||||
{
|
||||
"node_id": node.get_id(),
|
||||
"yaml_path": node.yaml_file_path,
|
||||
"cost": node.cost,
|
||||
"accuracy": moar.pareto_frontier.plans_accuracy.get(node),
|
||||
}
|
||||
)
|
||||
with open(pareto_file, "w") as f:
|
||||
json.dump(pareto_data, f, indent=2)
|
||||
results["pareto_frontier_file"] = str(pareto_file)
|
||||
|
||||
# Save experiment summary
|
||||
summary_file = save_dir / "experiment_summary.json"
|
||||
with open(summary_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
DOCETL_CONSOLE.log(f"[green]✅ Experiment summary saved to: {summary_file}[/green]")
|
||||
|
||||
return results
|
||||
|
|
@ -1,485 +0,0 @@
|
|||
"""
|
||||
Utility functions for MCTS implementation.
|
||||
|
||||
This module contains helper functions and utilities used by the MCTS algorithm
|
||||
but not core to the MCTS structure itself.
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from docetl.reasoning_optimizer.directives import (
|
||||
DIRECTIVE_GROUPS,
|
||||
Directive,
|
||||
get_all_cost_directive_strings,
|
||||
get_all_directive_strings,
|
||||
)
|
||||
from docetl.reasoning_optimizer.load_data import load_input_doc
|
||||
from docetl.reasoning_optimizer.op_descriptions import * # noqa: F403, F405
|
||||
|
||||
# Maximum number of tokens we will allow in the prompt we send to the model.
|
||||
# The Azure GPT-5 family allows 272,000 tokens.
|
||||
MAX_CONTEXT_TOKENS = 270_000
|
||||
|
||||
|
||||
class ExpandResponseFormat(BaseModel):
|
||||
directive: str
|
||||
operators: List[str]
|
||||
|
||||
|
||||
def count_tokens(messages):
|
||||
"""Count estimated tokens in messages list."""
|
||||
# messages should be a list of dicts, each with a "content" key
|
||||
total_chars = sum(
|
||||
len(m.get("content", "")) for m in messages if isinstance(m, dict)
|
||||
)
|
||||
return max(1, total_chars // 4)
|
||||
|
||||
|
||||
def trim_history(history: list, keep_system_first: bool = True) -> list:
|
||||
"""Trim the conversation history in-place so its estimated token count
|
||||
(via ``count_tokens``) does not exceed ``MAX_CONTEXT_TOKENS``.
|
||||
|
||||
We always keep the very first system message and the first user message so the
|
||||
assistant retains the global instructions and the initial query context. After
|
||||
that we drop the oldest messages until the budget is satisfied. Returns the
|
||||
trimmed history list.
|
||||
"""
|
||||
|
||||
# Determine starting index to preserve the initial system message and first user message
|
||||
start_idx = 0
|
||||
if keep_system_first and history:
|
||||
if history[0].get("role") == "system":
|
||||
start_idx = 1
|
||||
# Find the first user message after the system message
|
||||
for i in range(1, len(history)):
|
||||
if history[i].get("role") == "user":
|
||||
start_idx = i + 1
|
||||
break
|
||||
elif history[0].get("role") == "user":
|
||||
# If first message is user, keep it and find the next user message
|
||||
start_idx = 1
|
||||
for i in range(1, len(history)):
|
||||
if history[i].get("role") == "user":
|
||||
start_idx = i + 1
|
||||
break
|
||||
|
||||
# Drop oldest messages (just after the preserved block) until within limit
|
||||
while len(history) > start_idx + 1 and count_tokens(history) > MAX_CONTEXT_TOKENS:
|
||||
history.pop(start_idx)
|
||||
|
||||
return history
|
||||
|
||||
|
||||
def get_directive_group(directive_name: str) -> str:
|
||||
"""
|
||||
Get the group name for a directive.
|
||||
|
||||
Args:
|
||||
directive_name: Name of the directive
|
||||
|
||||
Returns:
|
||||
Group name if found, None otherwise
|
||||
"""
|
||||
for group_name, directives in DIRECTIVE_GROUPS.items():
|
||||
for directive in directives:
|
||||
if directive.name == directive_name:
|
||||
return group_name
|
||||
return None
|
||||
|
||||
|
||||
def get_excluded_directives_for_operation(node, op_name: str) -> set:
|
||||
"""Get compression directives to exclude for code_map and extract operations."""
|
||||
op_type = node.op_dict[op_name].get("type")
|
||||
compression_exclusions = set()
|
||||
if op_type in ["code_map", "extract"]:
|
||||
compression_exclusions = set(DIRECTIVE_GROUPS.get("compression", []))
|
||||
return compression_exclusions
|
||||
|
||||
|
||||
def is_action_applicable(node, action: Directive) -> bool:
|
||||
"""Check if an action is applicable to a node."""
|
||||
return True
|
||||
|
||||
|
||||
def update_pipeline(orig_config, new_ops_list, target_ops):
|
||||
"""
|
||||
Update the pipeline configuration with new operations.
|
||||
|
||||
Args:
|
||||
orig_config (dict): The original pipeline configuration
|
||||
new_ops_list (list): The entire pipeline operations list (not a subset)
|
||||
target_ops (list): List of target operation names to replace
|
||||
|
||||
Returns:
|
||||
dict: Updated pipeline configuration
|
||||
"""
|
||||
if new_ops_list is not None:
|
||||
op_names = [op.get("name") for op in new_ops_list if "name" in op]
|
||||
|
||||
# Update the pipeline steps to use the new operation names
|
||||
if "pipeline" in orig_config and "steps" in orig_config["pipeline"]:
|
||||
for step in orig_config["pipeline"]["steps"]:
|
||||
if "operations" in step:
|
||||
new_ops = []
|
||||
for op in step["operations"]:
|
||||
if op == target_ops[0]:
|
||||
new_ops.extend(op_names)
|
||||
step["operations"] = new_ops
|
||||
|
||||
return orig_config
|
||||
|
||||
|
||||
def fix_models(parsed_yaml):
|
||||
"""No-op: Model names should be specified correctly in the YAML."""
|
||||
pass
|
||||
|
||||
|
||||
def is_fully_explored(node, max_children_multiplier: float = 1.0) -> bool:
|
||||
"""Check if a node has been fully explored based on visit count."""
|
||||
if node.parent is None:
|
||||
return True
|
||||
allowed_children = max(
|
||||
2, 1 + math.floor(math.sqrt(float(node.visits)) * max_children_multiplier)
|
||||
)
|
||||
|
||||
# Not only check the number of children, but also ensure all children have been simulated
|
||||
# A child is considered simulated if it has been visited at least once (visits > 0)
|
||||
# This ensures that children created but not yet simulated won't cause the parent to be considered fully explored
|
||||
if len(node.children) < allowed_children:
|
||||
return False
|
||||
|
||||
# Check if all children have been simulated (visited at least once)
|
||||
# This prevents selecting children that were just created but not yet simulated
|
||||
for child in node.children:
|
||||
if child.visits == 0:
|
||||
# This child hasn't been selected/simulated yet, so parent is not fully explored
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def should_continue_search(
|
||||
iteration: int,
|
||||
max_iterations: int,
|
||||
start_time: float,
|
||||
max_time: Optional[float] = None,
|
||||
) -> bool:
|
||||
"""Determine if search should continue based on iteration count and time."""
|
||||
if iteration >= max_iterations:
|
||||
return False
|
||||
|
||||
if max_time is not None:
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time >= max_time:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def calculate_ucb1(
|
||||
node, parent_visits: int, exploration_constant: float = math.sqrt(2)
|
||||
) -> float:
|
||||
"""Calculate UCB1 value for node selection."""
|
||||
if node.visits == 0:
|
||||
return float("inf")
|
||||
|
||||
exploitation = node.value / node.visits
|
||||
exploration = exploration_constant * math.sqrt(
|
||||
math.log(parent_visits) / node.visits
|
||||
)
|
||||
return exploitation + exploration
|
||||
|
||||
|
||||
def print_tree_visits_and_values(node=None, depth=0, file_handle=None, console=None):
|
||||
"""Print tree structure with visit counts and values."""
|
||||
if node is None:
|
||||
return
|
||||
|
||||
indent = " " * depth
|
||||
node_info = f"{indent}Node ID: {node.get_id()}, Visits: {node.visits}, Value: {node.value:.4f}"
|
||||
|
||||
if file_handle:
|
||||
file_handle.write(node_info + "\n")
|
||||
else:
|
||||
if console is None:
|
||||
from docetl.console import DOCETL_CONSOLE
|
||||
|
||||
console = DOCETL_CONSOLE
|
||||
console.log(node_info)
|
||||
|
||||
for child in node.children:
|
||||
print_tree_visits_and_values(child, depth + 1, file_handle, console)
|
||||
|
||||
|
||||
def log_tree_to_file(root_node, iteration_num, output_dir="./outputs", console=None):
|
||||
"""Log the tree structure to a file."""
|
||||
log_file_path = os.path.join(output_dir, f"moar_tree_iteration_{iteration_num}.txt")
|
||||
|
||||
with open(log_file_path, "w") as f:
|
||||
f.write(f"MOAR Tree Structure - Iteration {iteration_num}\n")
|
||||
f.write("=" * 50 + "\n")
|
||||
print_tree_visits_and_values(root_node, file_handle=f, console=console)
|
||||
|
||||
|
||||
def create_expansion_prompt_acc(
|
||||
node,
|
||||
action_options,
|
||||
input_query,
|
||||
available_actions,
|
||||
action_cost_changes,
|
||||
action_accuracy_changes,
|
||||
action_counts,
|
||||
sample_input,
|
||||
root_node,
|
||||
yaml_file_path,
|
||||
dataset=None,
|
||||
node_accuracies=None,
|
||||
model_stats=None,
|
||||
available_models=None,
|
||||
) -> tuple[str, str]:
|
||||
"""Create expansion prompt for accuracy optimization."""
|
||||
|
||||
# Use provided available_models list, or extract from model_stats as fallback
|
||||
if available_models:
|
||||
available_models_list = available_models
|
||||
elif model_stats:
|
||||
available_models_list = list(model_stats.keys())
|
||||
else:
|
||||
available_models_list = []
|
||||
|
||||
availabel_actions_str = ""
|
||||
for item in action_options:
|
||||
op_name = item[0]
|
||||
action_name = item[1]
|
||||
action_str = f"Operator: {op_name}, Rewrite directive: {action_name}\n"
|
||||
availabel_actions_str += action_str
|
||||
|
||||
action_stats = []
|
||||
for action in available_actions:
|
||||
cost_change = action_cost_changes.get(action, 0)
|
||||
accuracy_change = action_accuracy_changes.get(action, 0)
|
||||
count = action_counts.get(action, 0)
|
||||
|
||||
if count > 0:
|
||||
avg_cost_change = cost_change / count
|
||||
avg_accuracy_change = accuracy_change / count
|
||||
action_stats.append(
|
||||
f"- {action.name}: {count} uses, avg change in cost: {avg_cost_change:+.2f}, avg change in accuracy: {avg_accuracy_change:+.4f}"
|
||||
)
|
||||
else:
|
||||
action_stats.append(
|
||||
f"- {action.name}: {count} uses, avg change in cost: Unknown (never tried), avg change in accuracy: Unknown (never tried)"
|
||||
)
|
||||
|
||||
action_stats_str = "\n".join(action_stats)
|
||||
|
||||
input_schema = load_input_doc(yaml_file_path)
|
||||
|
||||
user_message = f"""
|
||||
I have a set of operations used to process long documents, along with a list of possible rewrite directives aimed at improving the quality of the query result.
|
||||
Given a query pipeline made up of these operations, recommend one specific rewrite directive (specify by its name) that would improve accuracy and specify which operators (specify by their names) in the pipeline the directive should be applied to.
|
||||
Make sure that your chosen directive is in the provided list of rewrite directives.
|
||||
|
||||
Pipeline:
|
||||
Pipelines in DocETL are the core structures that define the flow of data processing. A pipeline consists of five main components: \n
|
||||
- Default Model: The language model to use for the pipeline. Limit your choice of model to {available_models_list} \n
|
||||
- System Prompts: A description of your dataset and the "persona" you'd like the LLM to adopt when analyzing your data. \n
|
||||
- Datasets: The input data sources for your pipeline. \n
|
||||
- Operators: The processing steps that transform your data. \n
|
||||
- Pipeline Specification: The sequence of steps and the output configuration. \n
|
||||
|
||||
Operators:
|
||||
Operators form the building blocks of data processing pipelines. Below is the list of operators:
|
||||
{op_map.to_string()}\n
|
||||
{op_extract.to_string()}\n
|
||||
{op_parallel_map.to_string()}\n
|
||||
{op_filter.to_string()}\n
|
||||
{op_reduce.to_string()}\n
|
||||
{op_split.to_string()}\n
|
||||
{op_gather.to_string()}\n
|
||||
{op_unnest.to_string()}\n
|
||||
{op_sample.to_string()}\n
|
||||
{op_resolve.to_string()}\n
|
||||
|
||||
Rewrite directives:
|
||||
{get_all_directive_strings()}\n
|
||||
|
||||
Your valid choice of operation and rewrite directive combination. Only choose one of these:\n
|
||||
{availabel_actions_str}
|
||||
|
||||
Action Performance History:
|
||||
Based on previous executions across DIFFERENT query pipelines, here's how each action has performed:\n
|
||||
{action_stats_str}
|
||||
|
||||
Note: These statistics come from applying actions to various other query pipelines, not the current one. Use this as general guidance about action effectiveness, but consider that performance may vary significantly for your specific pipeline structure and data.
|
||||
|
||||
Model Performance Reference:
|
||||
If you are considering a model change directive, here are model statistics on this specific dataset: \n {str(model_stats if model_stats else {})}
|
||||
These show the accuracy (acc) and cost for each model. Only reference this when evaluating model change options.
|
||||
|
||||
Selection Strategy:
|
||||
Consider the current query pipeline, which directive can best improve the accuracy.
|
||||
Prioritize exploration of untested actions while balancing with exploitation of proven performers:
|
||||
- Actions with 0 uses have unknown potential, so you should explore them if applicable.
|
||||
- fusion directives are helpful
|
||||
- Consider both immediate improvement and learning about the action space
|
||||
|
||||
{node.get_memo_for_llm(root_node, node_accuracies)}
|
||||
|
||||
Make sure you read every rewrite directive carefully.
|
||||
Make sure you only choose from the valid choices above and avoid already used combinations or approaches too similar to what has already been tried in the current optimization path.
|
||||
|
||||
Input document schema with token statistics: {input_schema} \n
|
||||
Input data sample: {json.dumps(sample_input, indent=2)[:5000]} \n
|
||||
The original query in YAML format using our operations: {input_query} \n
|
||||
The original query result: {json.dumps(node.sample_result, indent=2)[:3000]} \n
|
||||
"""
|
||||
|
||||
# Create a condensed version for message history (without full operator/directive descriptions)
|
||||
condensed_user_message = f"""
|
||||
Recommend one specific rewrite directive for accuracy optimization.
|
||||
|
||||
Valid choices:
|
||||
{availabel_actions_str}
|
||||
|
||||
Action Performance History:
|
||||
{action_stats_str}
|
||||
|
||||
Current pipeline: {input_query}
|
||||
"""
|
||||
|
||||
return user_message, condensed_user_message
|
||||
|
||||
|
||||
def create_expansion_prompt_cost(
|
||||
node,
|
||||
action_options,
|
||||
input_query,
|
||||
available_actions,
|
||||
action_cost_changes,
|
||||
action_accuracy_changes,
|
||||
action_counts,
|
||||
sample_input,
|
||||
root_node,
|
||||
yaml_file_path,
|
||||
dataset=None,
|
||||
node_accuracies=None,
|
||||
model_stats=None,
|
||||
available_models=None,
|
||||
) -> tuple[str, str]:
|
||||
"""Create expansion prompt for cost optimization."""
|
||||
|
||||
# Use provided available_models list, or extract from model_stats as fallback
|
||||
if available_models:
|
||||
available_models_list = available_models
|
||||
elif model_stats:
|
||||
available_models_list = list(model_stats.keys())
|
||||
else:
|
||||
available_models_list = []
|
||||
|
||||
availabel_actions_str = ""
|
||||
for item in action_options:
|
||||
op_name = item[0]
|
||||
action_name = item[1]
|
||||
action_str = f"Operator: {op_name}, Rewrite directive: {action_name}\n"
|
||||
availabel_actions_str += action_str
|
||||
|
||||
action_stats = []
|
||||
for action in available_actions:
|
||||
cost_change = action_cost_changes.get(action, 0)
|
||||
accuracy_change = action_accuracy_changes.get(action, 0)
|
||||
count = action_counts.get(action, 0)
|
||||
|
||||
if count > 0:
|
||||
avg_cost_change = cost_change / count
|
||||
avg_accuracy_change = accuracy_change / count
|
||||
action_stats.append(
|
||||
f"- {action.name}: {count} uses, avg change in cost: {avg_cost_change:+.2f}, avg change in accuracy: {avg_accuracy_change:+.4f}"
|
||||
)
|
||||
else:
|
||||
action_stats.append(
|
||||
f"- {action.name}: {count} uses, avg change in cost: Unknown (never tried), avg change in accuracy: Unknown (never tried)"
|
||||
)
|
||||
|
||||
action_stats_str = "\n".join(action_stats)
|
||||
|
||||
input_schema = load_input_doc(yaml_file_path)
|
||||
|
||||
user_message = f"""
|
||||
I have a set of operations used to process long documents, along with a list of possible rewrite directives designed to improve the cost effectiveness of the pipeline, while maintaining similar or better accuracy.
|
||||
Given a query pipeline composed of these operations, recommend one specific rewrite directive (identified by its name from the provided list) that would improve cost effectiveness. Also, specify which operator(s) (by name) in the pipeline the directive should be applied to.
|
||||
Make sure your recommended directive is selected from the provided list.
|
||||
|
||||
Pipeline:
|
||||
Pipelines in DocETL are the core structures that define the flow of data processing. A pipeline consists of five main components: \n
|
||||
- Default Model: The language model to use for the pipeline. Limit your choice of model to {available_models_list} \n
|
||||
- System Prompts: A description of your dataset and the "persona" you'd like the LLM to adopt when analyzing your data. \n
|
||||
- Datasets: The input data sources for your pipeline. \n
|
||||
- Operators: The processing steps that transform your data. \n
|
||||
- Pipeline Specification: The sequence of steps and the output configuration. \n
|
||||
|
||||
Operators:
|
||||
Operators form the building blocks of data processing pipelines. Below is the list of operators:
|
||||
{op_map.to_string()}\n
|
||||
{op_extract.to_string()}\n
|
||||
{op_parallel_map.to_string()}\n
|
||||
{op_filter.to_string()}\n
|
||||
{op_reduce.to_string()}\n
|
||||
{op_split.to_string()}\n
|
||||
{op_gather.to_string()}\n
|
||||
{op_unnest.to_string()}\n
|
||||
{op_sample.to_string()}\n
|
||||
{op_resolve.to_string()}\n
|
||||
|
||||
Rewrite directives:
|
||||
{get_all_cost_directive_strings()}\n
|
||||
|
||||
Your valid choice of operation and rewrite directive combination. Only choose one of these:\n
|
||||
{availabel_actions_str}
|
||||
|
||||
Action Performance History:
|
||||
Based on previous executions across DIFFERENT query pipelines, here's how each action has performed:\n
|
||||
{action_stats_str}
|
||||
|
||||
Note: These statistics come from applying actions to various other query pipelines, not the current one. Use this as general guidance about action effectiveness, but consider that performance may vary significantly for your specific pipeline structure and data.
|
||||
|
||||
Model Performance Reference:
|
||||
If you are considering a model change directive, here are model statistics on this specific dataset: \n {str(model_stats if model_stats else {})}
|
||||
These show the accuracy (acc) and cost for each model. Only reference this when evaluating model change options.
|
||||
|
||||
Selection Strategy:
|
||||
Consider the current query pipeline, which directive can best improve cost effectiveness.
|
||||
Prioritize exploration of untested actions while balancing with exploitation of proven performers:
|
||||
- Actions with 0 uses have unknown potential, so you should explore them if applicable.
|
||||
- fusion directives are helpful
|
||||
- Consider both immediate improvement and learning about the action space
|
||||
|
||||
{node.get_memo_for_llm(root_node, node_accuracies)}
|
||||
|
||||
Make sure you only choose from the valid choices above and avoid already used combinations or approaches too similar to what has already been tried in the current optimization path.
|
||||
|
||||
Input document schema with token statistics: {input_schema} \n
|
||||
Input data sample: {json.dumps(sample_input, indent=2)[:5000]} \n
|
||||
The original query in YAML format using our operations: {input_query} \n
|
||||
The original query result: {json.dumps(node.sample_result, indent=2)[:3000]} \n
|
||||
"""
|
||||
|
||||
condensed_user_message = f"""
|
||||
Recommend one specific rewrite directive for cost optimization.
|
||||
|
||||
Valid choices:
|
||||
{availabel_actions_str}
|
||||
|
||||
Action Performance History:
|
||||
{action_stats_str}
|
||||
|
||||
Current pipeline: {input_query}
|
||||
"""
|
||||
|
||||
return user_message, condensed_user_message
|
||||
|
|
@ -4,8 +4,8 @@ from docetl.operations.code_operations import CodeFilterOperation, CodeMapOperat
|
|||
from docetl.operations.equijoin import EquijoinOperation
|
||||
from docetl.operations.filter import FilterOperation
|
||||
from docetl.operations.gather import GatherOperation
|
||||
from docetl.operations.map import MapOperation, ParallelMapOperation
|
||||
from docetl.operations.link_resolve import LinkResolveOperation
|
||||
from docetl.operations.map import MapOperation, ParallelMapOperation
|
||||
from docetl.operations.reduce import ReduceOperation
|
||||
from docetl.operations.resolve import ResolveOperation
|
||||
from docetl.operations.rank import RankOperation
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
The BaseOperation class is an abstract base class for all operations in the docetl framework. It provides a common structure and interface for various data processing operations.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -88,7 +87,6 @@ class BaseOperation(ABC, metaclass=BaseOperationMeta):
|
|||
type: str
|
||||
skip_on_error: bool = False
|
||||
gleaning: GleaningConfig | None = None
|
||||
retriever: str | None = None
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, input_data: list[dict]) -> tuple[list[dict], float]:
|
||||
|
|
@ -111,26 +109,3 @@ class BaseOperation(ABC, metaclass=BaseOperationMeta):
|
|||
"""Perform syntax checks on the operation configuration."""
|
||||
# Validate the configuration using Pydantic
|
||||
self.schema.model_validate(self.config, context=context)
|
||||
|
||||
def _maybe_build_retrieval_context(self, context: dict[str, Any]) -> str:
|
||||
"""Build retrieval context string if a retriever is configured."""
|
||||
retriever_name = self.config.get("retriever")
|
||||
if not retriever_name:
|
||||
return ""
|
||||
retrievers = getattr(self.runner, "retrievers", {})
|
||||
if retriever_name not in retrievers:
|
||||
raise ValueError(
|
||||
f"Retriever '{retriever_name}' not found in configuration."
|
||||
)
|
||||
retriever = retrievers[retriever_name]
|
||||
try:
|
||||
result = retriever.retrieve(context)
|
||||
return result.rendered_context or ""
|
||||
except Exception as e:
|
||||
# Soft-fail to avoid blocking the op
|
||||
self.console.log(
|
||||
f"[yellow]Warning: retrieval failed for '{retriever_name}': {e}[/yellow]"
|
||||
)
|
||||
# Print traceback to help debug
|
||||
self.console.log(traceback.format_exc())
|
||||
return "No extra context available."
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from jinja2 import Template
|
|||
from .base import BaseOperation
|
||||
from .clustering_utils import get_embeddings_for_clustering
|
||||
from .utils import RichLoopBar, strict_render
|
||||
from docetl.utils import has_jinja_syntax, prompt_user_for_non_jinja_confirmation
|
||||
|
||||
|
||||
class ClusterOperation(BaseOperation):
|
||||
|
|
@ -20,19 +19,6 @@ class ClusterOperation(BaseOperation):
|
|||
self.max_batch_size: int = self.config.get(
|
||||
"max_batch_size", kwargs.get("max_batch_size", float("inf"))
|
||||
)
|
||||
# Check for non-Jinja prompts and prompt user for confirmation
|
||||
if "summary_prompt" in self.config and not has_jinja_syntax(
|
||||
self.config["summary_prompt"]
|
||||
):
|
||||
if not prompt_user_for_non_jinja_confirmation(
|
||||
self.config["summary_prompt"], self.config["name"], "summary_prompt"
|
||||
):
|
||||
raise ValueError(
|
||||
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your summary_prompt."
|
||||
)
|
||||
# Mark that we need to append document statement (cluster uses inputs)
|
||||
self.config["_append_document_to_prompt"] = True
|
||||
self.config["_is_reduce_operation"] = True
|
||||
|
||||
def syntax_check(self) -> None:
|
||||
"""
|
||||
|
|
@ -62,16 +48,11 @@ class ClusterOperation(BaseOperation):
|
|||
if not isinstance(self.config["summary_prompt"], str):
|
||||
raise TypeError("'prompt' must be a string")
|
||||
|
||||
# Check if the prompt has Jinja syntax
|
||||
if not has_jinja_syntax(self.config["summary_prompt"]):
|
||||
# This will be handled during initialization with user confirmation
|
||||
pass
|
||||
else:
|
||||
# Check if the prompt is a valid Jinja2 template
|
||||
try:
|
||||
Template(self.config["summary_prompt"])
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid Jinja2 template in 'prompt': {str(e)}")
|
||||
# Check if the prompt is a valid Jinja2 template
|
||||
try:
|
||||
Template(self.config["summary_prompt"])
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid Jinja2 template in 'prompt': {str(e)}")
|
||||
|
||||
# Check optional parameters
|
||||
if "max_batch_size" in self.config:
|
||||
|
|
|
|||
|
|
@ -4,8 +4,6 @@ This module contains utilities for clustering based on different methods.
|
|||
We use these in map and reduce operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from docetl.operations.utils import APIWrapper
|
||||
from docetl.utils import completion_cost
|
||||
|
||||
|
|
@ -28,10 +26,10 @@ def get_embeddings_for_clustering(
|
|||
for i in range(0, len(items), batch_size):
|
||||
batch = items[i : i + batch_size]
|
||||
texts = [
|
||||
" ".join(str(item[key]) for key in embedding_keys if key in item)[:1000]
|
||||
" ".join(str(item[key]) for key in embedding_keys if key in item)[:10000]
|
||||
for item in batch
|
||||
]
|
||||
response = api_wrapper.gen_embedding(embedding_model, json.dumps(texts))
|
||||
response = api_wrapper.gen_embedding(embedding_model, texts)
|
||||
embeddings.extend([data["embedding"] for data in response["data"]])
|
||||
cost += completion_cost(response)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,5 @@
|
|||
import inspect
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from docetl.operations.base import BaseOperation
|
||||
from docetl.operations.utils import RichLoopBar
|
||||
|
|
@ -12,25 +8,9 @@ from docetl.operations.utils import RichLoopBar
|
|||
class CodeMapOperation(BaseOperation):
|
||||
class schema(BaseOperation.schema):
|
||||
type: str = "code_map"
|
||||
code: Any
|
||||
code: str
|
||||
concurrent_thread_count: int = os.cpu_count()
|
||||
drop_keys: list[str] | None = None
|
||||
limit: int | None = Field(None, gt=0)
|
||||
|
||||
@field_validator("code")
|
||||
@classmethod
|
||||
def validate_code(cls, v: Any) -> str:
|
||||
if isinstance(v, str):
|
||||
return v
|
||||
if callable(v):
|
||||
try:
|
||||
src = inspect.getsource(v)
|
||||
except OSError as e:
|
||||
raise ValueError(
|
||||
"Unable to retrieve source for provided function. Please pass a normal def function."
|
||||
) from e
|
||||
return f"{src}\ntransform = {v.__name__}"
|
||||
raise TypeError("code must be a string or a callable")
|
||||
|
||||
def syntax_check(self) -> None:
|
||||
config = self.schema(**self.config)
|
||||
|
|
@ -45,10 +25,6 @@ class CodeMapOperation(BaseOperation):
|
|||
raise ValueError(f"Invalid code configuration: {str(e)}")
|
||||
|
||||
def execute(self, input_data: list[dict]) -> tuple[list[dict], float]:
|
||||
limit_value = self.config.get("limit")
|
||||
if limit_value is not None:
|
||||
input_data = input_data[:limit_value]
|
||||
|
||||
namespace = {}
|
||||
exec(self.config["code"], namespace)
|
||||
transform_fn = namespace["transform"]
|
||||
|
|
@ -81,24 +57,8 @@ class CodeMapOperation(BaseOperation):
|
|||
class CodeReduceOperation(BaseOperation):
|
||||
class schema(BaseOperation.schema):
|
||||
type: str = "code_reduce"
|
||||
code: Any
|
||||
code: str
|
||||
concurrent_thread_count: int = os.cpu_count()
|
||||
limit: int | None = Field(None, gt=0)
|
||||
|
||||
@field_validator("code")
|
||||
@classmethod
|
||||
def validate_code(cls, v: Any) -> str:
|
||||
if isinstance(v, str):
|
||||
return v
|
||||
if callable(v):
|
||||
try:
|
||||
src = inspect.getsource(v)
|
||||
except OSError as e:
|
||||
raise ValueError(
|
||||
"Unable to retrieve source for provided function. Please pass a normal def function."
|
||||
) from e
|
||||
return f"{src}\ntransform = {v.__name__}"
|
||||
raise TypeError("code must be a string or a callable")
|
||||
|
||||
def syntax_check(self) -> None:
|
||||
config = self.schema(**self.config)
|
||||
|
|
@ -137,12 +97,6 @@ class CodeReduceOperation(BaseOperation):
|
|||
|
||||
grouped_data = list(grouped_data.items())
|
||||
|
||||
limit_value = self.config.get("limit")
|
||||
if limit_value is not None:
|
||||
# Sort by group size (smallest first) and take the limit
|
||||
grouped_data = sorted(grouped_data, key=lambda x: len(x[1]))
|
||||
grouped_data = grouped_data[:limit_value]
|
||||
|
||||
results = []
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=self.config.get("concurrent_thread_count", os.cpu_count())
|
||||
|
|
@ -178,24 +132,8 @@ class CodeReduceOperation(BaseOperation):
|
|||
class CodeFilterOperation(BaseOperation):
|
||||
class schema(BaseOperation.schema):
|
||||
type: str = "code_filter"
|
||||
code: Any
|
||||
code: str
|
||||
concurrent_thread_count: int = os.cpu_count()
|
||||
limit: int | None = Field(None, gt=0)
|
||||
|
||||
@field_validator("code")
|
||||
@classmethod
|
||||
def validate_code(cls, v: Any) -> str:
|
||||
if isinstance(v, str):
|
||||
return v
|
||||
if callable(v):
|
||||
try:
|
||||
src = inspect.getsource(v)
|
||||
except OSError as e:
|
||||
raise ValueError(
|
||||
"Unable to retrieve source for provided function. Please pass a normal def function."
|
||||
) from e
|
||||
return f"{src}\ntransform = {v.__name__}"
|
||||
raise TypeError("code must be a string or a callable")
|
||||
|
||||
def syntax_check(self) -> None:
|
||||
config = self.schema(**self.config)
|
||||
|
|
@ -214,7 +152,6 @@ class CodeFilterOperation(BaseOperation):
|
|||
exec(self.config["code"], namespace)
|
||||
filter_fn = namespace["transform"]
|
||||
|
||||
limit_value = self.config.get("limit")
|
||||
results = []
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=self.config.get("concurrent_thread_count", os.cpu_count())
|
||||
|
|
@ -229,6 +166,4 @@ class CodeFilterOperation(BaseOperation):
|
|||
should_keep = futures[i].result()
|
||||
if should_keep:
|
||||
results.append(input_data[i])
|
||||
if limit_value is not None and len(results) >= limit_value:
|
||||
break
|
||||
return results, 0.0
|
||||
|
|
|
|||
|
|
@ -17,13 +17,8 @@ from rich.prompt import Confirm
|
|||
|
||||
from docetl.operations.base import BaseOperation
|
||||
from docetl.operations.utils import strict_render
|
||||
from docetl.operations.utils.blocking import RuntimeBlockingOptimizer
|
||||
from docetl.operations.utils.progress import RichLoopBar
|
||||
from docetl.utils import (
|
||||
completion_cost,
|
||||
has_jinja_syntax,
|
||||
prompt_user_for_non_jinja_confirmation,
|
||||
)
|
||||
from docetl.utils import completion_cost
|
||||
|
||||
# Global variables to store shared data
|
||||
_right_data = None
|
||||
|
|
@ -64,7 +59,6 @@ class EquijoinOperation(BaseOperation):
|
|||
comparison_prompt: str
|
||||
output: dict[str, Any] | None = None
|
||||
blocking_threshold: float | None = None
|
||||
blocking_target_recall: float | None = None
|
||||
blocking_conditions: list[str] | None = None
|
||||
limits: dict[str, int] | None = None
|
||||
comparison_model: str | None = None
|
||||
|
|
@ -95,41 +89,6 @@ class EquijoinOperation(BaseOperation):
|
|||
)
|
||||
return v
|
||||
|
||||
@field_validator("comparison_prompt")
|
||||
def validate_comparison_prompt(cls, v):
|
||||
# Check if it has Jinja syntax
|
||||
if not has_jinja_syntax(v):
|
||||
# This will be handled during initialization with user confirmation
|
||||
return v
|
||||
# If it has Jinja syntax, validate it's a valid template
|
||||
from jinja2 import Template
|
||||
|
||||
try:
|
||||
Template(v)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Invalid Jinja2 template in 'comparison_prompt': {str(e)}"
|
||||
)
|
||||
return v
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Check for non-Jinja prompts and prompt user for confirmation
|
||||
if "comparison_prompt" in self.config and not has_jinja_syntax(
|
||||
self.config["comparison_prompt"]
|
||||
):
|
||||
if not prompt_user_for_non_jinja_confirmation(
|
||||
self.config["comparison_prompt"],
|
||||
self.config["name"],
|
||||
"comparison_prompt",
|
||||
):
|
||||
raise ValueError(
|
||||
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your comparison_prompt."
|
||||
)
|
||||
# Mark that we need to append document statement
|
||||
# Note: equijoin uses left and right, so we'll handle it in strict_render
|
||||
self.config["_append_document_to_comparison_prompt"] = True
|
||||
|
||||
def compare_pair(
|
||||
self,
|
||||
comparison_prompt: str,
|
||||
|
|
@ -252,58 +211,6 @@ class EquijoinOperation(BaseOperation):
|
|||
if self.status:
|
||||
self.status.stop()
|
||||
|
||||
# Track pre-computed embeddings from auto-optimization
|
||||
precomputed_left_embeddings = None
|
||||
precomputed_right_embeddings = None
|
||||
|
||||
# Auto-compute blocking threshold if no blocking configuration is provided
|
||||
if not blocking_threshold and not blocking_conditions and not limit_comparisons:
|
||||
# Get target recall from operation config (default 0.95)
|
||||
target_recall = self.config.get("blocking_target_recall", 0.95)
|
||||
self.console.log(
|
||||
f"[yellow]No blocking configuration. Auto-computing threshold (target recall: {target_recall:.0%})...[/yellow]"
|
||||
)
|
||||
|
||||
# Create comparison function for threshold optimization
|
||||
def compare_fn_for_optimization(left_item, right_item):
|
||||
return self.compare_pair(
|
||||
self.config["comparison_prompt"],
|
||||
self.config.get("comparison_model", self.default_model),
|
||||
left_item,
|
||||
right_item,
|
||||
timeout_seconds=self.config.get("timeout", 120),
|
||||
max_retries_per_timeout=self.config.get(
|
||||
"max_retries_per_timeout", 2
|
||||
),
|
||||
)
|
||||
|
||||
# Run threshold optimization
|
||||
optimizer = RuntimeBlockingOptimizer(
|
||||
runner=self.runner,
|
||||
config=self.config,
|
||||
default_model=self.default_model,
|
||||
max_threads=self.max_threads,
|
||||
console=self.console,
|
||||
target_recall=target_recall,
|
||||
sample_size=min(100, len(left_data) * len(right_data) // 4),
|
||||
)
|
||||
(
|
||||
blocking_threshold,
|
||||
precomputed_left_embeddings,
|
||||
precomputed_right_embeddings,
|
||||
optimization_cost,
|
||||
) = optimizer.optimize_equijoin(
|
||||
left_data,
|
||||
right_data,
|
||||
compare_fn_for_optimization,
|
||||
left_keys=left_keys,
|
||||
right_keys=right_keys,
|
||||
)
|
||||
total_cost += optimization_cost
|
||||
self.console.log(
|
||||
f"[green]Using auto-computed blocking threshold: {blocking_threshold}[/green]"
|
||||
)
|
||||
|
||||
# Initial blocking using multiprocessing
|
||||
num_processes = min(cpu_count(), len(left_data))
|
||||
|
||||
|
|
@ -352,60 +259,45 @@ class EquijoinOperation(BaseOperation):
|
|||
)
|
||||
|
||||
if blocking_threshold is not None:
|
||||
# Use precomputed embeddings if available from auto-optimization
|
||||
if (
|
||||
precomputed_left_embeddings is not None
|
||||
and precomputed_right_embeddings is not None
|
||||
):
|
||||
left_embeddings = precomputed_left_embeddings
|
||||
right_embeddings = precomputed_right_embeddings
|
||||
else:
|
||||
embedding_model = self.config.get("embedding_model", self.default_model)
|
||||
model_input_context_length = model_cost.get(embedding_model, {}).get(
|
||||
"max_input_tokens", 8192
|
||||
)
|
||||
batch_size = 2000
|
||||
embedding_model = self.config.get("embedding_model", self.default_model)
|
||||
model_input_context_length = model_cost.get(embedding_model, {}).get(
|
||||
"max_input_tokens", 8192
|
||||
)
|
||||
|
||||
def get_embeddings(
|
||||
input_data: list[dict[str, Any]], keys: list[str], name: str
|
||||
) -> tuple[list[list[float]], float]:
|
||||
texts = [
|
||||
" ".join(str(item[key]) for key in keys if key in item)[
|
||||
: model_input_context_length * 4
|
||||
]
|
||||
for item in input_data
|
||||
def get_embeddings(
|
||||
input_data: list[dict[str, Any]], keys: list[str], name: str
|
||||
) -> tuple[list[list[float]], float]:
|
||||
texts = [
|
||||
" ".join(str(item[key]) for key in keys if key in item)[
|
||||
: model_input_context_length * 4
|
||||
]
|
||||
embeddings = []
|
||||
embedding_cost = 0
|
||||
num_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
for item in input_data
|
||||
]
|
||||
|
||||
for batch_idx, i in enumerate(range(0, len(texts), batch_size)):
|
||||
batch = texts[i : i + batch_size]
|
||||
if num_batches > 1:
|
||||
self.console.log(
|
||||
f"[dim]Creating {name} embeddings: batch {batch_idx + 1}/{num_batches} "
|
||||
f"({min(i + batch_size, len(texts))}/{len(texts)} items)[/dim]"
|
||||
)
|
||||
response = self.runner.api.gen_embedding(
|
||||
model=embedding_model,
|
||||
input=batch,
|
||||
)
|
||||
embeddings.extend(
|
||||
[data["embedding"] for data in response["data"]]
|
||||
)
|
||||
embedding_cost += completion_cost(response)
|
||||
return embeddings, embedding_cost
|
||||
embeddings = []
|
||||
total_cost = 0
|
||||
batch_size = 2000
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
self.console.log(
|
||||
f"On iteration {i} for creating embeddings for {name} data"
|
||||
)
|
||||
response = self.runner.api.gen_embedding(
|
||||
model=embedding_model,
|
||||
input=batch,
|
||||
)
|
||||
embeddings.extend([data["embedding"] for data in response["data"]])
|
||||
total_cost += completion_cost(response)
|
||||
return embeddings, total_cost
|
||||
|
||||
self.console.log(
|
||||
f"[cyan]Creating embeddings for {len(left_data)} left + {len(right_data)} right items...[/cyan]"
|
||||
)
|
||||
left_embeddings, left_cost = get_embeddings(
|
||||
left_data, left_keys, "left"
|
||||
)
|
||||
right_embeddings, right_cost = get_embeddings(
|
||||
right_data, right_keys, "right"
|
||||
)
|
||||
total_cost += left_cost + right_cost
|
||||
left_embeddings, left_cost = get_embeddings(left_data, left_keys, "left")
|
||||
right_embeddings, right_cost = get_embeddings(
|
||||
right_data, right_keys, "right"
|
||||
)
|
||||
total_cost += left_cost + right_cost
|
||||
self.console.log(
|
||||
f"Created embeddings for datasets. Total embedding creation cost: {total_cost}"
|
||||
)
|
||||
|
||||
# Compute all cosine similarities in one call
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from pydantic import Field, field_validator
|
|||
|
||||
from docetl.operations.base import BaseOperation
|
||||
from docetl.operations.utils import RichLoopBar, strict_render
|
||||
from docetl.utils import has_jinja_syntax, prompt_user_for_non_jinja_confirmation
|
||||
|
||||
|
||||
class ExtractOperation(BaseOperation):
|
||||
|
|
@ -26,14 +25,9 @@ class ExtractOperation(BaseOperation):
|
|||
timeout: int | None = None
|
||||
skip_on_error: bool = False
|
||||
litellm_completion_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
limit: int | None = Field(None, gt=0)
|
||||
|
||||
@field_validator("prompt")
|
||||
def validate_prompt(cls, v):
|
||||
# Check if it has Jinja syntax
|
||||
if not has_jinja_syntax(v):
|
||||
# This will be handled during initialization with user confirmation
|
||||
return v
|
||||
try:
|
||||
Template(v)
|
||||
except Exception as e:
|
||||
|
|
@ -53,16 +47,6 @@ class ExtractOperation(BaseOperation):
|
|||
self.extraction_key_suffix = f"_extracted_{self.config['name']}"
|
||||
else:
|
||||
self.extraction_key_suffix = self.config["extraction_key_suffix"]
|
||||
# Check for non-Jinja prompts and prompt user for confirmation
|
||||
if "prompt" in self.config and not has_jinja_syntax(self.config["prompt"]):
|
||||
if not prompt_user_for_non_jinja_confirmation(
|
||||
self.config["prompt"], self.config["name"], "prompt"
|
||||
):
|
||||
raise ValueError(
|
||||
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your prompt."
|
||||
)
|
||||
# Mark that we need to append document statement
|
||||
self.config["_append_document_to_prompt"] = True
|
||||
|
||||
def _reformat_text_with_line_numbers(self, text: str, line_width: int = 80) -> str:
|
||||
"""
|
||||
|
|
@ -119,7 +103,7 @@ class ExtractOperation(BaseOperation):
|
|||
|
||||
def _execute_line_number_strategy(
|
||||
self, item: dict, doc_key: str
|
||||
) -> tuple[list[str], float, str]:
|
||||
) -> tuple[list[dict[str, Any]], float]:
|
||||
"""
|
||||
Executes the line number extraction strategy for a single document key.
|
||||
|
||||
|
|
@ -148,18 +132,10 @@ class ExtractOperation(BaseOperation):
|
|||
formatted_text = self._reformat_text_with_line_numbers(text_content)
|
||||
|
||||
# Render the prompt
|
||||
# Retrieval context
|
||||
retrieval_context = self._maybe_build_retrieval_context({"input": item})
|
||||
extraction_instructions = strict_render(
|
||||
self.config["prompt"],
|
||||
{"input": item, "retrieval_context": retrieval_context},
|
||||
)
|
||||
extraction_instructions = strict_render(self.config["prompt"], {"input": item})
|
||||
augmented_prompt_template = """
|
||||
You are extracting specific content from text documents. Extract information according to these instructions: {{ extraction_instructions }}
|
||||
|
||||
Extra context (may be helpful):
|
||||
{{ retrieval_context }}
|
||||
|
||||
The text is formatted with line numbers as follows:
|
||||
{{ formatted_text }}
|
||||
|
||||
|
|
@ -186,7 +162,6 @@ Do not include explanatory text in your response, only the JSON object.
|
|||
{
|
||||
"extraction_instructions": extraction_instructions,
|
||||
"formatted_text": formatted_text,
|
||||
"retrieval_context": retrieval_context,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -227,7 +202,7 @@ Do not include explanatory text in your response, only the JSON object.
|
|||
self.console.log(
|
||||
f"[bold red]Error parsing LLM response: {llm_result.response}. Skipping.[/bold red]"
|
||||
)
|
||||
return [], llm_result.total_cost, retrieval_context
|
||||
return [], llm_result.total_cost
|
||||
|
||||
for line_range in parsed_output.get("line_ranges", []):
|
||||
start_line = line_range.get("start_line", 0)
|
||||
|
|
@ -255,20 +230,20 @@ Do not include explanatory text in your response, only the JSON object.
|
|||
|
||||
extracted_texts.append("".join(extracted_content))
|
||||
|
||||
return extracted_texts, llm_result.total_cost, retrieval_context
|
||||
return extracted_texts, llm_result.total_cost
|
||||
|
||||
except Exception as e:
|
||||
if self.config.get("skip_on_error", True):
|
||||
self.console.log(
|
||||
f"[bold red]Error parsing LLM response: {str(e)}. Skipping.[/bold red]"
|
||||
)
|
||||
return [], llm_result.total_cost, retrieval_context
|
||||
return [], llm_result.total_cost
|
||||
else:
|
||||
raise RuntimeError(f"Error parsing LLM response: {str(e)}") from e
|
||||
|
||||
def _execute_regex_strategy(
|
||||
self, item: dict, doc_key: str
|
||||
) -> tuple[list[str], float, str]:
|
||||
) -> tuple[list[str], float]:
|
||||
"""
|
||||
Executes the regex extraction strategy for a single document key.
|
||||
|
||||
|
|
@ -277,7 +252,7 @@ Do not include explanatory text in your response, only the JSON object.
|
|||
doc_key (str): The key of the document text to process.
|
||||
|
||||
Returns:
|
||||
tuple[list[str], float, str]: A tuple containing the extraction results, cost, and retrieval context.
|
||||
tuple[list[str], float]: A tuple containing the extraction results and the cost.
|
||||
"""
|
||||
import re
|
||||
|
||||
|
|
@ -287,7 +262,7 @@ Do not include explanatory text in your response, only the JSON object.
|
|||
self.console.log(
|
||||
f"[yellow]Warning: Key '{doc_key}' not found or not a string in document. Skipping.[/yellow]"
|
||||
)
|
||||
return [], 0.0, ""
|
||||
return [], 0.0
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Key '{doc_key}' not found or not a string in document"
|
||||
|
|
@ -296,21 +271,13 @@ Do not include explanatory text in your response, only the JSON object.
|
|||
text_content = item[doc_key]
|
||||
|
||||
# Prepare the context for prompt rendering
|
||||
retrieval_context = self._maybe_build_retrieval_context({"input": item})
|
||||
context = {
|
||||
"input": item,
|
||||
"text_content": text_content,
|
||||
"retrieval_context": retrieval_context,
|
||||
}
|
||||
context = {"input": item, "text_content": text_content}
|
||||
|
||||
# Render the prompt
|
||||
extraction_instructions = strict_render(self.config["prompt"], context)
|
||||
augmented_prompt_template = """
|
||||
You are creating regex patterns to extract specific content from text. Extract information according to these instructions: {{ extraction_instructions }}
|
||||
|
||||
Extra context (may be helpful):
|
||||
{{ retrieval_context }}
|
||||
|
||||
The text to analyze is:
|
||||
{{ text_content }}
|
||||
|
||||
|
|
@ -389,14 +356,14 @@ Return only the JSON object with your patterns, no explanatory text.
|
|||
else:
|
||||
raise ValueError(f"Invalid regex pattern '{pattern}': {str(e)}")
|
||||
|
||||
return extracted_texts, llm_result.total_cost, retrieval_context
|
||||
return extracted_texts, llm_result.total_cost
|
||||
|
||||
except Exception as e:
|
||||
if self.config.get("skip_on_error", True):
|
||||
self.console.log(
|
||||
f"[bold red]Error parsing LLM response: {str(e)}. Skipping.[/bold red]"
|
||||
)
|
||||
return [], llm_result.total_cost, retrieval_context
|
||||
return [], llm_result.total_cost
|
||||
else:
|
||||
raise RuntimeError(f"Error parsing LLM response: {str(e)}") from e
|
||||
|
||||
|
|
@ -410,10 +377,6 @@ Return only the JSON object with your patterns, no explanatory text.
|
|||
Returns:
|
||||
tuple[list[dict], float]: A tuple containing the processed data and the total cost of the operation.
|
||||
"""
|
||||
limit_value = self.config.get("limit")
|
||||
if limit_value is not None:
|
||||
input_data = input_data[:limit_value]
|
||||
|
||||
if not input_data:
|
||||
return [], 0.0
|
||||
|
||||
|
|
@ -468,7 +431,7 @@ Return only the JSON object with your patterns, no explanatory text.
|
|||
doc_key, future, output_item = futures[i]
|
||||
|
||||
try:
|
||||
extracted_texts_duped, cost, retrieval_context = future.result()
|
||||
extracted_texts_duped, cost = future.result()
|
||||
|
||||
# Remove duplicates and empty strings
|
||||
extracted_texts_duped = [
|
||||
|
|
@ -490,12 +453,6 @@ Return only the JSON object with your patterns, no explanatory text.
|
|||
else:
|
||||
output_item[output_key] = extracted_texts
|
||||
|
||||
# Save retrieved context if enabled
|
||||
if self.config.get("save_retriever_output", False):
|
||||
output_item[f"_{self.config['name']}_retrieved_context"] = (
|
||||
retrieval_context if retrieval_context else ""
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if self.config.get("skip_on_error", True):
|
||||
self.console.log(
|
||||
|
|
|
|||
|
|
@ -33,30 +33,6 @@ class FilterOperation(MapOperation):
|
|||
|
||||
return self
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._filter_key = next(
|
||||
iter(
|
||||
[
|
||||
k
|
||||
for k in self.config["output"]["schema"].keys()
|
||||
if k != "_short_explanation"
|
||||
]
|
||||
)
|
||||
)
|
||||
self._filter_is_build = False
|
||||
|
||||
def _limit_applies_to_inputs(self) -> bool:
|
||||
return False
|
||||
|
||||
def _handle_result(self, result: dict[str, Any]) -> tuple[dict | None, bool]:
|
||||
keep_record = bool(result.get(self._filter_key))
|
||||
result.pop(self._filter_key, None)
|
||||
|
||||
if self._filter_is_build or keep_record:
|
||||
return result, keep_record
|
||||
return None, False
|
||||
|
||||
def execute(
|
||||
self, input_data: list[dict], is_build: bool = False
|
||||
) -> tuple[list[dict], float]:
|
||||
|
|
@ -70,10 +46,55 @@ class FilterOperation(MapOperation):
|
|||
Returns:
|
||||
tuple[list[dict], float]: A tuple containing the filtered list of dictionaries
|
||||
and the total cost of the operation.
|
||||
|
||||
This method performs the following steps:
|
||||
1. Processes each input item using an LLM model
|
||||
2. Validates the output
|
||||
3. Filters the results based on the specified filter key
|
||||
4. Calculates the total cost of the operation
|
||||
|
||||
The method uses multi-threading to process items in parallel, improving performance
|
||||
for large datasets.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
from docetl.operations import FilterOperation
|
||||
|
||||
config = {
|
||||
"prompt": "Determine if the following item is important: {{input}}",
|
||||
"output": {
|
||||
"schema": {"is_important": "bool"}
|
||||
},
|
||||
"model": "gpt-3.5-turbo"
|
||||
}
|
||||
filter_op = FilterOperation(config)
|
||||
input_data = [
|
||||
{"id": 1, "text": "Critical update"},
|
||||
{"id": 2, "text": "Regular maintenance"}
|
||||
]
|
||||
results, cost = filter_op.execute(input_data)
|
||||
print(f"Filtered results: {results}")
|
||||
print(f"Total cost: {cost}")
|
||||
```
|
||||
"""
|
||||
previous_state = self._filter_is_build
|
||||
self._filter_is_build = is_build
|
||||
try:
|
||||
return super().execute(input_data)
|
||||
finally:
|
||||
self._filter_is_build = previous_state
|
||||
filter_key = next(
|
||||
iter(
|
||||
[
|
||||
k
|
||||
for k in self.config["output"]["schema"].keys()
|
||||
if k != "_short_explanation"
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
results, total_cost = super().execute(input_data)
|
||||
|
||||
# Drop records with filter_key values that are False
|
||||
if not is_build:
|
||||
results = [result for result in results if result[filter_key]]
|
||||
|
||||
# Drop the filter_key from the results
|
||||
for result in results:
|
||||
result.pop(filter_key, None)
|
||||
|
||||
return results, total_cost
|
||||
|
|
|
|||
|
|
@ -7,29 +7,11 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|||
|
||||
from docetl.operations.base import BaseOperation
|
||||
from docetl.operations.utils import RichLoopBar, strict_render
|
||||
from docetl.utils import has_jinja_syntax, prompt_user_for_non_jinja_confirmation
|
||||
|
||||
from .clustering_utils import get_embeddings_for_clustering
|
||||
|
||||
|
||||
class LinkResolveOperation(BaseOperation):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Check for non-Jinja prompts and prompt user for confirmation
|
||||
if "comparison_prompt" in self.config and not has_jinja_syntax(
|
||||
self.config["comparison_prompt"]
|
||||
):
|
||||
if not prompt_user_for_non_jinja_confirmation(
|
||||
self.config["comparison_prompt"],
|
||||
self.config["name"],
|
||||
"comparison_prompt",
|
||||
):
|
||||
raise ValueError(
|
||||
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your comparison_prompt."
|
||||
)
|
||||
# Mark that we need to append document statement
|
||||
# Note: link_resolve uses link_value, id_value, and item, so strict_render will handle it
|
||||
self.config["_append_document_to_comparison_prompt"] = True
|
||||
def execute(self, input_data: list[dict]) -> tuple[list[dict], float]:
|
||||
"""
|
||||
Executes the resolve links operation on the provided dataset.
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from docetl.base_schemas import Tool, ToolFunction
|
|||
from docetl.operations.base import BaseOperation
|
||||
from docetl.operations.utils import RichLoopBar, strict_render, validate_output_types
|
||||
from docetl.operations.utils.api import OutputMode
|
||||
from docetl.utils import has_jinja_syntax, prompt_user_for_non_jinja_confirmation
|
||||
|
||||
|
||||
class MapOperation(BaseOperation):
|
||||
|
|
@ -43,7 +42,6 @@ class MapOperation(BaseOperation):
|
|||
litellm_completion_kwargs: dict[str, Any] = {}
|
||||
pdf_url_key: str | None = None
|
||||
flush_partial_result: bool = False
|
||||
limit: int | None = Field(None, gt=0)
|
||||
# Calibration parameters
|
||||
calibrate: bool = False
|
||||
num_calibration_docs: int = Field(10, gt=0)
|
||||
|
|
@ -51,11 +49,6 @@ class MapOperation(BaseOperation):
|
|||
@field_validator("batch_prompt")
|
||||
def validate_batch_prompt(cls, v):
|
||||
if v is not None:
|
||||
# Check if it has Jinja syntax
|
||||
if not has_jinja_syntax(v):
|
||||
# This will be handled during initialization with user confirmation
|
||||
# We'll mark it for later processing
|
||||
return v
|
||||
try:
|
||||
template = Template(v)
|
||||
# Test render with a minimal inputs list to validate template
|
||||
|
|
@ -69,11 +62,6 @@ class MapOperation(BaseOperation):
|
|||
@field_validator("prompt")
|
||||
def validate_prompt(cls, v):
|
||||
if v is not None:
|
||||
# Check if it has Jinja syntax
|
||||
if not has_jinja_syntax(v):
|
||||
# This will be handled during initialization with user confirmation
|
||||
# We'll mark it for later processing
|
||||
return v
|
||||
try:
|
||||
Template(v)
|
||||
except Exception as e:
|
||||
|
|
@ -130,33 +118,6 @@ class MapOperation(BaseOperation):
|
|||
"max_batch_size", kwargs.get("max_batch_size", None)
|
||||
)
|
||||
self.clustering_method = "random"
|
||||
# Check for non-Jinja prompts and prompt user for confirmation
|
||||
if "prompt" in self.config and not has_jinja_syntax(self.config["prompt"]):
|
||||
if not prompt_user_for_non_jinja_confirmation(
|
||||
self.config["prompt"], self.config["name"], "prompt"
|
||||
):
|
||||
raise ValueError(
|
||||
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your prompt."
|
||||
)
|
||||
# Mark that we need to append document statement
|
||||
self.config["_append_document_to_prompt"] = True
|
||||
if "batch_prompt" in self.config and not has_jinja_syntax(
|
||||
self.config["batch_prompt"]
|
||||
):
|
||||
if not prompt_user_for_non_jinja_confirmation(
|
||||
self.config["batch_prompt"], self.config["name"], "batch_prompt"
|
||||
):
|
||||
raise ValueError(
|
||||
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your batch_prompt."
|
||||
)
|
||||
# Mark that we need to append document statement
|
||||
self.config["_append_document_to_batch_prompt"] = True
|
||||
|
||||
def _limit_applies_to_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
def _handle_result(self, result: dict[str, Any]) -> tuple[dict | None, bool]:
|
||||
return result, True
|
||||
|
||||
def _generate_calibration_context(self, input_data: list[dict]) -> str:
|
||||
"""
|
||||
|
|
@ -278,27 +239,17 @@ Reference anchors:"""
|
|||
|
||||
The method uses parallel processing to improve performance.
|
||||
"""
|
||||
limit_value = self.config.get("limit")
|
||||
|
||||
# Check if there's no prompt and only drop_keys
|
||||
if "prompt" not in self.config and "drop_keys" in self.config:
|
||||
data_to_process = input_data
|
||||
if limit_value is not None and self._limit_applies_to_inputs():
|
||||
data_to_process = input_data[:limit_value]
|
||||
# If only drop_keys is specified, simply drop the keys and return
|
||||
dropped_results = []
|
||||
for item in data_to_process:
|
||||
for item in input_data:
|
||||
new_item = {
|
||||
k: v for k, v in item.items() if k not in self.config["drop_keys"]
|
||||
}
|
||||
dropped_results.append(new_item)
|
||||
if limit_value is not None and len(dropped_results) >= limit_value:
|
||||
break
|
||||
return dropped_results, 0.0 # Return the modified data with no cost
|
||||
|
||||
if limit_value is not None and self._limit_applies_to_inputs():
|
||||
input_data = input_data[:limit_value]
|
||||
|
||||
# Generate calibration context if enabled
|
||||
calibration_context = ""
|
||||
if self.config.get("calibrate", False) and "prompt" in self.config:
|
||||
|
|
@ -325,17 +276,7 @@ Reference anchors:"""
|
|||
item: dict, initial_result: dict | None = None
|
||||
) -> tuple[dict | None, float]:
|
||||
|
||||
# Build retrieval context (if configured)
|
||||
retrieval_context = self._maybe_build_retrieval_context({"input": item})
|
||||
ctx = {"input": item, "retrieval_context": retrieval_context}
|
||||
rendered = strict_render(self.config["prompt"], ctx)
|
||||
# If template didn't use retrieval_context, prepend a standard header
|
||||
prompt = (
|
||||
f"Here is some extra context:\n{retrieval_context}\n\n{rendered}"
|
||||
if retrieval_context
|
||||
and "retrieval_context" not in self.config["prompt"]
|
||||
else rendered
|
||||
)
|
||||
prompt = strict_render(self.config["prompt"], {"input": item})
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
if self.config.get("pdf_url_key", None):
|
||||
# Append the pdf to the prompt
|
||||
|
|
@ -443,12 +384,6 @@ Reference anchors:"""
|
|||
output[f"_observability_{self.config['name']}"] = {
|
||||
"prompt": prompt
|
||||
}
|
||||
# Add retrieved context if save_retriever_output is enabled
|
||||
if self.config.get("save_retriever_output", False):
|
||||
for output in outputs:
|
||||
output[f"_{self.config['name']}_retrieved_context"] = (
|
||||
retrieval_context if retrieval_context else ""
|
||||
)
|
||||
return outputs, llm_result.total_cost
|
||||
|
||||
return None, llm_result.total_cost
|
||||
|
|
@ -544,87 +479,40 @@ Reference anchors:"""
|
|||
|
||||
return all_results, total_cost
|
||||
|
||||
limit_counter = 0
|
||||
batch_size = self.max_batch_size if self.max_batch_size is not None else 1
|
||||
total_batches = (len(input_data) + batch_size - 1) // batch_size
|
||||
if total_batches == 0:
|
||||
if self.status:
|
||||
self.status.start()
|
||||
return [], 0.0
|
||||
|
||||
worker_limit = self.max_batch_size or self.max_threads or 1
|
||||
window_size = (
|
||||
total_batches
|
||||
if limit_value is None
|
||||
else max(1, (limit_value + batch_size - 1) // batch_size)
|
||||
)
|
||||
|
||||
results: list[dict] = []
|
||||
total_cost = 0.0
|
||||
limit_reached = False
|
||||
op_name = self.config["name"]
|
||||
|
||||
if limit_value is not None and not self._limit_applies_to_inputs():
|
||||
self.console.log(
|
||||
f"[yellow]Note: Operation will terminate early once {limit_value} items pass the filter condition.[/yellow]"
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=worker_limit) as executor:
|
||||
with RichLoopBar(
|
||||
total=total_batches,
|
||||
desc=f"Processing {op_name} (map) on all documents",
|
||||
with ThreadPoolExecutor(max_workers=self.max_batch_size) as executor:
|
||||
batch_size = self.max_batch_size if self.max_batch_size is not None else 1
|
||||
futures = []
|
||||
for i in range(0, len(input_data), batch_size):
|
||||
batch = input_data[i : i + batch_size]
|
||||
futures.append(executor.submit(_process_map_batch, batch))
|
||||
results = []
|
||||
total_cost = 0
|
||||
pbar = RichLoopBar(
|
||||
range(len(futures)),
|
||||
desc=f"Processing {self.config['name']} (map) on all documents",
|
||||
console=self.console,
|
||||
) as pbar:
|
||||
chunk_start = 0
|
||||
while chunk_start < total_batches and not limit_reached:
|
||||
chunk_end = min(total_batches, chunk_start + window_size)
|
||||
chunk_ordinals = list(range(chunk_start, chunk_end))
|
||||
futures = []
|
||||
for ordinal in chunk_ordinals:
|
||||
start_idx = ordinal * batch_size
|
||||
batch = input_data[start_idx : start_idx + batch_size]
|
||||
futures.append(executor.submit(_process_map_batch, batch))
|
||||
|
||||
for relative_idx, future in enumerate(futures):
|
||||
if limit_value is not None and limit_counter >= limit_value:
|
||||
limit_reached = True
|
||||
break
|
||||
|
||||
result_list, item_cost = future.result()
|
||||
total_cost += item_cost
|
||||
|
||||
if result_list:
|
||||
if "drop_keys" in self.config:
|
||||
result_list = [
|
||||
{
|
||||
k: v
|
||||
for k, v in result.items()
|
||||
if k not in self.config["drop_keys"]
|
||||
}
|
||||
for result in result_list
|
||||
]
|
||||
|
||||
if self.config.get("flush_partial_results", False):
|
||||
self.runner._flush_partial_results(
|
||||
op_name, chunk_ordinals[relative_idx], result_list
|
||||
)
|
||||
|
||||
for result in result_list:
|
||||
processed_result, counts_towards_limit = (
|
||||
self._handle_result(result)
|
||||
)
|
||||
if processed_result is not None:
|
||||
results.append(processed_result)
|
||||
|
||||
if limit_value is not None and counts_towards_limit:
|
||||
limit_counter += 1
|
||||
if limit_counter >= limit_value:
|
||||
limit_reached = True
|
||||
break
|
||||
|
||||
pbar.update()
|
||||
|
||||
chunk_start = chunk_end
|
||||
)
|
||||
for batch_index in pbar:
|
||||
result_list, item_cost = futures[batch_index].result()
|
||||
if result_list:
|
||||
if "drop_keys" in self.config:
|
||||
result_list = [
|
||||
{
|
||||
k: v
|
||||
for k, v in result.items()
|
||||
if k not in self.config["drop_keys"]
|
||||
}
|
||||
for result in result_list
|
||||
]
|
||||
results.extend(result_list)
|
||||
# --- BEGIN: Flush partial checkpoint ---
|
||||
if self.config.get("flush_partial_results", False):
|
||||
op_name = self.config["name"]
|
||||
self.runner._flush_partial_results(
|
||||
op_name, batch_index, result_list
|
||||
)
|
||||
# --- END: Flush partial checkpoint ---
|
||||
total_cost += item_cost
|
||||
|
||||
if self.status:
|
||||
self.status.start()
|
||||
|
|
|
|||
|
|
@ -32,11 +32,7 @@ from docetl.operations.utils import (
|
|||
|
||||
# Import OutputMode enum for structured output checks
|
||||
from docetl.operations.utils.api import OutputMode
|
||||
from docetl.utils import (
|
||||
completion_cost,
|
||||
has_jinja_syntax,
|
||||
prompt_user_for_non_jinja_confirmation,
|
||||
)
|
||||
from docetl.utils import completion_cost
|
||||
|
||||
|
||||
class ReduceOperation(BaseOperation):
|
||||
|
|
@ -67,15 +63,10 @@ class ReduceOperation(BaseOperation):
|
|||
timeout: int | None = None
|
||||
litellm_completion_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
enable_observability: bool = False
|
||||
limit: int | None = Field(None, gt=0)
|
||||
|
||||
@field_validator("prompt")
|
||||
def validate_prompt(cls, v):
|
||||
if v is not None:
|
||||
# Check if it has Jinja syntax
|
||||
if not has_jinja_syntax(v):
|
||||
# This will be handled during initialization with user confirmation
|
||||
return v
|
||||
try:
|
||||
template = Template(v)
|
||||
template_vars = template.environment.parse(v).find_all(
|
||||
|
|
@ -93,10 +84,6 @@ class ReduceOperation(BaseOperation):
|
|||
@field_validator("fold_prompt")
|
||||
def validate_fold_prompt(cls, v):
|
||||
if v is not None:
|
||||
# Check if it has Jinja syntax
|
||||
if not has_jinja_syntax(v):
|
||||
# This will be handled during initialization with user confirmation
|
||||
return v
|
||||
try:
|
||||
fold_template = Template(v)
|
||||
fold_template_vars = fold_template.environment.parse(v).find_all(
|
||||
|
|
@ -117,10 +104,6 @@ class ReduceOperation(BaseOperation):
|
|||
@field_validator("merge_prompt")
|
||||
def validate_merge_prompt(cls, v):
|
||||
if v is not None:
|
||||
# Check if it has Jinja syntax
|
||||
if not has_jinja_syntax(v):
|
||||
# This will be handled during initialization with user confirmation
|
||||
return v
|
||||
try:
|
||||
merge_template = Template(v)
|
||||
merge_template_vars = merge_template.environment.parse(v).find_all(
|
||||
|
|
@ -198,39 +181,6 @@ class ReduceOperation(BaseOperation):
|
|||
)
|
||||
self.intermediates = {}
|
||||
self.lineage_keys = self.config.get("output", {}).get("lineage", [])
|
||||
# Check for non-Jinja prompts and prompt user for confirmation
|
||||
if "prompt" in self.config and not has_jinja_syntax(self.config["prompt"]):
|
||||
if not prompt_user_for_non_jinja_confirmation(
|
||||
self.config["prompt"], self.config["name"], "prompt"
|
||||
):
|
||||
raise ValueError(
|
||||
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your prompt."
|
||||
)
|
||||
# Mark that we need to append document statement (for reduce, use inputs)
|
||||
self.config["_append_document_to_prompt"] = True
|
||||
self.config["_is_reduce_operation"] = True
|
||||
if "fold_prompt" in self.config and not has_jinja_syntax(
|
||||
self.config["fold_prompt"]
|
||||
):
|
||||
if not prompt_user_for_non_jinja_confirmation(
|
||||
self.config["fold_prompt"], self.config["name"], "fold_prompt"
|
||||
):
|
||||
raise ValueError(
|
||||
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your fold_prompt."
|
||||
)
|
||||
self.config["_append_document_to_fold_prompt"] = True
|
||||
self.config["_is_reduce_operation"] = True
|
||||
if "merge_prompt" in self.config and not has_jinja_syntax(
|
||||
self.config["merge_prompt"]
|
||||
):
|
||||
if not prompt_user_for_non_jinja_confirmation(
|
||||
self.config["merge_prompt"], self.config["name"], "merge_prompt"
|
||||
):
|
||||
raise ValueError(
|
||||
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your merge_prompt."
|
||||
)
|
||||
self.config["_append_document_to_merge_prompt"] = True
|
||||
self.config["_is_reduce_operation"] = True
|
||||
|
||||
def execute(self, input_data: list[dict]) -> tuple[list[dict], float]:
|
||||
"""
|
||||
|
|
@ -286,12 +236,6 @@ class ReduceOperation(BaseOperation):
|
|||
# Convert the grouped data to a list of tuples
|
||||
grouped_data = list(grouped_data.items())
|
||||
|
||||
limit_value = self.config.get("limit")
|
||||
if limit_value is not None:
|
||||
# Sort by group size (smallest first) and take the limit
|
||||
grouped_data = sorted(grouped_data, key=lambda x: len(x[1]))
|
||||
grouped_data = grouped_data[:limit_value]
|
||||
|
||||
def process_group(
|
||||
key: tuple, group_elems: list[dict]
|
||||
) -> tuple[dict | None, float]:
|
||||
|
|
@ -304,16 +248,6 @@ class ReduceOperation(BaseOperation):
|
|||
group_list = group_elems
|
||||
|
||||
total_cost = 0.0
|
||||
# Build retrieval context once per group
|
||||
try:
|
||||
retrieval_context = self._maybe_build_retrieval_context(
|
||||
{
|
||||
"reduce_key": dict(zip(self.config["reduce_key"], key)),
|
||||
"inputs": group_list,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
retrieval_context = "No extra context available."
|
||||
|
||||
# Apply value sampling if enabled
|
||||
value_sampling = self.config.get("value_sampling", {})
|
||||
|
|
@ -343,26 +277,18 @@ class ReduceOperation(BaseOperation):
|
|||
|
||||
# Only execute merge-based plans if associative = True
|
||||
if "merge_prompt" in self.config and self.config.get("associative", True):
|
||||
result, prompts, cost = self._parallel_fold_and_merge(
|
||||
key, group_list, retrieval_context
|
||||
)
|
||||
result, prompts, cost = self._parallel_fold_and_merge(key, group_list)
|
||||
elif self.config.get("fold_batch_size", None) and self.config.get(
|
||||
"fold_batch_size"
|
||||
) >= len(group_list):
|
||||
# If the fold batch size is greater than or equal to the number of items in the group,
|
||||
# we can just run a single fold operation
|
||||
result, prompt, cost = self._batch_reduce(
|
||||
key, group_list, None, retrieval_context
|
||||
)
|
||||
result, prompt, cost = self._batch_reduce(key, group_list)
|
||||
prompts = [prompt]
|
||||
elif "fold_prompt" in self.config:
|
||||
result, prompts, cost = self._incremental_reduce(
|
||||
key, group_list, retrieval_context
|
||||
)
|
||||
result, prompts, cost = self._incremental_reduce(key, group_list)
|
||||
else:
|
||||
result, prompt, cost = self._batch_reduce(
|
||||
key, group_list, None, retrieval_context
|
||||
)
|
||||
result, prompt, cost = self._batch_reduce(key, group_list)
|
||||
prompts = [prompt]
|
||||
|
||||
total_cost += cost
|
||||
|
|
@ -374,16 +300,6 @@ class ReduceOperation(BaseOperation):
|
|||
# Add the _observability_{self.config['name']} key to the result
|
||||
result[f"_observability_{self.config['name']}"] = {"prompts": prompts}
|
||||
|
||||
# Add retrieved context if save_retriever_output is enabled
|
||||
if self.config.get("save_retriever_output", False):
|
||||
ctx = (
|
||||
retrieval_context
|
||||
if retrieval_context
|
||||
and retrieval_context != "No extra context available."
|
||||
else ""
|
||||
)
|
||||
result[f"_{self.config['name']}_retrieved_context"] = ctx
|
||||
|
||||
# Apply pass-through at the group level
|
||||
if (
|
||||
result is not None
|
||||
|
|
@ -426,9 +342,6 @@ class ReduceOperation(BaseOperation):
|
|||
if output is not None:
|
||||
results.append(output)
|
||||
|
||||
if limit_value is not None and len(results) > limit_value:
|
||||
results = results[:limit_value]
|
||||
|
||||
if self.config.get("persist_intermediates", False):
|
||||
for result in results:
|
||||
key = tuple(result[k] for k in self.config["reduce_key"])
|
||||
|
|
@ -505,7 +418,7 @@ class ReduceOperation(BaseOperation):
|
|||
return [group_list[i] for i in top_k_indices], cost
|
||||
|
||||
def _parallel_fold_and_merge(
|
||||
self, key: tuple, group_list: list[dict], retrieval_context: str
|
||||
self, key: tuple, group_list: list[dict]
|
||||
) -> tuple[dict | None, float]:
|
||||
"""
|
||||
Perform parallel folding and merging on a group of items.
|
||||
|
|
@ -670,7 +583,7 @@ class ReduceOperation(BaseOperation):
|
|||
)
|
||||
|
||||
def _incremental_reduce(
|
||||
self, key: tuple, group_list: list[dict], retrieval_context: str
|
||||
self, key: tuple, group_list: list[dict]
|
||||
) -> tuple[dict | None, list[str], float]:
|
||||
"""
|
||||
Perform an incremental reduce operation on a group of items.
|
||||
|
|
@ -766,7 +679,6 @@ class ReduceOperation(BaseOperation):
|
|||
batch: list[dict],
|
||||
current_output: dict | None,
|
||||
scratchpad: str | None = None,
|
||||
retrieval_context: str | None = None,
|
||||
) -> tuple[dict | None, str, float]:
|
||||
"""
|
||||
Perform an incremental fold operation on a batch of items.
|
||||
|
|
@ -783,7 +695,7 @@ class ReduceOperation(BaseOperation):
|
|||
the prompt used, and the cost of the fold operation.
|
||||
"""
|
||||
if current_output is None:
|
||||
return self._batch_reduce(key, batch, scratchpad, retrieval_context)
|
||||
return self._batch_reduce(key, batch, scratchpad)
|
||||
|
||||
start_time = time.time()
|
||||
fold_prompt = strict_render(
|
||||
|
|
@ -792,15 +704,8 @@ class ReduceOperation(BaseOperation):
|
|||
"inputs": batch,
|
||||
"output": current_output,
|
||||
"reduce_key": dict(zip(self.config["reduce_key"], key)),
|
||||
"retrieval_context": retrieval_context or "",
|
||||
},
|
||||
)
|
||||
if retrieval_context and "retrieval_context" not in self.config.get(
|
||||
"fold_prompt", ""
|
||||
):
|
||||
fold_prompt = (
|
||||
f"Here is some extra context:\n{retrieval_context}\n\n{fold_prompt}"
|
||||
)
|
||||
|
||||
response = self.runner.api.call_llm(
|
||||
self.config.get("model", self.default_model),
|
||||
|
|
@ -825,7 +730,6 @@ class ReduceOperation(BaseOperation):
|
|||
|
||||
end_time = time.time()
|
||||
self._update_fold_time(end_time - start_time)
|
||||
fold_cost = response.total_cost
|
||||
|
||||
if response.validated:
|
||||
structured_mode = (
|
||||
|
|
@ -847,7 +751,7 @@ class ReduceOperation(BaseOperation):
|
|||
return None, fold_prompt, fold_cost
|
||||
|
||||
def _merge_results(
|
||||
self, key: tuple, outputs: list[dict], retrieval_context: str | None = None
|
||||
self, key: tuple, outputs: list[dict]
|
||||
) -> tuple[dict | None, str, float]:
|
||||
"""
|
||||
Merge multiple outputs into a single result.
|
||||
|
|
@ -868,15 +772,8 @@ class ReduceOperation(BaseOperation):
|
|||
{
|
||||
"outputs": outputs,
|
||||
"reduce_key": dict(zip(self.config["reduce_key"], key)),
|
||||
"retrieval_context": retrieval_context or "",
|
||||
},
|
||||
)
|
||||
if retrieval_context and "retrieval_context" not in self.config.get(
|
||||
"merge_prompt", ""
|
||||
):
|
||||
merge_prompt = (
|
||||
f"Here is some extra context:\n{retrieval_context}\n\n{merge_prompt}"
|
||||
)
|
||||
response = self.runner.api.call_llm(
|
||||
self.config.get("model", self.default_model),
|
||||
"merge",
|
||||
|
|
@ -901,7 +798,6 @@ class ReduceOperation(BaseOperation):
|
|||
|
||||
end_time = time.time()
|
||||
self._update_merge_time(end_time - start_time)
|
||||
merge_cost = response.total_cost
|
||||
|
||||
if response.validated:
|
||||
structured_mode = (
|
||||
|
|
@ -971,11 +867,7 @@ class ReduceOperation(BaseOperation):
|
|||
self.merge_times.append(time)
|
||||
|
||||
def _batch_reduce(
|
||||
self,
|
||||
key: tuple,
|
||||
group_list: list[dict],
|
||||
scratchpad: str | None = None,
|
||||
retrieval_context: str | None = None,
|
||||
self, key: tuple, group_list: list[dict], scratchpad: str | None = None
|
||||
) -> tuple[dict | None, str, float]:
|
||||
"""
|
||||
Perform a batch reduce operation on a group of items.
|
||||
|
|
@ -995,13 +887,8 @@ class ReduceOperation(BaseOperation):
|
|||
{
|
||||
"reduce_key": dict(zip(self.config["reduce_key"], key)),
|
||||
"inputs": group_list,
|
||||
"retrieval_context": retrieval_context or "",
|
||||
},
|
||||
)
|
||||
if retrieval_context and "retrieval_context" not in self.config.get(
|
||||
"prompt", ""
|
||||
):
|
||||
prompt = f"Here is some extra context:\n{retrieval_context}\n\n{prompt}"
|
||||
item_cost = 0
|
||||
|
||||
response = self.runner.api.call_llm(
|
||||
|
|
|
|||
|
|
@ -10,16 +10,11 @@ import jinja2
|
|||
from jinja2 import Template
|
||||
from litellm import model_cost
|
||||
from pydantic import Field, ValidationInfo, field_validator, model_validator
|
||||
from rich.prompt import Confirm
|
||||
|
||||
from docetl.operations.base import BaseOperation
|
||||
from docetl.operations.utils import RichLoopBar, rich_as_completed, strict_render
|
||||
from docetl.operations.utils.blocking import RuntimeBlockingOptimizer
|
||||
from docetl.utils import (
|
||||
completion_cost,
|
||||
extract_jinja_variables,
|
||||
has_jinja_syntax,
|
||||
prompt_user_for_non_jinja_confirmation,
|
||||
)
|
||||
from docetl.utils import completion_cost, extract_jinja_variables
|
||||
|
||||
|
||||
def find_cluster(item, cluster_map):
|
||||
|
|
@ -40,7 +35,6 @@ class ResolveOperation(BaseOperation):
|
|||
comparison_model: str | None = None
|
||||
blocking_keys: list[str] | None = None
|
||||
blocking_threshold: float | None = Field(None, ge=0, le=1)
|
||||
blocking_target_recall: float | None = Field(None, ge=0, le=1)
|
||||
blocking_conditions: list[str] | None = None
|
||||
input: dict[str, Any] | None = None
|
||||
embedding_batch_size: int | None = Field(None, gt=0)
|
||||
|
|
@ -54,10 +48,6 @@ class ResolveOperation(BaseOperation):
|
|||
@field_validator("comparison_prompt")
|
||||
def validate_comparison_prompt(cls, v):
|
||||
if v is not None:
|
||||
# Check if it has Jinja syntax
|
||||
if not has_jinja_syntax(v):
|
||||
# This will be handled during initialization with user confirmation
|
||||
return v
|
||||
try:
|
||||
comparison_template = Template(v)
|
||||
comparison_vars = comparison_template.environment.parse(v).find_all(
|
||||
|
|
@ -80,10 +70,6 @@ class ResolveOperation(BaseOperation):
|
|||
@field_validator("resolution_prompt")
|
||||
def validate_resolution_prompt(cls, v):
|
||||
if v is not None:
|
||||
# Check if it has Jinja syntax
|
||||
if not has_jinja_syntax(v):
|
||||
# This will be handled during initialization with user confirmation
|
||||
return v
|
||||
try:
|
||||
reduction_template = Template(v)
|
||||
reduction_vars = reduction_template.environment.parse(v).find_all(
|
||||
|
|
@ -137,38 +123,6 @@ class ResolveOperation(BaseOperation):
|
|||
|
||||
return self
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Check for non-Jinja prompts and prompt user for confirmation
|
||||
if "comparison_prompt" in self.config and not has_jinja_syntax(
|
||||
self.config["comparison_prompt"]
|
||||
):
|
||||
if not prompt_user_for_non_jinja_confirmation(
|
||||
self.config["comparison_prompt"],
|
||||
self.config["name"],
|
||||
"comparison_prompt",
|
||||
):
|
||||
raise ValueError(
|
||||
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your comparison_prompt."
|
||||
)
|
||||
# Mark that we need to append document statement
|
||||
# Note: comparison_prompt uses input1 and input2, so we'll handle it specially in strict_render
|
||||
self.config["_append_document_to_comparison_prompt"] = True
|
||||
if "resolution_prompt" in self.config and not has_jinja_syntax(
|
||||
self.config["resolution_prompt"]
|
||||
):
|
||||
if not prompt_user_for_non_jinja_confirmation(
|
||||
self.config["resolution_prompt"],
|
||||
self.config["name"],
|
||||
"resolution_prompt",
|
||||
):
|
||||
raise ValueError(
|
||||
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your resolution_prompt."
|
||||
)
|
||||
# Mark that we need to append document statement (resolution uses inputs)
|
||||
self.config["_append_document_to_resolution_prompt"] = True
|
||||
self.config["_is_reduce_operation"] = True
|
||||
|
||||
def compare_pair(
|
||||
self,
|
||||
comparison_prompt: str,
|
||||
|
|
@ -267,75 +221,26 @@ class ResolveOperation(BaseOperation):
|
|||
blocking_keys = self.config.get("blocking_keys", [])
|
||||
blocking_threshold = self.config.get("blocking_threshold")
|
||||
blocking_conditions = self.config.get("blocking_conditions", [])
|
||||
limit_comparisons = self.config.get("limit_comparisons")
|
||||
total_cost = 0
|
||||
if self.status:
|
||||
self.status.stop()
|
||||
|
||||
# Track pre-computed embeddings from auto-optimization
|
||||
precomputed_embeddings = None
|
||||
|
||||
# Auto-compute blocking threshold if no blocking configuration is provided
|
||||
if not blocking_threshold and not blocking_conditions and not limit_comparisons:
|
||||
# Get target recall from operation config (default 0.95)
|
||||
target_recall = self.config.get("blocking_target_recall", 0.95)
|
||||
self.console.log(
|
||||
f"[yellow]No blocking configuration. Auto-computing threshold (target recall: {target_recall:.0%})...[/yellow]"
|
||||
)
|
||||
# Determine blocking keys if not set
|
||||
auto_blocking_keys = blocking_keys if blocking_keys else None
|
||||
if not auto_blocking_keys:
|
||||
prompt_template = self.config.get("comparison_prompt", "")
|
||||
prompt_vars = extract_jinja_variables(prompt_template)
|
||||
prompt_vars = [
|
||||
var
|
||||
for var in prompt_vars
|
||||
if var not in ["input", "input1", "input2"]
|
||||
]
|
||||
auto_blocking_keys = list(
|
||||
set([var.split(".")[-1] for var in prompt_vars])
|
||||
)
|
||||
if not auto_blocking_keys:
|
||||
auto_blocking_keys = list(input_data[0].keys())
|
||||
blocking_keys = auto_blocking_keys
|
||||
|
||||
# Create comparison function for threshold optimization
|
||||
def compare_fn_for_optimization(item1, item2):
|
||||
return self.compare_pair(
|
||||
self.config["comparison_prompt"],
|
||||
self.config.get("comparison_model", self.default_model),
|
||||
item1,
|
||||
item2,
|
||||
blocking_keys=[], # Don't use key-based shortcut during optimization
|
||||
timeout_seconds=self.config.get("timeout", 120),
|
||||
max_retries_per_timeout=self.config.get(
|
||||
"max_retries_per_timeout", 2
|
||||
),
|
||||
)
|
||||
|
||||
# Run threshold optimization
|
||||
optimizer = RuntimeBlockingOptimizer(
|
||||
runner=self.runner,
|
||||
config=self.config,
|
||||
default_model=self.default_model,
|
||||
max_threads=self.max_threads,
|
||||
console=self.console,
|
||||
target_recall=target_recall,
|
||||
sample_size=min(100, len(input_data) * (len(input_data) - 1) // 4),
|
||||
)
|
||||
blocking_threshold, precomputed_embeddings, optimization_cost = (
|
||||
optimizer.optimize_resolve(
|
||||
input_data,
|
||||
compare_fn_for_optimization,
|
||||
blocking_keys=blocking_keys,
|
||||
)
|
||||
)
|
||||
total_cost += optimization_cost
|
||||
if not blocking_threshold and not blocking_conditions:
|
||||
# Prompt the user for confirmation
|
||||
if not Confirm.ask(
|
||||
"[yellow]Warning: No blocking keys or conditions specified. "
|
||||
"This may result in a large number of comparisons. "
|
||||
"We recommend specifying at least one blocking key or condition, or using the optimizer to automatically come up with these. "
|
||||
"Do you want to continue without blocking?[/yellow]",
|
||||
console=self.runner.console,
|
||||
):
|
||||
raise ValueError("Operation cancelled by user.")
|
||||
|
||||
input_schema = self.config.get("input", {}).get("schema", {})
|
||||
if not blocking_keys:
|
||||
# Set them to all keys in the input data
|
||||
blocking_keys = list(input_data[0].keys())
|
||||
limit_comparisons = self.config.get("limit_comparisons")
|
||||
total_cost = 0
|
||||
|
||||
def is_match(item1: dict[str, Any], item2: dict[str, Any]) -> bool:
|
||||
return any(
|
||||
|
|
@ -346,101 +251,120 @@ class ResolveOperation(BaseOperation):
|
|||
# Calculate embeddings if blocking_threshold is set
|
||||
embeddings = None
|
||||
if blocking_threshold is not None:
|
||||
# Use precomputed embeddings if available from auto-optimization
|
||||
if precomputed_embeddings is not None:
|
||||
embeddings = precomputed_embeddings
|
||||
else:
|
||||
self.console.log(
|
||||
f"[cyan]Creating embeddings for {len(input_data)} items...[/cyan]"
|
||||
)
|
||||
|
||||
def get_embeddings_batch(
|
||||
items: list[dict[str, Any]]
|
||||
) -> list[tuple[list[float], float]]:
|
||||
embedding_model = self.config.get(
|
||||
"embedding_model", "text-embedding-3-small"
|
||||
)
|
||||
model_input_context_length = model_cost.get(embedding_model, {}).get(
|
||||
"max_input_tokens", 8192
|
||||
)
|
||||
batch_size = self.config.get("embedding_batch_size", 1000)
|
||||
embeddings = []
|
||||
embedding_cost = 0.0
|
||||
num_batches = (len(input_data) + batch_size - 1) // batch_size
|
||||
|
||||
for batch_idx in range(num_batches):
|
||||
start_idx = batch_idx * batch_size
|
||||
end_idx = min(start_idx + batch_size, len(input_data))
|
||||
batch = input_data[start_idx:end_idx]
|
||||
|
||||
if num_batches > 1:
|
||||
self.console.log(
|
||||
f"[dim]Creating embeddings: batch {batch_idx + 1}/{num_batches} "
|
||||
f"({end_idx}/{len(input_data)} items)[/dim]"
|
||||
)
|
||||
|
||||
texts = [
|
||||
" ".join(
|
||||
str(item[key]) for key in blocking_keys if key in item
|
||||
)[: model_input_context_length * 3]
|
||||
for item in batch
|
||||
texts = [
|
||||
" ".join(str(item[key]) for key in blocking_keys if key in item)[
|
||||
: model_input_context_length * 3
|
||||
]
|
||||
response = self.runner.api.gen_embedding(
|
||||
model=embedding_model, input=texts
|
||||
)
|
||||
embeddings.extend([data["embedding"] for data in response["data"]])
|
||||
embedding_cost += completion_cost(response)
|
||||
for item in items
|
||||
]
|
||||
|
||||
total_cost += embedding_cost
|
||||
response = self.runner.api.gen_embedding(
|
||||
model=embedding_model, input=texts
|
||||
)
|
||||
return [
|
||||
(data["embedding"], completion_cost(response))
|
||||
for data in response["data"]
|
||||
]
|
||||
|
||||
# Build a mapping of blocking key values to indices
|
||||
# This is used later for cluster merging (when two items match, merge all items sharing their key values)
|
||||
value_to_indices: dict[tuple[str, ...], list[int]] = {}
|
||||
for i, item in enumerate(input_data):
|
||||
key = tuple(str(item.get(k, "")) for k in blocking_keys)
|
||||
if key not in value_to_indices:
|
||||
value_to_indices[key] = []
|
||||
value_to_indices[key].append(i)
|
||||
embeddings = []
|
||||
costs = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
|
||||
for i in range(
|
||||
0, len(input_data), self.config.get("embedding_batch_size", 1000)
|
||||
):
|
||||
batch = input_data[
|
||||
i : i + self.config.get("embedding_batch_size", 1000)
|
||||
]
|
||||
batch_results = list(executor.map(get_embeddings_batch, [batch]))
|
||||
|
||||
# Total number of pairs to potentially compare
|
||||
n = len(input_data)
|
||||
total_pairs = n * (n - 1) // 2
|
||||
for result in batch_results:
|
||||
embeddings.extend([r[0] for r in result])
|
||||
costs.extend([r[1] for r in result])
|
||||
|
||||
# Apply code-based blocking conditions (check all pairs)
|
||||
code_blocked_pairs = []
|
||||
if blocking_conditions:
|
||||
for i in range(n):
|
||||
for j in range(i + 1, n):
|
||||
if is_match(input_data[i], input_data[j]):
|
||||
code_blocked_pairs.append((i, j))
|
||||
total_cost += sum(costs)
|
||||
|
||||
# Generate all pairs to compare, ensuring no duplicate comparisons
|
||||
def get_unique_comparison_pairs() -> (
|
||||
tuple[list[tuple[int, int]], dict[tuple[str, ...], list[int]]]
|
||||
):
|
||||
# Create a mapping of values to their indices
|
||||
value_to_indices: dict[tuple[str, ...], list[int]] = {}
|
||||
for i, item in enumerate(input_data):
|
||||
# Create a hashable key from the blocking keys
|
||||
key = tuple(str(item.get(k, "")) for k in blocking_keys)
|
||||
if key not in value_to_indices:
|
||||
value_to_indices[key] = []
|
||||
value_to_indices[key].append(i)
|
||||
|
||||
# Generate pairs for comparison, comparing each unique value combination only once
|
||||
comparison_pairs = []
|
||||
keys = list(value_to_indices.keys())
|
||||
|
||||
# First, handle comparisons between different values
|
||||
for i in range(len(keys)):
|
||||
for j in range(i + 1, len(keys)):
|
||||
# Only need one comparison between different values
|
||||
idx1 = value_to_indices[keys[i]][0]
|
||||
idx2 = value_to_indices[keys[j]][0]
|
||||
if idx1 < idx2: # Maintain ordering to avoid duplicates
|
||||
comparison_pairs.append((idx1, idx2))
|
||||
|
||||
return comparison_pairs, value_to_indices
|
||||
|
||||
comparison_pairs, value_to_indices = get_unique_comparison_pairs()
|
||||
|
||||
# Filter pairs based on blocking conditions
|
||||
def meets_blocking_conditions(pair: tuple[int, int]) -> bool:
|
||||
i, j = pair
|
||||
return (
|
||||
is_match(input_data[i], input_data[j]) if blocking_conditions else False
|
||||
)
|
||||
|
||||
# Start with pairs that meet blocking conditions, or empty list if no conditions
|
||||
code_blocked_pairs = (
|
||||
list(filter(meets_blocking_conditions, comparison_pairs))
|
||||
if blocking_conditions
|
||||
else []
|
||||
)
|
||||
|
||||
# Apply cosine similarity blocking if threshold is specified
|
||||
embedding_blocked_pairs = []
|
||||
if blocking_threshold is not None and embeddings is not None:
|
||||
import numpy as np
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
similarity_matrix = cosine_similarity(embeddings)
|
||||
|
||||
# Add pairs that meet the cosine similarity threshold and aren't already blocked
|
||||
code_blocked_set = set(code_blocked_pairs)
|
||||
|
||||
# Use numpy to efficiently find all pairs above threshold
|
||||
i_indices, j_indices = np.triu_indices(n, k=1)
|
||||
similarities = similarity_matrix[i_indices, j_indices]
|
||||
above_threshold_mask = similarities >= blocking_threshold
|
||||
for i, j in comparison_pairs:
|
||||
if (i, j) not in code_blocked_set:
|
||||
similarity = similarity_matrix[i, j]
|
||||
if similarity >= blocking_threshold:
|
||||
embedding_blocked_pairs.append((i, j))
|
||||
|
||||
# Get pairs above threshold
|
||||
above_threshold_i = i_indices[above_threshold_mask]
|
||||
above_threshold_j = j_indices[above_threshold_mask]
|
||||
self.console.log(
|
||||
f"Cosine similarity blocking: added {len(embedding_blocked_pairs)} pairs "
|
||||
f"(threshold: {blocking_threshold})"
|
||||
)
|
||||
|
||||
# Filter out pairs already in code_blocked_set
|
||||
embedding_blocked_pairs = [
|
||||
(int(i), int(j))
|
||||
for i, j in zip(above_threshold_i, above_threshold_j)
|
||||
if (i, j) not in code_blocked_set
|
||||
]
|
||||
|
||||
# Combine pairs from both blocking methods
|
||||
# Combine pairs with prioritization for sampling
|
||||
all_blocked_pairs = code_blocked_pairs + embedding_blocked_pairs
|
||||
|
||||
# If no blocking was applied, compare all pairs
|
||||
if not blocking_conditions and blocking_threshold is None:
|
||||
all_blocked_pairs = [(i, j) for i in range(n) for j in range(i + 1, n)]
|
||||
# If no pairs are blocked at all, fall back to all comparison pairs
|
||||
if not all_blocked_pairs:
|
||||
all_blocked_pairs = comparison_pairs
|
||||
# Apply limit_comparisons with prioritization
|
||||
if limit_comparisons is not None and len(all_blocked_pairs) > limit_comparisons:
|
||||
# Prioritize code-based pairs, then sample from embedding pairs if needed
|
||||
|
|
@ -507,6 +431,18 @@ class ResolveOperation(BaseOperation):
|
|||
cluster_map[root_idx] = root1
|
||||
clusters[root_idx] = set()
|
||||
|
||||
# Calculate and print statistics
|
||||
total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2
|
||||
comparisons_made = len(blocked_pairs)
|
||||
comparisons_saved = total_possible_comparisons - comparisons_made
|
||||
self.console.log(
|
||||
f"[green]Comparisons saved by deduping and blocking: {comparisons_saved} "
|
||||
f"({(comparisons_saved / total_possible_comparisons) * 100:.2f}%)[/green]"
|
||||
)
|
||||
self.console.log(
|
||||
f"[blue]Number of pairs to compare: {len(blocked_pairs)}[/blue]"
|
||||
)
|
||||
|
||||
# Compute an auto-batch size based on the number of comparisons
|
||||
def auto_batch() -> int:
|
||||
# Maximum batch size limit for 4o-mini model
|
||||
|
|
@ -532,14 +468,7 @@ class ResolveOperation(BaseOperation):
|
|||
|
||||
# Compare pairs and update clusters in real-time
|
||||
batch_size = self.config.get("compare_batch_size", auto_batch())
|
||||
|
||||
# Log blocking summary
|
||||
total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2
|
||||
self.console.log(
|
||||
f"Comparing {len(blocked_pairs):,} pairs "
|
||||
f"({len(blocked_pairs)/total_possible_comparisons*100:.1f}% of {total_possible_comparisons:,} total, "
|
||||
f"batch size: {batch_size})"
|
||||
)
|
||||
self.console.log(f"Using compare batch size: {batch_size}")
|
||||
pair_costs = 0
|
||||
|
||||
pbar = RichLoopBar(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from .api import APIWrapper
|
||||
from .blocking import RuntimeBlockingOptimizer
|
||||
from .cache import (
|
||||
cache,
|
||||
cache_key,
|
||||
|
|
@ -16,7 +15,6 @@ from .validation import safe_eval, convert_val, convert_dict_schema_to_list_sche
|
|||
|
||||
__all__ = [
|
||||
'APIWrapper',
|
||||
'RuntimeBlockingOptimizer',
|
||||
'cache',
|
||||
'cache_key',
|
||||
'clear_cache',
|
||||
|
|
|
|||
|
|
@ -86,35 +86,23 @@ class APIWrapper(object):
|
|||
from litellm import Router
|
||||
|
||||
# Build model list: operation model first, then fallbacks
|
||||
model_list = [
|
||||
{
|
||||
"model_name": operation_model,
|
||||
"litellm_params": {
|
||||
"model": operation_model,
|
||||
**(
|
||||
{"api_base": self.default_lm_api_base}
|
||||
if self.default_lm_api_base
|
||||
else {}
|
||||
),
|
||||
},
|
||||
model_list = [{
|
||||
"model_name": operation_model,
|
||||
"litellm_params": {
|
||||
"model": operation_model,
|
||||
**({"api_base": self.default_lm_api_base} if self.default_lm_api_base else {})
|
||||
}
|
||||
]
|
||||
}]
|
||||
model_names = [operation_model]
|
||||
|
||||
# Add fallback models, skipping duplicates
|
||||
seen = {operation_model}
|
||||
for cfg in self.fallback_models_config:
|
||||
name = (
|
||||
cfg.get("model_name")
|
||||
if isinstance(cfg, dict)
|
||||
else (cfg if isinstance(cfg, str) else None)
|
||||
)
|
||||
name = cfg.get("model_name") if isinstance(cfg, dict) else (cfg if isinstance(cfg, str) else None)
|
||||
if not name or name in seen:
|
||||
continue
|
||||
seen.add(name)
|
||||
params = (
|
||||
cfg.get("litellm_params", {}).copy() if isinstance(cfg, dict) else {}
|
||||
)
|
||||
params = cfg.get("litellm_params", {}).copy() if isinstance(cfg, dict) else {}
|
||||
params["model"] = name
|
||||
if self.default_lm_api_base and "api_base" not in params:
|
||||
params["api_base"] = self.default_lm_api_base
|
||||
|
|
@ -183,11 +171,7 @@ class APIWrapper(object):
|
|||
extra_kwargs["api_base"] = self.default_embedding_api_base
|
||||
|
||||
# Use embedding router if available (for fallback models)
|
||||
embedding_fn = (
|
||||
self.embedding_router.embedding
|
||||
if self.embedding_router
|
||||
else embedding
|
||||
)
|
||||
embedding_fn = self.embedding_router.embedding if self.embedding_router else embedding
|
||||
result = embedding_fn(model=model, input=input, **extra_kwargs)
|
||||
# Cache the result
|
||||
c.set(key, result)
|
||||
|
|
@ -277,10 +261,6 @@ class APIWrapper(object):
|
|||
):
|
||||
model = "azure/" + model
|
||||
|
||||
# Pop off temperature if it's gpt-5 in the model name
|
||||
if "gpt-5" in model:
|
||||
litellm_completion_kwargs.pop("temperature", None)
|
||||
|
||||
total_cost = 0.0
|
||||
validated = False
|
||||
with cache as c:
|
||||
|
|
@ -381,9 +361,7 @@ class APIWrapper(object):
|
|||
# Use router if available (for fallback models), otherwise use direct completion
|
||||
# When using router, ensure gleaning model is tried first, then fallback models
|
||||
if self.router and self.fallback_models_config:
|
||||
completion_fn = self._get_router_with_operation_model(
|
||||
gleaning_model
|
||||
)
|
||||
completion_fn = self._get_router_with_operation_model(gleaning_model)
|
||||
else:
|
||||
completion_fn = completion
|
||||
|
||||
|
|
@ -870,10 +848,6 @@ Your main result must be sent via send_output. The updated_scratchpad is only fo
|
|||
else:
|
||||
completion_fn = completion
|
||||
|
||||
# Pop off temperature if it's gpt-5 in the model name
|
||||
if "gpt-5" in model:
|
||||
extra_litellm_kwargs.pop("temperature", None)
|
||||
|
||||
if use_structured_output:
|
||||
try:
|
||||
response = completion_fn(
|
||||
|
|
|
|||
|
|
@ -1,567 +0,0 @@
|
|||
"""
|
||||
Runtime blocking threshold optimization utilities.
|
||||
|
||||
This module provides functionality for automatically computing embedding-based
|
||||
blocking thresholds at runtime when no blocking configuration is provided.
|
||||
"""
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
from litellm import model_cost
|
||||
from rich.console import Console
|
||||
|
||||
from docetl.utils import completion_cost, extract_jinja_variables
|
||||
|
||||
|
||||
class RuntimeBlockingOptimizer:
|
||||
"""
|
||||
Computes optimal embedding-based blocking thresholds at runtime.
|
||||
|
||||
This class samples pairs from the dataset, performs LLM comparisons,
|
||||
and finds the optimal cosine similarity threshold that achieves a
|
||||
target recall rate.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runner,
|
||||
config: dict[str, Any],
|
||||
default_model: str,
|
||||
max_threads: int,
|
||||
console: Console,
|
||||
target_recall: float = 0.95,
|
||||
sample_size: int = 100,
|
||||
sampling_weight: float = 20.0,
|
||||
):
|
||||
"""
|
||||
Initialize the RuntimeBlockingOptimizer.
|
||||
|
||||
Args:
|
||||
runner: The pipeline runner instance.
|
||||
config: Operation configuration.
|
||||
default_model: Default LLM model for comparisons.
|
||||
max_threads: Maximum threads for parallel processing.
|
||||
console: Rich console for logging.
|
||||
target_recall: Target recall rate (default 0.95).
|
||||
sample_size: Number of pairs to sample for threshold estimation.
|
||||
sampling_weight: Weight for exponential sampling towards higher similarities.
|
||||
"""
|
||||
self.runner = runner
|
||||
self.config = config
|
||||
self.default_model = default_model
|
||||
self.max_threads = max_threads
|
||||
self.console = console
|
||||
self.target_recall = target_recall
|
||||
self.sample_size = sample_size
|
||||
self.sampling_weight = sampling_weight
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
input_data: list[dict[str, Any]],
|
||||
keys: list[str],
|
||||
embedding_model: str | None = None,
|
||||
batch_size: int = 1000,
|
||||
) -> tuple[list[list[float]], float]:
|
||||
"""
|
||||
Compute embeddings for the input data.
|
||||
|
||||
Args:
|
||||
input_data: List of input documents.
|
||||
keys: Keys to use for embedding text.
|
||||
embedding_model: Model to use for embeddings.
|
||||
batch_size: Batch size for embedding computation.
|
||||
|
||||
Returns:
|
||||
Tuple of (embeddings list, total cost).
|
||||
"""
|
||||
embedding_model = embedding_model or self.config.get(
|
||||
"embedding_model", "text-embedding-3-small"
|
||||
)
|
||||
model_input_context_length = model_cost.get(embedding_model, {}).get(
|
||||
"max_input_tokens", 8192
|
||||
)
|
||||
texts = [
|
||||
" ".join(str(item[key]) for key in keys if key in item)[
|
||||
: model_input_context_length * 3
|
||||
]
|
||||
for item in input_data
|
||||
]
|
||||
|
||||
self.console.log(f"[cyan]Creating embeddings for {len(texts)} items...[/cyan]")
|
||||
|
||||
embeddings = []
|
||||
total_cost = 0.0
|
||||
num_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
for batch_idx, i in enumerate(range(0, len(texts), batch_size)):
|
||||
batch = texts[i : i + batch_size]
|
||||
if num_batches > 1:
|
||||
self.console.log(
|
||||
f"[dim] Batch {batch_idx + 1}/{num_batches} "
|
||||
f"({len(embeddings) + len(batch)}/{len(texts)} items)[/dim]"
|
||||
)
|
||||
response = self.runner.api.gen_embedding(
|
||||
model=embedding_model,
|
||||
input=batch,
|
||||
)
|
||||
embeddings.extend([data["embedding"] for data in response["data"]])
|
||||
total_cost += completion_cost(response)
|
||||
return embeddings, total_cost
|
||||
|
||||
def calculate_cosine_similarities_self(
|
||||
self, embeddings: list[list[float]]
|
||||
) -> list[tuple[int, int, float]]:
|
||||
"""
|
||||
Calculate pairwise cosine similarities for self-join.
|
||||
|
||||
Args:
|
||||
embeddings: List of embedding vectors.
|
||||
|
||||
Returns:
|
||||
List of (i, j, similarity) tuples for all pairs where i < j.
|
||||
"""
|
||||
embeddings_array = np.array(embeddings)
|
||||
norms = np.linalg.norm(embeddings_array, axis=1)
|
||||
# Avoid division by zero
|
||||
norms = np.where(norms == 0, 1e-10, norms)
|
||||
dot_products = np.dot(embeddings_array, embeddings_array.T)
|
||||
similarities_matrix = dot_products / np.outer(norms, norms)
|
||||
i, j = np.triu_indices(len(embeddings), k=1)
|
||||
similarities = list(
|
||||
zip(i.tolist(), j.tolist(), similarities_matrix[i, j].tolist())
|
||||
)
|
||||
return similarities
|
||||
|
||||
def calculate_cosine_similarities_cross(
|
||||
self,
|
||||
left_embeddings: list[list[float]],
|
||||
right_embeddings: list[list[float]],
|
||||
) -> list[tuple[int, int, float]]:
|
||||
"""
|
||||
Calculate cosine similarities between two sets of embeddings.
|
||||
|
||||
Args:
|
||||
left_embeddings: Embeddings for left dataset.
|
||||
right_embeddings: Embeddings for right dataset.
|
||||
|
||||
Returns:
|
||||
List of (left_idx, right_idx, similarity) tuples.
|
||||
"""
|
||||
left_array = np.array(left_embeddings)
|
||||
right_array = np.array(right_embeddings)
|
||||
dot_product = np.dot(left_array, right_array.T)
|
||||
norm_left = np.linalg.norm(left_array, axis=1)
|
||||
norm_right = np.linalg.norm(right_array, axis=1)
|
||||
# Avoid division by zero
|
||||
norm_left = np.where(norm_left == 0, 1e-10, norm_left)
|
||||
norm_right = np.where(norm_right == 0, 1e-10, norm_right)
|
||||
similarities = dot_product / np.outer(norm_left, norm_right)
|
||||
return [
|
||||
(i, j, float(sim))
|
||||
for i, row in enumerate(similarities)
|
||||
for j, sim in enumerate(row)
|
||||
]
|
||||
|
||||
def sample_pairs(
|
||||
self,
|
||||
similarities: list[tuple[int, int, float]],
|
||||
num_bins: int = 10,
|
||||
stratified_fraction: float = 0.5,
|
||||
) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Sample pairs using a hybrid of stratified and exponential-weighted sampling.
|
||||
|
||||
This ensures coverage across the similarity distribution while still
|
||||
focusing on high-similarity pairs where matches are more likely.
|
||||
|
||||
Args:
|
||||
similarities: List of (i, j, similarity) tuples.
|
||||
num_bins: Number of bins for stratified sampling.
|
||||
stratified_fraction: Fraction of samples to allocate to stratified sampling.
|
||||
|
||||
Returns:
|
||||
List of sampled (i, j) pairs.
|
||||
"""
|
||||
if len(similarities) == 0:
|
||||
return []
|
||||
|
||||
sample_count = min(self.sample_size, len(similarities))
|
||||
stratified_count = int(sample_count * stratified_fraction)
|
||||
exponential_count = sample_count - stratified_count
|
||||
|
||||
sampled_indices = set()
|
||||
sim_values = np.array([sim[2] for sim in similarities])
|
||||
|
||||
# Part 1: Stratified sampling across bins
|
||||
if stratified_count > 0:
|
||||
bin_edges = np.linspace(
|
||||
sim_values.min(), sim_values.max() + 1e-9, num_bins + 1
|
||||
)
|
||||
samples_per_bin = max(1, stratified_count // num_bins)
|
||||
|
||||
for bin_idx in range(num_bins):
|
||||
bin_mask = (sim_values >= bin_edges[bin_idx]) & (
|
||||
sim_values < bin_edges[bin_idx + 1]
|
||||
)
|
||||
bin_indices = np.where(bin_mask)[0]
|
||||
|
||||
if len(bin_indices) > 0:
|
||||
# Within each bin, use exponential weighting
|
||||
bin_sims = sim_values[bin_indices]
|
||||
bin_weights = np.exp(self.sampling_weight * bin_sims)
|
||||
bin_weights /= bin_weights.sum()
|
||||
|
||||
n_to_sample = min(samples_per_bin, len(bin_indices))
|
||||
chosen = np.random.choice(
|
||||
bin_indices,
|
||||
size=n_to_sample,
|
||||
replace=False,
|
||||
p=bin_weights,
|
||||
)
|
||||
sampled_indices.update(chosen.tolist())
|
||||
|
||||
# Part 2: Exponential-weighted sampling for remaining slots
|
||||
if exponential_count > 0:
|
||||
remaining_indices = [
|
||||
i for i in range(len(similarities)) if i not in sampled_indices
|
||||
]
|
||||
if remaining_indices:
|
||||
remaining_sims = sim_values[remaining_indices]
|
||||
weights = np.exp(self.sampling_weight * remaining_sims)
|
||||
weights /= weights.sum()
|
||||
|
||||
n_to_sample = min(exponential_count, len(remaining_indices))
|
||||
chosen = np.random.choice(
|
||||
remaining_indices,
|
||||
size=n_to_sample,
|
||||
replace=False,
|
||||
p=weights,
|
||||
)
|
||||
sampled_indices.update(chosen.tolist())
|
||||
|
||||
sampled_pairs = [
|
||||
(similarities[i][0], similarities[i][1]) for i in sampled_indices
|
||||
]
|
||||
return sampled_pairs
|
||||
|
||||
def _print_similarity_histogram(
|
||||
self,
|
||||
similarities: list[tuple[int, int, float]],
|
||||
comparison_results: list[tuple[int, int, bool]],
|
||||
threshold: float | None = None,
|
||||
):
|
||||
"""
|
||||
Print a histogram of embedding cosine similarity distribution.
|
||||
|
||||
Args:
|
||||
similarities: List of (i, j, similarity) tuples.
|
||||
comparison_results: List of (i, j, is_match) from LLM comparisons.
|
||||
threshold: Optional threshold to highlight in the histogram.
|
||||
"""
|
||||
# Filter out self-similarities (similarity == 1)
|
||||
flat_similarities = [sim[2] for sim in similarities if sim[2] != 1]
|
||||
if not flat_similarities:
|
||||
return
|
||||
|
||||
hist, bin_edges = np.histogram(flat_similarities, bins=20)
|
||||
max_bar_width, max_count = 40, max(hist) if max(hist) > 0 else 1
|
||||
normalized_hist = [int(count / max_count * max_bar_width) for count in hist]
|
||||
|
||||
# Create a dictionary to store true labels
|
||||
true_labels = {(i, j): is_match for i, j, is_match in comparison_results}
|
||||
|
||||
# Count pairs above threshold
|
||||
pairs_above_threshold = (
|
||||
sum(1 for sim in flat_similarities if sim >= threshold) if threshold else 0
|
||||
)
|
||||
total_pairs = len(flat_similarities)
|
||||
|
||||
lines = []
|
||||
for i, count in enumerate(normalized_hist):
|
||||
bar = "█" * count
|
||||
bin_start, bin_end = bin_edges[i], bin_edges[i + 1]
|
||||
label = f"{bin_start:.2f}-{bin_end:.2f}"
|
||||
|
||||
# Count true matches and not matches in this bin
|
||||
true_matches = 0
|
||||
not_matches = 0
|
||||
labeled_count = 0
|
||||
for sim in similarities:
|
||||
if bin_start <= sim[2] < bin_end:
|
||||
if (sim[0], sim[1]) in true_labels:
|
||||
labeled_count += 1
|
||||
if true_labels[(sim[0], sim[1])]:
|
||||
true_matches += 1
|
||||
else:
|
||||
not_matches += 1
|
||||
|
||||
# Calculate percentages of labeled pairs
|
||||
if labeled_count > 0:
|
||||
true_match_percent = (true_matches / labeled_count) * 100
|
||||
label_info = f"[green]{true_match_percent:5.1f}%[/green] match"
|
||||
else:
|
||||
label_info = "[dim]--[/dim]"
|
||||
|
||||
# Highlight the bin containing the threshold
|
||||
if threshold is not None and bin_start <= threshold < bin_end:
|
||||
lines.append(
|
||||
f"[bold yellow]{label}[/bold yellow] {bar:<{max_bar_width}} "
|
||||
f"[dim]n={hist[i]:>5}[/dim] {label_info} [bold yellow]◀ threshold[/bold yellow]"
|
||||
)
|
||||
else:
|
||||
lines.append(
|
||||
f"{label} {bar:<{max_bar_width}} "
|
||||
f"[dim]n={hist[i]:>5}[/dim] {label_info}"
|
||||
)
|
||||
|
||||
from rich.panel import Panel
|
||||
|
||||
histogram_content = "\n".join(lines)
|
||||
title = f"Similarity Distribution ({pairs_above_threshold:,} of {total_pairs:,} pairs ≥ {threshold:.4f})"
|
||||
self.console.log(Panel(histogram_content, title=title, border_style="cyan"))
|
||||
|
||||
def find_optimal_threshold(
|
||||
self,
|
||||
comparisons: list[tuple[int, int, bool]],
|
||||
similarities: list[tuple[int, int, float]],
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Find the optimal similarity threshold that achieves target recall.
|
||||
|
||||
Args:
|
||||
comparisons: List of (i, j, is_match) from LLM comparisons.
|
||||
similarities: List of (i, j, similarity) tuples.
|
||||
|
||||
Returns:
|
||||
Tuple of (optimal_threshold, achieved_recall).
|
||||
"""
|
||||
if not comparisons or not any(comp[2] for comp in comparisons):
|
||||
# No matches found, use a high threshold to be conservative
|
||||
self.console.log(
|
||||
"[yellow]No matches found in sample. Using 99th percentile "
|
||||
"similarity as threshold.[/yellow]"
|
||||
)
|
||||
all_sims = [sim[2] for sim in similarities]
|
||||
threshold = float(np.percentile(all_sims, 99)) if all_sims else 0.9
|
||||
return threshold, 0.0
|
||||
|
||||
true_labels = np.array([comp[2] for comp in comparisons])
|
||||
sim_dict = {(i, j): sim for i, j, sim in similarities}
|
||||
sim_scores = np.array([sim_dict.get((i, j), 0.0) for i, j, _ in comparisons])
|
||||
thresholds = np.linspace(0, 1, 100)
|
||||
recalls = []
|
||||
for threshold in thresholds:
|
||||
predictions = sim_scores >= threshold
|
||||
tp = np.sum(predictions & true_labels)
|
||||
fn = np.sum(~predictions & true_labels)
|
||||
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
||||
recalls.append(recall)
|
||||
|
||||
# Find highest threshold that achieves target recall
|
||||
valid_indices = [i for i, r in enumerate(recalls) if r >= self.target_recall]
|
||||
if not valid_indices:
|
||||
# If no threshold achieves target recall, use the one with highest recall
|
||||
best_idx = int(np.argmax(recalls))
|
||||
optimal_threshold = float(thresholds[best_idx])
|
||||
achieved_recall = float(recalls[best_idx])
|
||||
self.console.log(
|
||||
f"[yellow]Warning: Could not achieve target recall {self.target_recall:.0%}. "
|
||||
f"Using threshold {optimal_threshold:.4f} with recall {achieved_recall:.2%}.[/yellow]"
|
||||
)
|
||||
else:
|
||||
best_idx = max(valid_indices)
|
||||
optimal_threshold = float(thresholds[best_idx])
|
||||
achieved_recall = float(recalls[best_idx])
|
||||
|
||||
return round(optimal_threshold, 4), achieved_recall
|
||||
|
||||
def optimize_resolve(
|
||||
self,
|
||||
input_data: list[dict[str, Any]],
|
||||
compare_fn: Callable[[dict, dict], tuple[bool, float, str]],
|
||||
blocking_keys: list[str] | None = None,
|
||||
) -> tuple[float, list[list[float]], float]:
|
||||
"""
|
||||
Compute optimal blocking threshold for resolve operation.
|
||||
|
||||
Args:
|
||||
input_data: Input dataset.
|
||||
compare_fn: Function to compare two items, returns (is_match, cost, prompt).
|
||||
blocking_keys: Keys to use for blocking. If None, extracted from prompt.
|
||||
|
||||
Returns:
|
||||
Tuple of (threshold, embeddings, total_cost).
|
||||
"""
|
||||
from rich.panel import Panel
|
||||
|
||||
# Determine blocking keys
|
||||
if not blocking_keys:
|
||||
prompt_template = self.config.get("comparison_prompt", "")
|
||||
prompt_vars = extract_jinja_variables(prompt_template)
|
||||
prompt_vars = [
|
||||
var for var in prompt_vars if var not in ["input", "input1", "input2"]
|
||||
]
|
||||
blocking_keys = list(set([var.split(".")[-1] for var in prompt_vars]))
|
||||
if not blocking_keys:
|
||||
blocking_keys = list(input_data[0].keys())
|
||||
|
||||
# Compute embeddings
|
||||
embeddings, embedding_cost = self.compute_embeddings(input_data, blocking_keys)
|
||||
|
||||
# Calculate similarities
|
||||
similarities = self.calculate_cosine_similarities_self(embeddings)
|
||||
|
||||
# Sample pairs
|
||||
sampled_pairs = self.sample_pairs(similarities)
|
||||
if not sampled_pairs:
|
||||
self.console.log(
|
||||
"[yellow]No pairs to sample. Using default threshold 0.8.[/yellow]"
|
||||
)
|
||||
return 0.8, embeddings, embedding_cost
|
||||
|
||||
# Perform comparisons
|
||||
comparisons = []
|
||||
comparison_cost = 0.0
|
||||
matches_found = 0
|
||||
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
|
||||
futures = {
|
||||
executor.submit(compare_fn, input_data[i], input_data[j]): (i, j)
|
||||
for i, j in sampled_pairs
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
i, j = futures[future]
|
||||
try:
|
||||
is_match, cost, _ = future.result()
|
||||
comparisons.append((i, j, is_match))
|
||||
comparison_cost += cost
|
||||
if is_match:
|
||||
matches_found += 1
|
||||
except Exception as e:
|
||||
self.console.log(f"[red]Comparison error: {e}[/red]")
|
||||
comparisons.append((i, j, False))
|
||||
|
||||
# Find optimal threshold
|
||||
threshold, achieved_recall = self.find_optimal_threshold(
|
||||
comparisons, similarities
|
||||
)
|
||||
total_cost = embedding_cost + comparison_cost
|
||||
|
||||
# Print histogram visualization
|
||||
self._print_similarity_histogram(similarities, comparisons, threshold)
|
||||
|
||||
# Print summary
|
||||
n = len(input_data)
|
||||
total_pairs = n * (n - 1) // 2
|
||||
pairs_above = sum(1 for s in similarities if s[2] >= threshold)
|
||||
|
||||
summary = (
|
||||
f"[bold]Blocking keys:[/bold] {blocking_keys}\n"
|
||||
f"[bold]Sampled:[/bold] {len(sampled_pairs)} pairs → {matches_found} matches ({matches_found/len(sampled_pairs)*100:.1f}%)\n"
|
||||
f"[bold]Threshold:[/bold] {threshold:.4f} → {achieved_recall:.1%} recall (target: {self.target_recall:.0%})\n"
|
||||
f"[bold]Pairs to compare:[/bold] {pairs_above:,} of {total_pairs:,} ({pairs_above/total_pairs*100:.1f}%)\n"
|
||||
f"[bold]Optimization cost:[/bold] ${total_cost:.4f}"
|
||||
)
|
||||
self.console.log(
|
||||
Panel(
|
||||
summary, title="Blocking Threshold Optimization", border_style="green"
|
||||
)
|
||||
)
|
||||
|
||||
return threshold, embeddings, total_cost
|
||||
|
||||
def optimize_equijoin(
|
||||
self,
|
||||
left_data: list[dict[str, Any]],
|
||||
right_data: list[dict[str, Any]],
|
||||
compare_fn: Callable[[dict, dict], tuple[bool, float]],
|
||||
left_keys: list[str] | None = None,
|
||||
right_keys: list[str] | None = None,
|
||||
) -> tuple[float, list[list[float]], list[list[float]], float]:
|
||||
"""
|
||||
Compute optimal blocking threshold for equijoin operation.
|
||||
|
||||
Args:
|
||||
left_data: Left dataset.
|
||||
right_data: Right dataset.
|
||||
compare_fn: Function to compare two items, returns (is_match, cost).
|
||||
left_keys: Keys to use for left dataset embeddings.
|
||||
right_keys: Keys to use for right dataset embeddings.
|
||||
|
||||
Returns:
|
||||
Tuple of (threshold, left_embeddings, right_embeddings, total_cost).
|
||||
"""
|
||||
from rich.panel import Panel
|
||||
|
||||
# Determine keys
|
||||
if not left_keys:
|
||||
left_keys = list(left_data[0].keys()) if left_data else []
|
||||
if not right_keys:
|
||||
right_keys = list(right_data[0].keys()) if right_data else []
|
||||
|
||||
# Compute embeddings
|
||||
left_embeddings, left_cost = self.compute_embeddings(left_data, left_keys)
|
||||
right_embeddings, right_cost = self.compute_embeddings(right_data, right_keys)
|
||||
embedding_cost = left_cost + right_cost
|
||||
|
||||
# Calculate cross similarities
|
||||
similarities = self.calculate_cosine_similarities_cross(
|
||||
left_embeddings, right_embeddings
|
||||
)
|
||||
|
||||
# Sample pairs
|
||||
sampled_pairs = self.sample_pairs(similarities)
|
||||
if not sampled_pairs:
|
||||
self.console.log(
|
||||
"[yellow]No pairs to sample. Using default threshold 0.8.[/yellow]"
|
||||
)
|
||||
return 0.8, left_embeddings, right_embeddings, embedding_cost
|
||||
|
||||
# Perform comparisons
|
||||
comparisons = []
|
||||
comparison_cost = 0.0
|
||||
matches_found = 0
|
||||
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
|
||||
futures = {
|
||||
executor.submit(compare_fn, left_data[i], right_data[j]): (i, j)
|
||||
for i, j in sampled_pairs
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
i, j = futures[future]
|
||||
try:
|
||||
is_match, cost = future.result()
|
||||
comparisons.append((i, j, is_match))
|
||||
comparison_cost += cost
|
||||
if is_match:
|
||||
matches_found += 1
|
||||
except Exception as e:
|
||||
self.console.log(f"[red]Comparison error: {e}[/red]")
|
||||
comparisons.append((i, j, False))
|
||||
|
||||
# Find optimal threshold
|
||||
threshold, achieved_recall = self.find_optimal_threshold(
|
||||
comparisons, similarities
|
||||
)
|
||||
total_cost = embedding_cost + comparison_cost
|
||||
|
||||
# Print histogram visualization
|
||||
self._print_similarity_histogram(similarities, comparisons, threshold)
|
||||
|
||||
# Print summary
|
||||
total_pairs = len(left_data) * len(right_data)
|
||||
pairs_above = sum(1 for s in similarities if s[2] >= threshold)
|
||||
|
||||
summary = (
|
||||
f"[bold]Left keys:[/bold] {left_keys} [bold]Right keys:[/bold] {right_keys}\n"
|
||||
f"[bold]Sampled:[/bold] {len(sampled_pairs)} pairs → {matches_found} matches ({matches_found/len(sampled_pairs)*100:.1f}%)\n"
|
||||
f"[bold]Threshold:[/bold] {threshold:.4f} → {achieved_recall:.1%} recall (target: {self.target_recall:.0%})\n"
|
||||
f"[bold]Pairs to compare:[/bold] {pairs_above:,} of {total_pairs:,} ({pairs_above/total_pairs*100:.1f}%)\n"
|
||||
f"[bold]Optimization cost:[/bold] ${total_cost:.4f}"
|
||||
)
|
||||
self.console.log(
|
||||
Panel(
|
||||
summary, title="Blocking Threshold Optimization", border_style="green"
|
||||
)
|
||||
)
|
||||
|
||||
return threshold, left_embeddings, right_embeddings, total_cost
|
||||
|
|
@ -7,8 +7,6 @@ from jinja2.exceptions import UndefinedError
|
|||
from rich import print as rprint
|
||||
from rich.prompt import Prompt
|
||||
|
||||
from docetl.utils import has_jinja_syntax
|
||||
|
||||
aeval = Interpreter()
|
||||
|
||||
|
||||
|
|
@ -30,44 +28,17 @@ def strict_render(template: Template | str, context: dict[str, Any]) -> str:
|
|||
# Create strict environment
|
||||
env = Environment(undefined=StrictUndefined)
|
||||
|
||||
# Only process string templates for non-Jinja syntax check
|
||||
# Convert string to Template if needed
|
||||
if isinstance(template, str):
|
||||
template_string = template
|
||||
|
||||
# Check if template doesn't have Jinja syntax and append document statement
|
||||
if not has_jinja_syntax(template_string):
|
||||
# Determine the operation type based on context variables
|
||||
if "left" in context and "right" in context:
|
||||
# Equijoin operation - append both documents
|
||||
template_string = (
|
||||
f"{template_string}\n\nHere are the documents:\n"
|
||||
f"Left document: {{{{ left }}}}\n"
|
||||
f"Right document: {{{{ right }}}}"
|
||||
)
|
||||
elif "input1" in context and "input2" in context:
|
||||
# Comparison operation (resolve) - append both documents
|
||||
template_string = (
|
||||
f"{template_string}\n\nHere are the documents:\n"
|
||||
f"Document 1: {{{{ input1 }}}}\n"
|
||||
f"Document 2: {{{{ input2 }}}}"
|
||||
)
|
||||
elif "inputs" in context:
|
||||
# Reduce operation - append "Here are the documents: {{ inputs }}"
|
||||
template_string = (
|
||||
f"{template_string}\n\nHere are the documents: {{{{ inputs }}}}"
|
||||
)
|
||||
elif "input" in context:
|
||||
# Regular operation - append "Here is the document: {{ input }}"
|
||||
template_string = (
|
||||
f"{template_string}\n\nHere is the document: {{{{ input }}}}"
|
||||
)
|
||||
# # If "inputs" in the context, make sure they are not accessing some attribute of inputs
|
||||
# if "inputs" in context and "{{ inputs." in template:
|
||||
# raise UndefinedError("The inputs variable is a list, so you cannot access attributes of inputs. Use inputs[index].key instead.")
|
||||
|
||||
# Convert string template to Template object
|
||||
try:
|
||||
template = env.from_string(template_string)
|
||||
template = env.from_string(template)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid template: {str(e)}")
|
||||
# If template is already a Template object, use it as-is
|
||||
|
||||
try:
|
||||
return template.render(context)
|
||||
|
|
@ -151,7 +122,7 @@ def _is_integer(value: Any) -> bool:
|
|||
|
||||
def _is_number(value: Any) -> bool:
|
||||
"""Return True if value is a real number (int or float) but not a bool."""
|
||||
return isinstance(value, (int, float)) and not isinstance(value, bool)
|
||||
return (isinstance(value, (int, float)) and not isinstance(value, bool))
|
||||
|
||||
|
||||
def _validate_scalar(value: Any, schema: dict[str, Any]) -> bool:
|
||||
|
|
@ -205,9 +176,7 @@ def _validate_value_against_schema(value: Any, schema: dict[str, Any]) -> bool:
|
|||
return False
|
||||
# Validate known properties
|
||||
for key, prop_schema in properties.items():
|
||||
if key in value and not _validate_value_against_schema(
|
||||
value[key], prop_schema
|
||||
):
|
||||
if key in value and not _validate_value_against_schema(value[key], prop_schema):
|
||||
return False
|
||||
# additionalProperties constraint
|
||||
if not additional_props_allowed:
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class Optimizer:
|
|||
def __init__(
|
||||
self,
|
||||
runner: "DSLRunner",
|
||||
rewrite_agent_model: str = "gpt-5.1",
|
||||
rewrite_agent_model: str = "gpt-4o",
|
||||
judge_agent_model: str = "gpt-4o-mini",
|
||||
litellm_kwargs: dict[str, Any] = {},
|
||||
resume: bool = False,
|
||||
|
|
@ -67,7 +67,7 @@ class Optimizer:
|
|||
|
||||
Args:
|
||||
yaml_file (str): Path to the YAML configuration file.
|
||||
model (str): The name of the language model to use. Defaults to "gpt-5.1".
|
||||
model (str): The name of the language model to use. Defaults to "gpt-4o".
|
||||
resume (bool): Whether to resume optimization from a previous run. Defaults to False.
|
||||
timeout (int): Timeout in seconds for operations. Defaults to 60.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,926 +0,0 @@
|
|||
"""
|
||||
Fast decomposition analyzer for map operations.
|
||||
|
||||
This module provides a fast way to decompose map operations using directives,
|
||||
running candidates on samples, and selecting the best via pairwise LLM comparison.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm import completion, model_cost
|
||||
from rich.console import Console
|
||||
|
||||
from docetl.console import get_console
|
||||
from docetl.reasoning_optimizer.directives import (
|
||||
ChainingDirective,
|
||||
ClarifyInstructionsDirective,
|
||||
DeterministicDocCompressionDirective,
|
||||
DocumentChunkingDirective,
|
||||
GleaningDirective,
|
||||
IsolatingSubtasksDirective,
|
||||
)
|
||||
from docetl.runner import DSLRunner
|
||||
from docetl.utils import completion_cost, count_tokens
|
||||
|
||||
# Drop unsupported params for models like gpt-5 that don't support temperature=0
|
||||
litellm.drop_params = True
|
||||
|
||||
# Base directives always applicable to map operations
|
||||
BASE_MAP_DIRECTIVES = [
|
||||
ChainingDirective(),
|
||||
IsolatingSubtasksDirective(),
|
||||
GleaningDirective(),
|
||||
ClarifyInstructionsDirective(),
|
||||
]
|
||||
|
||||
# Directives that depend on document size
|
||||
CHUNKING_DIRECTIVE = DocumentChunkingDirective()
|
||||
COMPRESSION_DIRECTIVE = DeterministicDocCompressionDirective()
|
||||
|
||||
# Threshold for enabling DeterministicDocCompression (in characters)
|
||||
DOC_COMPRESSION_CHAR_THRESHOLD = 1000
|
||||
|
||||
# Threshold for enabling DocumentChunking (as fraction of context window)
|
||||
DOC_CHUNKING_CONTEXT_THRESHOLD = 0.10 # 10%
|
||||
|
||||
|
||||
class FastDecomposer:
|
||||
"""
|
||||
Fast decomposition of map operations using directives and pairwise comparison.
|
||||
|
||||
Instead of the full optimizer flow, this:
|
||||
1. Tries multiple directives to generate candidate decompositions
|
||||
2. Runs each candidate on sample documents
|
||||
3. Uses LLM judge for pairwise comparison
|
||||
4. Returns the winning decomposition
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
yaml_config_path: str,
|
||||
optimizer_model: str = "gpt-5.1",
|
||||
sample_size: int = 5,
|
||||
litellm_kwargs: dict[str, Any] | None = None,
|
||||
console: Console | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the decomposer.
|
||||
|
||||
Args:
|
||||
yaml_config_path: Path to the pipeline YAML config file
|
||||
optimizer_model: LLM model to use for directive instantiation and judging
|
||||
sample_size: Number of sample documents to run candidates on
|
||||
litellm_kwargs: Additional kwargs to pass to litellm.completion
|
||||
console: Rich console for output (uses default if not provided)
|
||||
"""
|
||||
self.yaml_config_path = yaml_config_path
|
||||
self.optimizer_model = optimizer_model
|
||||
self.sample_size = sample_size
|
||||
self.litellm_kwargs = litellm_kwargs or {}
|
||||
if "temperature" not in self.litellm_kwargs:
|
||||
self.litellm_kwargs["temperature"] = 0.0
|
||||
|
||||
self.total_cost = 0.0
|
||||
self.console = console or get_console()
|
||||
|
||||
# Load the config
|
||||
import yaml
|
||||
|
||||
with open(yaml_config_path, "r") as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
|
||||
self.operators = self.config.get("operations", [])
|
||||
self.default_model = self.config.get("default_model", "gpt-4o-mini")
|
||||
self.intermediate_dir = (
|
||||
self.config.get("pipeline", {}).get("output", {}).get("intermediate_dir")
|
||||
)
|
||||
|
||||
def _log(self, message: str) -> None:
|
||||
"""Log a message to the console."""
|
||||
self.console.log(message)
|
||||
|
||||
def get_model_context_limit(self, model: str) -> int:
|
||||
"""
|
||||
Get the context window limit for a model.
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., 'gpt-4o', 'azure/gpt-4')
|
||||
|
||||
Returns:
|
||||
Maximum number of input tokens the model can handle
|
||||
"""
|
||||
model_info = model_cost.get(model, {})
|
||||
# Try without provider prefix if not found
|
||||
if not model_info:
|
||||
model_name = model.split("/")[-1]
|
||||
model_info = model_cost.get(model_name, {})
|
||||
return model_info.get("max_input_tokens", 128000) # Default to 128k
|
||||
|
||||
def get_avg_doc_size(
|
||||
self, sample_data: list[dict[str, Any]], op_config: dict[str, Any]
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate the average document size in characters and tokens.
|
||||
|
||||
Extracts the document content from sample data based on the operation's
|
||||
prompt template (looks for {{ input.field_name }} patterns).
|
||||
|
||||
Args:
|
||||
sample_data: List of sample documents
|
||||
op_config: The operation configuration
|
||||
|
||||
Returns:
|
||||
Tuple of (avg_chars, avg_tokens) for the document content
|
||||
"""
|
||||
import re
|
||||
|
||||
if not sample_data:
|
||||
return 0.0, 0.0
|
||||
|
||||
prompt = op_config.get("prompt", "")
|
||||
model = op_config.get("model", self.default_model)
|
||||
|
||||
# Extract field names from prompt template ({{ input.field_name }})
|
||||
field_pattern = r"\{\{\s*input\.(\w+)\s*\}\}"
|
||||
fields = re.findall(field_pattern, prompt)
|
||||
|
||||
if not fields:
|
||||
# Fallback: use all string values from the first document
|
||||
if sample_data:
|
||||
fields = [
|
||||
k
|
||||
for k, v in sample_data[0].items()
|
||||
if isinstance(v, str) and len(v) > 100
|
||||
]
|
||||
|
||||
total_chars = 0
|
||||
total_tokens = 0
|
||||
|
||||
for doc in sample_data:
|
||||
doc_content = ""
|
||||
for field in fields:
|
||||
if field in doc:
|
||||
value = doc[field]
|
||||
if isinstance(value, str):
|
||||
doc_content += value
|
||||
else:
|
||||
doc_content += str(value)
|
||||
|
||||
total_chars += len(doc_content)
|
||||
if doc_content:
|
||||
total_tokens += count_tokens(doc_content, model)
|
||||
|
||||
n = len(sample_data)
|
||||
return total_chars / n if n > 0 else 0.0, total_tokens / n if n > 0 else 0.0
|
||||
|
||||
def get_applicable_directives(
|
||||
self,
|
||||
sample_data: list[dict[str, Any]],
|
||||
op_config: dict[str, Any],
|
||||
) -> list:
|
||||
"""
|
||||
Get the list of directives applicable based on data characteristics.
|
||||
|
||||
- DocumentChunkingDirective: only if avg doc size > 10% of context window
|
||||
- DeterministicDocCompressionDirective: only if avg doc size > 1000 chars
|
||||
|
||||
Args:
|
||||
sample_data: List of sample documents
|
||||
op_config: The operation configuration
|
||||
|
||||
Returns:
|
||||
List of applicable directive instances
|
||||
"""
|
||||
directives = [] # Build list with priority ordering
|
||||
|
||||
model = op_config.get("model", self.default_model)
|
||||
context_limit = self.get_model_context_limit(model)
|
||||
avg_chars, avg_tokens = self.get_avg_doc_size(sample_data, op_config)
|
||||
|
||||
self._log(
|
||||
f"Document analysis: avg_chars={avg_chars:.0f}, avg_tokens={avg_tokens:.0f}, "
|
||||
f"context_limit={context_limit}"
|
||||
)
|
||||
|
||||
# Add DeterministicDocCompression FIRST if doc size > 1000 chars (high priority)
|
||||
if avg_chars > DOC_COMPRESSION_CHAR_THRESHOLD:
|
||||
self._log(
|
||||
f"[cyan]Enabling DeterministicDocCompression (priority)[/cyan] "
|
||||
f"(avg {avg_chars:.0f} chars > {DOC_COMPRESSION_CHAR_THRESHOLD})"
|
||||
)
|
||||
directives.append(COMPRESSION_DIRECTIVE)
|
||||
|
||||
# Add base directives
|
||||
directives.extend(BASE_MAP_DIRECTIVES)
|
||||
|
||||
# Add DocumentChunking if avg tokens > 10% of context window
|
||||
token_threshold = context_limit * DOC_CHUNKING_CONTEXT_THRESHOLD
|
||||
if avg_tokens > token_threshold:
|
||||
self._log(
|
||||
f"[cyan]Enabling DocumentChunking[/cyan] "
|
||||
f"(avg {avg_tokens:.0f} tokens > {token_threshold:.0f} = 10% of {context_limit})"
|
||||
)
|
||||
directives.append(CHUNKING_DIRECTIVE)
|
||||
else:
|
||||
self._log(
|
||||
f"[dim]Skipping DocumentChunking[/dim] "
|
||||
f"(avg {avg_tokens:.0f} tokens <= {token_threshold:.0f} = 10% of {context_limit})"
|
||||
)
|
||||
|
||||
return directives
|
||||
|
||||
def load_sample_data(self, step_name: str, op_name: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load sample input data for an operation.
|
||||
|
||||
For the first operation, loads from the dataset.
|
||||
For subsequent operations, loads from the previous operation's intermediate output.
|
||||
|
||||
Args:
|
||||
step_name: Name of the pipeline step
|
||||
op_name: Name of the operation
|
||||
|
||||
Returns:
|
||||
List of sample documents
|
||||
"""
|
||||
# Find the operation's position
|
||||
op_names = [op.get("name") for op in self.operators]
|
||||
try:
|
||||
op_idx = op_names.index(op_name)
|
||||
except ValueError:
|
||||
raise ValueError(f"Operation '{op_name}' not found in config")
|
||||
|
||||
if op_idx == 0:
|
||||
# First operation - load from dataset
|
||||
datasets = self.config.get("datasets", {})
|
||||
# Get the first dataset (or the one used by this step)
|
||||
for dataset_name, dataset_config in datasets.items():
|
||||
dataset_path = dataset_config.get("path")
|
||||
if dataset_path and os.path.exists(dataset_path):
|
||||
with open(dataset_path, "r") as f:
|
||||
data = json.load(f)
|
||||
return data[: self.sample_size]
|
||||
raise FileNotFoundError("No dataset found in config")
|
||||
else:
|
||||
# Load from previous operation's intermediate output
|
||||
prev_op_name = op_names[op_idx - 1]
|
||||
output_path = os.path.join(
|
||||
self.intermediate_dir, step_name, f"{prev_op_name}.json"
|
||||
)
|
||||
if not os.path.exists(output_path):
|
||||
raise FileNotFoundError(
|
||||
f"No intermediate output found at {output_path}. "
|
||||
"Run the previous operation first."
|
||||
)
|
||||
with open(output_path, "r") as f:
|
||||
data = json.load(f)
|
||||
return data[: self.sample_size]
|
||||
|
||||
def get_input_file_path(self) -> str:
|
||||
"""Get the input file path for the pipeline."""
|
||||
datasets = self.config.get("datasets", {})
|
||||
for dataset_name, dataset_config in datasets.items():
|
||||
path = dataset_config.get("path")
|
||||
if path:
|
||||
return path
|
||||
return ""
|
||||
|
||||
def generate_candidates(
|
||||
self,
|
||||
op_name: str,
|
||||
sample_data: list[dict[str, Any]],
|
||||
target_op: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Generate candidate decompositions using directives.
|
||||
|
||||
Directives are selected based on data characteristics:
|
||||
- DocumentChunkingDirective: only if avg doc size > 10% of context window
|
||||
- DeterministicDocCompressionDirective: only if avg doc size > 1000 chars
|
||||
|
||||
Args:
|
||||
op_name: Name of the operation to decompose
|
||||
sample_data: Sample data for analyzing document characteristics
|
||||
target_op: The target operation configuration
|
||||
|
||||
Returns:
|
||||
List of candidate dictionaries with 'name', 'ops', 'cost' keys
|
||||
"""
|
||||
candidates = []
|
||||
|
||||
# Add original as baseline
|
||||
self._log("Adding original operation as baseline candidate")
|
||||
candidates.append(
|
||||
{
|
||||
"name": "original",
|
||||
"ops": deepcopy(self.operators),
|
||||
"cost": 0.0,
|
||||
"error": None,
|
||||
}
|
||||
)
|
||||
|
||||
input_file_path = self.get_input_file_path()
|
||||
|
||||
# Get applicable directives based on data characteristics
|
||||
applicable_directives = self.get_applicable_directives(sample_data, target_op)
|
||||
|
||||
self._log(
|
||||
f"Generating candidates using {len(applicable_directives)} directives..."
|
||||
)
|
||||
for i, directive in enumerate(applicable_directives, 1):
|
||||
with self.console.status(
|
||||
f"[bold cyan]({i}/{len(applicable_directives)}) Trying directive: {directive.name}...[/bold cyan]",
|
||||
spinner="dots",
|
||||
):
|
||||
try:
|
||||
new_ops_list, _, cost = directive.instantiate(
|
||||
operators=deepcopy(self.operators),
|
||||
target_ops=[op_name],
|
||||
agent_llm=self.optimizer_model,
|
||||
message_history=[],
|
||||
global_default_model=self.default_model,
|
||||
input_file_path=input_file_path,
|
||||
)
|
||||
self.total_cost += cost
|
||||
self._log(
|
||||
f" [green]✓[/green] {directive.name} generated {len(new_ops_list)} operations (cost: ${cost:.4f})"
|
||||
)
|
||||
candidates.append(
|
||||
{
|
||||
"name": directive.name,
|
||||
"ops": new_ops_list,
|
||||
"cost": cost,
|
||||
"error": None,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
# Directive not applicable or failed - skip it
|
||||
self._log(f" [red]✗[/red] {directive.name} failed: {str(e)}")
|
||||
candidates.append(
|
||||
{
|
||||
"name": directive.name,
|
||||
"ops": None,
|
||||
"cost": 0.0,
|
||||
"error": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
return candidates
|
||||
|
||||
def extract_ops_to_run(
|
||||
self, ops_list: list[dict], original_op_name: str
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Extract the operations that replaced the original operation.
|
||||
|
||||
Args:
|
||||
ops_list: The full transformed operations list
|
||||
original_op_name: Name of the original operation that was decomposed
|
||||
|
||||
Returns:
|
||||
List of operations that should be run on samples
|
||||
"""
|
||||
# Find ops that are new (not in original) or modified
|
||||
original_names = {op["name"] for op in self.operators}
|
||||
|
||||
# Find the position where the original op was
|
||||
original_idx = None
|
||||
for i, op in enumerate(self.operators):
|
||||
if op["name"] == original_op_name:
|
||||
original_idx = i
|
||||
break
|
||||
|
||||
if original_idx is None:
|
||||
return ops_list
|
||||
|
||||
# Find new ops (those not in original list)
|
||||
new_ops = []
|
||||
for op in ops_list:
|
||||
if op["name"] not in original_names or op["name"] == original_op_name:
|
||||
new_ops.append(op)
|
||||
|
||||
return new_ops if new_ops else [ops_list[original_idx]]
|
||||
|
||||
def run_candidate_on_samples(
|
||||
self,
|
||||
candidate: dict[str, Any],
|
||||
sample_data: list[dict[str, Any]],
|
||||
original_op_name: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Run a candidate's operations on sample data.
|
||||
|
||||
Args:
|
||||
candidate: Candidate dictionary with 'ops' key
|
||||
sample_data: List of sample documents
|
||||
original_op_name: Name of the original operation
|
||||
|
||||
Returns:
|
||||
List of output documents
|
||||
"""
|
||||
if candidate["ops"] is None:
|
||||
return []
|
||||
|
||||
# Extract ops to run
|
||||
ops_to_run = self.extract_ops_to_run(candidate["ops"], original_op_name)
|
||||
|
||||
if not ops_to_run:
|
||||
return []
|
||||
|
||||
# Create a minimal config for running these ops
|
||||
# Write sample data to a temp file
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(sample_data, f)
|
||||
temp_input_path = f.name
|
||||
|
||||
# Create a temp output file (required by DSLRunner validation)
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
temp_output_path = f.name
|
||||
|
||||
try:
|
||||
# Create a minimal pipeline config
|
||||
temp_config = {
|
||||
"default_model": self.default_model,
|
||||
"operations": ops_to_run,
|
||||
"datasets": {
|
||||
"sample_data": {
|
||||
"type": "file",
|
||||
"path": temp_input_path,
|
||||
}
|
||||
},
|
||||
"pipeline": {
|
||||
"steps": [
|
||||
{
|
||||
"name": "decompose_test",
|
||||
"input": "sample_data",
|
||||
"operations": [op["name"] for op in ops_to_run],
|
||||
}
|
||||
],
|
||||
"output": {
|
||||
"type": "file",
|
||||
"path": temp_output_path,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Create runner and execute
|
||||
runner = DSLRunner(temp_config, max_threads=4)
|
||||
|
||||
# Run operations sequentially on the data
|
||||
current_data = sample_data
|
||||
for op_config in ops_to_run:
|
||||
current_data = runner._run_operation(op_config, current_data)
|
||||
|
||||
self.total_cost += runner.total_cost
|
||||
|
||||
return current_data
|
||||
|
||||
finally:
|
||||
# Clean up temp files
|
||||
for temp_path in [temp_input_path, temp_output_path]:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def pairwise_compare(
|
||||
self,
|
||||
candidate_a: dict[str, Any],
|
||||
candidate_b: dict[str, Any],
|
||||
original_prompt: str,
|
||||
output_schema: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Compare two candidates using LLM judge.
|
||||
|
||||
Args:
|
||||
candidate_a: First candidate with 'name', 'outputs' keys
|
||||
candidate_b: Second candidate with 'name', 'outputs' keys
|
||||
original_prompt: The original operation's prompt
|
||||
output_schema: The expected output schema
|
||||
|
||||
Returns:
|
||||
The winning candidate
|
||||
"""
|
||||
# If one has no outputs, the other wins
|
||||
if not candidate_a.get("outputs"):
|
||||
return candidate_b
|
||||
if not candidate_b.get("outputs"):
|
||||
return candidate_a
|
||||
|
||||
system_prompt = """You are an expert judge comparing outputs from two data processing pipeline variants.
|
||||
|
||||
Your task is to determine which variant produces BETTER outputs based on:
|
||||
1. **Completeness**: Does the output contain all required information?
|
||||
2. **Accuracy**: Is the extracted/generated information correct?
|
||||
3. **Consistency**: Are the outputs consistent across different samples?
|
||||
4. **Quality**: Is the output well-structured and useful?
|
||||
|
||||
Be objective and focus on the actual output quality, not the approach used."""
|
||||
|
||||
user_prompt = f"""Compare outputs from two pipeline variants for this task:
|
||||
|
||||
## Original Task
|
||||
**Prompt:**
|
||||
```
|
||||
{original_prompt[:2000]}{"..." if len(original_prompt) > 2000 else ""}
|
||||
```
|
||||
|
||||
**Expected Output Schema:**
|
||||
```json
|
||||
{json.dumps(output_schema, indent=2)}
|
||||
```
|
||||
|
||||
## Variant A: {candidate_a["name"]}
|
||||
**Sample Outputs:**
|
||||
```json
|
||||
{json.dumps(candidate_a["outputs"][:3], indent=2, default=str)}
|
||||
```
|
||||
|
||||
## Variant B: {candidate_b["name"]}
|
||||
**Sample Outputs:**
|
||||
```json
|
||||
{json.dumps(candidate_b["outputs"][:3], indent=2, default=str)}
|
||||
```
|
||||
|
||||
## Your Task
|
||||
Which variant produces better outputs? Consider completeness, accuracy, consistency, and quality.
|
||||
Respond with your analysis and final choice."""
|
||||
|
||||
response = completion(
|
||||
model=self.optimizer_model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "comparison_result",
|
||||
"strict": True,
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"winner": {
|
||||
"type": "string",
|
||||
"enum": ["A", "B"],
|
||||
"description": "Which variant is better: A or B",
|
||||
},
|
||||
"rationale": {
|
||||
"type": "string",
|
||||
"description": "Explanation of why this variant is better",
|
||||
},
|
||||
"a_strengths": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Strengths of variant A",
|
||||
},
|
||||
"b_strengths": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Strengths of variant B",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"winner",
|
||||
"rationale",
|
||||
"a_strengths",
|
||||
"b_strengths",
|
||||
],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
**self.litellm_kwargs,
|
||||
)
|
||||
|
||||
cost = completion_cost(response)
|
||||
self.total_cost += cost
|
||||
|
||||
result = json.loads(response.choices[0].message.content)
|
||||
|
||||
winner = candidate_a if result["winner"] == "A" else candidate_b
|
||||
winner["comparison_rationale"] = result["rationale"]
|
||||
|
||||
return winner
|
||||
|
||||
def decompose(
|
||||
self,
|
||||
step_name: str,
|
||||
op_name: str,
|
||||
) -> tuple[list[dict[str, Any]], str, int, float]:
|
||||
"""
|
||||
Decompose an operation using fast directive-based approach.
|
||||
|
||||
This is the main entry point. It:
|
||||
1. Generates candidate decompositions using directives
|
||||
2. Runs each on sample data
|
||||
3. Uses pairwise LLM comparison to select the best
|
||||
4. Returns the winning decomposition
|
||||
|
||||
Args:
|
||||
step_name: Name of the pipeline step
|
||||
op_name: Name of the operation to decompose
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- decomposed_ops: List of operations that replace the original
|
||||
- winning_directive: Name of the winning directive
|
||||
- candidates_evaluated: Number of candidates that were compared
|
||||
- original_outputs: Sample outputs from the original operation
|
||||
- decomposed_outputs: Sample outputs from the winning decomposition
|
||||
- comparison_rationale: LLM's explanation of why the winner was chosen
|
||||
- cost: Total LLM API cost
|
||||
"""
|
||||
self.console.rule("[bold blue]Fast Decomposition[/bold blue]")
|
||||
self._log(f"[bold]Target operation:[/bold] {op_name}")
|
||||
|
||||
# Find the target operation config
|
||||
target_op = None
|
||||
for op in self.operators:
|
||||
if op["name"] == op_name:
|
||||
target_op = op
|
||||
break
|
||||
|
||||
if target_op is None:
|
||||
raise ValueError(f"Operation '{op_name}' not found in config")
|
||||
|
||||
# Verify it's a map operation
|
||||
if target_op.get("type") != "map":
|
||||
raise ValueError(
|
||||
f"Operation '{op_name}' is type '{target_op.get('type')}', "
|
||||
"but fast decomposition only supports 'map' operations"
|
||||
)
|
||||
|
||||
# Load sample data
|
||||
self._log(f"Loading sample data from step '{step_name}'...")
|
||||
sample_data = self.load_sample_data(step_name, op_name)
|
||||
self._log(f"[green]✓[/green] Loaded {len(sample_data)} sample documents")
|
||||
|
||||
# Generate candidates
|
||||
self.console.rule("[bold]Generating Candidates[/bold]")
|
||||
candidates = self.generate_candidates(op_name, sample_data, target_op)
|
||||
|
||||
# Filter out failed candidates
|
||||
valid_candidates = [c for c in candidates if c["ops"] is not None]
|
||||
self._log(f"[bold]{len(valid_candidates)}[/bold] valid candidates generated")
|
||||
|
||||
if len(valid_candidates) < 2:
|
||||
# Only original (or nothing) - return original
|
||||
self._log(
|
||||
"[yellow]No alternative decompositions generated. Keeping original.[/yellow]"
|
||||
)
|
||||
return {
|
||||
"decomposed_ops": self.operators,
|
||||
"winning_directive": "original",
|
||||
"candidates_evaluated": len(valid_candidates),
|
||||
"original_outputs": [],
|
||||
"decomposed_outputs": [],
|
||||
"comparison_rationale": "No alternative decompositions were generated.",
|
||||
"cost": self.total_cost,
|
||||
}
|
||||
|
||||
# Run each candidate on samples IN PARALLEL
|
||||
self.console.rule("[bold]Running Candidates on Samples[/bold]")
|
||||
self._log(f"Running {len(valid_candidates)} candidates in parallel...")
|
||||
|
||||
def run_single_candidate(candidate):
|
||||
"""Run a single candidate and return results."""
|
||||
name = candidate["name"]
|
||||
try:
|
||||
outputs = self.run_candidate_on_samples(candidate, sample_data, op_name)
|
||||
return {"name": name, "outputs": outputs, "error": None}
|
||||
except Exception as e:
|
||||
error_tb = traceback.format_exc()
|
||||
return {
|
||||
"name": name,
|
||||
"outputs": [],
|
||||
"error": str(e),
|
||||
"traceback": error_tb,
|
||||
}
|
||||
|
||||
# Run all candidates in parallel
|
||||
with self.console.status(
|
||||
"[bold cyan]Running candidates on samples...[/bold cyan]", spinner="dots"
|
||||
):
|
||||
with ThreadPoolExecutor(max_workers=len(valid_candidates)) as executor:
|
||||
future_to_candidate = {
|
||||
executor.submit(run_single_candidate, c): c
|
||||
for c in valid_candidates
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_candidate):
|
||||
candidate = future_to_candidate[future]
|
||||
result = future.result()
|
||||
|
||||
if result["error"]:
|
||||
candidate["outputs"] = []
|
||||
candidate["run_error"] = result["error"]
|
||||
self._log(
|
||||
f" [red]✗[/red] {result['name']} failed: {result['error']}"
|
||||
)
|
||||
else:
|
||||
candidate["outputs"] = result["outputs"]
|
||||
self._log(
|
||||
f" [green]✓[/green] {result['name']}: {len(result['outputs'])} outputs"
|
||||
)
|
||||
|
||||
# Filter to candidates with outputs
|
||||
candidates_with_outputs = [c for c in valid_candidates if c.get("outputs")]
|
||||
self._log(
|
||||
f"[bold]{len(candidates_with_outputs)}[/bold] candidates produced outputs"
|
||||
)
|
||||
|
||||
if not candidates_with_outputs:
|
||||
# All failed - return original
|
||||
self._log("[red]All decomposition candidates failed to execute.[/red]")
|
||||
# Log the errors for debugging
|
||||
for c in valid_candidates:
|
||||
if c.get("run_error"):
|
||||
self._log(f" [red]{c['name']}:[/red] {c['run_error']}")
|
||||
return {
|
||||
"decomposed_ops": self.operators,
|
||||
"winning_directive": "original",
|
||||
"candidates_evaluated": 0,
|
||||
"original_outputs": [],
|
||||
"decomposed_outputs": [],
|
||||
"comparison_rationale": "All decomposition candidates failed to execute.",
|
||||
"cost": self.total_cost,
|
||||
}
|
||||
|
||||
# Find the original candidate's outputs
|
||||
original_candidate = next(
|
||||
(c for c in candidates_with_outputs if c["name"] == "original"), None
|
||||
)
|
||||
original_outputs = original_candidate["outputs"] if original_candidate else []
|
||||
|
||||
# Parallel pairwise comparison: compare all candidates against original
|
||||
self.console.rule("[bold]Pairwise Comparison[/bold]")
|
||||
original_prompt = target_op.get("prompt", "")
|
||||
output_schema = target_op.get("output", {}).get("schema", {})
|
||||
|
||||
# If only original exists, it wins by default
|
||||
if len(candidates_with_outputs) == 1:
|
||||
winner = candidates_with_outputs[0]
|
||||
elif original_candidate is None:
|
||||
# No original - just pick the first candidate
|
||||
winner = candidates_with_outputs[0]
|
||||
else:
|
||||
# Compare all non-original candidates against original IN PARALLEL
|
||||
challengers = [
|
||||
c for c in candidates_with_outputs if c["name"] != "original"
|
||||
]
|
||||
|
||||
if not challengers:
|
||||
winner = original_candidate
|
||||
else:
|
||||
self._log(
|
||||
f"Comparing {len(challengers)} candidates against original in parallel..."
|
||||
)
|
||||
|
||||
def compare_against_original(challenger):
|
||||
"""Compare a single challenger against original."""
|
||||
try:
|
||||
result = self.pairwise_compare(
|
||||
original_candidate,
|
||||
challenger,
|
||||
original_prompt,
|
||||
output_schema,
|
||||
)
|
||||
won = result["name"] == challenger["name"]
|
||||
return {
|
||||
"challenger": challenger,
|
||||
"won": won,
|
||||
"rationale": result.get("comparison_rationale", ""),
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"challenger": challenger,
|
||||
"won": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
# Run all comparisons in parallel
|
||||
comparison_results = []
|
||||
with self.console.status(
|
||||
f"[bold cyan]Running {len(challengers)} comparisons in parallel...[/bold cyan]",
|
||||
spinner="dots",
|
||||
):
|
||||
with ThreadPoolExecutor(max_workers=len(challengers)) as executor:
|
||||
future_to_challenger = {
|
||||
executor.submit(compare_against_original, c): c
|
||||
for c in challengers
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_challenger):
|
||||
result = future.result()
|
||||
comparison_results.append(result)
|
||||
challenger_name = result["challenger"]["name"]
|
||||
if result.get("error"):
|
||||
self._log(
|
||||
f" [red]✗[/red] {challenger_name} vs original: error - {result['error']}"
|
||||
)
|
||||
elif result["won"]:
|
||||
self._log(
|
||||
f" [green]✓[/green] {challenger_name} beat original"
|
||||
)
|
||||
else:
|
||||
self._log(
|
||||
f" [dim]○[/dim] {challenger_name} lost to original"
|
||||
)
|
||||
|
||||
# Find winners (candidates that beat original)
|
||||
winners = [r for r in comparison_results if r.get("won")]
|
||||
|
||||
if not winners:
|
||||
# Original beats all challengers
|
||||
winner = original_candidate
|
||||
self._log("[bold]Original wins against all challengers[/bold]")
|
||||
elif len(winners) == 1:
|
||||
# Single winner
|
||||
winner = winners[0]["challenger"]
|
||||
winner["comparison_rationale"] = winners[0].get("rationale", "")
|
||||
else:
|
||||
# Multiple winners beat original - run tiebreaker comparisons in parallel
|
||||
self._log(
|
||||
f"[bold]{len(winners)} candidates beat original - running tiebreaker...[/bold]"
|
||||
)
|
||||
|
||||
# Compare all winners against each other in parallel (round-robin)
|
||||
winner_candidates = [w["challenger"] for w in winners]
|
||||
win_counts = {c["name"]: 0 for c in winner_candidates}
|
||||
|
||||
# Generate all pairwise matchups
|
||||
matchups = []
|
||||
for i, a in enumerate(winner_candidates):
|
||||
for b in winner_candidates[i + 1 :]:
|
||||
matchups.append((a, b))
|
||||
|
||||
if matchups:
|
||||
|
||||
def run_matchup(matchup):
|
||||
a, b = matchup
|
||||
try:
|
||||
result = self.pairwise_compare(
|
||||
a, b, original_prompt, output_schema
|
||||
)
|
||||
return {
|
||||
"winner": result["name"],
|
||||
"a": a["name"],
|
||||
"b": b["name"],
|
||||
}
|
||||
except Exception:
|
||||
return {
|
||||
"winner": a["name"],
|
||||
"a": a["name"],
|
||||
"b": b["name"],
|
||||
} # Default to first
|
||||
|
||||
with self.console.status(
|
||||
f"[bold cyan]Running {len(matchups)} tiebreaker comparisons...[/bold cyan]",
|
||||
spinner="dots",
|
||||
):
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=len(matchups)
|
||||
) as executor:
|
||||
for result in executor.map(run_matchup, matchups):
|
||||
win_counts[result["winner"]] += 1
|
||||
self._log(
|
||||
f" [dim]{result['a']} vs {result['b']} → {result['winner']}[/dim]"
|
||||
)
|
||||
|
||||
# Pick candidate with most wins
|
||||
best_name = max(win_counts, key=win_counts.get)
|
||||
winner = next(
|
||||
c for c in winner_candidates if c["name"] == best_name
|
||||
)
|
||||
self._log(
|
||||
f"[bold]Tiebreaker winner: {best_name} ({win_counts[best_name]} wins)[/bold]"
|
||||
)
|
||||
|
||||
# Extract the decomposed operations
|
||||
decomposed_ops = self.extract_ops_to_run(winner["ops"], op_name)
|
||||
|
||||
# Final summary
|
||||
self.console.rule("[bold green]Decomposition Complete[/bold green]")
|
||||
self._log(f"[bold]Winner:[/bold] [green]{winner['name']}[/green]")
|
||||
self._log(f"[bold]Candidates evaluated:[/bold] {len(candidates_with_outputs)}")
|
||||
self._log(f"[bold]New operations:[/bold] {len(decomposed_ops)}")
|
||||
self._log(f"[bold]Total cost:[/bold] ${self.total_cost:.4f}")
|
||||
|
||||
return {
|
||||
"decomposed_ops": decomposed_ops,
|
||||
"winning_directive": winner["name"],
|
||||
"candidates_evaluated": len(candidates_with_outputs),
|
||||
"original_outputs": original_outputs,
|
||||
"decomposed_outputs": winner.get("outputs", []),
|
||||
"comparison_rationale": winner.get("comparison_rationale", ""),
|
||||
"cost": self.total_cost,
|
||||
}
|
||||
|
|
@ -1,337 +0,0 @@
|
|||
"""
|
||||
Fast should_optimize analyzer using a single LLM call.
|
||||
|
||||
This module provides a lightweight alternative to the full MapOptimizer/ReduceOptimizer/JoinOptimizer
|
||||
flow for quickly determining if an operation should be decomposed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm import completion, model_cost
|
||||
|
||||
from docetl.utils import completion_cost, count_tokens
|
||||
|
||||
# Drop unsupported params for models like gpt-5 that don't support temperature=0
|
||||
litellm.drop_params = True
|
||||
|
||||
|
||||
class FastShouldOptimizeAnalyzer:
|
||||
"""
|
||||
Analyzes whether an operation should be optimized using a single LLM call.
|
||||
|
||||
Instead of running the operation on sample data and using complex evaluation logic,
|
||||
this reads cached outputs from intermediate files and makes one judgment call.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intermediate_dir: str,
|
||||
optimizer_model: str = "gpt-5.1",
|
||||
litellm_kwargs: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the analyzer.
|
||||
|
||||
Args:
|
||||
intermediate_dir: Path to the directory containing intermediate outputs
|
||||
optimizer_model: The LLM model to use for analysis (default: gpt-5.1)
|
||||
litellm_kwargs: Additional kwargs to pass to litellm.completion
|
||||
"""
|
||||
self.intermediate_dir = intermediate_dir
|
||||
self.optimizer_model = optimizer_model
|
||||
self.litellm_kwargs = litellm_kwargs or {}
|
||||
if "temperature" not in self.litellm_kwargs:
|
||||
self.litellm_kwargs["temperature"] = 0.0
|
||||
|
||||
def load_operation_data(self, step_name: str, op_name: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load data from the intermediate file for an operation.
|
||||
|
||||
Args:
|
||||
step_name: Name of the pipeline step
|
||||
op_name: Name of the operation
|
||||
|
||||
Returns:
|
||||
List of dictionaries
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the intermediate file doesn't exist
|
||||
"""
|
||||
output_path = os.path.join(self.intermediate_dir, step_name, f"{op_name}.json")
|
||||
if not os.path.exists(output_path):
|
||||
raise FileNotFoundError(
|
||||
f"No output file found at {output_path}. "
|
||||
"Run the operation first to generate outputs."
|
||||
)
|
||||
with open(output_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
def find_previous_operation(
|
||||
self, operations: list[dict[str, Any]], op_name: str
|
||||
) -> str | None:
|
||||
"""
|
||||
Find the operation that comes before op_name in the pipeline.
|
||||
|
||||
Args:
|
||||
operations: List of operation configs from the pipeline
|
||||
op_name: Name of the current operation
|
||||
|
||||
Returns:
|
||||
Name of the previous operation, or None if this is the first operation
|
||||
"""
|
||||
op_names = [op.get("name") for op in operations]
|
||||
try:
|
||||
idx = op_names.index(op_name)
|
||||
if idx > 0:
|
||||
return op_names[idx - 1]
|
||||
except ValueError:
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_max_context_tokens(self) -> int:
|
||||
"""
|
||||
Get the maximum input tokens for the optimizer model.
|
||||
|
||||
Returns:
|
||||
Maximum number of input tokens the model can handle
|
||||
"""
|
||||
model_info = model_cost.get(self.optimizer_model, {})
|
||||
# Try without provider prefix if not found
|
||||
if not model_info:
|
||||
model_name = self.optimizer_model.split("/")[-1]
|
||||
model_info = model_cost.get(model_name, {})
|
||||
return model_info.get("max_input_tokens", 128000) # Default to 128k
|
||||
|
||||
def calculate_samples_that_fit(
|
||||
self,
|
||||
op_config: dict[str, Any],
|
||||
outputs: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Calculate how many output samples fit in the context window.
|
||||
|
||||
Reserves space for system prompt, operation config, and response buffer,
|
||||
then fills remaining space with as many samples as possible.
|
||||
|
||||
Args:
|
||||
op_config: The operation configuration dictionary
|
||||
outputs: List of all output documents
|
||||
|
||||
Returns:
|
||||
List of samples that fit in the context window
|
||||
"""
|
||||
max_tokens = self.get_max_context_tokens()
|
||||
|
||||
# Reserve tokens for fixed parts
|
||||
system_prompt_tokens = 500
|
||||
op_config_tokens = count_tokens(
|
||||
json.dumps(op_config, default=str), self.optimizer_model
|
||||
)
|
||||
response_buffer = 2000
|
||||
|
||||
available_for_samples = (
|
||||
max_tokens - system_prompt_tokens - op_config_tokens - response_buffer
|
||||
)
|
||||
|
||||
# Collect samples that fit
|
||||
samples_to_include = []
|
||||
tokens_used = 0
|
||||
|
||||
for output in outputs:
|
||||
sample_json = json.dumps(output, default=str)
|
||||
sample_tokens = count_tokens(sample_json, self.optimizer_model)
|
||||
if tokens_used + sample_tokens <= available_for_samples:
|
||||
samples_to_include.append(output)
|
||||
tokens_used += sample_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
return samples_to_include
|
||||
|
||||
def build_analysis_prompt(
|
||||
self,
|
||||
op_config: dict[str, Any],
|
||||
samples: list[dict[str, Any]],
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Build the system prompt and user prompt for the analysis LLM call.
|
||||
|
||||
Args:
|
||||
op_config: The operation configuration dictionary
|
||||
samples: List of output samples to analyze
|
||||
|
||||
Returns:
|
||||
Tuple of (system_prompt, user_prompt)
|
||||
"""
|
||||
system_prompt = """You are an expert at analyzing LLM-powered data processing operations.
|
||||
Your task is to determine if an operation would benefit from being decomposed into multiple
|
||||
smaller, focused operations (also known as "optimization" or "decomposition").
|
||||
|
||||
An operation SHOULD be decomposed when:
|
||||
1. The prompt asks the LLM to do multiple distinct tasks that could be done separately
|
||||
2. The task is complex enough that breaking it into sequential steps would improve accuracy
|
||||
3. The outputs show inconsistency, incompleteness, or quality issues that iterative refinement could fix
|
||||
4. Long documents need to be processed in chunks rather than all at once
|
||||
5. The prompt asks for both extraction AND analysis/synthesis in one step
|
||||
|
||||
An operation should NOT be decomposed when:
|
||||
1. It performs a single, focused task well
|
||||
2. The outputs are consistently high quality and complete
|
||||
3. The task is simple and atomic (e.g., simple classification, single field extraction)
|
||||
4. The operation is already well-scoped and produces reliable results
|
||||
|
||||
Be conservative - only recommend decomposition if there's clear evidence it would help."""
|
||||
|
||||
output_schema = op_config.get("output", {}).get("schema", {})
|
||||
prompt_template = op_config.get("prompt", "No prompt specified")
|
||||
|
||||
# Truncate very long prompts for display
|
||||
if len(prompt_template) > 3000:
|
||||
prompt_template = prompt_template[:3000] + "\n... [truncated]"
|
||||
|
||||
user_prompt = f"""Analyze this data processing operation and its outputs:
|
||||
|
||||
## Operation Configuration
|
||||
|
||||
**Name:** {op_config.get('name', 'unknown')}
|
||||
**Type:** {op_config.get('type', 'unknown')}
|
||||
|
||||
**Prompt Template:**
|
||||
```
|
||||
{prompt_template}
|
||||
```
|
||||
|
||||
**Output Schema:**
|
||||
```json
|
||||
{json.dumps(output_schema, indent=2)}
|
||||
```
|
||||
|
||||
## Sample Outputs ({len(samples)} samples from the operation)
|
||||
|
||||
```json
|
||||
{json.dumps(samples, indent=2, default=str)}
|
||||
```
|
||||
|
||||
## Your Task
|
||||
|
||||
Based on the operation configuration and sample outputs, determine:
|
||||
1. Should this operation be decomposed/optimized?
|
||||
2. If yes, what specific improvements would help?
|
||||
|
||||
Consider the quality, completeness, and consistency of the outputs when making your assessment."""
|
||||
|
||||
return system_prompt, user_prompt
|
||||
|
||||
def analyze(
|
||||
self,
|
||||
op_config: dict[str, Any],
|
||||
step_name: str,
|
||||
op_name: str,
|
||||
) -> tuple[str, list[dict[str, Any]], int, float]:
|
||||
"""
|
||||
Analyze whether an operation should be optimized.
|
||||
|
||||
This is the main entry point. It loads outputs, builds the prompt,
|
||||
makes a single LLM call, and returns the assessment.
|
||||
|
||||
Args:
|
||||
op_config: The operation configuration dictionary
|
||||
step_name: Name of the pipeline step
|
||||
op_name: Name of the operation
|
||||
|
||||
Returns:
|
||||
Tuple of (rationale, output_samples, num_docs_analyzed, cost):
|
||||
- rationale: Empty string if no optimization needed, explanation if it should be optimized
|
||||
- output_samples: The samples that were analyzed
|
||||
- num_docs_analyzed: Number of documents that fit in the LLM prompt
|
||||
- cost: LLM API cost in USD
|
||||
|
||||
Raises:
|
||||
ValueError: If the operation is not an LLM-powered map operation
|
||||
"""
|
||||
# Validate operation type - only LLM-powered map operations
|
||||
op_type = op_config.get("type", "")
|
||||
if op_type != "map":
|
||||
raise ValueError(
|
||||
f"should_optimize only supports 'map' operations, got '{op_type}'. "
|
||||
"Only LLM-powered map operations can be analyzed for decomposition."
|
||||
)
|
||||
|
||||
# Load outputs from intermediate file
|
||||
outputs = self.load_operation_data(step_name, op_name)
|
||||
|
||||
if not outputs:
|
||||
return "No output samples available for analysis.", [], 0, 0.0
|
||||
|
||||
# Calculate samples that fit in context
|
||||
samples = self.calculate_samples_that_fit(op_config, outputs)
|
||||
|
||||
if not samples:
|
||||
return "Could not fit any samples in context window.", outputs[:5], 0, 0.0
|
||||
|
||||
# Build prompt
|
||||
system_prompt, user_prompt = self.build_analysis_prompt(op_config, samples)
|
||||
|
||||
# Make LLM call with structured output
|
||||
response = completion(
|
||||
model=self.optimizer_model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "optimization_analysis",
|
||||
"strict": True,
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"should_optimize": {
|
||||
"type": "boolean",
|
||||
"description": "True if operation should be decomposed/optimized",
|
||||
},
|
||||
"rationale": {
|
||||
"type": "string",
|
||||
"description": "Explanation of why the operation should or should not be optimized",
|
||||
},
|
||||
"suggested_improvements": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Specific improvements if optimization is recommended (empty if not)",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"should_optimize",
|
||||
"rationale",
|
||||
"suggested_improvements",
|
||||
],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
**self.litellm_kwargs,
|
||||
)
|
||||
|
||||
# Calculate cost
|
||||
cost = completion_cost(response)
|
||||
|
||||
# Parse response
|
||||
result = json.loads(response.choices[0].message.content)
|
||||
|
||||
num_docs_analyzed = len(samples)
|
||||
|
||||
if result["should_optimize"]:
|
||||
# Build rationale string with improvements
|
||||
rationale_parts = [result["rationale"]]
|
||||
if result["suggested_improvements"]:
|
||||
rationale_parts.append("\n\nSuggested improvements:")
|
||||
for imp in result["suggested_improvements"]:
|
||||
rationale_parts.append(f"- {imp}")
|
||||
return "\n".join(rationale_parts), samples, num_docs_analyzed, cost
|
||||
else:
|
||||
# Return empty string to indicate no optimization needed
|
||||
return "", samples, num_docs_analyzed, cost
|
||||
|
|
@ -30,7 +30,7 @@ class LLMClient:
|
|||
Initialize the LLMClient.
|
||||
|
||||
Args:
|
||||
model (str, optional): The name of the LLM model to use. Defaults to "gpt-5.1".
|
||||
model (str, optional): The name of the LLM model to use. Defaults to "gpt-4o".
|
||||
**litellm_kwargs: Additional keyword arguments for the LLM model.
|
||||
"""
|
||||
self.rewrite_agent_model = rewrite_agent_model
|
||||
|
|
|
|||
|
|
@ -1,85 +0,0 @@
|
|||
# DocETL Reasoning Optimizer - Agent Guide
|
||||
|
||||
## Overview
|
||||
The reasoning optimizer uses a graph-search algorithm with LLM-based directives to rewrite and optimize DocETL pipelines. This directory contains the directive system for pipeline transformations.
|
||||
|
||||
## Creating New Directives - Interactive Process
|
||||
|
||||
As an AI agent, I can guide you through creating new rewrite directives step-by-step. Here's the process:
|
||||
|
||||
### 1. Initial Consultation
|
||||
- **What I'll ask**: Describe what transformation you want the directive to perform
|
||||
- **What you provide**: High-level description of the directive's purpose and when to use it
|
||||
- **What I'll do**: Analyze existing directives and suggest the best approach
|
||||
|
||||
### 2. Schema Design Phase
|
||||
- **What I'll ask**: Confirm the configuration parameters needed for the directive
|
||||
- **What I'll propose**: The instantiate schema structure (Pydantic models) that the LLM agent will output
|
||||
- **What you confirm**: Whether the schema captures all necessary parameters
|
||||
|
||||
### 3. Directive Specification
|
||||
- **What I'll propose**:
|
||||
- `name`: Technical identifier for the directive
|
||||
- `formal_description`: Brief transformation pattern (e.g., "Op => Code Map -> Op")
|
||||
- `nl_description`: Natural language explanation
|
||||
- `when_to_use`: Specific use cases
|
||||
- **What you confirm**: Whether these descriptions accurately capture the directive's purpose
|
||||
|
||||
### 4. Example Creation
|
||||
- **What I'll propose**: Example showing original operation and expected instantiate schema output
|
||||
- **What you confirm**: Whether the example demonstrates the directive correctly
|
||||
|
||||
### 5. Test Case Design
|
||||
- **What I'll propose**: Test cases with input operations and expected behaviors
|
||||
- **What you confirm**: Whether test cases cover important scenarios
|
||||
|
||||
### 6. Implementation Review
|
||||
- **What I'll show**: Complete directive implementation including:
|
||||
- Schema classes in `instantiate_schemas.py`
|
||||
- Directive class with all required methods
|
||||
- Registration in `__init__.py`
|
||||
- Apply tests in `tests/reasoning_optimizer/test_directive_apply.py`
|
||||
- **What you confirm**: Final approval before implementation
|
||||
|
||||
## Existing Directive Patterns
|
||||
|
||||
### Single Operation Modification
|
||||
- **Gleaning**: Adds validation loops (`validation_prompt`, `num_rounds`)
|
||||
- **Change Model**: Switches LLM model (`model`)
|
||||
- **Deterministic Doc Compression**: Adds regex-based preprocessing
|
||||
|
||||
### Operation Replacement
|
||||
- **Chaining**: Replaces complex operation with sequential simpler ones
|
||||
- **Isolating Subtasks**: Breaks operation into independent parallel tasks
|
||||
|
||||
### Pipeline Preprocessing
|
||||
- **Doc Summarization**: Adds document summarization before main processing
|
||||
|
||||
## Key Files and Structure
|
||||
- `directives/`: Individual directive implementations
|
||||
- `instantiate_schemas.py`: Pydantic schemas for LLM outputs
|
||||
- `agent.py`: Core MCTS agent that applies directives
|
||||
- `op_descriptions.py`: Operation type descriptions for the agent
|
||||
|
||||
## Testing Commands
|
||||
|
||||
### Instantiation Tests
|
||||
- Single directive: `python experiments/reasoning/run_tests.py --directive=directive_name`
|
||||
- All directives: `python experiments/reasoning/run_tests.py`
|
||||
|
||||
### Apply Tests
|
||||
- All directive apply methods: `python tests/reasoning_optimizer/test_directive_apply.py`
|
||||
|
||||
### Full MCTS
|
||||
- Complete optimization: `python experiments/reasoning/run_mcts.py`
|
||||
|
||||
## Workflow for New Directive
|
||||
1. Describe desired transformation → I analyze and propose approach
|
||||
2. Confirm schema design → I implement Pydantic models
|
||||
3. Confirm descriptions → I write directive metadata
|
||||
4. Confirm examples → I create demonstration cases
|
||||
5. Confirm test cases → I design validation scenarios
|
||||
6. Review implementation → I write complete directive code
|
||||
7. Test and iterate → We verify functionality together
|
||||
|
||||
**Ready to create a new directive? Describe what transformation you want to implement!**
|
||||
|
|
@ -1,521 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, List
|
||||
|
||||
import litellm
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from docetl.reasoning_optimizer.directives import (
|
||||
get_all_directive_strings,
|
||||
instantiate_directive,
|
||||
)
|
||||
from docetl.reasoning_optimizer.load_data import load_input_doc
|
||||
from docetl.utils import load_config
|
||||
|
||||
from .op_descriptions import * # noqa: F403, F405
|
||||
|
||||
# argparse removed - use experiments/reasoning/run_baseline.py for CLI
|
||||
|
||||
# Global dictionary of rate limiters per model
|
||||
model_rate_limiters: Dict[str, "TokenRateLimiter"] = {}
|
||||
# Use environment variable or default to current directory
|
||||
data_dir = os.environ.get("EXPERIMENT_DATA_DIR", "./data/")
|
||||
|
||||
|
||||
def get_rate_limiter(model: str, max_tpm: int) -> "TokenRateLimiter":
|
||||
if model not in model_rate_limiters:
|
||||
model_rate_limiters[model] = TokenRateLimiter(max_tpm)
|
||||
return model_rate_limiters[model]
|
||||
|
||||
|
||||
class TokenRateLimiter:
|
||||
def __init__(self, max_tpm):
|
||||
self.max_tpm = max_tpm
|
||||
self.tokens_used = 0
|
||||
self.lock = threading.Lock()
|
||||
self.reset_time = time.time() + 60 # 60 seconds window
|
||||
|
||||
def allow(self, tokens):
|
||||
with self.lock:
|
||||
now = time.time()
|
||||
if now >= self.reset_time:
|
||||
self.tokens_used = 0
|
||||
self.reset_time = now + 60
|
||||
if self.tokens_used + tokens > self.max_tpm:
|
||||
return False
|
||||
self.tokens_used += tokens
|
||||
return True
|
||||
|
||||
def wait_for_slot(self, tokens):
|
||||
while not self.allow(tokens):
|
||||
time_to_wait = max(0, self.reset_time - time.time())
|
||||
time.sleep(time_to_wait)
|
||||
|
||||
|
||||
def count_tokens(messages):
|
||||
# messages should be a list of dicts, each with a "content" key
|
||||
total_chars = sum(
|
||||
len(m.get("content", "")) for m in messages if isinstance(m, dict)
|
||||
)
|
||||
return max(1, total_chars // 4)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 🔒 Context-window safety helpers
|
||||
# ------------------------------------------------------------------
|
||||
# Maximum number of tokens we will allow in the prompt we send to the model.
|
||||
# The Azure GPT-5 family allows 272,000 tokens.
|
||||
MAX_CONTEXT_TOKENS = 270_000
|
||||
|
||||
|
||||
def _trim_history(history: list, keep_system_first: bool = True) -> list:
|
||||
"""Trim the conversation history in-place so its estimated token count
|
||||
(via ``count_tokens``) does not exceed ``MAX_CONTEXT_TOKENS``.
|
||||
|
||||
We always keep the very first system message and the first user message so the
|
||||
assistant retains the global instructions and the initial query context. After
|
||||
that we drop the oldest messages until the budget is satisfied. Returns the
|
||||
trimmed history list.
|
||||
"""
|
||||
|
||||
# Determine starting index to preserve the initial system message and first user message
|
||||
start_idx = 0
|
||||
if keep_system_first and history:
|
||||
if history[0].get("role") == "system":
|
||||
start_idx = 1
|
||||
# Find the first user message after the system message
|
||||
for i in range(1, len(history)):
|
||||
if history[i].get("role") == "user":
|
||||
start_idx = i + 1
|
||||
break
|
||||
elif history[0].get("role") == "user":
|
||||
# If first message is user, keep it and find the next user message
|
||||
start_idx = 1
|
||||
for i in range(1, len(history)):
|
||||
if history[i].get("role") == "user":
|
||||
start_idx = i + 1
|
||||
break
|
||||
|
||||
# Drop oldest messages (just after the preserved block) until within limit
|
||||
while len(history) > start_idx + 1 and count_tokens(history) > MAX_CONTEXT_TOKENS:
|
||||
history.pop(start_idx)
|
||||
|
||||
return history
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
directive: str
|
||||
operators: List[str]
|
||||
|
||||
|
||||
def get_openai_response(
|
||||
input_query,
|
||||
input_schema,
|
||||
input_data_sample,
|
||||
model="o3",
|
||||
max_tpm=5000000,
|
||||
message_history=[],
|
||||
curr_plan_output="",
|
||||
prev_plan_cost: float = 0.0,
|
||||
iteration=1,
|
||||
):
|
||||
"""
|
||||
The first LLM call. Generates a rewrite plan given the rewrite directives.
|
||||
"""
|
||||
|
||||
if iteration == 1:
|
||||
user_message = f"""
|
||||
I have a set of operations used to process long documents, along with a list of possible rewrite directives aimed at improving the quality of the query result.
|
||||
Given a query pipeline made up of these operations, recommend one specific rewrite directive (specify by its name) that would improve accuracy and specify which operators (specify by the names) in the pipeline the directive should be applied to.
|
||||
Make sure that your cosen directive is in the provided list of rewrite directives.
|
||||
Pipeline:
|
||||
Pipelines in DocETL are the core structures that define the flow of data processing. A pipeline consists of five main components: \n
|
||||
- Default Model: The language model to use for the pipeline. Limit your choice of model to gpt-5-nano, gpt-4o-mini, gpt-5 \n
|
||||
- System Prompts: A description of your dataset and the "persona" you'd like the LLM to adopt when analyzing your data. \n
|
||||
- Datasets: The input data sources for your pipeline. \n
|
||||
- Operators: The processing steps that transform your data. \n
|
||||
- Pipeline Specification: The sequence of steps and the output configuration. \n
|
||||
|
||||
Operators:
|
||||
Operators form the building blocks of data processing pipelines. Below is the list of operators:
|
||||
{op_map.to_string()}\n
|
||||
{op_extract.to_string()}\n
|
||||
{op_parallel_map.to_string()}\n
|
||||
{op_filter.to_string()}\n
|
||||
{op_reduce.to_string()}\n
|
||||
{op_split.to_string()}\n
|
||||
{op_gather.to_string()}\n
|
||||
{op_unnest.to_string()}\n
|
||||
{op_sample.to_string()}\n
|
||||
{op_resolve.to_string()}\n
|
||||
|
||||
Rewrite directives:
|
||||
{get_all_directive_strings()}\n
|
||||
|
||||
Input document schema with token statistics: {input_schema} \n
|
||||
Cost of previous plan execution: ${prev_plan_cost:.4f} \n
|
||||
The original query in YAML format using our operations: {input_query} \n
|
||||
Input data sample: {json.dumps(input_data_sample, indent=2)[:3000]} \n
|
||||
Sample of the result from executing the original query: {json.dumps(curr_plan_output, indent=2)[:3000]} \n
|
||||
"""
|
||||
else:
|
||||
user_message = f"""
|
||||
Given the previously rewritten pipeline, recommend one specific rewrite directive (specify by its name) that would improve accuracy and specify which operator (specify by the name) in the pipeline the directive should be applied to.
|
||||
Make sure that your cosen directive is in the provided list of rewrite directives.
|
||||
Rewrite directives:
|
||||
{get_all_directive_strings()}\n
|
||||
|
||||
Cost of previous plan execution: ${prev_plan_cost:.4f} \n
|
||||
The original query in YAML format using our operations: {input_query} \n
|
||||
Sample of the result from executing the original query: {json.dumps(curr_plan_output, indent=2)[:3000]} \n
|
||||
"""
|
||||
|
||||
if len(message_history) == 0:
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an expert query optimization agent for document processing pipelines. Your role is to analyze user queries and apply rewrite directives to create more accurate execution plans. Your output must follow the structured output format.",
|
||||
},
|
||||
{"role": "user", "content": user_message},
|
||||
]
|
||||
)
|
||||
else:
|
||||
message_history.append({"role": "user", "content": user_message})
|
||||
|
||||
# Trim the history to prevent context window overflow before sending to the model
|
||||
message_history = _trim_history(message_history)
|
||||
|
||||
messages = message_history
|
||||
|
||||
# Enforce rate limit for the specified model
|
||||
if max_tpm > 0:
|
||||
limiter = get_rate_limiter(model, max_tpm)
|
||||
tokens = count_tokens(messages)
|
||||
limiter.wait_for_slot(tokens)
|
||||
|
||||
# Count the number of tokens in the messages for debugging/monitoring
|
||||
num_tokens = count_tokens(messages)
|
||||
print(f"Token count for current messages: {num_tokens}")
|
||||
# litellm._turn_on_debug()
|
||||
response = litellm.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=ResponseFormat,
|
||||
)
|
||||
assistant_response = response.choices[0].message.content
|
||||
|
||||
# Add user and assistant messages to message_history as dicts
|
||||
message_history.append({"role": "assistant", "content": assistant_response})
|
||||
return assistant_response, message_history
|
||||
|
||||
|
||||
def update_yaml_operations(input_file_path, output_file_path, new_operations):
|
||||
"""
|
||||
Load a YAML file, replace the operations section, and save to a new file.
|
||||
|
||||
Args:
|
||||
input_file_path (str): Path to the original YAML file
|
||||
output_file_path (str): Path where the modified YAML will be saved
|
||||
new_operations (list): List of operation dictionaries to replace the original operations
|
||||
"""
|
||||
# Load the original YAML file
|
||||
with open(input_file_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
# Replace the operations section
|
||||
config["operations"] = new_operations
|
||||
|
||||
# Write the modified config to a new YAML file
|
||||
with open(output_file_path, "w") as file:
|
||||
yaml.dump(
|
||||
config, file, default_flow_style=False, allow_unicode=True, sort_keys=False
|
||||
)
|
||||
|
||||
print(f"Modified YAML saved to: {output_file_path}")
|
||||
|
||||
|
||||
def update_pipeline(orig_config, new_ops_list, target_ops):
|
||||
"""
|
||||
Update the pipeline configuration with new operations.
|
||||
|
||||
Args:
|
||||
orig_config (dict): The original pipeline configuration
|
||||
new_ops_list (list): The entire pipeline operations list (not a subset)
|
||||
target_ops (list): List of target operation names to replace
|
||||
|
||||
Returns:
|
||||
dict: Updated pipeline configuration
|
||||
"""
|
||||
if new_ops_list is not None:
|
||||
op_names = [op.get("name") for op in new_ops_list if "name" in op]
|
||||
|
||||
# Update the pipeline steps to use the new operation names
|
||||
if "pipeline" in orig_config and "steps" in orig_config["pipeline"]:
|
||||
for step in orig_config["pipeline"]["steps"]:
|
||||
if "operations" in step:
|
||||
new_ops = []
|
||||
for op in step["operations"]:
|
||||
if op == target_ops[0]:
|
||||
new_ops.extend(op_names)
|
||||
step["operations"] = new_ops
|
||||
|
||||
return orig_config
|
||||
|
||||
|
||||
def fix_models(parsed_yaml):
|
||||
"""No-op: Model names should be specified correctly in the YAML."""
|
||||
pass
|
||||
|
||||
|
||||
def update_sample(new_ops_list, target_ops, orig_operators):
|
||||
"""
|
||||
Update sample settings in new operations based on original operators.
|
||||
|
||||
Args:
|
||||
new_ops_list (list): List of new operations to update
|
||||
target_ops (list): List of target operation names
|
||||
orig_operators (list): List of original operators
|
||||
|
||||
Returns:
|
||||
list: Updated new operations list with sample settings
|
||||
"""
|
||||
# Build a mapping from op name to op config in orig_operators
|
||||
op_name_to_config = {op.get("name"): op for op in orig_operators if "name" in op}
|
||||
|
||||
# For each op in new_ops_list, if the corresponding op in orig_operators has 'sample', add it
|
||||
|
||||
sample_size = -1
|
||||
for target_op_name in target_ops:
|
||||
target_op = op_name_to_config[target_op_name]
|
||||
if "sample" in target_op:
|
||||
sample_size = target_op["sample"]
|
||||
|
||||
print("SAMPLE SIZE: ", sample_size)
|
||||
|
||||
for op in new_ops_list:
|
||||
if sample_size != -1:
|
||||
op["sample"] = sample_size
|
||||
|
||||
return new_ops_list
|
||||
|
||||
|
||||
def save_message_history(message_history, filepath):
|
||||
"""
|
||||
Save message history to a JSON file.
|
||||
|
||||
Args:
|
||||
message_history (list): List of message dictionaries
|
||||
filepath (str): Path to save the message history
|
||||
"""
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(message_history, f, indent=2)
|
||||
print(f"Message history saved to: {filepath}")
|
||||
|
||||
|
||||
def load_message_history(filepath):
|
||||
"""
|
||||
Load message history from a JSON file.
|
||||
|
||||
Args:
|
||||
filepath (str): Path to the message history file
|
||||
|
||||
Returns:
|
||||
list: List of message dictionaries, or empty list if file doesn't exist
|
||||
"""
|
||||
if os.path.exists(filepath):
|
||||
with open(filepath, "r") as f:
|
||||
return json.load(f)
|
||||
return []
|
||||
|
||||
|
||||
def run_single_iteration(
|
||||
yaml_path,
|
||||
model,
|
||||
max_tpm,
|
||||
message_history,
|
||||
iteration_num,
|
||||
orig_output_sample,
|
||||
prev_plan_cost: float,
|
||||
output_dir=None,
|
||||
dataset="cuad",
|
||||
sample_data=None,
|
||||
):
|
||||
"""
|
||||
Run a single iteration of the optimization process.
|
||||
|
||||
Args:
|
||||
yaml_path (str): Path to the YAML file
|
||||
model (str): Model name
|
||||
max_tpm (int): Tokens per minute limit
|
||||
message_history (list): Cumulative message history
|
||||
iteration_num (int): Current iteration number
|
||||
|
||||
Returns:
|
||||
tuple: (output_file_path, updated_message_history)
|
||||
"""
|
||||
print(f"\n=== Running Iteration {iteration_num} ===")
|
||||
print(f"Input file: {yaml_path}")
|
||||
|
||||
# Parse input yaml file to get the list of operations
|
||||
orig_config = load_config(yaml_path)
|
||||
orig_operators = orig_config["operations"]
|
||||
|
||||
# Use provided sample data
|
||||
random_sample = sample_data if sample_data is not None else []
|
||||
|
||||
with open(yaml_path, "r") as f:
|
||||
input_query = f.read()
|
||||
|
||||
with open(yaml_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
global_default_model = config.get("default_model")
|
||||
datasets = config.get("datasets", {})
|
||||
input_file_path = None
|
||||
if isinstance(datasets, dict) and datasets:
|
||||
first_dataset = next(iter(datasets.values()))
|
||||
if isinstance(first_dataset, dict):
|
||||
input_file_path = first_dataset.get("path")
|
||||
|
||||
input_schema = load_input_doc(yaml_path)
|
||||
|
||||
reply, message_history = get_openai_response(
|
||||
input_query,
|
||||
input_schema,
|
||||
random_sample,
|
||||
model=model,
|
||||
max_tpm=max_tpm,
|
||||
message_history=message_history,
|
||||
curr_plan_output=orig_output_sample,
|
||||
prev_plan_cost=prev_plan_cost,
|
||||
iteration=iteration_num,
|
||||
)
|
||||
|
||||
# Use output_dir if provided, otherwise fall back to data_dir
|
||||
save_dir = output_dir if output_dir else data_dir
|
||||
|
||||
# Parse agent response
|
||||
try:
|
||||
parsed = json.loads(reply)
|
||||
directive = parsed.get("directive")
|
||||
target_ops = parsed.get("operators")
|
||||
print(f"Directive: {directive}, Target ops: {target_ops}")
|
||||
|
||||
# Log directive and target ops to baseline_log.txt in the same directory as output YAML
|
||||
log_file_path = os.path.join(save_dir, "baseline_log.txt")
|
||||
log_message = f"Iteration {iteration_num}: Directive: {directive}, Target ops: {target_ops}\n"
|
||||
with open(log_file_path, "a") as log_file:
|
||||
log_file.write(log_message)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to parse agent response: {e}")
|
||||
return None, message_history
|
||||
|
||||
try:
|
||||
new_ops_list, message_history, cost = instantiate_directive(
|
||||
directive_name=directive,
|
||||
operators=orig_operators,
|
||||
target_ops=target_ops,
|
||||
agent_llm=model,
|
||||
message_history=message_history,
|
||||
global_default_model=global_default_model,
|
||||
input_file_path=input_file_path,
|
||||
pipeline_code=orig_config,
|
||||
dataset=dataset,
|
||||
)
|
||||
orig_config["operations"] = new_ops_list
|
||||
|
||||
# Update pipeline steps to reflect new operation names
|
||||
orig_config = update_pipeline(orig_config, new_ops_list, target_ops)
|
||||
|
||||
# Apply special post-processing for chaining directive
|
||||
if directive == "chaining":
|
||||
new_ops_list = update_sample(new_ops_list, target_ops, orig_operators)
|
||||
orig_config["operations"] = new_ops_list
|
||||
|
||||
# Ensure all model references start with 'azure/'
|
||||
fix_models(orig_config)
|
||||
|
||||
except ValueError as e:
|
||||
print(f"Failed to instantiate directive '{directive}': {e}")
|
||||
return None, message_history
|
||||
|
||||
output_file_path = os.path.join(
|
||||
save_dir,
|
||||
f"iteration_{iteration_num}.yaml",
|
||||
)
|
||||
|
||||
# Model names should be specified correctly in the YAML - no automatic prefixing
|
||||
|
||||
# Add bypass_cache: true at the top level
|
||||
orig_config["bypass_cache"] = True
|
||||
|
||||
# Save the modified config
|
||||
with open(output_file_path, "w") as file:
|
||||
yaml.dump(
|
||||
orig_config,
|
||||
file,
|
||||
default_flow_style=False,
|
||||
allow_unicode=True,
|
||||
sort_keys=False,
|
||||
)
|
||||
|
||||
print(f"Modified YAML saved to: {output_file_path}")
|
||||
|
||||
# Execute the pipeline to get cost and sample outputs for next iteration
|
||||
total_cost = 0.0
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from docetl.runner import DSLRunner
|
||||
|
||||
# Update output path if output_dir is provided
|
||||
if output_dir:
|
||||
json_output_path = os.path.join(
|
||||
output_dir, f"iteration_{iteration_num}_results.json"
|
||||
)
|
||||
orig_config["pipeline"]["output"]["path"] = json_output_path
|
||||
|
||||
# Save updated YAML with new output path
|
||||
with open(output_file_path, "w") as file:
|
||||
yaml.dump(
|
||||
orig_config,
|
||||
file,
|
||||
default_flow_style=False,
|
||||
allow_unicode=True,
|
||||
sort_keys=False,
|
||||
)
|
||||
|
||||
# Load environment
|
||||
cwd = os.getcwd()
|
||||
env_file = os.path.join(cwd, ".env")
|
||||
if os.path.exists(env_file):
|
||||
load_dotenv(env_file)
|
||||
|
||||
print("🔄 Executing pipeline to get cost and sample outputs...")
|
||||
runner = DSLRunner.from_yaml(output_file_path)
|
||||
runner.load()
|
||||
|
||||
if runner.last_op_container:
|
||||
result_data, _, _ = runner.last_op_container.next()
|
||||
runner.save(result_data)
|
||||
total_cost = runner.total_cost
|
||||
print(f"✅ Pipeline executed successfully, cost: ${total_cost:.4f}")
|
||||
else:
|
||||
print("⚠️ No results from pipeline execution")
|
||||
raise Exception("No results from pipeline execution")
|
||||
|
||||
runner.reset_env()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Pipeline execution failed: {e}")
|
||||
raise e
|
||||
|
||||
return output_file_path, message_history, total_cost
|
||||
|
||||
|
||||
# Use experiments/reasoning/run_baseline.py to run experiments
|
||||
|
|
@ -1,458 +0,0 @@
|
|||
# Directives - How to Add a New Directive
|
||||
|
||||
This guide explains how to add a new directive to the reasoning optimizer. Directives are transformations that can be applied to pipeline operations to improve their effectiveness.
|
||||
|
||||
## What is a Directive?
|
||||
|
||||
A directive is a transformation rule that modifies pipeline operations. For example:
|
||||
- **Chaining**: Breaks complex operations into sequential steps
|
||||
- **Gleaning**: Adds validation loops to improve output quality
|
||||
- **Change Model**: Switches the LLM model for better performance
|
||||
- **Doc Summarization**: Adds preprocessing to summarize long documents
|
||||
|
||||
## Key Concepts
|
||||
|
||||
### What is an Instantiate Schema?
|
||||
|
||||
An **instantiate schema** is a Pydantic model that defines the structured output an agent (LLM) must produce to apply a directive. It acts as the "configuration blueprint" that tells the system exactly how to transform an operation.
|
||||
|
||||
For example, when applying the **gleaning** directive:
|
||||
1. The agent looks at the original operation
|
||||
2. The agent outputs a `GleaningInstantiateSchema` containing:
|
||||
```python
|
||||
{
|
||||
"gleaning_config": {
|
||||
"validation_prompt": "Check that the output has at least 2 insights...",
|
||||
"num_rounds": 3,
|
||||
"model": "gpt-4o-mini"
|
||||
}
|
||||
}
|
||||
```
|
||||
3. The directive uses this schema to add gleaning configuration to the operation
|
||||
|
||||
The instantiate schema ensures the agent provides all required parameters in the correct format for the directive to work.
|
||||
|
||||
### Workflow Overview
|
||||
|
||||
1. **Agent Analysis**: Agent examines the original operation and determines how to apply the directive
|
||||
2. **Schema Generation**: Agent outputs structured configuration using the instantiate schema format
|
||||
3. **Directive Application**: The directive's `apply()` method uses this configuration to transform the pipeline
|
||||
4. **Validation**: Schema validators ensure the configuration is valid before application
|
||||
|
||||
## Quick Start - Adding a New Directive
|
||||
|
||||
### 1. Define Your Instantiate Schema
|
||||
|
||||
First, add your schema classes to `docetl/reasoning_optimizer/instantiate_schemas.py`. The instantiate schema defines what the agent must output:
|
||||
|
||||
```python
|
||||
class MyDirectiveConfig(BaseModel):
|
||||
"""Configuration parameters for your directive."""
|
||||
param1: str = Field(..., description="Description of param1")
|
||||
param2: int = Field(default=3, description="Description of param2")
|
||||
model: str = Field(default="gpt-4o-mini", description="The LLM model to use")
|
||||
|
||||
class MyDirectiveInstantiateSchema(BaseModel):
|
||||
"""
|
||||
Schema that the agent must output to instantiate this directive.
|
||||
This is what gets returned by the LLM when asked to apply the directive.
|
||||
"""
|
||||
my_directive_config: MyDirectiveConfig = Field(
|
||||
..., description="The configuration to apply to the target operation"
|
||||
)
|
||||
|
||||
# Add validators if needed
|
||||
@field_validator("my_directive_config")
|
||||
@classmethod
|
||||
def validate_config(cls, v):
|
||||
# Add validation logic here
|
||||
return v
|
||||
```
|
||||
|
||||
### 2. Create Your Directive Class
|
||||
|
||||
Create a new file in this directory (e.g., `my_directive.py`):
|
||||
|
||||
```python
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import MyDirectiveInstantiateSchema
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
class MyDirective(Directive):
|
||||
name: str = Field(default="my_directive", description="The name of the directive")
|
||||
formal_description: str = Field(default="Op => Modified_Op")
|
||||
nl_description: str = Field(
|
||||
default="Natural language description of what this directive does"
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When to apply this directive (specific use cases)"
|
||||
)
|
||||
|
||||
# This tells the system what schema format the agent should output
|
||||
instantiate_schema_type: Type[BaseModel] = Field(default=MyDirectiveInstantiateSchema)
|
||||
|
||||
example: str = Field(
|
||||
default="""
|
||||
Original Op (MapOpConfig):
|
||||
- name: example_op
|
||||
type: map
|
||||
prompt: |
|
||||
Example prompt: {{ input.document }}
|
||||
output:
|
||||
schema:
|
||||
result: "string"
|
||||
|
||||
Example InstantiateSchema (what the agent should output):
|
||||
MyDirectiveConfig(
|
||||
param1="example_value",
|
||||
param2=5,
|
||||
model="gpt-4o-mini"
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="basic_functionality",
|
||||
description="Should apply directive transformation correctly",
|
||||
input_config={
|
||||
"name": "test_op",
|
||||
"type": "map",
|
||||
"prompt": "Test prompt: {{ input.text }}",
|
||||
"output": {"schema": {"result": "string"}},
|
||||
},
|
||||
target_ops=["test_op"],
|
||||
expected_behavior="Should modify the operation with directive configuration",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, MyDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("MyDirective")
|
||||
```
|
||||
|
||||
### 3. Implement Required Methods
|
||||
|
||||
```python
|
||||
def to_string_for_instantiate(self, original_op: Dict) -> str:
|
||||
"""
|
||||
Generate a prompt that asks the agent to output the instantiate schema.
|
||||
This prompt explains to the LLM what configuration it needs to generate.
|
||||
"""
|
||||
return (
|
||||
f"You are an expert at [specific domain expertise for your directive].\n\n"
|
||||
f"Original Operation:\n"
|
||||
f"{str(original_op)}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to instantiate this directive by generating a MyDirectiveConfig "
|
||||
f"that specifies [specific instructions for what the directive should do].\n\n"
|
||||
f"The agent must output the configuration in this exact format:\n"
|
||||
f"- param1: [explanation of how to set this]\n"
|
||||
f"- param2: [explanation of how to set this]\n"
|
||||
f"- model: [which model to use]\n\n"
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output only the InstantiateSchema (MyDirectiveConfig object) "
|
||||
f"that specifies how to apply this directive to the original operation."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
original_op: Dict,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
) -> tuple:
|
||||
"""
|
||||
Call the LLM to generate the instantiate schema.
|
||||
The LLM will output structured data matching MyDirectiveInstantiateSchema.
|
||||
"""
|
||||
|
||||
message_history.extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(original_op),
|
||||
},
|
||||
])
|
||||
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=MyDirectiveInstantiateSchema, # Forces structured output
|
||||
)
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
if "my_directive_config" not in parsed_res:
|
||||
raise ValueError("Response missing required key 'my_directive_config'")
|
||||
|
||||
config = parsed_res["my_directive_config"]
|
||||
schema = MyDirectiveInstantiateSchema(my_directive_config=config)
|
||||
message_history.append({
|
||||
"role": "assistant",
|
||||
"content": resp.choices[0].message.content
|
||||
})
|
||||
return schema, message_history
|
||||
except Exception as err:
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts.")
|
||||
|
||||
def apply(
|
||||
self, ops_list: List[Dict], target_op: str, rewrite: MyDirectiveInstantiateSchema
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive using the instantiate schema configuration.
|
||||
The 'rewrite' parameter contains the agent's generated configuration.
|
||||
"""
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find the target operation
|
||||
pos_to_replace = [i for i, op in enumerate(ops_list) if op["name"] == target_op][0]
|
||||
target_operator = new_ops_list[pos_to_replace]
|
||||
|
||||
# Apply transformation using the agent's configuration
|
||||
target_operator["my_directive_param"] = rewrite.my_directive_config.param1
|
||||
target_operator["model"] = rewrite.my_directive_config.model
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
**kwargs,
|
||||
) -> tuple:
|
||||
"""
|
||||
Main method that orchestrates directive instantiation:
|
||||
1. Get agent to generate instantiate schema
|
||||
2. Apply the transformation using that schema
|
||||
"""
|
||||
assert len(target_ops) == 1, "This directive requires exactly one target op"
|
||||
|
||||
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
|
||||
|
||||
# Step 1: Agent generates the instantiate schema
|
||||
rewrite, message_history = self.llm_instantiate(
|
||||
target_op_config, agent_llm, message_history
|
||||
)
|
||||
|
||||
# Step 2: Apply transformation using the schema
|
||||
return self.apply(operators, target_ops[0], rewrite), message_history
|
||||
```
|
||||
|
||||
### 4. Register Your Directive
|
||||
|
||||
Add your directive to `__init__.py`:
|
||||
|
||||
```python
|
||||
from .my_directive import MyDirective
|
||||
|
||||
ALL_DIRECTIVES = [
|
||||
ChainingDirective(),
|
||||
GleaningDirective(),
|
||||
ChangeModelDirective(),
|
||||
DocSummarizationDirective(),
|
||||
MyDirective(), # Add your directive here
|
||||
]
|
||||
```
|
||||
|
||||
### 5. Update Test Runner
|
||||
|
||||
Add your directive to `tests/reasoning_optimizer/test_runner.py` in two places:
|
||||
|
||||
1. **Import section**:
|
||||
```python
|
||||
from docetl.reasoning_optimizer.directives import (
|
||||
ChainingDirective,
|
||||
GleaningDirective,
|
||||
ChangeModelDirective,
|
||||
DocSummarizationDirective,
|
||||
IsolatingSubtasksDirective,
|
||||
MyDirective, # Add your directive here
|
||||
TestResult
|
||||
)
|
||||
```
|
||||
|
||||
2. **Both directive lists**:
|
||||
```python
|
||||
# In run_all_directive_tests function
|
||||
directives = [
|
||||
ChainingDirective(),
|
||||
GleaningDirective(),
|
||||
ChangeModelDirective(),
|
||||
DocSummarizationDirective(),
|
||||
IsolatingSubtasksDirective(),
|
||||
MyDirective() # Add your directive here
|
||||
]
|
||||
|
||||
# In run_specific_directive_test function
|
||||
directive_map = {
|
||||
"chaining": ChainingDirective(),
|
||||
"gleaning": GleaningDirective(),
|
||||
"change_model": ChangeModelDirective(),
|
||||
"doc_summarization": DocSummarizationDirective(),
|
||||
"isolating_subtasks": IsolatingSubtasksDirective(),
|
||||
"my_directive": MyDirective() # Add your directive here
|
||||
}
|
||||
```
|
||||
|
||||
## Real Examples of Instantiate Schemas
|
||||
|
||||
### Gleaning Directive
|
||||
```python
|
||||
# What the agent outputs:
|
||||
{
|
||||
"gleaning_config": {
|
||||
"validation_prompt": "Check that the output contains at least 2 insights and each has supporting actions",
|
||||
"num_rounds": 3,
|
||||
"model": "gpt-4o-mini"
|
||||
}
|
||||
}
|
||||
|
||||
# How it's applied: Adds gleaning configuration to the operation
|
||||
target_operator["gleaning"] = {
|
||||
"validation_prompt": rewrite.gleaning_config.validation_prompt,
|
||||
"num_rounds": rewrite.gleaning_config.num_rounds,
|
||||
"model": rewrite.gleaning_config.model,
|
||||
}
|
||||
```
|
||||
|
||||
### Chaining Directive
|
||||
```python
|
||||
# What the agent outputs:
|
||||
{
|
||||
"new_ops": [
|
||||
{
|
||||
"name": "extract_conditions",
|
||||
"prompt": "Identify new medical conditions from: {{ input.summary }}",
|
||||
"output_keys": ["conditions"],
|
||||
"model": "gpt-4o-mini"
|
||||
},
|
||||
{
|
||||
"name": "extract_treatments",
|
||||
"prompt": "Extract treatments for conditions {{ input.conditions }} from {{ input.summary }}",
|
||||
"output_keys": ["treatments"],
|
||||
"model": "gpt-4o-mini"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# How it's applied: Replaces one operation with multiple chained operations
|
||||
```
|
||||
|
||||
### Change Model Directive
|
||||
```python
|
||||
# What the agent outputs:
|
||||
{
|
||||
"change_model_config": {
|
||||
"model": "gpt-4o"
|
||||
}
|
||||
}
|
||||
|
||||
# How it's applied: Changes the model field
|
||||
target_operator["model"] = rewrite.change_model_config.model
|
||||
```
|
||||
|
||||
## Testing Your Directive
|
||||
|
||||
### Individual Directive Testing
|
||||
|
||||
Test your directive by running its test cases:
|
||||
|
||||
```python
|
||||
from docetl.reasoning_optimizer.directives import MyDirective
|
||||
|
||||
directive = MyDirective()
|
||||
test_results = directive.run_tests(agent_llm="gpt-4o-mini")
|
||||
|
||||
for result in test_results:
|
||||
print(f"{result.test_name}: {'PASS' if result.passed else 'FAIL'}")
|
||||
print(f"Reason: {result.reason}")
|
||||
```
|
||||
|
||||
### Command Line Testing
|
||||
|
||||
Run directive instantiation tests from the command line:
|
||||
|
||||
```bash
|
||||
# Test a specific directive
|
||||
python experiments/reasoning/run_tests.py --directive=isolating_subtasks
|
||||
|
||||
# Test all directive instantiation tests
|
||||
python experiments/reasoning/run_tests.py
|
||||
```
|
||||
|
||||
### Apply Method Testing
|
||||
|
||||
Test that directive `apply()` methods work correctly:
|
||||
|
||||
```bash
|
||||
# Test all directive apply methods
|
||||
python tests/reasoning_optimizer/test_directive_apply.py
|
||||
```
|
||||
|
||||
This ensures the `apply()` method doesn't crash when given realistic pipeline configurations and rewrite schemas.
|
||||
|
||||
### Integration Testing
|
||||
|
||||
Full pipeline integration testing can be done via `experiments/reasoning/run_mcts.py`.
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### 1. Single Operation Modification
|
||||
Adds configuration to existing operation:
|
||||
- **Gleaning**: Adds validation config
|
||||
- **Change Model**: Modifies model parameter
|
||||
|
||||
### 2. Operation Replacement
|
||||
Replaces one operation with multiple:
|
||||
- **Chaining**: Creates sequence of simpler operations
|
||||
|
||||
### 3. Pipeline Preprocessing
|
||||
Adds operations at pipeline start:
|
||||
- **Doc Summarization**: Adds summarization step before main processing
|
||||
|
||||
## File Structure Summary
|
||||
|
||||
```
|
||||
docetl/reasoning_optimizer/
|
||||
├── directives/
|
||||
│ ├── my_directive.py # Your directive implementation
|
||||
│ ├── __init__.py # Register in ALL_DIRECTIVES
|
||||
│ └── README.md # This file
|
||||
└── instantiate_schemas.py # Define your schema classes
|
||||
|
||||
experiments/reasoning/ # Testing framework
|
||||
└── run_mcts.py # Full pipeline testing
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Schema First**: Design the instantiate schema before implementing the directive - it defines the interface
|
||||
2. **Clear Agent Instructions**: The `to_string_for_instantiate()` method should clearly explain what the agent needs to output
|
||||
3. **Validation**: Use Pydantic validators in your schema to catch invalid configurations
|
||||
4. **Error Handling**: Handle LLM failures gracefully with retry logic
|
||||
5. **Comprehensive Testing**: Test edge cases where the agent might output invalid configurations
|
||||
6. **Documentation**: Clearly document what your instantiate schema fields mean and how they're used
|
||||
|
||||
The instantiate schema is the critical bridge between the agent's reasoning and your directive's implementation - design it carefully!
|
||||
|
|
@ -1,192 +0,0 @@
|
|||
from typing import Dict, List
|
||||
from .base import (
|
||||
Directive,
|
||||
DirectiveTestCase,
|
||||
TestResult,
|
||||
MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS,
|
||||
DEFAULT_MODEL,
|
||||
DEFAULT_MAX_TPM,
|
||||
DEFAULT_OUTPUT_DIR
|
||||
)
|
||||
from .chaining import ChainingDirective
|
||||
from .gleaning import GleaningDirective
|
||||
from .change_model import ChangeModelDirective
|
||||
from .change_model_acc import ChangeModelAccDirective # noqa: F401
|
||||
from .change_model_cost import ChangeModelCostDirective
|
||||
from .doc_summarization import DocSummarizationDirective
|
||||
from .isolating_subtasks import IsolatingSubtasksDirective
|
||||
# from .doc_compression import DocCompressionDirective
|
||||
from .deterministic_doc_compression import DeterministicDocCompressionDirective
|
||||
from .reduce_gleaning import ReduceGleaningDirective
|
||||
from .reduce_chaining import ReduceChainingDirective
|
||||
from .operator_fusion import OperatorFusionDirective
|
||||
from .doc_chunking import DocumentChunkingDirective
|
||||
from .doc_chunking_topk import DocumentChunkingTopKDirective
|
||||
from .chunk_header_summary import ChunkHeaderSummaryDirective
|
||||
from .take_head_tail import TakeHeadTailDirective
|
||||
from .clarify_instructions import ClarifyInstructionsDirective
|
||||
from .swap_with_code import SwapWithCodeDirective
|
||||
from .map_reduce_fusion import MapReduceFusionDirective
|
||||
from .hierarchical_reduce import HierarchicalReduceDirective
|
||||
from .cascade_filtering import CascadeFilteringDirective
|
||||
from .arbitrary_rewrite import ArbitraryRewriteDirective
|
||||
from .map_to_map_resolve_reduce import MapToMapResolveReduceDirective
|
||||
from .map_resolve_to_map_with_categories import MapResolveToMapWithCategoriesDirective
|
||||
|
||||
# Registry of all available directives
|
||||
ALL_DIRECTIVES = [
|
||||
ChainingDirective(),
|
||||
GleaningDirective(),
|
||||
ReduceGleaningDirective(),
|
||||
ReduceChainingDirective(),
|
||||
ChangeModelCostDirective(),
|
||||
DocSummarizationDirective(),
|
||||
IsolatingSubtasksDirective(),
|
||||
# DocCompressionDirective(),
|
||||
DeterministicDocCompressionDirective(),
|
||||
OperatorFusionDirective(),
|
||||
DocumentChunkingDirective(),
|
||||
DocumentChunkingTopKDirective(),
|
||||
ChunkHeaderSummaryDirective(),
|
||||
TakeHeadTailDirective(),
|
||||
ClarifyInstructionsDirective(),
|
||||
SwapWithCodeDirective(),
|
||||
MapReduceFusionDirective(),
|
||||
HierarchicalReduceDirective(),
|
||||
CascadeFilteringDirective(),
|
||||
ArbitraryRewriteDirective(),
|
||||
MapToMapResolveReduceDirective(),
|
||||
MapResolveToMapWithCategoriesDirective(),
|
||||
]
|
||||
|
||||
ALL_COST_DIRECTIVES = [
|
||||
ReduceChainingDirective(),
|
||||
DocSummarizationDirective(),
|
||||
# DocCompressionDirective(),
|
||||
DeterministicDocCompressionDirective(),
|
||||
DocumentChunkingTopKDirective(),
|
||||
OperatorFusionDirective(),
|
||||
TakeHeadTailDirective(),
|
||||
MapReduceFusionDirective(),
|
||||
CascadeFilteringDirective(),
|
||||
ArbitraryRewriteDirective(),
|
||||
]
|
||||
|
||||
DIRECTIVE_GROUPS = {
|
||||
"compression": [
|
||||
# DocCompressionDirective(),
|
||||
DocSummarizationDirective(),
|
||||
DeterministicDocCompressionDirective(),
|
||||
],
|
||||
"chunking": [
|
||||
DocumentChunkingDirective(),
|
||||
DocumentChunkingTopKDirective(),
|
||||
]
|
||||
}
|
||||
|
||||
MULTI_INSTANCE_DIRECTIVES = [
|
||||
DocumentChunkingDirective(),
|
||||
DocumentChunkingTopKDirective(),
|
||||
DeterministicDocCompressionDirective(),
|
||||
TakeHeadTailDirective(),
|
||||
CascadeFilteringDirective(),
|
||||
ClarifyInstructionsDirective(),
|
||||
]
|
||||
|
||||
# Create a mapping from directive names to directive instances
|
||||
DIRECTIVE_REGISTRY = {directive.name: directive for directive in ALL_DIRECTIVES}
|
||||
|
||||
def get_all_directive_strings() -> str:
|
||||
"""
|
||||
Generate string descriptions for all available directives for use in prompts.
|
||||
|
||||
Returns:
|
||||
str: Formatted string containing all directive descriptions
|
||||
"""
|
||||
return "\n".join([directive.to_string_for_plan() for directive in ALL_DIRECTIVES])
|
||||
|
||||
|
||||
def get_all_cost_directive_strings() -> str:
|
||||
return "\n".join([directive.to_string_for_plan() for directive in ALL_COST_DIRECTIVES])
|
||||
|
||||
|
||||
def instantiate_directive(
|
||||
directive_name: str,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list,
|
||||
global_default_model: str = None,
|
||||
dataset: str = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Centralized method to instantiate any directive by name.
|
||||
|
||||
Args:
|
||||
directive_name: Name of the directive to instantiate
|
||||
operators: List of pipeline operators
|
||||
target_ops: List of target operation names
|
||||
agent_llm: LLM model to use
|
||||
message_history: Conversation history
|
||||
**kwargs: Additional arguments to pass to directive
|
||||
|
||||
Returns:
|
||||
Tuple of (new_ops_list, updated_message_history)
|
||||
|
||||
Raises:
|
||||
ValueError: If directive_name is not recognized
|
||||
"""
|
||||
if message_history is None:
|
||||
message_history = []
|
||||
|
||||
if directive_name not in DIRECTIVE_REGISTRY:
|
||||
available = list(DIRECTIVE_REGISTRY.keys())
|
||||
raise ValueError(f"Unknown directive '{directive_name}'. Available: {available}")
|
||||
|
||||
directive = DIRECTIVE_REGISTRY[directive_name]
|
||||
return directive.instantiate(
|
||||
operators=operators,
|
||||
target_ops=target_ops,
|
||||
agent_llm=agent_llm,
|
||||
message_history=message_history,
|
||||
global_default_model = global_default_model,
|
||||
dataset=dataset,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Directive",
|
||||
"DirectiveTestCase",
|
||||
"TestResult",
|
||||
"MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS",
|
||||
"DEFAULT_MODEL",
|
||||
"DEFAULT_MAX_TPM",
|
||||
"DEFAULT_OUTPUT_DIR",
|
||||
"ChainingDirective",
|
||||
"GleaningDirective",
|
||||
"ReduceGleaningDirective",
|
||||
"ReduceChainingDirective",
|
||||
"ChangeModelDirective",
|
||||
"DocSummarizationDirective",
|
||||
"IsolatingSubtasksDirective",
|
||||
# "DocCompressionDirective",
|
||||
"DeterministicDocCompressionDirective",
|
||||
"OperatorFusionDirective",
|
||||
"DocumentChunkingDirective",
|
||||
"DocumentChunkingTopKDirective",
|
||||
"ChunkHeaderSummaryDirective",
|
||||
"TakeHeadTailDirective",
|
||||
"ClarifyInstructionsDirective",
|
||||
"SwapWithCodeDirective",
|
||||
"MapReduceFusionDirective",
|
||||
"HierarchicalReduceDirective",
|
||||
"CascadeFilteringDirective",
|
||||
"ArbitraryRewriteDirective",
|
||||
"MapToMapResolveReduceDirective",
|
||||
"MapResolveToMapWithCategoriesDirective",
|
||||
"ALL_DIRECTIVES",
|
||||
"DIRECTIVE_REGISTRY",
|
||||
"get_all_directive_strings",
|
||||
"instantiate_directive"
|
||||
]
|
||||
|
|
@ -1,458 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
from typing import Dict, List, Literal, Optional
|
||||
|
||||
from litellm import completion, model_cost
|
||||
from pydantic import BaseModel, Field
|
||||
from rich import print as rprint
|
||||
|
||||
from docetl.operations.utils.llm import count_tokens
|
||||
|
||||
|
||||
class AgentDecision(BaseModel):
|
||||
"""Schema for agent decision-making in agentic loops."""
|
||||
|
||||
action: Literal["read_next_docs", "read_operator_doc", "output_schema"] = Field(
|
||||
..., description="The action the agent wants to take"
|
||||
)
|
||||
reasoning: str = Field(
|
||||
...,
|
||||
description="Explanation of why the agent chose this action and what they learned from current samples",
|
||||
)
|
||||
operator_name: Optional[str] = Field(
|
||||
None,
|
||||
description="For read_operator_doc action: the operator name to read documentation for (e.g., 'map', 'filter', 'reduce')",
|
||||
)
|
||||
|
||||
|
||||
class ReadNextDocTool:
|
||||
"""Tool for iteratively reading documents from input data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_data: List[Dict],
|
||||
context_window: int = 32000,
|
||||
docs_per_iteration: int = 3,
|
||||
):
|
||||
self.input_data = input_data
|
||||
self.current_index = 0
|
||||
self.context_window = context_window
|
||||
self.total_docs = len(input_data)
|
||||
self.docs_per_iteration = docs_per_iteration
|
||||
|
||||
def read_next_doc(self) -> Optional[Dict]:
|
||||
"""Read the next document from the input data."""
|
||||
if self.current_index >= len(self.input_data):
|
||||
return None
|
||||
|
||||
doc = self.input_data[self.current_index]
|
||||
self.current_index += 1
|
||||
return doc
|
||||
|
||||
def read_next_docs(self, count: int = None) -> List[Dict]:
|
||||
"""Read the next N documents from the input data."""
|
||||
if count is None:
|
||||
count = self.docs_per_iteration
|
||||
docs = []
|
||||
for _ in range(count):
|
||||
if self.current_index >= len(self.input_data):
|
||||
break
|
||||
docs.append(self.input_data[self.current_index])
|
||||
self.current_index += 1
|
||||
return docs
|
||||
|
||||
def has_more_docs(self) -> bool:
|
||||
"""Check if there are more documents to read."""
|
||||
return self.current_index < len(self.input_data)
|
||||
|
||||
def get_remaining_count(self) -> int:
|
||||
"""Get the number of remaining documents."""
|
||||
return len(self.input_data) - self.current_index
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the iterator to the beginning."""
|
||||
self.current_index = 0
|
||||
|
||||
|
||||
def estimate_token_count(text: str, model: str = "gpt-4.1-mini") -> int:
|
||||
"""Use proper token counting instead of rough estimation."""
|
||||
return count_tokens(text, model)
|
||||
|
||||
|
||||
def truncate_message_content(messages: List[Dict], max_tokens: int) -> List[Dict]:
|
||||
"""
|
||||
Truncate message content to fit within token limits.
|
||||
Preserves system message and latest user message, truncates middle content.
|
||||
"""
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
# Calculate total token count
|
||||
total_tokens = sum(estimate_token_count(msg.get("content", "")) for msg in messages)
|
||||
|
||||
if total_tokens <= max_tokens:
|
||||
return messages
|
||||
|
||||
# Keep system message and latest user message
|
||||
truncated_messages = []
|
||||
if messages[0].get("role") == "system":
|
||||
truncated_messages.append(messages[0])
|
||||
remaining_messages = messages[1:]
|
||||
else:
|
||||
remaining_messages = messages
|
||||
|
||||
# Always keep the latest message
|
||||
if remaining_messages:
|
||||
truncated_messages.append(remaining_messages[-1])
|
||||
middle_messages = remaining_messages[:-1]
|
||||
else:
|
||||
middle_messages = []
|
||||
|
||||
# Calculate available tokens for middle messages
|
||||
system_tokens = (
|
||||
estimate_token_count(truncated_messages[0].get("content", ""))
|
||||
if truncated_messages
|
||||
else 0
|
||||
)
|
||||
latest_tokens = (
|
||||
estimate_token_count(truncated_messages[-1].get("content", ""))
|
||||
if len(truncated_messages) > 1
|
||||
else 0
|
||||
)
|
||||
available_tokens = (
|
||||
max_tokens - system_tokens - latest_tokens - 1000
|
||||
) # Buffer for response
|
||||
|
||||
# Add middle messages until we hit the limit
|
||||
current_tokens = 0
|
||||
for msg in reversed(middle_messages): # Add most recent first
|
||||
msg_tokens = estimate_token_count(msg.get("content", ""))
|
||||
if current_tokens + msg_tokens <= available_tokens:
|
||||
current_tokens += msg_tokens
|
||||
truncated_messages.insert(-1, msg) # Insert before the latest message
|
||||
else:
|
||||
break
|
||||
|
||||
return truncated_messages
|
||||
|
||||
|
||||
class AgenticDirectiveRunner:
|
||||
"""
|
||||
Utility class for running agentic directives that iteratively process documents.
|
||||
Manages context windows, document iteration, and decision-making loops.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_data: List[Dict],
|
||||
agent_llm: str = "gpt-4.1-mini",
|
||||
validation_func: Optional[callable] = None,
|
||||
enable_operator_docs: bool = False,
|
||||
):
|
||||
self.input_data = input_data
|
||||
self.agent_llm = agent_llm
|
||||
self.context_window = self._get_model_context_window(agent_llm)
|
||||
self.enable_operator_docs = enable_operator_docs
|
||||
# Double the max iterations if operator docs are enabled to allow more exploration
|
||||
self.docs_per_iteration = 3
|
||||
self.max_iterations = 6 if enable_operator_docs else 3
|
||||
self.doc_reader = ReadNextDocTool(
|
||||
input_data, self.context_window, self.docs_per_iteration
|
||||
)
|
||||
self.message_history = []
|
||||
self.validation_func = validation_func
|
||||
self.docs_path = os.path.join(
|
||||
os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
),
|
||||
"docs",
|
||||
"operators",
|
||||
)
|
||||
|
||||
def _get_model_context_window(self, model: str) -> int:
|
||||
"""Get the context window size for the given model."""
|
||||
model_cost_info = model_cost.get(model, {})
|
||||
if not model_cost_info:
|
||||
# Try stripping the first part before the /
|
||||
split_model = model.split("/")
|
||||
if len(split_model) > 1:
|
||||
model_cost_info = model_cost.get("/".join(split_model[1:]), {})
|
||||
|
||||
if not model_cost_info:
|
||||
model_cost_info = model_cost.get(model.split("/")[-1], {})
|
||||
|
||||
return model_cost_info.get("max_input_tokens", 32768)
|
||||
|
||||
def _read_operator_doc(self, operator_name: str) -> Optional[str]:
|
||||
"""
|
||||
Read the documentation for a specific operator.
|
||||
|
||||
Args:
|
||||
operator_name: Name of the operator (e.g., 'map', 'filter', 'reduce')
|
||||
|
||||
Returns:
|
||||
The markdown documentation content or None if not found
|
||||
"""
|
||||
doc_file = os.path.join(self.docs_path, f"{operator_name}.md")
|
||||
if os.path.exists(doc_file):
|
||||
try:
|
||||
with open(doc_file, "r") as f:
|
||||
content = f.read()
|
||||
return content
|
||||
except Exception as e:
|
||||
return f"Error reading documentation for {operator_name}: {str(e)}"
|
||||
else:
|
||||
# Try alternative names (e.g., 'parallel-map' for 'parallel_map')
|
||||
alt_name = operator_name.replace("_", "-")
|
||||
doc_file = os.path.join(self.docs_path, f"{alt_name}.md")
|
||||
if os.path.exists(doc_file):
|
||||
try:
|
||||
with open(doc_file, "r") as f:
|
||||
content = f.read()
|
||||
return content
|
||||
except Exception as e:
|
||||
return f"Error reading documentation for {operator_name}: {str(e)}"
|
||||
return f"Documentation not found for operator: {operator_name}"
|
||||
|
||||
def _truncate_doc_to_tokens(self, doc: Dict, max_tokens: int) -> str:
|
||||
"""
|
||||
Truncate document content to fit within the specified token limit.
|
||||
|
||||
Args:
|
||||
doc: The document dictionary to truncate
|
||||
max_tokens: Maximum number of tokens to allow
|
||||
|
||||
Returns:
|
||||
Truncated document string
|
||||
"""
|
||||
doc_str = json.dumps(doc, indent=2)
|
||||
doc_tokens = estimate_token_count(doc_str, self.agent_llm)
|
||||
|
||||
if doc_tokens <= max_tokens:
|
||||
return doc_str
|
||||
|
||||
# If document is too long, truncate it
|
||||
# Estimate characters per token (rough approximation)
|
||||
chars_per_token = len(doc_str) / doc_tokens
|
||||
target_chars = int(max_tokens * chars_per_token)
|
||||
|
||||
truncated = doc_str[:target_chars]
|
||||
return truncated + "... [truncated]"
|
||||
|
||||
def run_agentic_loop(
|
||||
self, system_prompt: str, initial_user_message: str, response_schema: BaseModel
|
||||
):
|
||||
"""
|
||||
Run an agentic loop where the agent analyzes input data for directive instantiation.
|
||||
|
||||
Args:
|
||||
system_prompt: System message for the agent
|
||||
initial_user_message: Initial user message with task description
|
||||
response_schema: Pydantic schema for the expected response
|
||||
|
||||
Returns:
|
||||
Tuple of (parsed_response, message_history)
|
||||
"""
|
||||
call_cost = 0.0
|
||||
|
||||
# Initialize message history
|
||||
self.message_history = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": initial_user_message},
|
||||
]
|
||||
|
||||
max_iterations = min(
|
||||
self.max_iterations, len(self.input_data)
|
||||
) # Conservative limit for analysis
|
||||
|
||||
rprint(
|
||||
f"[blue]🤖 Determining rewrite instantiation with {len(self.input_data)} documents available[/blue]"
|
||||
)
|
||||
|
||||
for iteration in range(max_iterations):
|
||||
# Calculate remaining context
|
||||
current_tokens = sum(
|
||||
estimate_token_count(msg.get("content", ""), self.agent_llm)
|
||||
for msg in self.message_history
|
||||
)
|
||||
remaining_tokens = self.context_window - current_tokens - 2000 # Buffer
|
||||
|
||||
# Add context info for analysis
|
||||
context_info = f"""
|
||||
Analysis Progress:
|
||||
- Remaining context window: ~{remaining_tokens} tokens
|
||||
- Documents analyzed: {self.doc_reader.current_index}/{self.doc_reader.total_docs}
|
||||
- Documents remaining: {self.doc_reader.get_remaining_count()}
|
||||
|
||||
Analyze the input samples to understand patterns, edge cases, and data characteristics that will help you complete your task effectively.
|
||||
"""
|
||||
|
||||
# Create action guidance
|
||||
if self.enable_operator_docs:
|
||||
action_guidance = f"""
|
||||
Choose your next action:
|
||||
- read_next_docs: If you need more examples to understand data patterns, edge cases, or to gather more information for your analysis (reads ~{self.docs_per_iteration} documents at once)
|
||||
- read_operator_doc: If you need to understand how a specific operator works, its parameters, or see examples (specify operator_name like 'map', 'filter', 'reduce')
|
||||
- output_schema: If you have sufficient examples to complete your task based on the patterns and insights you've gathered from the data
|
||||
|
||||
Focus on quality over quantity - a few diverse, informative examples are better than many similar ones.
|
||||
"""
|
||||
else:
|
||||
action_guidance = f"""
|
||||
Choose your next action:
|
||||
- read_next_docs: If you need more examples to understand data patterns, edge cases, or to gather more information for your analysis (reads ~{self.docs_per_iteration} documents at once)
|
||||
- output_schema: If you have sufficient examples to complete your task based on the patterns and insights you've gathered from the data
|
||||
|
||||
Focus on quality over quantity - a few diverse, informative examples are better than many similar ones.
|
||||
"""
|
||||
|
||||
# Update the latest user message with context info
|
||||
if self.message_history[-1]["role"] == "user":
|
||||
self.message_history[-1]["content"] += context_info + action_guidance
|
||||
|
||||
# Truncate messages if needed
|
||||
self.message_history = truncate_message_content(
|
||||
self.message_history, self.context_window - 2000
|
||||
)
|
||||
|
||||
rprint(
|
||||
f"[yellow]🧠 Iteration {iteration + 1}/{max_iterations}: Asking {self.agent_llm} agent to decide next action (tokens: {remaining_tokens} remaining)[/yellow]"
|
||||
)
|
||||
|
||||
# Get structured agent decision
|
||||
response = completion(
|
||||
model=self.agent_llm,
|
||||
messages=self.message_history,
|
||||
response_format=AgentDecision,
|
||||
)
|
||||
call_cost += response._hidden_params["response_cost"]
|
||||
|
||||
try:
|
||||
decision_json = json.loads(response.choices[0].message.content)
|
||||
decision = AgentDecision(**decision_json)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to parse agent decision: {str(e)}")
|
||||
|
||||
self.message_history.append(
|
||||
{"role": "assistant", "content": response.choices[0].message.content}
|
||||
)
|
||||
|
||||
# Handle agent's decision
|
||||
if decision.action == "read_operator_doc":
|
||||
# Agent wants to read operator documentation
|
||||
if not self.enable_operator_docs:
|
||||
user_message = "Operator documentation reading is not enabled for this directive."
|
||||
self.message_history.append(
|
||||
{"role": "user", "content": user_message}
|
||||
)
|
||||
elif not decision.operator_name:
|
||||
user_message = "Please specify which operator documentation you want to read (e.g., 'map', 'filter', 'reduce')."
|
||||
self.message_history.append(
|
||||
{"role": "user", "content": user_message}
|
||||
)
|
||||
else:
|
||||
rprint(
|
||||
f"[blue]📖 Agent reading documentation for operator: {decision.operator_name}[/blue]"
|
||||
)
|
||||
doc_content = self._read_operator_doc(decision.operator_name)
|
||||
user_message = f"Documentation for '{decision.operator_name}' operator:\n\n{doc_content}\n\nAnalyze this documentation to understand how to use this operator effectively."
|
||||
self.message_history.append(
|
||||
{"role": "user", "content": user_message}
|
||||
)
|
||||
|
||||
elif decision.action == "read_next_docs":
|
||||
# Agent wants to analyze more data
|
||||
next_docs = self.doc_reader.read_next_docs()
|
||||
if not next_docs:
|
||||
# No more documents - force output
|
||||
rprint(
|
||||
"[red]📄 No more documents available. Proceeding with schema generation.[/red]"
|
||||
)
|
||||
user_message = "No more documents available. Based on the samples you've analyzed, please complete your task."
|
||||
self.message_history.append(
|
||||
{"role": "user", "content": user_message}
|
||||
)
|
||||
break
|
||||
else:
|
||||
rprint(
|
||||
f"[green]📄 Agent reading {len(next_docs)} documents (up to {self.doc_reader.current_index}/{len(self.input_data)})[/green]"
|
||||
)
|
||||
docs_content = "\n".join(
|
||||
[
|
||||
f"Sample {self.doc_reader.current_index - len(next_docs) + i + 1}:\n{self._truncate_doc_to_tokens(doc, 1000)}"
|
||||
for i, doc in enumerate(next_docs)
|
||||
]
|
||||
)
|
||||
user_message = f"{docs_content}\n\nAnalyze these samples for patterns, edge cases, and characteristics that will help with your task."
|
||||
self.message_history.append(
|
||||
{"role": "user", "content": user_message}
|
||||
)
|
||||
|
||||
elif decision.action == "output_schema":
|
||||
# Agent is ready to create improved prompt
|
||||
rprint(
|
||||
f"[cyan]✨ Agent ready to generate final schema after analyzing {self.doc_reader.current_index} documents[/cyan]"
|
||||
)
|
||||
schema_prompt = f"""Based on your analysis of the input samples, complete your task using the patterns and insights you've gathered from the data.
|
||||
|
||||
Provide your response as a JSON object matching this schema: {response_schema.model_json_schema()}"""
|
||||
self.message_history.append({"role": "user", "content": schema_prompt})
|
||||
break
|
||||
|
||||
# Get the final schema response with validation and retries
|
||||
rprint("[magenta]🔧 Generating final rewrite schema...[/magenta]")
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS
|
||||
|
||||
error_message = ""
|
||||
|
||||
for attempt in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
schema_response = completion(
|
||||
model=self.agent_llm,
|
||||
messages=self.message_history,
|
||||
response_format=response_schema,
|
||||
)
|
||||
call_cost += schema_response._hidden_params["response_cost"]
|
||||
|
||||
try:
|
||||
parsed_response = json.loads(schema_response.choices[0].message.content)
|
||||
schema_instance = response_schema(**parsed_response)
|
||||
|
||||
# Add any additional validation if provided
|
||||
if self.validation_func:
|
||||
self.validation_func(schema_instance)
|
||||
|
||||
rprint(
|
||||
f"[green]✅ Schema validation passed on attempt {attempt + 1}[/green]"
|
||||
)
|
||||
self.message_history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": schema_response.choices[0].message.content,
|
||||
}
|
||||
)
|
||||
return schema_instance, self.message_history, call_cost
|
||||
|
||||
except Exception as err:
|
||||
error_message = f"Validation error: {err}\nPlease try again with a corrected response."
|
||||
rprint(
|
||||
f"[red]❌ Schema validation failed on attempt {attempt + 1}: {str(err)}[/red]"
|
||||
)
|
||||
|
||||
if attempt < MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS - 1:
|
||||
rprint(
|
||||
f"[yellow]🔄 Retrying schema generation (attempt {attempt + 2}/{MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS})[/yellow]"
|
||||
)
|
||||
self.message_history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": schema_response.choices[0].message.content,
|
||||
}
|
||||
)
|
||||
self.message_history.append(
|
||||
{"role": "user", "content": error_message}
|
||||
)
|
||||
|
||||
raise Exception(
|
||||
f"Failed to generate valid schema after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Error: {error_message}"
|
||||
)
|
||||
|
|
@ -1,361 +0,0 @@
|
|||
import json
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
ArbitraryRewriteInstantiateSchema,
|
||||
)
|
||||
from docetl.reasoning_optimizer.op_descriptions import get_all_operator_descriptions
|
||||
|
||||
from .agent_utils import AgenticDirectiveRunner
|
||||
from .base import Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class ArbitraryRewriteDirective(Directive):
|
||||
name: str = Field(
|
||||
default="arbitrary_rewrite", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(
|
||||
default="Pipeline => Modified Pipeline (search/replace edits)"
|
||||
)
|
||||
nl_description: str = Field(
|
||||
default="Allows the agent to make arbitrary edits to the pipeline using search/replace operations on the JSON representation. The agent can add, modify, remove, or replace operations to optimize for cost, accuracy, or both. This is a catch-all directive for optimizations that don't fit into other specific directive patterns."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When you identify obvious optimizations that can make the pipeline cheaper or more accurate, but they don't fit into existing directive patterns. Use this for complex multi-operation changes, reordering operations, consolidating redundant operations, or making systematic improvements across the pipeline."
|
||||
)
|
||||
|
||||
instantiate_schema_type: Type[BaseModel] = Field(
|
||||
default=ArbitraryRewriteInstantiateSchema
|
||||
)
|
||||
|
||||
example: str = Field(
|
||||
default="""
|
||||
# Example 1: Consolidating Redundant Operations
|
||||
Original Pipeline JSON (formatted for readability):
|
||||
[
|
||||
{
|
||||
"name": "extract_names",
|
||||
"type": "map",
|
||||
"model": "gpt-4o",
|
||||
"prompt": "Extract all person names from: {{ input.text }}",
|
||||
"output": {"schema": {"names": "list[string]"}}
|
||||
},
|
||||
{
|
||||
"name": "extract_locations",
|
||||
"type": "map",
|
||||
"model": "gpt-4o",
|
||||
"prompt": "Extract all locations from: {{ input.text }}",
|
||||
"output": {"schema": {"locations": "list[string]"}}
|
||||
},
|
||||
{
|
||||
"name": "extract_dates",
|
||||
"type": "map",
|
||||
"model": "gpt-4o",
|
||||
"prompt": "Extract all dates from: {{ input.text }}",
|
||||
"output": {"schema": {"dates": "list[string]"}}
|
||||
}
|
||||
]
|
||||
|
||||
Example SearchReplaceEdits (agent recognizes redundant LLM calls):
|
||||
search_replace_edits=[
|
||||
SearchReplaceEdit(
|
||||
search=' {\\n "name": "extract_names",\\n "type": "map",\\n "model": "gpt-4o",\\n "prompt": "Extract all person names from: {{ input.text }}",\\n "output": {"schema": {"names": "list[string]"}}\\n },\\n {\\n "name": "extract_locations",\\n "type": "map",\\n "model": "gpt-4o",\\n "prompt": "Extract all locations from: {{ input.text }}",\\n "output": {"schema": {"locations": "list[string]"}}\\n },\\n {\\n "name": "extract_dates",\\n "type": "map",\\n "model": "gpt-4o",\\n "prompt": "Extract all dates from: {{ input.text }}",\\n "output": {"schema": {"dates": "list[string]"}}\\n }',
|
||||
replace=' {\\n "name": "extract_all_entities",\\n "type": "map",\\n "model": "gpt-4o-mini",\\n "prompt": "Extract all entities from the text:\\\\n{{ input.text }}\\\\n\\\\nReturn names, locations, and dates.",\\n "output": {\\n "schema": {\\n "names": "list[string]",\\n "locations": "list[string]",\\n "dates": "list[string]"\\n }\\n }\\n }',
|
||||
reasoning="Consolidate three separate GPT-4o calls into one GPT-4o-mini call"
|
||||
)
|
||||
]
|
||||
|
||||
# Example 2: Reordering for Efficiency
|
||||
Original Pipeline JSON:
|
||||
[
|
||||
{
|
||||
"name": "summarize_documents",
|
||||
"type": "map",
|
||||
"model": "gpt-4o",
|
||||
"prompt": "Summarize this document: {{ input.full_text }}",
|
||||
"output": {"schema": {"summary": "string"}}
|
||||
},
|
||||
{
|
||||
"name": "filter_relevant",
|
||||
"type": "filter",
|
||||
"model": "gpt-4o-mini",
|
||||
"prompt": "Is this document about technology? {{ input.full_text }}",
|
||||
"output": {"schema": {"is_relevant": "boolean"}}
|
||||
}
|
||||
]
|
||||
|
||||
Example SearchReplaceEdits (agent recognizes inefficient ordering):
|
||||
search_replace_edits=[
|
||||
SearchReplaceEdit(
|
||||
search='{\\n "name": "summarize_documents",\\n "type": "map",\\n "model": "gpt-4o",\\n "prompt": "Summarize this document: {{ input.full_text }}",\\n "output": {"schema": {"summary": "string"}}\\n },\\n {\\n "name": "filter_relevant",',
|
||||
replace='{\\n "name": "filter_relevant",',
|
||||
reasoning="Remove summarize operation from before filter"
|
||||
),
|
||||
SearchReplaceEdit(
|
||||
search='"output": {"schema": {"is_relevant": "boolean"}}\\n }',
|
||||
replace='"output": {"schema": {"is_relevant": "boolean"}}\\n },\\n {\\n "name": "summarize_documents",\\n "type": "map",\\n "model": "gpt-4o",\\n "prompt": "Summarize this technology document: {{ input.full_text }}",\\n "output": {"schema": {"summary": "string"}}\\n }',
|
||||
reasoning="Add summarization after filter to only process relevant documents"
|
||||
)
|
||||
]
|
||||
|
||||
Note: The agent must provide exact string matches including whitespace for the search strings.
|
||||
"""
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="consolidate_operations",
|
||||
description="Should identify and consolidate redundant operations",
|
||||
input_config=[
|
||||
{
|
||||
"name": "extract_a",
|
||||
"type": "map",
|
||||
"model": "gpt-4o",
|
||||
"prompt": "Extract A from: {{ input.text }}",
|
||||
"output": {"schema": {"a": "string"}},
|
||||
},
|
||||
{
|
||||
"name": "extract_b",
|
||||
"type": "map",
|
||||
"model": "gpt-4o",
|
||||
"prompt": "Extract B from: {{ input.text }}",
|
||||
"output": {"schema": {"b": "string"}},
|
||||
},
|
||||
],
|
||||
target_ops=[], # Analyzes entire pipeline
|
||||
expected_behavior="Should propose consolidating the two extraction operations into one",
|
||||
should_pass=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ArbitraryRewriteDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("ArbitraryRewriteDirective")
|
||||
|
||||
def _order_ops_by_pipeline_steps(
|
||||
self, ops_list: List[Dict], pipeline_code: Dict = None
|
||||
) -> List[Dict]:
|
||||
"""Order operations according to pipeline steps if pipeline_code is provided."""
|
||||
if not (
|
||||
pipeline_code
|
||||
and "pipeline" in pipeline_code
|
||||
and "steps" in pipeline_code["pipeline"]
|
||||
):
|
||||
return ops_list
|
||||
|
||||
# Get the order from pipeline steps
|
||||
steps = pipeline_code["pipeline"]["steps"]
|
||||
step_order = []
|
||||
for step in steps:
|
||||
if isinstance(step, dict) and "operations" in step:
|
||||
step_order.extend(step["operations"])
|
||||
elif isinstance(step, str):
|
||||
step_order.append(step)
|
||||
|
||||
# Create a mapping of op names to ops
|
||||
ops_by_name = {op["name"]: op for op in ops_list}
|
||||
|
||||
# Reorder ops according to pipeline steps
|
||||
ordered_ops = []
|
||||
for op_name in step_order:
|
||||
if op_name in ops_by_name:
|
||||
ordered_ops.append(ops_by_name[op_name])
|
||||
|
||||
# Add any ops that weren't in steps (shouldn't happen, but just in case)
|
||||
for op in ops_list:
|
||||
if op not in ordered_ops:
|
||||
ordered_ops.append(op)
|
||||
|
||||
return ordered_ops
|
||||
|
||||
def to_string_for_instantiate(
|
||||
self, pipeline_ops: List[Dict], pipeline_code: Dict = None
|
||||
) -> str:
|
||||
"""Generate prompt for agent to analyze pipeline and propose edits."""
|
||||
|
||||
# Convert ops to pretty JSON string for display (already ordered by instantiate)
|
||||
pipeline_json = json.dumps(pipeline_ops, indent=4)
|
||||
|
||||
return (
|
||||
f"You are an expert at optimizing data processing pipelines for cost and accuracy.\n\n"
|
||||
f"Current Pipeline Operations (as JSON array):\n"
|
||||
f"{pipeline_json}\n\n"
|
||||
f"Your task is to analyze this pipeline and propose search/replace edits to optimize it.\n\n"
|
||||
f"IMPORTANT: The above JSON is ONLY the operations array. When creating search/replace edits, "
|
||||
f"work with this exact JSON structure. Do not include pipeline wrapper or other fields.\n\n"
|
||||
f"You have access to:\n"
|
||||
f"1. Sample input data through read_next_docs() - to understand data patterns and flow\n"
|
||||
f"2. Operator documentation through read_operator_doc(operator_name) - to learn about available operators\n\n"
|
||||
f"IMPORTANT: Read documentation for no more than 2 operators before analyzing sample data to get a sense for how best to rewrite the pipeline.\n\n"
|
||||
f"Use these tools to:\n"
|
||||
f"1. Understand the data flow through the pipeline\n"
|
||||
f"2. Identify inefficiencies or redundancies\n"
|
||||
f"3. Find opportunities for consolidation or reordering\n"
|
||||
f"4. Determine if cheaper models could be used\n"
|
||||
f"5. Discover new operators that might be more efficient\n\n"
|
||||
f"IMPORTANT: You will provide search/replace edits that work on the JSON string representation.\n"
|
||||
f"Each edit consists of:\n"
|
||||
f"- search: An exact string to find in the JSON (including whitespace)\n"
|
||||
f"- replace: The string to replace it with (can be empty to delete)\n"
|
||||
f"- reasoning: Why this edit improves the pipeline\n\n"
|
||||
f"The edits will be applied sequentially to the JSON string, then parsed back to operations.\n\n"
|
||||
f"Guidelines for search/replace:\n"
|
||||
f"- The search string must match EXACTLY including all whitespace, quotes, brackets, etc.\n"
|
||||
f"- You can delete operations by replacing them with empty string (but be careful with commas)\n"
|
||||
f"- You can add operations by replacing closing brackets with new operations\n"
|
||||
f"- You can reorder by using multiple search/replace operations\n"
|
||||
f"- Each edit operates on the result of the previous edit\n\n"
|
||||
f"Types of optimizations to look for:\n"
|
||||
f"- Redundant operations that could be consolidated\n"
|
||||
f"- Operations in suboptimal order (e.g., expensive operations before filters)\n"
|
||||
f"- Opportunities to use cheaper models\n"
|
||||
f"- Complex operations that could be broken into simpler steps\n"
|
||||
f"- Independent operations that could be parallelized\n\n"
|
||||
f"Examples:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Analyze the pipeline and sample data strategically. When you have identified optimizations, "
|
||||
f"output your proposed edits as an ArbitraryRewriteInstantiateSchema.\n\n"
|
||||
f"Remember: Your search strings must match the JSON exactly as it appears above."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
pipeline_ops: List[Dict],
|
||||
input_file_path: str,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
pipeline_code: Dict = None,
|
||||
):
|
||||
"""Use agentic approach to analyze pipeline and generate edits."""
|
||||
# Load sample input data
|
||||
try:
|
||||
with open(input_file_path, "r") as f:
|
||||
input_data = json.load(f)
|
||||
|
||||
if not isinstance(input_data, list) or len(input_data) == 0:
|
||||
raise ValueError(
|
||||
"Input file must contain a non-empty list of sample data"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Failed to load input data from {input_file_path}: {str(e)}"
|
||||
)
|
||||
|
||||
# Set up agentic runner with operator doc reading enabled
|
||||
runner = AgenticDirectiveRunner(
|
||||
input_data=input_data,
|
||||
agent_llm=agent_llm,
|
||||
enable_operator_docs=True, # Enable reading operator documentation
|
||||
)
|
||||
|
||||
# Create system prompt with operator descriptions
|
||||
operator_descriptions = get_all_operator_descriptions()
|
||||
system_prompt = (
|
||||
"You are an expert at optimizing data processing pipelines. "
|
||||
"You analyze pipeline structures and data flow to identify opportunities "
|
||||
"for cost reduction and accuracy improvement through strategic search/replace edits on the JSON representation.\n\n"
|
||||
f"{operator_descriptions}"
|
||||
)
|
||||
|
||||
# Create initial user message
|
||||
initial_message = self.to_string_for_instantiate(pipeline_ops, pipeline_code)
|
||||
|
||||
# Run the agentic loop
|
||||
try:
|
||||
schema, updated_message_history, call_cost = runner.run_agentic_loop(
|
||||
system_prompt=system_prompt,
|
||||
initial_user_message=initial_message,
|
||||
response_schema=ArbitraryRewriteInstantiateSchema,
|
||||
)
|
||||
|
||||
# Update message history
|
||||
message_history.extend(updated_message_history)
|
||||
|
||||
return schema, message_history, call_cost
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Failed to instantiate arbitrary_rewrite directive: {str(e)}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
ops_list: List[Dict],
|
||||
rewrite: ArbitraryRewriteInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""Apply the search/replace edits to the pipeline."""
|
||||
# Convert operations list to JSON string with consistent formatting
|
||||
pipeline_json = json.dumps(ops_list, indent=4)
|
||||
|
||||
# Apply each search/replace edit in sequence
|
||||
for i, edit in enumerate(rewrite.search_replace_edits):
|
||||
if edit.search in pipeline_json:
|
||||
pipeline_json = pipeline_json.replace(
|
||||
edit.search, edit.replace, 1
|
||||
) # Replace first occurrence only
|
||||
else:
|
||||
# Log warning but continue with other edits
|
||||
print(
|
||||
f"Warning: Search string not found in edit {i+1}: {edit.search[:50]}..."
|
||||
)
|
||||
|
||||
# Parse the modified JSON back to operations list
|
||||
try:
|
||||
new_ops_list = json.loads(pipeline_json)
|
||||
if not isinstance(new_ops_list, list):
|
||||
raise ValueError("Pipeline must be a list of operations")
|
||||
|
||||
# Get rid of any empty operations in new_ops_list
|
||||
new_ops_list = [op for op in new_ops_list if op]
|
||||
return new_ops_list
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Failed to parse pipeline after edits. The search/replace operations resulted in invalid JSON: {e}\n"
|
||||
f"Modified JSON:\n{pipeline_json[:500]}..."
|
||||
)
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str] = None,
|
||||
agent_llm: str = "gpt-4o-mini",
|
||||
message_history: list = [],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Main method that orchestrates directive instantiation.
|
||||
For ArbitraryRewrite, we analyze the entire pipeline rather than specific target ops.
|
||||
"""
|
||||
input_file_path = kwargs.get("input_file_path", None)
|
||||
pipeline_code = kwargs.get("pipeline_code", None)
|
||||
|
||||
if not input_file_path:
|
||||
raise ValueError(
|
||||
"input_file_path is required for ArbitraryRewrite directive"
|
||||
)
|
||||
|
||||
# Order ops according to pipeline steps before everything else
|
||||
ordered_ops = self._order_ops_by_pipeline_steps(operators, pipeline_code)
|
||||
|
||||
# Step 1: Agent analyzes pipeline and generates edits
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
ordered_ops,
|
||||
input_file_path,
|
||||
agent_llm,
|
||||
message_history,
|
||||
pipeline_code,
|
||||
)
|
||||
|
||||
# Step 2: Apply the edits
|
||||
return (
|
||||
self.apply(ordered_ops, rewrite),
|
||||
message_history,
|
||||
call_cost,
|
||||
)
|
||||
|
|
@ -1,231 +0,0 @@
|
|||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Configuration constants
|
||||
MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS = 3
|
||||
DEFAULT_MODEL = "gpt-4.1"
|
||||
DEFAULT_MAX_TPM = 5000000
|
||||
DEFAULT_OUTPUT_DIR = "./outputs"
|
||||
|
||||
|
||||
class TestResult(BaseModel):
|
||||
test_name: str
|
||||
passed: bool
|
||||
reason: str
|
||||
actual_output: Any
|
||||
execution_error: Optional[str] = None
|
||||
|
||||
|
||||
class DirectiveTestCase(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
input_config: Dict[str, Any] | List[Dict[str, Any]]
|
||||
target_ops: List[str]
|
||||
expected_behavior: str
|
||||
should_pass: bool = True
|
||||
|
||||
|
||||
class Directive(BaseModel, ABC):
|
||||
name: str = Field(..., description="The name of the directive")
|
||||
formal_description: str = Field(
|
||||
..., description="A description of the directive; e.g., map => map -> map"
|
||||
)
|
||||
nl_description: str = Field(
|
||||
..., description="An english description of the directive"
|
||||
)
|
||||
when_to_use: str = Field(..., description="When to use the directive")
|
||||
instantiate_schema_type: BaseModel = Field(
|
||||
...,
|
||||
description="The schema the agent must conform to when instantiating the directive",
|
||||
)
|
||||
example: str = Field(..., description="An example of the directive being used")
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=list, description="Test cases for this directive"
|
||||
)
|
||||
|
||||
def to_string_for_plan(self) -> str:
|
||||
"""Serialize directive for prompts."""
|
||||
parts = [
|
||||
f"### {self.name}",
|
||||
f"**Format:** {self.formal_description}",
|
||||
f"**Description:** {self.nl_description}",
|
||||
f"**When to Use:** {self.when_to_use}",
|
||||
]
|
||||
return "\n\n".join(parts)
|
||||
|
||||
@abstractmethod
|
||||
def to_string_for_instantiate(self, *args, **kwargs) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def llm_instantiate(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, *args, **kwargs) -> list:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def instantiate(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def run_tests(self, agent_llm: str = "gpt-4o-mini") -> List[TestResult]:
|
||||
"""Run all test cases for this directive using LLM judge"""
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
results = []
|
||||
|
||||
for test_case in self.test_cases:
|
||||
try:
|
||||
# Create fake sample data and pipeline for directives that need them
|
||||
sample_data = [
|
||||
{
|
||||
"text": "Sample document 1",
|
||||
"feedback": "Great product!",
|
||||
"doc1": "Document A",
|
||||
"doc2": "Document B",
|
||||
},
|
||||
{
|
||||
"text": "Sample document 2",
|
||||
"feedback": "Could be better",
|
||||
"doc1": "Report 1",
|
||||
"doc2": "Report 2",
|
||||
},
|
||||
{
|
||||
"text": "Sample document 3",
|
||||
"feedback": "Excellent service",
|
||||
"doc1": "Policy A",
|
||||
"doc2": "Policy B",
|
||||
},
|
||||
]
|
||||
|
||||
fake_pipeline = {
|
||||
"operations": (
|
||||
[test_case.input_config]
|
||||
if isinstance(test_case.input_config, dict)
|
||||
else test_case.input_config
|
||||
),
|
||||
"name": "test_pipeline",
|
||||
}
|
||||
|
||||
# Create temporary input file
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as f:
|
||||
json.dump(sample_data, f, indent=2)
|
||||
temp_file_path = f.name
|
||||
|
||||
try:
|
||||
# 1. Execute the directive
|
||||
# instantiate returns (new_ops_plan, message_history, call_cost)
|
||||
instantiate_result = self.instantiate(
|
||||
operators=(
|
||||
[test_case.input_config]
|
||||
if isinstance(test_case.input_config, dict)
|
||||
else test_case.input_config
|
||||
),
|
||||
target_ops=test_case.target_ops,
|
||||
agent_llm=agent_llm,
|
||||
input_file_path=temp_file_path,
|
||||
pipeline_code=fake_pipeline,
|
||||
)
|
||||
# Handle both 2-tuple and 3-tuple returns
|
||||
if isinstance(instantiate_result, tuple):
|
||||
actual_output = instantiate_result[0]
|
||||
else:
|
||||
actual_output = instantiate_result
|
||||
|
||||
# 2. Use LLM judge to evaluate
|
||||
judge_result = self._llm_judge_test(
|
||||
test_case=test_case,
|
||||
actual_output=actual_output,
|
||||
agent_llm=agent_llm,
|
||||
)
|
||||
|
||||
results.append(
|
||||
TestResult(
|
||||
test_name=test_case.name,
|
||||
passed=judge_result["passed"],
|
||||
reason=judge_result["reason"],
|
||||
actual_output=actual_output,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
try:
|
||||
os.unlink(temp_file_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
# Handle execution errors
|
||||
expected_to_fail = not test_case.should_pass
|
||||
results.append(
|
||||
TestResult(
|
||||
test_name=test_case.name,
|
||||
passed=expected_to_fail,
|
||||
reason=f"Execution {'failed as expected' if expected_to_fail else 'failed unexpectedly'}: {str(e)}",
|
||||
actual_output=None,
|
||||
execution_error=str(e),
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def _llm_judge_test(
|
||||
self, test_case: DirectiveTestCase, actual_output: Any, agent_llm: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Use LLM as judge to evaluate test results"""
|
||||
|
||||
user_message = f"""
|
||||
Evaluate whether the directive execution meets the expected behavior.
|
||||
|
||||
**Test Case:** {test_case.name}
|
||||
**Description:** {test_case.description}
|
||||
**Expected Behavior:** {test_case.expected_behavior}
|
||||
**Should Pass:** {test_case.should_pass}
|
||||
|
||||
**Original Input Configuration:**
|
||||
{test_case.input_config}
|
||||
|
||||
**Actual Output from Directive:**
|
||||
{actual_output}
|
||||
|
||||
**Evaluation Criteria:**
|
||||
- If should_pass=True: Does the output demonstrate the expected behavior?
|
||||
- If should_pass=False: Did the directive appropriately reject/not modify the input?
|
||||
- Does the transformation make logical sense?
|
||||
- Are all required elements preserved?
|
||||
|
||||
Provide your reasoning and a boolean pass/fail decision.
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"You are an expert judge evaluating whether a {self.name} directive execution meets the specified criteria. Focus on logical correctness and adherence to expected behavior.",
|
||||
},
|
||||
{"role": "user", "content": user_message},
|
||||
]
|
||||
|
||||
class JudgeResponse(BaseModel):
|
||||
passed: bool
|
||||
reason: str
|
||||
|
||||
response = completion(
|
||||
model=agent_llm,
|
||||
messages=messages,
|
||||
response_format=JudgeResponse,
|
||||
)
|
||||
|
||||
# Parse the JSON response
|
||||
|
||||
parsed_content = json.loads(response.choices[0].message.content)
|
||||
|
||||
return {"passed": parsed_content["passed"], "reason": parsed_content["reason"]}
|
||||
|
|
@ -1,443 +0,0 @@
|
|||
import json
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
CascadeFilteringInstantiateSchema,
|
||||
)
|
||||
|
||||
from .agent_utils import AgenticDirectiveRunner
|
||||
from .base import Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class CascadeFilteringDirective(Directive):
|
||||
name: str = Field(
|
||||
default="cascade_filtering", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(
|
||||
default="Filter => (Code Filter* -> Filter(gpt-5-nano)* ->) Filter"
|
||||
)
|
||||
nl_description: str = Field(
|
||||
default="Optimizes filtering costs by injecting a cascade of cheaper filters before the main filter. Starts with deterministic code filters (cheapest), then gpt-5-nano filters (ordered by prompt length), before the original expensive filter. Pre-filters prioritize high recall (rarely rejecting valid documents) and can have lower precision (okay to let through some invalid docs that the main filter will catch)."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When you have an expensive Filter operation (using costly models or complex prompts) and the data contains patterns that allow for cheaper pre-filtering. The pre-filters must have high recall (not dropping valid documents) but can have lower precision, as the final filter provides the actual precision."
|
||||
)
|
||||
|
||||
instantiate_schema_type: Type[BaseModel] = Field(
|
||||
default=CascadeFilteringInstantiateSchema
|
||||
)
|
||||
|
||||
example: str = Field(
|
||||
default="""
|
||||
# Example 1: Disjunction - Legal Document Relevance Filter
|
||||
Target Operation:
|
||||
- name: filter_litigation_relevant_docs
|
||||
type: filter
|
||||
model: gpt-4o
|
||||
prompt: |
|
||||
Determine if this document is relevant to our patent litigation case.
|
||||
|
||||
The document is relevant if it contains ANY of:
|
||||
1. Prior art references dated before 2015-03-15 describing similar technology
|
||||
2. Internal emails discussing the patent claims or invalidity arguments
|
||||
3. Technical specifications that contradict our patent claims
|
||||
4. License agreements mentioning the patents-in-suit
|
||||
5. Expert testimony or declarations about the technology
|
||||
6. Financial damages calculations or royalty discussions
|
||||
|
||||
Patent numbers in suit: 8,123,456 and 9,234,567
|
||||
Technology area: adaptive bitrate streaming for video delivery
|
||||
|
||||
Document: {{ input.document_text }}
|
||||
Metadata: {{ input.document_metadata }}
|
||||
output:
|
||||
schema:
|
||||
is_relevant: boolean
|
||||
|
||||
Example InstantiateSchema (agent recognizes the OR conditions can be split):
|
||||
CascadeFilteringInstantiateSchema(
|
||||
code_pre_filters=[
|
||||
CodePreFilter(
|
||||
name="has_patent_numbers",
|
||||
code=\"\"\"def transform(input_doc):
|
||||
text = (input_doc.get('document_text', '') + ' ' + str(input_doc.get('document_metadata', {}))).lower()
|
||||
# If patent numbers are mentioned, likely relevant
|
||||
return '8,123,456' in text or '9,234,567' in text or 'patent' in text\"\"\",
|
||||
reasoning="Documents mentioning our patent numbers or patents in general might be relevant"
|
||||
),
|
||||
CodePreFilter(
|
||||
name="has_tech_or_legal_terms",
|
||||
code=\"\"\"def transform(input_doc):
|
||||
text = input_doc.get('document_text', '').lower()
|
||||
# Must have streaming tech or legal terms to possibly be relevant
|
||||
terms = ['streaming', 'bitrate', 'video', 'codec', 'prior art', 'license',
|
||||
'damages', 'royalty', 'testimony', 'expert', 'invalidity']
|
||||
return any(term in text for term in terms)\"\"\",
|
||||
reasoning="Documents without technical or legal terminology cannot be relevant"
|
||||
)
|
||||
],
|
||||
llm_pre_filters=[
|
||||
LLMPreFilter(
|
||||
name="check_prior_art_date",
|
||||
prompt="Does this mention technology from before March 2015? {{ input.document_metadata }} If yes; set 'keep' to true. If no; set 'keep' to false.",
|
||||
reasoning="Prior art must predate the patent filing"
|
||||
),
|
||||
LLMPreFilter(
|
||||
name="check_email_discussion",
|
||||
prompt="Is this an email or discussion about patents? {{ input.document_text }} If yes; set 'keep' to true. If no; set 'keep' to false.",
|
||||
reasoning="Internal emails about patents might be relevant"
|
||||
),
|
||||
LLMPreFilter(
|
||||
name="check_technical_specs",
|
||||
prompt="Does this contain technical specifications? {{ input.document_text }} If yes; set 'keep' to true. If no; set 'keep' to false.",
|
||||
reasoning="Technical specs might contradict our claims"
|
||||
)
|
||||
],
|
||||
analysis_summary="Filter has 6 OR criteria. Pre-filters check each criterion cheaply, eliminating 70% of documents before expensive analysis"
|
||||
)
|
||||
|
||||
# Example 2: Complex Reasoning - Misinformation Detection Filter
|
||||
Target Operation:
|
||||
- name: filter_misinformation
|
||||
type: filter
|
||||
model: gpt-4o
|
||||
prompt: |
|
||||
Analyze this social media post to determine if it contains health misinformation.
|
||||
|
||||
Consider:
|
||||
- Claims that contradict scientific consensus
|
||||
- Misleading statistics or cherry-picked data
|
||||
- Conspiracy theories about health organizations
|
||||
- Dangerous medical advice without proper qualifications
|
||||
- Manipulation of legitimate studies to support false conclusions
|
||||
- Emotional manipulation to spread fear about vaccines/treatments
|
||||
|
||||
Requires nuanced understanding of:
|
||||
- Context and speaker credibility
|
||||
- Difference between opinion and factual claims
|
||||
- Satirical vs serious content
|
||||
- Cultural/religious beliefs vs dangerous misinformation
|
||||
|
||||
Post: {{ input.post_content }}
|
||||
Author Profile: {{ input.author_info }}
|
||||
Engagement Metrics: {{ input.engagement_stats }}
|
||||
output:
|
||||
schema:
|
||||
is_misinformation: boolean
|
||||
|
||||
Example InstantiateSchema (agent deduces proxy filters for complex reasoning):
|
||||
CascadeFilteringInstantiateSchema(
|
||||
code_pre_filters=[
|
||||
CodePreFilter(
|
||||
name="has_health_content",
|
||||
code=\"\"\"def transform(input_doc):
|
||||
text = input_doc.get('post_content', '').lower()
|
||||
# Must mention health/medical topics to potentially be health misinfo
|
||||
health_terms = ['vaccine', 'covid', 'cancer', 'cure', 'treatment', 'doctor',
|
||||
'medical', 'health', 'disease', 'virus', 'immune', 'pharmaceutical']
|
||||
return any(term in text for term in health_terms)\"\"\",
|
||||
reasoning="Posts without health-related content cannot be health misinformation"
|
||||
),
|
||||
CodePreFilter(
|
||||
name="has_claims_language",
|
||||
code=\"\"\"def transform(input_doc):
|
||||
text = input_doc.get('post_content', '').lower()
|
||||
# Look for language that makes claims rather than just sharing experience
|
||||
claim_markers = ['proven', 'studies show', 'research', 'scientists', 'they don\\'t want',
|
||||
'truth about', 'actually', 'fact', 'evidence', 'causes', 'prevents']
|
||||
return any(marker in text for marker in claim_markers) or '!' in text\"\"\",
|
||||
reasoning="Posts without claim-making language are less likely to be misinformation"
|
||||
)
|
||||
],
|
||||
llm_pre_filters=[
|
||||
LLMPreFilter(
|
||||
name="check_medical_claims",
|
||||
prompt="Does this post make medical or health claims? {{ input.post_content }} If yes; set 'keep' to true. If no; set 'keep' to false.",
|
||||
reasoning="Posts making medical claims need detailed analysis"
|
||||
),
|
||||
LLMPreFilter(
|
||||
name="check_high_engagement",
|
||||
prompt="Is this a high-engagement post (>1000 shares or viral)? {{ input.engagement_stats }} If yes; set 'keep' to true. If no; set 'keep' to false.",
|
||||
reasoning="High-engagement posts have more potential for harm if they contain misinformation"
|
||||
)
|
||||
],
|
||||
analysis_summary="Complex task requiring reasoning about credibility, context, and scientific accuracy. Pre-filters identify posts that definitely need review (30% of total) by checking for health content, claim-making language, and viral spread potential"
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="single_filter_cascade",
|
||||
description="Should create cascade of filters before expensive filter",
|
||||
input_config={
|
||||
"name": "filter_quality",
|
||||
"type": "filter",
|
||||
"model": "gpt-4o",
|
||||
"prompt": "Is this a high-quality research paper? Paper: {{ input.paper }}",
|
||||
"output": {"schema": {"is_quality": "boolean"}},
|
||||
},
|
||||
target_ops=["filter_quality"],
|
||||
expected_behavior="Should inject cheaper pre-filters (code and/or gpt-5-nano) before the expensive gpt-4o filter",
|
||||
should_pass=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, CascadeFilteringDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("CascadeFilteringDirective")
|
||||
|
||||
def _extract_input_keys(self, prompt: str) -> List[str]:
|
||||
"""Extract input field names from a Jinja2 prompt template."""
|
||||
# Match {{ input.fieldname }} patterns
|
||||
pattern = r"\{\{\s*input\.(\w+)\s*\}\}"
|
||||
matches = re.findall(pattern, prompt)
|
||||
return list(set(matches)) # Return unique field names
|
||||
|
||||
def to_string_for_instantiate(
|
||||
self, target_ops_configs: List[Dict], pipeline_code: Dict = None
|
||||
) -> str:
|
||||
"""Generate prompt for agent to analyze data and create cascade filters."""
|
||||
assert (
|
||||
len(target_ops_configs) == 1
|
||||
), "CascadeFiltering directive only supports single target operation"
|
||||
assert (
|
||||
target_ops_configs[0]["type"] == "filter"
|
||||
), "Target operation must be a filter"
|
||||
|
||||
op = target_ops_configs[0]
|
||||
input_keys = self._extract_input_keys(op.get("prompt", ""))
|
||||
|
||||
pipeline_context = ""
|
||||
if pipeline_code:
|
||||
pipeline_context = f"""
|
||||
Pipeline Context:
|
||||
{json.dumps(pipeline_code, indent=2)}
|
||||
|
||||
The target filter '{op['name']}' fits into this broader pipeline. Consider what types of documents flow into this filter and what the downstream operations expect.
|
||||
"""
|
||||
|
||||
return (
|
||||
f"You are an expert at optimizing filter operations by creating efficient filter cascades.\n\n"
|
||||
f"Target Filter Operation:\n"
|
||||
f"{json.dumps(op, indent=2)}\n\n"
|
||||
f"Input fields used in the original prompt: {input_keys}\n\n"
|
||||
f"{pipeline_context}\n"
|
||||
f"Your task is to analyze sample input data and create a cascade of cheaper pre-filters that can eliminate many documents before they reach the expensive main filter.\n\n"
|
||||
f"You will be given access to sample input data through a read_next_docs() function. Use this to:\n"
|
||||
f"1. Understand what documents should pass vs. fail the main filter\n"
|
||||
f"2. Identify simple patterns that can predict filter outcomes\n"
|
||||
f"3. Design code-based filters for deterministic patterns (cheapest)\n"
|
||||
f"4. Design gpt-5-nano filters for simple semantic patterns (still cheap)\n\n"
|
||||
f"Guidelines for code_pre_filters:\n"
|
||||
f"- Must be Python functions with signature: def transform(input_doc): return bool\n"
|
||||
f"- Should use regex, keyword matching, length checks, or simple logic\n"
|
||||
f"- Must be deterministic and fast\n"
|
||||
f"- Return True to keep the document, False to filter it out\n"
|
||||
f"- Access document fields using input_doc.get('fieldname', default_value)\n\n"
|
||||
f"Guidelines for llm_pre_filters:\n"
|
||||
f"- Must be Jinja2 templates that reference input fields using {{{{ input.fieldname }}}} syntax\n"
|
||||
f"- Must reference the same input fields as the original prompt\n"
|
||||
f"- Should be simple, focused prompts that elicit a yes/no or true/false response\n"
|
||||
f"- Keep prompts SHORT - they will be ordered by length automatically\n"
|
||||
f"- Must use at least one of these input fields: {input_keys}\n\n"
|
||||
f"General Guidelines:\n"
|
||||
f"- Pre-filters MUST have HIGH RECALL (rarely reject documents that would pass the main filter)\n"
|
||||
f"- Pre-filters can have LOWER PRECISION (okay to let through documents that fail - the main filter will catch them)\n"
|
||||
f"- When in doubt, let the document through (return True) - false negatives are worse than false positives\n\n"
|
||||
f"- All documents will have the same keys. So don't write code that checks if a particular key exists in a document or not.\n"
|
||||
f"Example transformation:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Analyze samples strategically - look for patterns that distinguish documents that pass vs fail the filter.\n"
|
||||
f"When you have identified enough patterns to create effective pre-filters, output your result.\n\n"
|
||||
f"Remember: The goal is to reduce costs by filtering out documents early with cheaper methods while maintaining the same final accuracy."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
target_ops_configs: List[Dict],
|
||||
input_file_path: str,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
pipeline_code: Dict = None,
|
||||
):
|
||||
"""Use agentic approach to analyze data and generate cascade filters."""
|
||||
# Load sample input data
|
||||
try:
|
||||
with open(input_file_path, "r") as f:
|
||||
input_data = json.load(f)
|
||||
|
||||
if not isinstance(input_data, list) or len(input_data) == 0:
|
||||
raise ValueError(
|
||||
"Input file must contain a non-empty list of sample data"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Failed to load input data from {input_file_path}: {str(e)}"
|
||||
)
|
||||
|
||||
# Extract input keys from original prompt for validation
|
||||
original_prompt = target_ops_configs[0].get("prompt", "")
|
||||
expected_input_keys = self._extract_input_keys(original_prompt)
|
||||
|
||||
def validate_llm_prompts(schema_instance):
|
||||
"""Validate that LLM prompts use correct input fields."""
|
||||
for llm_filter in schema_instance.llm_pre_filters:
|
||||
used_keys = set(
|
||||
re.findall(r"\{\{\s*input\.(\w+)\s*\}\}", llm_filter.prompt)
|
||||
)
|
||||
invalid_keys = used_keys - set(expected_input_keys)
|
||||
if invalid_keys:
|
||||
raise ValueError(
|
||||
f"LLM filter '{llm_filter.name}' uses invalid input fields: {invalid_keys}. "
|
||||
f"Available fields from original prompt: {expected_input_keys}"
|
||||
)
|
||||
if not used_keys:
|
||||
raise ValueError(
|
||||
f"LLM filter '{llm_filter.name}' must reference at least one input field "
|
||||
f"from the original prompt: {expected_input_keys}"
|
||||
)
|
||||
|
||||
# Set up agentic runner with validation
|
||||
runner = AgenticDirectiveRunner(
|
||||
input_data=input_data,
|
||||
agent_llm=agent_llm,
|
||||
validation_func=validate_llm_prompts,
|
||||
)
|
||||
|
||||
# Create system prompt
|
||||
system_prompt = (
|
||||
"You are an expert at optimizing data processing pipelines by creating efficient filter cascades. "
|
||||
"You analyze data to identify patterns that allow for cheap pre-filtering before expensive operations. "
|
||||
"Your goal is to reduce costs while maintaining accuracy by using progressively more expensive filters."
|
||||
)
|
||||
|
||||
# Create initial user message
|
||||
initial_message = self.to_string_for_instantiate(
|
||||
target_ops_configs, pipeline_code
|
||||
)
|
||||
|
||||
# Run the agentic loop
|
||||
try:
|
||||
schema, updated_message_history, call_cost = runner.run_agentic_loop(
|
||||
system_prompt=system_prompt,
|
||||
initial_user_message=initial_message,
|
||||
response_schema=CascadeFilteringInstantiateSchema,
|
||||
)
|
||||
|
||||
# Update message history
|
||||
message_history.extend(updated_message_history)
|
||||
|
||||
return schema, message_history, call_cost
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Failed to instantiate cascade_filtering directive: {str(e)}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model: str,
|
||||
ops_list: List[Dict],
|
||||
target_ops: List[str],
|
||||
rewrite: CascadeFilteringInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""Apply the directive by injecting pre-filters before the target filter."""
|
||||
new_ops_list = []
|
||||
|
||||
for op in ops_list:
|
||||
if op["name"] in target_ops:
|
||||
# First, add all code filters
|
||||
for i, code_filter in enumerate(rewrite.code_pre_filters):
|
||||
code_filter_op = {
|
||||
"name": f"{code_filter.name}_{op['name']}",
|
||||
"type": "code_filter",
|
||||
"code": code_filter.code,
|
||||
}
|
||||
new_ops_list.append(code_filter_op)
|
||||
|
||||
# Sort LLM filters by prompt length (shortest first for cost efficiency)
|
||||
sorted_llm_filters = sorted(
|
||||
rewrite.llm_pre_filters, key=lambda f: len(f.prompt)
|
||||
)
|
||||
|
||||
# Then, add all LLM filters (using gpt-5-nano)
|
||||
for i, llm_filter in enumerate(sorted_llm_filters):
|
||||
llm_filter_op = {
|
||||
"name": f"{llm_filter.name}_{op['name']}",
|
||||
"type": "filter",
|
||||
"model": "gpt-5-nano",
|
||||
"prompt": llm_filter.prompt,
|
||||
"output": {"schema": {"keep": "boolean"}},
|
||||
}
|
||||
new_ops_list.append(llm_filter_op)
|
||||
|
||||
# Finally, add the original filter
|
||||
new_ops_list.append(deepcopy(op))
|
||||
else:
|
||||
# Keep other operations as-is
|
||||
new_ops_list.append(deepcopy(op))
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Main method that orchestrates directive instantiation:
|
||||
1. Use agentic approach to analyze data and generate cascade filters
|
||||
2. Apply the transformation by injecting pre-filters
|
||||
"""
|
||||
assert (
|
||||
len(target_ops) == 1
|
||||
), "CascadeFiltering directive requires exactly one target operation"
|
||||
|
||||
input_file_path = kwargs.get("input_file_path", None)
|
||||
pipeline_code = kwargs.get("pipeline_code", None)
|
||||
|
||||
if not input_file_path:
|
||||
raise ValueError(
|
||||
"input_file_path is required for CascadeFiltering directive"
|
||||
)
|
||||
|
||||
# Get configuration for target operation
|
||||
target_ops_configs = [op for op in operators if op["name"] in target_ops]
|
||||
|
||||
if not target_ops_configs:
|
||||
raise ValueError(f"Target operation {target_ops[0]} not found in operators")
|
||||
|
||||
if target_ops_configs[0]["type"] != "filter":
|
||||
raise ValueError(
|
||||
f"Target operation {target_ops[0]} must be a filter operation"
|
||||
)
|
||||
|
||||
# Step 1: Agent analyzes data and generates cascade filters
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
target_ops_configs,
|
||||
input_file_path,
|
||||
agent_llm,
|
||||
message_history,
|
||||
pipeline_code,
|
||||
)
|
||||
|
||||
# Step 2: Apply transformation
|
||||
return (
|
||||
self.apply(global_default_model, operators, target_ops, rewrite),
|
||||
message_history,
|
||||
call_cost,
|
||||
)
|
||||
|
|
@ -1,313 +0,0 @@
|
|||
import json
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import ChainingInstantiateSchema
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class ChainingDirective(Directive):
|
||||
name: str = Field(default="chaining", description="The name of the directive")
|
||||
formal_description: str = Field(default="Op => Map* -> Op")
|
||||
nl_description: str = Field(
|
||||
default="Decompose a complex operation into a sequence by inserting one or more Map steps that rewrite the input for the next operation. Each Map step outputs a 'result' string, and the downstream operation uses this result in its prompt."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When the original task is too complex for one step and should be split into a series (e.g., first extract key facts, then generate a summary based on those facts)."
|
||||
)
|
||||
instantiate_schema_type: Type[BaseModel] = ChainingInstantiateSchema
|
||||
|
||||
example: str = Field(
|
||||
default=(
|
||||
"Original Op (MapOpConfig):\n"
|
||||
"- name: extract_newly_prescribed_treatments\n"
|
||||
" type: map\n"
|
||||
" prompt: |\n"
|
||||
" For a hospital discharge summary, extract every treatment that was prescribed specifically for newly diagnosed conditions.\n"
|
||||
' Discharge summary: "{{ input.summary }}"\n'
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
" treatments: list[str]\n"
|
||||
"\n"
|
||||
"Example InstantiateSchema (must refer to the same input document keys as the original Op, and subsequent Map operators must refer to the output of the previous Map operator):\n"
|
||||
"[\n"
|
||||
" MapOpConfig(\n"
|
||||
" name='identify_new_conditions',\n"
|
||||
" prompt='''Review the following hospital discharge summary:\n"
|
||||
"{{ input.summary }}\n"
|
||||
"Identify all medical conditions that are explicitly marked as new diagnoses (e.g., 'new diagnosis of atrial fibrillation', 'recent onset heart failure').\n"
|
||||
"Return a list of newly diagnosed conditions.''',\n"
|
||||
" output_keys=['new_conditions'],\n"
|
||||
" ),\n"
|
||||
" MapOpConfig(\n"
|
||||
" name='extract_treatments_for_new_conditions',\n"
|
||||
" prompt='''For each newly diagnosed condition listed below, extract every treatment or medication prescribed for that specific condition from the discharge summary.\n"
|
||||
"Discharge summary: {{ input.summary }}\n"
|
||||
"Newly diagnosed conditions: {{ input.new_conditions }}\n"
|
||||
"Return a list of prescribed treatments or medications for each condition.''',\n"
|
||||
" output_keys=['treatments'],\n"
|
||||
" ),\n"
|
||||
"]"
|
||||
),
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="complex_contract_extraction",
|
||||
description="Should decompose complex contract extraction into separate ops for each term type and a final unification op",
|
||||
input_config={
|
||||
"name": "extract_contract_terms",
|
||||
"type": "map",
|
||||
"prompt": "Extract all payment terms, liability clauses, and termination conditions from: {{ input.contract }}",
|
||||
"output": {"schema": {"terms": "list[str]"}},
|
||||
},
|
||||
target_ops=["extract_contract_terms"],
|
||||
expected_behavior="Should create one op for each of: payment terms, liability clauses, and termination conditions, then a final op to unify all results into 'terms'",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="medical_treatment_analysis",
|
||||
description="Should chain complex medical analysis into steps",
|
||||
input_config={
|
||||
"name": "analyze_patient_care",
|
||||
"type": "map",
|
||||
"prompt": "From this medical record, identify all diagnoses, treatments prescribed, and patient outcomes: {{ input.medical_record }}",
|
||||
"output": {"schema": {"analysis": "string"}},
|
||||
},
|
||||
target_ops=["analyze_patient_care"],
|
||||
expected_behavior="Should decompose into separate steps for diagnoses identification, treatment extraction, and outcome analysis",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="research_paper_analysis",
|
||||
description="Should chain research paper analysis into structured steps",
|
||||
input_config={
|
||||
"name": "analyze_research_paper",
|
||||
"type": "map",
|
||||
"prompt": "From this research paper, extract the methodology, key findings, limitations, and future work directions: {{ input.paper }}",
|
||||
"output": {"schema": {"analysis": "string"}},
|
||||
},
|
||||
target_ops=["analyze_research_paper"],
|
||||
expected_behavior="Should decompose into separate steps for methodology extraction, findings identification, limitation analysis, and future work extraction",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ChainingDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("ChainingDirective")
|
||||
|
||||
def to_string_for_instantiate(self, original_op: Dict) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive.
|
||||
|
||||
Args:
|
||||
original_op (str): The YAML or string representation of the original operation.
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
return (
|
||||
f"You are an expert at decomposing complex data processing operations into modular steps.\n\n"
|
||||
f"Original Operation:\n"
|
||||
f"{str(original_op)}\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to instantiate this directive by generating a list of new Map operators (as MapOpConfig objects) that decompose the original operation into a sequence of simpler steps. "
|
||||
f"Each Map step should output a 'result' or other relevant keys, and downstream steps should use the outputs of previous steps as their input. "
|
||||
f"Ensure that the chain of Map operators together accomplishes the intent of the original operation, but in a more modular and stepwise fashion.\n\n"
|
||||
f"""Key Issues to ensure:\n
|
||||
1. Ensure the new prompts don't reference categories, lists, or criteria without providing them or making them available from previous steps.\n
|
||||
2. Every detail, category, instruction, requirement, and definition from the original must be present in the new configuration.\n
|
||||
3. Confirm that each step can access all required information either from the original document or from outputs of preceding steps.\n"""
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output only the InstantiateSchema (a list of MapOpConfig objects) for the new chain, referring to the same input document keys as the original operation and chaining outputs appropriately."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
original_op: Dict,
|
||||
expected_input_keys: List[str],
|
||||
expected_output_keys: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive by decomposing the original operation.
|
||||
|
||||
Args:
|
||||
original_op (Dict): The original operation.
|
||||
expected_input_keys (List[str]): A list of input keys that the operation is expected to reference in its prompt. Each key should correspond to a field in the input document that must be used by the operator.
|
||||
expected_output_keys (List[str]): A list of output keys that the last operation is expected to produce.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
|
||||
Returns:
|
||||
ChainingInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant for document processing pipelines.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(original_op),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=ChainingInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
if "new_ops" not in parsed_res:
|
||||
raise ValueError(
|
||||
"Response from LLM is missing required key 'new_ops'"
|
||||
)
|
||||
new_ops = parsed_res["new_ops"]
|
||||
schema = ChainingInstantiateSchema(new_ops=new_ops)
|
||||
# Validate the chain with required input/output keys
|
||||
ChainingInstantiateSchema.validate_chain(
|
||||
new_ops=schema.new_ops,
|
||||
required_input_keys=expected_input_keys,
|
||||
expected_output_keys=expected_output_keys,
|
||||
)
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model,
|
||||
ops_list: List[Dict],
|
||||
target_op: str,
|
||||
rewrite: ChainingInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config.
|
||||
"""
|
||||
# Create a copy of the pipeline config
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find position of the target ops to replace
|
||||
|
||||
for i, op in enumerate(ops_list):
|
||||
if op["name"] == target_op:
|
||||
pos_to_replace = i
|
||||
orig_op = op
|
||||
break
|
||||
|
||||
# pos_to_replace = [i for i, op in enumerate(ops_list) if op["name"] == target_op][0]
|
||||
|
||||
# Create the new ops from the rewrite
|
||||
new_ops = []
|
||||
|
||||
defualt_model = global_default_model
|
||||
if "model" in orig_op:
|
||||
defualt_model = orig_op["model"]
|
||||
|
||||
for i, op in enumerate(rewrite.new_ops):
|
||||
if i < len(rewrite.new_ops) - 1:
|
||||
new_ops.append(
|
||||
{
|
||||
"name": op.name,
|
||||
"type": "map",
|
||||
"prompt": op.prompt,
|
||||
"model": defualt_model,
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
"output": {"schema": {key: "string" for key in op.output_keys}},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Last op in the chain
|
||||
new_ops.append(
|
||||
{
|
||||
"name": op.name,
|
||||
"type": "map",
|
||||
"prompt": op.prompt,
|
||||
"model": defualt_model,
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
"output": new_ops_list[pos_to_replace]["output"],
|
||||
}
|
||||
)
|
||||
|
||||
# Remove the target op and insert the new ops
|
||||
new_ops_list.pop(pos_to_replace)
|
||||
new_ops_list[pos_to_replace:pos_to_replace] = new_ops
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
optimize_goal="acc",
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Assert that there is only one target op
|
||||
assert (
|
||||
len(target_ops) == 1
|
||||
), "There must be exactly one target op to instantiate this chaining directive"
|
||||
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
|
||||
|
||||
# Get the expected input/output keys
|
||||
expected_output_keys = list(target_op_config["output"]["schema"].keys())
|
||||
|
||||
# Extract expected input keys from the target op's prompt template
|
||||
prompt_template = target_op_config["prompt"]
|
||||
# Find all occurrences of {{ input.key }} in the prompt
|
||||
input_key_pattern = r"\{\{\s*input\.([^\}\s]+)\s*\}\}"
|
||||
expected_input_keys = list(set(re.findall(input_key_pattern, prompt_template)))
|
||||
|
||||
print("input key: ", expected_input_keys)
|
||||
print("output key: ", expected_output_keys)
|
||||
|
||||
# Instantiate the directive
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
target_op_config,
|
||||
expected_input_keys,
|
||||
expected_output_keys,
|
||||
agent_llm,
|
||||
message_history,
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators
|
||||
new_ops_plan = self.apply(
|
||||
global_default_model, operators, target_ops[0], rewrite
|
||||
)
|
||||
return new_ops_plan, message_history, call_cost
|
||||
|
|
@ -1,311 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import ChangeModelInstantiateSchema
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class ChangeModelDirective(Directive):
|
||||
name: str = Field(default="change model", description="The name of the directive")
|
||||
formal_description: str = Field(
|
||||
default="Op => Op* (same operation with a different model choice)"
|
||||
)
|
||||
nl_description: str = Field(
|
||||
default="Rewrites an operator to use a different LLM model based on task requirements. Generally, simpler tasks like extraction or classification may work well with cheaper models (gpt-4o-mini, gpt-5-nano), while complex reasoning tasks often benefit from more powerful models (gpt-5), though actual performance and constraints should guide the choice."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When the current model choice may not be optimal for the task requirements, considering factors like task complexity, performance needs, cost constraints, and quality requirements."
|
||||
)
|
||||
instantiate_schema_type: Type[BaseModel] = ChangeModelInstantiateSchema
|
||||
|
||||
example: str = Field(
|
||||
default=(
|
||||
"Original Op (MapOpConfig):\n"
|
||||
"- name: extract_insights\n"
|
||||
" type: map\n"
|
||||
" prompt: |\n"
|
||||
" From the user log below, list 2-3 concise insights (1-2 words each) and 1-2 supporting actions per insight.\n"
|
||||
" Return as a list of dictionaries with 'insight' and 'supporting_actions'.\n"
|
||||
" Log: {{ input.log }}\n"
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
' insights_summary: "string"\n'
|
||||
" model: gpt-5\n"
|
||||
"\n"
|
||||
"Example InstantiateSchema:\n"
|
||||
"{\n"
|
||||
' "model": "gpt-4o-mini"\n'
|
||||
"}"
|
||||
),
|
||||
)
|
||||
|
||||
allowed_model_list: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="The allowed list of models to choose from",
|
||||
)
|
||||
|
||||
model_info: str = Field(
|
||||
default=(
|
||||
"""
|
||||
OpenAI-MRCR evaluates a model's ability to locate and disambiguate multiple well-hidden "needles" within a large context.
|
||||
Below are the actual performance scores for the 8-needle retrieval task at various context lengths. Use these results to compare the retrieval capabilities of each model.
|
||||
The below results are the mean match ratio.
|
||||
|
||||
Input Tokens (1000s) | GPT-5 | GPT-5 nano | GPT-4o mini
|
||||
---------------------|---------|--------------|-------------------|
|
||||
8 | 99% | 69% | 32% |
|
||||
16 | 100% | 64% | 30% |
|
||||
32 | 96% | 55% | 27% |
|
||||
64 | 98% | 45% | 25% |
|
||||
128 | 97% | 44% | 25% |
|
||||
256 | 92% | 40% | - |
|
||||
|
||||
|
||||
The context window and pricing details for each model are shown below (token prices are per 1 million tokens):
|
||||
Model | GPT-5-nano | GPT-4o-mini | GPT-5
|
||||
-------------------|--------------|-------------|----------
|
||||
Context Window | 400,000 | 128,000 | 400,000
|
||||
Max Output Tokens | 128,000 | 16,384 | 128,000
|
||||
Input Token Price | $0.05 | $0.15 | $1.25
|
||||
Output Token Price | $0.40 | $0.60 | $10
|
||||
"""
|
||||
),
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="cost_optimization_simple_task",
|
||||
description="Should suggest cheaper model for simple extraction task",
|
||||
input_config={
|
||||
"name": "extract_names",
|
||||
"type": "map",
|
||||
"prompt": "Extract person names from: {{ input.text }}",
|
||||
"output": {"schema": {"names": "list[str]"}},
|
||||
"model": "gpt-5",
|
||||
},
|
||||
target_ops=["extract_names"],
|
||||
expected_behavior="Should recommend gpt-4o-mini or gpt-5-nano for cost savings on simple task",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="complex_analysis_needs_powerful_model",
|
||||
description="Should recommend powerful model for complex reasoning task",
|
||||
input_config={
|
||||
"name": "analyze_legal_implications",
|
||||
"type": "map",
|
||||
"prompt": "Analyze the legal implications and potential risks in this complex contract: {{ input.contract }}",
|
||||
"output": {"schema": {"analysis": "string"}},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
target_ops=["analyze_legal_implications"],
|
||||
expected_behavior="Should recommend gpt-5 for complex legal analysis requiring strong reasoning",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="high_volume_processing_optimization",
|
||||
description="Should optimize model for high-volume document processing",
|
||||
input_config={
|
||||
"name": "batch_document_summary",
|
||||
"type": "map",
|
||||
"prompt": "Summarize this document in 2-3 sentences: {{ input.document }}",
|
||||
"output": {"schema": {"summary": "str"}},
|
||||
"model": "gpt-5",
|
||||
},
|
||||
target_ops=["batch_document_summary"],
|
||||
expected_behavior="Should recommend faster/cheaper model like gpt-4o-mini or gpt-5-nano for high-volume simple summarization",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ChangeModelDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("ChangeModelDirective")
|
||||
|
||||
def to_string_for_instantiate(self, original_op: Dict, optimize_goal) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive.
|
||||
|
||||
Args:
|
||||
original_op (str): The YAML or string representation of the original operation.
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
if optimize_goal == "acc":
|
||||
return (
|
||||
f"You are an expert at choosing the most suitable model for a given task based on complexity and cost considerations.\n\n"
|
||||
f"Original Operation:\n"
|
||||
f"{str(original_op)}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to instantiate this directive by suggesting a better model for executing the original operation.\n\n"
|
||||
f"MODEL SELECTION CONSIDERATIONS:\n"
|
||||
f"• Generally, simpler tasks (extraction, basic classification, straightforward summarization) may work well with cheaper models like gpt-4o-mini or gpt-5-nano\n"
|
||||
f"• Complex reasoning tasks (analysis, interpretation, multi-step thinking, legal/medical analysis) often benefit from more powerful models like gpt-5\n"
|
||||
f"• However, consider actual performance needs, quality requirements, and cost constraints when making the choice\n"
|
||||
f"• Sometimes a powerful model may be needed for seemingly simple tasks if quality is critical, or a cheaper model may suffice for complex tasks if budget is constrained\n\n"
|
||||
f"You have a list of allowed models to choose from: {str(self.allowed_model_list)}.\n\n"
|
||||
f"Consider the information about the allowed models: \n {self.model_info}\n"
|
||||
f"Your response should include the new model choice for the operation."
|
||||
f"Ensure that your chosen model is in the list of allowed models."
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output only the ChangeModelInstantiateSchema as JSON."
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"You are an expert at choosing the most suitable model for a given task to optimize cost.\n\n"
|
||||
f"Original Operation:\n"
|
||||
f"{str(original_op)}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to instantiate this directive by generating a ChangeModelConfig that suggests a more cost-effective model for executing the original operation."
|
||||
f"You have a list of allowed models to choose from: {str(self.allowed_model_list)}.\n\n"
|
||||
f"Consider the information about the allowed models: \n {self.model_info}\n"
|
||||
f"The ChangeModelConfig should include the new model choice for the operation that reduces costs while maintaining adequate performance."
|
||||
f"Ensure that your chosen model is in the list of allowed models."
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output only the InstantiateSchema (a ChangeModelConfig object)."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
global_default_model: str,
|
||||
original_op: Dict,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
optimize_goal="acc",
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive.
|
||||
|
||||
Args:
|
||||
original_op (Dict): The original operation.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
|
||||
Returns:
|
||||
ChangeModelInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant for document processing pipelines.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(
|
||||
original_op, optimize_goal
|
||||
),
|
||||
},
|
||||
]
|
||||
)
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=ChangeModelInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = ChangeModelInstantiateSchema(**parsed_res)
|
||||
orig_model = global_default_model
|
||||
if "model" in original_op:
|
||||
orig_model = original_op.get("model")
|
||||
# Validate the model is in the allowed model list
|
||||
ChangeModelInstantiateSchema.validate_diff_model_in_list(
|
||||
orig_model=orig_model,
|
||||
model=schema.model,
|
||||
list_of_model=self.allowed_model_list,
|
||||
)
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model: str,
|
||||
ops_list: List[Dict],
|
||||
target_op: str,
|
||||
rewrite: ChangeModelInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config by adding gleaning configuration to the target operator.
|
||||
"""
|
||||
# Create a copy of the pipeline config
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find position of the target op to modify
|
||||
pos_to_replace = [
|
||||
i for i, op in enumerate(ops_list) if op["name"] == target_op
|
||||
][0]
|
||||
|
||||
# Add change model configuration to the target operator
|
||||
target_operator = new_ops_list[pos_to_replace]
|
||||
target_operator["model"] = rewrite.model
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
optimize_goal="acc",
|
||||
global_default_model: str = None,
|
||||
allowed_model_list: List[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Update allowed_model_list if provided
|
||||
if allowed_model_list is not None:
|
||||
self.allowed_model_list = allowed_model_list
|
||||
|
||||
new_ops_list = deepcopy(operators)
|
||||
inst_error = 0
|
||||
for target_op in target_ops:
|
||||
target_op_config = [op for op in operators if op["name"] == target_op][0]
|
||||
# Instantiate the directive
|
||||
try:
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
global_default_model,
|
||||
target_op_config,
|
||||
agent_llm,
|
||||
message_history,
|
||||
optimize_goal=optimize_goal,
|
||||
)
|
||||
print(rewrite)
|
||||
except Exception:
|
||||
inst_error += 1
|
||||
new_ops_list = self.apply(
|
||||
global_default_model, new_ops_list, target_op, rewrite
|
||||
)
|
||||
|
||||
if inst_error == len(target_ops):
|
||||
print("CHANEG MODEL ERROR")
|
||||
return None, message_history
|
||||
return new_ops_list, message_history, call_cost
|
||||
|
|
@ -1,303 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import ChangeModelInstantiateSchema
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class ChangeModelAccDirective(Directive):
|
||||
name: str = Field(
|
||||
default="change model acc", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(
|
||||
default="Op => Op* (same operation with a more accurate model choice)"
|
||||
)
|
||||
nl_description: str = Field(
|
||||
default="Rewrites an operator to use a more powerful LLM model to optimize accuracy. Prioritizes model performance and quality over cost considerations, typically suggesting more capable models like gpt-5 for complex reasoning tasks."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When accuracy and quality are the primary concerns, and cost is secondary. Suitable for complex reasoning tasks, critical analysis, or when maximum model performance is needed. Usually changing to a more expensive model will improve accuracy, so you should try this directive if it has not been used in the past iterations."
|
||||
)
|
||||
instantiate_schema_type: Type[BaseModel] = ChangeModelInstantiateSchema
|
||||
|
||||
example: str = Field(
|
||||
default=(
|
||||
"Original Op (MapOpConfig):\n"
|
||||
"- name: analyze_complex_data\n"
|
||||
" type: map\n"
|
||||
" prompt: |\n"
|
||||
" Analyze this complex financial data and provide detailed insights with risk assessment.\n"
|
||||
" Data: {{ input.financial_data }}\n"
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
' analysis: "string"\n'
|
||||
" model: gpt-4o-mini\n"
|
||||
"\n"
|
||||
"Example InstantiateSchema:\n"
|
||||
"{\n"
|
||||
' "model": "gpt-5"\n'
|
||||
"}"
|
||||
),
|
||||
)
|
||||
|
||||
allowed_model_list: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="The allowed list of models to choose from",
|
||||
)
|
||||
|
||||
model_info: str = Field(
|
||||
default=(
|
||||
"""
|
||||
OpenAI-MRCR evaluates a model's ability to locate and disambiguate multiple well-hidden "needles" within a large context.
|
||||
Below are the actual performance scores for the 2-needle retrieval task at various context lengths. Use these results to compare the retrieval capabilities of each model.
|
||||
The below results are the mean match ratio.
|
||||
|
||||
Context Length | gpt-5 | gpt-5-nano | gpt-4o-mini | gemini-2.5-pro | gemini-2.5-flash | gpt-4.1 | gpt-4.1-mini | gpt-4.1-nano | gemini-2.5-flash-lite
|
||||
---------------|-------|------------|-------------|----------------|------------------|---------|--------------|--------------|----------------------
|
||||
128k | 97% | 44% | 25% | 83.6% | 86.2% | 61.3% | 47.1% | 36.7% | 39.9%
|
||||
1M | - | - | - | 62.8% | 60.0% | 45.9% | 34.6% | 14.2% | 18.1%
|
||||
|
||||
The context window and pricing details for each model are shown below (token prices are per 1 million tokens):
|
||||
| Family | Model | Input Price /1M | Output Price /1M | Context Window (API) |
|
||||
|----------------|----------------------------|---------------------------------|-----------------------------------|-----------------------------|
|
||||
| **GPT-5** | azure/gpt-5 | $1.25 | $10.00 | 400K (272K in + 128K out) |
|
||||
| | azure/gpt-5-mini | $0.25 | $2.00 | 400K |
|
||||
| | azure/gpt-5-nano | $0.05 | $0.40 | 400K |
|
||||
| **GPT-4.1** | azure/gpt-4.1 | $2.00 | $8.00 | 1M |
|
||||
| | azure/gpt-4.1-mini | $0.40 | $1.60 | 1M |
|
||||
| | azure/gpt-4.1-nano | $0.10 | $0.40 | 1M |
|
||||
| **GPT-4o** | azure/gpt-4o | $2.50 | $10.00 | 128K |
|
||||
| | azure/gpt-4o-mini | $0.15 | $0.60 | 128K (≈16K output cap) |
|
||||
| **Gemini 2.5** | gemini/gemini-2.5-pro | $1.25 (≤200K) / $2.50 (>200K) | $10.00 (≤200K) / $15.00 (>200K) | 1M (2M soon) |
|
||||
| | gemini/gemini-2.5-flash | $0.30 | $2.50 | 1M |
|
||||
| | gemini/gemini-2.5-flash-lite | $0.10 | $0.40 | 1M |
|
||||
"""
|
||||
),
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="complex_reasoning_needs_powerful_model",
|
||||
description="Should recommend most powerful model for complex reasoning task",
|
||||
input_config={
|
||||
"name": "analyze_legal_implications",
|
||||
"type": "map",
|
||||
"prompt": "Analyze the legal implications and potential risks in this complex contract: {{ input.contract }}",
|
||||
"output": {"schema": {"analysis": "string"}},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
target_ops=["analyze_legal_implications"],
|
||||
expected_behavior="Should recommend gpt-5 for complex legal analysis requiring strong reasoning",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="accuracy_over_cost_scientific_analysis",
|
||||
description="Should prioritize accuracy for scientific analysis task",
|
||||
input_config={
|
||||
"name": "analyze_research_data",
|
||||
"type": "map",
|
||||
"prompt": "Perform detailed statistical analysis and provide research insights: {{ input.data }}",
|
||||
"output": {"schema": {"insights": "string"}},
|
||||
"model": "gpt-5-nano",
|
||||
},
|
||||
target_ops=["analyze_research_data"],
|
||||
expected_behavior="Should recommend gpt-5 for accurate scientific analysis despite higher cost",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="critical_medical_analysis_high_accuracy",
|
||||
description="Should recommend most accurate model for critical medical analysis",
|
||||
input_config={
|
||||
"name": "medical_diagnosis_support",
|
||||
"type": "map",
|
||||
"prompt": "Analyze medical symptoms and provide diagnostic insights: {{ input.symptoms }}",
|
||||
"output": {"schema": {"diagnosis": "string"}},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
target_ops=["medical_diagnosis_support"],
|
||||
expected_behavior="Should recommend gpt-5 for critical medical analysis requiring highest accuracy",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ChangeModelAccDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("ChangeModelAccDirective")
|
||||
|
||||
def to_string_for_instantiate(self, original_op: Dict) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive for accuracy optimization.
|
||||
|
||||
Args:
|
||||
original_op (str): The YAML or string representation of the original operation.
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
return (
|
||||
f"You are an expert at choosing the most accurate and powerful model for a given task, prioritizing quality over cost.\n\n"
|
||||
f"Original Operation:\n"
|
||||
f"{str(original_op)}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to instantiate this directive by suggesting the most accurate model for executing the original operation.\n\n"
|
||||
f"TASK COMPLEXITY ANALYSIS AND MODEL SELECTION:\n"
|
||||
f"First, carefully analyze the original operation to understand:\n"
|
||||
f"• What specific task is being performed (extraction, analysis, transformation, reasoning, etc.)\n"
|
||||
f"• How much cognitive complexity and intelligence is required\n"
|
||||
f"• Whether the task involves simple pattern matching or sophisticated reasoning\n"
|
||||
f"• If the task requires domain expertise, multi-step thinking, or nuanced understanding\n"
|
||||
f"• The criticality and precision requirements of the output\n"
|
||||
f"• If the task requires processing very long context (1M+ tokens)\n\n"
|
||||
f"Based on your analysis of task complexity, select the model that will provide the most accurate response:\n"
|
||||
f"• For simple extraction or formatting tasks: Consider efficient models from the available options\n"
|
||||
f"• For moderate complexity tasks requiring some reasoning: Use capable models from gpt-5 or gemini series\n"
|
||||
f"• For complex reasoning, analysis, interpretation, legal/medical tasks, or critical decisions: Strongly prefer the most advanced models from gpt-5 or gemini series\n"
|
||||
f"• For tasks requiring very long context (1M+ tokens): Consider models with extended context like gpt-4.1 series or gemini models\n"
|
||||
f"• For highly specialized or extremely complex cognitive tasks: Use the most powerful model available\n\n"
|
||||
f"Remember: The goal is maximum accuracy given the intelligence and context requirements of the specific task.\n"
|
||||
f"You have a list of allowed models to choose from: {str(self.allowed_model_list)}.\n\n"
|
||||
f"Consider the information about the allowed models: \n {self.model_info}\n"
|
||||
f"Your response should include the new model choice for the operation that maximizes accuracy given the task complexity and context requirements."
|
||||
f"Ensure that your chosen model is in the list of allowed models."
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output only the ChangeModelInstantiateSchema as JSON."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
global_default_model: str,
|
||||
original_op: Dict,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive for accuracy optimization.
|
||||
|
||||
Args:
|
||||
original_op (Dict): The original operation.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
|
||||
Returns:
|
||||
ChangeModelInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant for document processing pipelines focused on accuracy optimization.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(original_op),
|
||||
},
|
||||
]
|
||||
)
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=ChangeModelInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = ChangeModelInstantiateSchema(**parsed_res)
|
||||
orig_model = global_default_model
|
||||
if "model" in original_op:
|
||||
orig_model = original_op.get("model")
|
||||
# Validate the model is in the allowed model list
|
||||
ChangeModelInstantiateSchema.validate_diff_model_in_list(
|
||||
orig_model=orig_model,
|
||||
model=schema.model,
|
||||
list_of_model=self.allowed_model_list,
|
||||
)
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model: str,
|
||||
ops_list: List[Dict],
|
||||
target_op: str,
|
||||
rewrite: ChangeModelInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config by changing the model of the target operator.
|
||||
"""
|
||||
# Create a copy of the pipeline config
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find position of the target op to modify
|
||||
pos_to_replace = [
|
||||
i for i, op in enumerate(ops_list) if op["name"] == target_op
|
||||
][0]
|
||||
|
||||
# Add change model configuration to the target operator
|
||||
target_operator = new_ops_list[pos_to_replace]
|
||||
target_operator["model"] = rewrite.model
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
global_default_model: str = None,
|
||||
allowed_model_list: List[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Update allowed_model_list if provided
|
||||
if allowed_model_list is not None:
|
||||
self.allowed_model_list = allowed_model_list
|
||||
|
||||
new_ops_list = deepcopy(operators)
|
||||
inst_error = 0
|
||||
for target_op in target_ops:
|
||||
target_op_config = [op for op in operators if op["name"] == target_op][0]
|
||||
# Instantiate the directive
|
||||
try:
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
global_default_model,
|
||||
target_op_config,
|
||||
agent_llm,
|
||||
message_history,
|
||||
)
|
||||
print(rewrite)
|
||||
except Exception:
|
||||
inst_error += 1
|
||||
new_ops_list = self.apply(
|
||||
global_default_model, new_ops_list, target_op, rewrite
|
||||
)
|
||||
|
||||
if inst_error == len(target_ops):
|
||||
print("CHANGE MODEL ACC ERROR")
|
||||
return None, message_history
|
||||
return new_ops_list, message_history, call_cost
|
||||
|
|
@ -1,395 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import ChangeModelInstantiateSchema
|
||||
|
||||
|
||||
def get_cheaper_models(current_model: str, model_stats: Dict = None) -> List[str]:
|
||||
"""
|
||||
Get list of models that are cheaper than the current model.
|
||||
|
||||
Args:
|
||||
current_model: The current model name (may include "azure/" or "gemini/" prefix)
|
||||
model_stats: Dictionary of model statistics from MOARSearch
|
||||
|
||||
Returns:
|
||||
List of model names that are cheaper than current_model, sorted by cost (cheapest first)
|
||||
"""
|
||||
if model_stats is None or not model_stats:
|
||||
return []
|
||||
|
||||
current_model_stats = model_stats.get(current_model)
|
||||
if current_model_stats is None:
|
||||
return []
|
||||
|
||||
current_cost = current_model_stats.get("cost")
|
||||
if current_cost is None:
|
||||
return []
|
||||
|
||||
cheaper_models = []
|
||||
for model, stats in model_stats.items():
|
||||
if not isinstance(stats, dict):
|
||||
continue
|
||||
model_cost = stats.get("cost")
|
||||
if model_cost is not None and model_cost < current_cost:
|
||||
cheaper_models.append(model)
|
||||
|
||||
return cheaper_models
|
||||
|
||||
|
||||
from .base import ( # noqa: E402
|
||||
MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS,
|
||||
Directive,
|
||||
DirectiveTestCase,
|
||||
)
|
||||
|
||||
|
||||
def create_model_specific_directives(
|
||||
current_model: str, allowed_model_list: List[str] = None
|
||||
):
|
||||
"""Create model-specific directives for the current model."""
|
||||
directive = ChangeModelCostDirective(target_model=current_model)
|
||||
directive.name = f"change to {current_model}"
|
||||
directive.nl_description = f"Rewrites an operator to use the {current_model} model to optimize expenses while maintaining adequate performance."
|
||||
if allowed_model_list is not None:
|
||||
directive.allowed_model_list = allowed_model_list
|
||||
else:
|
||||
directive.allowed_model_list = [
|
||||
current_model
|
||||
] # Only allow this specific model if no list provided
|
||||
|
||||
return directive
|
||||
|
||||
|
||||
class ChangeModelCostDirective(Directive):
|
||||
name: str = Field(
|
||||
default="change model cost", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(
|
||||
default="Op => Op* (same operation with a more cost-effective model choice)"
|
||||
)
|
||||
nl_description: str = Field(
|
||||
default="Rewrites an operator to use a more cost-effective LLM model to optimize expenses. Prioritizes cost savings while maintaining adequate performance, typically suggesting cheaper models like gpt-4o-mini or gpt-5-nano for simpler tasks."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When cost optimization is the primary concern and the task can be performed adequately by a less expensive model. Suitable for simple extraction, basic classification, or high-volume processing where budget constraints are important."
|
||||
)
|
||||
instantiate_schema_type: Type[BaseModel] = ChangeModelInstantiateSchema
|
||||
|
||||
example: str = Field(
|
||||
default=(
|
||||
"Original Op (MapOpConfig):\n"
|
||||
"- name: extract_names\n"
|
||||
" type: map\n"
|
||||
" prompt: |\n"
|
||||
" Extract person names from the following text.\n"
|
||||
" Text: {{ input.text }}\n"
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
' names: "list[str]"\n'
|
||||
" model: gpt-5\n"
|
||||
"\n"
|
||||
"Example InstantiateSchema:\n"
|
||||
"{\n"
|
||||
' "model": "gpt-4o-mini"\n'
|
||||
"}"
|
||||
),
|
||||
)
|
||||
|
||||
allowed_model_list: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="The allowed list of models to choose from",
|
||||
)
|
||||
|
||||
target_model: str = Field(
|
||||
default="",
|
||||
description="The specific target model for this directive instance",
|
||||
)
|
||||
|
||||
model_info: str = Field(
|
||||
default=(
|
||||
"""
|
||||
OpenAI-MRCR evaluates a model's ability to locate and disambiguate multiple well-hidden "needles" within a large context.
|
||||
Below are the actual performance scores for the 2-needle retrieval task at various context lengths. Use these results to compare the retrieval capabilities of each model.
|
||||
The below results are the mean match ratio.
|
||||
|
||||
Context Length | gpt-5 | gpt-5-nano | gpt-4o-mini | gemini-2.5-pro | gemini-2.5-flash | gpt-4.1 | gpt-4.1-mini | gpt-4.1-nano | gemini-2.5-flash-lite
|
||||
---------------|-------|------------|-------------|----------------|------------------|---------|--------------|--------------|----------------------
|
||||
128k | 97% | 44% | 25% | 83.6% | 86.2% | 61.3% | 47.1% | 36.7% | 39.9%
|
||||
1M | - | - | - | 62.8% | 60.0% | 45.9% | 34.6% | 14.2% | 18.1%
|
||||
|
||||
The context window and pricing details for each model are shown below (token prices are per 1 million tokens):
|
||||
| Family | Model | Input Price /1M | Output Price /1M | Context Window (API) |
|
||||
|----------------|----------------------------|---------------------------------|-----------------------------------|-----------------------------|
|
||||
| **GPT-5** | azure/gpt-5 | $1.25 | $10.00 | 400K (272K in + 128K out) |
|
||||
| | azure/gpt-5-mini | $0.25 | $2.00 | 400K |
|
||||
| | azure/gpt-5-nano | $0.05 | $0.40 | 400K |
|
||||
| **GPT-4.1** | azure/gpt-4.1 | $2.00 | $8.00 | 1M |
|
||||
| | azure/gpt-4.1-mini | $0.40 | $1.60 | 1M |
|
||||
| | azure/gpt-4.1-nano | $0.10 | $0.40 | 1M |
|
||||
| **GPT-4o** | azure/gpt-4o | $2.50 | $10.00 | 128K |
|
||||
| | azure/gpt-4o-mini | $0.15 | $0.60 | 128K (≈16K output cap) |
|
||||
| **Gemini 2.5** | gemini/gemini-2.5-pro | $1.25 (≤200K) / $2.50 (>200K) | $10.00 (≤200K) / $15.00 (>200K) | 1M (2M soon) |
|
||||
| | gemini/gemini-2.5-flash | $0.30 | $2.50 | 1M |
|
||||
| | gemini/gemini-2.5-flash-lite | $0.10 | $0.40 | 1M |
|
||||
"""
|
||||
),
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="cost_optimization_simple_extraction",
|
||||
description="Should suggest cheaper model for simple extraction task",
|
||||
input_config={
|
||||
"name": "extract_names",
|
||||
"type": "map",
|
||||
"prompt": "Extract person names from: {{ input.text }}",
|
||||
"output": {"schema": {"names": "list[str]"}},
|
||||
"model": "gpt-5",
|
||||
},
|
||||
target_ops=["extract_names"],
|
||||
expected_behavior="Should recommend gpt-4o-mini or gpt-5-nano for cost savings on simple task",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="high_volume_processing_cost_optimization",
|
||||
description="Should optimize model for high-volume document processing to reduce costs",
|
||||
input_config={
|
||||
"name": "batch_document_summary",
|
||||
"type": "map",
|
||||
"prompt": "Summarize this document in 2-3 sentences: {{ input.document }}",
|
||||
"output": {"schema": {"summary": "str"}},
|
||||
"model": "gpt-5",
|
||||
},
|
||||
target_ops=["batch_document_summary"],
|
||||
expected_behavior="Should recommend cheaper model like gpt-4o-mini or gpt-5-nano for high-volume simple summarization",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="basic_classification_cost_savings",
|
||||
description="Should recommend cost-effective model for basic classification",
|
||||
input_config={
|
||||
"name": "classify_sentiment",
|
||||
"type": "map",
|
||||
"prompt": "Classify the sentiment of this text as positive, negative, or neutral: {{ input.text }}",
|
||||
"output": {"schema": {"sentiment": "string"}},
|
||||
"model": "gpt-5",
|
||||
},
|
||||
target_ops=["classify_sentiment"],
|
||||
expected_behavior="Should recommend cheaper model for basic sentiment classification to reduce costs",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, ChangeModelCostDirective)
|
||||
and self.target_model == other.target_model
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(f"ChangeModelCostDirective_{self.target_model}")
|
||||
|
||||
def to_string_for_instantiate(
|
||||
self, original_op: Dict, dataset: str, model_stats: Dict = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive for cost optimization.
|
||||
|
||||
Args:
|
||||
original_op (str): The YAML or string representation of the original operation.
|
||||
dataset: The dataset name
|
||||
model_stats: Dictionary of model statistics from MOARSearch (optional)
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
model_stats_str = ""
|
||||
if model_stats:
|
||||
model_stats_str = f"You have a list of model statistics on the task with the original query pipeline: \n {str(model_stats)}\n"
|
||||
|
||||
return (
|
||||
f"You are an expert at choosing the most cost-effective model for a given task while maintaining adequate performance.\n\n"
|
||||
f"Original Operation:\n"
|
||||
f"{str(original_op)}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to instantiate this directive by suggesting the cheapest model that meets the requirements for executing the original operation.\n\n"
|
||||
f"COST OPTIMIZATION STRATEGY:\n"
|
||||
f"• Choose the cheapest model that meets the task requirements\n"
|
||||
f"• For tasks requiring ultra-long context (1M+ context window), use gpt-4.1 series or gemini models\n"
|
||||
f"• For tasks that can work with document samples or fit within 272k context window, use gpt-5-nano\n"
|
||||
f"• For simple tasks (extraction, basic classification, straightforward summarization), use the cheapest available model like gpt-4o-mini or gpt-5-nano\n"
|
||||
f"• For high-volume processing where cost accumulates quickly, prioritize the most economical option\n"
|
||||
f"• Only use expensive models if the task absolutely requires capabilities not available in cheaper alternatives\n"
|
||||
f"• Consider document length and context requirements when selecting models\n\n"
|
||||
f"You have a list of allowed models to choose from: {str(self.allowed_model_list)}.\n\n"
|
||||
f"Consider the information about the allowed models: \n {self.model_info}\n"
|
||||
f"{model_stats_str}"
|
||||
f"Your response should include the cheapest model choice that meets the operation requirements."
|
||||
f"Ensure that your chosen model is in the list of allowed models."
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output only the ChangeModelInstantiateSchema as JSON."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
global_default_model: str,
|
||||
original_op: Dict,
|
||||
agent_llm: str,
|
||||
dataset: str,
|
||||
message_history: list = [],
|
||||
model_stats: Dict = None,
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive for cost optimization.
|
||||
|
||||
Args:
|
||||
original_op (Dict): The original operation.
|
||||
agent_llm (str): The LLM model to use.
|
||||
dataset: The dataset name
|
||||
message_history (List, optional): Conversation history for context.
|
||||
model_stats: Dictionary of model statistics from MOARSearch (optional)
|
||||
|
||||
Returns:
|
||||
ChangeModelInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
# If target_model is specified, use it directly without LLM call
|
||||
if self.target_model:
|
||||
schema = ChangeModelInstantiateSchema(model=self.target_model)
|
||||
return schema, message_history, 0.0
|
||||
|
||||
# Otherwise, use LLM to choose the model
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant for document processing pipelines focused on cost optimization.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(
|
||||
original_op, dataset, model_stats
|
||||
),
|
||||
},
|
||||
]
|
||||
)
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=ChangeModelInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = ChangeModelInstantiateSchema(**parsed_res)
|
||||
orig_model = global_default_model
|
||||
if "model" in original_op:
|
||||
orig_model = original_op.get("model")
|
||||
# Validate the model is in the allowed model list
|
||||
ChangeModelInstantiateSchema.validate_diff_model_in_list(
|
||||
orig_model=orig_model,
|
||||
model=schema.model,
|
||||
list_of_model=self.allowed_model_list,
|
||||
)
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model: str,
|
||||
ops_list: List[Dict],
|
||||
target_op: str,
|
||||
rewrite: ChangeModelInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config by changing the model of the target operator.
|
||||
"""
|
||||
# Create a copy of the pipeline config
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find position of the target op to modify
|
||||
pos_to_replace = [
|
||||
i for i, op in enumerate(ops_list) if op["name"] == target_op
|
||||
][0]
|
||||
|
||||
# Add change model configuration to the target operator
|
||||
target_operator = new_ops_list[pos_to_replace]
|
||||
target_operator["model"] = rewrite.model
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
global_default_model: str = None,
|
||||
dataset: str = None,
|
||||
model_stats: Dict = None,
|
||||
allowed_model_list: List[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
|
||||
Args:
|
||||
operators: List of operator configurations
|
||||
target_ops: List of target operation names
|
||||
agent_llm: LLM model to use for instantiation
|
||||
message_history: Conversation history
|
||||
global_default_model: Default model for the pipeline
|
||||
dataset: Dataset name
|
||||
model_stats: Dictionary of model statistics from MOARSearch (optional)
|
||||
allowed_model_list: List of allowed models (optional)
|
||||
**kwargs: Additional keyword arguments
|
||||
"""
|
||||
# Update allowed_model_list if provided
|
||||
if allowed_model_list is not None:
|
||||
self.allowed_model_list = allowed_model_list
|
||||
|
||||
new_ops_list = deepcopy(operators)
|
||||
inst_error = 0
|
||||
for target_op in target_ops:
|
||||
target_op_config = [op for op in operators if op["name"] == target_op][0]
|
||||
# Instantiate the directive
|
||||
try:
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
global_default_model,
|
||||
target_op_config,
|
||||
agent_llm,
|
||||
dataset,
|
||||
message_history,
|
||||
model_stats,
|
||||
)
|
||||
print(rewrite)
|
||||
except Exception:
|
||||
inst_error += 1
|
||||
new_ops_list = self.apply(
|
||||
global_default_model, new_ops_list, target_op, rewrite
|
||||
)
|
||||
|
||||
if inst_error == len(target_ops):
|
||||
print("CHANGE MODEL COST ERROR")
|
||||
return None, message_history
|
||||
return new_ops_list, message_history, call_cost
|
||||
|
|
@ -1,348 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
ChunkHeaderSummaryInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class ChunkHeaderSummaryDirective(Directive):
|
||||
name: str = Field(
|
||||
default="chunk_header_summary", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(default="Split -> Gather => Split -> Map -> Gather")
|
||||
nl_description: str = Field(
|
||||
default="Transforms an existing Split -> Gather pipeline by inserting a Map operation between them that extracts headers and creates summaries from each chunk. The Gather operation is then modified to use summaries for middle chunks and headers for document structure. This directive enhances chunking pipelines with header extraction and chunk summarization capabilities. Only use this if it is clear that chunk-level analysis is insufficient because the chunk requires headers and summaries from other chunks to be interpreted correctly."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="Use only when you have an existing chunking pipeline (Split -> Gather) processing documents with clear hierarchical structure (legal contracts, technical manuals, research papers), and it is evident that chunk-level analysis is not accurate because the chunk needs headers and summaries from other chunks to make sense. This is beneficial when full chunk content in gather would be too verbose, and summarized or structured context is required for correct downstream processing. The target operators should be the split and gather. Make sure you specify these two operators when choosing this directive."
|
||||
)
|
||||
|
||||
instantiate_schema_type: Type[BaseModel] = Field(
|
||||
default=ChunkHeaderSummaryInstantiateSchema
|
||||
)
|
||||
|
||||
example: str = Field(
|
||||
default="""
|
||||
Original Pipeline (Split -> Gather):
|
||||
- name: split_contract_analysis
|
||||
type: split
|
||||
split_key: contract_text
|
||||
method: token_count
|
||||
method_kwargs:
|
||||
num_tokens: 1000
|
||||
|
||||
- name: gather_contract_context
|
||||
type: gather
|
||||
content_key: contract_text_chunk
|
||||
doc_id_key: split_contract_analysis_id
|
||||
order_key: split_contract_analysis_chunk_num
|
||||
peripheral_chunks:
|
||||
previous:
|
||||
tail:
|
||||
count: 1
|
||||
|
||||
Example InstantiateSchema Options (what the agent should output):
|
||||
|
||||
# Basic header and summary extraction:
|
||||
{
|
||||
"header_extraction_prompt": "Extract any section headers or subsection titles from this contract chunk: {{ input.contract_text_chunk }}. Return the headers with their hierarchical levels.",
|
||||
"summary_prompt": "Summarize the key legal concepts and clause types in this contract chunk: {{ input.contract_text_chunk }}. Focus on liability, indemnification, and related contractual obligations.",
|
||||
}
|
||||
""",
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="legal_document_with_structure",
|
||||
description="Should transform split->gather pipeline to include header extraction and summarization",
|
||||
input_config=[
|
||||
{
|
||||
"name": "split_legal_docs",
|
||||
"type": "split",
|
||||
"split_key": "agreement_text",
|
||||
"method": "token_count",
|
||||
"method_kwargs": {"num_tokens": 1000},
|
||||
},
|
||||
{
|
||||
"name": "gather_legal_context",
|
||||
"type": "gather",
|
||||
"content_key": "agreement_text_chunk",
|
||||
"doc_id_key": "split_legal_docs_id",
|
||||
"order_key": "split_legal_docs_chunk_num",
|
||||
"peripheral_chunks": {
|
||||
"previous": {"tail": {"count": 1}},
|
||||
"next": {"head": {"count": 1}},
|
||||
},
|
||||
},
|
||||
],
|
||||
target_ops=["split_legal_docs", "gather_legal_context"],
|
||||
expected_behavior="Should insert parallel_map between split and gather for header extraction and summarization, with gather using doc_header_key",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="technical_manual_analysis",
|
||||
description="Should transform split->gather pipeline for technical documentation with header and summary context",
|
||||
input_config=[
|
||||
{
|
||||
"name": "split_manual",
|
||||
"type": "split",
|
||||
"split_key": "manual_text",
|
||||
"method": "token_count",
|
||||
"method_kwargs": {"num_tokens": 800},
|
||||
},
|
||||
{
|
||||
"name": "gather_manual_context",
|
||||
"type": "gather",
|
||||
"content_key": "manual_text_chunk",
|
||||
"doc_id_key": "split_manual_id",
|
||||
"order_key": "split_manual_chunk_num",
|
||||
"peripheral_chunks": {
|
||||
"previous": {"tail": {"count": 2}},
|
||||
"next": {"head": {"count": 1}},
|
||||
},
|
||||
},
|
||||
],
|
||||
target_ops=["split_manual", "gather_manual_context"],
|
||||
expected_behavior="Should insert parallel_map between split and gather for header extraction and chunk summarization, enhancing technical documentation processing",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ChunkHeaderSummaryDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("ChunkHeaderSummaryDirective")
|
||||
|
||||
def to_string_for_instantiate(self, original_op: Dict) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive.
|
||||
|
||||
Args:
|
||||
original_op (Dict): The original operation.
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
return (
|
||||
f"You are an expert at enhancing document processing pipelines with header extraction and chunk summarization.\n\n"
|
||||
f"Original Pipeline:\n"
|
||||
f"{str(original_op)}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to instantiate this directive by creating a configuration that enhances an existing Split -> Gather pipeline "
|
||||
f"by inserting a Map operation between them for header extraction and summarization.\n\n"
|
||||
f"Key requirements:\n"
|
||||
f"1. header_extraction_prompt: Create a prompt to extract headers/section titles from each chunk:\n"
|
||||
f" - Use '{{{{ input.<split_key>_chunk }}}}' to reference chunk content (you'll know the split_key from the existing split operation)\n"
|
||||
f" - Focus on document structure (titles, headings, section numbers)\n"
|
||||
f" - Should output 'headers' field with hierarchical level information\n"
|
||||
f"2. summary_prompt: Create a prompt to summarize each chunk:\n"
|
||||
f" - Use '{{{{ input.<split_key>_chunk }}}}' to reference chunk content\n"
|
||||
f" - Focus on key concepts relevant to the downstream processing\n"
|
||||
f" - Should output '<split_key>_summary' field\n"
|
||||
f" - Keep summaries concise but informative for context\n"
|
||||
f"The header extraction helps maintain document structure context.\n"
|
||||
f"The summary provides condensed context from surrounding chunks.\n"
|
||||
f"The gather operation combines both for comprehensive context.\n\n"
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output only the ChunkHeaderSummaryInstantiateSchema object as JSON."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
original_op: Dict,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive by creating chunking configuration.
|
||||
|
||||
Args:
|
||||
original_op (Dict): The original operation.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
|
||||
Returns:
|
||||
ChunkHeaderSummaryInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(original_op),
|
||||
},
|
||||
]
|
||||
)
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=ChunkHeaderSummaryInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = ChunkHeaderSummaryInstantiateSchema(**parsed_res)
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model: str,
|
||||
ops_list: List[Dict],
|
||||
target_ops: List[str],
|
||||
rewrite: ChunkHeaderSummaryInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config by inserting a parallel_map operation
|
||||
between the existing split and gather operations.
|
||||
"""
|
||||
# Create a copy of the pipeline config
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find the split and gather operations
|
||||
split_op = None
|
||||
gather_op = None
|
||||
split_idx = None
|
||||
gather_idx = None
|
||||
|
||||
for i, op in enumerate(new_ops_list):
|
||||
if op["name"] in target_ops:
|
||||
if op["type"] == "split":
|
||||
split_op = op
|
||||
split_idx = i
|
||||
elif op["type"] == "gather":
|
||||
gather_op = op
|
||||
gather_idx = i
|
||||
|
||||
if not split_op or not gather_op:
|
||||
raise ValueError(
|
||||
"Both split and gather operations must be provided as target operations"
|
||||
)
|
||||
|
||||
if split_idx >= gather_idx:
|
||||
raise ValueError(
|
||||
"Split operation must come before gather operation in the pipeline"
|
||||
)
|
||||
|
||||
if gather_idx - split_idx > 1:
|
||||
raise ValueError("There should not be operators between split and gather")
|
||||
|
||||
# Get the split_key from the split operation
|
||||
split_key = split_op.get("split_key")
|
||||
if not split_key:
|
||||
raise ValueError(
|
||||
f"Split operation '{split_op['name']}' must have a 'split_key' field"
|
||||
)
|
||||
|
||||
# Create the parallel map operation for header extraction and summarization
|
||||
parallel_map_name = f"parallel_map_{split_op['name']}_header_summary"
|
||||
parallel_map_op = {
|
||||
"name": parallel_map_name,
|
||||
"type": "parallel_map",
|
||||
"prompts": [
|
||||
{
|
||||
"name": f"header_extraction_{split_op['name']}",
|
||||
"prompt": rewrite.header_extraction_prompt,
|
||||
"output_keys": ["headers"],
|
||||
},
|
||||
{
|
||||
"name": f"summary_{split_op['name']}",
|
||||
"prompt": rewrite.summary_prompt,
|
||||
"output_keys": [f"{split_key}_summary"],
|
||||
},
|
||||
],
|
||||
"model": global_default_model,
|
||||
"output": {
|
||||
"schema": {
|
||||
"headers": "list[{header: string, level: integer}]",
|
||||
f"{split_key}_summary": "string",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Add doc_header_key to the gather operation to use extracted headers
|
||||
new_ops_list[gather_idx]["doc_header_key"] = "headers"
|
||||
|
||||
# Insert the parallel map operation between split and gather
|
||||
new_ops_list.insert(gather_idx, parallel_map_op)
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Assert that there are exactly two target ops
|
||||
assert (
|
||||
len(target_ops) == 2
|
||||
), "There must be exactly two target ops (split and gather) to instantiate this chunk header summary directive"
|
||||
|
||||
# Find split and gather operations
|
||||
split_op = None
|
||||
gather_op = None
|
||||
for op in operators:
|
||||
if op["name"] in target_ops:
|
||||
if op["type"] == "split":
|
||||
split_op = op
|
||||
elif op["type"] == "gather":
|
||||
gather_op = op
|
||||
|
||||
if not split_op:
|
||||
raise ValueError(
|
||||
f"Chunk header summary directive requires a split operation among target operations, but none found in {target_ops}"
|
||||
)
|
||||
|
||||
if not gather_op:
|
||||
raise ValueError(
|
||||
f"Chunk header summary directive requires a gather operation among target operations, but none found in {target_ops}"
|
||||
)
|
||||
|
||||
# Create a combined context for instantiation
|
||||
pipeline_context = {
|
||||
"split_op": split_op,
|
||||
"gather_op": gather_op,
|
||||
"target_ops": target_ops,
|
||||
}
|
||||
|
||||
# Instantiate the directive
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
pipeline_context, agent_llm, message_history
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators
|
||||
return (
|
||||
self.apply(global_default_model, operators, target_ops, rewrite),
|
||||
message_history,
|
||||
call_cost,
|
||||
)
|
||||
|
|
@ -1,282 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
ClarifyInstructionsInstantiateSchema,
|
||||
)
|
||||
|
||||
from .agent_utils import AgenticDirectiveRunner
|
||||
from .base import Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class ClarifyInstructionsDirective(Directive):
|
||||
name: str = Field(
|
||||
default="clarify_instructions", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(default="Single-op => Op")
|
||||
nl_description: str = Field(
|
||||
default="Improves a single operation's prompt clarity and specificity by analyzing sample input data to identify patterns, resolve ambiguities, and create more precise instructions that reduce LLM confusion and improve output consistency"
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When a single operation has a vague or ambiguous prompt that could benefit from more specific instructions based on actual data patterns. Particularly useful when you have multiple samples of input data and want to create a prompt for one specific operation that handles the patterns and edge cases present in your dataset"
|
||||
)
|
||||
|
||||
instantiate_schema_type: Type[BaseModel] = Field(
|
||||
default=ClarifyInstructionsInstantiateSchema
|
||||
)
|
||||
|
||||
example: str = Field(
|
||||
default="""
|
||||
Target Operation:
|
||||
- name: extract_key_findings
|
||||
type: map
|
||||
prompt: |
|
||||
Extract the key findings from: {{ input.research_paper }}
|
||||
output:
|
||||
schema:
|
||||
findings: "list[str]"
|
||||
|
||||
After analyzing sample research papers, the agent might discover papers contain
|
||||
abstracts, conclusions, and results sections with different formats.
|
||||
|
||||
Example InstantiateSchema (what the agent should output):
|
||||
ClarifyInstructionsInstantiateSchema(
|
||||
clarified_prompt="Extract the key findings from the research paper: {{ input.research_paper }}. Focus on: 1) Main experimental results and statistical significance from Results sections, 2) Primary conclusions from Abstract and Conclusion sections, 3) Novel contributions explicitly stated by authors. Ignore methodological details, related work summaries, and future work suggestions. Return 3-7 concise findings as bullet points."
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="single_target_clarification",
|
||||
description="Should create clarified prompt for single target operation",
|
||||
input_config={
|
||||
"name": "analyze_feedback",
|
||||
"type": "map",
|
||||
"prompt": "Analyze the feedback: {{ input.feedback }}",
|
||||
"output": {
|
||||
"schema": {"sentiment": "string", "issues": "list[str]"}
|
||||
},
|
||||
},
|
||||
target_ops=["analyze_feedback"],
|
||||
expected_behavior="Should replace the original prompt with a more specific version based on analysis of sample feedback data. The clarified prompt should reference {{ input.feedback }} and provide specific guidance on what to analyze.",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="preserve_input_variables",
|
||||
description="Should preserve all input variable references from original prompt",
|
||||
input_config={
|
||||
"name": "compare_documents",
|
||||
"type": "map",
|
||||
"prompt": "Compare {{ input.doc1 }} with {{ input.doc2 }} and identify differences",
|
||||
"output": {"schema": {"differences": "list[str]"}},
|
||||
},
|
||||
target_ops=["compare_documents"],
|
||||
expected_behavior="Should create clarified prompt that still references both {{ input.doc1 }} and {{ input.doc2 }} while providing more specific comparison instructions.",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ClarifyInstructionsDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("ClarifyInstructionsDirective")
|
||||
|
||||
def to_string_for_instantiate(
|
||||
self, target_ops_configs: List[Dict], pipeline_code: Dict = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a prompt that asks the agent to analyze sample data and create an improved prompt.
|
||||
"""
|
||||
assert (
|
||||
len(target_ops_configs) == 1
|
||||
), "ClarifyInstructions directive only supports single target operation"
|
||||
|
||||
op = target_ops_configs[0]
|
||||
original_prompt = op.get("prompt", "")
|
||||
|
||||
# Build pipeline context
|
||||
pipeline_context = ""
|
||||
if pipeline_code:
|
||||
pipeline_context = f"""
|
||||
Pipeline Context:
|
||||
{json.dumps(pipeline_code, indent=2)}
|
||||
|
||||
The target operation '{op['name']}' fits into this broader pipeline. Consider:
|
||||
- What data flows into this operation from previous steps
|
||||
- How this operation's output will be used by subsequent operations
|
||||
- The overall goal of the pipeline when creating your improved prompt
|
||||
"""
|
||||
|
||||
return (
|
||||
f"You are an expert at analyzing data patterns and creating precise, effective prompts.\n\n"
|
||||
f"Target Operation:\n"
|
||||
f"{json.dumps(op, indent=2)}\n\n"
|
||||
f"Original Prompt: {original_prompt}\n\n"
|
||||
f"{pipeline_context}\n"
|
||||
f"Your task is to analyze sample input data and create a significantly improved version of this prompt.\n\n"
|
||||
f"You will be given access to sample input data through a read_next_doc() function. Use this to:\n"
|
||||
f"1. Understand the actual structure and patterns in the input data\n"
|
||||
f"2. Identify ambiguities in the original prompt that could be clarified\n"
|
||||
f"3. Discover specific patterns, formats, or edge cases that should be addressed\n"
|
||||
f"4. Consider how this operation fits into the broader pipeline context\n"
|
||||
f"5. Create a more specific, actionable prompt based on these insights\n\n"
|
||||
f"Guidelines for the improved prompt:\n"
|
||||
f"- Must preserve ALL original input variable references (like {{{{ input.fieldname }}}})\n"
|
||||
f"- Should be significantly more specific than the original\n"
|
||||
f"- Include concrete instructions based on patterns observed in the data\n"
|
||||
f"- Address potential ambiguities or edge cases discovered in samples\n"
|
||||
f"- Consider the pipeline context and how this operation contributes to the overall goal\n"
|
||||
f"- Maintain the same general task but with much clearer execution details\n\n"
|
||||
f"Example transformation:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Analyze samples strategically - focus on diversity and understanding patterns rather than reading every document.\n"
|
||||
f"When you have enough information to create a substantially improved prompt, output your result.\n\n"
|
||||
f"Remember: Your goal is to make the prompt so clear and specific that it produces more consistent, higher-quality results."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
target_ops_configs: List[Dict],
|
||||
input_file_path: str,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
pipeline_code: Dict = None,
|
||||
):
|
||||
"""
|
||||
Use agentic approach to analyze sample data and generate improved prompt.
|
||||
"""
|
||||
# Load sample input data
|
||||
try:
|
||||
with open(input_file_path, "r") as f:
|
||||
input_data = json.load(f)
|
||||
|
||||
if not isinstance(input_data, list) or len(input_data) == 0:
|
||||
raise ValueError(
|
||||
"Input file must contain a non-empty list of sample data"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Failed to load input data from {input_file_path}: {str(e)}"
|
||||
)
|
||||
|
||||
# Create validation function for input variable preservation
|
||||
original_prompt = target_ops_configs[0].get("prompt", "")
|
||||
|
||||
def validate_input_variables(schema_instance):
|
||||
ClarifyInstructionsInstantiateSchema.validate_input_variables_preserved(
|
||||
schema_instance.clarified_prompt, original_prompt
|
||||
)
|
||||
|
||||
# Set up agentic runner with validation
|
||||
runner = AgenticDirectiveRunner(
|
||||
input_data=input_data,
|
||||
agent_llm=agent_llm,
|
||||
validation_func=validate_input_variables,
|
||||
)
|
||||
|
||||
# Create system prompt for the agentic runner
|
||||
system_prompt = (
|
||||
"You are an expert prompt engineer who analyzes data to create better, more specific prompts. "
|
||||
"Your goal is to examine input samples to understand patterns, identify ambiguities, and create "
|
||||
"significantly improved prompts that are more specific and actionable than the original. "
|
||||
"You also consider the broader pipeline context to ensure the improved prompt serves the overall data processing goal."
|
||||
)
|
||||
|
||||
# Create initial user message
|
||||
initial_message = self.to_string_for_instantiate(
|
||||
target_ops_configs, pipeline_code
|
||||
)
|
||||
|
||||
# Run the agentic loop (validation is handled internally)
|
||||
try:
|
||||
schema, updated_message_history, call_cost = runner.run_agentic_loop(
|
||||
system_prompt=system_prompt,
|
||||
initial_user_message=initial_message,
|
||||
response_schema=ClarifyInstructionsInstantiateSchema,
|
||||
)
|
||||
|
||||
# Update message history
|
||||
message_history.extend(updated_message_history)
|
||||
|
||||
return schema, message_history, call_cost
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Failed to instantiate clarify_instructions directive: {str(e)}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model: str,
|
||||
ops_list: List[Dict],
|
||||
target_ops: List[str],
|
||||
rewrite: ClarifyInstructionsInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive by replacing the target operation's prompt with the clarified version.
|
||||
"""
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find and update the target operation
|
||||
for i, op in enumerate(new_ops_list):
|
||||
if op["name"] in target_ops:
|
||||
# Update the prompt with the clarified version
|
||||
new_ops_list[i]["prompt"] = rewrite.clarified_prompt
|
||||
break
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Main method that orchestrates directive instantiation:
|
||||
1. Use agentic approach to analyze data and generate improved prompt
|
||||
2. Apply the transformation using that improved prompt
|
||||
"""
|
||||
assert (
|
||||
len(target_ops) == 1
|
||||
), "ClarifyInstructions directive requires exactly one target operation"
|
||||
input_file_path = kwargs.get("input_file_path", None)
|
||||
pipeline_code = kwargs.get("pipeline_code", None)
|
||||
|
||||
if not input_file_path:
|
||||
raise ValueError(
|
||||
"input_file_path is required for ClarifyInstructions directive"
|
||||
)
|
||||
|
||||
# Get configuration for target operation
|
||||
target_ops_configs = [op for op in operators if op["name"] in target_ops]
|
||||
|
||||
if not target_ops_configs:
|
||||
raise ValueError(f"Target operation {target_ops[0]} not found in operators")
|
||||
|
||||
# Step 1: Agent analyzes data and generates improved prompt
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
target_ops_configs,
|
||||
input_file_path,
|
||||
agent_llm,
|
||||
message_history,
|
||||
pipeline_code,
|
||||
)
|
||||
|
||||
# Step 2: Apply transformation using the improved prompt
|
||||
return (
|
||||
self.apply(global_default_model, operators, target_ops, rewrite),
|
||||
message_history,
|
||||
call_cost,
|
||||
)
|
||||
|
|
@ -1,350 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
DeterministicDocCompressionInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
# TODO: For the agent instantiating the rewrite directive,
|
||||
# we might want to allow it to look at some example documents /
|
||||
# enough documents until it feels like it has a good understanding of the data.
|
||||
|
||||
|
||||
class DeterministicDocCompressionDirective(Directive):
|
||||
name: str = Field(
|
||||
default="deterministic_doc_compression", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(default="Op => Code Map -> Op")
|
||||
nl_description: str = Field(
|
||||
default="Reduces LLM processing costs by using deterministic logic (regex, patterns) to compress documents before expensive downstream operations, removing irrelevant content that could distract the LLM"
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When documents contain identifiable patterns or keywords and you want to reduce token costs for downstream LLM operations while improving accuracy by eliminating distracting irrelevant content"
|
||||
)
|
||||
|
||||
instantiate_schema_type: Type[BaseModel] = Field(
|
||||
default=DeterministicDocCompressionInstantiateSchema
|
||||
)
|
||||
|
||||
example: str = Field(
|
||||
default="""
|
||||
Target Operations:
|
||||
- name: analyze_regulatory_compliance
|
||||
type: map
|
||||
prompt: |
|
||||
Analyze regulatory compliance issues in this legal document: {{ input.legal_document }}
|
||||
Focus on identifying violations, required actions, and compliance deadlines.
|
||||
output:
|
||||
schema:
|
||||
violations: "list[str]"
|
||||
required_actions: "list[str]"
|
||||
deadlines: "list[str]"
|
||||
|
||||
Example InstantiateSchema (what the agent should output):
|
||||
{
|
||||
"name": "extract_compliance_sections",
|
||||
"code": '''
|
||||
def transform(input_doc):
|
||||
import re
|
||||
|
||||
legal_document = input_doc.get('legal_document', '')
|
||||
|
||||
# Patterns to identify compliance-related content
|
||||
compliance_patterns = [
|
||||
r'(?i)(violat[e|ion]|breach|non-complian[ce|t])',
|
||||
r'(?i)(deadline|due date|expir[e|ation]|within.*days?)',
|
||||
r'(?i)(shall|must|required|mandatory|obligation)',
|
||||
r'(?i)(section|article|clause)\\s+\\d+.*(complian[ce|t]|regulat[e|ory])'
|
||||
]
|
||||
|
||||
relevant_spans = []
|
||||
|
||||
# Find all matches and extract context around them
|
||||
for pattern in compliance_patterns:
|
||||
for match in re.finditer(pattern, legal_document):
|
||||
start_pos = match.start()
|
||||
end_pos = match.end()
|
||||
|
||||
# Extract 300 chars before and 800 chars after the match
|
||||
context_start = max(0, start_pos - 300)
|
||||
context_end = min(len(legal_document), end_pos + 800)
|
||||
|
||||
# Extract the context around the match
|
||||
context = legal_document[context_start:context_end]
|
||||
relevant_spans.append((context_start, context_end, context))
|
||||
|
||||
# Merge overlapping spans and remove duplicates
|
||||
if relevant_spans:
|
||||
# Sort by start position
|
||||
relevant_spans.sort(key=lambda x: x[0])
|
||||
merged_spans = [relevant_spans[0]]
|
||||
|
||||
for current_start, current_end, current_text in relevant_spans[1:]:
|
||||
last_start, last_end, last_text = merged_spans[-1]
|
||||
|
||||
if current_start <= last_end + 100: # Merge if close enough
|
||||
# Extend the last span
|
||||
new_end = max(last_end, current_end)
|
||||
new_text = legal_document[last_start:new_end]
|
||||
merged_spans[-1] = (last_start, new_end, new_text)
|
||||
else:
|
||||
merged_spans.append((current_start, current_end, current_text))
|
||||
|
||||
# Extract just the text portions
|
||||
compressed_text = '\\n\\n--- SECTION BREAK ---\\n\\n'.join([span[2] for span in merged_spans])
|
||||
else:
|
||||
compressed_text = legal_document # Fallback if no matches
|
||||
|
||||
return {
|
||||
'legal_document': compressed_text
|
||||
}
|
||||
'''
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="detailed_merger_agreement_analysis",
|
||||
description="Should compress merger agreement for comprehensive legal analysis",
|
||||
input_config={
|
||||
"name": "analyze_merger_agreement_terms",
|
||||
"type": "map",
|
||||
"prompt": """Perform a comprehensive legal analysis of this merger agreement: {{ input.merger_agreement }}
|
||||
|
||||
Analyze and extract the following:
|
||||
1. Purchase price structure and payment mechanisms (cash, stock, earnouts, escrow arrangements)
|
||||
2. Material adverse change (MAC) definitions and carve-outs that could affect deal completion
|
||||
3. Representations and warranties with survival periods and liability caps
|
||||
4. Closing conditions precedent, including regulatory approvals and third-party consents
|
||||
5. Termination rights and associated breakup fees or reverse breakup fees
|
||||
6. Indemnification provisions including baskets, caps, and survival periods
|
||||
7. Employee retention arrangements and change-in-control provisions
|
||||
8. Integration planning requirements and operational restrictions during pendency
|
||||
9. Dispute resolution mechanisms and governing law provisions
|
||||
10. Post-closing adjustments and working capital mechanisms
|
||||
|
||||
For each area, provide specific clause references, dollar amounts where applicable,
|
||||
time periods, and risk assessment (High/Medium/Low) with justification.""",
|
||||
"output": {
|
||||
"schema": {
|
||||
"purchase_price_analysis": "string",
|
||||
"mac_provisions": "list[str]",
|
||||
"representations_warranties": "list[str]",
|
||||
"closing_conditions": "list[str]",
|
||||
"termination_rights": "string",
|
||||
"indemnification_terms": "string",
|
||||
"employee_provisions": "list[str]",
|
||||
"integration_restrictions": "list[str]",
|
||||
"dispute_resolution": "string",
|
||||
"post_closing_adjustments": "string",
|
||||
"risk_assessment": "string",
|
||||
}
|
||||
},
|
||||
},
|
||||
target_ops=["analyze_merger_agreement_terms"],
|
||||
expected_behavior="Should add Code Map operation that extracts merger agreement sections using regex patterns for legal terms, financial provisions, and risk-related clauses. The return dictionary of the transform function should be {'merger_agreement': ....} only.",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="multi_document_analysis_compression",
|
||||
description="Should compress document for multiple analysis operations",
|
||||
input_config=[
|
||||
{
|
||||
"name": "extract_financial_metrics",
|
||||
"type": "map",
|
||||
"prompt": "Extract revenue, profit, and expense figures from: {{ input.earnings_report }}",
|
||||
"output": {
|
||||
"schema": {
|
||||
"revenue": "string",
|
||||
"profit": "string",
|
||||
"expenses": "list[str]",
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "assess_financial_risks",
|
||||
"type": "map",
|
||||
"prompt": "Identify financial risks and warning signs in: {{ input.earnings_report }}",
|
||||
"output": {
|
||||
"schema": {
|
||||
"risks": "list[str]",
|
||||
"warning_signs": "list[str]",
|
||||
}
|
||||
},
|
||||
},
|
||||
],
|
||||
target_ops=["extract_financial_metrics", "assess_financial_risks"],
|
||||
expected_behavior="Should add Code Map operation that extracts financial content needed for both operations. The return dictionary of the transform function should be {'earnings_report': ....} only.",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, DeterministicDocCompressionDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("DeterministicDocCompressionDirective")
|
||||
|
||||
def to_string_for_instantiate(self, target_ops_configs: List[Dict]) -> str:
|
||||
"""
|
||||
Generate a prompt that asks the agent to output the instantiate schema.
|
||||
This prompt explains to the LLM what configuration it needs to generate.
|
||||
"""
|
||||
ops_str = "\n".join(
|
||||
[
|
||||
f"Operation {i+1}:\n{str(op)}\n"
|
||||
for i, op in enumerate(target_ops_configs)
|
||||
]
|
||||
)
|
||||
|
||||
return (
|
||||
f"You are an expert at document processing and Python programming.\n\n"
|
||||
f"Target Operations:\n"
|
||||
f"{ops_str}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to instantiate this directive by specifying a Code Map operation "
|
||||
f"that specifies how to compress the input document using deterministic logic.\n\n"
|
||||
f"The directive will insert a Code Map operation that:\n"
|
||||
f"1. Takes document field(s) from the input\n"
|
||||
f"2. Uses deterministic logic (regex, keyword matching, pattern extraction) to compress them\n"
|
||||
f"3. Returns a dictionary with the EXACT SAME document field key and compressed content\n"
|
||||
f"4. Reduces token usage and improves focus for the downstream operations\n\n"
|
||||
f"The agent must output the configuration specifying:\n"
|
||||
f"- name: A descriptive name for the Code Map operation\n"
|
||||
f"- code: Python code defining a 'transform' function that:\n"
|
||||
f" * Takes input_doc as parameter\n"
|
||||
f" * Imports 're' and other standard libraries WITHIN the function itself\n"
|
||||
f" * Only uses standard Python libraries (re, string, json, etc.) - no external packages\n"
|
||||
f" * Uses deterministic logic to extract relevant content patterns\n"
|
||||
f" * For each pattern match, extracts context around it (e.g., ±500 chars, or -300 to +800 chars)\n"
|
||||
f" * Use your judgment to determine appropriate character counts that capture enough context\n"
|
||||
f" * Merges overlapping context spans to avoid duplication\n"
|
||||
f" * Returns a dictionary with the EXACT document key and compressed value: {{document_key: compressed_content}}\n\n"
|
||||
f"CRITICAL: The returned dictionary must use the exact same document field names as the original, "
|
||||
f"not modified versions like 'document_key_compressed'. The downstream operations expect the exact same field names.\n\n"
|
||||
f"IMPORTANT: Focus on extracting the minimal content necessary for ALL target operations "
|
||||
f"using deterministic pattern matching. Analyze each operation's prompt to identify what types "
|
||||
f"of content patterns to look for.\n\n"
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output only the DeterministicDocCompressionInstantiateSchema as JSON "
|
||||
f"that specifies how to apply this directive to the target operations."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
target_ops_configs: List[Dict],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
):
|
||||
"""
|
||||
Call the LLM to generate the instantiate schema.
|
||||
The LLM will output structured data matching DeterministicDocCompressionInstantiateSchema.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(target_ops_configs),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=DeterministicDocCompressionInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = DeterministicDocCompressionInstantiateSchema(**parsed_res)
|
||||
|
||||
# Validate against target operations
|
||||
schema.validate_against_target_ops(target_ops_configs)
|
||||
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model: str,
|
||||
ops_list: List[Dict],
|
||||
target_ops: List[str],
|
||||
rewrite: DeterministicDocCompressionInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive using the instantiate schema configuration.
|
||||
Inserts a Code Map operation before the first target operation.
|
||||
"""
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find the position of the first target operation
|
||||
first_target_pos = min(
|
||||
[i for i, op in enumerate(ops_list) if op["name"] in target_ops]
|
||||
)
|
||||
|
||||
# Create the Code Map operation
|
||||
code_map_op = {
|
||||
"name": rewrite.name,
|
||||
"type": "code_map",
|
||||
"code": rewrite.code,
|
||||
}
|
||||
|
||||
# Insert the Code Map operation before the first target operation
|
||||
new_ops_list.insert(first_target_pos, code_map_op)
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Main method that orchestrates directive instantiation:
|
||||
1. Get agent to generate instantiate schema for all target operations
|
||||
2. Apply the transformation using that schema
|
||||
"""
|
||||
assert len(target_ops) >= 1, "This directive requires at least one target op"
|
||||
|
||||
# Get configurations for all target operations
|
||||
target_ops_configs = [op for op in operators if op["name"] in target_ops]
|
||||
|
||||
# Step 1: Agent generates the instantiate schema considering all target ops
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
target_ops_configs, agent_llm, message_history
|
||||
)
|
||||
|
||||
# Step 2: Apply transformation using the schema
|
||||
return (
|
||||
self.apply(global_default_model, operators, target_ops, rewrite),
|
||||
message_history,
|
||||
call_cost,
|
||||
)
|
||||
|
|
@ -1,466 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
DocumentChunkingInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class DocumentChunkingDirective(Directive):
|
||||
name: str = Field(default="doc_chunking", description="The name of the directive")
|
||||
formal_description: str = Field(
|
||||
default="Map => Split -> Gather -> [Sample] -> Map -> Reduce"
|
||||
)
|
||||
nl_description: str = Field(
|
||||
default="Transforms a single Map operation into a chunking pipeline: splits long documents into chunks, gathers context around each chunk, optionally samples a subset of chunks for efficiency, processes chunks with a new Map operation, then reduces the results. By default, sampling is applied unless the task requires processing ALL chunks. This directive can only be applied to a top-level Map operation, not to a sub-map within a pipeline that already contains a split, gather, or reduce sequence."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="Use when you need to process long documents to extract information, and the document is too long for a single Map operation. The agent will automatically decide whether to sample chunks (for tasks like categorization, theme extraction) or process all chunks (for comprehensive extraction of all instances). Do not apply if the target operation is already part of a split -> gather -> map -> reduce pipeline. Use different gather configs: 'previous.head' for documents with key metadata/definitions at the start, 'previous.tail' for maintaining references, and 'next.head' only for tables/clauses spanning chunks."
|
||||
)
|
||||
|
||||
instantiate_schema_type: Type[BaseModel] = Field(
|
||||
default=DocumentChunkingInstantiateSchema
|
||||
)
|
||||
|
||||
example: str = Field(
|
||||
default="""
|
||||
Original Op (MapOpConfig):
|
||||
- name: extract_contract_terms
|
||||
type: map
|
||||
prompt: |
|
||||
Extract all payment terms, deadlines, and penalty clauses from the contract:
|
||||
{{ input.contract_text }}
|
||||
Return a comprehensive list of all terms found.
|
||||
output:
|
||||
schema:
|
||||
contract_terms: list[str]
|
||||
|
||||
Example InstantiateSchema Options (what the agent should output):
|
||||
|
||||
# Basic context - good for most cases:
|
||||
{
|
||||
"chunk_size": 10000,
|
||||
"split_key": "contract_text",
|
||||
"sub_prompt": "You are analyzing a chunk of a larger document. Extract all payment terms, deadlines, and penalty clauses from this contract chunk: {{ input.contract_text_chunk_rendered }}. Return a comprehensive list of all terms found.",
|
||||
"reduce_prompt": "Combine results from multiple document chunks: Extract all payment terms, deadlines, and penalty clauses by combining the results from each chunk: {% for input in inputs %}{{ input.contract_terms | join(', ') }}{% if not loop.last %}, {% endif %}{% endfor %}. Remove duplicates and return a comprehensive list of all terms found.",
|
||||
"gather_config": {
|
||||
"previous": {
|
||||
"tail": {
|
||||
"count": 0.5
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Rich context - for complex documents needing document-level metadata:
|
||||
{
|
||||
"gather_config": {
|
||||
"previous": {
|
||||
"head": {
|
||||
"count": 1,
|
||||
"content_key": "contract_text_chunk"
|
||||
},
|
||||
"tail": {
|
||||
"count": 2,
|
||||
"content_key": "contract_text_chunk"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Forward context - for tables or clauses spanning chunks:
|
||||
{
|
||||
"gather_config": {
|
||||
"previous": {
|
||||
"tail": {
|
||||
"count": 1,
|
||||
"content_key": "contract_text_chunk"
|
||||
}
|
||||
},
|
||||
"next": {
|
||||
"head": {
|
||||
"count": 1,
|
||||
"content_key": "contract_text_chunk"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
""",
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="comprehensive_legal_analysis",
|
||||
description="Should transform complex legal document analysis into chunking pipeline",
|
||||
input_config={
|
||||
"name": "analyze_legal_document",
|
||||
"type": "map",
|
||||
"prompt": "From this legal document, extract all liability clauses with risk ratings (1-10), identify all parties and their obligations, find all monetary amounts with currencies, extract all dates and deadlines with legal consequences, and list all governing laws or jurisdictions mentioned. For each liability clause, assess the risk level considering industry standards and provide specific reasoning. Group findings by document section if clearly indicated. Return comprehensive analysis ensuring no critical legal elements are missed: {{ input.legal_document }}",
|
||||
"output": {
|
||||
"schema": {
|
||||
"liability_analysis": "list[str]",
|
||||
"parties_obligations": "list[str]",
|
||||
"financial_terms": "list[str]",
|
||||
"critical_dates": "list[str]",
|
||||
"governing_laws": "list[str]",
|
||||
"risk_assessment": "str",
|
||||
}
|
||||
},
|
||||
},
|
||||
target_ops=["analyze_legal_document"],
|
||||
expected_behavior="Should create chunking pipeline where sub_prompt covers all extraction tasks (liability, parties, financial, dates, laws) with same risk assessment criteria, and reduce_prompt aggregates all findings maintaining the complete analytical framework and output schema from original prompt",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="clinical_trial_comprehensive_extraction",
|
||||
description="Should transform detailed clinical research analysis into chunking pipeline",
|
||||
input_config={
|
||||
"name": "extract_clinical_data",
|
||||
"type": "map",
|
||||
"prompt": "Analyze this clinical trial document and extract: primary and secondary endpoints with measurement criteria, all adverse events categorized by severity (mild/moderate/severe), patient demographics including inclusion/exclusion criteria, statistical significance results with p-values and confidence intervals, drug dosages and administration protocols, and study methodology details. For each adverse event, determine if it's treatment-related based on temporal relationship and biological plausibility. Calculate overall safety profile score (1-10) considering frequency and severity of events. Ensure all regulatory compliance elements are captured: {{ input.trial_document }}",
|
||||
"output": {
|
||||
"schema": {
|
||||
"endpoints": "list[str]",
|
||||
"adverse_events": "list[str]",
|
||||
"demographics": "str",
|
||||
"statistical_results": "list[str]",
|
||||
"protocols": "list[str]",
|
||||
"safety_assessment": "str",
|
||||
"compliance_status": "str",
|
||||
}
|
||||
},
|
||||
},
|
||||
target_ops=["extract_clinical_data"],
|
||||
expected_behavior="Should create chunking pipeline where sub_prompt preserves all clinical analysis requirements (endpoints, adverse events, demographics, statistics, protocols) with same assessment criteria, and reduce_prompt combines results maintaining complete clinical framework and safety scoring methodology from original prompt",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="financial_comprehensive_analysis",
|
||||
description="Should transform complex financial document analysis into chunking pipeline",
|
||||
input_config={
|
||||
"name": "analyze_financial_report",
|
||||
"type": "map",
|
||||
"prompt": "From this annual financial report, extract all revenue streams with growth rates and market segments, identify all risk factors with impact assessments (low/medium/high), find all forward-looking statements and their associated uncertainties, extract key financial ratios and calculate trend analysis over mentioned periods, identify all subsidiaries with their contribution to consolidated results, and analyze competitive positioning statements. For each risk factor, assess potential financial impact in dollar ranges and likelihood percentages. Ensure all material information affecting investor decisions is captured and categorized by urgency level: {{ input.financial_report }}",
|
||||
"output": {
|
||||
"schema": {
|
||||
"revenue_analysis": "list[str]",
|
||||
"risk_factors": "list[str]",
|
||||
"forward_statements": "list[str]",
|
||||
"financial_ratios": "str",
|
||||
"subsidiaries": "list[str]",
|
||||
"competitive_analysis": "str",
|
||||
"material_disclosures": "list[str]",
|
||||
}
|
||||
},
|
||||
},
|
||||
target_ops=["analyze_financial_report"],
|
||||
expected_behavior="Should create chunking pipeline where sub_prompt maintains all financial analysis requirements (revenue, risks, statements, ratios, subsidiaries, competitive analysis) with same impact assessment methodology, and reduce_prompt aggregates preserving complete financial analytical framework and materiality assessments from original prompt",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, DocumentChunkingDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("DocumentChunkingDirective")
|
||||
|
||||
def to_string_for_instantiate(self, original_op: Dict) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive.
|
||||
|
||||
Args:
|
||||
original_op (Dict): The original operation.
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
return (
|
||||
f"You are an expert at transforming document processing operations into chunking pipelines.\n\n"
|
||||
f"Original Operation:\n"
|
||||
f"{str(original_op)}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to instantiate this directive by creating a configuration that transforms the original Map operation "
|
||||
f"into a Split -> Gather -> Map -> Reduce pipeline for processing long documents in chunks.\n\n"
|
||||
f"Key requirements:\n"
|
||||
f"1. chunk_size: Choose an appropriate token count (typically 10000-15000) based on the complexity of the task\n"
|
||||
f"2. split_key: Identify the document field to split from the original operation's prompt. Make sure to use the same field as the original operation's prompt.\n"
|
||||
f"3. sub_prompt: Take the original prompt exactly but:\n"
|
||||
f" - Add instruction at start: 'You are analyzing a chunk of a larger document.'\n"
|
||||
f" - Replace '{{{{ input.<split_key> }}}}' with '{{{{ input.<split_key>_chunk_rendered }}}}'\n"
|
||||
f" - Keep everything else identical (same task, same output schema)\n"
|
||||
f"4. reduce_prompt: Take original task instructions but adapt for aggregation:\n"
|
||||
f" - Start with: 'Combine results from multiple document chunks:'\n"
|
||||
f" - Include same task context/requirements as original prompt\n"
|
||||
f" - Use '{{% for input in inputs %}}' to iterate over chunk results\n"
|
||||
f" - Combine/deduplicate to match original output schema exactly\n"
|
||||
f"5. sampling_config: IMPORTANT - Include sampling by default UNLESS the task requires ALL chunks:\n"
|
||||
f" - ALWAYS use sampling for: categorization, theme identification, sentiment analysis, document type classification\n"
|
||||
f" - NEVER use sampling for: comprehensive extraction ('extract ALL instances'), complete analysis requiring every chunk\n"
|
||||
f" - Default sampling: method='uniform' with stratify_key, samples=5-10 chunks\n"
|
||||
f" - For simple tasks (categorization): samples=1-3 chunks\n"
|
||||
f" - For complex analysis: samples=5-15 chunks\n"
|
||||
f" - For stratified sampling: specify method='uniform' and stratify_key (note: split document ID is automatically included)\n"
|
||||
f" - Set sampling_config=null only if you need to process every single chunk\n"
|
||||
f"6. gather_config: Configure context from surrounding chunks. Structure:\n"
|
||||
f" gather_config:\n"
|
||||
f" previous: # chunks before current chunk\n"
|
||||
f" head: # first chunk(s) in document\n"
|
||||
f" count: 1\n"
|
||||
f" content_key: full_content_chunk # optional, defaults to main key chunk\n"
|
||||
f" tail: # chunk(s) immediately before current\n"
|
||||
f" count: 2\n"
|
||||
f" content_key: full_content_chunk # optional\n"
|
||||
f" next: # chunks after current chunk\n"
|
||||
f" head: # chunk(s) immediately after current\n"
|
||||
f" count: 1\n"
|
||||
f" content_key: full_content_chunk # optional\n"
|
||||
f" Usage guidelines:\n"
|
||||
f" - Use 'previous.head' when document has important metadata/definitions at start\n"
|
||||
f" - Use 'previous.tail' to maintain references and immediate context\n"
|
||||
f" - Use 'next.head' only for tables/clauses that span chunk boundaries\n"
|
||||
f" - Minimize context: only include what's needed for accurate operation\n"
|
||||
f" - Count can be float (e.g., 0.5 for half chunk, 1.5 for chunk and a half)\n"
|
||||
f" - More context increases token usage and cost - be judicious\n"
|
||||
f" - Default to 0.5 previous tail if unsure about context needs\n"
|
||||
f"The sub_prompt should focus on the main chunk content and extract the same type of information as the original.\n"
|
||||
f"The reduce_prompt must produce the same output schema as the original operation.\n\n"
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output only the DocumentChunkingInstantiateSchema object as JSON."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
operators,
|
||||
input_file_path,
|
||||
original_op: Dict,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
) -> tuple:
|
||||
"""
|
||||
Use LLM to instantiate this directive by creating chunking configuration.
|
||||
|
||||
Args:
|
||||
original_op (Dict): The original operation.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
|
||||
Returns:
|
||||
DocumentChunkingInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(original_op),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
call_cost = 0.0
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=DocumentChunkingInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = DocumentChunkingInstantiateSchema(**parsed_res)
|
||||
schema.validate_stratify_key_in_pipeline(operators)
|
||||
schema.validate_split_key_exists_in_input(input_file_path)
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model: str,
|
||||
ops_list: List[Dict],
|
||||
target_op: str,
|
||||
rewrite: DocumentChunkingInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config by replacing the target operation
|
||||
with a split -> gather -> map -> reduce sequence.
|
||||
"""
|
||||
|
||||
# Create a copy of the pipeline config
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
target_op_config = [op for op in new_ops_list if op["name"] == target_op][0]
|
||||
|
||||
# Find position of the target op to replace
|
||||
pos_to_replace = [
|
||||
i for i, op in enumerate(ops_list) if op["name"] == target_op
|
||||
][0]
|
||||
|
||||
original_op = ops_list[pos_to_replace]
|
||||
original_model = target_op_config.get("model", global_default_model)
|
||||
|
||||
# Create operation names based on the original operation name
|
||||
split_name = f"split_{target_op}"
|
||||
gather_name = f"gather_{target_op}"
|
||||
sample_name = f"sample_{target_op}_chunks"
|
||||
map_name = f"map_{target_op}_chunks"
|
||||
reduce_name = f"reduce_{target_op}"
|
||||
|
||||
# Create the split operation
|
||||
split_op = {
|
||||
"name": split_name,
|
||||
"type": "split",
|
||||
"split_key": rewrite.split_key,
|
||||
"method": "token_count",
|
||||
"method_kwargs": {
|
||||
"num_tokens": rewrite.chunk_size,
|
||||
"model": original_model,
|
||||
},
|
||||
}
|
||||
|
||||
# Create the gather operation with agent-configured context
|
||||
# Convert Pydantic model to dict, excluding None values
|
||||
gather_config_dict = (
|
||||
rewrite.gather_config.model_dump(exclude_none=True)
|
||||
if rewrite.gather_config
|
||||
else {}
|
||||
)
|
||||
# Use default config if the gather_config is empty (all fields were None)
|
||||
if not gather_config_dict:
|
||||
gather_config_dict = {"previous": {"tail": {"count": 1}}}
|
||||
|
||||
gather_op = {
|
||||
"name": gather_name,
|
||||
"type": "gather",
|
||||
"content_key": f"{rewrite.split_key}_chunk",
|
||||
"doc_id_key": f"{split_name}_id",
|
||||
"order_key": f"{split_name}_chunk_num",
|
||||
"peripheral_chunks": gather_config_dict,
|
||||
}
|
||||
|
||||
# Create the map operation for processing chunks
|
||||
map_op = {
|
||||
"name": map_name,
|
||||
"type": "map",
|
||||
"prompt": rewrite.sub_prompt,
|
||||
"model": original_model,
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
"output": deepcopy(original_op["output"]), # Same output schema as original
|
||||
}
|
||||
|
||||
# Create the reduce operation
|
||||
reduce_op = {
|
||||
"name": reduce_name,
|
||||
"type": "reduce",
|
||||
"reduce_key": f"{split_name}_id",
|
||||
"prompt": rewrite.reduce_prompt,
|
||||
"model": original_model,
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
"output": deepcopy(original_op["output"]), # Same output schema as original
|
||||
"associative": False, # Order matters for chunks,
|
||||
"pass_through": True,
|
||||
}
|
||||
|
||||
# Construct operation sequence
|
||||
ops_sequence = [split_op, gather_op]
|
||||
|
||||
# Add sample operation if sampling config is provided
|
||||
if rewrite.sampling_config:
|
||||
sample_op = {
|
||||
"name": sample_name,
|
||||
"type": "sample",
|
||||
"method": rewrite.sampling_config.method,
|
||||
"samples": rewrite.sampling_config.samples,
|
||||
}
|
||||
|
||||
# Always stratify by split document ID and set samples_per_group
|
||||
stratify_keys = [f"{split_name}_id"]
|
||||
|
||||
# Add agent-specified stratify key if provided
|
||||
if rewrite.sampling_config.stratify_key:
|
||||
stratify_keys.append(rewrite.sampling_config.stratify_key)
|
||||
sample_op["stratify_key"] = stratify_keys
|
||||
|
||||
if rewrite.sampling_config.samples_per_group:
|
||||
sample_op["samples_per_group"] = (
|
||||
rewrite.sampling_config.samples_per_group
|
||||
)
|
||||
|
||||
# Add optional fields if provided
|
||||
if rewrite.sampling_config.random_state is not None:
|
||||
sample_op["random_state"] = rewrite.sampling_config.random_state
|
||||
|
||||
if rewrite.sampling_config.method_kwargs is not None:
|
||||
try:
|
||||
sample_op["method_kwargs"] = json.loads(
|
||||
rewrite.sampling_config.method_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid method_kwargs: {e}")
|
||||
|
||||
ops_sequence.append(sample_op)
|
||||
|
||||
# Add map and reduce operations
|
||||
ops_sequence.extend([map_op, reduce_op])
|
||||
|
||||
# Replace the target operation with the new sequence
|
||||
new_ops_list[pos_to_replace : pos_to_replace + 1] = ops_sequence
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Assert that there is only one target op
|
||||
assert (
|
||||
len(target_ops) == 1
|
||||
), "There must be exactly one target op to instantiate this document chunking directive"
|
||||
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
|
||||
|
||||
input_file_path = kwargs.get("input_file_path", None)
|
||||
|
||||
# Validate that the target operation is a map operation
|
||||
if target_op_config.get("type") != "map":
|
||||
raise ValueError(
|
||||
f"Document chunking directive can only be applied to map operations, but target operation '{target_ops[0]}' is of type '{target_op_config.get('type')}'"
|
||||
)
|
||||
|
||||
# Instantiate the directive
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
operators, input_file_path, target_op_config, agent_llm, message_history
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators
|
||||
return (
|
||||
self.apply(global_default_model, operators, target_ops[0], rewrite),
|
||||
message_history,
|
||||
call_cost,
|
||||
)
|
||||
|
|
@ -1,619 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
DocumentChunkingTopKInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class DocumentChunkingTopKDirective(Directive):
|
||||
name: str = Field(
|
||||
default="doc_chunking_topk", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(
|
||||
default="Map/Filter => Split -> TopK -> Reduce (-> Code Filter if original was Filter)"
|
||||
)
|
||||
nl_description: str = Field(
|
||||
default="Cost optimization directive for documents where only certain portions are relevant to the task (when at least half the document is irrelevant). Works with both Map and Filter operations. Transforms into a retrieval-augmented pipeline: splits documents into chunks, uses topk to retrieve the most relevant chunks, processes them in a reduce operation. For Filter operations, adds a final code_filter step to return boolean results. Ideal when processing full documents would be wasteful due to irrelevant content."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="Use when only certain portions of documents are relevant to the task and at least half of the document content is irrelevant. Perfect for complex filters (e.g., 'does this review mention competitor products more favorably?') or targeted extraction from documents with localized relevant sections. Works with both Map (extraction) and Filter (boolean decision) operations. The retrieval step (embedding or FTS) finds the relevant chunks, avoiding processing irrelevant content. For filters, the final code_filter converts the reduce output to True/False."
|
||||
)
|
||||
|
||||
instantiate_schema_type: Type[BaseModel] = Field(
|
||||
default=DocumentChunkingTopKInstantiateSchema
|
||||
)
|
||||
|
||||
example: str = Field(
|
||||
default="""
|
||||
# Example 1: Complex Filter on Long Customer Reviews
|
||||
Original Op:
|
||||
- name: filter_competitor_mentions
|
||||
type: filter
|
||||
prompt: |
|
||||
Analyze this customer review to determine if it mentions competitor products
|
||||
more positively than our product.
|
||||
|
||||
Our Product: {{ input.our_product }}
|
||||
Review: {{ input.review_text }}
|
||||
Review ID: {{ input.review_id }}
|
||||
|
||||
Return true if the review speaks more favorably about competitor products than ours.
|
||||
Consider: feature comparisons, performance mentions, value assessments, recommendations.
|
||||
output:
|
||||
schema:
|
||||
mentions_competitors_more_positively: bool
|
||||
|
||||
InstantiateSchema (filter with embedding search):
|
||||
{
|
||||
"chunk_size": 10000,
|
||||
"split_key": "review_text",
|
||||
"reduce_prompt": "Analyze this customer review to determine if it mentions competitor products more positively than our product.\\n\\nOur Product: {{ inputs[0].our_product }}\\nReview: the top {{ inputs|length }} most relevant chunks from the document (ordered by relevance):\\n{% for input in inputs|sort(attribute='_topk_filter_competitor_mentions_chunks_rank') %}\\nChunk (Rank {{ input._topk_filter_competitor_mentions_chunks_rank }}, Score {{ input._topk_filter_competitor_mentions_chunks_score }}):\\n{{ input.review_text_chunk }}\\n{% endfor %}\\nReview ID: {{ inputs[0].review_id }}\\n\\nReturn true if the review speaks more favorably about competitor products than ours.\\nConsider: feature comparisons, performance mentions, value assessments, recommendations.",
|
||||
"topk_config": {
|
||||
"method": "embedding",
|
||||
"k": 5,
|
||||
"query": "competitor comparison versus alternative better than superior inferior worse features performance value recommendation prefer instead",
|
||||
"keys": ["review_text_chunk"],
|
||||
"embedding_model": "text-embedding-3-small"
|
||||
}
|
||||
}
|
||||
|
||||
# Example 2: Map Operation - Extract Specific Sections from Long Documents
|
||||
Original Op:
|
||||
- name: extract_methodology_from_paper
|
||||
type: map
|
||||
prompt: |
|
||||
Extract detailed methodology from this research paper:
|
||||
|
||||
Paper: {{ input.paper_content }}
|
||||
Title: {{ input.title }}
|
||||
|
||||
Extract: study design, sample size, data collection methods,
|
||||
statistical analyses, and validation approaches.
|
||||
output:
|
||||
schema:
|
||||
study_design: str
|
||||
sample_size: dict
|
||||
data_collection: list[str]
|
||||
statistical_methods: list[str]
|
||||
validation: str
|
||||
|
||||
InstantiateSchema (map with embedding search):
|
||||
{
|
||||
"chunk_size": 15000,
|
||||
"split_key": "paper_content",
|
||||
"reduce_prompt": "Extract detailed methodology from this research paper:\\n\\nPaper: the top {{ inputs|length }} most relevant chunks from the document (ordered by relevance):\\n{% for input in inputs|sort(attribute='_topk_extract_methodology_from_paper_chunks_rank') %}\\nChunk (Rank {{ input._topk_extract_methodology_from_paper_chunks_rank }}, Score {{ input._topk_extract_methodology_from_paper_chunks_score }}):\\n{{ input.paper_content_chunk }}\\n{% endfor %}\\nTitle: {{ inputs[0].title }}\\n\\nExtract: study design, sample size, data collection methods, statistical analyses, and validation approaches.",
|
||||
"topk_config": {
|
||||
"method": "embedding",
|
||||
"k": 8,
|
||||
"query": "methodology methods study design sample size participants data collection statistical analysis validation procedure protocol experimental",
|
||||
"keys": ["paper_content_chunk"],
|
||||
"embedding_model": "text-embedding-3-small"
|
||||
}
|
||||
}
|
||||
|
||||
# Example 3: Filter with FTS - Check Contract Compliance
|
||||
Original Op:
|
||||
- name: filter_contracts_with_liability_caps
|
||||
type: filter
|
||||
prompt: |
|
||||
Determine if this contract contains liability cap provisions
|
||||
that limit damages to less than $1 million.
|
||||
|
||||
Contract: {{ input.contract_text }}
|
||||
Contract ID: {{ input.contract_id }}
|
||||
Party: {{ input.counterparty }}
|
||||
|
||||
Return true if contract caps liability below $1M, false otherwise.
|
||||
output:
|
||||
schema:
|
||||
has_low_liability_cap: bool
|
||||
|
||||
InstantiateSchema (filter with FTS for legal terms):
|
||||
{
|
||||
"chunk_size": 12000,
|
||||
"split_key": "contract_text",
|
||||
"reduce_prompt": "Determine if this contract contains liability cap provisions that limit damages to less than $1 million.\\n\\nContract: the top {{ inputs|length }} most relevant chunks from the document (ordered by relevance):\\n{% for input in inputs|sort(attribute='_topk_filter_contracts_with_liability_caps_chunks_rank') %}\\nSection (Rank {{ input._topk_filter_contracts_with_liability_caps_chunks_rank }}, Score {{ input._topk_filter_contracts_with_liability_caps_chunks_score }}):\\n{{ input.contract_text_chunk }}\\n{% endfor %}\\nContract ID: {{ inputs[0].contract_id }}\\nParty: {{ inputs[0].counterparty }}\\n\\nReturn true if contract caps liability below $1M, false otherwise.",
|
||||
"topk_config": {
|
||||
"method": "fts",
|
||||
"k": 10,
|
||||
"query": "liability limitation cap maximum damages indirect consequential million dollars aggregate total exposure indemnification",
|
||||
"keys": ["contract_text_chunk"]
|
||||
}
|
||||
}
|
||||
""",
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="clinical_trial_adverse_events_extraction",
|
||||
description="Should transform clinical trial safety analysis into chunking pipeline with embedding-based topk for thematic content",
|
||||
input_config={
|
||||
"name": "extract_clinical_trial_safety",
|
||||
"type": "map",
|
||||
"prompt": """Analyze this clinical trial protocol and safety report to extract comprehensive safety information:
|
||||
|
||||
Protocol Number: {{ input.protocol_id }}
|
||||
Study Phase: {{ input.study_phase }}
|
||||
Document: {{ input.trial_document }}
|
||||
|
||||
Extract and analyze:
|
||||
1. ALL adverse events (AEs) with:
|
||||
- Event description and medical terminology (MedDRA preferred terms)
|
||||
- Severity grade (1-5 per CTCAE v5.0)
|
||||
- Relationship to study drug (definitely, probably, possibly, unlikely, not related)
|
||||
- Onset timing relative to treatment start
|
||||
- Resolution status and duration
|
||||
- Actions taken (dose reduced, interrupted, discontinued)
|
||||
|
||||
2. Serious adverse events (SAEs) with additional details:
|
||||
- Hospitalization requirements
|
||||
- Life-threatening classification
|
||||
- Death outcomes with causality assessment
|
||||
- Expedited reporting timeline compliance
|
||||
|
||||
3. Laboratory abnormalities:
|
||||
- Clinically significant lab value shifts
|
||||
- Grade 3/4 laboratory toxicities
|
||||
- Liver function test elevations (ALT, AST, bilirubin)
|
||||
- Renal function changes (creatinine, eGFR)
|
||||
- Hematologic abnormalities
|
||||
|
||||
4. Dose-limiting toxicities (DLTs) and maximum tolerated dose (MTD) determination
|
||||
|
||||
5. Safety run-in period results if applicable
|
||||
|
||||
6. Data safety monitoring board (DSMB) recommendations and protocol modifications
|
||||
|
||||
Ensure all safety data is captured with appropriate medical coding and regulatory compliance.""",
|
||||
"output": {
|
||||
"schema": {
|
||||
"adverse_events": "list[dict]",
|
||||
"serious_adverse_events": "list[dict]",
|
||||
"lab_abnormalities": "list[dict]",
|
||||
"dose_limiting_toxicities": "list[dict]",
|
||||
"dsmb_recommendations": "list[str]",
|
||||
"safety_summary": "dict",
|
||||
}
|
||||
},
|
||||
},
|
||||
target_ops=["extract_clinical_trial_safety"],
|
||||
expected_behavior="Should create chunking pipeline with topk using embedding search to find sections discussing adverse events, safety data, laboratory results, and DSMB recommendations. Chunks should be 5-8k tokens with k=10-15 to capture all safety-related sections",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="sec_filing_risk_factors_extraction",
|
||||
description="Should transform SEC filing analysis into chunking pipeline with FTS-based topk for specific regulatory terms",
|
||||
input_config={
|
||||
"name": "extract_sec_risk_disclosures",
|
||||
"type": "map",
|
||||
"prompt": """Extract and analyze all risk factor disclosures from this SEC 10-K filing:
|
||||
|
||||
Company: {{ input.company_ticker }}
|
||||
Filing Period: {{ input.filing_period }}
|
||||
Document: {{ input.form_10k }}
|
||||
Industry: {{ input.industry_classification }}
|
||||
|
||||
Identify and categorize:
|
||||
|
||||
1. BUSINESS AND OPERATIONAL RISKS:
|
||||
- Supply chain vulnerabilities and dependencies
|
||||
- Key customer concentration (customers >10% revenue)
|
||||
- Competition and market share threats
|
||||
- Product obsolescence and innovation risks
|
||||
- Manufacturing and quality control risks
|
||||
- Intellectual property disputes and patent expirations
|
||||
|
||||
2. FINANCIAL AND MARKET RISKS:
|
||||
- Liquidity and cash flow concerns
|
||||
- Debt covenants and refinancing risks
|
||||
- Foreign exchange exposure by currency
|
||||
- Interest rate sensitivity analysis
|
||||
- Credit risk and counterparty exposure
|
||||
- Goodwill and intangible asset impairment risks
|
||||
|
||||
3. REGULATORY AND COMPLIANCE RISKS:
|
||||
- SEC investigation disclosures
|
||||
- FDA/regulatory approval dependencies
|
||||
- Environmental liabilities and remediation costs
|
||||
- Tax disputes and uncertain tax positions
|
||||
- FCPA and anti-corruption compliance
|
||||
- Data privacy (GDPR, CCPA) obligations
|
||||
|
||||
4. CYBERSECURITY AND TECHNOLOGY RISKS:
|
||||
- Data breach history and potential impacts
|
||||
- IT system dependencies and modernization needs
|
||||
- Third-party technology provider risks
|
||||
- Business continuity and disaster recovery
|
||||
|
||||
5. LITIGATION AND LEGAL RISKS:
|
||||
- Material pending litigation with potential damages
|
||||
- Class action lawsuit exposure
|
||||
- Warranty and product liability claims
|
||||
- Employment and labor disputes
|
||||
|
||||
6. ESG AND REPUTATIONAL RISKS:
|
||||
- Climate change physical and transition risks
|
||||
- Social license to operate concerns
|
||||
- Executive succession planning
|
||||
- Related party transaction risks
|
||||
|
||||
For each risk, extract:
|
||||
- Risk description and specific company exposure
|
||||
- Quantitative impact estimates if disclosed
|
||||
- Mitigation strategies mentioned
|
||||
- Changes from prior year disclosure
|
||||
- Forward-looking statements and warnings""",
|
||||
"output": {
|
||||
"schema": {
|
||||
"business_operational_risks": "list[dict]",
|
||||
"financial_market_risks": "list[dict]",
|
||||
"regulatory_compliance_risks": "list[dict]",
|
||||
"cybersecurity_technology_risks": "list[dict]",
|
||||
"litigation_legal_risks": "list[dict]",
|
||||
"esg_reputational_risks": "list[dict]",
|
||||
"risk_factor_changes": "list[dict]",
|
||||
"material_risk_summary": "dict",
|
||||
}
|
||||
},
|
||||
},
|
||||
target_ops=["extract_sec_risk_disclosures"],
|
||||
expected_behavior="Should create chunking pipeline with topk using FTS to search for specific risk-related keywords and sections (Item 1A, risk factors, legal proceedings, etc.). Chunks should be 6-10k tokens with k=15-20 to ensure comprehensive risk coverage",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="insurance_claim_analysis_with_dynamic_query",
|
||||
description="Should transform insurance claim analysis with Jinja template query based on claim type",
|
||||
input_config={
|
||||
"name": "analyze_insurance_claim",
|
||||
"type": "map",
|
||||
"prompt": """Analyze this insurance claim file for coverage determination and fraud indicators:
|
||||
|
||||
Claim Number: {{ input.claim_id }}
|
||||
Policy Number: {{ input.policy_number }}
|
||||
Claim Type: {{ input.claim_type }}
|
||||
Claimed Amount: {{ input.claimed_amount }}
|
||||
Policy Documents: {{ input.policy_documents }}
|
||||
Claim Documents: {{ input.claim_submission }}
|
||||
Prior Claims History: {{ input.claims_history }}
|
||||
|
||||
Perform comprehensive analysis:
|
||||
|
||||
1. COVERAGE DETERMINATION:
|
||||
- Verify incident date falls within policy period
|
||||
- Check specific peril coverage for {{ input.claim_type }} claims
|
||||
- Identify applicable policy limits and sublimits
|
||||
- Calculate deductibles and co-insurance
|
||||
- Review exclusions that may apply
|
||||
- Assess pre-existing condition clauses (if medical)
|
||||
- Verify additional living expense limits (if property)
|
||||
|
||||
2. CLAIM VALIDATION:
|
||||
- Cross-reference damage description with photos/evidence
|
||||
- Verify repair estimates against market rates
|
||||
- Validate medical treatment necessity and coding
|
||||
- Check for duplicate submissions or double-dipping
|
||||
- Verify loss circumstances match policy terms
|
||||
|
||||
3. FRAUD INDICATORS ASSESSMENT:
|
||||
- Pattern analysis against known fraud schemes
|
||||
- Inconsistencies in statements or documentation
|
||||
- Suspicious timing (policy inception, premium issues)
|
||||
- Inflated valuations or treatment costs
|
||||
- Missing or altered documentation
|
||||
- Prior suspicious claims pattern
|
||||
|
||||
4. THIRD-PARTY LIABILITY:
|
||||
- Subrogation opportunities
|
||||
- Other insurance coverage available
|
||||
- Responsible party identification
|
||||
- Coordination of benefits requirements
|
||||
|
||||
5. REGULATORY COMPLIANCE:
|
||||
- State-specific claim handling requirements
|
||||
- Unfair claim settlement practices act compliance
|
||||
- Required notices and timelines
|
||||
- Bad faith claim indicators
|
||||
|
||||
6. SETTLEMENT RECOMMENDATION:
|
||||
- Covered amount calculation
|
||||
- Recommended settlement range
|
||||
- Payment breakdown by category
|
||||
- Reserve recommendations
|
||||
- Special investigation unit (SIU) referral if warranted""",
|
||||
"output": {
|
||||
"schema": {
|
||||
"coverage_analysis": "dict",
|
||||
"claim_validation": "dict",
|
||||
"fraud_indicators": "list[dict]",
|
||||
"third_party_liability": "dict",
|
||||
"regulatory_compliance": "dict",
|
||||
"settlement_recommendation": "dict",
|
||||
"siu_referral": "bool",
|
||||
"reserve_amount": "float",
|
||||
}
|
||||
},
|
||||
},
|
||||
target_ops=["analyze_insurance_claim"],
|
||||
expected_behavior="Should create chunking pipeline with topk using dynamic Jinja query that incorporates claim_type to search for relevant policy sections and prior claims. Query should adapt based on whether it's property, auto, medical, or liability claim. Chunks should be 5-7k tokens with k=12-18",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, DocumentChunkingTopKDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("DocumentChunkingTopKDirective")
|
||||
|
||||
def to_string_for_instantiate(self, original_op: Dict) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive.
|
||||
|
||||
Args:
|
||||
original_op (Dict): The original operation.
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
op_type = original_op.get("type", "map")
|
||||
return (
|
||||
f"You are an expert at transforming document processing operations into chunking pipelines with intelligent topk-based retrieval.\n\n"
|
||||
f"Original Operation:\n"
|
||||
f"{str(original_op)}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to instantiate this directive by creating a configuration that transforms the original {op_type.capitalize()} operation "
|
||||
f"into a Split -> TopK -> Reduce pipeline for processing very long documents with intelligent chunk selection.\n"
|
||||
f"{'For Filter operations, a final code_filter step will be automatically added to return boolean results.' if op_type == 'filter' else ''}\n\n"
|
||||
f"Key requirements:\n"
|
||||
f"1. chunk_size: Choose an appropriate token count (typically 10000-15000) for cost-effective processing of long documents\n"
|
||||
f"2. split_key: Identify the document field to split from the original operation's prompt (the longest text field)\n"
|
||||
f"3. reduce_prompt: Use the EXACT SAME prompt as the original, with ONE change:\n"
|
||||
f" - Where the original references '{{{{ input.<split_key> }}}}', replace it with:\n"
|
||||
f" 'the top {{{{ inputs|length }}}} most relevant chunks from the document (ordered by relevance):\\n{{% for input in inputs|sort(attribute='_<topk_name>_rank') %}}\\nChunk (Rank {{{{ input._<topk_name>_rank }}}}, Score {{{{ input._<topk_name>_score }}}}):\\n{{{{ input.<split_key>_chunk }}}}\\n{{% endfor %}}'\n"
|
||||
f" - Keep EVERYTHING else identical - same instructions, same output requirements\n"
|
||||
f" - For other context fields (non-document fields), use {{{{ inputs[0].field_name }}}} instead of {{{{ input.field_name }}}}\n"
|
||||
f"4. topk_config: REQUIRED - Configure intelligent chunk selection:\n"
|
||||
f" - method: Choose 'embedding' for semantic similarity or 'fts' for keyword matching\n"
|
||||
f" * Use 'embedding' when: looking for conceptual comparisons, themes, or abstract relationships\n"
|
||||
f" * Use 'fts' when: searching for specific terms, legal clauses, technical codes, or exact phrases\n"
|
||||
f" - k: Number of chunks to retrieve (typically 5-10 for comprehensive coverage)\n"
|
||||
f" * For complex tasks needing most of the document as context: k=10\n"
|
||||
f" * For targeted extraction from specific sections: k=5\n"
|
||||
f" - query: Craft carefully to find chunks relevant to the {'filter decision' if op_type == 'filter' else 'extraction task'}\n"
|
||||
f" * For embedding: use terms related to the comparison/decision criteria\n"
|
||||
f" * For fts: use specific keywords that appear in relevant sections\n"
|
||||
f" * Can use Jinja: '{{{{ input.competitor_name }}}} comparison versus {{{{ input.our_product }}}}'\n"
|
||||
f" - keys: Always use the chunk key, typically ['<split_key>_chunk']\n"
|
||||
f" - embedding_model: (optional, only for embedding method) defaults to 'text-embedding-3-small'\n"
|
||||
f"The topk query should be carefully crafted to find the most relevant chunks.\n"
|
||||
f"The reduce_prompt must process chunks directly and {'output a boolean decision' if op_type == 'filter' else 'preserve the original output schema'}.\n\n"
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output only the DocumentChunkingTopKInstantiateSchema object as JSON."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
operators,
|
||||
input_file_path,
|
||||
original_op: Dict,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive by creating chunking configuration with topk.
|
||||
|
||||
Args:
|
||||
original_op (Dict): The original operation.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
|
||||
Returns:
|
||||
DocumentChunkingTopKInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(original_op),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=DocumentChunkingTopKInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = DocumentChunkingTopKInstantiateSchema(**parsed_res)
|
||||
schema.validate_split_key_exists_in_input(input_file_path)
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model: str,
|
||||
ops_list: List[Dict],
|
||||
target_op: str,
|
||||
rewrite: DocumentChunkingTopKInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config by replacing the target operation
|
||||
with a split -> topk -> reduce sequence.
|
||||
"""
|
||||
|
||||
# Create a copy of the pipeline config
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
target_op_config = [op for op in new_ops_list if op["name"] == target_op][0]
|
||||
original_model = target_op_config.get("model", global_default_model)
|
||||
|
||||
# Find position of the target op to replace
|
||||
pos_to_replace = [
|
||||
i for i, op in enumerate(ops_list) if op["name"] == target_op
|
||||
][0]
|
||||
|
||||
original_op = ops_list[pos_to_replace]
|
||||
|
||||
# Create operation names based on the original operation name
|
||||
split_name = f"split_{target_op}"
|
||||
topk_name = f"topk_{target_op}_chunks"
|
||||
reduce_name = f"reduce_{target_op}"
|
||||
|
||||
# Create the split operation
|
||||
split_op = {
|
||||
"name": split_name,
|
||||
"type": "split",
|
||||
"split_key": rewrite.split_key,
|
||||
"method": "token_count",
|
||||
"method_kwargs": {
|
||||
"num_tokens": rewrite.chunk_size,
|
||||
"model": original_model,
|
||||
},
|
||||
}
|
||||
|
||||
# Create the topk operation
|
||||
topk_op = {
|
||||
"name": topk_name,
|
||||
"type": "topk",
|
||||
"method": rewrite.topk_config.method,
|
||||
"k": rewrite.topk_config.k,
|
||||
"keys": rewrite.topk_config.keys,
|
||||
"query": rewrite.topk_config.query,
|
||||
"stratify_key": [f"{split_name}_id"],
|
||||
}
|
||||
|
||||
# Add stratify_key if specified
|
||||
if rewrite.topk_config.stratify_key:
|
||||
topk_op["stratify_key"] = topk_op["stratify_key"] + [
|
||||
rewrite.topk_config.stratify_key
|
||||
]
|
||||
|
||||
# Add embedding_model for embedding method
|
||||
if (
|
||||
rewrite.topk_config.method == "embedding"
|
||||
and rewrite.topk_config.embedding_model
|
||||
):
|
||||
topk_op["embedding_model"] = rewrite.topk_config.embedding_model
|
||||
|
||||
# Check if original operation is a filter
|
||||
is_filter = original_op.get("type") == "filter"
|
||||
|
||||
# Create the reduce operation that directly processes the chunks
|
||||
if is_filter:
|
||||
# For filter operations, reduce should output a boolean field
|
||||
reduce_output = deepcopy(original_op["output"])
|
||||
# Ensure we have a boolean field in the output schema
|
||||
if "schema" in reduce_output:
|
||||
# Get the first boolean field name from the schema
|
||||
bool_field = None
|
||||
for field_name, field_type in reduce_output["schema"].items():
|
||||
if "bool" in field_type.lower():
|
||||
bool_field = field_name
|
||||
break
|
||||
if not bool_field:
|
||||
# If no boolean field found, create one
|
||||
bool_field = "filter_result"
|
||||
reduce_output["schema"] = {bool_field: "bool"}
|
||||
else:
|
||||
reduce_output = deepcopy(original_op["output"])
|
||||
|
||||
reduce_op = {
|
||||
"name": reduce_name,
|
||||
"type": "reduce",
|
||||
"reduce_key": f"{split_name}_id",
|
||||
"prompt": rewrite.reduce_prompt,
|
||||
"model": original_model,
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
"output": reduce_output,
|
||||
"associative": False, # Order matters for chunks
|
||||
"pass_through": True,
|
||||
}
|
||||
|
||||
# Construct operation sequence
|
||||
ops_sequence = [split_op, topk_op, reduce_op]
|
||||
|
||||
# If original was a filter, add a code_filter operation
|
||||
if is_filter:
|
||||
# Find the boolean field name in the output schema
|
||||
bool_field = None
|
||||
if "schema" in reduce_output:
|
||||
for field_name, field_type in reduce_output["schema"].items():
|
||||
if "bool" in field_type.lower():
|
||||
bool_field = field_name
|
||||
break
|
||||
|
||||
if bool_field:
|
||||
code_filter_op = {
|
||||
"name": f"code_filter_{target_op}",
|
||||
"type": "code_filter",
|
||||
"code": f"def transform(input_doc):\n return input_doc.get('{bool_field}', False)",
|
||||
}
|
||||
ops_sequence.append(code_filter_op)
|
||||
|
||||
# Replace the target operation with the new sequence
|
||||
new_ops_list[pos_to_replace : pos_to_replace + 1] = ops_sequence
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
) -> tuple:
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Assert that there is only one target op
|
||||
assert (
|
||||
len(target_ops) == 1
|
||||
), "There must be exactly one target op to instantiate this document chunking topk directive"
|
||||
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
|
||||
|
||||
input_file_path = kwargs.get("input_file_path", None)
|
||||
|
||||
# Validate that the target operation is a map or filter operation
|
||||
if target_op_config.get("type") not in ["map", "filter"]:
|
||||
raise ValueError(
|
||||
f"Document chunking topk directive can only be applied to map or filter operations, but target operation '{target_ops[0]}' is of type '{target_op_config.get('type')}'"
|
||||
)
|
||||
|
||||
# Instantiate the directive
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
operators, input_file_path, target_op_config, agent_llm, message_history
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators
|
||||
return (
|
||||
self.apply(global_default_model, operators, target_ops[0], rewrite),
|
||||
message_history,
|
||||
call_cost,
|
||||
)
|
||||
|
|
@ -1,254 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
DocCompressionInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class DocCompressionDirective(Directive):
|
||||
name: str = Field(
|
||||
default="doc_compression", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(default="Op => Extract -> Op")
|
||||
nl_description: str = Field(
|
||||
default="Reduces LLM processing costs by using an Extract operator to intelligently compress documents before expensive downstream operations, removing irrelevant content that could distract the LLM"
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When documents contain irrelevant content and you want to reduce token costs for downstream LLM operations while improving accuracy by having the LLM focus on only the essential content"
|
||||
)
|
||||
|
||||
instantiate_schema_type: Type[BaseModel] = Field(
|
||||
default=DocCompressionInstantiateSchema
|
||||
)
|
||||
|
||||
example: str = Field(
|
||||
default="""
|
||||
Target Operations:
|
||||
- name: analyze_regulatory_impact
|
||||
type: map
|
||||
prompt: |
|
||||
Analyze the potential regulatory impact described in: {{ input.legal_document }}
|
||||
Consider stakeholder groups, compliance burdens, and implementation feasibility.
|
||||
output:
|
||||
schema:
|
||||
stakeholder_impacts: "list[str]"
|
||||
compliance_changes: "string"
|
||||
|
||||
- name: extract_key_dates
|
||||
type: map
|
||||
prompt: |
|
||||
Extract important deadlines and dates from: {{ input.legal_document }}
|
||||
output:
|
||||
schema:
|
||||
deadlines: "list[str]"
|
||||
|
||||
Example InstantiateSchema (what the agent should output):
|
||||
DocCompressionConfig(
|
||||
name="extract_regulatory_content",
|
||||
document_key="legal_document",
|
||||
prompt="Extract the minimal content necessary spanning: sections defining new regulatory requirements, stakeholder obligations, compliance deadlines, implementation timelines, and enforcement mechanisms. Focus only on substantive regulatory changes and specific dates, not background or procedural text.",
|
||||
model="gpt-4o-mini"
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="single_target_compression",
|
||||
description="Should insert Extract operation before single target operation",
|
||||
input_config={
|
||||
"name": "analyze_document",
|
||||
"type": "map",
|
||||
"prompt": "Analyze this document: {{ input.document }}",
|
||||
"output": {"schema": {"analysis": "string"}},
|
||||
},
|
||||
target_ops=["analyze_document"],
|
||||
expected_behavior="Should add an Extract operation that compresses the document field before the analysis. The extract operation document_keys should be 'document' only.",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="multiple_target_compression",
|
||||
description="Should insert Extract operation before first target operation and consider all targets",
|
||||
input_config=[
|
||||
{
|
||||
"name": "extract_findings",
|
||||
"type": "map",
|
||||
"prompt": "Extract key findings from: {{ input.report }}",
|
||||
"output": {"schema": {"findings": "list[str]"}},
|
||||
},
|
||||
{
|
||||
"name": "analyze_impact",
|
||||
"type": "map",
|
||||
"prompt": "Analyze the business impact in: {{ input.report }}",
|
||||
"output": {"schema": {"impact": "string"}},
|
||||
},
|
||||
],
|
||||
target_ops=["extract_findings", "analyze_impact"],
|
||||
expected_behavior="Should add Extract operation before extract_findings that considers content needed for both operations. The extract operation document_keys should be 'report' only.",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, DocCompressionDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("DocCompressionDirective")
|
||||
|
||||
def to_string_for_instantiate(self, target_ops_configs: List[Dict]) -> str:
|
||||
"""
|
||||
Generate a prompt that asks the agent to output the instantiate schema.
|
||||
This prompt explains to the LLM what configuration it needs to generate.
|
||||
"""
|
||||
ops_str = "\n".join(
|
||||
[
|
||||
f"Operation {i+1}:\n{str(op)}\n"
|
||||
for i, op in enumerate(target_ops_configs)
|
||||
]
|
||||
)
|
||||
|
||||
return (
|
||||
f"You are an expert at document analysis and optimization.\n\n"
|
||||
f"Target Operations:\n"
|
||||
f"{ops_str}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to instantiate this directive by generating a DocCompressionConfig "
|
||||
f"that specifies how to compress the input document before processing.\n\n"
|
||||
f"The directive will insert an Extract operation that:\n"
|
||||
f"1. Takes a long document field from the input\n"
|
||||
f"2. Extracts only the MINIMAL content relevant to ALL the target operations\n"
|
||||
f"3. Replaces the original document field with the compressed content\n"
|
||||
f"4. Reduces token usage and improves focus for the downstream operations\n\n"
|
||||
f"The agent must output the configuration specifying:\n"
|
||||
f"- name: A descriptive name for the Extract operation\n"
|
||||
f"- document_key: Which document field contains the long content to compress\n"
|
||||
f"- prompt: Plain text instructions for what MINIMAL content to extract (NOT a Jinja template)\n"
|
||||
f"- model: Which model to use for extraction (typically a cheaper model like gpt-4o-mini)\n\n"
|
||||
f"IMPORTANT: The extraction prompt should focus on extracting the minimal content necessary "
|
||||
f"for ALL target operations. Analyze each operation's prompt to identify the "
|
||||
f"specific information types needed across all operations, then design an extraction prompt "
|
||||
f"that gets just those essential pieces while removing all irrelevant material.\n\n"
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output only the InstantiateSchema (DocCompressionConfig object) "
|
||||
f"that specifies how to apply this directive to the target operations."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
target_ops_configs: List[Dict],
|
||||
input_file_path: str,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
):
|
||||
"""
|
||||
Call the LLM to generate the instantiate schema.
|
||||
The LLM will output structured data matching DocCompressionInstantiateSchema.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(target_ops_configs),
|
||||
},
|
||||
]
|
||||
)
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=DocCompressionInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = DocCompressionInstantiateSchema(**parsed_res)
|
||||
schema.validate_document_keys_exists_in_input(input_file_path)
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model: str,
|
||||
ops_list: List[Dict],
|
||||
target_ops: List[str],
|
||||
rewrite: DocCompressionInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive using the instantiate schema configuration.
|
||||
Inserts an Extract operation before the first target operation.
|
||||
"""
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find the position of the first target operation
|
||||
first_target_pos = min(
|
||||
[i for i, op in enumerate(ops_list) if op["name"] in target_ops]
|
||||
)
|
||||
|
||||
extract_op = {
|
||||
"name": rewrite.name,
|
||||
"type": "extract",
|
||||
"prompt": rewrite.prompt,
|
||||
"document_keys": [rewrite.document_key],
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
"model": rewrite.model,
|
||||
}
|
||||
|
||||
# Insert the Extract operation before the first target operation
|
||||
new_ops_list.insert(first_target_pos, extract_op)
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Main method that orchestrates directive instantiation:
|
||||
1. Get agent to generate instantiate schema for all target operations
|
||||
2. Apply the transformation using that schema
|
||||
"""
|
||||
assert len(target_ops) >= 1, "This directive requires at least one target op"
|
||||
input_file_path = kwargs.get("input_file_path", None)
|
||||
|
||||
# Get configurations for all target operations
|
||||
target_ops_configs = [op for op in operators if op["name"] in target_ops]
|
||||
|
||||
# Step 1: Agent generates the instantiate schema considering all target ops
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
target_ops_configs, input_file_path, agent_llm, message_history
|
||||
)
|
||||
|
||||
# Step 2: Apply transformation using the schema
|
||||
return (
|
||||
self.apply(global_default_model, operators, target_ops, rewrite),
|
||||
message_history,
|
||||
call_cost,
|
||||
)
|
||||
|
|
@ -1,345 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
DocSummarizationInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class DocSummarizationDirective(Directive):
|
||||
name: str = Field(
|
||||
default="doc_summarization", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(default="Op => Map -> Op")
|
||||
nl_description: str = Field(
|
||||
default=(
|
||||
"Adds a Map summarization operator at the very beginning of the pipeline to shorten the document before any downstream operations. "
|
||||
"This reduces the number of tokens processed in later steps, saving cost and improving efficiency. "
|
||||
"The summary is constructed to include all information required by any downstream operator. "
|
||||
"Target operations should be all operators that reference the document key being summarized "
|
||||
"(e.g., all ops using {{ input.transcript }})."
|
||||
)
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default=(
|
||||
"Use when documents are too long or detailed for the downstream pipeline. "
|
||||
"Summarization should preserve essential information and make subsequent tasks more efficient. "
|
||||
"Target ops should include all operators that use the document key being summarized, and the summary model should be cheap."
|
||||
)
|
||||
)
|
||||
|
||||
instantiate_schema_type: Type[BaseModel] = Field(
|
||||
default=DocSummarizationInstantiateSchema
|
||||
)
|
||||
|
||||
example: str = Field(
|
||||
default="""
|
||||
"Original Pipeline - Complex medical reasoning across multiple operators:\n"
|
||||
"[\n"
|
||||
" {\n"
|
||||
" name: 'assess_drug_interactions',\n"
|
||||
" type: 'map',\n"
|
||||
" prompt: 'Analyze potential drug interactions from this consultation: {{ input.transcript }}. Consider contraindications.',\n"
|
||||
" output: { schema: { interaction_risks: 'list[dict]' } }\n"
|
||||
" },\n"
|
||||
" {\n"
|
||||
" name: 'predict_side_effects',\n"
|
||||
" type: 'map', \n"
|
||||
" prompt: 'Based on the transcript, predict likely side effects for this patient: {{ input.transcript }}. Consider patient demographics and medical history.',\n"
|
||||
" output: { schema: { predicted_effects: 'list[dict]' } }\n"
|
||||
" },\n"
|
||||
" {\n"
|
||||
" name: 'generate_monitoring_plan',\n"
|
||||
" type: 'map',\n"
|
||||
" prompt: 'Create a monitoring plan based on this consultation: {{ input.transcript }}. Focus on symptoms to watch for.',\n"
|
||||
" output: { schema: { monitoring_plan: 'string' } }\n"
|
||||
" }\n"
|
||||
"]\n"
|
||||
"\n"
|
||||
"Problem: 50-page transcript with scheduling, insurance, small talk distracts from medical reasoning.\n"
|
||||
"\n"
|
||||
"Example InstantiateSchema:\n"
|
||||
"[\n"
|
||||
" DocSummarizationConfig(\n"
|
||||
" name='extract_medical_essentials',\n"
|
||||
" document_key='transcript',\n"
|
||||
" prompt='Extract a summary of medical information from this consultation transcript for drug interaction analysis, side effect prediction, and monitoring plan creation. Include: all medications with dosages, patient complaints/symptoms, medical history, current conditions, patient demographics (age, weight), allergies, and any contraindications mentioned. Exclude scheduling, insurance, and casual conversation: {{ input.transcript }}',\n"
|
||||
" model='gpt-4o-mini'\n"
|
||||
" )\n"
|
||||
"]\n"
|
||||
"\n"
|
||||
"Result: All three reasoning operations work on focused medical facts, dramatically improving accuracy."
|
||||
""",
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="medical_pipeline_analysis",
|
||||
description="Should add summarization for multi-step medical reasoning pipeline",
|
||||
input_config=[
|
||||
{
|
||||
"name": "assess_drug_interactions",
|
||||
"type": "map",
|
||||
"prompt": "Analyze potential drug interactions from this consultation: {{ input.transcript }}. Consider contraindications and patient allergies.",
|
||||
"output": {"schema": {"interaction_risks": "list[str]"}},
|
||||
},
|
||||
{
|
||||
"name": "predict_side_effects",
|
||||
"type": "map",
|
||||
"prompt": "Predict likely side effects for this patient based on: {{ input.transcript }}. Consider age, weight, and medical history.",
|
||||
"output": {"schema": {"predicted_effects": "list[str]"}},
|
||||
},
|
||||
{
|
||||
"name": "create_monitoring_plan",
|
||||
"type": "map",
|
||||
"prompt": "Create patient monitoring plan from: {{ input.transcript }}. Focus on symptoms to watch and lab work needed.",
|
||||
"output": {"schema": {"monitoring_plan": "string"}},
|
||||
},
|
||||
],
|
||||
target_ops=[
|
||||
"assess_drug_interactions",
|
||||
"predict_side_effects",
|
||||
"create_monitoring_plan",
|
||||
],
|
||||
expected_behavior="Should add summarization that extracts comprehensive medical information (medications, dosages, allergies, demographics, symptoms, history) needed for all three downstream reasoning tasks",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="legal_contract_pipeline",
|
||||
description="Should add summarization for multi-step legal analysis pipeline",
|
||||
input_config=[
|
||||
{
|
||||
"name": "identify_liability_risks",
|
||||
"type": "map",
|
||||
"prompt": "Identify liability and indemnification risks in: {{ input.contract_document }}. Focus on limitation of liability clauses.",
|
||||
"output": {"schema": {"liability_risks": "list[str]"}},
|
||||
},
|
||||
{
|
||||
"name": "analyze_termination_terms",
|
||||
"type": "map",
|
||||
"prompt": "Analyze termination and breach conditions in: {{ input.contract_document }}. Include notice requirements and penalties.",
|
||||
"output": {"schema": {"termination_analysis": "string"}},
|
||||
},
|
||||
{
|
||||
"name": "assess_compliance_requirements",
|
||||
"type": "map",
|
||||
"prompt": "Assess regulatory compliance obligations from: {{ input.contract_document }}. Include data protection and industry standards.",
|
||||
"output": {"schema": {"compliance_requirements": "list[str]"}},
|
||||
},
|
||||
],
|
||||
target_ops=[
|
||||
"identify_liability_risks",
|
||||
"analyze_termination_terms",
|
||||
"assess_compliance_requirements",
|
||||
],
|
||||
expected_behavior="Should add summarization that extracts all legally relevant clauses, terms, obligations, penalties, and compliance requirements needed across the legal analysis pipeline",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="financial_analysis_pipeline",
|
||||
description="Should add summarization for multi-step investment analysis pipeline",
|
||||
input_config=[
|
||||
{
|
||||
"name": "evaluate_financial_health",
|
||||
"type": "map",
|
||||
"prompt": "Evaluate company financial health from: {{ input.annual_report }}. Calculate key ratios and assess profitability trends.",
|
||||
"output": {
|
||||
"schema": {
|
||||
"financial_metrics": "str",
|
||||
"health_score": "float",
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "assess_market_position",
|
||||
"type": "map",
|
||||
"prompt": "Assess competitive market position using: {{ input.annual_report }}. Analyze market share and competitive advantages.",
|
||||
"output": {
|
||||
"schema": {
|
||||
"market_analysis": "string",
|
||||
"competitive_strengths": "list[str]",
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "identify_growth_risks",
|
||||
"type": "map",
|
||||
"prompt": "Identify growth opportunities and risks from: {{ input.annual_report }}. Include regulatory and market risks.",
|
||||
"output": {
|
||||
"schema": {
|
||||
"growth_opportunities": "list[str]",
|
||||
"risk_factors": "list[str]",
|
||||
}
|
||||
},
|
||||
},
|
||||
],
|
||||
target_ops=[
|
||||
"evaluate_financial_health",
|
||||
"assess_market_position",
|
||||
"identify_growth_risks",
|
||||
],
|
||||
expected_behavior="Should add summarization that extracts financial data, market information, competitive landscape, and risk factors needed for comprehensive investment analysis",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, DocSummarizationDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("DocSummarizationDirective")
|
||||
|
||||
def to_string_for_instantiate(
|
||||
self, operators: List[Dict], target_ops: List[str]
|
||||
) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive.
|
||||
|
||||
Args:
|
||||
operators: List of all operators in the pipeline
|
||||
target_ops: List of target operation names that need the summarized content
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
return (
|
||||
f"You are an expert at adding document summarization to data processing pipelines.\n\n"
|
||||
f"Full Pipeline Context:\n"
|
||||
f"{str(operators)}\n\n"
|
||||
f"Target Operations (that will use summarized content): {target_ops}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to instantiate this directive by generating a DocSummarizationConfig that adds a Map summarization operator "
|
||||
f"at the start of the pipeline.\n\n"
|
||||
f"Analysis steps:\n"
|
||||
f"1. Identify which input field contains long documents that could benefit from summarization\n"
|
||||
f"2. Analyze ALL target operations' prompts to understand what information each needs\n"
|
||||
f"3. Create a comprehensive summarization prompt that preserves ALL information needed by ANY target operation\n"
|
||||
f"4. Ensure the summary contains sufficient detail for all downstream reasoning tasks\n\n"
|
||||
f"The document_key should be the field name containing the long content to summarize.\n"
|
||||
f"The prompt should instruct the LLM to extract and preserve ALL information types needed across ALL target operations.\n"
|
||||
f"The output will replace the original document field with the summarized version.\n\n"
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output only the InstantiateSchema (a DocSummarizationConfig object)."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive by creating a summarization operation.
|
||||
|
||||
Args:
|
||||
operators (List[Dict]): All operators in the pipeline.
|
||||
target_ops (List[str]): Target operation names that need summarized content.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
|
||||
Returns:
|
||||
DocSummarizationInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(operators, target_ops),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=DocSummarizationInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = DocSummarizationInstantiateSchema(**parsed_res)
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model: str,
|
||||
ops_list: List[Dict],
|
||||
target_op: str,
|
||||
rewrite: DocSummarizationInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config by adding a summarization Map operator at the start.
|
||||
"""
|
||||
# Create a copy of the pipeline config
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Create the summarization Map operator using the LLM-generated name
|
||||
summarization_op = {
|
||||
"name": rewrite.name,
|
||||
"type": "map",
|
||||
"prompt": rewrite.prompt,
|
||||
"model": rewrite.model,
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
"output": {"schema": {rewrite.document_key: "string"}},
|
||||
}
|
||||
|
||||
# Insert the summarization operator at the beginning of the pipeline
|
||||
new_ops_list.insert(0, summarization_op)
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Validate that target ops exist in the pipeline
|
||||
operator_names = {op["name"] for op in operators}
|
||||
for target_op in target_ops:
|
||||
if target_op not in operator_names:
|
||||
raise ValueError(
|
||||
f"Target operation '{target_op}' not found in pipeline"
|
||||
)
|
||||
|
||||
# Instantiate the directive using full pipeline context
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
operators, target_ops, agent_llm, message_history
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators (use first target op for compatibility with apply method)
|
||||
return (
|
||||
self.apply(global_default_model, operators, target_ops[0], rewrite),
|
||||
message_history,
|
||||
call_cost,
|
||||
)
|
||||
|
|
@ -1,231 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import GleaningInstantiateSchema
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class GleaningDirective(Directive):
|
||||
name: str = Field(default="gleaning", description="The name of the directive")
|
||||
formal_description: str = Field(default="Map => Map_m (with gleaning config)")
|
||||
nl_description: str = Field(
|
||||
default="""Adds a validation loop to Map: after each LLM generation, a separate "judge" LLM evaluates the output using a yes/no validation prompt. If the output fails, the original LLM refines its answer and repeats until the output passes or the max number of rounds is reached."""
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When initial Map outputs may not meet quality criteria and must be checked or improved automatically (e.g., too short, missing required info)."
|
||||
)
|
||||
|
||||
# Remove from Pydantic fields, make it a plain class variable
|
||||
instantiate_schema_type: Type[BaseModel] = Field(default=GleaningInstantiateSchema)
|
||||
|
||||
example: str = Field(
|
||||
default="""
|
||||
Original Op (MapOpConfig):
|
||||
- name: extract_insights
|
||||
type: map
|
||||
prompt: |
|
||||
From the user log below, list 2-3 concise insights (1-2 words each) and 1-2 supporting actions per insight.
|
||||
Return as a list of dictionaries with 'insight' and 'supporting_actions'.
|
||||
Log: {{ input.log }}
|
||||
output:
|
||||
schema:
|
||||
insights_summary: "string"
|
||||
|
||||
Example InstantiateSchema (what the agent should output):
|
||||
{
|
||||
"validation_prompt": "There should be at least 2 insights, and each insight should have at least 1 supporting action.",
|
||||
"num_rounds": 2,
|
||||
"model": "gpt-4o-mini"
|
||||
}
|
||||
""",
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="quality_validation_insights",
|
||||
description="Should add validation for insight extraction quality",
|
||||
input_config={
|
||||
"name": "extract_insights",
|
||||
"type": "map",
|
||||
"prompt": "From the user log below, list 2-3 concise insights and supporting actions: {{ input.log }}",
|
||||
"output": {"schema": {"insights_summary": "string"}},
|
||||
},
|
||||
target_ops=["extract_insights"],
|
||||
expected_behavior="Should add gleaning config with validation prompt checking for minimum number of insights and supporting actions",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="contract_term_validation",
|
||||
description="Should add validation for contract term extraction completeness",
|
||||
input_config={
|
||||
"name": "extract_contract_terms",
|
||||
"type": "map",
|
||||
"prompt": "Extract payment terms, deadlines, and penalties from: {{ input.contract }}",
|
||||
"output": {"schema": {"terms": "list[str]"}},
|
||||
},
|
||||
target_ops=["extract_contract_terms"],
|
||||
expected_behavior="Should add gleaning validation to ensure all required term types are extracted",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="financial_report_analysis",
|
||||
description="Should add validation for financial analysis accuracy",
|
||||
input_config={
|
||||
"name": "analyze_financial_report",
|
||||
"type": "map",
|
||||
"prompt": "Extract revenue, expenses, profit margins, and key financial ratios from: {{ input.report }}",
|
||||
"output": {"schema": {"financial_analysis": "string"}},
|
||||
},
|
||||
target_ops=["analyze_financial_report"],
|
||||
expected_behavior="Should add gleaning validation to ensure all financial metrics are accurately extracted and calculated",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, GleaningDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("GleaningDirective")
|
||||
|
||||
def to_string_for_instantiate(self, original_op: Dict) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive.
|
||||
|
||||
Args:
|
||||
original_op (str): The YAML or string representation of the original operation.
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
return (
|
||||
f"You are an expert at adding validation and refinement loops to data processing operations.\n\n"
|
||||
f"Original Operation:\n"
|
||||
f"{str(original_op)}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to instantiate this directive by generating a configuration that adds validation loops to the original operation. "
|
||||
f"The gleaning configuration should include a validation prompt that evaluates the output quality and provides feedback for improvement, "
|
||||
f"along with the number of refinement rounds to attempt. In the prompt, you shuold not use any Jinja variables. \n"
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output only the GleaningInstantiateSchema object that specifies how to validate and refine the output of the original operation."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
original_op: Dict,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive by decomposing the original operation.
|
||||
|
||||
Args:
|
||||
original_op (Dict): The original operation.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
|
||||
Returns:
|
||||
GleaningInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(original_op),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=GleaningInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = GleaningInstantiateSchema(**parsed_res)
|
||||
GleaningInstantiateSchema.check_no_jinja_variables(
|
||||
schema.validation_prompt
|
||||
)
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
default_model: str,
|
||||
ops_list: List[Dict],
|
||||
target_op: str,
|
||||
rewrite: GleaningInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config by adding gleaning configuration to the target operator.
|
||||
"""
|
||||
# Create a copy of the pipeline config
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find position of the target op to modify
|
||||
pos_to_replace = [
|
||||
i for i, op in enumerate(ops_list) if op["name"] == target_op
|
||||
][0]
|
||||
|
||||
# Add gleaning configuration to the target operator
|
||||
target_operator = new_ops_list[pos_to_replace]
|
||||
target_operator["gleaning"] = {
|
||||
"validation_prompt": rewrite.validation_prompt,
|
||||
"num_rounds": rewrite.num_rounds,
|
||||
"model": rewrite.model,
|
||||
}
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Assert that there is only one target op
|
||||
assert (
|
||||
len(target_ops) == 1
|
||||
), "There must be exactly one target op to instantiate this chaining directive"
|
||||
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
|
||||
|
||||
# Instantiate the directive
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
target_op_config, agent_llm, message_history
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators
|
||||
return (
|
||||
self.apply(global_default_model, operators, target_ops[0], rewrite),
|
||||
message_history,
|
||||
call_cost,
|
||||
)
|
||||
|
|
@ -1,322 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
HierarchicalReduceInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class HierarchicalReduceDirective(Directive):
|
||||
name: str = Field(
|
||||
default="hierarchical_reduce", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(
|
||||
default="Reduce => (Map* ->) Reduce -> Reduce (optionally with Map before first Reduce)"
|
||||
)
|
||||
nl_description: str = Field(
|
||||
default="Transform a reduce operation that aggregates large groups of documents by first aggregating at a finer granularity (reduce_key + additional_key), then rolling up to the desired level (reduce_key only). This hierarchical approach can capture nuances that might be lost in a single large-scale aggregation and allows for intermediate validation."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When a reduce operation processes many documents per group and it would be beneficial to first aggregate at a finer granularity before rolling up. Useful when there's a semantic hierarchy in the data (e.g., aggregate by state+city first, then by state only) or when you want to prevent information loss in large-scale aggregations. The target operator must be a reduce operator."
|
||||
)
|
||||
instantiate_schema_type: Type[BaseModel] = HierarchicalReduceInstantiateSchema
|
||||
|
||||
example: str = Field(
|
||||
default=(
|
||||
"Original Reduce Op:\n"
|
||||
"- name: summarize_by_state\n"
|
||||
" type: reduce\n"
|
||||
" reduce_key: state\n"
|
||||
" prompt: |\n"
|
||||
" Summarize voting patterns from these social media posts:\n"
|
||||
" {% for input in inputs %}\n"
|
||||
" Post: {{ input.content }}\n"
|
||||
" {% endfor %}\n"
|
||||
" Return a summary of voting patterns.\n"
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
" summary: string\n"
|
||||
"\n"
|
||||
"Example InstantiateSchema (with Map for synthetic key):\n"
|
||||
"HierarchicalReduceInstantiateSchema(\n"
|
||||
" map_config=MapOpConfig(\n"
|
||||
" name='extract_city',\n"
|
||||
" prompt='Extract the city mentioned in this post:\\n{{ input.content }} made in this state:\\n{{ input.state }}\\nReturn the city name or \"Unknown\" if not found.',\n"
|
||||
" output_keys=['city'],\n"
|
||||
" ),\n"
|
||||
" additional_key='city',\n"
|
||||
" reduce_1_name='summarize_by_state_city',\n"
|
||||
" # First reduce: Process raw posts at city level\n"
|
||||
" reduce_1_prompt='Goal: Summarize voting patterns from social media posts to understand public sentiment.\\n\\nFor this state and city, analyze these posts:\\n{% for input in inputs %}\\nPost: {{ input.content }}\\n{% endfor %}\\nReturn a summary of voting patterns and key themes.',\n"
|
||||
" # Second reduce: Explicitly work with summaries from first reduce\n"
|
||||
" reduce_2_prompt='Goal: Summarize voting patterns from social media posts to understand public sentiment.\\n\\nWe have already summarized voting patterns at the city level. Your task is to combine these city-level summaries into a comprehensive state-level analysis:\\n{% for input in inputs %}\\nCity: {{ input.city }}\\nCity-Level Summary: {{ input.summary }}\\n{% endfor %}\\nSynthesize these city summaries into a unified state-level summary of voting patterns.',\n"
|
||||
")\n"
|
||||
"\n"
|
||||
"Example InstantiateSchema (using existing key):\n"
|
||||
"HierarchicalReduceInstantiateSchema(\n"
|
||||
" map_config=None, # No Map needed when using existing key\n"
|
||||
" additional_key='county', # Assuming 'county' already exists in the data\n"
|
||||
" reduce_1_name='summarize_by_state_county',\n"
|
||||
" # First reduce: Process raw posts at county level\n"
|
||||
" reduce_1_prompt='Goal: Summarize voting patterns from social media posts.\\n\\nAnalyze posts for this state and county:\\n{% for input in inputs %}\\nPost: {{ input.content }}\\n{% endfor %}\\nReturn voting pattern summary for this county.',\n"
|
||||
" # Second reduce: Explicitly work with county summaries\n"
|
||||
" reduce_2_prompt='Goal: Summarize voting patterns from social media posts.\\n\\nWe have already analyzed voting patterns at the county level. Your task is to synthesize these county-level summaries into a state-level overview:\\n{% for input in inputs %}\\nCounty: {{ input.county }}\\nCounty Analysis: {{ input.summary }}\\n{% endfor %}\\nCombine these county analyses into a comprehensive state voting pattern summary.',\n"
|
||||
")"
|
||||
),
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="voting_pattern_aggregation",
|
||||
description="Should create hierarchical aggregation for voting patterns",
|
||||
input_config={
|
||||
"name": "analyze_voting",
|
||||
"type": "reduce",
|
||||
"reduce_key": "state",
|
||||
"prompt": "Analyze voting sentiments from these posts:\n{% for input in inputs %}\nPost: {{ input.post }}\n{% endfor %}\nReturn voting sentiment analysis.",
|
||||
"output": {"schema": {"analysis": "string"}},
|
||||
},
|
||||
target_ops=["analyze_voting"],
|
||||
expected_behavior="Should create two reduce operations: first by state+additional_key, then by state only, optionally with a Map to create synthetic keys",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="sales_hierarchical_aggregation",
|
||||
description="Should create hierarchical aggregation for sales data",
|
||||
input_config={
|
||||
"name": "aggregate_sales",
|
||||
"type": "reduce",
|
||||
"reduce_key": "region",
|
||||
"prompt": "Aggregate sales data from these records:\n{% for input in inputs %}\nSales: {{ input.sales_data }}\n{% endfor %}\nReturn total sales metrics.",
|
||||
"output": {"schema": {"metrics": "object"}},
|
||||
},
|
||||
target_ops=["aggregate_sales"],
|
||||
expected_behavior="Should create hierarchical reduce pattern for sales aggregation",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, HierarchicalReduceDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("HierarchicalReduceDirective")
|
||||
|
||||
def to_string_for_instantiate(self, original_op: Dict) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive.
|
||||
|
||||
Args:
|
||||
original_op (Dict): The original reduce operation.
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
return (
|
||||
f"You are an expert at optimizing data processing operations using hierarchical aggregation patterns.\n\n"
|
||||
f"Original Operation:\n"
|
||||
f"{str(original_op)}\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to instantiate this directive by creating a hierarchical reduce pattern that:\n"
|
||||
f"1. First aggregates data at a finer granularity (original reduce_key + additional_key)\n"
|
||||
f"2. Then rolls up these finer-grained aggregations to the desired level (original reduce_key only)\n\n"
|
||||
f"Key Requirements:\n"
|
||||
f"1. Identify or create an appropriate additional key for finer granularity:\n"
|
||||
f" - Check if there's an existing key in the data that provides meaningful sub-grouping\n"
|
||||
f" - If no suitable key exists, create a Map operation to synthesize one (e.g., extract city from text)\n"
|
||||
f"2. Adapt the original reduce prompt for both aggregation levels:\n"
|
||||
f" - reduce_1_prompt: Should aggregate at the finer granularity (both keys) from raw data\n"
|
||||
f" - reduce_2_prompt: Should combine the outputs from reduce_1 to the target granularity\n"
|
||||
f"3. IMPORTANT: Both reduce prompts should understand the overall goal from the original operation:\n"
|
||||
f" - Each reduce operation runs independently without access to the original prompt\n"
|
||||
f" - Both prompts should share a common context/prefix that explains the overall goal\n"
|
||||
f" - The second reduce MUST explicitly acknowledge it's working with summaries/aggregations from the first reduce\n"
|
||||
f" - Example: 'We have already summarized X at the Y level. Your task is to combine these Y-level summaries...'\n"
|
||||
f"4. Both reduce operations should produce the same output schema as the original\n"
|
||||
f"5. The reduce_2_prompt must reference the outputs from reduce_1, not the original documents\n\n"
|
||||
f"This hierarchical approach is especially useful when:\n"
|
||||
f"- There are many documents per group in the original reduce\n"
|
||||
f"- There's a natural hierarchy in the data (geographic, temporal, categorical)\n"
|
||||
f"- You want to prevent information loss in large-scale aggregations\n\n"
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output the HierarchicalReduceInstantiateSchema with the hierarchical aggregation details."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self, original_op: Dict, agent_llm: str, message_history: list = []
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive by creating a hierarchical reduce pattern.
|
||||
|
||||
Args:
|
||||
original_op (Dict): The original reduce operation.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
|
||||
Returns:
|
||||
HierarchicalReduceInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant for document processing pipelines specializing in hierarchical aggregation patterns.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(original_op),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=HierarchicalReduceInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = HierarchicalReduceInstantiateSchema(**parsed_res)
|
||||
|
||||
# Validate that if map_config is provided, additional_key should match one of the output keys
|
||||
if schema.map_config:
|
||||
if len(schema.map_config.output_keys) != 1:
|
||||
raise ValueError(
|
||||
"Map config must have exactly one output key for hierarchical reduce"
|
||||
)
|
||||
if schema.additional_key != schema.map_config.output_keys[0]:
|
||||
raise ValueError(
|
||||
f"When creating a synthetic key with Map, additional_key must match the map output key '{schema.map_config.output_keys[0]}'"
|
||||
)
|
||||
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model,
|
||||
ops_list: List[Dict],
|
||||
target_op: str,
|
||||
rewrite: HierarchicalReduceInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config.
|
||||
"""
|
||||
# Create a copy of the pipeline config
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find position of the target reduce op to modify
|
||||
pos_to_modify = None
|
||||
orig_op = None
|
||||
for i, op in enumerate(ops_list):
|
||||
if op["name"] == target_op:
|
||||
pos_to_modify = i
|
||||
orig_op = op
|
||||
break
|
||||
|
||||
if pos_to_modify is None:
|
||||
raise ValueError(
|
||||
f"Target operation '{target_op}' not found in operations list"
|
||||
)
|
||||
|
||||
# Determine the model to use
|
||||
default_model = orig_op.get("model", global_default_model)
|
||||
|
||||
operations_to_insert = []
|
||||
|
||||
# Create the optional Map operation if specified
|
||||
if rewrite.map_config:
|
||||
new_map_op = {
|
||||
"name": rewrite.map_config.name,
|
||||
"type": "map",
|
||||
"prompt": rewrite.map_config.prompt,
|
||||
"model": default_model,
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
"output": {"schema": {rewrite.map_config.output_keys[0]: "string"}},
|
||||
}
|
||||
operations_to_insert.append(new_map_op)
|
||||
|
||||
# Create the first reduce operation (finer granularity)
|
||||
first_reduce_op = deepcopy(orig_op)
|
||||
first_reduce_op["name"] = rewrite.reduce_1_name
|
||||
first_reduce_op["reduce_key"] = (
|
||||
[orig_op["reduce_key"], rewrite.additional_key]
|
||||
if isinstance(orig_op["reduce_key"], str)
|
||||
else orig_op["reduce_key"] + [rewrite.additional_key]
|
||||
)
|
||||
first_reduce_op["prompt"] = rewrite.reduce_1_prompt
|
||||
first_reduce_op["model"] = default_model
|
||||
operations_to_insert.append(first_reduce_op)
|
||||
|
||||
# Create the second reduce operation (target granularity)
|
||||
second_reduce_op = deepcopy(orig_op)
|
||||
second_reduce_op["prompt"] = rewrite.reduce_2_prompt
|
||||
second_reduce_op["model"] = default_model
|
||||
operations_to_insert.append(second_reduce_op)
|
||||
|
||||
# Replace the original operation with the new operations
|
||||
new_ops_list[pos_to_modify : pos_to_modify + 1] = operations_to_insert
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
optimize_goal="acc",
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Assert that there is only one target op
|
||||
assert (
|
||||
len(target_ops) == 1
|
||||
), "There must be exactly one target op to instantiate this hierarchical reduce directive"
|
||||
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
|
||||
|
||||
# Ensure it's a reduce operation
|
||||
if target_op_config.get("type") != "reduce":
|
||||
raise ValueError(
|
||||
f"Target operation '{target_ops[0]}' must be a reduce operation"
|
||||
)
|
||||
|
||||
# Instantiate the directive
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
target_op_config,
|
||||
agent_llm,
|
||||
message_history,
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators
|
||||
new_ops_plan = self.apply(
|
||||
global_default_model, operators, target_ops[0], rewrite
|
||||
)
|
||||
return new_ops_plan, message_history, call_cost
|
||||
|
|
@ -1,412 +0,0 @@
|
|||
import json
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
IsolatingSubtasksInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class IsolatingSubtasksDirective(Directive):
|
||||
name: str = Field(
|
||||
default="isolating_subtasks", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(default="Map => Parallel Map -> Map")
|
||||
nl_description: str = Field(
|
||||
default="Rewrites a single Map into a Parallel Map that isolates subtasks and generates separate outputs for each, followed by a Map that aggregates or synthesizes the results."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When the original Map is overloaded—either the prompt asks for many different things OR the output schema has many fields—and subtasks are better handled independently (e.g., extract each attribute in parallel, then combine into a unified output)."
|
||||
)
|
||||
|
||||
instantiate_schema_type: Type[BaseModel] = Field(
|
||||
default=IsolatingSubtasksInstantiateSchema
|
||||
)
|
||||
|
||||
example: str = Field(
|
||||
default="""
|
||||
Original Op (MapOpConfig):
|
||||
- name: extract_contract_info
|
||||
type: map
|
||||
prompt: |
|
||||
Extract the following from this contract: {{ input.document }}
|
||||
- party names
|
||||
- agreement date
|
||||
- governing law
|
||||
- termination clauses
|
||||
output:
|
||||
schema:
|
||||
parties: "string"
|
||||
agreement_date: "string"
|
||||
governing_law: "string"
|
||||
termination_clauses: "string"
|
||||
|
||||
Example InstantiateSchema (what the agent should output):
|
||||
IsolatingSubtasksConfig(
|
||||
subtasks=[
|
||||
SubtaskConfig(
|
||||
name="Extract Basic Contract Info",
|
||||
prompt="Extract party names and agreement date from: {{ input.document }}",
|
||||
output_keys=["parties", "agreement_date"]
|
||||
),
|
||||
SubtaskConfig(
|
||||
name="Extract Legal Terms",
|
||||
prompt="Extract governing law and termination clauses from: {{ input.document }}",
|
||||
output_keys=["governing_law", "termination_clauses"]
|
||||
)
|
||||
],
|
||||
aggregation_prompt="Combine the basic info {{ input.subtask_1_output }} with legal terms {{ input.subtask_2_output }} into the final contract summary."
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="complex_prompt_simple_output",
|
||||
description="Complex multi-task prompt but simple list output - should isolate by prompt complexity",
|
||||
input_config={
|
||||
"name": "analyze_document",
|
||||
"type": "map",
|
||||
"prompt": "Analyze this document for: 1) sentiment and emotional tone, 2) key topics and themes, 3) factual accuracy and bias, 4) writing quality and readability, 5) actionable insights and recommendations. Document: {{ input.text }}",
|
||||
"output": {"schema": {"results": "list[string]"}},
|
||||
},
|
||||
target_ops=["analyze_document"],
|
||||
expected_behavior="Should create >1 parallel map prompts covering all analysis aspects (sentiment, topics, bias, quality, insights) and aggregation prompt referencing all subtask outputs",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="contract_analysis_many_fields",
|
||||
description="Legal contract extraction with many specific fields",
|
||||
input_config={
|
||||
"name": "extract_contract_terms",
|
||||
"type": "map",
|
||||
"prompt": "Extract contract information from: {{ input.contract_text }}",
|
||||
"output": {
|
||||
"schema": {
|
||||
"parties": "string",
|
||||
"agreement_date": "string",
|
||||
"governing_law": "string",
|
||||
"termination_clause": "string",
|
||||
"payment_terms": "string",
|
||||
"liability_cap": "string",
|
||||
}
|
||||
},
|
||||
},
|
||||
target_ops=["extract_contract_terms"],
|
||||
expected_behavior="Should create >1 parallel map prompts covering all 6 fields (parties, agreement_date, governing_law, termination_clause, payment_terms, liability_cap) and aggregation prompt referencing all subtask outputs",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="medical_transcript_processing",
|
||||
description="Medical data extraction - different subtasks for different medical info types",
|
||||
input_config={
|
||||
"name": "process_medical_record",
|
||||
"type": "map",
|
||||
"prompt": "Extract patient demographics, symptoms, diagnosis, and treatment plan from: {{ input.transcript }}",
|
||||
"output": {
|
||||
"schema": {
|
||||
"patient_info": "string",
|
||||
"symptoms": "string",
|
||||
"diagnosis": "string",
|
||||
"treatment": "string",
|
||||
}
|
||||
},
|
||||
},
|
||||
target_ops=["process_medical_record"],
|
||||
expected_behavior="Should create >1 parallel map prompts covering all medical aspects (demographics, symptoms, diagnosis, treatment) and aggregation prompt referencing all subtask outputs",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="research_paper_summary",
|
||||
description="Academic paper analysis with focus on different aspects",
|
||||
input_config={
|
||||
"name": "summarize_research",
|
||||
"type": "map",
|
||||
"prompt": "Analyze this research paper for methodology, key findings, limitations, and practical applications: {{ input.paper_text }}",
|
||||
"output": {
|
||||
"schema": {"summary": "string", "key_points": "list[string]"}
|
||||
},
|
||||
},
|
||||
target_ops=["summarize_research"],
|
||||
expected_behavior="Should create >1 parallel map prompts covering all research aspects (methodology, findings, limitations, applications) and aggregation prompt referencing all subtask outputs",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="customer_feedback_analysis",
|
||||
description="Multi-aspect customer feedback analysis with simple output",
|
||||
input_config={
|
||||
"name": "analyze_feedback",
|
||||
"type": "map",
|
||||
"prompt": "Analyze customer feedback for: product quality issues, service experience problems, pricing concerns, feature requests, and overall satisfaction. Feedback: {{ input.feedback_text }}",
|
||||
"output": {
|
||||
"schema": {
|
||||
"insights": "list[string]",
|
||||
"priority_score": "string",
|
||||
}
|
||||
},
|
||||
},
|
||||
target_ops=["analyze_feedback"],
|
||||
expected_behavior="Should create >1 parallel map prompts covering all feedback aspects (quality, service, pricing, features, satisfaction) and aggregation prompt referencing all subtask outputs",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="financial_report_extraction",
|
||||
description="Financial document with many specific metrics to extract",
|
||||
input_config={
|
||||
"name": "extract_financial_data",
|
||||
"type": "map",
|
||||
"prompt": "Extract financial metrics from earnings report: {{ input.report }}",
|
||||
"output": {
|
||||
"schema": {
|
||||
"revenue": "string",
|
||||
"profit_margin": "string",
|
||||
"cash_flow": "string",
|
||||
"debt_ratio": "string",
|
||||
"growth_rate": "string",
|
||||
"market_share": "string",
|
||||
}
|
||||
},
|
||||
},
|
||||
target_ops=["extract_financial_data"],
|
||||
expected_behavior="Should create >1 parallel map prompts covering all 6 financial metrics (revenue, profit_margin, cash_flow, debt_ratio, growth_rate, market_share) and aggregation prompt referencing all subtask outputs",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, IsolatingSubtasksDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("IsolatingSubtasksDirective")
|
||||
|
||||
def to_string_for_instantiate(self, original_op: Dict) -> str:
|
||||
"""
|
||||
Generate a prompt that asks the agent to output the instantiate schema.
|
||||
"""
|
||||
# Extract original operation details
|
||||
original_name = original_op.get("name", "unknown")
|
||||
original_prompt = original_op.get("prompt", "")
|
||||
original_output_schema = original_op.get("output", {}).get("schema", {})
|
||||
original_output_keys = (
|
||||
list(original_output_schema.keys()) if original_output_schema else []
|
||||
)
|
||||
|
||||
# Find the input key from the original prompt (look for {{ input.XXX }} pattern)
|
||||
|
||||
input_matches = re.findall(r"\{\{\s*input\.([^}\s]+)\s*\}\}", original_prompt)
|
||||
input_key = input_matches[0] if input_matches else "document"
|
||||
|
||||
return (
|
||||
f"You are an expert at analyzing overloaded map operations and breaking them into focused subtasks.\n\n"
|
||||
f"Original Operation:\n"
|
||||
f"Name: {original_name}\n"
|
||||
f"Prompt: {original_prompt}\n"
|
||||
f"Output Schema: {original_output_schema}\n"
|
||||
f"Output Keys: {original_output_keys}\n\n"
|
||||
f"This map operation is overloaded - either the prompt asks for many different things "
|
||||
f"OR it has {len(original_output_keys)} output fields to generate. "
|
||||
f"Your task is to create an IsolatingSubtasksConfig with:\n\n"
|
||||
f"1. **SUBTASKS**: Group the {len(original_output_keys)} output fields into 2-4 logical subtasks "
|
||||
f"where each subtask handles related fields that can be processed independently:\n"
|
||||
f" - Each subtask needs a descriptive 'name'\n"
|
||||
f" - Each subtask needs a focused Jinja 'prompt' that uses {{{{ input.{input_key} }}}} (same as original)\n"
|
||||
f" - Each subtask needs 'output_keys' listing which fields it extracts\n"
|
||||
f" - Every original output key must appear in exactly one subtask's output_keys\n\n"
|
||||
f"2. **AGGREGATION_PROMPT**: A Jinja template that combines all subtask results:\n"
|
||||
f" - Must reference ALL subtask outputs as {{{{ input.subtask_1_output }}}}, {{{{ input.subtask_2_output }}}}, etc.\n"
|
||||
f" - Should synthesize/combine the subtask results into the final output\n"
|
||||
f" - Final result must match the original output schema exactly\n"
|
||||
f" - Example: 'Combine basic info {{{{ input.subtask_1_output }}}} with details {{{{ input.subtask_2_output }}}} into complete result.'\n\n"
|
||||
f"CRITICAL REQUIREMENTS:\n"
|
||||
f"- All {len(original_output_keys)} original output keys must be covered by subtasks\n"
|
||||
f"- Subtask prompts must use {{{{ input.{input_key} }}}} (same input as original)\n"
|
||||
f"- Aggregation prompt must reference {{{{ input.subtask_N_output }}}} for each subtask\n"
|
||||
f"- No information should be lost in the transformation\n\n"
|
||||
f"Example logical grouping for contract extraction:\n"
|
||||
f"- Subtask 1: Basic info (parties, dates) \n"
|
||||
f"- Subtask 2: Legal terms (governing law, clauses)\n"
|
||||
f"- Subtask 3: Commercial terms (pricing, commitments)\n\n"
|
||||
f"Please output the IsolatingSubtasksConfig that transforms this overloaded operation."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
original_op: Dict,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
):
|
||||
"""
|
||||
Call the LLM to generate the instantiate schema with validation.
|
||||
"""
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(original_op),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
original_output_keys = list(
|
||||
original_op.get("output", {}).get("schema", {}).keys()
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=IsolatingSubtasksInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = IsolatingSubtasksInstantiateSchema(**parsed_res)
|
||||
|
||||
# Use the schema's validation methods
|
||||
schema.validate_subtasks_coverage(original_output_keys)
|
||||
schema.validate_aggregation_references_all_subtasks()
|
||||
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease ensure all original output keys are covered by subtasks and try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model,
|
||||
ops_list: List[Dict],
|
||||
target_op: str,
|
||||
rewrite: IsolatingSubtasksInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive by replacing the target map with parallel map + aggregation map.
|
||||
"""
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find the target operation
|
||||
pos_to_replace = None
|
||||
original_op = None
|
||||
for i, op in enumerate(ops_list):
|
||||
if op["name"] == target_op:
|
||||
pos_to_replace = i
|
||||
original_op = op
|
||||
break
|
||||
|
||||
if pos_to_replace is None:
|
||||
raise ValueError(f"Target operation '{target_op}' not found")
|
||||
|
||||
# Create the parallel map operation
|
||||
parallel_map_op = {
|
||||
"name": f"{target_op}_parallel",
|
||||
"type": "parallel_map",
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
"prompts": [],
|
||||
}
|
||||
|
||||
assert original_op
|
||||
# Copy over other fields from original operation (sample, random_sample, etc.)
|
||||
for key, value in original_op.items():
|
||||
if key not in ["name", "type", "prompt", "output"]:
|
||||
parallel_map_op[key] = value
|
||||
|
||||
# Set up output schema for parallel map
|
||||
parallel_output_schema = {}
|
||||
|
||||
# Add each subtask as a prompt in the parallel map
|
||||
for i, subtask in enumerate(rewrite.subtasks, 1):
|
||||
subtask_output_key = f"subtask_{i}_output"
|
||||
|
||||
prompt_config = {
|
||||
"name": subtask.name,
|
||||
"output_keys": [subtask_output_key],
|
||||
"prompt": subtask.prompt,
|
||||
}
|
||||
|
||||
parallel_map_op["prompts"].append(prompt_config)
|
||||
parallel_output_schema[subtask_output_key] = "string"
|
||||
|
||||
# Set the output schema for parallel map
|
||||
parallel_map_op["output"] = {"schema": parallel_output_schema}
|
||||
|
||||
# Use the same model as the original operation
|
||||
default_model = original_op.get("model", global_default_model)
|
||||
parallel_map_op["model"] = default_model
|
||||
|
||||
# Check if aggregation is needed by comparing subtask output keys with original keys
|
||||
subtask_output_keys = set()
|
||||
for subtask in rewrite.subtasks:
|
||||
subtask_output_keys.update(subtask.output_keys)
|
||||
|
||||
# Check if aggregation is needed: aggregation_prompt is empty
|
||||
if not rewrite.aggregation_prompt.strip():
|
||||
# Just return the parallel map - it already produces the right output
|
||||
parallel_map_op["output"] = original_op.get("output", {})
|
||||
new_ops_list[pos_to_replace : pos_to_replace + 1] = [parallel_map_op]
|
||||
else:
|
||||
# Need aggregation step
|
||||
aggregation_map_op = {
|
||||
"name": f"{target_op}_aggregate",
|
||||
"type": "map",
|
||||
"prompt": rewrite.aggregation_prompt,
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
"output": original_op.get(
|
||||
"output", {}
|
||||
), # Same output schema as original
|
||||
}
|
||||
|
||||
# Use the same model as the original operation
|
||||
aggregation_map_op["model"] = default_model
|
||||
|
||||
# Replace the original operation with both operations
|
||||
new_ops_list[pos_to_replace : pos_to_replace + 1] = [
|
||||
parallel_map_op,
|
||||
aggregation_map_op,
|
||||
]
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Main method that orchestrates directive instantiation.
|
||||
"""
|
||||
assert len(target_ops) == 1, "This directive requires exactly one target op"
|
||||
|
||||
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
|
||||
|
||||
# Step 1: Agent generates the instantiate schema
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
target_op_config, agent_llm, message_history
|
||||
)
|
||||
|
||||
# Step 2: Apply transformation using the schema
|
||||
return (
|
||||
self.apply(global_default_model, operators, target_ops[0], rewrite),
|
||||
message_history,
|
||||
call_cost,
|
||||
)
|
||||
|
|
@ -1,387 +0,0 @@
|
|||
import json
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
MapReduceFusionInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class MapReduceFusionDirective(Directive):
|
||||
name: str = Field(
|
||||
default="map_reduce_fusion", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(
|
||||
default="Map -> Reduce => Map (with new prompt) -> Reduce (with new propmt)"
|
||||
)
|
||||
nl_description: str = Field(
|
||||
default="Transform a Map operation followed by a Reduce operation over long documents by updating the Map prompt to first extract or process only the relevant information from each document. Then, modify the Reduce prompt to operate on these processed outputs instead of the full document content."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="There is a Map operation followed by a Reduce operation, and the Reduce step iterates over long documents to extract specific information (e.g., locations, entities, themes) that could be pre-extracted per document in the Map step. This makes the Reduce operation more efficient and focused. The target operators must be a Map operator followed by a Reduce operator. When selecting this directive, specify the Map and Reduce operators as the target operators."
|
||||
)
|
||||
instantiate_schema_type: Type[BaseModel] = MapReduceFusionInstantiateSchema
|
||||
|
||||
example: str = Field(
|
||||
default=(
|
||||
"Original Pipeline:\\n"
|
||||
"Map Op (classify_document):\\n"
|
||||
"- name: classify_document\\n"
|
||||
" type: map\\n"
|
||||
" prompt: |\\n"
|
||||
" Classify the following document into a category:\\n"
|
||||
" {{ input.content }}\\n"
|
||||
" Choose from: news, research, legal, business\\n"
|
||||
" output:\\n"
|
||||
" schema:\\n"
|
||||
" category: str\\n"
|
||||
"\\n"
|
||||
"Reduce Op (extract_organizations):\\n"
|
||||
"- name: extract_organizations\\n"
|
||||
" type: reduce\\n"
|
||||
" reduce_key: category\\n"
|
||||
" prompt: |\\n"
|
||||
' For each category \\"{{ reduce_key }}\\", extract all organization names from these documents:\\n'
|
||||
" {% for input in inputs %}\\n"
|
||||
" Document {{ loop.index }}: {{ input.content }}\\n"
|
||||
" {% endfor %}\\n"
|
||||
" Return a list of unique organization names.\\n"
|
||||
" output:\\n"
|
||||
" schema:\\n"
|
||||
" organizations: list[str]\\n"
|
||||
"\\n"
|
||||
"Example InstantiateSchema Output:\\n"
|
||||
"MapReduceFusionInstantiateSchema(\\n"
|
||||
" new_map_name='fused_classify_extract_organizations',\\n"
|
||||
" new_map_prompt='Analyze the following document:\\n{{ input.content }}\\n\\n1. Classify the document into a category (news, research, legal, business)\\n\\n2. Extract all organization names mentioned in the document\\n\\nProvide both the category and list of organizations.',\\n"
|
||||
" new_key='organizations',\\n"
|
||||
" new_reduce_prompt='For each category \\\"{{ reduce_key }}\\\", combine all organization names from these pre-extracted lists:\\n{% for input in inputs %}\\nOrganizations from document {{ loop.index }}: {{ input.organizations }}\\n{% endfor %}\\nReturn a single deduplicated list of all unique organization names.'\\n"
|
||||
")"
|
||||
),
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="classification_organization_fusion",
|
||||
description="Should fuse document classification with organization extraction",
|
||||
input_config=[
|
||||
{
|
||||
"name": "classify_document",
|
||||
"type": "map",
|
||||
"prompt": "Classify this document: {{ input.content }}",
|
||||
"output": {"schema": {"category": "str"}},
|
||||
},
|
||||
{
|
||||
"name": "extract_organizations",
|
||||
"type": "reduce",
|
||||
"reduce_key": "category",
|
||||
"prompt": "For each category '{{ reduce_key }}', extract organizations from: {% for input in inputs %}{{ input.content }}{% endfor %}",
|
||||
"output": {"schema": {"organizations": "list[str]"}},
|
||||
},
|
||||
],
|
||||
target_ops=["classify_document", "extract_organizations"],
|
||||
expected_behavior="Should modify map to classify AND extract organizations per document, then reduce to aggregate pre-extracted organizations",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="analysis_entity_fusion",
|
||||
description="Should fuse document analysis with entity extraction per category",
|
||||
input_config=[
|
||||
{
|
||||
"name": "analyze_document",
|
||||
"type": "map",
|
||||
"prompt": "Analyze document type: {{ input.text }}",
|
||||
"output": {"schema": {"doc_type": "str"}},
|
||||
},
|
||||
{
|
||||
"name": "find_people",
|
||||
"type": "reduce",
|
||||
"reduce_key": "doc_type",
|
||||
"prompt": "For each document type '{{ reduce_key }}', find all people mentioned: {% for input in inputs %}Document: {{ input.text }}{% endfor %}",
|
||||
"output": {"schema": {"people": "list[str]"}},
|
||||
},
|
||||
],
|
||||
target_ops=["analyze_document", "find_people"],
|
||||
expected_behavior="Should modify map to analyze type AND extract people per document, then reduce to combine pre-extracted people lists",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, MapReduceFusionDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("MapReduceFusionDirective")
|
||||
|
||||
def to_string_for_instantiate(self, original_ops: List[Dict]) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive.
|
||||
|
||||
Args:
|
||||
original_ops (List[Dict]): List containing the map and reduce operations.
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
map_op, reduce_op = original_ops[0], original_ops[1]
|
||||
return (
|
||||
f"You are an expert at optimizing data processing pipelines for efficiency.\\n\\n"
|
||||
f"Original Map Operation:\\n"
|
||||
f"{str(map_op)}\\n\\n"
|
||||
f"Original Reduce Operation:\\n"
|
||||
f"{str(reduce_op)}\\n\\n"
|
||||
f"Directive: {self.name}\\n"
|
||||
f"Your task is to apply map-reduce fusion: modify the Map operation to pre-extract the information that the Reduce operation needs, "
|
||||
f"then update the Reduce operation to work with these pre-extracted results instead of processing full documents.\\n\\n"
|
||||
f"Key Requirements:\\n"
|
||||
f"1. Analyze what specific information the Reduce operation extracts from documents\\n"
|
||||
f"2. Create a new Map prompt that does BOTH the original Map task AND extracts the information needed by Reduce\\n"
|
||||
f"3. Create a new key name for the pre-extracted information that the Reduce will reference\\n"
|
||||
f"4. Create a new Reduce prompt that aggregates the pre-extracted information instead of processing full documents\\n"
|
||||
f"5. The Reduce operation should reference the new key (e.g., {{ input.new_key }}) instead of full document content\\n\\n"
|
||||
f"Output Format:\\n"
|
||||
f"Return a MapReduceFusionInstantiateSchema with:\\n"
|
||||
f"- new_map_name: Combined name for the fused map operation\\n"
|
||||
f"- new_map_prompt: Prompt that does both original map task + extraction\\n"
|
||||
f"- new_key: Key name for the extracted information\\n"
|
||||
f"- new_reduce_prompt: Prompt that works with pre-extracted data\\n\\n"
|
||||
f"Example:\\n"
|
||||
f"{self.example}\\n\\n"
|
||||
f"Please output only the MapReduceFusionInstantiateSchema with the four required fields."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
map_op: Dict,
|
||||
reduce_op: Dict,
|
||||
expected_document_key,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive by transforming the map and reduce operations.
|
||||
|
||||
Args:
|
||||
map_op (Dict): The original map operation.
|
||||
reduce_op (Dict): The original reduce operation.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
|
||||
Returns:
|
||||
MapReduceFusionInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant for optimizing document processing pipelines.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate([map_op, reduce_op]),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=MapReduceFusionInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
if (
|
||||
"new_map_name" not in parsed_res
|
||||
or "new_map_prompt" not in parsed_res
|
||||
or "new_key" not in parsed_res
|
||||
or "new_reduce_prompt" not in parsed_res
|
||||
):
|
||||
raise ValueError(
|
||||
"Response from LLM is missing required keys: 'new_map_name', 'new_map_prompt', 'new_key', or 'new_reduce_prompt'"
|
||||
)
|
||||
|
||||
schema = MapReduceFusionInstantiateSchema(
|
||||
new_map_name=parsed_res["new_map_name"],
|
||||
new_map_prompt=parsed_res["new_map_prompt"],
|
||||
new_key=parsed_res["new_key"],
|
||||
new_reduce_prompt=parsed_res["new_reduce_prompt"],
|
||||
)
|
||||
|
||||
# Validate the schema
|
||||
MapReduceFusionInstantiateSchema.validate_reduce_prompt_references_new_key(
|
||||
schema.new_reduce_prompt, schema.new_key, expected_document_key
|
||||
)
|
||||
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model,
|
||||
ops_list: List[Dict],
|
||||
map_target: str,
|
||||
reduce_target: str,
|
||||
rewrite: MapReduceFusionInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config.
|
||||
"""
|
||||
# Create a copy of the pipeline config
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find positions of the target ops
|
||||
map_pos = None
|
||||
reduce_pos = None
|
||||
orig_map_op = None
|
||||
orig_reduce_op = None
|
||||
|
||||
for i, op in enumerate(ops_list):
|
||||
if op["name"] == map_target:
|
||||
map_pos = i
|
||||
orig_map_op = op
|
||||
elif op["name"] == reduce_target:
|
||||
reduce_pos = i
|
||||
orig_reduce_op = op
|
||||
|
||||
if map_pos is None or reduce_pos is None:
|
||||
raise ValueError(
|
||||
f"Could not find target operations: {map_target}, {reduce_target}"
|
||||
)
|
||||
|
||||
# Get default model
|
||||
default_model = orig_map_op.get("model", global_default_model)
|
||||
|
||||
# Create the new map operation with fused functionality
|
||||
new_map_op = {
|
||||
"name": rewrite.new_map_name,
|
||||
"type": "map",
|
||||
"prompt": rewrite.new_map_prompt,
|
||||
"model": default_model,
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
"output": {
|
||||
"schema": {
|
||||
**orig_map_op.get("output", {}).get("schema", {}),
|
||||
rewrite.new_key: "list[str]",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Create the new reduce operation that works with pre-extracted data
|
||||
new_reduce_op = {
|
||||
"name": orig_reduce_op["name"],
|
||||
"type": "reduce",
|
||||
"reduce_key": orig_reduce_op.get("reduce_key"),
|
||||
"prompt": rewrite.new_reduce_prompt,
|
||||
"model": default_model,
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
"output": orig_reduce_op.get("output", {}),
|
||||
}
|
||||
|
||||
# Replace the operations
|
||||
new_ops_list[map_pos] = new_map_op
|
||||
new_ops_list[reduce_pos] = new_reduce_op
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
optimize_goal="acc",
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Assert that there are exactly two target ops (map and reduce)
|
||||
if len(target_ops) != 2:
|
||||
raise ValueError(
|
||||
"map_reduce_fusion directive requires exactly two target operators: a map and a reduce operation"
|
||||
)
|
||||
|
||||
# Find the map and reduce operations
|
||||
first_op = None
|
||||
second_op = None
|
||||
|
||||
for op in operators:
|
||||
if op["name"] == target_ops[0]:
|
||||
first_op = op
|
||||
elif op["name"] == target_ops[1]:
|
||||
second_op = op
|
||||
|
||||
if first_op is None or second_op is None:
|
||||
raise ValueError(f"Could not find target operations: {target_ops}")
|
||||
|
||||
if first_op.get("type") == "map" and second_op.get("type") == "reduce":
|
||||
map_op = first_op
|
||||
reduce_op = second_op
|
||||
map_target = target_ops[0]
|
||||
reduce_target = target_ops[1]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Target operators must be one map operation followed by one reduce operation!"
|
||||
)
|
||||
|
||||
# Extract expected document key from the reduce prompt template
|
||||
prompt_template = reduce_op["prompt"]
|
||||
# Find all occurrences of {{ input.key }} in the prompt
|
||||
input_key_pattern = r"\{\{\s*([^\}\s]+)\s*\}\}"
|
||||
input_keys = list(set(re.findall(input_key_pattern, prompt_template)))
|
||||
print("input_keys: ", input_keys)
|
||||
# Heuristic: pick the key that's most likely to contain document content
|
||||
# Look for common document field names
|
||||
document_key_candidates = [
|
||||
key
|
||||
for key in input_keys
|
||||
if any(
|
||||
doc_word in key.lower()
|
||||
for doc_word in ["document", "text", "content", "body", "description"]
|
||||
)
|
||||
]
|
||||
|
||||
if document_key_candidates:
|
||||
expected_document_key = document_key_candidates[
|
||||
0
|
||||
] # Pick the first candidate
|
||||
elif input_keys:
|
||||
expected_document_key = input_keys[0] # Fall back to the first input key
|
||||
else:
|
||||
raise ValueError("No input keys found in the reduce operation prompt")
|
||||
|
||||
print(f"Detected document key: {expected_document_key}")
|
||||
|
||||
# Instantiate the directive
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
map_op, reduce_op, expected_document_key, agent_llm, message_history
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators
|
||||
new_ops_plan = self.apply(
|
||||
global_default_model, operators, map_target, reduce_target, rewrite
|
||||
)
|
||||
return new_ops_plan, message_history, call_cost
|
||||
|
|
@ -1,361 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
MapResolveToMapWithCategoriesInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class MapResolveToMapWithCategoriesDirective(Directive):
|
||||
name: str = Field(
|
||||
default="map_resolve_to_map_with_categories",
|
||||
description="The name of the directive",
|
||||
)
|
||||
formal_description: str = Field(default="Map -> Resolve => Map (with categories)")
|
||||
nl_description: str = Field(
|
||||
default="Replace a Map -> Resolve pattern with a single Map operation that has predefined categories. The agent analyzes the data and task to propose a set of canonical categories, and the new Map forces outputs into one of these categories (or 'none of the above'). This effectively performs entity resolution deterministically by standardizing outputs upfront, avoiding the need for pairwise comparisons."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When a Map operation produces outputs that are then resolved/deduplicated, and the set of valid output categories can be enumerated upfront. This is more efficient than Resolve when the category space is small and well-defined (e.g., standardizing company types, product categories, sentiment labels). The target must be a Map operation followed by a Resolve operation."
|
||||
)
|
||||
instantiate_schema_type: Type[BaseModel] = (
|
||||
MapResolveToMapWithCategoriesInstantiateSchema
|
||||
)
|
||||
|
||||
example: str = Field(
|
||||
default=(
|
||||
"Original Pipeline:\n"
|
||||
"- name: extract_sentiment\n"
|
||||
" type: map\n"
|
||||
" prompt: |\n"
|
||||
" What is the sentiment of this review?\n"
|
||||
" {{ input.review }}\n"
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
" sentiment: string\n"
|
||||
"\n"
|
||||
"- name: normalize_sentiment\n"
|
||||
" type: resolve\n"
|
||||
" comparison_prompt: |\n"
|
||||
" Are these sentiments equivalent?\n"
|
||||
" Sentiment 1: {{ input1.sentiment }}\n"
|
||||
" Sentiment 2: {{ input2.sentiment }}\n"
|
||||
" resolution_prompt: |\n"
|
||||
" Normalize these sentiments:\n"
|
||||
" {% for input in inputs %}\n"
|
||||
" - {{ input.sentiment }}\n"
|
||||
" {% endfor %}\n"
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
" sentiment: string\n"
|
||||
"\n"
|
||||
"Example InstantiateSchema:\n"
|
||||
"MapResolveToMapWithCategoriesInstantiateSchema(\n"
|
||||
" categories=['Positive', 'Negative', 'Neutral', 'Mixed'],\n"
|
||||
" category_key='sentiment',\n"
|
||||
" new_prompt='''Analyze the sentiment of this review and classify it into one of the following categories:\n"
|
||||
"\n"
|
||||
"Categories:\n"
|
||||
"- Positive: Clearly positive sentiment, satisfaction, praise\n"
|
||||
"- Negative: Clearly negative sentiment, complaints, criticism\n"
|
||||
"- Neutral: No strong sentiment, factual statements\n"
|
||||
"- Mixed: Contains both positive and negative elements\n"
|
||||
"- None of the above: If the review doesn't fit any category\n"
|
||||
"\n"
|
||||
"Review: {{ input.review }}\n"
|
||||
"\n"
|
||||
"Return exactly one of: Positive, Negative, Neutral, Mixed, or None of the above.''',\n"
|
||||
" include_none_of_above=True,\n"
|
||||
")"
|
||||
),
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="sentiment_categorization",
|
||||
description="Should replace map+resolve with categorized map for sentiment",
|
||||
input_config=[
|
||||
{
|
||||
"name": "extract_sentiment",
|
||||
"type": "map",
|
||||
"prompt": "What is the sentiment? {{ input.text }}",
|
||||
"output": {"schema": {"sentiment": "string"}},
|
||||
},
|
||||
{
|
||||
"name": "normalize_sentiment",
|
||||
"type": "resolve",
|
||||
"comparison_prompt": "Same sentiment? {{ input1.sentiment }} vs {{ input2.sentiment }}",
|
||||
"resolution_prompt": "Normalize: {% for input in inputs %}{{ input.sentiment }}{% endfor %}",
|
||||
"output": {"schema": {"sentiment": "string"}},
|
||||
},
|
||||
],
|
||||
target_ops=["extract_sentiment", "normalize_sentiment"],
|
||||
expected_behavior="Should create a single map with predefined sentiment categories",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, MapResolveToMapWithCategoriesDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("MapResolveToMapWithCategoriesDirective")
|
||||
|
||||
def to_string_for_instantiate(
|
||||
self, map_op: Dict, resolve_op: Dict, sample_data: List[Dict] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive.
|
||||
|
||||
Args:
|
||||
map_op (Dict): The map operation configuration.
|
||||
resolve_op (Dict): The resolve operation configuration.
|
||||
sample_data (List[Dict], optional): Sample data to help identify categories.
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
sample_str = ""
|
||||
if sample_data:
|
||||
sample_str = f"\n\nSample Input Data (first 5 items):\n{json.dumps(sample_data[:5], indent=2)}"
|
||||
|
||||
return (
|
||||
f"You are an expert at optimizing data processing pipelines by replacing entity resolution with categorical constraints.\n\n"
|
||||
f"Map Operation:\n"
|
||||
f"{str(map_op)}\n\n"
|
||||
f"Resolve Operation:\n"
|
||||
f"{str(resolve_op)}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to replace the Map -> Resolve pattern with a single Map that uses predefined categories.\n\n"
|
||||
f"Key Requirements:\n"
|
||||
f"1. Analyze the map's output field and the resolve operation to understand what values are being normalized:\n"
|
||||
f" - Look at the comparison_prompt to understand what variations are being matched\n"
|
||||
f" - Look at the resolution_prompt to understand the canonical form\n\n"
|
||||
f"2. Propose a set of categories that cover all expected outputs:\n"
|
||||
f" - Categories should be mutually exclusive\n"
|
||||
f" - Categories should be exhaustive (cover all realistic cases)\n"
|
||||
f" - Consider including 'None of the above' for edge cases\n\n"
|
||||
f"3. Optionally provide descriptions for each category to help the LLM classify correctly\n\n"
|
||||
f"4. Create a new_prompt that:\n"
|
||||
f" - Lists all valid categories with their descriptions\n"
|
||||
f" - Instructs the LLM to output exactly one category\n"
|
||||
f" - References the input using {{{{ input.key }}}} syntax\n"
|
||||
f" - Includes any context from the original map prompt\n\n"
|
||||
f"5. Identify the category_key (the output field that will contain the category)\n\n"
|
||||
f"Benefits of this approach:\n"
|
||||
f"- Eliminates O(n^2) pairwise comparisons from Resolve\n"
|
||||
f"- Produces consistent, standardized outputs\n"
|
||||
f"- Reduces cost by removing the Resolve operation entirely\n"
|
||||
f"{sample_str}\n\n"
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please analyze the operations and propose appropriate categories. Output the MapResolveToMapWithCategoriesInstantiateSchema."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
map_op: Dict,
|
||||
resolve_op: Dict,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
sample_data: List[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive.
|
||||
|
||||
Args:
|
||||
map_op (Dict): The map operation configuration.
|
||||
resolve_op (Dict): The resolve operation configuration.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
sample_data (List[Dict], optional): Sample data to help identify categories.
|
||||
|
||||
Returns:
|
||||
MapResolveToMapWithCategoriesInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant for document processing pipelines specializing in categorical classification.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(
|
||||
map_op, resolve_op, sample_data
|
||||
),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=MapResolveToMapWithCategoriesInstantiateSchema,
|
||||
)
|
||||
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = MapResolveToMapWithCategoriesInstantiateSchema(**parsed_res)
|
||||
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model,
|
||||
ops_list: List[Dict],
|
||||
map_op_name: str,
|
||||
resolve_op_name: str,
|
||||
rewrite: MapResolveToMapWithCategoriesInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config.
|
||||
"""
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find position of the map and resolve ops
|
||||
map_pos = None
|
||||
resolve_pos = None
|
||||
map_op = None
|
||||
|
||||
for i, op in enumerate(new_ops_list):
|
||||
if op["name"] == map_op_name:
|
||||
map_pos = i
|
||||
map_op = op
|
||||
elif op["name"] == resolve_op_name:
|
||||
resolve_pos = i
|
||||
|
||||
if map_pos is None or resolve_pos is None:
|
||||
raise ValueError(
|
||||
f"Could not find map '{map_op_name}' and resolve '{resolve_op_name}' operations"
|
||||
)
|
||||
|
||||
# Determine the model to use
|
||||
default_model = map_op.get("model", global_default_model)
|
||||
|
||||
# Build the list of valid values for validation
|
||||
valid_values = list(rewrite.categories)
|
||||
if rewrite.include_none_of_above:
|
||||
valid_values.append("None of the above")
|
||||
|
||||
# Modify the map operation with the new prompt and add validation
|
||||
new_ops_list[map_pos]["prompt"] = rewrite.new_prompt
|
||||
new_ops_list[map_pos]["model"] = default_model
|
||||
|
||||
# Add validation to ensure output is one of the categories
|
||||
if "validate" not in new_ops_list[map_pos]:
|
||||
new_ops_list[map_pos]["validate"] = []
|
||||
|
||||
# Add validation rule for the category key
|
||||
validation_rule = f"output['{rewrite.category_key}'] in {valid_values}"
|
||||
new_ops_list[map_pos]["validate"].append(validation_rule)
|
||||
|
||||
# Update the output schema to reflect the category key
|
||||
if "output" not in new_ops_list[map_pos]:
|
||||
new_ops_list[map_pos]["output"] = {"schema": {}}
|
||||
new_ops_list[map_pos]["output"]["schema"][rewrite.category_key] = "string"
|
||||
|
||||
# Remove the resolve operation
|
||||
new_ops_list.pop(resolve_pos)
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
optimize_goal="acc",
|
||||
global_default_model: str = None,
|
||||
dataset: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Assert that there are exactly two target ops (map and resolve)
|
||||
assert (
|
||||
len(target_ops) == 2
|
||||
), "There must be exactly two target ops (map and resolve) to instantiate this directive"
|
||||
|
||||
# Find the map and resolve operations
|
||||
map_op = None
|
||||
resolve_op = None
|
||||
|
||||
for op in operators:
|
||||
if op["name"] == target_ops[0]:
|
||||
if op.get("type") == "map":
|
||||
map_op = op
|
||||
elif op.get("type") == "resolve":
|
||||
resolve_op = op
|
||||
elif op["name"] == target_ops[1]:
|
||||
if op.get("type") == "map":
|
||||
map_op = op
|
||||
elif op.get("type") == "resolve":
|
||||
resolve_op = op
|
||||
|
||||
if map_op is None or resolve_op is None:
|
||||
raise ValueError(
|
||||
f"Could not find both a map and resolve operation in target_ops: {target_ops}"
|
||||
)
|
||||
|
||||
# Verify the map comes before resolve
|
||||
map_idx = next(
|
||||
i for i, op in enumerate(operators) if op["name"] == map_op["name"]
|
||||
)
|
||||
resolve_idx = next(
|
||||
i for i, op in enumerate(operators) if op["name"] == resolve_op["name"]
|
||||
)
|
||||
|
||||
if map_idx >= resolve_idx:
|
||||
raise ValueError(
|
||||
f"Map operation '{map_op['name']}' must come before resolve operation '{resolve_op['name']}'"
|
||||
)
|
||||
|
||||
# Load sample data if available
|
||||
sample_data = None
|
||||
if dataset:
|
||||
try:
|
||||
with open(dataset, "r") as f:
|
||||
sample_data = json.load(f)
|
||||
except Exception:
|
||||
pass # Ignore if we can't load sample data
|
||||
|
||||
# Instantiate the directive
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
map_op,
|
||||
resolve_op,
|
||||
agent_llm,
|
||||
message_history,
|
||||
sample_data,
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators
|
||||
new_ops_plan = self.apply(
|
||||
global_default_model, operators, map_op["name"], resolve_op["name"], rewrite
|
||||
)
|
||||
return new_ops_plan, message_history, call_cost
|
||||
|
|
@ -1,335 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
MapToMapResolveReduceInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class MapToMapResolveReduceDirective(Directive):
|
||||
name: str = Field(
|
||||
default="map_to_map_resolve_reduce", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(default="Map -> Reduce => Map -> Resolve -> Reduce")
|
||||
nl_description: str = Field(
|
||||
default="Insert a Resolve operation between Map and Reduce to deduplicate or normalize entities before aggregation. The Resolve operation uses code-powered blocking conditions to efficiently identify which pairs to compare, avoiding O(n^2) comparisons. This is useful when the Map output contains duplicate or near-duplicate entities that should be merged before the Reduce step."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When a Map operation produces outputs that may contain duplicates, variations, or near-duplicates (e.g., different spellings of names, similar categories), and these should be normalized before the Reduce aggregation step. The target must be a Map operation followed by a Reduce operation."
|
||||
)
|
||||
instantiate_schema_type: Type[BaseModel] = MapToMapResolveReduceInstantiateSchema
|
||||
|
||||
example: str = Field(
|
||||
default=(
|
||||
"Original Pipeline:\n"
|
||||
"- name: extract_companies\n"
|
||||
" type: map\n"
|
||||
" prompt: |\n"
|
||||
" Extract company names from this news article:\n"
|
||||
" {{ input.article }}\n"
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
" company_name: string\n"
|
||||
"\n"
|
||||
"- name: aggregate_companies\n"
|
||||
" type: reduce\n"
|
||||
" reduce_key: sector\n"
|
||||
" prompt: |\n"
|
||||
" List all unique companies in this sector:\n"
|
||||
" {% for input in inputs %}\n"
|
||||
" - {{ input.company_name }}\n"
|
||||
" {% endfor %}\n"
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
" companies: list[str]\n"
|
||||
"\n"
|
||||
"Example InstantiateSchema:\n"
|
||||
"MapToMapResolveReduceInstantiateSchema(\n"
|
||||
" resolve_name='normalize_company_names',\n"
|
||||
" comparison_prompt='''Are these two company names referring to the same company?\n"
|
||||
"Company 1: {{ input1.company_name }}\n"
|
||||
"Company 2: {{ input2.company_name }}\n"
|
||||
"Consider variations like abbreviations (IBM vs International Business Machines), \n"
|
||||
"different legal suffixes (Inc, Corp, LLC), and common misspellings.''',\n"
|
||||
" resolution_prompt='''Given these variations of a company name:\n"
|
||||
"{% for input in inputs %}\n"
|
||||
"- {{ input.company_name }}\n"
|
||||
"{% endfor %}\n"
|
||||
"Return the canonical/official company name.''',\n"
|
||||
" blocking_conditions=[\n"
|
||||
" \"input1['company_name'][:3].lower() == input2['company_name'][:3].lower()\",\n"
|
||||
" \"input1['company_name'].split()[0].lower() == input2['company_name'].split()[0].lower()\",\n"
|
||||
" ],\n"
|
||||
" blocking_keys=['company_name'],\n"
|
||||
" limit_comparisons=1000,\n"
|
||||
" output_schema={'company_name': 'string'},\n"
|
||||
")"
|
||||
),
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="company_name_normalization",
|
||||
description="Should insert resolve between map and reduce for company names",
|
||||
input_config=[
|
||||
{
|
||||
"name": "extract_companies",
|
||||
"type": "map",
|
||||
"prompt": "Extract company name from: {{ input.text }}",
|
||||
"output": {"schema": {"company_name": "string"}},
|
||||
},
|
||||
{
|
||||
"name": "aggregate_by_sector",
|
||||
"type": "reduce",
|
||||
"reduce_key": "sector",
|
||||
"prompt": "List companies:\n{% for input in inputs %}\n- {{ input.company_name }}\n{% endfor %}",
|
||||
"output": {"schema": {"companies": "list[str]"}},
|
||||
},
|
||||
],
|
||||
target_ops=["extract_companies", "aggregate_by_sector"],
|
||||
expected_behavior="Should create a resolve operation between the map and reduce with appropriate blocking conditions",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, MapToMapResolveReduceDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("MapToMapResolveReduceDirective")
|
||||
|
||||
def to_string_for_instantiate(self, map_op: Dict, reduce_op: Dict) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive.
|
||||
|
||||
Args:
|
||||
map_op (Dict): The map operation configuration.
|
||||
reduce_op (Dict): The reduce operation configuration.
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
return (
|
||||
f"You are an expert at optimizing data processing pipelines by inserting entity resolution steps.\n\n"
|
||||
f"Map Operation:\n"
|
||||
f"{str(map_op)}\n\n"
|
||||
f"Reduce Operation:\n"
|
||||
f"{str(reduce_op)}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to insert a Resolve operation between the Map and Reduce to deduplicate/normalize entities.\n\n"
|
||||
f"Key Requirements:\n"
|
||||
f"1. Create a comparison_prompt that determines if two items from the Map output are duplicates/variations:\n"
|
||||
f" - Must reference {{ input1.key }} and {{ input2.key }} for comparing fields\n"
|
||||
f" - Should handle common variations (abbreviations, misspellings, formatting differences)\n\n"
|
||||
f"2. Create a resolution_prompt that merges matched items into a canonical form:\n"
|
||||
f" - Must use {{% for input in inputs %}} to iterate over matched items\n"
|
||||
f" - Should produce the most authoritative/complete version\n\n"
|
||||
f"3. Create blocking_conditions to avoid O(n^2) comparisons:\n"
|
||||
f" - These are Python expressions with access to 'input1' and 'input2' dicts\n"
|
||||
f" - They should filter pairs to only those likely to match\n"
|
||||
f" - Examples:\n"
|
||||
f" * \"input1['name'][:3].lower() == input2['name'][:3].lower()\" (first 3 chars match)\n"
|
||||
f" * \"input1['name'].split()[0].lower() == input2['name'].split()[0].lower()\" (first word matches)\n"
|
||||
f" * \"abs(len(input1['name']) - len(input2['name'])) < 10\" (similar length)\n"
|
||||
f" - Multiple conditions are OR'd together\n\n"
|
||||
f"4. Set limit_comparisons to cap the number of pairs (recommended: 500-2000)\n\n"
|
||||
f"5. The output_schema should match what the Reduce operation expects from each input\n\n"
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output the MapToMapResolveReduceInstantiateSchema."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self, map_op: Dict, reduce_op: Dict, agent_llm: str, message_history: list = []
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive.
|
||||
|
||||
Args:
|
||||
map_op (Dict): The map operation configuration.
|
||||
reduce_op (Dict): The reduce operation configuration.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
|
||||
Returns:
|
||||
MapToMapResolveReduceInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant for document processing pipelines specializing in entity resolution.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(map_op, reduce_op),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=MapToMapResolveReduceInstantiateSchema,
|
||||
)
|
||||
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = MapToMapResolveReduceInstantiateSchema(**parsed_res)
|
||||
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model,
|
||||
ops_list: List[Dict],
|
||||
map_op_name: str,
|
||||
reduce_op_name: str,
|
||||
rewrite: MapToMapResolveReduceInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config.
|
||||
"""
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find position of the map and reduce ops
|
||||
map_pos = None
|
||||
reduce_pos = None
|
||||
map_op = None
|
||||
|
||||
for i, op in enumerate(ops_list):
|
||||
if op["name"] == map_op_name:
|
||||
map_pos = i
|
||||
map_op = op
|
||||
elif op["name"] == reduce_op_name:
|
||||
reduce_pos = i
|
||||
|
||||
if map_pos is None or reduce_pos is None:
|
||||
raise ValueError(
|
||||
f"Could not find map '{map_op_name}' and reduce '{reduce_op_name}' operations"
|
||||
)
|
||||
|
||||
# Determine the model to use
|
||||
default_model = map_op.get("model", global_default_model)
|
||||
|
||||
# Find the reduce operation to get the reduce_key
|
||||
reduce_op = None
|
||||
for op in ops_list:
|
||||
if op["name"] == reduce_op_name:
|
||||
reduce_op = op
|
||||
break
|
||||
|
||||
# Derive output schema from reduce_key - that's what's being grouped/resolved
|
||||
reduce_key = reduce_op.get("reduce_key", []) if reduce_op else []
|
||||
if isinstance(reduce_key, str):
|
||||
reduce_key = [reduce_key]
|
||||
|
||||
# Build output schema from reduce_key fields
|
||||
output_schema = {key: "string" for key in reduce_key}
|
||||
|
||||
# Create the resolve operation
|
||||
resolve_op = {
|
||||
"name": rewrite.resolve_name,
|
||||
"type": "resolve",
|
||||
"comparison_prompt": rewrite.comparison_prompt,
|
||||
"resolution_prompt": rewrite.resolution_prompt,
|
||||
"blocking_conditions": rewrite.blocking_conditions,
|
||||
"blocking_keys": rewrite.blocking_keys,
|
||||
"limit_comparisons": rewrite.limit_comparisons,
|
||||
"model": default_model,
|
||||
"output": {"schema": output_schema},
|
||||
}
|
||||
|
||||
# Insert resolve operation after the map operation
|
||||
new_ops_list.insert(map_pos + 1, resolve_op)
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
optimize_goal="acc",
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Assert that there are exactly two target ops (map and reduce)
|
||||
assert (
|
||||
len(target_ops) == 2
|
||||
), "There must be exactly two target ops (map and reduce) to instantiate this directive"
|
||||
|
||||
# Find the map and reduce operations
|
||||
map_op = None
|
||||
reduce_op = None
|
||||
|
||||
for op in operators:
|
||||
if op["name"] == target_ops[0]:
|
||||
if op.get("type") == "map":
|
||||
map_op = op
|
||||
elif op.get("type") == "reduce":
|
||||
reduce_op = op
|
||||
elif op["name"] == target_ops[1]:
|
||||
if op.get("type") == "map":
|
||||
map_op = op
|
||||
elif op.get("type") == "reduce":
|
||||
reduce_op = op
|
||||
|
||||
if map_op is None or reduce_op is None:
|
||||
raise ValueError(
|
||||
f"Could not find both a map and reduce operation in target_ops: {target_ops}"
|
||||
)
|
||||
|
||||
# Verify the map comes before reduce
|
||||
map_idx = next(
|
||||
i for i, op in enumerate(operators) if op["name"] == map_op["name"]
|
||||
)
|
||||
reduce_idx = next(
|
||||
i for i, op in enumerate(operators) if op["name"] == reduce_op["name"]
|
||||
)
|
||||
|
||||
if map_idx >= reduce_idx:
|
||||
raise ValueError(
|
||||
f"Map operation '{map_op['name']}' must come before reduce operation '{reduce_op['name']}'"
|
||||
)
|
||||
|
||||
# Instantiate the directive
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
map_op,
|
||||
reduce_op,
|
||||
agent_llm,
|
||||
message_history,
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators
|
||||
new_ops_plan = self.apply(
|
||||
global_default_model, operators, map_op["name"], reduce_op["name"], rewrite
|
||||
)
|
||||
return new_ops_plan, message_history, call_cost
|
||||
|
|
@ -1,329 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
OperatorFusionInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class OperatorFusionDirective(Directive):
|
||||
name: str = Field(
|
||||
default="operator_fusion", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(default="Op1 -> Op2 => Op2")
|
||||
nl_description: str = Field(
|
||||
default="Combines two sequential operations into a single operation to reduce LLM processing costs by avoiding duplicate document reads and API calls"
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When you have two sequential LLM operations processing the same documents keys and want to optimize cost by combining them into one operation that performs both tasks. The target operators should be two consecutive operations. Make sure you specify two operators when choosing this directive."
|
||||
)
|
||||
|
||||
instantiate_schema_type: Type[BaseModel] = Field(
|
||||
default=OperatorFusionInstantiateSchema
|
||||
)
|
||||
|
||||
example: str = Field(
|
||||
default="""
|
||||
Example 1 - Map + Filter Fusion:
|
||||
Original: extract_sentiment (map) → filter_positive (filter)
|
||||
Agent output: {"fused_prompt": "Extract sentiment from {{ input.review }} AND determine if it's positive (true/false)"}
|
||||
|
||||
Example 2 - Map + Map Fusion:
|
||||
Original: extract_entities (map) → classify_urgency (map)
|
||||
Agent output: {"fused_prompt": "Extract entities from {{ input.text }} AND classify urgency level"}
|
||||
|
||||
Example 3 - Map + Reduce Fusion:
|
||||
Original: extract_themes (map) → summarize_themes (reduce)
|
||||
Agent output: {"fused_prompt": "For each group of feedback, extract themes and summarize them: {% for item in inputs %}{{ item.feedback }}{% endfor %}"}
|
||||
"""
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="sentiment_analysis_fusion",
|
||||
description="Fuse sentiment extraction + quality filter",
|
||||
input_config=[
|
||||
{
|
||||
"name": "extract_sentiment",
|
||||
"type": "map",
|
||||
"prompt": "What is the sentiment of this review? {{ input.review_text }}",
|
||||
"output": {
|
||||
"schema": {"sentiment": "string", "confidence": "float"}
|
||||
},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
{
|
||||
"name": "filter_confident",
|
||||
"type": "filter",
|
||||
"prompt": "Is this sentiment confident (>0.8)? Confidence: {{ input.confidence }}",
|
||||
"output": {"schema": {"is_confident": "boolean"}},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
],
|
||||
target_ops=["extract_sentiment", "filter_confident"],
|
||||
expected_behavior="Should create map + code_filter for sentiment extraction with confidence filtering",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="document_processing_fusion",
|
||||
description="Fuse document classification + summarization",
|
||||
input_config=[
|
||||
{
|
||||
"name": "classify_doc",
|
||||
"type": "map",
|
||||
"prompt": "Classify this document: {{ input.content }}",
|
||||
"output": {"schema": {"category": "string"}},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
{
|
||||
"name": "summarize_doc",
|
||||
"type": "map",
|
||||
"prompt": "Summarize this document: {{ input.content }}",
|
||||
"output": {"schema": {"summary": "string"}},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
],
|
||||
target_ops=["classify_doc", "summarize_doc"],
|
||||
expected_behavior="Should create single map combining classification and summarization",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="extract_and_aggregate_fusion",
|
||||
description="Fuse entity extraction + aggregation",
|
||||
input_config=[
|
||||
{
|
||||
"name": "extract_mentions",
|
||||
"type": "map",
|
||||
"prompt": "Extract company mentions from: {{ input.article }}",
|
||||
"output": {"schema": {"companies": "list[str]"}},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
{
|
||||
"name": "count_mentions",
|
||||
"type": "reduce",
|
||||
"reduce_key": "topic",
|
||||
"prompt": "Count company mentions: {% for item in inputs %}{{ item.companies }}{% endfor %}",
|
||||
"output": {"schema": {"mention_counts": "str"}},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
],
|
||||
target_ops=["extract_mentions", "count_mentions"],
|
||||
expected_behavior="Should create single reduce combining extraction and counting",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, OperatorFusionDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("OperatorFusionDirective")
|
||||
|
||||
def to_string_for_instantiate(self, original_ops: List[Dict]) -> str:
|
||||
"""
|
||||
Generate a prompt that asks the agent to output the instantiate schema.
|
||||
"""
|
||||
op1, op2 = original_ops[0], original_ops[1]
|
||||
|
||||
return (
|
||||
f"You are an expert at optimizing data processing pipelines for cost efficiency.\n\n"
|
||||
f"Two Sequential Operations:\n"
|
||||
f"Operation 1: {op1}\n"
|
||||
f"Operation 2: {op2}\n\n"
|
||||
f"Your task is to fuse these two operations into a single operation that performs both tasks, "
|
||||
f"reducing LLM API calls and processing costs.\n\n"
|
||||
f"Create a combined prompt that:\n"
|
||||
f"1. Performs the logic of both operations in one LLM call\n"
|
||||
f"2. Uses the same input references as the original operations\n"
|
||||
f"3. If either operation is a filter, include boolean logic for filtering\n"
|
||||
f"4. Maintains the same output schema requirements\n\n"
|
||||
f"IMPORTANT: If either operation is a filter, your fused prompt MUST include logic that "
|
||||
f"outputs a boolean field for filtering purposes. A code_filter will be automatically added.\n\n"
|
||||
f"Example outputs:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output only the OperatorFusionInstantiateSchema with 'fused_prompt' field "
|
||||
f"that specifies how to combine these operations efficiently."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
original_ops: List[Dict],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
) -> tuple:
|
||||
"""
|
||||
Call the LLM to generate the instantiate schema.
|
||||
"""
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(original_ops),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=OperatorFusionInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = OperatorFusionInstantiateSchema(**parsed_res)
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model: str,
|
||||
ops_list: List[Dict],
|
||||
target_ops: List[str],
|
||||
rewrite: OperatorFusionInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the fusion directive by combining two operations and optionally adding code_filter.
|
||||
"""
|
||||
assert (
|
||||
len(target_ops) == 2
|
||||
), "Operator fusion requires exactly two target operations"
|
||||
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
op1_name, op2_name = target_ops[0], target_ops[1]
|
||||
|
||||
# Find the operations
|
||||
op1_idx = next(i for i, op in enumerate(ops_list) if op["name"] == op1_name)
|
||||
op2_idx = next(i for i, op in enumerate(ops_list) if op["name"] == op2_name)
|
||||
op1, op2 = ops_list[op1_idx], ops_list[op2_idx]
|
||||
|
||||
# Determine fused operation type and schema based on the combination
|
||||
op1_type, op2_type = op1.get("type"), op2.get("type")
|
||||
|
||||
assert (
|
||||
op1_type != "reduce" and op2_type != "reduce"
|
||||
), "Cannot apply fusion on reduce"
|
||||
|
||||
default_model = op1.get("model", global_default_model)
|
||||
|
||||
# Create base fused operation
|
||||
fused_op = {
|
||||
"name": f"fused_{op1_name}_{op2_name}",
|
||||
"prompt": rewrite.fused_prompt,
|
||||
"model": default_model,
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
}
|
||||
|
||||
needs_code_filter = False
|
||||
|
||||
# Determine type, schema, and code_filter need based on combination
|
||||
if op1_type == "map" and op2_type == "map":
|
||||
# map + map => fuse into one map
|
||||
fused_op["type"] = "map"
|
||||
fused_op["output"] = {
|
||||
"schema": {**op1["output"]["schema"], **op2["output"]["schema"]}
|
||||
}
|
||||
needs_code_filter = False
|
||||
elif (op1_type == "map" and op2_type == "filter") or (
|
||||
op1_type == "filter" and op2_type == "map"
|
||||
):
|
||||
# map + filter OR filter + map => fuse into map (with union of schemas) + code filter
|
||||
fused_op["type"] = "map"
|
||||
fused_op["output"] = {
|
||||
"schema": {**op1["output"]["schema"], **op2["output"]["schema"]}
|
||||
}
|
||||
needs_code_filter = True
|
||||
|
||||
elif op1_type == "filter" and op2_type == "filter":
|
||||
# filter + filter => fuse into one filter with bool output
|
||||
fused_op["type"] = "filter"
|
||||
fused_op["output"] = {"schema": {"_bool": "bool"}}
|
||||
needs_code_filter = False
|
||||
|
||||
# Replace the original operations
|
||||
if op1_idx < op2_idx:
|
||||
# Remove in reverse order to maintain indices
|
||||
new_ops_list.pop(op2_idx)
|
||||
new_ops_list.pop(op1_idx)
|
||||
new_ops_list.insert(op1_idx, fused_op)
|
||||
else:
|
||||
new_ops_list.pop(op1_idx)
|
||||
new_ops_list.pop(op2_idx)
|
||||
new_ops_list.insert(op2_idx, fused_op)
|
||||
|
||||
# Add code_filter if needed
|
||||
if needs_code_filter:
|
||||
# Get the filter field name from the filter operation
|
||||
filter_op = op1 if op1.get("type") == "filter" else op2
|
||||
filter_field = list(filter_op["output"]["schema"].keys())[0]
|
||||
|
||||
code_filter_op = {
|
||||
"name": f"filter_{fused_op['name']}",
|
||||
"type": "code_filter",
|
||||
"code": f"""
|
||||
def transform(input_doc):
|
||||
return input_doc.get('{filter_field}', False)
|
||||
""",
|
||||
}
|
||||
|
||||
# Insert code_filter after the fused operation
|
||||
fused_idx = next(
|
||||
i for i, op in enumerate(new_ops_list) if op["name"] == fused_op["name"]
|
||||
)
|
||||
new_ops_list.insert(fused_idx + 1, code_filter_op)
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Main method that orchestrates directive instantiation.
|
||||
"""
|
||||
assert (
|
||||
len(target_ops) == 2
|
||||
), "Operator fusion requires exactly two target operations"
|
||||
|
||||
# Get the two operations to fuse
|
||||
target_op_configs = [op for op in operators if op["name"] in target_ops]
|
||||
|
||||
# Ensure they are in the correct order
|
||||
target_op_configs.sort(key=lambda op: target_ops.index(op["name"]))
|
||||
|
||||
# Step 1: Agent generates the instantiate schema
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
target_op_configs, agent_llm, message_history
|
||||
)
|
||||
|
||||
# Step 2: Apply transformation using the schema
|
||||
return (
|
||||
self.apply(global_default_model, operators, target_ops, rewrite),
|
||||
message_history,
|
||||
call_cost,
|
||||
)
|
||||
|
|
@ -1,304 +0,0 @@
|
|||
import json
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
ReduceChainingInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class ReduceChainingDirective(Directive):
|
||||
name: str = Field(
|
||||
default="reduce_chaining", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(default="Reduce => Map -> Reduce")
|
||||
nl_description: str = Field(
|
||||
default="Transform a reduce operation that processes long documents by inserting a Map step that extracts/processes relevant information from each document first, then modifying the reduce prompt to work with the processed results instead of full document content."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When a reduce operation iterates through long documents to extract specific information (e.g., locations, entities, themes) that could be pre-extracted per document to make the reduce step more efficient and focused. The target operator must be a reduce operator. You should specify a reduce operator as the target operator when choosing this directive."
|
||||
)
|
||||
instantiate_schema_type: Type[BaseModel] = ReduceChainingInstantiateSchema
|
||||
|
||||
example: str = Field(
|
||||
default=(
|
||||
"Original Reduce Op:\n"
|
||||
"- name: extract_all_locations\n"
|
||||
" type: reduce\n"
|
||||
" reduce_key: document_collection\n"
|
||||
" prompt: |\n"
|
||||
" Extract all distinct locations mentioned across these documents:\n"
|
||||
" {% for input in inputs %}\n"
|
||||
" Document: {{ input.document }}\n"
|
||||
" {% endfor %}\n"
|
||||
" Return a list of unique location names.\n"
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
" locations: list[str]\n"
|
||||
"\n"
|
||||
"Example InstantiateSchema:\n"
|
||||
"ReduceChainingInstantiateSchema(\n"
|
||||
" map_name='extract_document_locations',\n"
|
||||
" map_prompt='Extract all location names mentioned in this document:\\n{{ input.document }}\\nReturn a list of locations.',\n"
|
||||
" new_key='locations',\n"
|
||||
" modified_reduce_prompt='Combine and deduplicate all locations from these documents:\\n{% for input in inputs %}\\nLocations from document: {{ input.locations }}\\n{% endfor %}\\nReturn a list of unique location names.',\n"
|
||||
")"
|
||||
),
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="extract_distinct_entities",
|
||||
description="Should decompose entity extraction into map-reduce pattern",
|
||||
input_config={
|
||||
"name": "extract_all_people",
|
||||
"type": "reduce",
|
||||
"reduce_key": "doc_id",
|
||||
"prompt": "Extract all distinct person names mentioned across these documents:\n{% for input in inputs %}\nDocument: {{ input.text }}\n{% endfor %}\nReturn a list of unique person names.",
|
||||
"output": {"schema": {"people": "list[str]"}},
|
||||
},
|
||||
target_ops=["extract_all_people"],
|
||||
expected_behavior="Should create a map op to extract people from each document, then modify reduce to work with extracted lists",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="theme_analysis",
|
||||
description="Should decompose theme analysis into map-reduce pattern",
|
||||
input_config={
|
||||
"name": "analyze_themes",
|
||||
"type": "reduce",
|
||||
"reduce_key": "category",
|
||||
"prompt": "Identify common themes across these research papers:\n{% for input in inputs %}\nPaper: {{ input.content }}\n{% endfor %}\nReturn the main themes.",
|
||||
"output": {"schema": {"themes": "list[str]"}},
|
||||
},
|
||||
target_ops=["analyze_themes"],
|
||||
expected_behavior="Should create a map op to extract themes from each paper, then reduce to identify common ones",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ReduceChainingDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("ReduceChainingDirective")
|
||||
|
||||
def to_string_for_instantiate(self, original_op: Dict) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive.
|
||||
|
||||
Args:
|
||||
original_op (str): The YAML or string representation of the original operation.
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
return (
|
||||
f"You are an expert at optimizing data processing operations by decomposing complex reduce operations.\n\n"
|
||||
f"Original Operation:\n"
|
||||
f"{str(original_op)}\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to instantiate this directive by creating a Map operation that preprocesses individual documents, "
|
||||
f"and then modifying the Reduce operation to work with the preprocessed results instead of raw document content.\n\n"
|
||||
f"The goal is to make the reduce operation more efficient by having the map operation extract or process "
|
||||
f"the specific information needed from each document, rather than having the reduce operation process the full document content.\n\n"
|
||||
f"Key Requirements:\n"
|
||||
f"1. Create a Map operation that processes individual documents and extracts the relevant information\n"
|
||||
f"2. Choose an appropriate new key name for the Map operation's output\n"
|
||||
f"3. Modify the original reduce prompt to work with the processed results instead of the original document content\n"
|
||||
f"4. Ensure the final output schema and semantics remain the same\n"
|
||||
f"5. The modified reduce prompt should reference the new key, not the original document key\n\n"
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output the ReduceChainingInstantiateSchema with the map operation details and modified reduce prompt."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
original_op: Dict,
|
||||
expected_document_key: str,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive by decomposing the reduce operation.
|
||||
|
||||
Args:
|
||||
original_op (Dict): The original reduce operation.
|
||||
expected_document_key (str): The key that contains the document content to be processed.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
|
||||
Returns:
|
||||
ReduceChainingInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant for document processing pipelines.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(original_op),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=ReduceChainingInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = ReduceChainingInstantiateSchema(**parsed_res)
|
||||
|
||||
# Validate the schema
|
||||
ReduceChainingInstantiateSchema.validate_reduce_prompt_references_new_key(
|
||||
schema.modified_reduce_prompt, schema.new_key, expected_document_key
|
||||
)
|
||||
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model,
|
||||
ops_list: List[Dict],
|
||||
target_op: str,
|
||||
rewrite: ReduceChainingInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config.
|
||||
"""
|
||||
# Create a copy of the pipeline config
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find position of the target reduce op to modify
|
||||
pos_to_modify = None
|
||||
orig_op = None
|
||||
for i, op in enumerate(ops_list):
|
||||
if op["name"] == target_op:
|
||||
pos_to_modify = i
|
||||
orig_op = op
|
||||
break
|
||||
|
||||
if pos_to_modify is None:
|
||||
raise ValueError(
|
||||
f"Target operation '{target_op}' not found in operations list"
|
||||
)
|
||||
|
||||
# Determine the model to use
|
||||
default_model = orig_op.get("model", global_default_model)
|
||||
|
||||
# Create the new map operation
|
||||
new_map_op = {
|
||||
"name": rewrite.map_name,
|
||||
"type": "map",
|
||||
"prompt": rewrite.map_prompt,
|
||||
"model": default_model,
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
"output": {"schema": {rewrite.new_key: "string"}},
|
||||
}
|
||||
|
||||
# Modify the reduce operation
|
||||
modified_reduce_op = deepcopy(orig_op)
|
||||
modified_reduce_op["prompt"] = rewrite.modified_reduce_prompt
|
||||
|
||||
# Insert the map operation before the reduce operation
|
||||
new_ops_list.insert(pos_to_modify, new_map_op)
|
||||
# Update the reduce operation (now at pos_to_modify + 1)
|
||||
new_ops_list[pos_to_modify + 1] = modified_reduce_op
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
optimize_goal="acc",
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Assert that there is only one target op
|
||||
assert (
|
||||
len(target_ops) == 1
|
||||
), "There must be exactly one target op to instantiate this reduce chaining directive"
|
||||
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
|
||||
|
||||
# Ensure it's a reduce operation
|
||||
if target_op_config.get("type") != "reduce":
|
||||
raise ValueError(
|
||||
f"Target operation '{target_ops[0]}' must be a reduce operation"
|
||||
)
|
||||
|
||||
# Extract expected document key from the reduce prompt template
|
||||
prompt_template = target_op_config["prompt"]
|
||||
# Find all occurrences of {{ input.key }} in the prompt
|
||||
input_key_pattern = r"\{\{\s*([^\}\s]+)\s*\}\}"
|
||||
input_keys = list(set(re.findall(input_key_pattern, prompt_template)))
|
||||
print("input_keys: ", input_keys)
|
||||
# Heuristic: pick the key that's most likely to contain document content
|
||||
# Look for common document field names
|
||||
document_key_candidates = [
|
||||
key
|
||||
for key in input_keys
|
||||
if any(
|
||||
doc_word in key.lower()
|
||||
for doc_word in ["document", "text", "content", "body", "description"]
|
||||
)
|
||||
]
|
||||
|
||||
if document_key_candidates:
|
||||
expected_document_key = document_key_candidates[
|
||||
0
|
||||
] # Pick the first candidate
|
||||
elif input_keys:
|
||||
expected_document_key = input_keys[0] # Fall back to the first input key
|
||||
else:
|
||||
raise ValueError("No input keys found in the reduce operation prompt")
|
||||
|
||||
print(f"Detected document key: {expected_document_key}")
|
||||
|
||||
# Instantiate the directive
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
target_op_config,
|
||||
expected_document_key,
|
||||
agent_llm,
|
||||
message_history,
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators
|
||||
new_ops_plan = self.apply(
|
||||
global_default_model, operators, target_ops[0], rewrite
|
||||
)
|
||||
return new_ops_plan, message_history, call_cost
|
||||
|
|
@ -1,343 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import GleaningInstantiateSchema
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class ReduceGleaningDirective(Directive):
|
||||
name: str = Field(
|
||||
default="reduce_gleaning", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(default="Reduce => Reduce_m (with gleaning config)")
|
||||
nl_description: str = Field(
|
||||
default="""Adds a validation loop to Reduce operations: after each LLM generation during the reduce process, a separate "judge" LLM evaluates the output using a yes/no validation prompt. If the output fails, the original LLM refines its answer and repeats until the output passes or the max number of rounds is reached."""
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When reduce operations process complex documents that require comprehensive analysis and synthesis (e.g., research paper analysis, customer feedback consolidation, legal document review, literature synthesis) where outputs must be validated for completeness, accuracy, and proper coverage of all input materials."
|
||||
)
|
||||
|
||||
instantiate_schema_type: Type[BaseModel] = Field(default=GleaningInstantiateSchema)
|
||||
|
||||
example: str = Field(
|
||||
default="""
|
||||
Original Op (ReduceOpConfig):
|
||||
- name: synthesize_research_findings
|
||||
type: reduce
|
||||
reduce_key: research_domain
|
||||
prompt: |
|
||||
You are analyzing research papers in the {{ research_domain }} domain.
|
||||
Synthesize the following research findings into a comprehensive domain overview:
|
||||
|
||||
{% for paper in inputs %}
|
||||
**Paper {{ loop.index }}:**
|
||||
- Title: {{ paper.title }}
|
||||
- Key Findings: {{ paper.key_findings }}
|
||||
- Methodology: {{ paper.methodology }}
|
||||
- Limitations: {{ paper.limitations }}
|
||||
- Future Work: {{ paper.future_work }}
|
||||
|
||||
{% endfor %}
|
||||
|
||||
Generate a synthesis with the following structure:
|
||||
- **Domain Overview**: 2-3 sentences describing the field
|
||||
- **Major Findings**: List of 4-6 key insights across all papers
|
||||
- **Methodological Approaches**: Summary of research methods used
|
||||
- **Research Gaps**: Identified limitations and areas needing investigation
|
||||
- **Future Directions**: Consolidated recommendations for future research
|
||||
output:
|
||||
schema:
|
||||
domain_overview: "string"
|
||||
major_findings: "list[str]"
|
||||
methodological_approaches: "string"
|
||||
research_gaps: "list[str]"
|
||||
future_directions: "list[str]"
|
||||
|
||||
Example InstantiateSchema (what the agent should output):
|
||||
{
|
||||
"validation_prompt": "Verify that the synthesis includes all required sections (domain overview, major findings, methodological approaches, research gaps, future directions). Each major finding should be supported by evidence from multiple papers. Research gaps should be specific and actionable. The domain overview should accurately represent the scope covered by the input papers.",
|
||||
"num_rounds": 3,
|
||||
"model": "gpt-4o-mini"
|
||||
}
|
||||
""",
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="customer_feedback_analysis",
|
||||
description="Should add validation for comprehensive customer feedback analysis by product category",
|
||||
input_config={
|
||||
"name": "analyze_feedback_by_product",
|
||||
"type": "reduce",
|
||||
"reduce_key": "product_category",
|
||||
"prompt": """Analyze customer feedback for {{ product_category }} products.
|
||||
Create a comprehensive analysis from the following feedback:
|
||||
|
||||
{% for review in inputs %}
|
||||
Review {{ loop.index }}:
|
||||
- Rating: {{ review.rating }}/5
|
||||
- Comment: {{ review.comment }}
|
||||
- Customer Type: {{ review.customer_type }}
|
||||
- Date: {{ review.date }}
|
||||
{% endfor %}
|
||||
|
||||
Provide analysis with:
|
||||
- Overall sentiment and satisfaction level
|
||||
- Top 3 most praised features
|
||||
- Top 3 most criticized issues
|
||||
- Recommendations for product improvements
|
||||
- Customer segment insights""",
|
||||
"output": {
|
||||
"schema": {
|
||||
"overall_sentiment": "string",
|
||||
"satisfaction_level": "string",
|
||||
"top_praised_features": "list[str]",
|
||||
"top_criticized_issues": "list[str]",
|
||||
"improvement_recommendations": "list[str]",
|
||||
"segment_insights": "string",
|
||||
}
|
||||
},
|
||||
},
|
||||
target_ops=["analyze_feedback_by_product"],
|
||||
expected_behavior="Should add gleaning validation to ensure all feedback is considered, sentiment analysis is accurate, and recommendations are actionable",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="legal_contract_analysis",
|
||||
description="Should add validation for thorough legal contract analysis by contract type",
|
||||
input_config={
|
||||
"name": "analyze_contracts_by_type",
|
||||
"type": "reduce",
|
||||
"reduce_key": "contract_type",
|
||||
"prompt": """Analyze {{ contract_type }} contracts and extract key legal provisions.
|
||||
|
||||
Review the following contracts:
|
||||
{% for contract in inputs %}
|
||||
Contract {{ loop.index }}:
|
||||
- Party Names: {{ contract.parties }}
|
||||
- Key Terms: {{ contract.key_terms }}
|
||||
- Obligations: {{ contract.obligations }}
|
||||
- Termination Clauses: {{ contract.termination }}
|
||||
- Risk Factors: {{ contract.risks }}
|
||||
{% endfor %}
|
||||
|
||||
Provide comprehensive analysis including:
|
||||
- Common contractual patterns across all contracts
|
||||
- Standard terms and deviations
|
||||
- Risk assessment summary
|
||||
- Compliance requirements identified
|
||||
- Recommendations for contract standardization""",
|
||||
"output": {
|
||||
"schema": {
|
||||
"common_patterns": "list[str]",
|
||||
"standard_terms": "list[str]",
|
||||
"risk_assessment": "string",
|
||||
"compliance_requirements": "list[str]",
|
||||
"standardization_recommendations": "list[str]",
|
||||
}
|
||||
},
|
||||
},
|
||||
target_ops=["analyze_contracts_by_type"],
|
||||
expected_behavior="Should add gleaning validation to ensure all contract provisions are captured, risk analysis is thorough, and recommendations are legally sound",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="research_literature_synthesis",
|
||||
description="Should add validation for comprehensive literature synthesis by research topic",
|
||||
input_config={
|
||||
"name": "synthesize_literature_by_topic",
|
||||
"type": "reduce",
|
||||
"reduce_key": "research_topic",
|
||||
"prompt": """Synthesize research literature on {{ research_topic }}.
|
||||
|
||||
Analyze the following academic papers:
|
||||
{% for paper in inputs %}
|
||||
Paper {{ loop.index }}:
|
||||
- Title: {{ paper.title }}
|
||||
- Abstract: {{ paper.abstract }}
|
||||
- Methodology: {{ paper.methodology }}
|
||||
- Key Results: {{ paper.results }}
|
||||
- Conclusions: {{ paper.conclusions }}
|
||||
- Limitations: {{ paper.limitations }}
|
||||
{% endfor %}
|
||||
|
||||
Create a literature synthesis with:
|
||||
- Theoretical frameworks identified across papers
|
||||
- Consensus findings and contradictory results
|
||||
- Methodological approaches comparison
|
||||
- Research gaps and limitations summary
|
||||
- Future research directions and recommendations""",
|
||||
"output": {
|
||||
"schema": {
|
||||
"theoretical_frameworks": "list[str]",
|
||||
"consensus_findings": "list[str]",
|
||||
"contradictory_results": "list[str]",
|
||||
"methodological_comparison": "string",
|
||||
"research_gaps": "list[str]",
|
||||
"future_directions": "list[str]",
|
||||
}
|
||||
},
|
||||
},
|
||||
target_ops=["synthesize_literature_by_topic"],
|
||||
expected_behavior="Should add gleaning validation to ensure all papers are properly synthesized, contradictions are identified, and research gaps are accurately captured",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ReduceGleaningDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("ReduceGleaningDirective")
|
||||
|
||||
def to_string_for_instantiate(self, original_op: Dict) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive.
|
||||
|
||||
Args:
|
||||
original_op (str): The YAML or string representation of the original operation.
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
return (
|
||||
f"You are an expert at adding validation and refinement loops to reduce operations for document processing tasks.\n\n"
|
||||
f"Original Operation:\n"
|
||||
f"{str(original_op)}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to instantiate this directive by generating a configuration that adds validation loops to the reduce operation. "
|
||||
f"The gleaning configuration should include a validation prompt that evaluates the quality of the reduce output, "
|
||||
f"focusing on document analysis criteria such as:\n"
|
||||
f"- Completeness: Are all input documents/items properly considered and synthesized?\n"
|
||||
f"- Accuracy: Are the extracted insights, patterns, and conclusions accurate?\n"
|
||||
f"- Structure: Does the output follow the required format and include all requested fields?\n"
|
||||
f"- Comprehensiveness: Are key themes, patterns, and insights captured across all inputs?\n"
|
||||
f"- Consistency: Are the analysis and recommendations internally consistent?\n\n"
|
||||
f"For reduce operations, the LLM processes groups of related documents and creates consolidated, synthesized outputs. "
|
||||
f"The validation should ensure proper document analysis, synthesis quality, and adherence to output requirements.\n\n"
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output only the GleaningInstantiateSchema object that specifies how to validate and refine the output of the reduce operation."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
original_op: Dict,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive by decomposing the original operation.
|
||||
|
||||
Args:
|
||||
original_op (Dict): The original operation.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
|
||||
Returns:
|
||||
GleaningInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(original_op),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=GleaningInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = GleaningInstantiateSchema(**parsed_res)
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model: str,
|
||||
ops_list: List[Dict],
|
||||
target_op: str,
|
||||
rewrite: GleaningInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config by adding gleaning configuration to the target reduce operator.
|
||||
"""
|
||||
# Create a copy of the pipeline config
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find position of the target op to modify
|
||||
pos_to_replace = [
|
||||
i for i, op in enumerate(ops_list) if op["name"] == target_op
|
||||
][0]
|
||||
|
||||
# Add gleaning configuration to the target operator
|
||||
target_operator = new_ops_list[pos_to_replace]
|
||||
target_operator["gleaning"] = {
|
||||
"validation_prompt": rewrite.validation_prompt,
|
||||
"num_rounds": rewrite.num_rounds,
|
||||
"model": rewrite.model,
|
||||
}
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Assert that there is only one target op
|
||||
assert (
|
||||
len(target_ops) == 1
|
||||
), "There must be exactly one target op to instantiate this reduce gleaning directive"
|
||||
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
|
||||
|
||||
# Verify the target operation is a reduce operation
|
||||
if target_op_config.get("type") != "reduce":
|
||||
raise ValueError(
|
||||
f"ReduceGleaningDirective can only be applied to reduce operations, got {target_op_config.get('type')}"
|
||||
)
|
||||
|
||||
# Instantiate the directive
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
target_op_config, agent_llm, message_history
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators
|
||||
return (
|
||||
self.apply(global_default_model, operators, target_ops[0], rewrite),
|
||||
message_history,
|
||||
call_cost,
|
||||
)
|
||||
|
|
@ -1,339 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import SwapWithCodeInstantiateSchema
|
||||
|
||||
from .agent_utils import AgenticDirectiveRunner
|
||||
from .base import Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class SwapWithCodeDirective(Directive):
|
||||
name: str = Field(default="swap_with_code", description="The name of the directive")
|
||||
formal_description: str = Field(default="Reduce => Code Reduce + Map")
|
||||
nl_description: str = Field(
|
||||
default="Replaces a Reduce operation with a Code Reduce operation for deterministic logic plus an optional Map operation to format the output. The Code Reduce handles the core reduction logic (like counting, collecting, or aggregating) while the optional Map operation converts the result to match the expected schema format."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When a reduce operation performs logic that can be implemented more efficiently or deterministically with code rather than an LLM. Examples include: counting distinct values, finding most common elements, basic aggregations, set operations, or mathematical computations. Particularly useful when the reduction logic is straightforward but the output needs to be formatted in a specific way for downstream operations."
|
||||
)
|
||||
|
||||
instantiate_schema_type: Type[BaseModel] = Field(
|
||||
default=SwapWithCodeInstantiateSchema
|
||||
)
|
||||
|
||||
example: str = Field(
|
||||
default="""
|
||||
Target Operation:
|
||||
- name: summarize_locations
|
||||
type: reduce
|
||||
reduce_key: "_all"
|
||||
prompt: |
|
||||
Summarize all distinct locations from these documents:
|
||||
{% for input in inputs %}{{ input.locations }}{% endfor %}
|
||||
output:
|
||||
schema:
|
||||
summary: "str"
|
||||
distinct_locations: "list[str]"
|
||||
|
||||
The agent might convert this to:
|
||||
1. Code Reduce that collects distinct locations: {"locations": ["NYC", "SF", "LA"]}
|
||||
2. Optional Map that formats this as: {"summary": "Found 3 distinct locations: NYC, SF, LA", "distinct_locations": ["NYC", "SF", "LA"]}
|
||||
|
||||
Example InstantiateSchema (what the agent should output):
|
||||
SwapWithCodeInstantiateSchema(
|
||||
code_reduce_name="collect_distinct_locations",
|
||||
code="def transform(inputs):\n locations = set()\n for item in inputs:\n if 'locations' in item and isinstance(item['locations'], list):\n locations.update(item['locations'])\n return {'distinct_locations': sorted(list(locations))}",
|
||||
map_prompt="Create a summary of the locations: {{ input.distinct_locations }}. Output format: summary (string describing the count and locations), distinct_locations (the original list)."
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="basic_reduce_to_code_reduce",
|
||||
description="Should convert reduce operation to code reduce + optional map",
|
||||
input_config={
|
||||
"name": "count_items",
|
||||
"type": "reduce",
|
||||
"reduce_key": "_all",
|
||||
"prompt": "Count the total items: {% for input in inputs %}{{ input.count }}{% endfor %}",
|
||||
"output": {"schema": {"total": "int"}},
|
||||
},
|
||||
target_ops=["count_items"],
|
||||
expected_behavior="Should replace reduce with code reduce that performs counting logic and optional map to format output",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="reduce_with_grouping",
|
||||
description="Should handle reduce operations with grouping keys",
|
||||
input_config={
|
||||
"name": "group_by_category",
|
||||
"type": "reduce",
|
||||
"reduce_key": "category",
|
||||
"prompt": "List all items for this category: {% for input in inputs %}{{ input.name }}{% endfor %}",
|
||||
"output": {"schema": {"category": "str", "items": "list[str]"}},
|
||||
},
|
||||
target_ops=["group_by_category"],
|
||||
expected_behavior="Should preserve reduce_key grouping in the code reduce operation",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, SwapWithCodeDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("SwapWithCodeDirective")
|
||||
|
||||
def to_string_for_instantiate(
|
||||
self, target_ops_configs: List[Dict], pipeline_code: Dict = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a prompt that asks the agent to analyze sample data and create a code reduce + optional map replacement.
|
||||
"""
|
||||
assert (
|
||||
len(target_ops_configs) == 1
|
||||
), "SwapWithCode directive only supports single target operation"
|
||||
|
||||
op = target_ops_configs[0]
|
||||
original_prompt = op.get("prompt", "")
|
||||
reduce_key = op.get("reduce_key", "_all")
|
||||
output_schema = op.get("output", {}).get("schema", {})
|
||||
|
||||
# Build pipeline context
|
||||
pipeline_context = ""
|
||||
if pipeline_code:
|
||||
pipeline_context = f"""
|
||||
Pipeline Context:
|
||||
{json.dumps(pipeline_code, indent=2)}
|
||||
|
||||
The target reduce operation '{op['name']}' fits into this broader pipeline. Consider:
|
||||
- What data flows into this operation from previous steps
|
||||
- How this operation's output will be used by subsequent operations
|
||||
- The overall goal of the pipeline when designing your code reduce + map solution
|
||||
"""
|
||||
|
||||
return (
|
||||
f"You are an expert at analyzing reduce operations and implementing efficient code-based alternatives.\n\n"
|
||||
f"Target Reduce Operation:\n"
|
||||
f"{json.dumps(op, indent=2)}\n\n"
|
||||
f"Original Prompt: {original_prompt}\n"
|
||||
f"Reduce Key: {reduce_key}\n"
|
||||
f"Expected Output Schema: {json.dumps(output_schema, indent=2)}\n\n"
|
||||
f"{pipeline_context}\n"
|
||||
f"Your task is to replace this reduce operation with:\n"
|
||||
f"1. A Code Reduce operation that implements the core reduction logic deterministically\n"
|
||||
f"2. An optional Map operation that formats the code reduce output to match the expected schema\n\n"
|
||||
f"You will be given access to sample input data through a read_next_docs() function. Use this to:\n"
|
||||
f"1. Understand the actual structure and patterns in the input data\n"
|
||||
f"2. Identify what the reduce operation is trying to accomplish\n"
|
||||
f"3. Design efficient Python code that performs the same reduction logic\n"
|
||||
f"4. Determine if a follow-up map operation is needed to format the output correctly\n"
|
||||
f"5. Consider edge cases and data variations in your implementation\n\n"
|
||||
f"Guidelines for the replacement:\n"
|
||||
f"- The Code Reduce must implement a 'transform' function that takes a list of inputs and returns a dictionary\n"
|
||||
f"- Include all necessary imports within the transform function\n"
|
||||
f"- The code should handle the same reduce_key grouping as the original operation\n"
|
||||
f"- If the code reduce output doesn't match the expected schema, provide a map_prompt to format it\n"
|
||||
f"- The map operation (if needed) should reference fields from the code reduce output using {{{{ input.field_name }}}}\n"
|
||||
f"- Focus on correctness, efficiency, and handling edge cases found in the sample data\n\n"
|
||||
f"Examples of good candidates for code reduce:\n"
|
||||
f"- Counting, summing, or basic mathematical operations\n"
|
||||
f"- Collecting distinct values or creating sets\n"
|
||||
f"- Finding min/max values or sorting\n"
|
||||
f"- Simple aggregations or list operations\n"
|
||||
f"- Deterministic text processing\n\n"
|
||||
f"Example transformation:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Analyze samples strategically to understand the data patterns and reduction requirements.\n"
|
||||
f"When you have enough information to create an efficient code-based solution, output your result."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
target_ops_configs: List[Dict],
|
||||
input_file_path: str,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
pipeline_code: Dict = None,
|
||||
):
|
||||
"""
|
||||
Use agentic approach to analyze sample data and generate code reduce + optional map replacement.
|
||||
"""
|
||||
# Load sample input data
|
||||
try:
|
||||
with open(input_file_path, "r") as f:
|
||||
input_data = json.load(f)
|
||||
|
||||
if not isinstance(input_data, list) or len(input_data) == 0:
|
||||
raise ValueError(
|
||||
"Input file must contain a non-empty list of sample data"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Failed to load input data from {input_file_path}: {str(e)}"
|
||||
)
|
||||
|
||||
# Create validation function
|
||||
def validate_code_reduce_schema(schema_instance):
|
||||
# Basic validation is handled by Pydantic validators
|
||||
# Could add additional validation here if needed
|
||||
pass
|
||||
|
||||
# Set up agentic runner with validation
|
||||
runner = AgenticDirectiveRunner(
|
||||
input_data=input_data,
|
||||
agent_llm=agent_llm,
|
||||
validation_func=validate_code_reduce_schema,
|
||||
)
|
||||
|
||||
# Create system prompt for the agentic runner
|
||||
system_prompt = (
|
||||
"You are an expert at analyzing reduce operations and designing efficient code-based alternatives. "
|
||||
"Your goal is to examine input samples to understand the reduction logic, then implement it as "
|
||||
"efficient Python code with optional formatting. You consider both performance and correctness "
|
||||
"while ensuring the output matches the expected schema through code reduce + optional map pattern."
|
||||
)
|
||||
|
||||
# Create initial user message
|
||||
initial_message = self.to_string_for_instantiate(
|
||||
target_ops_configs, pipeline_code
|
||||
)
|
||||
|
||||
# Run the agentic loop
|
||||
try:
|
||||
schema, updated_message_history, call_cost = runner.run_agentic_loop(
|
||||
system_prompt=system_prompt,
|
||||
initial_user_message=initial_message,
|
||||
response_schema=SwapWithCodeInstantiateSchema,
|
||||
)
|
||||
|
||||
# Update message history
|
||||
message_history.extend(updated_message_history)
|
||||
|
||||
return schema, message_history, call_cost
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to instantiate swap_with_code directive: {str(e)}")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model: str,
|
||||
ops_list: List[Dict],
|
||||
target_ops: List[str],
|
||||
rewrite: SwapWithCodeInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive by replacing the reduce operation with code reduce + optional map.
|
||||
"""
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find the target operation
|
||||
target_pos = None
|
||||
target_op = None
|
||||
for i, op in enumerate(new_ops_list):
|
||||
if op["name"] in target_ops:
|
||||
target_pos = i
|
||||
target_op = op
|
||||
break
|
||||
|
||||
if target_pos is None:
|
||||
raise ValueError(f"Target operation {target_ops[0]} not found")
|
||||
|
||||
# Get model from original reduce operation or use global default
|
||||
default_model = target_op.get("model", global_default_model)
|
||||
|
||||
# Create the code reduce operation
|
||||
code_reduce_op = {
|
||||
"name": rewrite.code_reduce_name,
|
||||
"type": "code_reduce",
|
||||
"code": rewrite.code,
|
||||
"reduce_key": target_op.get("reduce_key", "_all"),
|
||||
}
|
||||
|
||||
# Start with just the code reduce operation
|
||||
replacement_ops = [code_reduce_op]
|
||||
|
||||
# Add optional map operation if specified
|
||||
if rewrite.map_prompt is not None and rewrite.map_prompt.strip():
|
||||
map_op = {
|
||||
"name": target_op[
|
||||
"name"
|
||||
], # Keep the original name for the final output
|
||||
"type": "map",
|
||||
"prompt": rewrite.map_prompt,
|
||||
"model": default_model,
|
||||
"output": target_op.get("output", {}),
|
||||
}
|
||||
replacement_ops.append(map_op)
|
||||
else:
|
||||
# If no map operation, rename the code reduce to match original name
|
||||
code_reduce_op["name"] = target_op["name"]
|
||||
# Add output schema if it exists in original
|
||||
if "output" in target_op:
|
||||
code_reduce_op["output"] = target_op["output"]
|
||||
|
||||
# Replace the target operation with the replacement operations
|
||||
new_ops_list = (
|
||||
new_ops_list[:target_pos] + replacement_ops + new_ops_list[target_pos + 1 :]
|
||||
)
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Main method that orchestrates directive instantiation:
|
||||
1. Use agentic approach to analyze data and generate code reduce + optional map
|
||||
2. Apply the transformation using that configuration
|
||||
"""
|
||||
assert (
|
||||
len(target_ops) == 1
|
||||
), "SwapWithCode directive requires exactly one target operation"
|
||||
|
||||
input_file_path = kwargs.get("input_file_path", None)
|
||||
pipeline_code = kwargs.get("pipeline_code", None)
|
||||
|
||||
if not input_file_path:
|
||||
raise ValueError("input_file_path is required for SwapWithCode directive")
|
||||
|
||||
# Get configuration for target operation
|
||||
target_ops_configs = [op for op in operators if op["name"] in target_ops]
|
||||
|
||||
if not target_ops_configs:
|
||||
raise ValueError(f"Target operation {target_ops[0]} not found in operators")
|
||||
|
||||
# Validate that target operation is a reduce operation
|
||||
target_op = target_ops_configs[0]
|
||||
if target_op.get("type") != "reduce":
|
||||
raise ValueError(
|
||||
f"SwapWithCode directive can only be applied to reduce operations, but {target_ops[0]} is of type {target_op.get('type')}"
|
||||
)
|
||||
|
||||
# Step 1: Agent analyzes data and generates code reduce + optional map solution
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
target_ops_configs,
|
||||
input_file_path,
|
||||
agent_llm,
|
||||
message_history,
|
||||
pipeline_code,
|
||||
)
|
||||
|
||||
# Step 2: Apply transformation using the generated configuration
|
||||
return (
|
||||
self.apply(global_default_model, operators, target_ops, rewrite),
|
||||
message_history, call_cost
|
||||
)
|
||||
|
|
@ -1,321 +0,0 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import TakeHeadTailInstantiateSchema
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class TakeHeadTailDirective(Directive):
|
||||
name: str = Field(default="take_head_tail", description="The name of the directive")
|
||||
formal_description: str = Field(
|
||||
default="LLM_Op => Code Map -> LLM_Op",
|
||||
description="Inserts a Code Map operation before any LLM-powered operation (Map, Filter, Reduce) to truncate document content to head and tail words",
|
||||
)
|
||||
nl_description: str = Field(
|
||||
default="Reduces document length by keeping only the first k words and optionally the last l words of the longest document field. This improves cost efficiency and can enhance accuracy for tasks that only require document beginnings (like classification)."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When any LLM operation (Map, Filter, Reduce) only needs the beginning (and optionally end) of documents, such as classification tasks, filtering by document type, reducing document summaries, or when full document content causes accuracy issues due to too much context."
|
||||
)
|
||||
|
||||
instantiate_schema_type: Type[BaseModel] = Field(
|
||||
default=TakeHeadTailInstantiateSchema
|
||||
)
|
||||
|
||||
example: str = Field(
|
||||
default="""
|
||||
Original Map Operation (Research Paper Classification):
|
||||
- name: classify_research_domain
|
||||
type: map
|
||||
prompt: |
|
||||
What research domain does this paper belong to? Classify as Computer Science, Biology, Physics, Chemistry, or Other based on: {{ input.paper_text }}
|
||||
output:
|
||||
schema:
|
||||
domain: "string"
|
||||
confidence: "float"
|
||||
model: gpt-4o-mini
|
||||
|
||||
TakeHeadTailInstantiateSchema:
|
||||
{
|
||||
"name": "extract_paper_abstract",
|
||||
"document_key": "paper_text",
|
||||
"head_words": 150,
|
||||
"tail_words": 0
|
||||
}
|
||||
|
||||
Resulting Pipeline:
|
||||
- name: extract_paper_abstract
|
||||
type: code_map
|
||||
code: |
|
||||
def transform(input_doc):
|
||||
paper_text_content = input_doc.get('paper_text', '')
|
||||
words = paper_text_content.split()
|
||||
if len(words) <= 150:
|
||||
truncated = paper_text_content
|
||||
else:
|
||||
head = ' '.join(words[:150])
|
||||
truncated = head
|
||||
return {'paper_text': truncated}
|
||||
- name: classify_research_domain
|
||||
type: map
|
||||
prompt: |
|
||||
What research domain does this paper belong to? Classify as Computer Science, Biology, Physics, Chemistry, or Other based on: {{ input.paper_text }}
|
||||
output:
|
||||
schema:
|
||||
domain: "string"
|
||||
confidence: "float"
|
||||
model: gpt-4o-mini
|
||||
"""
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="research_paper_classification",
|
||||
description="Classify research papers by domain using only abstract/introduction",
|
||||
input_config={
|
||||
"name": "classify_paper_domain",
|
||||
"type": "map",
|
||||
"prompt": "What research domain does this paper belong to (CS, Biology, Physics, etc.)? Base your classification on the content: {{ input.full_text }}",
|
||||
"output": {"schema": {"domain": "string", "confidence": "float"}},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
target_ops=["classify_paper_domain"],
|
||||
expected_behavior="Should truncate full_text to first ~150 words (abstract/intro) since paper classification only needs the beginning, not the full methodology/results sections",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="document_metadata_extraction",
|
||||
description="Extract metadata from document headers/footers for indexing",
|
||||
input_config={
|
||||
"name": "extract_document_metadata",
|
||||
"type": "map",
|
||||
"prompt": "Extract the title, author, date, and document type from this document: {{ input.content }}",
|
||||
"output": {
|
||||
"schema": {
|
||||
"title": "string",
|
||||
"author": "string",
|
||||
"date": "string",
|
||||
"doc_type": "string",
|
||||
}
|
||||
},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
target_ops=["extract_document_metadata"],
|
||||
expected_behavior="Should keep both head (~100 words for headers/title) and tail (~50 words for footers/signatures) since metadata appears at document beginning and end",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="email_priority_classification",
|
||||
description="Classify email priority using subject and first paragraph",
|
||||
input_config={
|
||||
"name": "classify_email_priority",
|
||||
"type": "map",
|
||||
"prompt": "Classify this email as HIGH, MEDIUM, or LOW priority based on urgency indicators: {{ input.email_body }}",
|
||||
"output": {"schema": {"priority": "string", "reasoning": "string"}},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
target_ops=["classify_email_priority"],
|
||||
expected_behavior="Should truncate email_body to first ~75 words since email priority is determined by subject line and opening, not full conversation thread",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="legal_document_type_identification",
|
||||
description="Identify legal document type from contract headers and signature blocks",
|
||||
input_config={
|
||||
"name": "identify_legal_doc_type",
|
||||
"type": "map",
|
||||
"prompt": "What type of legal document is this (contract, agreement, policy, etc.)? Analyze: {{ input.legal_text }}",
|
||||
"output": {
|
||||
"schema": {
|
||||
"document_type": "string",
|
||||
"parties_involved": "list[string]",
|
||||
}
|
||||
},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
target_ops=["identify_legal_doc_type"],
|
||||
expected_behavior="Should keep head (~200 words for title/parties) and tail (~100 words for signature blocks) since legal doc type is indicated at beginning and parties sign at end",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="spam_email_filtering",
|
||||
description="Filter out spam emails based on subject line and opening content",
|
||||
input_config={
|
||||
"name": "filter_spam_emails",
|
||||
"type": "filter",
|
||||
"prompt": "Is this email spam? Look for suspicious patterns in: {{ input.email_content }}",
|
||||
"output": {"schema": {"_bool": "bool"}},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
target_ops=["filter_spam_emails"],
|
||||
expected_behavior="Should truncate email_content to first ~100 words since spam detection relies on subject, sender, and opening content, not full email thread",
|
||||
should_pass=True,
|
||||
),
|
||||
DirectiveTestCase(
|
||||
name="research_findings_synthesis",
|
||||
description="Reduce multiple research papers into a unified findings summary",
|
||||
input_config={
|
||||
"name": "synthesize_research_findings",
|
||||
"type": "reduce",
|
||||
"prompt": "Synthesize the key findings from these research abstracts and conclusions: {% for doc in inputs %}{{ doc.paper_content }}{% endfor %}",
|
||||
"output": {
|
||||
"schema": {"synthesis": "string", "key_themes": "list[string]"}
|
||||
},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
target_ops=["synthesize_research_findings"],
|
||||
expected_behavior="Should keep head (~200 words for abstracts) and tail (~150 words for conclusions) from each paper since synthesis needs both research goals and outcomes",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TakeHeadTailDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("TakeHeadTailDirective")
|
||||
|
||||
def to_string_for_instantiate(self, original_op: Dict) -> str:
|
||||
op_type = original_op.get("type", "operation")
|
||||
op_type_caps = op_type.capitalize()
|
||||
|
||||
return (
|
||||
f"You are an expert at optimizing document processing pipelines for cost and accuracy.\n\n"
|
||||
f"Original {op_type_caps} Operation:\n"
|
||||
f"{str(original_op)}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to instantiate this directive by generating a TakeHeadTailInstantiateSchema "
|
||||
f"that specifies how to truncate document content to improve efficiency.\n\n"
|
||||
f"The directive will insert a Code Map operation before the target {op_type_caps} that:\n"
|
||||
f"1. Identifies the document key with the longest text content\n"
|
||||
f"2. Keeps only the first 'head_words' words\n"
|
||||
f"3. Optionally keeps the last 'tail_words' words (default 0)\n"
|
||||
f"4. Returns the truncated content in the same key\n\n"
|
||||
f"Guidelines:\n"
|
||||
f"- Choose head_words based on how much context the task likely needs\n"
|
||||
f"- Set tail_words > 0 only if the task benefits from document endings (e.g., conclusions, signatures)\n"
|
||||
f"- For classification/filtering: typically 50-150 head_words, tail_words=0\n"
|
||||
f"- For summarization/reduction: typically 100-300 head_words, tail_words=50-100\n"
|
||||
f"- For metadata extraction: head=100-200, tail=50-100 (headers + footers)\n"
|
||||
f"- Identify the document_key by looking at {{{{ input.KEY }}}} references in the prompt\n\n"
|
||||
f"Operation Type Considerations:\n"
|
||||
f"- Map: Focus on what information is needed for transformation\n"
|
||||
f"- Filter: Focus on what information is needed for the boolean decision\n"
|
||||
f"- Reduce: Focus on what information is needed for aggregation/synthesis\n\n"
|
||||
f"Example configuration:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output only the TakeHeadTailInstantiateSchema object that specifies:\n"
|
||||
f"- name: descriptive name for the truncation operation\n"
|
||||
f"- document_key: the key containing text to truncate\n"
|
||||
f"- head_words: number of words to keep from start\n"
|
||||
f"- tail_words: number of words to keep from end (0 if not needed)"
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
original_op: Dict,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
):
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(original_op),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=TakeHeadTailInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = TakeHeadTailInstantiateSchema(**parsed_res)
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
ops_list: List[Dict],
|
||||
target_op: str,
|
||||
rewrite: TakeHeadTailInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
pos_to_replace = [
|
||||
i for i, op in enumerate(ops_list) if op["name"] == target_op
|
||||
][0]
|
||||
|
||||
head_words = rewrite.head_words
|
||||
tail_words = rewrite.tail_words
|
||||
document_key = rewrite.document_key
|
||||
|
||||
code_map_function = f"""def transform(input_doc):
|
||||
{document_key}_content = input_doc.get('{document_key}', '')
|
||||
words = {document_key}_content.split()
|
||||
|
||||
if len(words) <= {head_words + tail_words}:
|
||||
# Document is short enough, keep as is
|
||||
truncated = {document_key}_content
|
||||
else:
|
||||
head = ' '.join(words[:{head_words}])
|
||||
if {tail_words} > 0:
|
||||
tail = ' '.join(words[-{tail_words}:])
|
||||
truncated = head + ' ... ' + tail
|
||||
else:
|
||||
truncated = head
|
||||
|
||||
return {{'{document_key}': truncated}}"""
|
||||
|
||||
code_map_op = {
|
||||
"name": rewrite.name,
|
||||
"type": "code_map",
|
||||
"code": code_map_function,
|
||||
}
|
||||
|
||||
new_ops_list.insert(pos_to_replace, code_map_op)
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
**kwargs,
|
||||
):
|
||||
assert (
|
||||
len(target_ops) == 1
|
||||
), "TakeHeadTail directive requires exactly one target op"
|
||||
|
||||
target_op_config = [op for op in operators if op["name"] == target_ops[0]][0]
|
||||
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
target_op_config, agent_llm, message_history
|
||||
)
|
||||
|
||||
return self.apply(operators, target_ops[0], rewrite), message_history, call_cost
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,180 +0,0 @@
|
|||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import tiktoken
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from docetl.dataset import Dataset, create_parsing_tool_map
|
||||
from docetl.utils import load_config
|
||||
|
||||
|
||||
def extract_input_schema(
|
||||
data: List[Dict[str, Any]], max_samples: int = 10
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Extract the input schema from JSON data by analyzing the structure.
|
||||
|
||||
Args:
|
||||
data: List of dictionaries containing the dataset
|
||||
max_samples: Maximum number of records to analyze (default: 10)
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[str, Any]]: Schema mapping field names to their type and token info
|
||||
"""
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
# Sample records for analysis (to avoid processing entire large datasets)
|
||||
sample_size = min(max_samples, len(data))
|
||||
sample_data = data[:sample_size]
|
||||
|
||||
schema = {}
|
||||
|
||||
def count_field_tokens(value: Any, model="gpt-4o") -> int:
|
||||
"""Count tokens for a single field value."""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
if value is None:
|
||||
return 0
|
||||
elif isinstance(value, (dict, list)):
|
||||
value_str = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
value_str = str(value)
|
||||
tokens = encoding.encode(value_str)
|
||||
return len(tokens)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def infer_type(value: Any) -> str:
|
||||
"""Infer the type of a value."""
|
||||
if value is None:
|
||||
return "string" # Default to string for null values
|
||||
elif isinstance(value, bool):
|
||||
return "boolean"
|
||||
elif isinstance(value, int):
|
||||
return "integer"
|
||||
elif isinstance(value, float):
|
||||
return "number"
|
||||
elif isinstance(value, list):
|
||||
if not value:
|
||||
return "list[string]" # Default for empty lists
|
||||
# Check first few elements to determine list type
|
||||
sample_elements = value[:3]
|
||||
element_types = [infer_type(elem) for elem in sample_elements]
|
||||
# If all elements are the same type, use that type
|
||||
if len(set(element_types)) == 1:
|
||||
return f"list[{element_types[0]}]"
|
||||
else:
|
||||
return "list[string]" # Mixed types default to string
|
||||
elif isinstance(value, dict):
|
||||
return "dict"
|
||||
else:
|
||||
return "string"
|
||||
|
||||
# Analyze each field across all sample records
|
||||
all_fields = set()
|
||||
for record in sample_data:
|
||||
all_fields.update(record.keys())
|
||||
|
||||
# For each field, determine the most common type and token statistics
|
||||
for field in sorted(all_fields):
|
||||
field_values = []
|
||||
field_tokens = []
|
||||
|
||||
for record in sample_data:
|
||||
if field in record:
|
||||
value = record[field]
|
||||
field_values.append(value)
|
||||
token_count = count_field_tokens(value)
|
||||
field_tokens.append(token_count)
|
||||
|
||||
if not field_values:
|
||||
schema[field] = {"type": "string", "avg_tokens": 0}
|
||||
continue
|
||||
|
||||
# Count type occurrences
|
||||
type_counts = {}
|
||||
for value in field_values:
|
||||
value_type = infer_type(value)
|
||||
type_counts[value_type] = type_counts.get(value_type, 0) + 1
|
||||
|
||||
# Use the most common type
|
||||
most_common_type = max(type_counts.items(), key=lambda x: x[1])[0]
|
||||
avg_tokens = sum(field_tokens) / len(field_tokens)
|
||||
|
||||
schema[field] = {"type": most_common_type, "avg_tokens": round(avg_tokens, 1)}
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def count_tokens_in_data(data, model="gpt-4o"):
|
||||
"""
|
||||
Count the total number of tokens in the data using tiktoken.
|
||||
|
||||
Args:
|
||||
data: List of dictionaries containing the dataset
|
||||
model: The model to use for tokenization (default: gpt-4o)
|
||||
|
||||
Returns:
|
||||
int: Total number of tokens
|
||||
"""
|
||||
try:
|
||||
# Get the encoding for the specified model
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
|
||||
total_tokens = 0
|
||||
|
||||
for item in data:
|
||||
# Convert the entire item to a JSON string for tokenization
|
||||
item_str = json.dumps(item, ensure_ascii=False)
|
||||
tokens = encoding.encode(item_str)
|
||||
total_tokens += len(tokens)
|
||||
return total_tokens
|
||||
|
||||
except Exception as e:
|
||||
print(f" [WARNING] Could not count tokens: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def load_input_doc(yaml_path):
|
||||
doc_info = ""
|
||||
load_dotenv()
|
||||
try:
|
||||
config = load_config(yaml_path)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to load config: {e}")
|
||||
return doc_info
|
||||
|
||||
parsing_tool_map = create_parsing_tool_map(config.get("parsing_tools", None))
|
||||
datasets_config = config.get("datasets", {})
|
||||
if not datasets_config:
|
||||
print("[ERROR] No datasets found in config.")
|
||||
return doc_info
|
||||
|
||||
for name, dataset_config in datasets_config.items():
|
||||
doc_info += f"Dataset: {name}\n"
|
||||
try:
|
||||
ds = Dataset(
|
||||
runner=None,
|
||||
type=dataset_config["type"],
|
||||
path_or_data=dataset_config["path"],
|
||||
source=dataset_config.get("source", "local"),
|
||||
parsing=dataset_config.get("parsing", []),
|
||||
user_defined_parsing_tool_map=parsing_tool_map,
|
||||
)
|
||||
data = ds.load()
|
||||
|
||||
if data:
|
||||
doc_info += f" Type: {ds.type}\n"
|
||||
doc_info += f" Records loaded: {len(data)}\n"
|
||||
schema = extract_input_schema(data)
|
||||
doc_info += " Input schema:\n"
|
||||
for field, field_info in schema.items():
|
||||
doc_info += f" {field}: {field_info['type']} (avg: {field_info['avg_tokens']} tokens)\n"
|
||||
token_count = count_tokens_in_data(data)
|
||||
if token_count is not None:
|
||||
doc_info += f" Total tokens: {token_count:,}\n"
|
||||
|
||||
except Exception as e:
|
||||
doc_info += f" [ERROR] Failed to load dataset '{name}': {e}\n"
|
||||
return doc_info
|
||||
|
|
@ -1,638 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Operator(BaseModel):
|
||||
"""Operator model for each DocETL operator."""
|
||||
|
||||
# Fields matching your spreadsheet columns
|
||||
name: str = Field(..., description="Operation name")
|
||||
type_llm_or_not: str = Field(
|
||||
..., description="Type (LLM-powered or not LLM-powered)"
|
||||
)
|
||||
description: str = Field(..., description="Description")
|
||||
when_to_use: str = Field(..., description="When to Use")
|
||||
required_parameters: str = Field(..., description="Required Parameters")
|
||||
optional_parameters: Optional[str] = Field(None, description="Optional Parameters")
|
||||
returns: str = Field(..., description="Returns")
|
||||
minimal_example_configuration: str = Field(
|
||||
..., description="Minimal Example Configuration"
|
||||
)
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Serialize operator for prompts."""
|
||||
parts = [
|
||||
f"## {self.name} ({self.type_llm_or_not})",
|
||||
f"**Description:** {self.description}",
|
||||
f"**When to Use:** {self.when_to_use}",
|
||||
f"**Required Parameters:**\n{self.required_parameters}",
|
||||
]
|
||||
|
||||
if self.optional_parameters:
|
||||
parts.append(f"**Optional Parameters:**\n{self.optional_parameters}")
|
||||
|
||||
parts.append(f"**Returns:** {self.returns}")
|
||||
parts.append(
|
||||
f"**Example Configuration:**\n{self.minimal_example_configuration}\n"
|
||||
)
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
op_map = Operator(
|
||||
name="Map",
|
||||
type_llm_or_not="LLM-powered",
|
||||
description="Processes each document independently by making an LLM call with your prompt template. Creates one output for each input document, with the output conforming to your specified schema.",
|
||||
when_to_use="Use when you need to process each document individually - like extracting, summarizing, classifying, or generating new fields based on document content.",
|
||||
required_parameters="""
|
||||
name: Unique name for the operation
|
||||
type: Must be "map"
|
||||
model: LLM to use to execute the prompt
|
||||
prompt: Jinja2 template with {{ input.key }} for each document field you want to reference
|
||||
output.schema: Dictionary defining the structure and types of the output""",
|
||||
optional_parameters="""
|
||||
gleaning: Iteratively refines outputs that don't meet quality criteria. The LLM reviews its initial output and improves it based on validation feedback. Config requires:
|
||||
- if: Python expression for when to refine; e.g., "len(output[key] == 0)" (optional)
|
||||
- num_rounds: Max refinement iterations
|
||||
- validation_prompt: What to improve
|
||||
- model: LLM to use to execute the validation prompt and provide feedback (defaults to same model as operation model)
|
||||
(Default: gleaning is not enabled)
|
||||
|
||||
calibrate: (bool) Processes a sample of documents first to create reference examples, then uses those examples in all subsequent prompts to ensure consistent outputs across the dataset
|
||||
(Default: calibrate is False)
|
||||
|
||||
num_calibration_docs: Number of docs to use for calibration (default: 10)""",
|
||||
returns="Each original document, augmented with new keys specified in output_schema",
|
||||
minimal_example_configuration="""
|
||||
name: gen_insights
|
||||
type: map
|
||||
model: gpt-4o-mini
|
||||
prompt: From the user log below, list 2-3 concise insights (1-2 words each) and 1-2 supporting actions per insight. Return as a list of dictionaries with 'insight' and 'supporting_actions'. Log: {{ input.log }}
|
||||
output:
|
||||
schema:
|
||||
insights_summary: "string"
|
||||
""",
|
||||
)
|
||||
|
||||
op_extract = Operator(
|
||||
name="Extract",
|
||||
type_llm_or_not="LLM-powered",
|
||||
description="""Pulls out specific portions of text exactly as they appear in the source document. The LLM identifies which parts to extract by providing line number ranges or regex patterns. Extracted text is saved to the original field name with "_extracted" suffix (e.g., report_text → report_text_extracted).""",
|
||||
when_to_use="Use when you need exact text from documents - like pulling out direct quotes, specific contract clauses, key findings, or any content that must be preserved word-for-word without LLM paraphrasing.",
|
||||
required_parameters="""
|
||||
name: Unique name for the operation
|
||||
type: Must be "extract"
|
||||
prompt: Instructions for what to extract (this is NOT a Jinja template; we will run the prompt independently for each key in document_keys)
|
||||
document_keys: List of document fields to extract from
|
||||
model: LLM to use to execute the extraction
|
||||
""",
|
||||
optional_parameters="""
|
||||
extraction_method: How the LLM specifies what to extract from each document:
|
||||
- "line_number" (default): The LLM outputs the line numbers or ranges in the document to extract. Use when relevant information is best identified by its position within each document.
|
||||
- "regex": The LLM generates a custom regex pattern for each document to match the target text. Use when the information varies in location but follows identifiable text patterns (e.g., emails, dates, or structured phrases).
|
||||
""",
|
||||
returns="""Each original document, augmented with {key}_extracted for each key in the specified document_keys""",
|
||||
minimal_example_configuration="""
|
||||
name: findings
|
||||
type: extract
|
||||
prompt: Extract all sections that discuss key findings, results, or conclusions from this research report. Focus on paragraphs that:
|
||||
- Summarize experimental outcomes
|
||||
- Present statistical results
|
||||
- Describe discovered insights
|
||||
- State conclusions drawn from the research
|
||||
Only extract the most important and substantive findings.
|
||||
document_keys: ["report_text"]
|
||||
model: gpt-4o-mini
|
||||
""",
|
||||
)
|
||||
|
||||
op_parallel_map = Operator(
|
||||
name="Parallel Map",
|
||||
type_llm_or_not="LLM-powered",
|
||||
description="Runs multiple independent map operations concurrently on each document. Each prompt generates specific fields, and all outputs are combined into a single result per input document. More efficient than sequential maps when transformations are independent.",
|
||||
when_to_use="Use when you need multiple independent analyses of the same document - like extracting different types of information, running multiple classifications, or generating various summaries from the same input.",
|
||||
required_parameters="""
|
||||
name: Unique name for the operation
|
||||
type: Must be "parallel_map"
|
||||
prompts: List of prompt configs, each with:
|
||||
- prompt: Jinja2 template (should reference document fields using {{ input.key }})
|
||||
- output_keys: List of fields this prompt generates
|
||||
output.schema: Combined schema for all outputs. This must be the union of the output_keys from all prompts—every key in the output schema should be generated by at least one prompt, and no keys should be missing.
|
||||
""",
|
||||
optional_parameters="""
|
||||
model: Default LLM for all prompts
|
||||
Per-prompt options:
|
||||
- model: (optional) Override the default LLM for this specific prompt
|
||||
- gleaning: (optional) Validation and refinement configuration for this prompt, akin to gleaning in map operation
|
||||
""",
|
||||
returns="Each original document augmented with new keys specified in output_schema",
|
||||
minimal_example_configuration="""
|
||||
name: analyze_resume
|
||||
type: parallel_map
|
||||
prompts:
|
||||
- prompt: Extract skills from: {{ input.resume }}
|
||||
output_keys: [skills]
|
||||
model: gpt-4o-mini
|
||||
- prompt: Calculate years of experience from: {{ input.resume }}
|
||||
output_keys: [years_exp]
|
||||
- prompt: Rate writing quality 1-10: {{ input.cover_letter }}
|
||||
output_keys: [writing_score]
|
||||
output:
|
||||
schema:
|
||||
skills: list[string]
|
||||
years_exp: float
|
||||
writing_score: integer
|
||||
""",
|
||||
)
|
||||
|
||||
op_filter = Operator(
|
||||
name="Filter",
|
||||
type_llm_or_not="LLM-powered",
|
||||
description="Evaluates each document with an LLM prompt and only keeps documents where the output is true. Documents evaluating to false are removed from the dataset entirely.",
|
||||
when_to_use="Use when you need to keep only documents meeting specific criteria - like filtering high-impact articles, relevant records, quality content, or documents matching complex conditions that require LLM judgment.",
|
||||
required_parameters="""
|
||||
name: Unique name for the operation
|
||||
type: Must be "filter"
|
||||
model: LLM to use to execute the prompt
|
||||
prompt: Jinja2 template that guides the LLM to return true or false (should reference document fields using {{ input.key }})
|
||||
output.schema: Must contain exactly one boolean field
|
||||
""",
|
||||
optional_parameters="""
|
||||
gleaning: Iteratively refines outputs that don't meet quality criteria. The LLM reviews its initial output and improves it based on validation feedback. Config requires:
|
||||
- if: Python expression for when to refine; e.g., "len(output[key] == 0)" (optional)
|
||||
- num_rounds: Max refinement iterations
|
||||
- validation_prompt: What to improve
|
||||
- model: LLM to use to execute the validation prompt and provide feedback (defaults to same model as operation model)
|
||||
(Default: gleaning is not enabled)
|
||||
|
||||
calibrate: (bool) Processes a sample of documents first to create reference examples, then uses those examples in all subsequent prompts to ensure consistent outputs across the dataset
|
||||
(Default: calibrate is False)
|
||||
|
||||
num_calibration_docs: Number of docs to use for calibration (default: 10)
|
||||
""",
|
||||
returns="Subset of documents; each document has same keys as before",
|
||||
minimal_example_configuration="""
|
||||
name: filter_insightful_comments
|
||||
type: filter
|
||||
prompt: Is this comment insightful?
|
||||
Comment: {{ input.comment }}
|
||||
Consider whether the comment adds a new perspective, explains reasoning, or deepens the discussion. Return true if it is insightful, false otherwise.
|
||||
output:
|
||||
schema:
|
||||
is_insightful: boolean
|
||||
model: gpt-4o-mini
|
||||
""",
|
||||
)
|
||||
|
||||
op_reduce = Operator(
|
||||
name="Reduce",
|
||||
type_llm_or_not="LLM-powered",
|
||||
description="Aggregates multiple documents with the same key value(s) into a single output. Groups documents by reduce_key, then applies an LLM prompt to each group to create one aggregated output per unique key combination.",
|
||||
when_to_use="Use when you need to summarize, consolidate, or analyze groups of related documents - like combining all reviews for a product, summarizing feedback by department, or aggregating patient records by ID.",
|
||||
required_parameters="""
|
||||
name: Unique name for the operation
|
||||
type: Must be "reduce"
|
||||
model: LLM used to execute the prompts
|
||||
reduce_key: Key or keys to group by (can be a string, a list of strings, or "_all" to aggregate over all documents)
|
||||
prompt: Jinja2 template that references {{ inputs }} (the list of grouped documents, where each document is a dictionary) and {{ reduce_key }}
|
||||
|
||||
fold_prompt: Template for incrementally processing large groups by folding batches into the current state. The template should reference:
|
||||
- {{ inputs }}: The current batch of documents to process
|
||||
- {{ output }}: The current aggregated state (matches output.schema)
|
||||
The LLM processes the data in batches, updating the state after each batch using the fold_prompt.
|
||||
|
||||
fold_batch_size: Number of documents to process in each fold batch
|
||||
|
||||
output.schema: Schema for the aggregated output
|
||||
""",
|
||||
optional_parameters="""
|
||||
value_sampling: Run the reduce operation only on a sample, to reduce processing cost and time. Specify a method:
|
||||
- random: Randomly select a subset of items from each group.
|
||||
- first_n: Select the first N items in each group.
|
||||
- cluster: Use clustering (e.g., K-means) to select a diverse, representative subset.
|
||||
- semantic_similarity: Select items most relevant to a provided query, based on embeddings.
|
||||
|
||||
Optional parameters for value_sampling:
|
||||
- enabled: true to activate value sampling (default: false)
|
||||
- method: One of the above sampling methods
|
||||
- sample_size: Number of items to sample from each group
|
||||
- For cluster: no additional parameters required
|
||||
- For semantic_similarity:
|
||||
- embedding_model: Embedding model to use (e.g., text-embedding-3-small)
|
||||
- embedding_keys: List of fields to embed (e.g., [review])
|
||||
- query_text: Text to focus sampling (e.g., "battery life and performance")
|
||||
""",
|
||||
returns="""A set of fewer documents; each document has reduce_keys and all the keys specified in the output schema. Note that many keys in the original documents, if not specified in reduce_keys (i.e., the groupby), will get dropped.""",
|
||||
minimal_example_configuration="""
|
||||
name: summarize_feedback
|
||||
type: reduce
|
||||
reduce_key: department
|
||||
model: gpt-4o-mini
|
||||
prompt: Summarize the customer feedback for {{ reduce_key.department }}:
|
||||
{% for item in inputs %}
|
||||
Feedback {{ loop.index }}: {{ item.feedback }}
|
||||
{% endfor %}
|
||||
Provide main points and overall sentiment.
|
||||
|
||||
fold_prompt: Incrementally update the summary and sentiment based on a batch of new feedback for {{ reduce_key.department }}.
|
||||
Current summary: {{ output.summary }}
|
||||
Current sentiment: {{ output.sentiment }}
|
||||
New feedback batch:
|
||||
{% for item in inputs %}
|
||||
Feedback {{ loop.index }}: {{ item.feedback }}
|
||||
{% endfor %}
|
||||
Return the updated summary and sentiment after incorporating the new feedback.
|
||||
|
||||
fold_batch_size: 25
|
||||
|
||||
output:
|
||||
schema:
|
||||
summary: string
|
||||
sentiment: string
|
||||
""",
|
||||
)
|
||||
|
||||
op_split = Operator(
|
||||
name="Split",
|
||||
type_llm_or_not="Not LLM-powered",
|
||||
description="Divides long text into smaller chunks based on token count or delimiters. Creates multiple output documents from each input, one per chunk. Adds chunk metadata including chunk ID and sequence number.",
|
||||
when_to_use="""Use when documents exceed LLM token limits, when processing long transcripts/reports/contracts and we need to read every portion of the document for the task. E.g., "extract all mentions of X from this document." """,
|
||||
required_parameters="""
|
||||
name: Unique name for the operation
|
||||
type: Must be "split"
|
||||
split_key: Field containing the text to split
|
||||
method: "token_count" or "delimiter" (defaults to "token_count")
|
||||
method_kwargs:
|
||||
- For "token_count": num_tokens (integer) — Number of tokens per split
|
||||
- For "delimiter": delimiter (string) — Delimiter to use for splitting the text
|
||||
""",
|
||||
optional_parameters=None,
|
||||
returns="""
|
||||
The Split operation generates multiple output items for each input item. Each output includes:
|
||||
- All original key-value pairs from the input item
|
||||
- {split_key}_chunk: The content of the split chunk
|
||||
- {op_name}_id: A unique identifier for the original document
|
||||
- {op_name}_chunk_num: The sequential number of the chunk within its original document
|
||||
""",
|
||||
minimal_example_configuration="""
|
||||
name: split_transcript
|
||||
type: split
|
||||
split_key: transcript
|
||||
method: token_count
|
||||
method_kwargs:
|
||||
num_tokens: 500
|
||||
""",
|
||||
)
|
||||
|
||||
op_gather = Operator(
|
||||
name="Gather",
|
||||
type_llm_or_not="Not LLM-powered",
|
||||
description="""Adds context from surrounding chunks to each chunk after splitting. Includes content from previous/next chunks and maintains document structure through header hierarchies. Creates a "rendered" version of each chunk with its context.""",
|
||||
when_to_use="Use after Split when chunks need context for accurate processing - essential for legal documents, technical manuals, or any content where references span chunks. Helps maintain document structure and cross-references.",
|
||||
required_parameters="""
|
||||
name: Unique name for the operation
|
||||
type: Must be "gather"
|
||||
content_key: Field containing the chunk content to gather
|
||||
doc_id_key: Field linking all chunks from the same original document
|
||||
order_key: Field specifying the order/sequence of each chunk within the document
|
||||
|
||||
peripheral_chunks: Configuration for including context from surrounding (previous and/or next) chunks. This helps each chunk retain important context that may be necessary for accurate downstream processing.
|
||||
- You can specify previous and/or next context, and control how much content to include before and after each chunk.
|
||||
- For each (previous/next), you can define:
|
||||
- head: The first chunk(s) in the section (e.g., the chunk immediately before/after the current one)
|
||||
- middle: Chunks between head and tail (e.g., summarized versions of further-away chunks)
|
||||
- tail: The last chunk(s) in the section (e.g., furthest before/after the current chunk you want to include)
|
||||
- For each section, specify:
|
||||
- count: Number of chunks to include (head and tail only)
|
||||
- content_key: Which field to use for context (e.g., full text or summary)
|
||||
|
||||
Example:
|
||||
peripheral_chunks:
|
||||
previous:
|
||||
head:
|
||||
count: 1
|
||||
content_key: full_content
|
||||
middle:
|
||||
content_key: summary_content
|
||||
tail:
|
||||
count: 2
|
||||
content_key: full_content
|
||||
next:
|
||||
head:
|
||||
count: 1
|
||||
content_key: full_content
|
||||
|
||||
- This config means:
|
||||
• Include the full content of 1 chunk before, summaries of all in-between previous chunks, and the full content of the 2 furthest-back previous chunks.
|
||||
• Include the full content of the 1 chunk immediately after the current chunk.
|
||||
- Use full content for immediate context (head/tail) and summaries for middle sections to balance completeness and token efficiency.
|
||||
- Only include next chunks if future context is important; by default, focus on previous for most text documents.
|
||||
""",
|
||||
optional_parameters="""
|
||||
doc_header_key: (optional) Field containing extracted headers for each chunk.
|
||||
- This field provides the hierarchical structure of document sections, enabling the Gather operation to reconstruct header context for each chunk.
|
||||
- To use this, you must first run a map operation that extracts headers from each chunk, using the following schema:
|
||||
headers: list of {header: string, level: integer}
|
||||
- Example map operation:
|
||||
|
||||
name: extract_headers
|
||||
type: map
|
||||
input:
|
||||
- agreement_text_chunk
|
||||
prompt: Extract any section headers from the following merger agreement chunk:
|
||||
{{ input.agreement_text_chunk }}
|
||||
Return the headers as a list, preserving their hierarchy.
|
||||
output.schema:
|
||||
headers: list[{header: string, level: integer}]
|
||||
""",
|
||||
returns="""
|
||||
The Gather operation produces one output item for each input chunk. Each output includes:
|
||||
|
||||
- All original key-value pairs from the input document
|
||||
- {content_key}_rendered: The content of the chunk, enriched with:
|
||||
• Reconstructed header hierarchy (if doc_header_key is provided)
|
||||
• Previous context (chunks before the current chunk, if configured)
|
||||
• The main chunk, clearly marked
|
||||
• Next context (chunks after the current chunk, if configured)
|
||||
• Indications of skipped content between included sections (e.g., "[... 500 characters skipped ...]")
|
||||
|
||||
For example, if your content_key is agreement_text_chunk, the Gather operation adds:
|
||||
agreement_text_chunk_rendered
|
||||
|
||||
Note: No additional unique identifier or chunk number fields are created by Gather. (Those are typically added by Split.) Gather focuses on adding the rendered, context-enriched content field.
|
||||
""",
|
||||
minimal_example_configuration="""
|
||||
name: add_context
|
||||
type: gather
|
||||
content_key: text_chunk
|
||||
doc_id_key: split_doc_id
|
||||
order_key: split_chunk_num
|
||||
|
||||
peripheral_chunks:
|
||||
previous:
|
||||
head:
|
||||
count: 1
|
||||
content_key: text_chunk
|
||||
tail:
|
||||
count: 2
|
||||
content_key: text_chunk
|
||||
""",
|
||||
)
|
||||
|
||||
op_unnest = Operator(
|
||||
name="Unnest",
|
||||
type_llm_or_not="Not LLM-powered",
|
||||
description="Expands arrays into multiple documents (one per element) or flattens dictionary fields into the parent document. For arrays, replaces the array with individual elements. For dicts, adds specified fields to parent while keeping original dict.",
|
||||
when_to_use="Use when you need to process array elements individually, or when flattening nested data structures for easier analysis. Essentially, this operation is for normalizing nested data.",
|
||||
required_parameters="""
|
||||
name: Unique name for the operation
|
||||
type: Must be "unnest"
|
||||
unnest_key: Field containing the array or dictionary to expand
|
||||
""",
|
||||
optional_parameters="""
|
||||
keep_empty: If true, empty arrays being exploded will be kept in the output (with value None). Default: false
|
||||
recursive: If true, the unnest operation will be applied recursively to nested arrays. Default: false
|
||||
depth: The maximum depth for recursive unnesting (only applicable if recursive is true)
|
||||
""",
|
||||
returns="Returns one output document for each element in the unnested array or dictionary. Each output preserves all original fields from the input document, and adds fields from the expanded element. For arrays, each item becomes its own document. For dictionaries, specified expand_fields are added as top-level fields in the output.",
|
||||
minimal_example_configuration="""
|
||||
name: expand_user
|
||||
type: unnest
|
||||
unnest_key: user_info
|
||||
""",
|
||||
)
|
||||
|
||||
op_sample = Operator(
|
||||
name="Sample",
|
||||
type_llm_or_not="Not LLM-powered",
|
||||
description="Selects a subset of documents from the input according to the specified sampling method. Used to generate a representative sample for further analysis or processing.",
|
||||
when_to_use="Use when you want to work with a smaller subset of your data for debugging, rapid prototyping, or to reduce compute cost. Also useful for sampling before downstream processing. Stratification can be applied to uniform, first, outliers, top_embedding, and top_fts methods. It ensures that the sample maintains the distribution of specified key(s) in the data or retrieves top items from each stratum.",
|
||||
required_parameters="""
|
||||
name: Unique name for the operation
|
||||
type: Must be "sample"
|
||||
method: The sampling method to use ("uniform", "outliers", "custom", "first", "top_embedding", or "top_fts")
|
||||
samples: Either a list of key-value pairs representing document ids and values, an integer count of samples, or a float fraction of samples.
|
||||
""",
|
||||
optional_parameters="""
|
||||
method: The sampling method to use. Options:
|
||||
- uniform: Randomly select the specified number or fraction of documents. When combined with stratification, maintains the distribution of the stratified groups.
|
||||
- first: Select the first N documents from the dataset. When combined with stratification, takes proportionally from each group.
|
||||
- top_embedding: Select top documents based on embedding similarity to a query. Requires the following in method_kwargs: keys: A list of keys to use for creating embeddings, query: The query string to match against (supports Jinja templates), embedding_model: (Optional) The embedding model to use. Defaults to "text-embedding-3-small".
|
||||
- top_fts: Retrieves the top N items using full-text search with BM25 algorithm. Requires the following in method_kwargs: keys: A list of keys to search within, query: The query string for keyword matching (supports Jinja templates).
|
||||
- outliers: Select or remove documents considered outliers based on embedding distance. Requiresthe following in method_kwargs: embedding_keys (fields to embed), std (standard deviation cutoff) or samples (number/fraction of outlier samples), and keep (true to keep, false to remove outliers; default false). Optionally, method_kwargs.center can specify the center point.
|
||||
- custom: Samples specific items by matching key-value pairs. Stratification is not supported with custom sampling.
|
||||
|
||||
samples: The number of samples to select (integer), fraction of documents to sample (float), or explicit list of document IDs (for custom).
|
||||
|
||||
random_state: Integer to seed the random generator for reproducible results (default: random each run).
|
||||
|
||||
stratify_key: Field or list of fields to stratify by (for uniform method stratified sampling)
|
||||
samples_per_group: When stratifying, sample N items per group vs. dividing total (for uniform method)
|
||||
|
||||
method_kwargs: Additional parameters required by the chosen sampling method, such as:
|
||||
- embedding_keys: List of fields to embed (for outliers)
|
||||
- std: Number of standard deviations for outlier cutoff (for outliers)
|
||||
- keep: true to keep or false to remove outliers (for outliers; default false)
|
||||
- center: Dictionary specifying the center for distance calculations (for outliers)
|
||||
""",
|
||||
returns="A subset of input documents, with the same schema as the original input.",
|
||||
minimal_example_configuration="""
|
||||
name: stratified_sample
|
||||
type: sample
|
||||
method: uniform
|
||||
samples: 0.2
|
||||
stratify_key: category
|
||||
""",
|
||||
)
|
||||
|
||||
op_resolve = Operator(
|
||||
name="Resolve",
|
||||
type_llm_or_not="LLM-powered",
|
||||
description="Identifies and canonicalizes duplicate or matching entities across your dataset using LLM-driven pairwise comparison and resolution prompts. Useful for data cleaning, deduplication, and standardizing variations created by LLMs in preceding map operations.",
|
||||
when_to_use="Use when you need to standardize documents that may refer to the same real-world entity but have inconsistent or duplicated fields (e.g., names, product titles, organizations) due to extraction, human error, or LLM variation.",
|
||||
required_parameters="""
|
||||
name: Unique name for the operation
|
||||
type: Must be "resolve"
|
||||
comparison_prompt: Jinja2 template for comparing two candidate records (refer to as {{ input1 }}, {{ input2 }})
|
||||
resolution_prompt: Jinja2 template for consolidating/mapping a set of matched records (refer to as {{ inputs }})
|
||||
output.schema: Dictionary defining the structure and types of the resolved output
|
||||
embedding_model: Model to use for creating embeddings for blocking (default: falls back to default_model)
|
||||
comparison_model: LLM to use for comparisons (default: falls back to default_model)
|
||||
resolution_model: LLM to use for final resolution (default: falls back to default_model)
|
||||
""",
|
||||
optional_parameters="""
|
||||
blocking_keys: List of fields for blocking—records must match on at least one key to be compared (default: all input keys)
|
||||
blocking_threshold: Embedding similarity threshold for blocking (only compare above this value)
|
||||
blocking_conditions: List of Python expressions for custom blocking logic (e.g., "left['ssn'][-4:] == right['ssn'][-4:]")
|
||||
embedding_batch_size: Number of records sent to embedding model per batch (default: 1000)
|
||||
compare_batch_size: Number of record pairs compared per batch (default: 500)
|
||||
limit_comparisons: Maximum number of pairwise comparisons (default: no limit)
|
||||
""",
|
||||
returns="One output document per input document, preserving the original document structure, but with specified fields in the output schema updated to their resolved (standardized) values. All other fields remain unchanged.",
|
||||
minimal_example_configuration="""
|
||||
name: standardize_patient_names
|
||||
type: resolve
|
||||
comparison_model: gpt-4o-mini
|
||||
resolution_model: gpt-4o
|
||||
embedding_model: text-embedding-3-small
|
||||
comparison_prompt: |
|
||||
Compare the following two patient name entries:
|
||||
Patient 1: {{ input1.patient_name }}
|
||||
Date of Birth 1: {{ input1.date_of_birth }}
|
||||
Patient 2: {{ input2.patient_name }}
|
||||
Date of Birth 2: {{ input2.date_of_birth }}
|
||||
Are these entries likely referring to the same patient? Respond "True" or "False".
|
||||
resolution_prompt: |
|
||||
Standardize these patient names into a single, consistent format:
|
||||
{% for entry in inputs %}
|
||||
Patient Name {{ loop.index }}: {{ entry.patient_name }}
|
||||
{% endfor %}
|
||||
Provide a single, standardized patient name.
|
||||
output:
|
||||
schema:
|
||||
patient_name: string
|
||||
blocking_keys:
|
||||
- last_name
|
||||
- date_of_birth
|
||||
blocking_threshold: 0.8
|
||||
blocking_conditions:
|
||||
- "left['last_name'][:2].lower() == right['last_name'][:2].lower()"
|
||||
- "left['date_of_birth'] == right['date_of_birth']"
|
||||
""",
|
||||
)
|
||||
|
||||
op_code_map = Operator(
|
||||
name="Code Map",
|
||||
type_llm_or_not="not LLM-powered",
|
||||
description="Applies a Python function to each document independently using custom code. Returns a dictionary of key-value pairs to UPDATE the original document with. Useful for deterministic transformations, regex processing, calculations, or leveraging external Python libraries.",
|
||||
when_to_use="Use when you need deterministic processing, complex calculations, regex/pattern matching, or want to leverage existing Python libraries. Ideal for structured data transformations that don't require LLM reasoning.",
|
||||
required_parameters="""
|
||||
name: Unique name for the operation
|
||||
type: Must be "code_map"
|
||||
code: Python code defining a function named 'transform' that takes an input document and returns a dictionary of updates. Must include all necessary imports within the function. Format: def transform(input_doc): ...""",
|
||||
optional_parameters="""
|
||||
drop_keys: List of keys to remove from output (default: None)
|
||||
concurrent_thread_count: Number of threads to use (default: number of logical CPU cores)""",
|
||||
returns="Each original document, updated with the key-value pairs returned by the transform function",
|
||||
minimal_example_configuration="""
|
||||
name: extract_keywords_deterministic
|
||||
type: code_map
|
||||
code: |
|
||||
def transform(input_doc):
|
||||
import re
|
||||
text = input_doc.get('content', '')
|
||||
# Extract words that are all caps (potential keywords)
|
||||
keywords = re.findall(r'\\b[A-Z]{2,}\\b', text)
|
||||
# Extract email addresses
|
||||
emails = re.findall(r'\\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}\\b', text)
|
||||
return {
|
||||
'keywords': list(set(keywords)),
|
||||
'email_count': len(emails),
|
||||
'word_count': len(text.split())
|
||||
}
|
||||
""",
|
||||
)
|
||||
|
||||
op_code_filter = Operator(
|
||||
name="Code Filter",
|
||||
type_llm_or_not="not LLM-powered",
|
||||
description="Filters documents based on custom Python logic. Uses a Python function that returns True to keep documents and False to filter them out. Useful for deterministic filtering based on calculations, regex patterns, or structured data conditions.",
|
||||
when_to_use="Use when you need deterministic filtering logic that doesn't require LLM reasoning - like filtering based on numeric thresholds, text patterns, data completeness, or complex boolean conditions.",
|
||||
required_parameters="""
|
||||
name: Unique name for the operation
|
||||
type: Must be "code_filter"
|
||||
code: Python code defining a function named 'transform' that takes an input document and returns a boolean (True to keep, False to filter out). Must include all necessary imports within the function. Format: def transform(input_doc): ...""",
|
||||
optional_parameters="""
|
||||
concurrent_thread_count: Number of threads to use (default: number of logical CPU cores)""",
|
||||
returns="Subset of input documents where the transform function returned True. Documents retain all original fields.",
|
||||
minimal_example_configuration="""
|
||||
name: filter_valid_scores
|
||||
type: code_filter
|
||||
code: |
|
||||
def transform(input_doc):
|
||||
score = input_doc.get('confidence_score', 0)
|
||||
text_length = len(input_doc.get('content', ''))
|
||||
# Keep documents with high confidence and sufficient content
|
||||
return score >= 0.8 and text_length >= 100
|
||||
""",
|
||||
)
|
||||
|
||||
op_topk = Operator(
|
||||
name="TopK",
|
||||
type_llm_or_not="LLM-powered or not LLM-powered",
|
||||
description="Retrieves the most relevant items from your dataset using semantic similarity, full-text search, or LLM-based comparison. Provides a specialized interface for retrieval tasks where you need to find and rank the best matching documents based on specific criteria.",
|
||||
when_to_use="Use when you need to find the most relevant documents for a query, filter large datasets to the most important items, implement retrieval-augmented generation (RAG) pipelines, or build recommendation systems. Choose this over general sampling when you specifically need the 'best' matches according to some criteria.",
|
||||
required_parameters="""
|
||||
name: Unique name for the operation
|
||||
type: Must be "topk"
|
||||
method: Retrieval method to use ("embedding" for semantic similarity, "fts" for full-text search, or "llm_compare" for LLM-based ranking)
|
||||
k: Number of items to retrieve (integer) or percentage (float between 0 and 1)
|
||||
keys: List of document fields to use for matching/comparison
|
||||
query: Query or ranking criteria (Jinja templates supported for embedding and fts methods only)
|
||||
""",
|
||||
optional_parameters="""
|
||||
embedding_model: Model for embeddings (default: "text-embedding-3-small"). Used for embedding and llm_compare methods.
|
||||
|
||||
model: LLM model for comparisons (required for llm_compare method)
|
||||
|
||||
batch_size: Batch size for LLM ranking (default: 10, only for llm_compare method)
|
||||
|
||||
stratify_key: Key(s) for stratified retrieval - ensures you retrieve top items from each group (string or list of strings). Not supported with llm_compare method.
|
||||
|
||||
Method-specific notes:
|
||||
- embedding: Uses semantic similarity via embeddings. Supports Jinja templates in query.
|
||||
- fts: Uses BM25 full-text search algorithm. No API costs. Supports Jinja templates in query.
|
||||
- llm_compare: Uses LLM for complex ranking based on multiple criteria. Most expensive but most flexible. Does NOT support Jinja templates in query (ranking criteria must be consistent across all documents).
|
||||
""",
|
||||
returns="Top k documents based on the specified method and query, with the same schema as the original input",
|
||||
minimal_example_configuration="""
|
||||
name: find_relevant_tickets
|
||||
type: topk
|
||||
method: embedding
|
||||
k: 5
|
||||
keys:
|
||||
- subject
|
||||
- description
|
||||
- customer_feedback
|
||||
query: "payment processing errors with international transactions"
|
||||
embedding_model: text-embedding-3-small
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
# List of all operators
|
||||
ALL_OPERATORS = [
|
||||
op_map,
|
||||
op_extract,
|
||||
op_parallel_map,
|
||||
op_filter,
|
||||
op_reduce,
|
||||
op_split,
|
||||
op_gather,
|
||||
op_unnest,
|
||||
op_sample,
|
||||
op_resolve,
|
||||
op_code_map,
|
||||
op_code_filter,
|
||||
op_topk,
|
||||
]
|
||||
|
||||
|
||||
def get_all_operator_descriptions() -> str:
|
||||
"""
|
||||
Generate a comprehensive string containing all operator descriptions.
|
||||
This is useful for providing context about available operators in prompts.
|
||||
|
||||
Returns:
|
||||
str: Formatted string containing all operator descriptions
|
||||
"""
|
||||
descriptions = []
|
||||
descriptions.append("# Available DocETL Operators\n")
|
||||
descriptions.append(
|
||||
"Below are all the operators available in the DocETL pipeline:\n"
|
||||
)
|
||||
|
||||
for op in ALL_OPERATORS:
|
||||
descriptions.append(op.to_string())
|
||||
|
||||
return "\n".join(descriptions)
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
from .base import Retriever
|
||||
from .lancedb import LanceDBRetriever
|
||||
|
||||
__all__ = ["Retriever", "LanceDBRetriever"]
|
||||
|
||||
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalResult:
|
||||
"""Container for retrieval outputs."""
|
||||
|
||||
docs: list[dict]
|
||||
rendered_context: str
|
||||
meta: dict[str, Any]
|
||||
|
||||
|
||||
class Retriever(ABC):
|
||||
"""Abstract base class for retrievers."""
|
||||
|
||||
name: str
|
||||
|
||||
def __init__(self, runner, name: str, config: dict[str, Any]):
|
||||
self.runner = runner
|
||||
self.name = name
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def ensure_index(self) -> None:
|
||||
"""Create or verify the underlying index/table."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def retrieve(self, context: dict[str, Any]) -> RetrievalResult:
|
||||
"""Execute retrieval based on the provided Jinja context."""
|
||||
raise NotImplementedError
|
||||
|
|
@ -1,358 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
from .base import RetrievalResult, Retriever
|
||||
|
||||
|
||||
class LanceDBRetriever(Retriever):
|
||||
"""OSS LanceDB-backed retriever supporting FTS, vector, and hybrid search."""
|
||||
|
||||
_index_lock = threading.Lock()
|
||||
_ensured = False
|
||||
|
||||
def _connect(self):
|
||||
# Lazy import to avoid hard dependency during linting/tests without install
|
||||
try:
|
||||
import lancedb # type: ignore
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
"lancedb is required for LanceDBRetriever. Please add it to dependencies."
|
||||
) from e
|
||||
index_dir = self.config["index_dir"]
|
||||
os.makedirs(index_dir, exist_ok=True)
|
||||
return lancedb.connect(index_dir)
|
||||
|
||||
def _table_name(self) -> str:
|
||||
# Stable table name based on retriever name
|
||||
return f"docetl_{self.name}"
|
||||
|
||||
def _iter_dataset_rows(self) -> list[dict]:
|
||||
"""Load dataset referenced by this retriever."""
|
||||
dataset_name = self.config["dataset"]
|
||||
# Ensure datasets are loaded
|
||||
if not getattr(self.runner, "datasets", None):
|
||||
self.runner.load()
|
||||
if dataset_name not in self.runner.datasets:
|
||||
raise ValueError(
|
||||
f"Retriever '{self.name}' references unknown dataset '{dataset_name}'."
|
||||
)
|
||||
return self.runner.datasets[dataset_name].load()
|
||||
|
||||
def _render_input_phrase(self, tmpl: str | None, input_obj: dict) -> str:
|
||||
if not tmpl:
|
||||
return ""
|
||||
try:
|
||||
return Template(tmpl).render(input=input_obj)
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def _batch_embed(self, texts: list[str]) -> list[list[float]]:
|
||||
model = self.config.get("embedding", {}).get("model")
|
||||
if not model:
|
||||
# If no embedding model, return empty vectors
|
||||
return []
|
||||
# Use DocETL embedding router + cache
|
||||
resp = self.runner.api.gen_embedding(model, json.dumps(texts))
|
||||
# litellm embedding response
|
||||
vectors = [d["embedding"] for d in resp["data"]]
|
||||
return vectors
|
||||
|
||||
def _index_types(self) -> set[str]:
|
||||
idx = self.config.get("index_types", [])
|
||||
if isinstance(idx, str):
|
||||
idx = [idx]
|
||||
idx = set(idx or [])
|
||||
# Interpret hybrid as both
|
||||
if "hybrid" in idx:
|
||||
idx.update({"fts", "embedding"})
|
||||
return idx
|
||||
|
||||
def ensure_index(self) -> None:
|
||||
build_policy = self.config.get("build_index", "if_missing")
|
||||
with self._index_lock:
|
||||
# Only build once per process; 'always' means rebuild once at startup
|
||||
if self._ensured:
|
||||
return
|
||||
|
||||
db = self._connect()
|
||||
table_name = self._table_name()
|
||||
dataset_name = self.config.get("dataset", "<unknown>")
|
||||
idx_types = self._index_types()
|
||||
try:
|
||||
self.runner.console.log(
|
||||
f"[cyan]LanceDB[/cyan] ensure_index: table='{table_name}', dataset='{dataset_name}', "
|
||||
f"index_types={sorted(list(idx_types))}, build_policy='{build_policy}'"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
table = None
|
||||
try:
|
||||
existing_names = db.table_names()
|
||||
except Exception as exc:
|
||||
# Surface the real error instead of sidestepping
|
||||
raise RuntimeError(f"Failed to list LanceDB tables: {exc}") from exc
|
||||
if build_policy == "never":
|
||||
try:
|
||||
self.runner.console.log(
|
||||
f"[yellow]LanceDB[/yellow] skipping index build (build_index=never) for '{table_name}'"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._ensured = True
|
||||
return
|
||||
if table_name in existing_names:
|
||||
if build_policy == "always":
|
||||
try:
|
||||
self.runner.console.log(
|
||||
f"[cyan]LanceDB[/cyan] dropping existing table '{table_name}' (build_index=always)"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
db.drop_table(table_name)
|
||||
else:
|
||||
# if_missing and table exists -> nothing to do
|
||||
try:
|
||||
self.runner.console.log(
|
||||
f"[green]LanceDB[/green] index exists for '{table_name}', skipping build (build_index=if_missing)"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._ensured = True
|
||||
return
|
||||
|
||||
if table is None:
|
||||
try:
|
||||
rows = self._iter_dataset_rows()
|
||||
except Exception:
|
||||
# Dataset may not be available yet (e.g., produced by a prior step)
|
||||
try:
|
||||
self.runner.console.log(
|
||||
f"[yellow]LanceDB[/yellow] dataset '{dataset_name}' not available; "
|
||||
f"deferring index build for '{table_name}'"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
idx_types = self._index_types()
|
||||
|
||||
# Build per-row phrases for FTS and Embedding
|
||||
fts_index_tmpl = self.config.get("fts", {}).get("index_phrase", None)
|
||||
emb_index_tmpl = self.config.get("embedding", {}).get(
|
||||
"index_phrase", None
|
||||
)
|
||||
|
||||
fts_texts: list[str] = []
|
||||
if "fts" in idx_types:
|
||||
fts_texts = [
|
||||
self._render_input_phrase(fts_index_tmpl, r) for r in rows
|
||||
]
|
||||
|
||||
emb_texts: list[str] = []
|
||||
if "embedding" in idx_types:
|
||||
emb_texts = [
|
||||
self._render_input_phrase(emb_index_tmpl, r) for r in rows
|
||||
]
|
||||
# If no template, fall back to fts texts
|
||||
if not any(emb_texts) and fts_texts:
|
||||
emb_texts = fts_texts
|
||||
# Compute embeddings for embedding index
|
||||
vectors: list[list[float]] = []
|
||||
if "embedding" in idx_types:
|
||||
# Batch embeddings
|
||||
batch = 128
|
||||
for i in range(0, len(emb_texts), batch):
|
||||
chunk = emb_texts[i : i + batch]
|
||||
if not chunk:
|
||||
continue
|
||||
vectors.extend(self._batch_embed(chunk))
|
||||
|
||||
# Construct records for LanceDB
|
||||
data = []
|
||||
for idx, r in enumerate(rows):
|
||||
rec = {
|
||||
"id": r.get("id", idx),
|
||||
"text": (
|
||||
fts_texts[idx]
|
||||
if fts_texts
|
||||
else (emb_texts[idx] if emb_texts else "")
|
||||
),
|
||||
}
|
||||
if vectors and idx < len(vectors):
|
||||
rec["vector"] = vectors[idx]
|
||||
# keep some fields for rendering
|
||||
data.append(rec)
|
||||
|
||||
# Create table
|
||||
table = db.create_table(table_name, data=data, mode="overwrite")
|
||||
try:
|
||||
self.runner.console.log(
|
||||
f"[green]LanceDB[/green] created table '{table_name}' with {len(data)} rows "
|
||||
f"(fts={'fts' in idx_types}, embedding={'embedding' in idx_types})"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Create FTS index if fts enabled
|
||||
if "fts" in idx_types:
|
||||
try:
|
||||
table.create_fts_index("text")
|
||||
try:
|
||||
self.runner.console.log(
|
||||
f"[green]LanceDB[/green] created FTS index on column 'text' for '{table_name}'"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
# Index may already exist or backend unavailable; continue
|
||||
try:
|
||||
self.runner.console.log(
|
||||
f"[yellow]LanceDB[/yellow] FTS index creation skipped/failed for '{table_name}'"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._ensured = True
|
||||
try:
|
||||
self.runner.console.log(
|
||||
f"[green]LanceDB[/green] index ready for '{table_name}'"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _render_query_phrase(self, tmpl: str | None, context: dict[str, Any]) -> str:
|
||||
if not tmpl:
|
||||
return ""
|
||||
try:
|
||||
return Template(tmpl).render(**context)
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def _select_mode(self) -> str:
|
||||
qmode = self.config.get("query", {}).get("mode", None)
|
||||
if qmode:
|
||||
return qmode
|
||||
# Auto: hybrid if both index types present
|
||||
idx = self._index_types()
|
||||
if "fts" in idx and "embedding" in idx:
|
||||
return "hybrid"
|
||||
if "fts" in idx:
|
||||
return "fts"
|
||||
if "embedding" in idx:
|
||||
return "embedding"
|
||||
return "fts"
|
||||
|
||||
def _reranker(self):
|
||||
# Lazy import and instantiate RRFReranker when needed
|
||||
try:
|
||||
from lancedb.rerankers import RRFReranker # type: ignore
|
||||
|
||||
rr = RRFReranker()
|
||||
return rr
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _limit_and_format(self, rows: list[dict]) -> list[dict]:
|
||||
top_k = int(self.config.get("query", {}).get("top_k", 5))
|
||||
return rows[:top_k]
|
||||
|
||||
def _fetch(self, context: dict[str, Any]) -> list[dict]:
|
||||
db = self._connect()
|
||||
table = db.open_table(self._table_name())
|
||||
mode = self._select_mode()
|
||||
top_k = int(self.config.get("query", {}).get("top_k", 5))
|
||||
|
||||
# Build queries
|
||||
text_query = self._render_query_phrase(
|
||||
self.config.get("fts", {}).get("query_phrase", None), context
|
||||
)
|
||||
vector_query: list[float] | None = None
|
||||
emb_qtext = self._render_query_phrase(
|
||||
self.config.get("embedding", {}).get("query_phrase", None), context
|
||||
)
|
||||
if emb_qtext:
|
||||
vecs = self._batch_embed([emb_qtext])
|
||||
vector_query = vecs[0] if vecs else None
|
||||
|
||||
# Execute
|
||||
try:
|
||||
if mode == "hybrid" and text_query and vector_query is not None:
|
||||
q = (
|
||||
table.search(query_type="hybrid")
|
||||
.vector(vector_query)
|
||||
.text(text_query)
|
||||
)
|
||||
rr = self._reranker()
|
||||
if rr:
|
||||
q = q.rerank(rr)
|
||||
df = q.limit(top_k).to_pandas()
|
||||
elif mode == "fts" and text_query:
|
||||
# FTS-only
|
||||
q = table.search(text_query, query_type="fts", fts_columns="text")
|
||||
df = q.limit(top_k).to_pandas()
|
||||
elif mode == "embedding" and vector_query is not None:
|
||||
q = table.search(vector_query)
|
||||
df = q.limit(top_k).to_pandas()
|
||||
else:
|
||||
return []
|
||||
except Exception:
|
||||
# Fallbacks
|
||||
try:
|
||||
q = (
|
||||
table.search(text_query, query_type="fts", fts_columns="text")
|
||||
if text_query
|
||||
else table.search(vector_query)
|
||||
)
|
||||
df = q.limit(top_k).to_pandas()
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
# Convert to list of dicts
|
||||
rows = df.to_dict(orient="records") if hasattr(df, "to_dict") else []
|
||||
return self._limit_and_format(rows)
|
||||
|
||||
def _render_docs(self, docs: list[dict]) -> str:
|
||||
# Minimal, opinionated rendering: bullets of text (truncated)
|
||||
max_chars = 1000
|
||||
lines: list[str] = []
|
||||
for i, d in enumerate(docs, start=1):
|
||||
snippet = str(d.get("text", ""))[:max_chars]
|
||||
if snippet:
|
||||
lines.append(f"- [{i}] {snippet}")
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=256)
|
||||
def _cache_key(name: str, ctx_key: str) -> str:
|
||||
return f"{name}:{ctx_key}"
|
||||
|
||||
def _ctx_fingerprint(self, context: dict[str, Any]) -> str:
|
||||
# Make a stable string key from context, using only parts used by templates
|
||||
if "input" in context:
|
||||
base = context["input"]
|
||||
elif "reduce_key" in context:
|
||||
base = {
|
||||
"reduce_key": context.get("reduce_key", {}),
|
||||
"inputs_len": len(context.get("inputs", [])),
|
||||
}
|
||||
else:
|
||||
base = context
|
||||
try:
|
||||
return json.dumps(base, sort_keys=True, ensure_ascii=False)[:2000]
|
||||
except Exception:
|
||||
return str(base)[:2000]
|
||||
|
||||
def retrieve(self, context: dict[str, Any]) -> RetrievalResult:
|
||||
self.ensure_index()
|
||||
# Build a fingerprint (reserved for future caching)
|
||||
_ = self._ctx_fingerprint(context)
|
||||
docs = self._fetch(context)
|
||||
rendered = self._render_docs(docs) if docs else "No extra context available."
|
||||
meta = {"retriever": self.name, "num_docs": len(docs)}
|
||||
return RetrievalResult(docs=docs, rendered_context=rendered, meta=meta)
|
||||
|
|
@ -93,7 +93,6 @@ class DSLRunner(ConfigWrapper):
|
|||
config: dict[str, Any] | None
|
||||
parsing_tools: list[schemas.ParsingTool] | None
|
||||
datasets: dict[str, schemas.Dataset]
|
||||
retrievers: dict[str, Any] | None
|
||||
operations: list[OpType]
|
||||
pipeline: schemas.PipelineSpec
|
||||
|
||||
|
|
@ -120,7 +119,6 @@ class DSLRunner(ConfigWrapper):
|
|||
self.total_cost = 0
|
||||
self._initialize_state()
|
||||
self._setup_parsing_tools()
|
||||
self._setup_retrievers()
|
||||
self._build_operation_graph(config)
|
||||
self._compute_operation_hashes()
|
||||
|
||||
|
|
@ -142,31 +140,6 @@ class DSLRunner(ConfigWrapper):
|
|||
self.config.get("parsing_tools", None)
|
||||
)
|
||||
|
||||
def _setup_retrievers(self) -> None:
|
||||
"""Instantiate retrievers from configuration (lazy index creation)."""
|
||||
from docetl.retrievers.lancedb import LanceDBRetriever
|
||||
|
||||
self.retrievers: dict[str, Any] = {}
|
||||
retrievers_cfg = self.config.get("retrievers", {}) or {}
|
||||
for name, rconf in retrievers_cfg.items():
|
||||
if not isinstance(rconf, dict):
|
||||
raise ValueError(f"Invalid retriever '{name}' configuration")
|
||||
if rconf.get("type") != "lancedb":
|
||||
raise ValueError(
|
||||
f"Unsupported retriever type '{rconf.get('type')}' for '{name}'. Only 'lancedb' is supported."
|
||||
)
|
||||
required = ["dataset", "index_dir", "index_types"]
|
||||
for key in required:
|
||||
if key not in rconf:
|
||||
raise ValueError(
|
||||
f"Retriever '{name}' missing required key '{key}'."
|
||||
)
|
||||
# Defaults
|
||||
rconf.setdefault("query", {"top_k": 5})
|
||||
rconf.setdefault("build_index", "if_missing")
|
||||
|
||||
self.retrievers[name] = LanceDBRetriever(self, name, rconf)
|
||||
|
||||
def _build_operation_graph(self, config: dict) -> None:
|
||||
"""Build the DAG of operations from configuration"""
|
||||
self.config = config
|
||||
|
|
@ -699,7 +672,7 @@ class DSLRunner(ConfigWrapper):
|
|||
"litellm_kwargs", {}
|
||||
)
|
||||
kwargs["rewrite_agent_model"] = self.config.get("optimizer_config", {}).get(
|
||||
"rewrite_agent_model", "gpt-5.1"
|
||||
"rewrite_agent_model", "gpt-4o"
|
||||
)
|
||||
kwargs["judge_agent_model"] = self.config.get("optimizer_config", {}).get(
|
||||
"judge_agent_model", "gpt-4o-mini"
|
||||
|
|
@ -727,7 +700,7 @@ class DSLRunner(ConfigWrapper):
|
|||
"litellm_kwargs", {}
|
||||
)
|
||||
kwargs["rewrite_agent_model"] = self.config.get("optimizer_config", {}).get(
|
||||
"rewrite_agent_model", "gpt-5.1"
|
||||
"rewrite_agent_model", "gpt-4o"
|
||||
)
|
||||
kwargs["judge_agent_model"] = self.config.get("optimizer_config", {}).get(
|
||||
"judge_agent_model", "gpt-4o-mini"
|
||||
|
|
|
|||
|
|
@ -4,9 +4,7 @@ from .base_schemas import * # noqa: F403
|
|||
# ruff: noqa: F403
|
||||
from .operations import (
|
||||
cluster,
|
||||
code_operations,
|
||||
equijoin,
|
||||
extract,
|
||||
filter,
|
||||
gather,
|
||||
map,
|
||||
|
|
@ -28,10 +26,6 @@ GatherOp = gather.GatherOperation.schema
|
|||
UnnestOp = unnest.UnnestOperation.schema
|
||||
ClusterOp = cluster.ClusterOperation.schema
|
||||
SampleOp = sample.SampleOperation.schema
|
||||
CodeMapOp = code_operations.CodeMapOperation.schema
|
||||
CodeReduceOp = code_operations.CodeReduceOperation.schema
|
||||
CodeFilterOp = code_operations.CodeFilterOperation.schema
|
||||
ExtractOp = extract.ExtractOperation.schema
|
||||
|
||||
OpType = (
|
||||
MapOp
|
||||
|
|
@ -43,10 +37,6 @@ OpType = (
|
|||
| SplitOp
|
||||
| GatherOp
|
||||
| UnnestOp
|
||||
| CodeMapOp
|
||||
| CodeReduceOp
|
||||
| CodeFilterOp
|
||||
| ExtractOp
|
||||
)
|
||||
|
||||
Dataset = dataset.Dataset.schema
|
||||
|
|
|
|||
152
docetl/utils.py
152
docetl/utils.py
|
|
@ -10,7 +10,6 @@ from jinja2 import Environment, meta
|
|||
from litellm import ModelResponse
|
||||
from litellm import completion_cost as lcc
|
||||
from lzstring import LZString
|
||||
from rich.prompt import Confirm
|
||||
|
||||
|
||||
class Decryptor:
|
||||
|
|
@ -80,72 +79,6 @@ class CapturedOutput:
|
|||
self.optimizer_output[self.step][stage_type] = output
|
||||
|
||||
|
||||
def has_jinja_syntax(template_string: str) -> bool:
|
||||
"""
|
||||
Check if a string contains Jinja2 template syntax.
|
||||
|
||||
Args:
|
||||
template_string (str): The string to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the string contains Jinja2 syntax ({{ }} or {% %}), False otherwise.
|
||||
"""
|
||||
# Check for Jinja2 expression syntax {{ }}
|
||||
if re.search(r"\{\{.*?\}\}", template_string):
|
||||
return True
|
||||
# Check for Jinja2 statement syntax {% %}
|
||||
if re.search(r"\{%.*?%\}", template_string):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def prompt_user_for_non_jinja_confirmation(
|
||||
prompt_text: str, operation_name: str, prompt_field: str = "prompt"
|
||||
) -> bool:
|
||||
"""
|
||||
Prompt the user for confirmation when a prompt doesn't contain Jinja syntax.
|
||||
|
||||
Args:
|
||||
prompt_text (str): The prompt text that doesn't contain Jinja syntax.
|
||||
operation_name (str): The name of the operation.
|
||||
prompt_field (str): The name of the prompt field (e.g., "prompt", "batch_prompt").
|
||||
|
||||
Returns:
|
||||
bool: True if user confirms, False otherwise.
|
||||
"""
|
||||
from docetl.console import DOCETL_CONSOLE
|
||||
|
||||
console = DOCETL_CONSOLE
|
||||
console.print(
|
||||
f"\n[bold yellow]⚠ Warning:[/bold yellow] The '{prompt_field}' in operation '{operation_name}' "
|
||||
f"does not appear to be a Jinja2 template (no {{}} or {{% %}} syntax found)."
|
||||
)
|
||||
console.print(
|
||||
f"[dim]Prompt:[/dim] {prompt_text[:100]}{'...' if len(prompt_text) > 100 else ''}"
|
||||
)
|
||||
console.print(
|
||||
"\n[bold]We will automatically append the document(s) to your prompt during execution:[/bold]"
|
||||
)
|
||||
console.print(
|
||||
" • For single-document operations: 'Here is the document: {{ input }}'"
|
||||
)
|
||||
console.print(" • For reduce operations: 'Here are the documents: {{ inputs }}'")
|
||||
console.print()
|
||||
|
||||
try:
|
||||
return Confirm.ask(
|
||||
"Do you want to proceed with inserting all documents as-is?",
|
||||
default=True,
|
||||
console=console,
|
||||
)
|
||||
except Exception:
|
||||
# If Confirm fails (e.g., in non-interactive mode), default to True
|
||||
console.print(
|
||||
"[dim]Non-interactive mode: proceeding with document insertion[/dim]"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def extract_jinja_variables(template_string: str) -> list[str]:
|
||||
"""
|
||||
Extract variables from a Jinja2 template string.
|
||||
|
|
@ -366,88 +299,3 @@ class classproperty:
|
|||
|
||||
def __get__(self, obj: Any | None, owner: type) -> Any:
|
||||
return self.f(owner)
|
||||
|
||||
|
||||
def extract_output_from_json(yaml_file_path, json_output_path=None):
|
||||
"""
|
||||
Extract output fields from JSON file based on the output schema defined in the YAML file.
|
||||
|
||||
If the last operation doesn't have an output schema, returns all keys from the output data.
|
||||
|
||||
Args:
|
||||
yaml_file_path (str): Path to the YAML configuration file
|
||||
json_output_path (str): Path to the JSON output file to extract from
|
||||
|
||||
Returns:
|
||||
List[Dict]: Extracted data containing only the fields specified in the output schema,
|
||||
or all fields if no output schema is defined
|
||||
"""
|
||||
# Load YAML configuration
|
||||
with open(yaml_file_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
if json_output_path is None:
|
||||
json_output_path = config.get("pipeline", {}).get("output", {}).get("path")
|
||||
if json_output_path is None:
|
||||
raise ValueError("No output path found in YAML file")
|
||||
|
||||
# Load JSON output data
|
||||
with open(json_output_path, "r") as f:
|
||||
output_data = json.load(f)
|
||||
|
||||
# Find the last operation in the pipeline
|
||||
pipeline = config.get("pipeline", {})
|
||||
steps = pipeline.get("steps", [])
|
||||
if not steps:
|
||||
raise ValueError("No pipeline steps found in YAML file")
|
||||
|
||||
# Get the last step and its operations
|
||||
last_step = steps[-1]
|
||||
last_step_operations = last_step.get("operations", [])
|
||||
if not last_step_operations:
|
||||
raise ValueError("No operations found in the last pipeline step")
|
||||
|
||||
# Get the name of the last operation in the last step
|
||||
last_operation_name = last_step_operations[-1]
|
||||
|
||||
# Find this operation in the operations list
|
||||
operations = config.get("operations", [])
|
||||
last_operation = None
|
||||
for op in operations:
|
||||
if op.get("name") == last_operation_name:
|
||||
last_operation = op
|
||||
break
|
||||
|
||||
if not last_operation:
|
||||
raise ValueError(
|
||||
f"Operation '{last_operation_name}' not found in operations list"
|
||||
)
|
||||
|
||||
output_schema = last_operation.get("output", {}).get("schema", {})
|
||||
if not output_schema:
|
||||
# If no output schema, return all keys from the output data
|
||||
if isinstance(output_data, list) and len(output_data) > 0:
|
||||
# Get all unique keys from all items
|
||||
all_keys = set()
|
||||
for item in output_data:
|
||||
if isinstance(item, dict):
|
||||
all_keys.update(item.keys())
|
||||
# Return all data with all keys
|
||||
return output_data
|
||||
else:
|
||||
# If output_data is not a list or is empty, return as-is
|
||||
return output_data if isinstance(output_data, list) else [output_data]
|
||||
|
||||
# Extract the field names from the schema
|
||||
schema_fields = list(output_schema.keys())
|
||||
|
||||
# Extract only the specified fields from each item in the output data
|
||||
extracted_data = []
|
||||
for item in output_data:
|
||||
extracted_item = {}
|
||||
for field in schema_fields:
|
||||
if field in item:
|
||||
extracted_item[field] = item[field]
|
||||
extracted_data.append(extracted_item)
|
||||
|
||||
return extracted_data
|
||||
|
|
|
|||
|
|
@ -1,129 +0,0 @@
|
|||
"""
|
||||
Dataset utility functions for DocETL.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def compute_dataset_stats(
|
||||
data: List[Dict[str, Any]], dataset_name: str = "data"
|
||||
) -> str:
|
||||
"""
|
||||
Compute statistics for a dataset by analyzing the actual data.
|
||||
|
||||
Args:
|
||||
data: List of data records
|
||||
dataset_name: Name of the dataset
|
||||
|
||||
Returns:
|
||||
str: Formatted dataset statistics
|
||||
"""
|
||||
if not data:
|
||||
return (
|
||||
f"Dataset: {dataset_name}\nType: file\nRecords loaded: 0\nNo data available"
|
||||
)
|
||||
|
||||
num_records = len(data)
|
||||
total_tokens = 0
|
||||
field_stats = {}
|
||||
|
||||
# Analyze each record
|
||||
for record in data:
|
||||
if isinstance(record, dict):
|
||||
for key, value in record.items():
|
||||
# Skip if key starts with "GT "
|
||||
if key.startswith("GT "):
|
||||
continue
|
||||
|
||||
if key not in field_stats:
|
||||
field_stats[key] = {
|
||||
"total_chars": 0,
|
||||
"count": 0,
|
||||
"type": type(value).__name__,
|
||||
}
|
||||
|
||||
if isinstance(value, str):
|
||||
char_count = len(value)
|
||||
field_stats[key]["total_chars"] += char_count
|
||||
field_stats[key]["count"] += 1
|
||||
total_tokens += (
|
||||
char_count / 4
|
||||
) # 4 characters per token approximation
|
||||
elif isinstance(value, (int, float)):
|
||||
# Numbers are typically short, estimate as ~5 characters
|
||||
field_stats[key]["total_chars"] += 5
|
||||
field_stats[key]["count"] += 1
|
||||
total_tokens += 1.25
|
||||
elif isinstance(value, list):
|
||||
# For lists, estimate based on string representation
|
||||
str_repr = str(value)
|
||||
char_count = len(str_repr)
|
||||
field_stats[key]["total_chars"] += char_count
|
||||
field_stats[key]["count"] += 1
|
||||
total_tokens += char_count / 4
|
||||
|
||||
# Format the output
|
||||
stats_lines = [
|
||||
f"Dataset: {dataset_name}",
|
||||
"Type: file",
|
||||
f"Records loaded: {num_records}",
|
||||
"Input schema:",
|
||||
]
|
||||
|
||||
for field, stats in field_stats.items():
|
||||
if stats["count"] > 0:
|
||||
avg_tokens = (stats["total_chars"] / stats["count"]) / 4
|
||||
field_type = "string" if stats["type"] in ["str"] else stats["type"]
|
||||
stats_lines.append(
|
||||
f" {field}: {field_type} (avg: {avg_tokens:.1f} tokens)"
|
||||
)
|
||||
|
||||
stats_lines.append(f"Total tokens: {int(total_tokens):,}")
|
||||
|
||||
return "\n ".join(stats_lines)
|
||||
|
||||
|
||||
def get_dataset_stats(yaml_path: str, dataset_name: str | None = None) -> str:
|
||||
"""
|
||||
Get dataset statistics by loading and analyzing the actual data from YAML config.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the YAML configuration file
|
||||
dataset_name: Optional name of the dataset (if not provided, uses first dataset in config)
|
||||
|
||||
Returns:
|
||||
str: Formatted dataset statistics
|
||||
"""
|
||||
# Load the YAML config to get the data path
|
||||
with open(yaml_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# Extract dataset info from config
|
||||
datasets = config.get("datasets", {})
|
||||
if not datasets:
|
||||
return f"Dataset: {dataset_name or 'unknown'}\nType: file\nRecords loaded: 0\nNo datasets found in config"
|
||||
|
||||
# Get the first dataset (or specified dataset)
|
||||
if dataset_name and dataset_name in datasets:
|
||||
dataset_config = datasets[dataset_name]
|
||||
actual_dataset_name = dataset_name
|
||||
else:
|
||||
actual_dataset_name, dataset_config = next(iter(datasets.items()))
|
||||
|
||||
data_path = dataset_config.get("path")
|
||||
|
||||
if not data_path:
|
||||
return f"Dataset: {actual_dataset_name}\nType: file\nRecords loaded: 0\nNo data path found"
|
||||
|
||||
# Load the data
|
||||
try:
|
||||
with open(data_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
return compute_dataset_stats(data, actual_dataset_name)
|
||||
|
||||
except Exception as e:
|
||||
return f"Dataset: {actual_dataset_name}\nType: file\nRecords loaded: 0\nError loading data: {e}"
|
||||
|
|
@ -1,441 +0,0 @@
|
|||
"""
|
||||
Evaluation utility functions for DocETL.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import inspect
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from docetl.console import DOCETL_CONSOLE
|
||||
|
||||
|
||||
def register_eval(
|
||||
func: Callable[[str, str], Dict[str, Any]]
|
||||
) -> Callable[[str, str], Dict[str, Any]]:
|
||||
"""
|
||||
Decorator to mark a function as a DocETL evaluation function.
|
||||
|
||||
The decorated function should take two arguments (dataset_file_path, results_file_path) and return
|
||||
a dictionary of evaluation metrics.
|
||||
|
||||
Example:
|
||||
@docetl.register_eval
|
||||
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
|
||||
# ... evaluation logic ...
|
||||
return {"score": 0.95}
|
||||
"""
|
||||
func._docetl_eval = True
|
||||
return func
|
||||
|
||||
|
||||
def load_custom_evaluate_func(
|
||||
evaluation_file_path: str, dataset_file_path: str
|
||||
) -> Callable[[str], Dict[str, Any]]:
|
||||
"""
|
||||
Load a custom evaluation function from a Python file and wrap it to pass dataset_file_path.
|
||||
|
||||
The file should contain a function decorated with @docetl.register_eval.
|
||||
If multiple functions are decorated, an error is raised.
|
||||
|
||||
Args:
|
||||
evaluation_file_path: Path to a Python file containing a function decorated with @docetl.register_eval
|
||||
dataset_file_path: Path to the dataset file to pass to the evaluation function
|
||||
|
||||
Returns:
|
||||
callable: Wrapped evaluation function that takes (results_file_path: str) -> dict
|
||||
|
||||
Raises:
|
||||
ValueError: If the file doesn't exist, doesn't contain a decorated function, or has multiple decorated functions
|
||||
"""
|
||||
func_path = Path(evaluation_file_path)
|
||||
if not func_path.exists():
|
||||
raise ValueError(f"Evaluation file not found: {evaluation_file_path}")
|
||||
|
||||
# Use a unique module name based on the file path to avoid conflicts
|
||||
module_name = f"docetl_eval_{func_path.stem}_{hash(str(func_path))}"
|
||||
spec = importlib.util.spec_from_file_location(module_name, func_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ValueError(f"Could not load module from: {evaluation_file_path}")
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Find all functions decorated with @docetl.register_eval
|
||||
eval_functions = []
|
||||
for name, obj in inspect.getmembers(module, inspect.isfunction):
|
||||
if hasattr(obj, "_docetl_eval") and obj._docetl_eval:
|
||||
eval_functions.append((name, obj))
|
||||
|
||||
if len(eval_functions) == 0:
|
||||
raise ValueError(
|
||||
f"Module {evaluation_file_path} must contain a function decorated with @docetl.register_eval. "
|
||||
f"Found functions: {[name for name, _ in inspect.getmembers(module, inspect.isfunction)]}"
|
||||
)
|
||||
|
||||
if len(eval_functions) > 1:
|
||||
function_names = [name for name, _ in eval_functions]
|
||||
raise ValueError(
|
||||
f"Module {evaluation_file_path} contains multiple functions decorated with @docetl.register_eval: {function_names}. "
|
||||
f"Only one evaluation function is allowed per file."
|
||||
)
|
||||
|
||||
# Wrap the function to pass dataset_file_path
|
||||
original_func = eval_functions[0][1]
|
||||
|
||||
def wrapped_func(results_file_path: str) -> Dict[str, Any]:
|
||||
return original_func(dataset_file_path, results_file_path)
|
||||
|
||||
return wrapped_func
|
||||
|
||||
|
||||
def _extract_node_data(item: Any) -> tuple[Optional[str], Dict[str, Any]]:
|
||||
"""Extract node data from either a node object or a dict/file path."""
|
||||
if hasattr(item, "result_path"):
|
||||
jf = item.result_path
|
||||
node_data = {
|
||||
"node_id": item.get_id(),
|
||||
"cost": item.cost,
|
||||
"visits": getattr(item, "visits", 0),
|
||||
"value": getattr(item, "value", 0),
|
||||
}
|
||||
else:
|
||||
jf = item.get("file_path") if isinstance(item, dict) else item
|
||||
node_data = {
|
||||
"node_id": (
|
||||
item.get("node_id", "unknown") if isinstance(item, dict) else "unknown"
|
||||
),
|
||||
"cost": item.get("cost", 0.0) if isinstance(item, dict) else 0.0,
|
||||
"visits": item.get("visits", 0) if isinstance(item, dict) else 0,
|
||||
"value": item.get("value", 0) if isinstance(item, dict) else 0,
|
||||
}
|
||||
return jf, node_data
|
||||
|
||||
|
||||
def _get_display_path(jf: str, output_path: Path) -> str:
|
||||
"""Get display path for a result file."""
|
||||
jp = Path(jf).resolve()
|
||||
op_root = output_path.resolve()
|
||||
if hasattr(jp, "is_relative_to") and jp.is_relative_to(op_root):
|
||||
return str(jp.relative_to(op_root))
|
||||
else:
|
||||
return jp.name
|
||||
|
||||
|
||||
def _add_frontier_info(result: Dict[str, Any], item: Any) -> Dict[str, Any]:
|
||||
"""Add frontier information if available."""
|
||||
if hasattr(item, "result_path"):
|
||||
result.update(
|
||||
{
|
||||
"moar_accuracy": getattr(item, "moar_accuracy", None),
|
||||
"on_frontier": getattr(item, "on_frontier", False),
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def identify_pareto_frontier(
|
||||
eval_results: List[Dict[str, Any]], metric_key: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Identify the Pareto frontier for evaluation results based on accuracy vs cost.
|
||||
|
||||
Args:
|
||||
eval_results: List of evaluation results with cost and accuracy metrics
|
||||
metric_key: Key to use for accuracy metric
|
||||
|
||||
Returns:
|
||||
Updated eval_results with 'on_frontier' field set to True/False
|
||||
|
||||
Raises:
|
||||
KeyError: If required metrics are missing from results
|
||||
"""
|
||||
if not eval_results:
|
||||
return eval_results
|
||||
|
||||
# Filter out results that don't have the required metrics
|
||||
valid_results = [r for r in eval_results if metric_key in r and "cost" in r]
|
||||
if not valid_results:
|
||||
DOCETL_CONSOLE.log(
|
||||
f"[yellow]⚠️ No valid results with {metric_key} and cost metrics[/yellow]"
|
||||
)
|
||||
return eval_results
|
||||
|
||||
# Validate that all results have the required metrics
|
||||
for r in valid_results:
|
||||
if metric_key not in r:
|
||||
raise KeyError(
|
||||
f"Missing required accuracy metric '{metric_key}' in evaluation result. "
|
||||
f"Available keys: {list(r.keys())}. "
|
||||
f"This metric is required for Pareto frontier identification."
|
||||
)
|
||||
if "cost" not in r:
|
||||
raise KeyError(
|
||||
f"Missing required 'cost' metric in evaluation result. "
|
||||
f"Available keys: {list(r.keys())}"
|
||||
)
|
||||
|
||||
# Sort by cost (ascending) and accuracy (descending for maximization)
|
||||
valid_results.sort(key=lambda x: (x["cost"], -x[metric_key]))
|
||||
|
||||
# Identify Pareto frontier: points that are not dominated by any other point
|
||||
frontier = []
|
||||
for i, candidate in enumerate(valid_results):
|
||||
is_dominated = False
|
||||
for j, other in enumerate(valid_results):
|
||||
if i == j:
|
||||
continue
|
||||
# Check if other point dominates candidate
|
||||
# Dominated if: other has lower cost AND higher/equal accuracy, OR same cost AND higher accuracy
|
||||
if (
|
||||
other["cost"] < candidate["cost"]
|
||||
and other[metric_key] >= candidate[metric_key]
|
||||
) or (
|
||||
other["cost"] == candidate["cost"]
|
||||
and other[metric_key] > candidate[metric_key]
|
||||
):
|
||||
is_dominated = True
|
||||
break
|
||||
|
||||
if not is_dominated:
|
||||
frontier.append(candidate)
|
||||
|
||||
# Mark all results with frontier status
|
||||
frontier_set = set(id(f) for f in frontier)
|
||||
for r in eval_results:
|
||||
r["on_frontier"] = id(r) in frontier_set
|
||||
|
||||
return eval_results
|
||||
|
||||
|
||||
def print_pareto_frontier_summary(
|
||||
eval_results: List[Dict[str, Any]],
|
||||
metric_key: str,
|
||||
dataset_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Print a summary of the Pareto frontier points.
|
||||
|
||||
Args:
|
||||
eval_results: List of evaluation results with 'on_frontier' field
|
||||
metric_key: Key to use for accuracy metric
|
||||
dataset_name: Optional dataset name for display purposes
|
||||
|
||||
Raises:
|
||||
KeyError: If required metrics are missing from frontier points
|
||||
"""
|
||||
frontier_points = [r for r in eval_results if r.get("on_frontier", False)]
|
||||
|
||||
if not frontier_points:
|
||||
dataset_str = f" for {dataset_name}" if dataset_name else ""
|
||||
DOCETL_CONSOLE.log(
|
||||
f"[yellow]📊 No Pareto frontier points found{dataset_str}[/yellow]"
|
||||
)
|
||||
return
|
||||
|
||||
# Sort frontier points by cost for better display
|
||||
frontier_points.sort(key=lambda x: x["cost"])
|
||||
|
||||
dataset_str = f" for {dataset_name.upper()}" if dataset_name else ""
|
||||
DOCETL_CONSOLE.log(f"\n🏆 Pareto Frontier Summary{dataset_str}:")
|
||||
DOCETL_CONSOLE.log("=" * 60)
|
||||
DOCETL_CONSOLE.log(
|
||||
f"{'Rank':<4} {'Cost ($)':<10} {metric_key.upper():<15} {'File':<30}"
|
||||
)
|
||||
DOCETL_CONSOLE.log("=" * 60)
|
||||
|
||||
for i, point in enumerate(frontier_points, 1):
|
||||
if "cost" not in point:
|
||||
raise KeyError(
|
||||
f"Missing required 'cost' metric in frontier point. "
|
||||
f"Available keys: {list(point.keys())}"
|
||||
)
|
||||
if metric_key not in point:
|
||||
raise KeyError(
|
||||
f"Missing required accuracy metric '{metric_key}' in frontier point. "
|
||||
f"Available keys: {list(point.keys())}. "
|
||||
f"This metric is required for Pareto frontier summary."
|
||||
)
|
||||
cost = point["cost"]
|
||||
accuracy = point[metric_key]
|
||||
file_name = point.get("file", "unknown")
|
||||
DOCETL_CONSOLE.log(f"{i:<4} ${cost:<9.4f} {accuracy:<15.4f} {file_name:<30}")
|
||||
|
||||
DOCETL_CONSOLE.log("=" * 60)
|
||||
DOCETL_CONSOLE.log(f"Total frontier points: {len(frontier_points)}")
|
||||
|
||||
|
||||
def save_pareto_frontier_results(
|
||||
eval_results: List[Dict[str, Any]],
|
||||
output_path: Path,
|
||||
metric_key: str,
|
||||
dataset_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save Pareto frontier results to a separate JSON file for analysis.
|
||||
|
||||
Args:
|
||||
eval_results: List of evaluation results with 'on_frontier' field
|
||||
output_path: Output directory path
|
||||
metric_key: Key to use for accuracy metric
|
||||
dataset_name: Optional dataset name
|
||||
|
||||
Raises:
|
||||
KeyError: If required metrics are missing from frontier points
|
||||
"""
|
||||
frontier_points = [r for r in eval_results if r.get("on_frontier", False)]
|
||||
|
||||
if not frontier_points:
|
||||
return
|
||||
|
||||
# Sort frontier points by cost
|
||||
frontier_points.sort(key=lambda x: x["cost"])
|
||||
|
||||
# Add rank and accuracy metric information
|
||||
for i, point in enumerate(frontier_points):
|
||||
point["rank"] = i + 1
|
||||
point["accuracy_metric"] = metric_key
|
||||
|
||||
# Calculate cost-effectiveness ratios between consecutive points
|
||||
cost_effectiveness_analysis = []
|
||||
for i in range(len(frontier_points) - 1):
|
||||
curr = frontier_points[i]
|
||||
next_point = frontier_points[i + 1]
|
||||
|
||||
if "cost" not in curr or "cost" not in next_point:
|
||||
raise KeyError(
|
||||
f"Missing required 'cost' metric in frontier point. "
|
||||
f"Available keys: {list(curr.keys() if 'cost' not in curr else next_point.keys())}"
|
||||
)
|
||||
if metric_key not in curr or metric_key not in next_point:
|
||||
raise KeyError(
|
||||
f"Missing required accuracy metric '{metric_key}' in frontier point. "
|
||||
f"Available keys: {list(curr.keys() if metric_key not in curr else next_point.keys())}. "
|
||||
f"This metric is required for cost-effectiveness analysis."
|
||||
)
|
||||
cost_diff = next_point["cost"] - curr["cost"]
|
||||
accuracy_diff = next_point[metric_key] - curr[metric_key]
|
||||
|
||||
if cost_diff > 0 and accuracy_diff > 0:
|
||||
cost_effectiveness = cost_diff / accuracy_diff
|
||||
cost_effectiveness_analysis.append(
|
||||
{
|
||||
"from_file": curr["file"],
|
||||
"to_file": next_point["file"],
|
||||
"cost_increase": cost_diff,
|
||||
"accuracy_increase": accuracy_diff,
|
||||
"cost_per_unit_improvement": cost_effectiveness,
|
||||
}
|
||||
)
|
||||
|
||||
# Create frontier summary
|
||||
frontier_summary = {
|
||||
"accuracy_metric": metric_key,
|
||||
"total_frontier_points": len(frontier_points),
|
||||
"frontier_points": frontier_points,
|
||||
"cost_effectiveness_analysis": cost_effectiveness_analysis,
|
||||
}
|
||||
if dataset_name:
|
||||
frontier_summary["dataset"] = dataset_name
|
||||
|
||||
# Save to file
|
||||
frontier_file = output_path / "pareto_frontier.json"
|
||||
with open(frontier_file, "w") as f:
|
||||
json.dump(frontier_summary, f, indent=2)
|
||||
|
||||
DOCETL_CONSOLE.log(f"📊 Pareto frontier results written to {frontier_file}")
|
||||
|
||||
|
||||
def run_evaluation(
|
||||
nodes_or_files: List[Any],
|
||||
evaluate_func: Callable[[str], Dict[str, Any]],
|
||||
metric_key: str,
|
||||
output_path: Path,
|
||||
dataset_name: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run evaluation on a set of nodes or files using a custom evaluation function.
|
||||
|
||||
This is a general-purpose evaluation function that does not depend on
|
||||
experiment-specific datasets. It processes nodes/files, extracts metrics,
|
||||
identifies the Pareto frontier, and saves results.
|
||||
|
||||
Args:
|
||||
nodes_or_files: List of node objects (with result_path) or file paths
|
||||
evaluate_func: Evaluation function (results_file_path: str) -> dict
|
||||
metric_key: Key to extract from evaluation results for accuracy metric
|
||||
output_path: Path to save evaluation results
|
||||
dataset_name: Optional dataset name for display purposes
|
||||
|
||||
Returns:
|
||||
List of evaluation results with 'on_frontier' field set
|
||||
|
||||
Raises:
|
||||
ValueError: If metric_key is not provided
|
||||
KeyError: If required metrics are missing from evaluation results
|
||||
"""
|
||||
if not metric_key:
|
||||
raise ValueError("metric_key must be provided")
|
||||
|
||||
eval_results = []
|
||||
|
||||
# Process evaluation items
|
||||
for item in nodes_or_files:
|
||||
jf, node_data = _extract_node_data(item)
|
||||
if jf is None or not Path(jf).exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
metrics = evaluate_func(jf)
|
||||
display_path = _get_display_path(jf, output_path)
|
||||
|
||||
# Extract the custom metric
|
||||
accuracy_value = metrics.get(metric_key)
|
||||
if accuracy_value is None:
|
||||
DOCETL_CONSOLE.log(
|
||||
f"[yellow]⚠️ Warning: Metric key '{metric_key}' not found in evaluation results for {jf}. "
|
||||
f"Available keys: {list(metrics.keys())}[/yellow]"
|
||||
)
|
||||
# Try to find a numeric value as fallback
|
||||
accuracy_value = next(
|
||||
(v for v in metrics.values() if isinstance(v, (int, float))), None
|
||||
)
|
||||
if accuracy_value is None:
|
||||
DOCETL_CONSOLE.log(
|
||||
f"[red]❌ Skipping {jf}: No valid accuracy metric found[/red]"
|
||||
)
|
||||
continue
|
||||
|
||||
result = {
|
||||
"file": display_path,
|
||||
metric_key: accuracy_value,
|
||||
**metrics, # Include all metrics from custom function
|
||||
**node_data,
|
||||
}
|
||||
result = _add_frontier_info(result, item)
|
||||
eval_results.append(result)
|
||||
except Exception as e:
|
||||
DOCETL_CONSOLE.log(f"[red] ⚠️ Evaluation failed for {jf}: {e}[/red]")
|
||||
|
||||
# Identify Pareto frontier
|
||||
if eval_results:
|
||||
DOCETL_CONSOLE.log("\n🔍 Identifying Pareto frontier...")
|
||||
eval_results = identify_pareto_frontier(eval_results, metric_key)
|
||||
|
||||
# Print Pareto frontier summary
|
||||
print_pareto_frontier_summary(eval_results, metric_key, dataset_name)
|
||||
|
||||
# Save Pareto frontier results to separate file
|
||||
save_pareto_frontier_results(
|
||||
eval_results, output_path, metric_key, dataset_name
|
||||
)
|
||||
|
||||
# Save evaluation results
|
||||
if eval_results:
|
||||
eval_out_file = output_path / "evaluation_metrics.json"
|
||||
with open(eval_out_file, "w") as f:
|
||||
json.dump(eval_results, f, indent=2)
|
||||
DOCETL_CONSOLE.log(f"📊 Evaluation results written to {eval_out_file}")
|
||||
|
||||
return eval_results
|
||||
|
|
@ -101,63 +101,6 @@
|
|||
ignore_init_summary: false
|
||||
trim_doctest_flags: true
|
||||
|
||||
::: docetl.schemas.CodeMapOp
|
||||
options:
|
||||
show_root_heading: true
|
||||
heading_level: 3
|
||||
show_if_no_docstring: false
|
||||
docstring_options:
|
||||
ignore_init_summary: false
|
||||
trim_doctest_flags: true
|
||||
|
||||
::: docetl.schemas.CodeReduceOp
|
||||
options:
|
||||
show_root_heading: true
|
||||
heading_level: 3
|
||||
show_if_no_docstring: false
|
||||
docstring_options:
|
||||
ignore_init_summary: false
|
||||
trim_doctest_flags: true
|
||||
|
||||
::: docetl.schemas.CodeFilterOp
|
||||
options:
|
||||
show_root_heading: true
|
||||
heading_level: 3
|
||||
show_if_no_docstring: false
|
||||
docstring_options:
|
||||
ignore_init_summary: false
|
||||
trim_doctest_flags: true
|
||||
|
||||
::: docetl.schemas.ExtractOp
|
||||
options:
|
||||
show_root_heading: true
|
||||
heading_level: 3
|
||||
show_if_no_docstring: false
|
||||
docstring_options:
|
||||
ignore_init_summary: false
|
||||
trim_doctest_flags: true
|
||||
|
||||
### Callable support for code ops
|
||||
|
||||
Code operations (`code_map`, `code_reduce`, `code_filter`) accept either a string containing Python code that defines a `transform` function, or a regular Python function. When you pass a function, it does not need to be named `transform`; DocETL binds it internally.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from docetl.api import CodeMapOp
|
||||
|
||||
def my_map(doc: dict) -> dict:
|
||||
return {"double": doc["x"] * 2}
|
||||
|
||||
code_map = CodeMapOp(name="double_x", type="code_map", code=my_map)
|
||||
```
|
||||
|
||||
- Map: `fn(doc: dict) -> dict`
|
||||
- Filter: `fn(doc: dict) -> bool`
|
||||
- Reduce: `fn(group: list[dict]) -> dict`
|
||||
|
||||
See also: [Code Operators](../operators/code.md), [Extract Operator](../operators/extract.md)
|
||||
|
||||
## Dataset and Pipeline
|
||||
|
||||
::: docetl.schemas.Dataset
|
||||
|
|
|
|||
|
|
@ -56,31 +56,13 @@ The DocETL optimizer operates using the following mechanism:
|
|||
|
||||
### Using the Optimizer
|
||||
|
||||
DocETL provides two optimizer options:
|
||||
|
||||
#### MOAR Optimizer (Recommended)
|
||||
|
||||
The MOAR optimizer uses Monte Carlo Tree Search to find Pareto-optimal solutions balancing accuracy and cost:
|
||||
You can invoke the optimizer using the following command:
|
||||
|
||||
```bash
|
||||
docetl build your_pipeline.yaml --optimizer moar
|
||||
docetl build your_pipeline.yaml
|
||||
```
|
||||
|
||||
See the [MOAR Optimizer Guide](../optimization/moar.md) for detailed instructions.
|
||||
|
||||
#### V1 Optimizer (Deprecated)
|
||||
|
||||
!!! warning "Deprecated"
|
||||
The V1 optimizer is deprecated and no longer recommended. Use MOAR instead.
|
||||
|
||||
The V1 optimizer uses a greedy approach with validation. It's still available for backward compatibility:
|
||||
|
||||
```bash
|
||||
docetl build your_pipeline.yaml --optimizer v1
|
||||
```
|
||||
|
||||
!!! warning "Not Recommended"
|
||||
The V1 optimizer should not be used for new projects. Use MOAR instead.
|
||||
This command will save the optimized pipeline to `your_pipeline_opt.yaml`. Note that the optimizer will only rewrite operators where you've set `optimize: true`. Leaving this field unset will skip optimization for that operator.
|
||||
|
||||
<!-- ### Automatic Entity Resolution
|
||||
|
||||
|
|
|
|||
|
|
@ -34,9 +34,6 @@ To get started with DocETL:
|
|||
2. Define your pipeline in a YAML file. Want to use an LLM like ChatGPT or Claude to help you write your pipeline? See [docetl.org/llms.txt](https://docetl.org/llms.txt) for a big prompt you can copy paste into ChatGPT or Claude, before describing your task.
|
||||
3. Run your pipeline using the DocETL command-line interface
|
||||
|
||||
!!! tip "Fastest Way: Claude Code"
|
||||
Clone this repo and run `claude` to use the built-in DocETL skill. Just describe your data processing task and Claude will create and run the pipeline for you. See [Quick Start (Claude Code)](quickstart-claude-code.md) for details.
|
||||
|
||||
## 🏛️ Project Origin
|
||||
|
||||
DocETL was created by members of the EPIC Data Lab and Data Systems and Foundations group at UC Berkeley. The EPIC (Effective Programming, Interaction, and Computation with Data) Lab focuses on developing low-code and no-code interfaces for data work, powered by next-generation predictive programming techniques. DocETL is one of the projects that emerged from our research efforts to streamline complex document processing tasks.
|
||||
|
|
|
|||
|
|
@ -55,6 +55,8 @@ If you want to use only the parsing extra:
|
|||
uv sync --extra parsing
|
||||
```
|
||||
|
||||
This will create a virtual environment and install all the required dependencies.
|
||||
|
||||
4. Set up your OpenAI API key:
|
||||
|
||||
Create a .env file in the project root and add your OpenAI API key:
|
||||
|
|
|
|||
|
|
@ -91,10 +91,3 @@ The transform function should return True for items to keep and False for items
|
|||
| reduce_key | Key(s) to group by (code_reduce only) | "_all" |
|
||||
| pass_through | Pass through unmodified keys from first item in group (code_reduce only) | false |
|
||||
| concurrent_thread_count | The number of threads to start | the number of logical CPU cores (os.cpu_count()) |
|
||||
| limit | Maximum number of outputs to produce before stopping | Processes all data |
|
||||
|
||||
The `limit` parameter behaves differently for each operation type:
|
||||
|
||||
- **code_map**: Limits the number of input documents to process
|
||||
- **code_filter**: Limits the number of documents that pass the filter (outputs)
|
||||
- **code_reduce**: Limits the number of groups to reduce, selecting the smallest groups first (by document count)
|
||||
|
|
|
|||
|
|
@ -34,9 +34,9 @@ This Equijoin operation matches job candidates to job postings:
|
|||
|
||||
The prompt template uses Jinja2 syntax, allowing you to reference input fields directly (e.g., `left.skills`). You can reference the left and right documents using `left` and `right` respectively.
|
||||
|
||||
!!! info "Automatic Blocking"
|
||||
!!! warning "Performance Consideration"
|
||||
|
||||
If you don't specify any blocking configuration (`blocking_threshold`, `blocking_conditions`, or `limit_comparisons`), the Equijoin operation will automatically compute an optimal embedding-based blocking threshold at runtime. It samples pairs from your data, runs LLM comparisons on the sample, and finds a threshold that achieves 95% recall by default. You can adjust this with the `blocking_target_recall` parameter.
|
||||
For large datasets, running comparisons with an LLM can be time-consuming. It's recommended to optimize your pipeline using `docetl build pipeline.yaml` to generate efficient blocking rules for the operation.
|
||||
|
||||
## Blocking
|
||||
|
||||
|
|
@ -95,19 +95,10 @@ A full Equijoin step combining both ideas might look like:
|
|||
|
||||
Equijoin shares many parameters with the Resolve operation. For a detailed list of required and optional parameters, please see the [Parameters section in the Resolve operation documentation](resolve.md#required-parameters).
|
||||
|
||||
### Equijoin-Specific Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
| ------------------------- | --------------------------------------------------------------------------------- | ----------------------------- |
|
||||
| `limits` | Maximum matches for each left/right item: `{"left": n, "right": m}` | No limit |
|
||||
| `blocking_keys` | Keys for embedding blocking: `{"left": [...], "right": [...]}` | All keys from each dataset |
|
||||
| `blocking_threshold` | Embedding similarity threshold for considering pairs | Auto-computed if not set |
|
||||
| `blocking_target_recall` | Target recall when auto-computing blocking threshold (0.0 to 1.0) | 0.95 |
|
||||
|
||||
Key differences from Resolve:
|
||||
Key differences for Equijoin include:
|
||||
|
||||
- `resolution_prompt` is not used in Equijoin.
|
||||
- `blocking_keys` uses a dict with `left` and `right` keys instead of a simple list.
|
||||
- `limits` parameter is specific to Equijoin, allowing you to set maximum matches for each left and right item.
|
||||
|
||||
## Incorporating Into a Pipeline
|
||||
|
||||
|
|
|
|||
|
|
@ -140,11 +140,6 @@ This strategy asks the LLM to generate regex patterns matching the desired conte
|
|||
| `timeout` | Timeout for LLM calls in seconds | 120 |
|
||||
| `skip_on_error` | Continue processing if errors occur | false |
|
||||
| `litellm_completion_kwargs` | Additional parameters for LiteLLM calls | {} |
|
||||
| `limit` | Maximum number of documents to extract from before stopping | Processes all data |
|
||||
| `retriever` | Name of a retriever to use for RAG. See [Retrievers](../retrievers.md). | None |
|
||||
| `save_retriever_output` | If true, saves the retrieved context to `_<operation_name>_retrieved_context` in the output. | False |
|
||||
|
||||
When `limit` is set, Extract only reformats and submits the first _N_ documents. This is handy when the upstream dataset is large and you want to cap cost while previewing results.
|
||||
|
||||
## Best Practices
|
||||
|
||||
|
|
|
|||
|
|
@ -83,10 +83,6 @@ This example demonstrates how the Filter operation distinguishes between high-im
|
|||
|
||||
See [map optional parameters](./map.md#optional-parameters) for additional configuration options, including `batch_prompt` and `max_batch_size`.
|
||||
|
||||
### Limiting filtered outputs
|
||||
|
||||
`limit` behaves slightly differently for filter operations than for map operations. Because filter drops documents whose predicate evaluates to `false`, the limit counts only the documents that would be retained (i.e., the ones whose boolean output is `true`). DocETL will continue evaluating additional inputs until it has collected `limit` passing documents and then stop scheduling further LLM calls. This ensures you can request “the first N matches” without paying to score the entire dataset.
|
||||
|
||||
!!! info "Validation"
|
||||
|
||||
For more details on validation techniques and implementation, see [operators](../concepts/operators.md#validation).
|
||||
|
|
|
|||
|
|
@ -140,7 +140,6 @@ This example demonstrates how the Map operation can transform long, unstructured
|
|||
| `optimize` | Flag to enable operation optimization | `True` |
|
||||
| `recursively_optimize` | Flag to enable recursive optimization of operators synthesized as part of rewrite rules | `false` |
|
||||
| `sample` | Number of samples to use for the operation | Processes all data |
|
||||
| `limit` | Maximum number of outputs to produce before stopping | Processes all data |
|
||||
| `tools` | List of tool definitions for LLM use | None |
|
||||
| `validate` | List of Python expressions to validate the output | None |
|
||||
| `flush_partial_results` | Write results of individual batches of map operation to disk for faster inspection | False |
|
||||
|
|
@ -156,15 +155,9 @@ This example demonstrates how the Map operation can transform long, unstructured
|
|||
| `pdf_url_key` | If specified, the key in the input that contains the URL of the PDF to process. | None |
|
||||
| `calibrate` | Improve consistency across documents by using sample data as reference anchors. | False |
|
||||
| `num_calibration_docs` | Number of documents to use sample and generate outputs for, for calibration. | 10 |
|
||||
| `retriever` | Name of a retriever to use for RAG. See [Retrievers](../retrievers.md). | None |
|
||||
| `save_retriever_output` | If true, saves the retrieved context to `_<operation_name>_retrieved_context` in the output. | False |
|
||||
|
||||
Note: If `drop_keys` is specified, `prompt` and `output` become optional parameters.
|
||||
|
||||
### Limiting execution
|
||||
|
||||
Set `limit` when you only need the first _N_ map results or want to cap LLM spend. The operation slices the processed dataset to the first `limit` entries and also stops scheduling new prompts once that many outputs have been produced, even if a prompt returns multiple records. Filter operations inherit this behavior but redefine the count so the limit only applies to records whose filter predicate evaluates to `true` (see [Filter](./filter.md#optional-parameters)).
|
||||
|
||||
|
||||
!!! info "Validation and Gleaning"
|
||||
|
||||
|
|
|
|||
|
|
@ -52,7 +52,6 @@ This Reduce operation processes customer feedback grouped by department:
|
|||
| Parameter | Description | Default |
|
||||
| ------------------------- | ------------------------------------------------------------------------------------------------------ | --------------------------- |
|
||||
| `sample` | Number of samples to use for the operation | None |
|
||||
| `limit` | Maximum number of groups to process before stopping | All groups |
|
||||
| `synthesize_resolve` | If false, won't synthesize a resolve operation between map and reduce | true |
|
||||
| `model` | The language model to use | Falls back to default_model |
|
||||
| `input` | Specifies the schema or keys to subselect from each item | All keys from input items |
|
||||
|
|
@ -67,12 +66,6 @@ This Reduce operation processes customer feedback grouped by department:
|
|||
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |
|
||||
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
|
||||
| `bypass_cache` | If true, bypass the cache for this operation. | False |
|
||||
| `retriever` | Name of a retriever to use for RAG. See [Retrievers](../retrievers.md). | None |
|
||||
| `save_retriever_output` | If true, saves the retrieved context to `_<operation_name>_retrieved_context` in the output. | False |
|
||||
|
||||
### Limiting group processing
|
||||
|
||||
Set `limit` to short-circuit the reduce phase after _N_ groups have been aggregated. When `limit` is set, groups are sorted by size (smallest first) and only the _N_ smallest groups are processed. This is useful for previewing results or capping LLM usage while minimizing cost by processing groups with fewer documents. Groups beyond the limit are never scheduled, so you avoid extra fold/merge calls. If a grouped reduce returns more than one record per group, the final output list is truncated to `limit`.
|
||||
|
||||
## Advanced Features
|
||||
|
||||
|
|
|
|||
|
|
@ -44,9 +44,9 @@ This Resolve operation processes patient names to identify and standardize dupli
|
|||
|
||||
Note: The prompt templates use Jinja2 syntax, allowing you to reference input fields directly (e.g., `input1.patient_name`).
|
||||
|
||||
!!! info "Automatic Blocking"
|
||||
!!! warning "Performance Consideration"
|
||||
|
||||
If you don't specify any blocking configuration (`blocking_threshold`, `blocking_conditions`, or `limit_comparisons`), the Resolve operation will automatically compute an optimal embedding-based blocking threshold at runtime. It samples pairs from your data, runs LLM comparisons on the sample, and finds a threshold that achieves 95% recall by default. You can adjust this with the `blocking_target_recall` parameter.
|
||||
You should not run this operation as-is unless your dataset is small! Running O(n^2) comparisons with an LLM can be extremely time-consuming for large datasets. Instead, optimize your pipeline first using `docetl build pipeline.yaml` and run the optimized version, which will generate efficient blocking rules for the operation. Make sure you've set `optimize: true` in your resolve operation config.
|
||||
|
||||
## Blocking
|
||||
|
||||
|
|
@ -132,8 +132,7 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un
|
|||
| `resolution_model` | The language model to use for reducing matched entries | Falls back to `default_model` |
|
||||
| `comparison_model` | The language model to use for comparing potential matches | Falls back to `default_model` |
|
||||
| `blocking_keys` | List of keys to use for initial blocking | All keys in the input data |
|
||||
| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | Auto-computed if not set |
|
||||
| `blocking_target_recall` | Target recall when auto-computing blocking threshold (0.0 to 1.0) | 0.95 |
|
||||
| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | None |
|
||||
| `blocking_conditions` | List of conditions for initial blocking | [] |
|
||||
| `input` | Specifies the schema or keys to subselect from each item to pass into the prompts | All keys from input items |
|
||||
| `embedding_batch_size` | The number of entries to send to the embedding model at a time | 1000 |
|
||||
|
|
@ -141,9 +140,9 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un
|
|||
| `limit_comparisons` | Maximum number of comparisons to perform | None |
|
||||
| `timeout` | Timeout for each LLM call in seconds | 120 |
|
||||
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |
|
||||
| `sample` | Number of samples to use for the operation | None |
|
||||
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
|
||||
| `bypass_cache` | If true, bypass the cache for this operation. | False |
|
||||
| `sample` | Number of samples to use for the operation | None |
|
||||
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
|
||||
| `bypass_cache` | If true, bypass the cache for this operation. | False |
|
||||
|
||||
## Best Practices
|
||||
|
||||
|
|
|
|||
|
|
@ -1,34 +0,0 @@
|
|||
# MOAR Optimizer
|
||||
|
||||
The MOAR (Multi-Objective Agentic Rewrites) optimizer explores different ways to optimize your pipeline, finding solutions that balance accuracy and cost.
|
||||
|
||||
## What is MOAR?
|
||||
|
||||
When optimizing pipelines, you trade off cost and accuracy. MOAR explores many different pipeline configurations (like changing models, adding validation steps, combining operations, etc.) and evaluates each one to find the best trade-offs. It returns a frontier of plans that balance cost and accuracy, giving you multiple optimized options to choose from based on your budget and accuracy requirements.
|
||||
|
||||
## Quick Navigation
|
||||
|
||||
- **[Getting Started](moar/getting-started.md)** - Step-by-step guide to run your first MOAR optimization
|
||||
- **[Configuration](moar/configuration.md)** - Complete reference for all configuration options
|
||||
- **[Evaluation Functions](moar/evaluation.md)** - How to write and use evaluation functions
|
||||
- **[Understanding Results](moar/results.md)** - What MOAR outputs and how to interpret it
|
||||
- **[Examples](moar/examples.md)** - Complete working examples
|
||||
- **[Troubleshooting](moar/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## When to Use MOAR
|
||||
|
||||
!!! success "Good for"
|
||||
- Finding cost-accuracy trade-offs across different models
|
||||
- When you want multiple optimization options to choose from
|
||||
- Custom evaluation metrics specific to your use case
|
||||
- Exploring different pipeline configurations automatically
|
||||
|
||||
## Basic Workflow
|
||||
|
||||
1. **Create your pipeline YAML** - Define your DocETL pipeline
|
||||
2. **Write an evaluation function** - Create a Python function to measure accuracy
|
||||
3. **Configure MOAR** - Set up `optimizer_config` in your YAML
|
||||
4. **Run optimization** - Execute `docetl build pipeline.yaml --optimizer moar`
|
||||
5. **Review results** - Choose from the cost-accuracy frontier
|
||||
|
||||
Ready to get started? Head to the [Getting Started guide](moar/getting-started.md).
|
||||
|
|
@ -1,116 +0,0 @@
|
|||
# MOAR Configuration Reference
|
||||
|
||||
Complete reference for all MOAR configuration options.
|
||||
|
||||
## Required Fields
|
||||
|
||||
All fields in `optimizer_config` are required (no defaults):
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `type` | `str` | Must be `"moar"` |
|
||||
| `save_dir` | `str` | Directory where MOAR results will be saved |
|
||||
| `available_models` | `list[str]` | List of LiteLLM model names to explore (e.g., `["gpt-4o-mini", "gpt-4o"]`). Make sure your API keys are set in your environment for these models. |
|
||||
| `evaluation_file` | `str` | Path to Python file containing `@register_eval` decorated function |
|
||||
| `metric_key` | `str` | Key in evaluation results dictionary to use as accuracy metric |
|
||||
| `max_iterations` | `int` | Maximum number of MOARSearch iterations to run |
|
||||
| `rewrite_agent_model` | `str` | LLM model to use for directive instantiation during search |
|
||||
|
||||
!!! warning "All Fields Required"
|
||||
MOAR will error if any required field is missing. There are no defaults.
|
||||
|
||||
## Optional Fields
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `dataset_path` | `str` | Inferred from `datasets` | Path to dataset file to use for optimization. Use a sample/hold-out dataset to avoid optimizing on your test set. |
|
||||
| `exploration_weight` | `float` | `1.414` | UCB exploration constant (higher = more exploration) |
|
||||
| `build_first_layer` | `bool` | `False` | Whether to build initial model-specific nodes |
|
||||
| `ground_truth_path` | `str` | `None` | Path to ground truth file (for evaluation) |
|
||||
|
||||
## Dataset Path
|
||||
|
||||
### Automatic Inference
|
||||
|
||||
If `dataset_path` is not specified, MOAR will automatically infer it from the `datasets` section of your YAML:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
transcripts:
|
||||
path: data/full_dataset.json # This will be used if dataset_path not specified
|
||||
type: file
|
||||
|
||||
optimizer_config:
|
||||
# dataset_path not specified - will use data/full_dataset.json
|
||||
# ... other config ...
|
||||
```
|
||||
|
||||
### Using Sample/Hold-Out Datasets
|
||||
|
||||
!!! tip "Best Practice"
|
||||
Use a sample or hold-out dataset for optimization to avoid optimizing on your test set.
|
||||
|
||||
```yaml
|
||||
optimizer_config:
|
||||
dataset_path: data/sample_dataset.json # Use sample/hold-out for optimization
|
||||
# ... other config ...
|
||||
|
||||
datasets:
|
||||
transcripts:
|
||||
path: data/full_dataset.json # Full dataset for final pipeline
|
||||
```
|
||||
|
||||
The optimizer will use the sample dataset, but your final pipeline uses the full dataset. This ensures you don't overfit to your test set during optimization.
|
||||
|
||||
## Model Configuration
|
||||
|
||||
### Available Models
|
||||
|
||||
!!! info "LiteLLM Model Names"
|
||||
Use LiteLLM model names (e.g., `gpt-4o-mini`, `gpt-4o`, `gpt-5.1`). Make sure your API keys are set in your environment.
|
||||
|
||||
```yaml
|
||||
available_models: # LiteLLM model names - ensure API keys are set
|
||||
- gpt-5.1-nano # Cheapest, lower accuracy
|
||||
- gpt-5.1-mini # Low cost, decent accuracy
|
||||
- gpt-5.1 # Balanced
|
||||
- gpt-4o # Higher cost, better accuracy
|
||||
```
|
||||
|
||||
### Model for Directive Instantiation
|
||||
|
||||
The `rewrite_agent_model` field specifies which LLM to use for generating optimization directives during the search process. This doesn't affect the models tested in `available_models`.
|
||||
|
||||
!!! tip "Cost Consideration"
|
||||
Use a cheaper model (like `gpt-4o-mini`) for directive instantiation to reduce search costs.
|
||||
|
||||
## Iteration Count
|
||||
|
||||
The `max_iterations` parameter controls how many pipeline configurations MOAR explores:
|
||||
|
||||
- **10-20 iterations**: Quick exploration, good for testing
|
||||
- **40 iterations**: Recommended for most use cases
|
||||
- **100+ iterations**: For complex pipelines or when you need the absolute best results
|
||||
|
||||
!!! note "Time vs Quality"
|
||||
More iterations give better results but take longer and cost more.
|
||||
|
||||
## Complete Example
|
||||
|
||||
```yaml
|
||||
optimizer_config:
|
||||
type: moar
|
||||
save_dir: results/moar_optimization
|
||||
available_models:
|
||||
- gpt-4o-mini
|
||||
- gpt-4o
|
||||
- gpt-5.1-mini
|
||||
- gpt-5.1
|
||||
evaluation_file: evaluate_medications.py
|
||||
metric_key: medication_extraction_score
|
||||
max_iterations: 40
|
||||
rewrite_agent_model: gpt-5.1
|
||||
dataset_path: data/sample.json # Optional
|
||||
exploration_weight: 1.414 # Optional
|
||||
```
|
||||
|
||||
|
|
@ -1,193 +0,0 @@
|
|||
# Evaluation Functions
|
||||
|
||||
How to write evaluation functions for MOAR optimization.
|
||||
|
||||
## How Evaluation Functions Work
|
||||
|
||||
Your evaluation function receives the pipeline output and computes metrics by comparing it to the original dataset. MOAR uses one specific metric from your returned dictionary (specified by `metric_key`) to optimize for accuracy.
|
||||
|
||||
!!! info "Function Signature"
|
||||
Your function must have exactly this signature:
|
||||
```python
|
||||
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
|
||||
```
|
||||
|
||||
### What You Receive
|
||||
|
||||
- **`results_file_path`**: Path to JSON file containing your pipeline's output
|
||||
- **`dataset_file_path`**: Path to JSON file containing the original dataset
|
||||
|
||||
### What You Return
|
||||
|
||||
A dictionary with numeric metrics. The key specified in `optimizer_config.metric_key` will be used as the accuracy metric for optimization.
|
||||
|
||||
!!! tip "Using Original Input Data"
|
||||
Pipeline output includes the original input data. For example, if your dataset has a `src` attribute, it will be available in the output. You can use this directly for comparison without loading the dataset file separately.
|
||||
|
||||
## Basic Example
|
||||
|
||||
```python
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
from docetl.utils_evaluation import register_eval
|
||||
|
||||
@register_eval
|
||||
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
|
||||
# Load pipeline output
|
||||
with open(results_file_path, 'r') as f:
|
||||
output = json.load(f)
|
||||
|
||||
total_correct = 0
|
||||
for result in output:
|
||||
# For example, if your dataset has a 'src' attribute, it's available in the output
|
||||
original_text = result.get("src", "").lower()
|
||||
# Replace "your_extraction_key" with the actual key from your pipeline output
|
||||
extracted_items = result.get("your_extraction_key", [])
|
||||
|
||||
# Check if extracted items appear in original text
|
||||
for item in extracted_items:
|
||||
if str(item).lower() in original_text:
|
||||
total_correct += 1
|
||||
|
||||
return {
|
||||
"extraction_score": total_correct, # This key is used if metric_key="extraction_score"
|
||||
"total_extracted": sum(len(r.get("your_extraction_key", [])) for r in output),
|
||||
}
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
!!! warning "Critical Requirements"
|
||||
- The function must be decorated with `@docetl.register_eval`
|
||||
- It must take exactly two arguments: `dataset_file_path` and `results_file_path`
|
||||
- It must return a dictionary with numeric metrics
|
||||
- The `metric_key` in your `optimizer_config` must match one of the keys in this dictionary
|
||||
- Only one function per file can be decorated with `@register_eval`
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
!!! tip "Keep It Fast"
|
||||
Your evaluation function will be called many times during optimization. Make sure it's efficient:
|
||||
|
||||
- Avoid expensive computations
|
||||
- Cache results if possible
|
||||
- Keep the function simple and fast
|
||||
|
||||
## Common Evaluation Patterns
|
||||
|
||||
### Pattern 1: Extraction Verification with Recall
|
||||
|
||||
Check if extracted items appear in the document text and compute recall:
|
||||
|
||||
```python
|
||||
@register_eval
|
||||
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
|
||||
with open(results_file_path, 'r') as f:
|
||||
output = json.load(f)
|
||||
|
||||
# For example, if your dataset has a 'src' attribute, it's available in the output
|
||||
total_correct = 0
|
||||
total_extracted = 0
|
||||
total_expected = 0
|
||||
|
||||
for result in output:
|
||||
# Replace "src" with the actual key from your dataset
|
||||
original_text = result.get("src", "").lower()
|
||||
extracted_items = result.get("your_extraction_key", []) # Replace with your key
|
||||
|
||||
# Count correct extractions (items that appear in text)
|
||||
for item in extracted_items:
|
||||
total_extracted += 1
|
||||
if str(item).lower() in original_text:
|
||||
total_correct += 1
|
||||
|
||||
# Count expected items (if you have ground truth)
|
||||
# total_expected += len(expected_items)
|
||||
|
||||
precision = total_correct / total_extracted if total_extracted > 0 else 0.0
|
||||
recall = total_correct / total_expected if total_expected > 0 else 0.0
|
||||
|
||||
return {
|
||||
"extraction_score": total_correct, # Use this as metric_key
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
}
|
||||
```
|
||||
|
||||
### Pattern 2: Comparing Against Ground Truth
|
||||
|
||||
Load ground truth from the dataset file and compare:
|
||||
|
||||
```python
|
||||
@register_eval
|
||||
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
|
||||
with open(results_file_path, 'r') as f:
|
||||
predictions = json.load(f)
|
||||
|
||||
with open(dataset_file_path, 'r') as f:
|
||||
ground_truth = json.load(f)
|
||||
|
||||
# Compare predictions with ground truth
|
||||
# Adjust keys based on your data structure
|
||||
correct = 0
|
||||
total = len(predictions)
|
||||
|
||||
for pred, truth in zip(predictions, ground_truth):
|
||||
# Example: compare classification labels
|
||||
if pred.get("predicted_label") == truth.get("true_label"):
|
||||
correct += 1
|
||||
|
||||
return {
|
||||
"accuracy": correct / total if total > 0 else 0.0,
|
||||
"correct": correct,
|
||||
"total": total,
|
||||
}
|
||||
```
|
||||
|
||||
### Pattern 3: External Evaluation (File or API)
|
||||
|
||||
Load additional data or call an API for evaluation:
|
||||
|
||||
```python
|
||||
import requests
|
||||
from pathlib import Path
|
||||
|
||||
@register_eval
|
||||
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
|
||||
with open(results_file_path, 'r') as f:
|
||||
output = json.load(f)
|
||||
|
||||
# Option A: Load ground truth from a separate file
|
||||
ground_truth_path = Path(dataset_file_path).parent / "ground_truth.json"
|
||||
with open(ground_truth_path, 'r') as f:
|
||||
ground_truth = json.load(f)
|
||||
|
||||
# Option B: Call an API for evaluation
|
||||
# response = requests.post("https://api.example.com/evaluate", json=output)
|
||||
# api_score = response.json()["score"]
|
||||
|
||||
# Evaluate using ground truth
|
||||
scores = []
|
||||
for result, truth in zip(output, ground_truth):
|
||||
# Your evaluation logic here
|
||||
score = compute_score(result, truth)
|
||||
scores.append(score)
|
||||
|
||||
return {
|
||||
"average_score": sum(scores) / len(scores) if scores else 0.0,
|
||||
"scores": scores,
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Your Function
|
||||
|
||||
!!! tip "Test Before Running"
|
||||
Test your evaluation function independently before running MOAR:
|
||||
|
||||
```python
|
||||
result = evaluate_results("dataset.json", "results.json")
|
||||
print(result) # Check that your metric_key is present
|
||||
```
|
||||
|
||||
This helps catch errors early and ensures your function works correctly.
|
||||
|
||||
|
|
@ -1,133 +0,0 @@
|
|||
# MOAR Examples
|
||||
|
||||
Complete working examples for MOAR optimization.
|
||||
|
||||
## Medication Extraction Example
|
||||
|
||||
This example extracts medications from medical transcripts and evaluates extraction accuracy.
|
||||
|
||||
!!! note "Metric Key"
|
||||
The `metric_key` in the `optimizer_config` section specifies which key from your evaluation function's return dictionary will be used as the accuracy metric. In this example, `metric_key: medication_extraction_score` means MOAR will optimize using the `medication_extraction_score` value returned by the evaluation function.
|
||||
|
||||
### pipeline.yaml
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
transcripts:
|
||||
path: workloads/medical/raw.json
|
||||
type: file
|
||||
|
||||
default_model: gpt-4o-mini
|
||||
bypass_cache: true
|
||||
|
||||
optimizer_config:
|
||||
type: moar
|
||||
dataset_path: workloads/medical/raw_sample.json # Use sample for faster optimization
|
||||
save_dir: workloads/medical/moar_results
|
||||
available_models: # LiteLLM model names - ensure API keys are set in your environment
|
||||
- gpt-5.1-nano
|
||||
- gpt-5.1-mini
|
||||
- gpt-5.1
|
||||
- gpt-4o
|
||||
- gpt-4o-mini
|
||||
evaluation_file: workloads/medical/evaluate_medications.py
|
||||
metric_key: medication_extraction_score
|
||||
max_iterations: 40
|
||||
rewrite_agent_model: gpt-5.1
|
||||
|
||||
system_prompt:
|
||||
dataset_description: a collection of transcripts of doctor visits
|
||||
persona: a medical practitioner analyzing patient symptoms and reactions to medications
|
||||
|
||||
operations:
|
||||
- name: extract_medications
|
||||
type: map
|
||||
output:
|
||||
schema:
|
||||
medication: list[str]
|
||||
prompt: |
|
||||
Analyze the following transcript of a conversation between a doctor and a patient:
|
||||
{{ input.src }}
|
||||
Extract and list all medications mentioned in the transcript.
|
||||
If no medications are mentioned, return an empty list.
|
||||
|
||||
pipeline:
|
||||
steps:
|
||||
- name: medication_extraction
|
||||
input: transcripts
|
||||
operations:
|
||||
- extract_medications
|
||||
output:
|
||||
type: file
|
||||
path: workloads/medical/extracted_medications_results.json
|
||||
```
|
||||
|
||||
### evaluate_medications.py
|
||||
|
||||
```python
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
from docetl.utils_evaluation import register_eval
|
||||
|
||||
@register_eval
|
||||
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Evaluate medication extraction results.
|
||||
|
||||
Checks if each extracted medication appears verbatim in the original transcript.
|
||||
In this example, the dataset has a 'src' attribute with the original input text.
|
||||
"""
|
||||
# Load pipeline output
|
||||
with open(results_file_path, 'r') as f:
|
||||
output = json.load(f)
|
||||
|
||||
total_correct_medications = 0
|
||||
total_extracted_medications = 0
|
||||
|
||||
# Evaluate each result
|
||||
for result in output:
|
||||
# In this example, the dataset has a 'src' attribute with the original transcript
|
||||
original_transcript = result.get("src", "").lower()
|
||||
extracted_medications = result.get("medication", [])
|
||||
|
||||
# Check each extracted medication
|
||||
for medication in extracted_medications:
|
||||
total_extracted_medications += 1
|
||||
medication_lower = str(medication).lower().strip()
|
||||
|
||||
# Check if medication appears in transcript
|
||||
if medication_lower in original_transcript:
|
||||
total_correct_medications += 1
|
||||
|
||||
# Calculate metrics
|
||||
precision = total_correct_medications / total_extracted_medications if total_extracted_medications > 0 else 0.0
|
||||
|
||||
return {
|
||||
"medication_extraction_score": total_correct_medications, # This is used as the accuracy metric
|
||||
"total_correct_medications": total_correct_medications,
|
||||
"total_extracted_medications": total_extracted_medications,
|
||||
"precision": precision,
|
||||
}
|
||||
```
|
||||
|
||||
### Running the Optimization
|
||||
|
||||
```bash
|
||||
docetl build workloads/medical/pipeline_medication_extraction.yaml --optimizer moar
|
||||
```
|
||||
|
||||
!!! tip "Using Sample Datasets"
|
||||
Notice that `dataset_path` points to `raw_sample.json` for optimization, while the main pipeline uses `raw.json`. This prevents optimizing on your test set.
|
||||
|
||||
## Key Points
|
||||
|
||||
!!! info "Evaluation Function"
|
||||
- In this example, uses the `src` attribute from output items (no need to load dataset separately)
|
||||
- Checks if extracted medications appear verbatim in the transcript
|
||||
- Returns multiple metrics, with `medication_extraction_score` as the primary one
|
||||
|
||||
!!! info "Configuration"
|
||||
- Uses a sample dataset for optimization (`dataset_path`)
|
||||
- Includes multiple models in `available_models` to explore trade-offs
|
||||
- Sets `max_iterations` to 40 for a good balance of exploration and time
|
||||
|
||||
|
|
@ -1,157 +0,0 @@
|
|||
# Getting Started with MOAR
|
||||
|
||||
This guide walks you through running your first MOAR optimization step by step.
|
||||
|
||||
## Step 1: Create Your Pipeline YAML
|
||||
|
||||
Start with a standard DocETL pipeline YAML file:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
transcripts:
|
||||
path: data/transcripts.json
|
||||
type: file
|
||||
|
||||
default_model: gpt-4o-mini
|
||||
|
||||
operations:
|
||||
- name: extract_medications
|
||||
type: map
|
||||
output:
|
||||
schema:
|
||||
medication: list[str]
|
||||
prompt: |
|
||||
Extract all medications mentioned in: {{ input.src }}
|
||||
|
||||
pipeline:
|
||||
steps:
|
||||
- name: medication_extraction
|
||||
input: transcripts
|
||||
operations:
|
||||
- extract_medications
|
||||
output:
|
||||
type: file
|
||||
path: results.json
|
||||
```
|
||||
|
||||
!!! note "Standard Pipeline"
|
||||
Your pipeline doesn't need any special configuration for MOAR. Just create a normal DocETL pipeline.
|
||||
|
||||
## Step 2: Create an Evaluation Function
|
||||
|
||||
Create a Python file with an evaluation function. This function will be called for each pipeline configuration that MOAR explores.
|
||||
|
||||
!!! info "How Evaluation Works"
|
||||
- Your function receives the pipeline output and the original dataset
|
||||
- You compute evaluation metrics by comparing the output to the dataset
|
||||
- You return a dictionary of metrics
|
||||
- MOAR uses one specific key from this dictionary (specified by `metric_key`) as the accuracy metric to optimize
|
||||
|
||||
```python
|
||||
# evaluate_medications.py
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
from docetl.utils_evaluation import register_eval
|
||||
|
||||
@register_eval
|
||||
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Evaluate pipeline output against the original dataset.
|
||||
"""
|
||||
# Load pipeline output
|
||||
with open(results_file_path, 'r') as f:
|
||||
output = json.load(f)
|
||||
|
||||
# Load original dataset for comparison
|
||||
with open(dataset_file_path, 'r') as f:
|
||||
dataset = json.load(f)
|
||||
|
||||
# Compute your evaluation metrics
|
||||
correct_count = 0
|
||||
total_count = len(output)
|
||||
|
||||
for idx, result in enumerate(output):
|
||||
# Compare result with original data
|
||||
# For example, if your dataset has a 'src' attribute, it's available in the output
|
||||
original_text = result.get("src", "").lower()
|
||||
extracted_items = result.get("medication", [])
|
||||
|
||||
# Check if extracted items appear in original text
|
||||
for item in extracted_items:
|
||||
if item.lower() in original_text:
|
||||
correct_count += 1
|
||||
|
||||
# Return dictionary of metrics
|
||||
return {
|
||||
"medication_extraction_score": correct_count, # This key will be used if metric_key matches
|
||||
"total_extracted": total_count,
|
||||
"precision": correct_count / total_count if total_count > 0 else 0.0,
|
||||
}
|
||||
```
|
||||
|
||||
!!! warning "Important Requirements"
|
||||
- The function must be decorated with `@docetl.register_eval`
|
||||
- It must take exactly two arguments: `dataset_file_path` and `results_file_path`
|
||||
- It must return a dictionary with numeric metrics
|
||||
- The `metric_key` in your `optimizer_config` must match one of the keys in this dictionary
|
||||
- Only one function per file can be decorated with `@register_eval`
|
||||
|
||||
For more details on evaluation functions, see the [Evaluation Functions guide](evaluation.md).
|
||||
|
||||
## Step 3: Configure the Optimizer
|
||||
|
||||
Add an `optimizer_config` section to your YAML. The `metric_key` specifies which key from your evaluation function's return dictionary will be used as the accuracy metric for optimization:
|
||||
|
||||
```yaml
|
||||
optimizer_config:
|
||||
type: moar
|
||||
save_dir: results/moar_optimization
|
||||
available_models: # LiteLLM model names - ensure API keys are set in your environment
|
||||
- gpt-4o-mini
|
||||
- gpt-4o
|
||||
- gpt-5.1-mini
|
||||
- gpt-5.1
|
||||
evaluation_file: evaluate_medications.py
|
||||
metric_key: medication_extraction_score # This must match a key in your evaluation function's return dictionary
|
||||
max_iterations: 40
|
||||
rewrite_agent_model: gpt-5.1
|
||||
dataset_path: data/transcripts_sample.json # Optional: use sample/hold-out dataset
|
||||
```
|
||||
|
||||
!!! tip "Using Sample Datasets"
|
||||
Use `dataset_path` to specify a sample or hold-out dataset for optimization. This prevents optimizing on your test set. The main pipeline will still use the full dataset from the `datasets` section.
|
||||
|
||||
For complete configuration details, see the [Configuration Reference](configuration.md).
|
||||
|
||||
## Step 4: Run the Optimizer
|
||||
|
||||
Run MOAR optimization using the CLI:
|
||||
|
||||
```bash
|
||||
docetl build pipeline.yaml --optimizer moar
|
||||
```
|
||||
|
||||
!!! success "What Happens Next"
|
||||
MOAR will:
|
||||
1. Explore different pipeline configurations
|
||||
2. Evaluate each configuration using your evaluation function
|
||||
3. Build a cost-accuracy frontier of optimal solutions
|
||||
4. Save results to your `save_dir`
|
||||
|
||||
## Step 5: Review Results
|
||||
|
||||
After optimization completes, check your `save_dir` for:
|
||||
|
||||
- **`experiment_summary.json`** - High-level summary of the run
|
||||
- **`pareto_frontier.json`** - List of optimal solutions
|
||||
- **`evaluation_metrics.json`** - Detailed evaluation results
|
||||
- **`pipeline_*.yaml`** - Optimized pipeline configurations
|
||||
|
||||
For details on interpreting results, see [Understanding Results](results.md).
|
||||
|
||||
## Next Steps
|
||||
|
||||
- Learn about [configuration options](configuration.md)
|
||||
- See [complete examples](examples.md)
|
||||
- Read [troubleshooting tips](troubleshooting.md)
|
||||
|
||||
|
|
@ -1,93 +0,0 @@
|
|||
# Understanding MOAR Results
|
||||
|
||||
What MOAR outputs and how to interpret the results.
|
||||
|
||||
## Output Files
|
||||
|
||||
After running MOAR optimization, you'll find several files in your `save_dir`:
|
||||
|
||||
- **`experiment_summary.json`** - High-level summary
|
||||
- **`pareto_frontier.json`** - Optimal solutions
|
||||
- **`evaluation_metrics.json`** - Detailed evaluation results
|
||||
- **`pipeline_*.yaml`** - Optimized pipeline configurations
|
||||
|
||||
## experiment_summary.json
|
||||
|
||||
High-level summary of the optimization run:
|
||||
|
||||
```json
|
||||
{
|
||||
"optimizer": "moar",
|
||||
"input_pipeline": "pipeline.yaml",
|
||||
"rewrite_agent_model": "gpt-5.1",
|
||||
"max_iterations": 40,
|
||||
"save_dir": "results/moar_optimization",
|
||||
"dataset": "transcripts",
|
||||
"start_time": "2024-01-15T10:30:00",
|
||||
"end_time": "2024-01-15T11:15:00",
|
||||
"duration_seconds": 2700,
|
||||
"num_best_nodes": 5,
|
||||
"total_nodes_explored": 120,
|
||||
"total_search_cost": 15.50
|
||||
}
|
||||
```
|
||||
|
||||
!!! info "Key Metrics"
|
||||
- `num_best_nodes`: Number of solutions on the Pareto frontier
|
||||
- `total_nodes_explored`: Total configurations tested
|
||||
- `total_search_cost`: Total cost of the optimization search
|
||||
|
||||
## pareto_frontier.json
|
||||
|
||||
List of Pareto-optimal solutions (the cost-accuracy frontier):
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"node_id": 5,
|
||||
"yaml_path": "results/moar_optimization/pipeline_5.yaml",
|
||||
"cost": 0.05,
|
||||
"accuracy": 0.92
|
||||
},
|
||||
{
|
||||
"node_id": 12,
|
||||
"yaml_path": "results/moar_optimization/pipeline_12.yaml",
|
||||
"cost": 0.08,
|
||||
"accuracy": 0.95
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
!!! tip "Choosing a Solution"
|
||||
Review the Pareto frontier to find solutions that match your priorities:
|
||||
|
||||
- **Low cost priority**: Choose solutions with lower cost
|
||||
- **High accuracy priority**: Choose solutions with higher accuracy
|
||||
- **Balanced**: Choose solutions in the middle
|
||||
|
||||
Each solution includes a `yaml_path` pointing to the optimized pipeline configuration.
|
||||
|
||||
## evaluation_metrics.json
|
||||
|
||||
Detailed evaluation results for all explored configurations. This file contains comprehensive metrics for every pipeline configuration tested during optimization.
|
||||
|
||||
## Pipeline Configurations
|
||||
|
||||
Each solution on the Pareto frontier has a corresponding YAML file (e.g., `pipeline_5.yaml`) containing the optimized pipeline configuration. You can:
|
||||
|
||||
1. Review the changes MOAR made
|
||||
2. Test the pipeline on your full dataset
|
||||
3. Use it in production
|
||||
|
||||
## Next Steps
|
||||
|
||||
After reviewing the results:
|
||||
|
||||
1. **Review the Pareto frontier** - See available options
|
||||
2. **Choose a solution** - Based on your accuracy/cost priorities
|
||||
3. **Test the chosen pipeline** - Run it on your full dataset
|
||||
4. **Integrate into production** - Use the optimized configuration
|
||||
|
||||
!!! success "Success"
|
||||
You now have multiple optimized pipeline options to choose from, each representing a different point on the cost-accuracy trade-off curve.
|
||||
|
||||
|
|
@ -1,113 +0,0 @@
|
|||
# Troubleshooting MOAR
|
||||
|
||||
Common issues and solutions when using MOAR optimization.
|
||||
|
||||
## Error: Missing required accuracy metric
|
||||
|
||||
!!! error "Error Message"
|
||||
`KeyError: Missing required accuracy metric 'your_metric_key'`
|
||||
|
||||
**Solution:**
|
||||
|
||||
Check that:
|
||||
|
||||
1. Your evaluation function returns a dictionary with the `metric_key` you specified
|
||||
2. The `metric_key` in `optimizer_config` matches the key in your evaluation results
|
||||
3. Your evaluation function is working correctly (test it independently)
|
||||
|
||||
```python
|
||||
# Test your function
|
||||
result = evaluate_results("dataset.json", "results.json")
|
||||
print(result) # Verify your metric_key is present
|
||||
```
|
||||
|
||||
## Error: Evaluation function takes wrong number of arguments
|
||||
|
||||
!!! error "Error Message"
|
||||
`TypeError: evaluate_results() takes 1 positional argument but 2 were given`
|
||||
|
||||
**Solution:**
|
||||
|
||||
Make sure your evaluation function has exactly this signature:
|
||||
|
||||
```python
|
||||
def evaluate_results(dataset_file_path: str, results_file_path: str) -> Dict[str, Any]:
|
||||
```
|
||||
|
||||
And that it's decorated with `@docetl.register_eval`.
|
||||
|
||||
## All accuracies showing as 0.0
|
||||
|
||||
!!! warning "Symptom"
|
||||
All solutions show 0.0 accuracy in the Pareto frontier.
|
||||
|
||||
**Possible causes:**
|
||||
|
||||
1. **Evaluation function failing silently** - Check the error logs
|
||||
2. **Result files don't exist** - Make sure pipelines are executing successfully
|
||||
3. **Metric key doesn't match** - Verify `metric_key` matches what your function returns
|
||||
|
||||
**Solution:**
|
||||
|
||||
Test your evaluation function independently and check MOAR logs for errors.
|
||||
|
||||
## Optimization taking too long
|
||||
|
||||
!!! tip "Speed Up Optimization"
|
||||
If optimization is taking too long, try:
|
||||
|
||||
- Reduce `max_iterations` (e.g., from 40 to 20)
|
||||
- Use a smaller sample dataset via `dataset_path`
|
||||
- Reduce the number of models in `available_models`
|
||||
- Use a faster model for directive instantiation (`model` parameter)
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Using Sample/Hold-Out Datasets
|
||||
|
||||
!!! tip "Avoid Overfitting"
|
||||
Always use a sample or hold-out dataset for optimization to avoid optimizing on your test set:
|
||||
|
||||
```yaml
|
||||
optimizer_config:
|
||||
dataset_path: data/sample_100.json # Use sample/hold-out for optimization
|
||||
```
|
||||
|
||||
### Choosing Models
|
||||
|
||||
!!! tip "Model Selection"
|
||||
Include a range of models in `available_models` to explore cost-accuracy trade-offs:
|
||||
|
||||
```yaml
|
||||
available_models:
|
||||
- gpt-5.1-nano # Cheapest, lower accuracy
|
||||
- gpt-5.1-mini # Low cost, decent accuracy
|
||||
- gpt-5.1 # Balanced
|
||||
- gpt-4o # Higher cost, better accuracy
|
||||
```
|
||||
|
||||
### Iteration Count
|
||||
|
||||
!!! tip "Iteration Guidelines"
|
||||
- **10-20 iterations**: Quick exploration, good for testing
|
||||
- **40 iterations**: Recommended for most use cases
|
||||
- **100+ iterations**: For complex pipelines or when you need the absolute best results
|
||||
|
||||
### Evaluation Function Performance
|
||||
|
||||
!!! tip "Keep Functions Fast"
|
||||
Your evaluation function will be called many times. Make sure it's efficient:
|
||||
|
||||
- Avoid expensive computations
|
||||
- Cache results if possible
|
||||
- Keep the function simple and fast
|
||||
|
||||
## Getting Help
|
||||
|
||||
If you're still experiencing issues:
|
||||
|
||||
1. Check the MOAR logs for detailed error messages
|
||||
2. Verify your evaluation function works independently
|
||||
3. Test with a smaller `max_iterations` to isolate issues
|
||||
4. Review the [Configuration Reference](configuration.md) to ensure all required fields are set
|
||||
|
||||
|
|
@ -1,29 +1,6 @@
|
|||
# DocETL Optimizer
|
||||
|
||||
DocETL provides two optimizer options to improve your document processing pipelines:
|
||||
|
||||
## MOAR Optimizer (Recommended)
|
||||
|
||||
The **MOAR (Multi-Objective Agentic Rewrites)** optimizer uses Monte Carlo Tree Search to explore optimization space and find Pareto-optimal solutions that balance accuracy and cost. It's the recommended optimizer for most use cases.
|
||||
|
||||
**Key Features:**
|
||||
|
||||
- Multi-objective optimization (accuracy + cost)
|
||||
- Returns multiple Pareto-optimal solutions
|
||||
- Automatic model exploration
|
||||
- Custom evaluation functions
|
||||
- Intelligent search using MCTS
|
||||
|
||||
See the [MOAR Optimizer Guide](moar.md) for detailed documentation and examples.
|
||||
|
||||
## V1 Optimizer (Deprecated)
|
||||
|
||||
!!! warning "Deprecated"
|
||||
The V1 optimizer is deprecated and no longer recommended. Use MOAR instead for all new optimizations.
|
||||
|
||||
The V1 optimizer uses a greedy approach with validation to find improved pipeline configurations. It's still available for backward compatibility but should not be used for new projects.
|
||||
|
||||
The rest of this page describes the general optimization concepts that apply to both optimizers.
|
||||
The DocETL optimizer finds a plan that improves the accuracy of your document processing pipelines. It works by analyzing and potentially rewriting operations marked for optimization, finding optimal plans for execution.
|
||||
|
||||
## Key Features
|
||||
|
||||
|
|
@ -104,17 +81,3 @@ After applying the optimizer, your pipeline could be transformed into a more eff
|
|||
3. **Reduce Operation**: For each contract, combine the extracted and tagged clauses from each chunk.
|
||||
|
||||
The goal of the DocETL optimizer is to try many ways of rewriting your pipeline and then select the best one. This may take some time (20-30 minutes for very complex tasks and large documents). But the optimizer's ability to break down complex tasks into more manageable sub-steps can lead to more accurate and reliable results.
|
||||
|
||||
## Choosing an Optimizer
|
||||
|
||||
**Use MOAR if:**
|
||||
|
||||
- You want to explore cost-accuracy trade-offs
|
||||
- You need multiple solution options (Pareto frontier)
|
||||
- You have custom evaluation metrics
|
||||
- You want automatic model exploration
|
||||
|
||||
!!! warning "V1 Optimizer Deprecated"
|
||||
The V1 optimizer is deprecated. Use MOAR instead. If you have existing V1-optimized pipelines, they will continue to work, but new optimizations should use MOAR.
|
||||
|
||||
For detailed MOAR usage, see the [MOAR Optimizer Guide](moar.md).
|
||||
|
|
|
|||
|
|
@ -1,45 +0,0 @@
|
|||
# Quick Start with Claude Code
|
||||
|
||||
The fastest way to build DocETL pipelines is with [Claude Code](https://claude.ai/download), Anthropic's agentic coding tool. DocETL includes a built-in Claude Code skill that helps you create, run, and debug pipelines interactively.
|
||||
|
||||
## Option 1: Clone the Repository (Recommended)
|
||||
|
||||
This gives you the full development environment with the skill already configured.
|
||||
|
||||
1. Follow the [Installation from Source](installation.md#installation-from-source) instructions
|
||||
2. Run `claude` in the repository directory
|
||||
|
||||
The skill is located at `.claude/skills/docetl/SKILL.md`.
|
||||
|
||||
## Option 2: Install via pip
|
||||
|
||||
If you already have DocETL installed via pip, you can install the skill separately:
|
||||
|
||||
```bash
|
||||
pip install docetl
|
||||
docetl install-skill
|
||||
```
|
||||
|
||||
This copies the skill to `~/.claude/skills/docetl/`. Then run `claude` in any directory.
|
||||
|
||||
To uninstall: `docetl install-skill --uninstall`
|
||||
|
||||
## Usage
|
||||
|
||||
Simply describe what you want to do with your data. The skill activates automatically when you mention "docetl" or describe unstructured data processing tasks:
|
||||
|
||||
```
|
||||
> I have a folder of customer support tickets in JSON format.
|
||||
> I want to extract the main issue, sentiment, and suggested resolution for each.
|
||||
```
|
||||
|
||||
Claude will:
|
||||
|
||||
1. **Read your data** to understand its structure
|
||||
2. **Write a tailored pipeline** with prompts specific to your documents
|
||||
3. **Run the pipeline** and show you the results
|
||||
4. **Debug issues** if any operations fail
|
||||
|
||||
## Alternative: Manual Pipeline Authoring
|
||||
|
||||
If you prefer not to use Claude Code, see the [Quick Start Tutorial](tutorial.md) for writing pipelines by hand.
|
||||
|
|
@ -1,396 +0,0 @@
|
|||
## Retrievers (LanceDB OSS)
|
||||
|
||||
Retrievers let you augment LLM operations with retrieved context from a LanceDB index built over a DocETL dataset. You define retrievers once at the top-level, then attach them to any LLM-powered operation using `retriever: <name>`. At runtime, DocETL performs full-text, vector, or hybrid search and injects the results into your prompt as `{{ retrieval_context }}`.
|
||||
|
||||
LanceDB supports built-in full-text search, vector search, and hybrid with RRF reranking. See the official docs: [LanceDB Hybrid Search docs](https://lancedb.com/docs/search/hybrid-search/).
|
||||
|
||||
### Key points
|
||||
|
||||
- Always OSS LanceDB (local `index_dir`).
|
||||
- A retriever references a dataset from the pipeline config, or the output of a previous pipeline step.
|
||||
- Operations do not override retriever settings. One source of truth = consistency.
|
||||
- `{{ retrieval_context }}` is available to your prompt; if not used, DocETL prepends a short "extra context" section automatically.
|
||||
|
||||
## Configuration
|
||||
|
||||
Add a top-level `retrievers` section. Each retriever has:
|
||||
|
||||
- `dataset`: dataset name to index (can be a dataset or output of a previous pipeline step)
|
||||
- `index_dir`: LanceDB path
|
||||
- `index_types`: which indexes to build: `fts`, `embedding`, or `hybrid` (both)
|
||||
- `fts.index_phrase`: Jinja template for indexing each row for full-text search
|
||||
- `fts.query_phrase`: Jinja template for building the FTS query at runtime
|
||||
- `embedding.model`: embedding model for vector index and queries
|
||||
- `embedding.index_phrase`: Jinja template for indexing each row for embeddings
|
||||
- `embedding.query_phrase`: Jinja template for building the embedding query
|
||||
- `query.mode`: `fts` | `embedding` | `hybrid` (defaults to `hybrid` when both indexes exist)
|
||||
- `query.top_k`: number of results to retrieve
|
||||
|
||||
### Basic example
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
transcripts:
|
||||
type: file
|
||||
path: workloads/medical/raw.json
|
||||
|
||||
default_model: gpt-4o-mini
|
||||
|
||||
retrievers:
|
||||
medical_r:
|
||||
type: lancedb
|
||||
dataset: transcripts
|
||||
index_dir: workloads/medical/lance_index
|
||||
build_index: if_missing # if_missing | always | never
|
||||
index_types: ["fts", "embedding"]
|
||||
fts:
|
||||
index_phrase: "{{ input.src }}"
|
||||
query_phrase: "{{ input.src[:1000] }}"
|
||||
embedding:
|
||||
model: openai/text-embedding-3-small
|
||||
index_phrase: "{{ input.src }}"
|
||||
query_phrase: "{{ input.src[:1000] }}"
|
||||
query:
|
||||
mode: hybrid
|
||||
top_k: 8
|
||||
```
|
||||
|
||||
## Multi-step pipelines with retrieval
|
||||
|
||||
Most pipelines have a single step, but you can define multiple steps where **the output of one step becomes the input (and retriever source) for the next**. This is powerful for patterns like:
|
||||
|
||||
1. Extract structured data from documents
|
||||
2. Build a retrieval index on that extracted data
|
||||
3. Use retrieval to find related items and process them
|
||||
|
||||
### Example: Extract facts, then find conflicts
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
articles:
|
||||
type: file
|
||||
path: workloads/wiki/articles.json
|
||||
|
||||
default_model: gpt-4o-mini
|
||||
|
||||
# Retriever indexes output of step 1 (extract_facts_step)
|
||||
retrievers:
|
||||
facts_index:
|
||||
type: lancedb
|
||||
dataset: extract_facts_step # References output of a pipeline step!
|
||||
index_dir: workloads/wiki/facts_lance_index
|
||||
build_index: if_missing
|
||||
index_types: ["fts", "embedding"]
|
||||
fts:
|
||||
index_phrase: "{{ input.fact }} from {{ input.title }}"
|
||||
query_phrase: "{{ input.fact }}"
|
||||
embedding:
|
||||
model: openai/text-embedding-3-small
|
||||
index_phrase: "{{ input.fact }}"
|
||||
query_phrase: "{{ input.fact }}"
|
||||
query:
|
||||
mode: hybrid
|
||||
top_k: 5
|
||||
|
||||
operations:
|
||||
- name: extract_facts
|
||||
type: map
|
||||
prompt: |
|
||||
Extract factual claims from this article.
|
||||
Article: {{ input.title }}
|
||||
Text: {{ input.text }}
|
||||
output:
|
||||
schema:
|
||||
facts: list[string]
|
||||
|
||||
- name: unnest_facts
|
||||
type: unnest
|
||||
unnest_key: facts
|
||||
|
||||
- name: find_conflicts
|
||||
type: map
|
||||
retriever: facts_index # Uses the retriever
|
||||
prompt: |
|
||||
Check if this fact conflicts with similar facts from other articles.
|
||||
|
||||
Current fact: {{ input.facts }} (from {{ input.title }})
|
||||
|
||||
Similar facts from other articles:
|
||||
{{ retrieval_context }}
|
||||
|
||||
Return true only if there's a genuine contradiction.
|
||||
output:
|
||||
schema:
|
||||
has_conflict: boolean
|
||||
|
||||
pipeline:
|
||||
steps:
|
||||
# Step 1: Extract and unnest facts
|
||||
- name: extract_facts_step
|
||||
input: articles
|
||||
operations:
|
||||
- extract_facts
|
||||
- unnest_facts
|
||||
|
||||
# Step 2: Use retrieval to find conflicts
|
||||
- name: find_conflicts_step
|
||||
input: extract_facts_step # Input is output of step 1
|
||||
operations:
|
||||
- find_conflicts
|
||||
|
||||
output:
|
||||
type: file
|
||||
path: workloads/wiki/conflicts.json
|
||||
intermediate_dir: workloads/wiki/intermediates
|
||||
```
|
||||
|
||||
In this example:
|
||||
- **Step 1** (`extract_facts_step`) extracts facts from articles
|
||||
- The **retriever** (`facts_index`) indexes the output of step 1
|
||||
- **Step 2** (`find_conflicts_step`) processes each fact, using retrieval to find similar facts from other articles
|
||||
|
||||
## Configuration reference
|
||||
|
||||
### Minimal example
|
||||
|
||||
Here's the simplest possible retriever config (FTS only):
|
||||
|
||||
```yaml
|
||||
retrievers:
|
||||
my_search: # name can be anything you want
|
||||
type: lancedb
|
||||
dataset: my_dataset # must match a dataset name or pipeline step
|
||||
index_dir: ./my_lance_index
|
||||
index_types: ["fts"]
|
||||
fts:
|
||||
index_phrase: "{{ input.text }}" # what to index from each row
|
||||
query_phrase: "{{ input.query }}" # what to search for at runtime
|
||||
```
|
||||
|
||||
### Full example with all options
|
||||
|
||||
```yaml
|
||||
retrievers:
|
||||
my_search:
|
||||
type: lancedb
|
||||
dataset: my_dataset
|
||||
index_dir: ./my_lance_index
|
||||
build_index: if_missing # optional, default: if_missing
|
||||
index_types: ["fts", "embedding"] # can be ["fts"], ["embedding"], or both
|
||||
fts:
|
||||
index_phrase: "{{ input.text }}"
|
||||
query_phrase: "{{ input.query }}"
|
||||
embedding:
|
||||
model: openai/text-embedding-3-small
|
||||
index_phrase: "{{ input.text }}" # optional, falls back to fts.index_phrase
|
||||
query_phrase: "{{ input.query }}"
|
||||
query: # optional section
|
||||
mode: hybrid # optional, auto-selects based on index_types
|
||||
top_k: 10 # optional, default: 5
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Required fields
|
||||
|
||||
| Field | Description |
|
||||
| --- | --- |
|
||||
| `type` | Must be `lancedb` |
|
||||
| `dataset` | Name of a dataset or pipeline step to index |
|
||||
| `index_dir` | Path where LanceDB stores the index (created if missing) |
|
||||
| `index_types` | List of index types: `["fts"]`, `["embedding"]`, or `["fts", "embedding"]` |
|
||||
|
||||
---
|
||||
|
||||
### Optional fields
|
||||
|
||||
| Field | Default | Description |
|
||||
| --- | --- | --- |
|
||||
| `build_index` | `if_missing` | When to build: `if_missing`, `always`, or `never` |
|
||||
| `query.mode` | auto | `fts`, `embedding`, or `hybrid`. Auto-selects based on what indexes exist |
|
||||
| `query.top_k` | 5 | Number of results to return |
|
||||
|
||||
---
|
||||
|
||||
### The `fts` section
|
||||
|
||||
Required if `"fts"` is in `index_types`. Configures full-text search.
|
||||
|
||||
| Field | Required | Description |
|
||||
| --- | --- | --- |
|
||||
| `index_phrase` | yes | Jinja template: what text to index from each dataset row |
|
||||
| `query_phrase` | yes | Jinja template: what text to search for at query time |
|
||||
|
||||
**Jinja variables available:**
|
||||
|
||||
| Template | Variables | When it runs |
|
||||
| --- | --- | --- |
|
||||
| `index_phrase` | `input` = the dataset row | Once per row when building the index |
|
||||
| `query_phrase` | `input` = current item (map/filter/extract) | At query time for each item processed |
|
||||
| `query_phrase` | `reduce_key`, `inputs` (reduce operations) | At query time for each group |
|
||||
|
||||
**Example - Medical knowledge base:**
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
drugs:
|
||||
type: file
|
||||
path: drugs.json # [{"name": "Aspirin", "uses": "pain, fever"}, ...]
|
||||
|
||||
patient_notes:
|
||||
type: file
|
||||
path: notes.json # [{"symptoms": "headache and fever"}, ...]
|
||||
|
||||
retrievers:
|
||||
drug_lookup:
|
||||
type: lancedb
|
||||
dataset: drugs # index the drugs dataset
|
||||
index_dir: ./drug_index
|
||||
index_types: ["fts"]
|
||||
fts:
|
||||
index_phrase: "{{ input.name }}: {{ input.uses }}" # index: "Aspirin: pain, fever"
|
||||
query_phrase: "{{ input.symptoms }}" # search with patient symptoms
|
||||
|
||||
operations:
|
||||
- name: find_treatment
|
||||
type: map
|
||||
retriever: drug_lookup # attach the retriever
|
||||
prompt: |
|
||||
Patient symptoms: {{ input.symptoms }}
|
||||
|
||||
Relevant drugs from knowledge base:
|
||||
{{ retrieval_context }}
|
||||
|
||||
Recommend a treatment.
|
||||
output:
|
||||
schema:
|
||||
recommendation: string
|
||||
```
|
||||
|
||||
When processing `{"symptoms": "headache and fever"}`:
|
||||
|
||||
1. `query_phrase` renders to `"headache and fever"`
|
||||
2. FTS searches the index and finds `"Aspirin: pain, fever"` as a match
|
||||
3. `{{ retrieval_context }}` in your prompt contains the matched results
|
||||
|
||||
---
|
||||
|
||||
### The `embedding` section
|
||||
|
||||
Required if `"embedding"` is in `index_types`. Configures vector/semantic search.
|
||||
|
||||
| Field | Required | Description |
|
||||
| --- | --- | --- |
|
||||
| `model` | yes | Embedding model, e.g. `openai/text-embedding-3-small` |
|
||||
| `index_phrase` | no | Jinja template for text to embed. Falls back to `fts.index_phrase` |
|
||||
| `query_phrase` | yes | Jinja template for query text to embed |
|
||||
|
||||
**Jinja variables:** Same as FTS section.
|
||||
|
||||
**Example - Semantic search:**
|
||||
|
||||
```yaml
|
||||
retrievers:
|
||||
semantic_docs:
|
||||
type: lancedb
|
||||
dataset: documentation
|
||||
index_dir: ./docs_index
|
||||
index_types: ["embedding"]
|
||||
embedding:
|
||||
model: openai/text-embedding-3-small
|
||||
index_phrase: "{{ input.content }}"
|
||||
query_phrase: "{{ input.question }}"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### The `query` section (optional)
|
||||
|
||||
Controls search behavior. You can omit this entire section.
|
||||
|
||||
| Field | Default | Description |
|
||||
| --- | --- | --- |
|
||||
| `mode` | auto | `fts`, `embedding`, or `hybrid`. Auto-selects `hybrid` if both indexes exist |
|
||||
| `top_k` | 5 | Number of results to retrieve |
|
||||
|
||||
**Example - Override defaults:**
|
||||
|
||||
```yaml
|
||||
retrievers:
|
||||
my_search:
|
||||
# ... other config ...
|
||||
query:
|
||||
mode: fts # force FTS even if embedding index exists
|
||||
top_k: 20 # return more results
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Using a retriever in operations
|
||||
|
||||
Attach a retriever to any LLM operation (map, filter, reduce, extract) with `retriever: <retriever_name>`. The retrieved results are available as `{{ retrieval_context }}` in your prompt.
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| --- | --- | --- | --- |
|
||||
| retriever | string | - | Name of the retriever to use (must match a key in `retrievers`). |
|
||||
| save_retriever_output | bool | false | If true, saves retrieved context to `_<operation_name>_retrieved_context` in output. |
|
||||
|
||||
### Map example
|
||||
|
||||
```yaml
|
||||
- name: tag_visit
|
||||
type: map
|
||||
retriever: medical_r
|
||||
save_retriever_output: true
|
||||
output:
|
||||
schema:
|
||||
tag: string
|
||||
prompt: |
|
||||
Classify this medical visit. Related context:
|
||||
{{ retrieval_context }}
|
||||
|
||||
Transcript: {{ input.src }}
|
||||
```
|
||||
|
||||
### Filter example
|
||||
|
||||
```yaml
|
||||
- name: filter_relevant
|
||||
type: filter
|
||||
retriever: medical_r
|
||||
prompt: |
|
||||
Is this transcript relevant to medication counseling?
|
||||
Context: {{ retrieval_context }}
|
||||
Transcript: {{ input.src }}
|
||||
output:
|
||||
schema:
|
||||
is_relevant: boolean
|
||||
```
|
||||
|
||||
### Reduce example
|
||||
|
||||
When using reduce, the retrieval context is computed per group.
|
||||
|
||||
```yaml
|
||||
- name: summarize_by_medication
|
||||
type: reduce
|
||||
retriever: medical_r
|
||||
reduce_key: medication
|
||||
output:
|
||||
schema:
|
||||
summary: string
|
||||
prompt: |
|
||||
Summarize key points for medication '{{ reduce_key.medication }}'.
|
||||
Related context: {{ retrieval_context }}
|
||||
|
||||
Inputs:
|
||||
{% for item in inputs %}
|
||||
- {{ item.src }}
|
||||
{% endfor %}
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **No results**: the retriever injects "No extra context available." and continues.
|
||||
- **Index issues**: set `build_index: always` to rebuild; ensure `index_dir` exists and is writable.
|
||||
- **Token limits**: `retrieval_context` is truncated to ~1000 chars per retrieved doc.
|
||||
|
|
@ -246,43 +246,4 @@ print(f"Optimized pipeline execution completed. Total cost: ${cost:.2f}")
|
|||
# Continue processing with other pandas operations
|
||||
```
|
||||
|
||||
Learn more about the pandas integration in the [pandas documentation](pandas/index.md).
|
||||
|
||||
## Using a Code Operation to Normalize Medication Names (Python API)
|
||||
|
||||
You can insert a code operation in the Python API pipeline to perform per-document transformations without calling an LLM. Code ops accept a Python function (it does not need to be named `transform`).
|
||||
|
||||
```python
|
||||
from docetl.api import CodeMapOp
|
||||
|
||||
# Normalize medication names (lowercase/trim)
|
||||
def normalize_medication(doc: dict) -> dict:
|
||||
med = doc.get("medication", "")
|
||||
return {"medication_norm": med.lower().strip() if isinstance(med, str) else med}
|
||||
|
||||
# Add this op after unnesting and before resolve
|
||||
operations = [
|
||||
# ... existing extract_medications (MapOp), unnest_medications (UnnestOp)
|
||||
CodeMapOp(
|
||||
name="normalize_medication",
|
||||
type="code_map",
|
||||
code=normalize_medication,
|
||||
),
|
||||
# Optionally, update Resolve/Reduce to use "medication_norm" instead of "medication"
|
||||
]
|
||||
|
||||
# And include it in the step order, e.g.
|
||||
step = PipelineStep(
|
||||
name="medical_info_extraction",
|
||||
input="transcripts",
|
||||
operations=[
|
||||
"extract_medications",
|
||||
"unnest_medications",
|
||||
"normalize_medication", # new code op here
|
||||
"resolve_medications",
|
||||
"summarize_prescriptions",
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
This keeps your deterministic preprocessing in Python while still leveraging DocETL for the LLM-powered stages.
|
||||
Learn more about the pandas integration in the [pandas documentation](pandas/index.md).
|
||||
|
|
@ -36,7 +36,7 @@ DocETL uses [LiteLLM](https://github.com/BerriAI/litellm) under the hood, which
|
|||
|
||||
## Preparing the Data
|
||||
|
||||
Organize your medical transcript data in a JSON file as a list of objects. Each object should have a "src" key containing the transcript text. You can download the example dataset [here](assets/medical_transcripts.json).
|
||||
Organize your medical transcript data in a JSON file as a list of objects. Each object should have a "src" key containing the transcript text. You can download the example dataset [here](../assets/medical_transcripts.json).
|
||||
|
||||
!!! example "Sample Data Structure"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,246 +0,0 @@
|
|||
# Running MOARSearch and Simple Agent Experiments
|
||||
|
||||
This guide explains how to run the MOAR optimizer `run_moar.py` and the simple agent optimizer `run_simple_agent.py`.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Prerequisites](#prerequisites)
|
||||
- [MOAR](#moarsearch-run_moarpy)
|
||||
- [Simple Agent](#simple-agent-run_simple_agentpy)
|
||||
- [Supported Datasets](#supported-datasets)
|
||||
- [Custom Datasets with User-Authored Accuracy Functions](#custom-datasets-with-user-authored-accuracy-functions)
|
||||
- [Available Models](#available-models)
|
||||
- [Output Files](#output-files)
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Dataset and Pipeline Files
|
||||
|
||||
Download the 6 workload experiment datasets and initial pipelines from the [Google Drive Link](https://drive.google.com/drive/folders/1pNFqYCguAZL3iGYHd-jDoialrtf73fbR?usp=drive_link). Extract the `data/` folder and `pipelines/` folder to `experiments/reasoning/` directory.
|
||||
|
||||
### Python Dependencies
|
||||
|
||||
Install the required Python packages:
|
||||
|
||||
```bash
|
||||
pip install -r experiments/reasoning/requirements.txt
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
#### Required Environment Variables
|
||||
|
||||
Set up your Azure API credentials (required for LLM calls):
|
||||
|
||||
```bash
|
||||
export AZURE_API_KEY="your_key_here"
|
||||
export AZURE_API_BASE="your_endpoint_here"
|
||||
export AZURE_API_VERSION="your_version_here"
|
||||
```
|
||||
|
||||
#### Optional Environment Variables
|
||||
|
||||
- **`EXPERIMENT_OUTPUT_DIR`**: Directory to save experiment outputs. If not set, defaults are used:
|
||||
- `run_moar.py`: `./outputs`
|
||||
- `run_simple_agent.py`: `./outputs/simple_agent`
|
||||
|
||||
|
||||
- **`EXPERIMENT_DATA_DIR`**: Directory containing input data files. Defaults to `./data/` if not set. Can also be set via `--data_dir` parameter in `run_moar.py`.
|
||||
|
||||
|
||||
## MOARSearch (`run_moar.py`)
|
||||
|
||||
MOARSearch is a multi-objective optimization algorithm that uses graph search to find Pareto-optimal pipelines balancing cost and accuracy.
|
||||
|
||||
### Local Execution
|
||||
|
||||
#### Basic Usage
|
||||
|
||||
```bash
|
||||
python experiments/reasoning/run_moar.py \
|
||||
--yaml_path experiments/reasoning/pipelines/cuad.yaml \
|
||||
--dataset_path experiments/reasoning/data/train/cuad.json \
|
||||
--experiment_name cuad_moar \
|
||||
--dataset cuad
|
||||
```
|
||||
|
||||
### Modal Execution (Recommended)
|
||||
|
||||
Modal execution is recommended for longer-running experiments as it provides better resource management and persistence.
|
||||
|
||||
```bash
|
||||
modal run experiments/reasoning/run_moar.py \
|
||||
--yaml-path=experiments/reasoning/pipelines/medec.yaml \
|
||||
--dataset-path=experiments/reasoning/data/train/medec.json \
|
||||
--experiment-name=medec_moar \
|
||||
--dataset=medec \
|
||||
--max-iterations=30
|
||||
```
|
||||
|
||||
Outputs are written to `/mnt/docetl-ro-experiments/outputs/{experiment_name}` in the shared Modal volume.
|
||||
|
||||
### Parameters
|
||||
|
||||
| Parameter | Required | Default | Description |
|
||||
|-----------|----------|---------|-------------|
|
||||
| `--yaml_path` | ✅ Yes | - | Path to the user-authored input YAML pipeline file |
|
||||
| `--dataset_path` | ✅ Yes | - | Path to the dataset file for sample input data |
|
||||
| `--experiment_name` | ✅ Yes | - | Unique experiment identifier |
|
||||
| `--dataset` | Conditional | `cuad` | Dataset name for evaluation. **Required** if `--accuracy_function` is not provided. Must be one of: cuad, blackvault, medec, biodex, sustainability, game_reviews. If using `--accuracy_function`, can be any string (used for naming). |
|
||||
| `--max_iterations` | No | `100` | Maximum MOARSearch iterations |
|
||||
| `--exploration_weight` | No | `1.414` | UCB exploration parameter c (controls exploration vs exploitation) |
|
||||
| `--model` | No | `gpt-5` | LLM model to use for directive instantiation |
|
||||
| `--available_models` | No | All 11 default models | Space-separated list of models for first layer testing. Example: `gpt-5 gpt-5-mini gpt-4o` |
|
||||
| `--data_dir` | No | `EXPERIMENT_DATA_DIR` env var or `./data/` | Directory containing input data files |
|
||||
| `--output_dir` | No | `EXPERIMENT_OUTPUT_DIR` env var or `./outputs` | Directory to save experiment outputs |
|
||||
| `--ground_truth` | No | - | Path to ground-truth file |
|
||||
| `--accuracy_function` | No | - | Path to Python file containing custom `evaluate_results` function (for user datasets) |
|
||||
| `--accuracy_metric_key` | No | - | Key to extract from evaluation results dict for accuracy metric (required with `--accuracy_function`) |
|
||||
|
||||
## Simple Agent (`run_simple_agent.py`)
|
||||
|
||||
The Simple Agent is a baseline optimizer that uses basic tool calling to generate and test pipeline configurations iteratively.
|
||||
|
||||
### Local Execution
|
||||
|
||||
#### Basic Usage
|
||||
|
||||
```bash
|
||||
python experiments/reasoning/run_simple_agent.py \
|
||||
--dataset medec \
|
||||
--model gpt-4o-mini
|
||||
```
|
||||
|
||||
### Modal Execution (Recommended)
|
||||
|
||||
```bash
|
||||
modal run experiments/reasoning/run_simple_agent.py \
|
||||
--dataset=medec \
|
||||
--model=gpt-4o-mini \
|
||||
--experiment-name=medec_simple
|
||||
```
|
||||
|
||||
### Parameters
|
||||
|
||||
| Parameter | Required | Default | Description |
|
||||
|-----------|----------|---------|-------------|
|
||||
| `--dataset` | Conditional | - | Dataset name. **Required** if `--accuracy_function` is not provided. Must be one of: cuad, blackvault, medec, biodex, sustainability, game_reviews. If using `--accuracy_function`, can be any string (used for naming). |
|
||||
| `--model` | No | `gpt-5` | LLM model to use for optimization |
|
||||
| `--experiment_name` | No | `simple_agent_{dataset}` | Unique experiment identifier |
|
||||
| `--output_dir` | No | `outputs/simple_agent` | Output directory |
|
||||
| `--ground_truth` | No | - | Path to ground truth file for evaluation |
|
||||
| `--available_models` | No | All 11 default models | Space-separated list of available models. Example: `gpt-5 gpt-5-mini gpt-4o` |
|
||||
| `--accuracy_function` | No | - | Path to Python file containing custom `evaluate_results` function (for user datasets) |
|
||||
| `--accuracy_metric_key` | No | - | Key to extract from evaluation results dict for accuracy metric (required with `--accuracy_function`) |
|
||||
|
||||
## Supported Datasets
|
||||
|
||||
Both `run_moar.py` and `run_simple_agent.py` support the following datasets:
|
||||
|
||||
- `cuad` - Legal clause extraction
|
||||
- `blackvault` - UFO sighting analysis
|
||||
- `medec` - Medical entity classification
|
||||
- `biodex` - Biochemical reaction prediction
|
||||
- `sustainability` - Sustainability analysis
|
||||
- `game_reviews` - Game review sentiment analysis
|
||||
|
||||
### Custom Datasets with User-Authored Accuracy Functions
|
||||
|
||||
For datasets not listed above, you can provide your own accuracy evaluation function using the `--accuracy_function` and `--accuracy_metric_key` parameters.
|
||||
|
||||
#### Creating a Custom Accuracy Function
|
||||
|
||||
Create a Python file (e.g., `my_evaluate.py`) with an `evaluate_results` function:
|
||||
|
||||
```python
|
||||
# my_evaluate.py
|
||||
import json
|
||||
|
||||
def evaluate_results(method_name, results_file_path):
|
||||
"""
|
||||
Evaluate pipeline results and return metrics.
|
||||
|
||||
Args:
|
||||
method_name: Name of the method being evaluated
|
||||
results_file_path: Path to the JSON file containing pipeline results
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing evaluation metrics. Must include the metric
|
||||
specified by --accuracy_metric_key.
|
||||
"""
|
||||
# Load results
|
||||
with open(results_file_path, 'r') as f:
|
||||
results = json.load(f)
|
||||
|
||||
# Your evaluation logic here
|
||||
# Calculate metrics based on your dataset's requirements
|
||||
|
||||
metrics = {
|
||||
"my_accuracy_metric": 0.95, # Primary accuracy metric (specify key with --accuracy_metric_key)
|
||||
# Add any other metrics you want to track
|
||||
}
|
||||
|
||||
return metrics
|
||||
```
|
||||
|
||||
#### Using Custom Accuracy Functions
|
||||
|
||||
**MOARSearch:**
|
||||
```bash
|
||||
python experiments/reasoning/run_moar.py \
|
||||
--yaml_path my_pipeline.yaml \
|
||||
--dataset_path my_data.json \
|
||||
--experiment_name my_custom_experiment \
|
||||
--dataset my_custom_dataset \
|
||||
--accuracy_function my_evaluate.py \
|
||||
--accuracy_metric_key my_accuracy_metric
|
||||
```
|
||||
|
||||
**Simple Agent:**
|
||||
```bash
|
||||
python experiments/reasoning/run_simple_agent.py \
|
||||
--dataset my_custom_dataset \
|
||||
--accuracy_function my_evaluate.py \
|
||||
--accuracy_metric_key my_accuracy_metric
|
||||
```
|
||||
|
||||
**Note:** When using custom accuracy functions, the `--dataset` parameter can be any string (it's used for naming/organization). The actual evaluation logic comes from your custom function.
|
||||
|
||||
## Available Models
|
||||
|
||||
Both scripts support the `--available_models` parameter to specify which models should be tested. If not provided, the following default model list is used:
|
||||
|
||||
- `gpt-5`
|
||||
- `gpt-5-mini`
|
||||
- `gpt-5-nano`
|
||||
- `gpt-4.1`
|
||||
- `gpt-4.1-mini`
|
||||
- `gpt-4.1-nano`
|
||||
- `gpt-4o`
|
||||
- `gpt-4o-mini`
|
||||
- `gemini-2.5-pro`
|
||||
- `gemini-2.5-flash`
|
||||
- `gemini-2.5-flash-lite`
|
||||
|
||||
|
||||
## Output Files
|
||||
|
||||
### MOARSearch Output
|
||||
|
||||
Results are saved to `outputs/{experiment_name}/` containing:
|
||||
|
||||
- `experiment_summary.json` - Experiment metadata and results
|
||||
- `pareto_frontier.json` - All Pareto-optimal solutions with accuracy and cost
|
||||
- `evaluation_metrics.json` - Detailed evaluation results
|
||||
- `moar_tree_log.txt` - Search tree structure and visit counts
|
||||
- `cost_vs_{metric}.png` - Plot showing cost vs performance (dataset-specific)
|
||||
- Pipeline YAML files for each explored configuration
|
||||
|
||||
### Simple Agent Output
|
||||
|
||||
Results are saved to `outputs/simple_agent/{experiment_name}/` containing:
|
||||
|
||||
- `final_pipeline.yaml` - Final optimized pipeline configuration
|
||||
- `iteration_{n}_output.json` - Pipeline outputs for each iteration
|
||||
- `evaluation_metrics.json` - Performance evaluation results
|
||||
- `cost_vs_{metric}.png` - Plot showing cost vs performance across iterations
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue