Compare commits

...

3 Commits

Author SHA1 Message Date
Shreya Shankar bcac6872f5
Embedding blocking threshold optimization (#473)
* feat: Add runtime blocking threshold optimization

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Checkpoint before follow-up message

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Refactor: Simplify target_recall retrieval in Equijoin and Resolve

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Refactor: Improve blocking documentation and add auto-blocking

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* allow resolve and equijoin to figure out blocking thresholds on the fly.

---------

Co-authored-by: Cursor Agent <cursoragent@cursor.com>
2025-12-29 19:47:21 -06:00
Shreya Shankar 57a284bcb1
Fast Decomposition for Map Operations in DocWrangler (#472)
* refactor: docwrangler to use a faster decomposition flow, only if the last operation in a pipeline is a map operation.

* refactor: docwrangler to use a faster decomposition flow, only if the last operation in a pipeline is a map operation.

* refactor: update MOAR documentation
2025-12-29 18:22:02 -06:00
Shreya Shankar cfbb64470a
optimizer: add directives for resolve operator (#470) 2025-12-29 13:57:51 -06:00
69 changed files with 5058 additions and 759 deletions

View File

@ -1,11 +1,15 @@
__version__ = "0.2.6"
import warnings
import litellm
from docetl.runner import DSLRunner
from docetl.optimizer import Optimizer
from docetl.apis.pd_accessors import SemanticAccessor
# Drop unsupported params for models like gpt-5 that don't support temperature=0
litellm.drop_params = True
# TODO: Remove after https://github.com/BerriAI/litellm/issues/7560 is fixed
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._config")

View File

@ -1,4 +1,5 @@
import os
import re
import threading
import time
from io import StringIO
@ -12,6 +13,63 @@ from docetl.utils import StageType, get_stage_description
install(show_locals=False)
# ANSI escape sequences to strip (cursor movement, line clearing, cursor visibility)
ANSI_CURSOR_PATTERNS = re.compile(
r"\x1b\[\?25[lh]" # Hide/show cursor
r"|\x1b\[\d*[ABCD]" # Cursor up/down/forward/back
r"|\x1b\[\d*[JK]" # Clear screen/line (but we handle \x1b[2K specially)
r"|\x1b\[s|\x1b\[u" # Save/restore cursor position
)
def process_carriage_returns(text: str) -> str:
"""
Process terminal control sequences for buffer-captured output.
Rich's spinner uses ANSI sequences for:
- `\r` - carriage return (move to start of line)
- `\x1b[2K` - clear entire line
- `\x1b[?25l/h` - hide/show cursor
- `\x1b[1A` - move cursor up
When captured to a buffer (not a real terminal), we simulate this by:
1. Handling `\r` as "replace from start of line"
2. Stripping cursor movement/visibility sequences
3. Keeping only meaningful content
"""
# First, strip cursor visibility and movement sequences
text = ANSI_CURSOR_PATTERNS.sub("", text)
# Process line by line
lines = text.split("\n")
processed_lines = []
for line in lines:
# Handle carriage returns - keep only content after the last \r
if "\r" in line:
segments = line.split("\r")
# Keep the last non-empty segment (after stripping ANSI clear codes)
for segment in reversed(segments):
# Strip the clear line code if present
cleaned = segment.replace("\x1b[2K", "").strip()
if cleaned:
# Keep the segment but preserve any color codes
processed_lines.append(segment.replace("\x1b[2K", ""))
break
else:
# All segments were empty, skip this line entirely
pass
else:
# No carriage return, keep the line as-is
if line.strip() or line == "": # Keep empty lines between content
processed_lines.append(line)
# Remove trailing empty lines
while processed_lines and not processed_lines[-1].strip():
processed_lines.pop()
return "\n".join(processed_lines)
class ThreadSafeConsole(Console):
def __init__(self, *args, **kwargs):
@ -24,11 +82,18 @@ class ThreadSafeConsole(Console):
self.optimizer_rationale = None
def get_output(self):
# return self.export_text(styles=True)
"""
Get the output from the buffer, processing carriage returns.
Rich's spinner uses carriage returns to overwrite lines in place.
We process these to only keep the latest content, preventing
duplicate spinner frames from flooding the output.
"""
value = self.buffer.getvalue()
self.buffer.truncate(0)
self.buffer.seek(0)
return value
# Process carriage returns to handle spinner overwrites
return process_carriage_returns(value)
def status(
self,
@ -36,10 +101,15 @@ class ThreadSafeConsole(Console):
*,
spinner: str = "dots",
spinner_style: "StyleType" = "status.spinner",
speed: float = 0.1, # Much slower speed
refresh_per_second: float = 0.5, # Much slower refresh rate (every 2 seconds)
speed: float = 1.0,
refresh_per_second: float = 4,
) -> "Status":
"""
Return a Rich Status with animation.
The carriage returns from the spinner animation are processed
in get_output() to prevent duplicate lines.
"""
status_renderable = Status(
status,
console=self,

View File

@ -465,7 +465,8 @@ class OpContainer:
return cached_data, 0, curr_logs
# Try to load from checkpoint if available
if not is_build:
# Skip if this operation has bypass_cache: true
if not is_build and not self.config.get("bypass_cache", False):
attempted_input_data = self.runner._load_from_checkpoint_if_exists(
self.name.split("/")[0], self.name.split("/")[-1]
)

View File

@ -849,13 +849,9 @@ class MOARSearch:
response = litellm.completion(
model=self.model,
messages=messages,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ExpandResponseFormat,
)
call_cost = response._hidden_params["response_cost"]
call_cost = response._hidden_params.get("response_cost", 0)
with self.tree_lock:
self.console.log(
f"[green]💰 Adding LLM call cost:[/green] ${call_cost:.4f} (total before: ${self.total_search_cost:.4f})"

View File

@ -153,10 +153,11 @@ def run_moar_optimization(
"optimizer_config must contain 'max_iterations' (number of MOARSearch iterations) for MOAR optimizer"
)
model = optimizer_config.get("model")
# Use rewrite_agent_model for consistency with rest of codebase, fallback to model for backwards compatibility
model = optimizer_config.get("rewrite_agent_model") or optimizer_config.get("model")
if not model:
raise ValueError(
"optimizer_config must contain 'model' (LLM model name for directive instantiation) for MOAR optimizer"
"optimizer_config must contain 'rewrite_agent_model' (LLM model name for directive instantiation) for MOAR optimizer"
)
# Optional parameters

View File

@ -17,6 +17,7 @@ from rich.prompt import Confirm
from docetl.operations.base import BaseOperation
from docetl.operations.utils import strict_render
from docetl.operations.utils.blocking import RuntimeBlockingOptimizer
from docetl.operations.utils.progress import RichLoopBar
from docetl.utils import (
completion_cost,
@ -63,6 +64,7 @@ class EquijoinOperation(BaseOperation):
comparison_prompt: str
output: dict[str, Any] | None = None
blocking_threshold: float | None = None
blocking_target_recall: float | None = None
blocking_conditions: list[str] | None = None
limits: dict[str, int] | None = None
comparison_model: str | None = None
@ -250,6 +252,58 @@ class EquijoinOperation(BaseOperation):
if self.status:
self.status.stop()
# Track pre-computed embeddings from auto-optimization
precomputed_left_embeddings = None
precomputed_right_embeddings = None
# Auto-compute blocking threshold if no blocking configuration is provided
if not blocking_threshold and not blocking_conditions and not limit_comparisons:
# Get target recall from operation config (default 0.95)
target_recall = self.config.get("blocking_target_recall", 0.95)
self.console.log(
f"[yellow]No blocking configuration. Auto-computing threshold (target recall: {target_recall:.0%})...[/yellow]"
)
# Create comparison function for threshold optimization
def compare_fn_for_optimization(left_item, right_item):
return self.compare_pair(
self.config["comparison_prompt"],
self.config.get("comparison_model", self.default_model),
left_item,
right_item,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
)
# Run threshold optimization
optimizer = RuntimeBlockingOptimizer(
runner=self.runner,
config=self.config,
default_model=self.default_model,
max_threads=self.max_threads,
console=self.console,
target_recall=target_recall,
sample_size=min(100, len(left_data) * len(right_data) // 4),
)
(
blocking_threshold,
precomputed_left_embeddings,
precomputed_right_embeddings,
optimization_cost,
) = optimizer.optimize_equijoin(
left_data,
right_data,
compare_fn_for_optimization,
left_keys=left_keys,
right_keys=right_keys,
)
total_cost += optimization_cost
self.console.log(
f"[green]Using auto-computed blocking threshold: {blocking_threshold}[/green]"
)
# Initial blocking using multiprocessing
num_processes = min(cpu_count(), len(left_data))
@ -298,45 +352,60 @@ class EquijoinOperation(BaseOperation):
)
if blocking_threshold is not None:
embedding_model = self.config.get("embedding_model", self.default_model)
model_input_context_length = model_cost.get(embedding_model, {}).get(
"max_input_tokens", 8192
)
def get_embeddings(
input_data: list[dict[str, Any]], keys: list[str], name: str
) -> tuple[list[list[float]], float]:
texts = [
" ".join(str(item[key]) for key in keys if key in item)[
: model_input_context_length * 4
]
for item in input_data
]
embeddings = []
total_cost = 0
# Use precomputed embeddings if available from auto-optimization
if (
precomputed_left_embeddings is not None
and precomputed_right_embeddings is not None
):
left_embeddings = precomputed_left_embeddings
right_embeddings = precomputed_right_embeddings
else:
embedding_model = self.config.get("embedding_model", self.default_model)
model_input_context_length = model_cost.get(embedding_model, {}).get(
"max_input_tokens", 8192
)
batch_size = 2000
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
self.console.log(
f"On iteration {i} for creating embeddings for {name} data"
)
response = self.runner.api.gen_embedding(
model=embedding_model,
input=batch,
)
embeddings.extend([data["embedding"] for data in response["data"]])
total_cost += completion_cost(response)
return embeddings, total_cost
left_embeddings, left_cost = get_embeddings(left_data, left_keys, "left")
right_embeddings, right_cost = get_embeddings(
right_data, right_keys, "right"
)
total_cost += left_cost + right_cost
self.console.log(
f"Created embeddings for datasets. Total embedding creation cost: {total_cost}"
)
def get_embeddings(
input_data: list[dict[str, Any]], keys: list[str], name: str
) -> tuple[list[list[float]], float]:
texts = [
" ".join(str(item[key]) for key in keys if key in item)[
: model_input_context_length * 4
]
for item in input_data
]
embeddings = []
embedding_cost = 0
num_batches = (len(texts) + batch_size - 1) // batch_size
for batch_idx, i in enumerate(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
if num_batches > 1:
self.console.log(
f"[dim]Creating {name} embeddings: batch {batch_idx + 1}/{num_batches} "
f"({min(i + batch_size, len(texts))}/{len(texts)} items)[/dim]"
)
response = self.runner.api.gen_embedding(
model=embedding_model,
input=batch,
)
embeddings.extend(
[data["embedding"] for data in response["data"]]
)
embedding_cost += completion_cost(response)
return embeddings, embedding_cost
self.console.log(
f"[cyan]Creating embeddings for {len(left_data)} left + {len(right_data)} right items...[/cyan]"
)
left_embeddings, left_cost = get_embeddings(
left_data, left_keys, "left"
)
right_embeddings, right_cost = get_embeddings(
right_data, right_keys, "right"
)
total_cost += left_cost + right_cost
# Compute all cosine similarities in one call
from sklearn.metrics.pairwise import cosine_similarity

View File

@ -10,10 +10,10 @@ import jinja2
from jinja2 import Template
from litellm import model_cost
from pydantic import Field, ValidationInfo, field_validator, model_validator
from rich.prompt import Confirm
from docetl.operations.base import BaseOperation
from docetl.operations.utils import RichLoopBar, rich_as_completed, strict_render
from docetl.operations.utils.blocking import RuntimeBlockingOptimizer
from docetl.utils import (
completion_cost,
extract_jinja_variables,
@ -40,6 +40,7 @@ class ResolveOperation(BaseOperation):
comparison_model: str | None = None
blocking_keys: list[str] | None = None
blocking_threshold: float | None = Field(None, ge=0, le=1)
blocking_target_recall: float | None = Field(None, ge=0, le=1)
blocking_conditions: list[str] | None = None
input: dict[str, Any] | None = None
embedding_batch_size: int | None = Field(None, gt=0)
@ -266,26 +267,75 @@ class ResolveOperation(BaseOperation):
blocking_keys = self.config.get("blocking_keys", [])
blocking_threshold = self.config.get("blocking_threshold")
blocking_conditions = self.config.get("blocking_conditions", [])
limit_comparisons = self.config.get("limit_comparisons")
total_cost = 0
if self.status:
self.status.stop()
if not blocking_threshold and not blocking_conditions:
# Prompt the user for confirmation
if not Confirm.ask(
"[yellow]Warning: No blocking keys or conditions specified. "
"This may result in a large number of comparisons. "
"We recommend specifying at least one blocking key or condition, or using the optimizer to automatically come up with these. "
"Do you want to continue without blocking?[/yellow]",
console=self.runner.console,
):
raise ValueError("Operation cancelled by user.")
# Track pre-computed embeddings from auto-optimization
precomputed_embeddings = None
# Auto-compute blocking threshold if no blocking configuration is provided
if not blocking_threshold and not blocking_conditions and not limit_comparisons:
# Get target recall from operation config (default 0.95)
target_recall = self.config.get("blocking_target_recall", 0.95)
self.console.log(
f"[yellow]No blocking configuration. Auto-computing threshold (target recall: {target_recall:.0%})...[/yellow]"
)
# Determine blocking keys if not set
auto_blocking_keys = blocking_keys if blocking_keys else None
if not auto_blocking_keys:
prompt_template = self.config.get("comparison_prompt", "")
prompt_vars = extract_jinja_variables(prompt_template)
prompt_vars = [
var
for var in prompt_vars
if var not in ["input", "input1", "input2"]
]
auto_blocking_keys = list(
set([var.split(".")[-1] for var in prompt_vars])
)
if not auto_blocking_keys:
auto_blocking_keys = list(input_data[0].keys())
blocking_keys = auto_blocking_keys
# Create comparison function for threshold optimization
def compare_fn_for_optimization(item1, item2):
return self.compare_pair(
self.config["comparison_prompt"],
self.config.get("comparison_model", self.default_model),
item1,
item2,
blocking_keys=[], # Don't use key-based shortcut during optimization
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
)
# Run threshold optimization
optimizer = RuntimeBlockingOptimizer(
runner=self.runner,
config=self.config,
default_model=self.default_model,
max_threads=self.max_threads,
console=self.console,
target_recall=target_recall,
sample_size=min(100, len(input_data) * (len(input_data) - 1) // 4),
)
blocking_threshold, precomputed_embeddings, optimization_cost = (
optimizer.optimize_resolve(
input_data,
compare_fn_for_optimization,
blocking_keys=blocking_keys,
)
)
total_cost += optimization_cost
input_schema = self.config.get("input", {}).get("schema", {})
if not blocking_keys:
# Set them to all keys in the input data
blocking_keys = list(input_data[0].keys())
limit_comparisons = self.config.get("limit_comparisons")
total_cost = 0
def is_match(item1: dict[str, Any], item2: dict[str, Any]) -> bool:
return any(
@ -296,120 +346,101 @@ class ResolveOperation(BaseOperation):
# Calculate embeddings if blocking_threshold is set
embeddings = None
if blocking_threshold is not None:
def get_embeddings_batch(
items: list[dict[str, Any]]
) -> list[tuple[list[float], float]]:
# Use precomputed embeddings if available from auto-optimization
if precomputed_embeddings is not None:
embeddings = precomputed_embeddings
else:
self.console.log(
f"[cyan]Creating embeddings for {len(input_data)} items...[/cyan]"
)
embedding_model = self.config.get(
"embedding_model", "text-embedding-3-small"
)
model_input_context_length = model_cost.get(embedding_model, {}).get(
"max_input_tokens", 8192
)
batch_size = self.config.get("embedding_batch_size", 1000)
embeddings = []
embedding_cost = 0.0
num_batches = (len(input_data) + batch_size - 1) // batch_size
texts = [
" ".join(str(item[key]) for key in blocking_keys if key in item)[
: model_input_context_length * 3
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(input_data))
batch = input_data[start_idx:end_idx]
if num_batches > 1:
self.console.log(
f"[dim]Creating embeddings: batch {batch_idx + 1}/{num_batches} "
f"({end_idx}/{len(input_data)} items)[/dim]"
)
texts = [
" ".join(
str(item[key]) for key in blocking_keys if key in item
)[: model_input_context_length * 3]
for item in batch
]
for item in items
]
response = self.runner.api.gen_embedding(
model=embedding_model, input=texts
)
embeddings.extend([data["embedding"] for data in response["data"]])
embedding_cost += completion_cost(response)
response = self.runner.api.gen_embedding(
model=embedding_model, input=texts
)
return [
(data["embedding"], completion_cost(response))
for data in response["data"]
]
total_cost += embedding_cost
embeddings = []
costs = []
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
for i in range(
0, len(input_data), self.config.get("embedding_batch_size", 1000)
):
batch = input_data[
i : i + self.config.get("embedding_batch_size", 1000)
]
batch_results = list(executor.map(get_embeddings_batch, [batch]))
# Build a mapping of blocking key values to indices
# This is used later for cluster merging (when two items match, merge all items sharing their key values)
value_to_indices: dict[tuple[str, ...], list[int]] = {}
for i, item in enumerate(input_data):
key = tuple(str(item.get(k, "")) for k in blocking_keys)
if key not in value_to_indices:
value_to_indices[key] = []
value_to_indices[key].append(i)
for result in batch_results:
embeddings.extend([r[0] for r in result])
costs.extend([r[1] for r in result])
# Total number of pairs to potentially compare
n = len(input_data)
total_pairs = n * (n - 1) // 2
total_cost += sum(costs)
# Generate all pairs to compare, ensuring no duplicate comparisons
def get_unique_comparison_pairs() -> (
tuple[list[tuple[int, int]], dict[tuple[str, ...], list[int]]]
):
# Create a mapping of values to their indices
value_to_indices: dict[tuple[str, ...], list[int]] = {}
for i, item in enumerate(input_data):
# Create a hashable key from the blocking keys
key = tuple(str(item.get(k, "")) for k in blocking_keys)
if key not in value_to_indices:
value_to_indices[key] = []
value_to_indices[key].append(i)
# Generate pairs for comparison, comparing each unique value combination only once
comparison_pairs = []
keys = list(value_to_indices.keys())
# First, handle comparisons between different values
for i in range(len(keys)):
for j in range(i + 1, len(keys)):
# Only need one comparison between different values
idx1 = value_to_indices[keys[i]][0]
idx2 = value_to_indices[keys[j]][0]
if idx1 < idx2: # Maintain ordering to avoid duplicates
comparison_pairs.append((idx1, idx2))
return comparison_pairs, value_to_indices
comparison_pairs, value_to_indices = get_unique_comparison_pairs()
# Filter pairs based on blocking conditions
def meets_blocking_conditions(pair: tuple[int, int]) -> bool:
i, j = pair
return (
is_match(input_data[i], input_data[j]) if blocking_conditions else False
)
# Start with pairs that meet blocking conditions, or empty list if no conditions
code_blocked_pairs = (
list(filter(meets_blocking_conditions, comparison_pairs))
if blocking_conditions
else []
)
# Apply code-based blocking conditions (check all pairs)
code_blocked_pairs = []
if blocking_conditions:
for i in range(n):
for j in range(i + 1, n):
if is_match(input_data[i], input_data[j]):
code_blocked_pairs.append((i, j))
# Apply cosine similarity blocking if threshold is specified
embedding_blocked_pairs = []
if blocking_threshold is not None and embeddings is not None:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
similarity_matrix = cosine_similarity(embeddings)
# Add pairs that meet the cosine similarity threshold and aren't already blocked
code_blocked_set = set(code_blocked_pairs)
for i, j in comparison_pairs:
if (i, j) not in code_blocked_set:
similarity = similarity_matrix[i, j]
if similarity >= blocking_threshold:
embedding_blocked_pairs.append((i, j))
# Use numpy to efficiently find all pairs above threshold
i_indices, j_indices = np.triu_indices(n, k=1)
similarities = similarity_matrix[i_indices, j_indices]
above_threshold_mask = similarities >= blocking_threshold
self.console.log(
f"Cosine similarity blocking: added {len(embedding_blocked_pairs)} pairs "
f"(threshold: {blocking_threshold})"
)
# Get pairs above threshold
above_threshold_i = i_indices[above_threshold_mask]
above_threshold_j = j_indices[above_threshold_mask]
# Combine pairs with prioritization for sampling
# Filter out pairs already in code_blocked_set
embedding_blocked_pairs = [
(int(i), int(j))
for i, j in zip(above_threshold_i, above_threshold_j)
if (i, j) not in code_blocked_set
]
# Combine pairs from both blocking methods
all_blocked_pairs = code_blocked_pairs + embedding_blocked_pairs
# If no pairs are blocked at all, fall back to all comparison pairs
if not all_blocked_pairs:
all_blocked_pairs = comparison_pairs
# If no blocking was applied, compare all pairs
if not blocking_conditions and blocking_threshold is None:
all_blocked_pairs = [(i, j) for i in range(n) for j in range(i + 1, n)]
# Apply limit_comparisons with prioritization
if limit_comparisons is not None and len(all_blocked_pairs) > limit_comparisons:
# Prioritize code-based pairs, then sample from embedding pairs if needed
@ -476,18 +507,6 @@ class ResolveOperation(BaseOperation):
cluster_map[root_idx] = root1
clusters[root_idx] = set()
# Calculate and print statistics
total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2
comparisons_made = len(blocked_pairs)
comparisons_saved = total_possible_comparisons - comparisons_made
self.console.log(
f"[green]Comparisons saved by deduping and blocking: {comparisons_saved} "
f"({(comparisons_saved / total_possible_comparisons) * 100:.2f}%)[/green]"
)
self.console.log(
f"[blue]Number of pairs to compare: {len(blocked_pairs)}[/blue]"
)
# Compute an auto-batch size based on the number of comparisons
def auto_batch() -> int:
# Maximum batch size limit for 4o-mini model
@ -513,7 +532,14 @@ class ResolveOperation(BaseOperation):
# Compare pairs and update clusters in real-time
batch_size = self.config.get("compare_batch_size", auto_batch())
self.console.log(f"Using compare batch size: {batch_size}")
# Log blocking summary
total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2
self.console.log(
f"Comparing {len(blocked_pairs):,} pairs "
f"({len(blocked_pairs)/total_possible_comparisons*100:.1f}% of {total_possible_comparisons:,} total, "
f"batch size: {batch_size})"
)
pair_costs = 0
pbar = RichLoopBar(

View File

@ -1,4 +1,5 @@
from .api import APIWrapper
from .blocking import RuntimeBlockingOptimizer
from .cache import (
cache,
cache_key,
@ -15,6 +16,7 @@ from .validation import safe_eval, convert_val, convert_dict_schema_to_list_sche
__all__ = [
'APIWrapper',
'RuntimeBlockingOptimizer',
'cache',
'cache_key',
'clear_cache',

View File

@ -0,0 +1,567 @@
"""
Runtime blocking threshold optimization utilities.
This module provides functionality for automatically computing embedding-based
blocking thresholds at runtime when no blocking configuration is provided.
"""
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable
import numpy as np
from litellm import model_cost
from rich.console import Console
from docetl.utils import completion_cost, extract_jinja_variables
class RuntimeBlockingOptimizer:
"""
Computes optimal embedding-based blocking thresholds at runtime.
This class samples pairs from the dataset, performs LLM comparisons,
and finds the optimal cosine similarity threshold that achieves a
target recall rate.
"""
def __init__(
self,
runner,
config: dict[str, Any],
default_model: str,
max_threads: int,
console: Console,
target_recall: float = 0.95,
sample_size: int = 100,
sampling_weight: float = 20.0,
):
"""
Initialize the RuntimeBlockingOptimizer.
Args:
runner: The pipeline runner instance.
config: Operation configuration.
default_model: Default LLM model for comparisons.
max_threads: Maximum threads for parallel processing.
console: Rich console for logging.
target_recall: Target recall rate (default 0.95).
sample_size: Number of pairs to sample for threshold estimation.
sampling_weight: Weight for exponential sampling towards higher similarities.
"""
self.runner = runner
self.config = config
self.default_model = default_model
self.max_threads = max_threads
self.console = console
self.target_recall = target_recall
self.sample_size = sample_size
self.sampling_weight = sampling_weight
def compute_embeddings(
self,
input_data: list[dict[str, Any]],
keys: list[str],
embedding_model: str | None = None,
batch_size: int = 1000,
) -> tuple[list[list[float]], float]:
"""
Compute embeddings for the input data.
Args:
input_data: List of input documents.
keys: Keys to use for embedding text.
embedding_model: Model to use for embeddings.
batch_size: Batch size for embedding computation.
Returns:
Tuple of (embeddings list, total cost).
"""
embedding_model = embedding_model or self.config.get(
"embedding_model", "text-embedding-3-small"
)
model_input_context_length = model_cost.get(embedding_model, {}).get(
"max_input_tokens", 8192
)
texts = [
" ".join(str(item[key]) for key in keys if key in item)[
: model_input_context_length * 3
]
for item in input_data
]
self.console.log(f"[cyan]Creating embeddings for {len(texts)} items...[/cyan]")
embeddings = []
total_cost = 0.0
num_batches = (len(texts) + batch_size - 1) // batch_size
for batch_idx, i in enumerate(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
if num_batches > 1:
self.console.log(
f"[dim] Batch {batch_idx + 1}/{num_batches} "
f"({len(embeddings) + len(batch)}/{len(texts)} items)[/dim]"
)
response = self.runner.api.gen_embedding(
model=embedding_model,
input=batch,
)
embeddings.extend([data["embedding"] for data in response["data"]])
total_cost += completion_cost(response)
return embeddings, total_cost
def calculate_cosine_similarities_self(
self, embeddings: list[list[float]]
) -> list[tuple[int, int, float]]:
"""
Calculate pairwise cosine similarities for self-join.
Args:
embeddings: List of embedding vectors.
Returns:
List of (i, j, similarity) tuples for all pairs where i < j.
"""
embeddings_array = np.array(embeddings)
norms = np.linalg.norm(embeddings_array, axis=1)
# Avoid division by zero
norms = np.where(norms == 0, 1e-10, norms)
dot_products = np.dot(embeddings_array, embeddings_array.T)
similarities_matrix = dot_products / np.outer(norms, norms)
i, j = np.triu_indices(len(embeddings), k=1)
similarities = list(
zip(i.tolist(), j.tolist(), similarities_matrix[i, j].tolist())
)
return similarities
def calculate_cosine_similarities_cross(
self,
left_embeddings: list[list[float]],
right_embeddings: list[list[float]],
) -> list[tuple[int, int, float]]:
"""
Calculate cosine similarities between two sets of embeddings.
Args:
left_embeddings: Embeddings for left dataset.
right_embeddings: Embeddings for right dataset.
Returns:
List of (left_idx, right_idx, similarity) tuples.
"""
left_array = np.array(left_embeddings)
right_array = np.array(right_embeddings)
dot_product = np.dot(left_array, right_array.T)
norm_left = np.linalg.norm(left_array, axis=1)
norm_right = np.linalg.norm(right_array, axis=1)
# Avoid division by zero
norm_left = np.where(norm_left == 0, 1e-10, norm_left)
norm_right = np.where(norm_right == 0, 1e-10, norm_right)
similarities = dot_product / np.outer(norm_left, norm_right)
return [
(i, j, float(sim))
for i, row in enumerate(similarities)
for j, sim in enumerate(row)
]
def sample_pairs(
self,
similarities: list[tuple[int, int, float]],
num_bins: int = 10,
stratified_fraction: float = 0.5,
) -> list[tuple[int, int]]:
"""
Sample pairs using a hybrid of stratified and exponential-weighted sampling.
This ensures coverage across the similarity distribution while still
focusing on high-similarity pairs where matches are more likely.
Args:
similarities: List of (i, j, similarity) tuples.
num_bins: Number of bins for stratified sampling.
stratified_fraction: Fraction of samples to allocate to stratified sampling.
Returns:
List of sampled (i, j) pairs.
"""
if len(similarities) == 0:
return []
sample_count = min(self.sample_size, len(similarities))
stratified_count = int(sample_count * stratified_fraction)
exponential_count = sample_count - stratified_count
sampled_indices = set()
sim_values = np.array([sim[2] for sim in similarities])
# Part 1: Stratified sampling across bins
if stratified_count > 0:
bin_edges = np.linspace(
sim_values.min(), sim_values.max() + 1e-9, num_bins + 1
)
samples_per_bin = max(1, stratified_count // num_bins)
for bin_idx in range(num_bins):
bin_mask = (sim_values >= bin_edges[bin_idx]) & (
sim_values < bin_edges[bin_idx + 1]
)
bin_indices = np.where(bin_mask)[0]
if len(bin_indices) > 0:
# Within each bin, use exponential weighting
bin_sims = sim_values[bin_indices]
bin_weights = np.exp(self.sampling_weight * bin_sims)
bin_weights /= bin_weights.sum()
n_to_sample = min(samples_per_bin, len(bin_indices))
chosen = np.random.choice(
bin_indices,
size=n_to_sample,
replace=False,
p=bin_weights,
)
sampled_indices.update(chosen.tolist())
# Part 2: Exponential-weighted sampling for remaining slots
if exponential_count > 0:
remaining_indices = [
i for i in range(len(similarities)) if i not in sampled_indices
]
if remaining_indices:
remaining_sims = sim_values[remaining_indices]
weights = np.exp(self.sampling_weight * remaining_sims)
weights /= weights.sum()
n_to_sample = min(exponential_count, len(remaining_indices))
chosen = np.random.choice(
remaining_indices,
size=n_to_sample,
replace=False,
p=weights,
)
sampled_indices.update(chosen.tolist())
sampled_pairs = [
(similarities[i][0], similarities[i][1]) for i in sampled_indices
]
return sampled_pairs
def _print_similarity_histogram(
self,
similarities: list[tuple[int, int, float]],
comparison_results: list[tuple[int, int, bool]],
threshold: float | None = None,
):
"""
Print a histogram of embedding cosine similarity distribution.
Args:
similarities: List of (i, j, similarity) tuples.
comparison_results: List of (i, j, is_match) from LLM comparisons.
threshold: Optional threshold to highlight in the histogram.
"""
# Filter out self-similarities (similarity == 1)
flat_similarities = [sim[2] for sim in similarities if sim[2] != 1]
if not flat_similarities:
return
hist, bin_edges = np.histogram(flat_similarities, bins=20)
max_bar_width, max_count = 40, max(hist) if max(hist) > 0 else 1
normalized_hist = [int(count / max_count * max_bar_width) for count in hist]
# Create a dictionary to store true labels
true_labels = {(i, j): is_match for i, j, is_match in comparison_results}
# Count pairs above threshold
pairs_above_threshold = (
sum(1 for sim in flat_similarities if sim >= threshold) if threshold else 0
)
total_pairs = len(flat_similarities)
lines = []
for i, count in enumerate(normalized_hist):
bar = "" * count
bin_start, bin_end = bin_edges[i], bin_edges[i + 1]
label = f"{bin_start:.2f}-{bin_end:.2f}"
# Count true matches and not matches in this bin
true_matches = 0
not_matches = 0
labeled_count = 0
for sim in similarities:
if bin_start <= sim[2] < bin_end:
if (sim[0], sim[1]) in true_labels:
labeled_count += 1
if true_labels[(sim[0], sim[1])]:
true_matches += 1
else:
not_matches += 1
# Calculate percentages of labeled pairs
if labeled_count > 0:
true_match_percent = (true_matches / labeled_count) * 100
label_info = f"[green]{true_match_percent:5.1f}%[/green] match"
else:
label_info = "[dim]--[/dim]"
# Highlight the bin containing the threshold
if threshold is not None and bin_start <= threshold < bin_end:
lines.append(
f"[bold yellow]{label}[/bold yellow] {bar:<{max_bar_width}} "
f"[dim]n={hist[i]:>5}[/dim] {label_info} [bold yellow]◀ threshold[/bold yellow]"
)
else:
lines.append(
f"{label} {bar:<{max_bar_width}} "
f"[dim]n={hist[i]:>5}[/dim] {label_info}"
)
from rich.panel import Panel
histogram_content = "\n".join(lines)
title = f"Similarity Distribution ({pairs_above_threshold:,} of {total_pairs:,} pairs ≥ {threshold:.4f})"
self.console.log(Panel(histogram_content, title=title, border_style="cyan"))
def find_optimal_threshold(
self,
comparisons: list[tuple[int, int, bool]],
similarities: list[tuple[int, int, float]],
) -> tuple[float, float]:
"""
Find the optimal similarity threshold that achieves target recall.
Args:
comparisons: List of (i, j, is_match) from LLM comparisons.
similarities: List of (i, j, similarity) tuples.
Returns:
Tuple of (optimal_threshold, achieved_recall).
"""
if not comparisons or not any(comp[2] for comp in comparisons):
# No matches found, use a high threshold to be conservative
self.console.log(
"[yellow]No matches found in sample. Using 99th percentile "
"similarity as threshold.[/yellow]"
)
all_sims = [sim[2] for sim in similarities]
threshold = float(np.percentile(all_sims, 99)) if all_sims else 0.9
return threshold, 0.0
true_labels = np.array([comp[2] for comp in comparisons])
sim_dict = {(i, j): sim for i, j, sim in similarities}
sim_scores = np.array([sim_dict.get((i, j), 0.0) for i, j, _ in comparisons])
thresholds = np.linspace(0, 1, 100)
recalls = []
for threshold in thresholds:
predictions = sim_scores >= threshold
tp = np.sum(predictions & true_labels)
fn = np.sum(~predictions & true_labels)
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
recalls.append(recall)
# Find highest threshold that achieves target recall
valid_indices = [i for i, r in enumerate(recalls) if r >= self.target_recall]
if not valid_indices:
# If no threshold achieves target recall, use the one with highest recall
best_idx = int(np.argmax(recalls))
optimal_threshold = float(thresholds[best_idx])
achieved_recall = float(recalls[best_idx])
self.console.log(
f"[yellow]Warning: Could not achieve target recall {self.target_recall:.0%}. "
f"Using threshold {optimal_threshold:.4f} with recall {achieved_recall:.2%}.[/yellow]"
)
else:
best_idx = max(valid_indices)
optimal_threshold = float(thresholds[best_idx])
achieved_recall = float(recalls[best_idx])
return round(optimal_threshold, 4), achieved_recall
def optimize_resolve(
self,
input_data: list[dict[str, Any]],
compare_fn: Callable[[dict, dict], tuple[bool, float, str]],
blocking_keys: list[str] | None = None,
) -> tuple[float, list[list[float]], float]:
"""
Compute optimal blocking threshold for resolve operation.
Args:
input_data: Input dataset.
compare_fn: Function to compare two items, returns (is_match, cost, prompt).
blocking_keys: Keys to use for blocking. If None, extracted from prompt.
Returns:
Tuple of (threshold, embeddings, total_cost).
"""
from rich.panel import Panel
# Determine blocking keys
if not blocking_keys:
prompt_template = self.config.get("comparison_prompt", "")
prompt_vars = extract_jinja_variables(prompt_template)
prompt_vars = [
var for var in prompt_vars if var not in ["input", "input1", "input2"]
]
blocking_keys = list(set([var.split(".")[-1] for var in prompt_vars]))
if not blocking_keys:
blocking_keys = list(input_data[0].keys())
# Compute embeddings
embeddings, embedding_cost = self.compute_embeddings(input_data, blocking_keys)
# Calculate similarities
similarities = self.calculate_cosine_similarities_self(embeddings)
# Sample pairs
sampled_pairs = self.sample_pairs(similarities)
if not sampled_pairs:
self.console.log(
"[yellow]No pairs to sample. Using default threshold 0.8.[/yellow]"
)
return 0.8, embeddings, embedding_cost
# Perform comparisons
comparisons = []
comparison_cost = 0.0
matches_found = 0
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
futures = {
executor.submit(compare_fn, input_data[i], input_data[j]): (i, j)
for i, j in sampled_pairs
}
for future in as_completed(futures):
i, j = futures[future]
try:
is_match, cost, _ = future.result()
comparisons.append((i, j, is_match))
comparison_cost += cost
if is_match:
matches_found += 1
except Exception as e:
self.console.log(f"[red]Comparison error: {e}[/red]")
comparisons.append((i, j, False))
# Find optimal threshold
threshold, achieved_recall = self.find_optimal_threshold(
comparisons, similarities
)
total_cost = embedding_cost + comparison_cost
# Print histogram visualization
self._print_similarity_histogram(similarities, comparisons, threshold)
# Print summary
n = len(input_data)
total_pairs = n * (n - 1) // 2
pairs_above = sum(1 for s in similarities if s[2] >= threshold)
summary = (
f"[bold]Blocking keys:[/bold] {blocking_keys}\n"
f"[bold]Sampled:[/bold] {len(sampled_pairs)} pairs → {matches_found} matches ({matches_found/len(sampled_pairs)*100:.1f}%)\n"
f"[bold]Threshold:[/bold] {threshold:.4f}{achieved_recall:.1%} recall (target: {self.target_recall:.0%})\n"
f"[bold]Pairs to compare:[/bold] {pairs_above:,} of {total_pairs:,} ({pairs_above/total_pairs*100:.1f}%)\n"
f"[bold]Optimization cost:[/bold] ${total_cost:.4f}"
)
self.console.log(
Panel(
summary, title="Blocking Threshold Optimization", border_style="green"
)
)
return threshold, embeddings, total_cost
def optimize_equijoin(
self,
left_data: list[dict[str, Any]],
right_data: list[dict[str, Any]],
compare_fn: Callable[[dict, dict], tuple[bool, float]],
left_keys: list[str] | None = None,
right_keys: list[str] | None = None,
) -> tuple[float, list[list[float]], list[list[float]], float]:
"""
Compute optimal blocking threshold for equijoin operation.
Args:
left_data: Left dataset.
right_data: Right dataset.
compare_fn: Function to compare two items, returns (is_match, cost).
left_keys: Keys to use for left dataset embeddings.
right_keys: Keys to use for right dataset embeddings.
Returns:
Tuple of (threshold, left_embeddings, right_embeddings, total_cost).
"""
from rich.panel import Panel
# Determine keys
if not left_keys:
left_keys = list(left_data[0].keys()) if left_data else []
if not right_keys:
right_keys = list(right_data[0].keys()) if right_data else []
# Compute embeddings
left_embeddings, left_cost = self.compute_embeddings(left_data, left_keys)
right_embeddings, right_cost = self.compute_embeddings(right_data, right_keys)
embedding_cost = left_cost + right_cost
# Calculate cross similarities
similarities = self.calculate_cosine_similarities_cross(
left_embeddings, right_embeddings
)
# Sample pairs
sampled_pairs = self.sample_pairs(similarities)
if not sampled_pairs:
self.console.log(
"[yellow]No pairs to sample. Using default threshold 0.8.[/yellow]"
)
return 0.8, left_embeddings, right_embeddings, embedding_cost
# Perform comparisons
comparisons = []
comparison_cost = 0.0
matches_found = 0
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
futures = {
executor.submit(compare_fn, left_data[i], right_data[j]): (i, j)
for i, j in sampled_pairs
}
for future in as_completed(futures):
i, j = futures[future]
try:
is_match, cost = future.result()
comparisons.append((i, j, is_match))
comparison_cost += cost
if is_match:
matches_found += 1
except Exception as e:
self.console.log(f"[red]Comparison error: {e}[/red]")
comparisons.append((i, j, False))
# Find optimal threshold
threshold, achieved_recall = self.find_optimal_threshold(
comparisons, similarities
)
total_cost = embedding_cost + comparison_cost
# Print histogram visualization
self._print_similarity_histogram(similarities, comparisons, threshold)
# Print summary
total_pairs = len(left_data) * len(right_data)
pairs_above = sum(1 for s in similarities if s[2] >= threshold)
summary = (
f"[bold]Left keys:[/bold] {left_keys} [bold]Right keys:[/bold] {right_keys}\n"
f"[bold]Sampled:[/bold] {len(sampled_pairs)} pairs → {matches_found} matches ({matches_found/len(sampled_pairs)*100:.1f}%)\n"
f"[bold]Threshold:[/bold] {threshold:.4f}{achieved_recall:.1%} recall (target: {self.target_recall:.0%})\n"
f"[bold]Pairs to compare:[/bold] {pairs_above:,} of {total_pairs:,} ({pairs_above/total_pairs*100:.1f}%)\n"
f"[bold]Optimization cost:[/bold] ${total_cost:.4f}"
)
self.console.log(
Panel(
summary, title="Blocking Threshold Optimization", border_style="green"
)
)
return threshold, left_embeddings, right_embeddings, total_cost

View File

@ -55,7 +55,7 @@ class Optimizer:
def __init__(
self,
runner: "DSLRunner",
rewrite_agent_model: str = "gpt-4o",
rewrite_agent_model: str = "gpt-5.1",
judge_agent_model: str = "gpt-4o-mini",
litellm_kwargs: dict[str, Any] = {},
resume: bool = False,
@ -67,7 +67,7 @@ class Optimizer:
Args:
yaml_file (str): Path to the YAML configuration file.
model (str): The name of the language model to use. Defaults to "gpt-4o".
model (str): The name of the language model to use. Defaults to "gpt-5.1".
resume (bool): Whether to resume optimization from a previous run. Defaults to False.
timeout (int): Timeout in seconds for operations. Defaults to 60.

View File

@ -0,0 +1,926 @@
"""
Fast decomposition analyzer for map operations.
This module provides a fast way to decompose map operations using directives,
running candidates on samples, and selecting the best via pairwise LLM comparison.
"""
import json
import os
import tempfile
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy
from typing import Any
import litellm
from litellm import completion, model_cost
from rich.console import Console
from docetl.console import get_console
from docetl.reasoning_optimizer.directives import (
ChainingDirective,
ClarifyInstructionsDirective,
DeterministicDocCompressionDirective,
DocumentChunkingDirective,
GleaningDirective,
IsolatingSubtasksDirective,
)
from docetl.runner import DSLRunner
from docetl.utils import completion_cost, count_tokens
# Drop unsupported params for models like gpt-5 that don't support temperature=0
litellm.drop_params = True
# Base directives always applicable to map operations
BASE_MAP_DIRECTIVES = [
ChainingDirective(),
IsolatingSubtasksDirective(),
GleaningDirective(),
ClarifyInstructionsDirective(),
]
# Directives that depend on document size
CHUNKING_DIRECTIVE = DocumentChunkingDirective()
COMPRESSION_DIRECTIVE = DeterministicDocCompressionDirective()
# Threshold for enabling DeterministicDocCompression (in characters)
DOC_COMPRESSION_CHAR_THRESHOLD = 1000
# Threshold for enabling DocumentChunking (as fraction of context window)
DOC_CHUNKING_CONTEXT_THRESHOLD = 0.10 # 10%
class FastDecomposer:
"""
Fast decomposition of map operations using directives and pairwise comparison.
Instead of the full optimizer flow, this:
1. Tries multiple directives to generate candidate decompositions
2. Runs each candidate on sample documents
3. Uses LLM judge for pairwise comparison
4. Returns the winning decomposition
"""
def __init__(
self,
yaml_config_path: str,
optimizer_model: str = "gpt-5.1",
sample_size: int = 5,
litellm_kwargs: dict[str, Any] | None = None,
console: Console | None = None,
):
"""
Initialize the decomposer.
Args:
yaml_config_path: Path to the pipeline YAML config file
optimizer_model: LLM model to use for directive instantiation and judging
sample_size: Number of sample documents to run candidates on
litellm_kwargs: Additional kwargs to pass to litellm.completion
console: Rich console for output (uses default if not provided)
"""
self.yaml_config_path = yaml_config_path
self.optimizer_model = optimizer_model
self.sample_size = sample_size
self.litellm_kwargs = litellm_kwargs or {}
if "temperature" not in self.litellm_kwargs:
self.litellm_kwargs["temperature"] = 0.0
self.total_cost = 0.0
self.console = console or get_console()
# Load the config
import yaml
with open(yaml_config_path, "r") as f:
self.config = yaml.safe_load(f)
self.operators = self.config.get("operations", [])
self.default_model = self.config.get("default_model", "gpt-4o-mini")
self.intermediate_dir = (
self.config.get("pipeline", {}).get("output", {}).get("intermediate_dir")
)
def _log(self, message: str) -> None:
"""Log a message to the console."""
self.console.log(message)
def get_model_context_limit(self, model: str) -> int:
"""
Get the context window limit for a model.
Args:
model: The model name (e.g., 'gpt-4o', 'azure/gpt-4')
Returns:
Maximum number of input tokens the model can handle
"""
model_info = model_cost.get(model, {})
# Try without provider prefix if not found
if not model_info:
model_name = model.split("/")[-1]
model_info = model_cost.get(model_name, {})
return model_info.get("max_input_tokens", 128000) # Default to 128k
def get_avg_doc_size(
self, sample_data: list[dict[str, Any]], op_config: dict[str, Any]
) -> tuple[float, float]:
"""
Calculate the average document size in characters and tokens.
Extracts the document content from sample data based on the operation's
prompt template (looks for {{ input.field_name }} patterns).
Args:
sample_data: List of sample documents
op_config: The operation configuration
Returns:
Tuple of (avg_chars, avg_tokens) for the document content
"""
import re
if not sample_data:
return 0.0, 0.0
prompt = op_config.get("prompt", "")
model = op_config.get("model", self.default_model)
# Extract field names from prompt template ({{ input.field_name }})
field_pattern = r"\{\{\s*input\.(\w+)\s*\}\}"
fields = re.findall(field_pattern, prompt)
if not fields:
# Fallback: use all string values from the first document
if sample_data:
fields = [
k
for k, v in sample_data[0].items()
if isinstance(v, str) and len(v) > 100
]
total_chars = 0
total_tokens = 0
for doc in sample_data:
doc_content = ""
for field in fields:
if field in doc:
value = doc[field]
if isinstance(value, str):
doc_content += value
else:
doc_content += str(value)
total_chars += len(doc_content)
if doc_content:
total_tokens += count_tokens(doc_content, model)
n = len(sample_data)
return total_chars / n if n > 0 else 0.0, total_tokens / n if n > 0 else 0.0
def get_applicable_directives(
self,
sample_data: list[dict[str, Any]],
op_config: dict[str, Any],
) -> list:
"""
Get the list of directives applicable based on data characteristics.
- DocumentChunkingDirective: only if avg doc size > 10% of context window
- DeterministicDocCompressionDirective: only if avg doc size > 1000 chars
Args:
sample_data: List of sample documents
op_config: The operation configuration
Returns:
List of applicable directive instances
"""
directives = [] # Build list with priority ordering
model = op_config.get("model", self.default_model)
context_limit = self.get_model_context_limit(model)
avg_chars, avg_tokens = self.get_avg_doc_size(sample_data, op_config)
self._log(
f"Document analysis: avg_chars={avg_chars:.0f}, avg_tokens={avg_tokens:.0f}, "
f"context_limit={context_limit}"
)
# Add DeterministicDocCompression FIRST if doc size > 1000 chars (high priority)
if avg_chars > DOC_COMPRESSION_CHAR_THRESHOLD:
self._log(
f"[cyan]Enabling DeterministicDocCompression (priority)[/cyan] "
f"(avg {avg_chars:.0f} chars > {DOC_COMPRESSION_CHAR_THRESHOLD})"
)
directives.append(COMPRESSION_DIRECTIVE)
# Add base directives
directives.extend(BASE_MAP_DIRECTIVES)
# Add DocumentChunking if avg tokens > 10% of context window
token_threshold = context_limit * DOC_CHUNKING_CONTEXT_THRESHOLD
if avg_tokens > token_threshold:
self._log(
f"[cyan]Enabling DocumentChunking[/cyan] "
f"(avg {avg_tokens:.0f} tokens > {token_threshold:.0f} = 10% of {context_limit})"
)
directives.append(CHUNKING_DIRECTIVE)
else:
self._log(
f"[dim]Skipping DocumentChunking[/dim] "
f"(avg {avg_tokens:.0f} tokens <= {token_threshold:.0f} = 10% of {context_limit})"
)
return directives
def load_sample_data(self, step_name: str, op_name: str) -> list[dict[str, Any]]:
"""
Load sample input data for an operation.
For the first operation, loads from the dataset.
For subsequent operations, loads from the previous operation's intermediate output.
Args:
step_name: Name of the pipeline step
op_name: Name of the operation
Returns:
List of sample documents
"""
# Find the operation's position
op_names = [op.get("name") for op in self.operators]
try:
op_idx = op_names.index(op_name)
except ValueError:
raise ValueError(f"Operation '{op_name}' not found in config")
if op_idx == 0:
# First operation - load from dataset
datasets = self.config.get("datasets", {})
# Get the first dataset (or the one used by this step)
for dataset_name, dataset_config in datasets.items():
dataset_path = dataset_config.get("path")
if dataset_path and os.path.exists(dataset_path):
with open(dataset_path, "r") as f:
data = json.load(f)
return data[: self.sample_size]
raise FileNotFoundError("No dataset found in config")
else:
# Load from previous operation's intermediate output
prev_op_name = op_names[op_idx - 1]
output_path = os.path.join(
self.intermediate_dir, step_name, f"{prev_op_name}.json"
)
if not os.path.exists(output_path):
raise FileNotFoundError(
f"No intermediate output found at {output_path}. "
"Run the previous operation first."
)
with open(output_path, "r") as f:
data = json.load(f)
return data[: self.sample_size]
def get_input_file_path(self) -> str:
"""Get the input file path for the pipeline."""
datasets = self.config.get("datasets", {})
for dataset_name, dataset_config in datasets.items():
path = dataset_config.get("path")
if path:
return path
return ""
def generate_candidates(
self,
op_name: str,
sample_data: list[dict[str, Any]],
target_op: dict[str, Any],
) -> list[dict[str, Any]]:
"""
Generate candidate decompositions using directives.
Directives are selected based on data characteristics:
- DocumentChunkingDirective: only if avg doc size > 10% of context window
- DeterministicDocCompressionDirective: only if avg doc size > 1000 chars
Args:
op_name: Name of the operation to decompose
sample_data: Sample data for analyzing document characteristics
target_op: The target operation configuration
Returns:
List of candidate dictionaries with 'name', 'ops', 'cost' keys
"""
candidates = []
# Add original as baseline
self._log("Adding original operation as baseline candidate")
candidates.append(
{
"name": "original",
"ops": deepcopy(self.operators),
"cost": 0.0,
"error": None,
}
)
input_file_path = self.get_input_file_path()
# Get applicable directives based on data characteristics
applicable_directives = self.get_applicable_directives(sample_data, target_op)
self._log(
f"Generating candidates using {len(applicable_directives)} directives..."
)
for i, directive in enumerate(applicable_directives, 1):
with self.console.status(
f"[bold cyan]({i}/{len(applicable_directives)}) Trying directive: {directive.name}...[/bold cyan]",
spinner="dots",
):
try:
new_ops_list, _, cost = directive.instantiate(
operators=deepcopy(self.operators),
target_ops=[op_name],
agent_llm=self.optimizer_model,
message_history=[],
global_default_model=self.default_model,
input_file_path=input_file_path,
)
self.total_cost += cost
self._log(
f" [green]✓[/green] {directive.name} generated {len(new_ops_list)} operations (cost: ${cost:.4f})"
)
candidates.append(
{
"name": directive.name,
"ops": new_ops_list,
"cost": cost,
"error": None,
}
)
except Exception as e:
# Directive not applicable or failed - skip it
self._log(f" [red]✗[/red] {directive.name} failed: {str(e)}")
candidates.append(
{
"name": directive.name,
"ops": None,
"cost": 0.0,
"error": str(e),
}
)
return candidates
def extract_ops_to_run(
self, ops_list: list[dict], original_op_name: str
) -> list[dict]:
"""
Extract the operations that replaced the original operation.
Args:
ops_list: The full transformed operations list
original_op_name: Name of the original operation that was decomposed
Returns:
List of operations that should be run on samples
"""
# Find ops that are new (not in original) or modified
original_names = {op["name"] for op in self.operators}
# Find the position where the original op was
original_idx = None
for i, op in enumerate(self.operators):
if op["name"] == original_op_name:
original_idx = i
break
if original_idx is None:
return ops_list
# Find new ops (those not in original list)
new_ops = []
for op in ops_list:
if op["name"] not in original_names or op["name"] == original_op_name:
new_ops.append(op)
return new_ops if new_ops else [ops_list[original_idx]]
def run_candidate_on_samples(
self,
candidate: dict[str, Any],
sample_data: list[dict[str, Any]],
original_op_name: str,
) -> list[dict[str, Any]]:
"""
Run a candidate's operations on sample data.
Args:
candidate: Candidate dictionary with 'ops' key
sample_data: List of sample documents
original_op_name: Name of the original operation
Returns:
List of output documents
"""
if candidate["ops"] is None:
return []
# Extract ops to run
ops_to_run = self.extract_ops_to_run(candidate["ops"], original_op_name)
if not ops_to_run:
return []
# Create a minimal config for running these ops
# Write sample data to a temp file
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(sample_data, f)
temp_input_path = f.name
# Create a temp output file (required by DSLRunner validation)
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
temp_output_path = f.name
try:
# Create a minimal pipeline config
temp_config = {
"default_model": self.default_model,
"operations": ops_to_run,
"datasets": {
"sample_data": {
"type": "file",
"path": temp_input_path,
}
},
"pipeline": {
"steps": [
{
"name": "decompose_test",
"input": "sample_data",
"operations": [op["name"] for op in ops_to_run],
}
],
"output": {
"type": "file",
"path": temp_output_path,
},
},
}
# Create runner and execute
runner = DSLRunner(temp_config, max_threads=4)
# Run operations sequentially on the data
current_data = sample_data
for op_config in ops_to_run:
current_data = runner._run_operation(op_config, current_data)
self.total_cost += runner.total_cost
return current_data
finally:
# Clean up temp files
for temp_path in [temp_input_path, temp_output_path]:
try:
os.unlink(temp_path)
except Exception:
pass
def pairwise_compare(
self,
candidate_a: dict[str, Any],
candidate_b: dict[str, Any],
original_prompt: str,
output_schema: dict[str, Any],
) -> dict[str, Any]:
"""
Compare two candidates using LLM judge.
Args:
candidate_a: First candidate with 'name', 'outputs' keys
candidate_b: Second candidate with 'name', 'outputs' keys
original_prompt: The original operation's prompt
output_schema: The expected output schema
Returns:
The winning candidate
"""
# If one has no outputs, the other wins
if not candidate_a.get("outputs"):
return candidate_b
if not candidate_b.get("outputs"):
return candidate_a
system_prompt = """You are an expert judge comparing outputs from two data processing pipeline variants.
Your task is to determine which variant produces BETTER outputs based on:
1. **Completeness**: Does the output contain all required information?
2. **Accuracy**: Is the extracted/generated information correct?
3. **Consistency**: Are the outputs consistent across different samples?
4. **Quality**: Is the output well-structured and useful?
Be objective and focus on the actual output quality, not the approach used."""
user_prompt = f"""Compare outputs from two pipeline variants for this task:
## Original Task
**Prompt:**
```
{original_prompt[:2000]}{"..." if len(original_prompt) > 2000 else ""}
```
**Expected Output Schema:**
```json
{json.dumps(output_schema, indent=2)}
```
## Variant A: {candidate_a["name"]}
**Sample Outputs:**
```json
{json.dumps(candidate_a["outputs"][:3], indent=2, default=str)}
```
## Variant B: {candidate_b["name"]}
**Sample Outputs:**
```json
{json.dumps(candidate_b["outputs"][:3], indent=2, default=str)}
```
## Your Task
Which variant produces better outputs? Consider completeness, accuracy, consistency, and quality.
Respond with your analysis and final choice."""
response = completion(
model=self.optimizer_model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format={
"type": "json_schema",
"json_schema": {
"name": "comparison_result",
"strict": True,
"schema": {
"type": "object",
"properties": {
"winner": {
"type": "string",
"enum": ["A", "B"],
"description": "Which variant is better: A or B",
},
"rationale": {
"type": "string",
"description": "Explanation of why this variant is better",
},
"a_strengths": {
"type": "array",
"items": {"type": "string"},
"description": "Strengths of variant A",
},
"b_strengths": {
"type": "array",
"items": {"type": "string"},
"description": "Strengths of variant B",
},
},
"required": [
"winner",
"rationale",
"a_strengths",
"b_strengths",
],
"additionalProperties": False,
},
},
},
**self.litellm_kwargs,
)
cost = completion_cost(response)
self.total_cost += cost
result = json.loads(response.choices[0].message.content)
winner = candidate_a if result["winner"] == "A" else candidate_b
winner["comparison_rationale"] = result["rationale"]
return winner
def decompose(
self,
step_name: str,
op_name: str,
) -> tuple[list[dict[str, Any]], str, int, float]:
"""
Decompose an operation using fast directive-based approach.
This is the main entry point. It:
1. Generates candidate decompositions using directives
2. Runs each on sample data
3. Uses pairwise LLM comparison to select the best
4. Returns the winning decomposition
Args:
step_name: Name of the pipeline step
op_name: Name of the operation to decompose
Returns:
Dict with:
- decomposed_ops: List of operations that replace the original
- winning_directive: Name of the winning directive
- candidates_evaluated: Number of candidates that were compared
- original_outputs: Sample outputs from the original operation
- decomposed_outputs: Sample outputs from the winning decomposition
- comparison_rationale: LLM's explanation of why the winner was chosen
- cost: Total LLM API cost
"""
self.console.rule("[bold blue]Fast Decomposition[/bold blue]")
self._log(f"[bold]Target operation:[/bold] {op_name}")
# Find the target operation config
target_op = None
for op in self.operators:
if op["name"] == op_name:
target_op = op
break
if target_op is None:
raise ValueError(f"Operation '{op_name}' not found in config")
# Verify it's a map operation
if target_op.get("type") != "map":
raise ValueError(
f"Operation '{op_name}' is type '{target_op.get('type')}', "
"but fast decomposition only supports 'map' operations"
)
# Load sample data
self._log(f"Loading sample data from step '{step_name}'...")
sample_data = self.load_sample_data(step_name, op_name)
self._log(f"[green]✓[/green] Loaded {len(sample_data)} sample documents")
# Generate candidates
self.console.rule("[bold]Generating Candidates[/bold]")
candidates = self.generate_candidates(op_name, sample_data, target_op)
# Filter out failed candidates
valid_candidates = [c for c in candidates if c["ops"] is not None]
self._log(f"[bold]{len(valid_candidates)}[/bold] valid candidates generated")
if len(valid_candidates) < 2:
# Only original (or nothing) - return original
self._log(
"[yellow]No alternative decompositions generated. Keeping original.[/yellow]"
)
return {
"decomposed_ops": self.operators,
"winning_directive": "original",
"candidates_evaluated": len(valid_candidates),
"original_outputs": [],
"decomposed_outputs": [],
"comparison_rationale": "No alternative decompositions were generated.",
"cost": self.total_cost,
}
# Run each candidate on samples IN PARALLEL
self.console.rule("[bold]Running Candidates on Samples[/bold]")
self._log(f"Running {len(valid_candidates)} candidates in parallel...")
def run_single_candidate(candidate):
"""Run a single candidate and return results."""
name = candidate["name"]
try:
outputs = self.run_candidate_on_samples(candidate, sample_data, op_name)
return {"name": name, "outputs": outputs, "error": None}
except Exception as e:
error_tb = traceback.format_exc()
return {
"name": name,
"outputs": [],
"error": str(e),
"traceback": error_tb,
}
# Run all candidates in parallel
with self.console.status(
"[bold cyan]Running candidates on samples...[/bold cyan]", spinner="dots"
):
with ThreadPoolExecutor(max_workers=len(valid_candidates)) as executor:
future_to_candidate = {
executor.submit(run_single_candidate, c): c
for c in valid_candidates
}
for future in as_completed(future_to_candidate):
candidate = future_to_candidate[future]
result = future.result()
if result["error"]:
candidate["outputs"] = []
candidate["run_error"] = result["error"]
self._log(
f" [red]✗[/red] {result['name']} failed: {result['error']}"
)
else:
candidate["outputs"] = result["outputs"]
self._log(
f" [green]✓[/green] {result['name']}: {len(result['outputs'])} outputs"
)
# Filter to candidates with outputs
candidates_with_outputs = [c for c in valid_candidates if c.get("outputs")]
self._log(
f"[bold]{len(candidates_with_outputs)}[/bold] candidates produced outputs"
)
if not candidates_with_outputs:
# All failed - return original
self._log("[red]All decomposition candidates failed to execute.[/red]")
# Log the errors for debugging
for c in valid_candidates:
if c.get("run_error"):
self._log(f" [red]{c['name']}:[/red] {c['run_error']}")
return {
"decomposed_ops": self.operators,
"winning_directive": "original",
"candidates_evaluated": 0,
"original_outputs": [],
"decomposed_outputs": [],
"comparison_rationale": "All decomposition candidates failed to execute.",
"cost": self.total_cost,
}
# Find the original candidate's outputs
original_candidate = next(
(c for c in candidates_with_outputs if c["name"] == "original"), None
)
original_outputs = original_candidate["outputs"] if original_candidate else []
# Parallel pairwise comparison: compare all candidates against original
self.console.rule("[bold]Pairwise Comparison[/bold]")
original_prompt = target_op.get("prompt", "")
output_schema = target_op.get("output", {}).get("schema", {})
# If only original exists, it wins by default
if len(candidates_with_outputs) == 1:
winner = candidates_with_outputs[0]
elif original_candidate is None:
# No original - just pick the first candidate
winner = candidates_with_outputs[0]
else:
# Compare all non-original candidates against original IN PARALLEL
challengers = [
c for c in candidates_with_outputs if c["name"] != "original"
]
if not challengers:
winner = original_candidate
else:
self._log(
f"Comparing {len(challengers)} candidates against original in parallel..."
)
def compare_against_original(challenger):
"""Compare a single challenger against original."""
try:
result = self.pairwise_compare(
original_candidate,
challenger,
original_prompt,
output_schema,
)
won = result["name"] == challenger["name"]
return {
"challenger": challenger,
"won": won,
"rationale": result.get("comparison_rationale", ""),
}
except Exception as e:
return {
"challenger": challenger,
"won": False,
"error": str(e),
}
# Run all comparisons in parallel
comparison_results = []
with self.console.status(
f"[bold cyan]Running {len(challengers)} comparisons in parallel...[/bold cyan]",
spinner="dots",
):
with ThreadPoolExecutor(max_workers=len(challengers)) as executor:
future_to_challenger = {
executor.submit(compare_against_original, c): c
for c in challengers
}
for future in as_completed(future_to_challenger):
result = future.result()
comparison_results.append(result)
challenger_name = result["challenger"]["name"]
if result.get("error"):
self._log(
f" [red]✗[/red] {challenger_name} vs original: error - {result['error']}"
)
elif result["won"]:
self._log(
f" [green]✓[/green] {challenger_name} beat original"
)
else:
self._log(
f" [dim]○[/dim] {challenger_name} lost to original"
)
# Find winners (candidates that beat original)
winners = [r for r in comparison_results if r.get("won")]
if not winners:
# Original beats all challengers
winner = original_candidate
self._log("[bold]Original wins against all challengers[/bold]")
elif len(winners) == 1:
# Single winner
winner = winners[0]["challenger"]
winner["comparison_rationale"] = winners[0].get("rationale", "")
else:
# Multiple winners beat original - run tiebreaker comparisons in parallel
self._log(
f"[bold]{len(winners)} candidates beat original - running tiebreaker...[/bold]"
)
# Compare all winners against each other in parallel (round-robin)
winner_candidates = [w["challenger"] for w in winners]
win_counts = {c["name"]: 0 for c in winner_candidates}
# Generate all pairwise matchups
matchups = []
for i, a in enumerate(winner_candidates):
for b in winner_candidates[i + 1 :]:
matchups.append((a, b))
if matchups:
def run_matchup(matchup):
a, b = matchup
try:
result = self.pairwise_compare(
a, b, original_prompt, output_schema
)
return {
"winner": result["name"],
"a": a["name"],
"b": b["name"],
}
except Exception:
return {
"winner": a["name"],
"a": a["name"],
"b": b["name"],
} # Default to first
with self.console.status(
f"[bold cyan]Running {len(matchups)} tiebreaker comparisons...[/bold cyan]",
spinner="dots",
):
with ThreadPoolExecutor(
max_workers=len(matchups)
) as executor:
for result in executor.map(run_matchup, matchups):
win_counts[result["winner"]] += 1
self._log(
f" [dim]{result['a']} vs {result['b']}{result['winner']}[/dim]"
)
# Pick candidate with most wins
best_name = max(win_counts, key=win_counts.get)
winner = next(
c for c in winner_candidates if c["name"] == best_name
)
self._log(
f"[bold]Tiebreaker winner: {best_name} ({win_counts[best_name]} wins)[/bold]"
)
# Extract the decomposed operations
decomposed_ops = self.extract_ops_to_run(winner["ops"], op_name)
# Final summary
self.console.rule("[bold green]Decomposition Complete[/bold green]")
self._log(f"[bold]Winner:[/bold] [green]{winner['name']}[/green]")
self._log(f"[bold]Candidates evaluated:[/bold] {len(candidates_with_outputs)}")
self._log(f"[bold]New operations:[/bold] {len(decomposed_ops)}")
self._log(f"[bold]Total cost:[/bold] ${self.total_cost:.4f}")
return {
"decomposed_ops": decomposed_ops,
"winning_directive": winner["name"],
"candidates_evaluated": len(candidates_with_outputs),
"original_outputs": original_outputs,
"decomposed_outputs": winner.get("outputs", []),
"comparison_rationale": winner.get("comparison_rationale", ""),
"cost": self.total_cost,
}

View File

@ -0,0 +1,337 @@
"""
Fast should_optimize analyzer using a single LLM call.
This module provides a lightweight alternative to the full MapOptimizer/ReduceOptimizer/JoinOptimizer
flow for quickly determining if an operation should be decomposed.
"""
import json
import os
from typing import Any
import litellm
from litellm import completion, model_cost
from docetl.utils import completion_cost, count_tokens
# Drop unsupported params for models like gpt-5 that don't support temperature=0
litellm.drop_params = True
class FastShouldOptimizeAnalyzer:
"""
Analyzes whether an operation should be optimized using a single LLM call.
Instead of running the operation on sample data and using complex evaluation logic,
this reads cached outputs from intermediate files and makes one judgment call.
"""
def __init__(
self,
intermediate_dir: str,
optimizer_model: str = "gpt-5.1",
litellm_kwargs: dict[str, Any] | None = None,
):
"""
Initialize the analyzer.
Args:
intermediate_dir: Path to the directory containing intermediate outputs
optimizer_model: The LLM model to use for analysis (default: gpt-5.1)
litellm_kwargs: Additional kwargs to pass to litellm.completion
"""
self.intermediate_dir = intermediate_dir
self.optimizer_model = optimizer_model
self.litellm_kwargs = litellm_kwargs or {}
if "temperature" not in self.litellm_kwargs:
self.litellm_kwargs["temperature"] = 0.0
def load_operation_data(self, step_name: str, op_name: str) -> list[dict[str, Any]]:
"""
Load data from the intermediate file for an operation.
Args:
step_name: Name of the pipeline step
op_name: Name of the operation
Returns:
List of dictionaries
Raises:
FileNotFoundError: If the intermediate file doesn't exist
"""
output_path = os.path.join(self.intermediate_dir, step_name, f"{op_name}.json")
if not os.path.exists(output_path):
raise FileNotFoundError(
f"No output file found at {output_path}. "
"Run the operation first to generate outputs."
)
with open(output_path, "r") as f:
return json.load(f)
def find_previous_operation(
self, operations: list[dict[str, Any]], op_name: str
) -> str | None:
"""
Find the operation that comes before op_name in the pipeline.
Args:
operations: List of operation configs from the pipeline
op_name: Name of the current operation
Returns:
Name of the previous operation, or None if this is the first operation
"""
op_names = [op.get("name") for op in operations]
try:
idx = op_names.index(op_name)
if idx > 0:
return op_names[idx - 1]
except ValueError:
pass
return None
def get_max_context_tokens(self) -> int:
"""
Get the maximum input tokens for the optimizer model.
Returns:
Maximum number of input tokens the model can handle
"""
model_info = model_cost.get(self.optimizer_model, {})
# Try without provider prefix if not found
if not model_info:
model_name = self.optimizer_model.split("/")[-1]
model_info = model_cost.get(model_name, {})
return model_info.get("max_input_tokens", 128000) # Default to 128k
def calculate_samples_that_fit(
self,
op_config: dict[str, Any],
outputs: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""
Calculate how many output samples fit in the context window.
Reserves space for system prompt, operation config, and response buffer,
then fills remaining space with as many samples as possible.
Args:
op_config: The operation configuration dictionary
outputs: List of all output documents
Returns:
List of samples that fit in the context window
"""
max_tokens = self.get_max_context_tokens()
# Reserve tokens for fixed parts
system_prompt_tokens = 500
op_config_tokens = count_tokens(
json.dumps(op_config, default=str), self.optimizer_model
)
response_buffer = 2000
available_for_samples = (
max_tokens - system_prompt_tokens - op_config_tokens - response_buffer
)
# Collect samples that fit
samples_to_include = []
tokens_used = 0
for output in outputs:
sample_json = json.dumps(output, default=str)
sample_tokens = count_tokens(sample_json, self.optimizer_model)
if tokens_used + sample_tokens <= available_for_samples:
samples_to_include.append(output)
tokens_used += sample_tokens
else:
break
return samples_to_include
def build_analysis_prompt(
self,
op_config: dict[str, Any],
samples: list[dict[str, Any]],
) -> tuple[str, str]:
"""
Build the system prompt and user prompt for the analysis LLM call.
Args:
op_config: The operation configuration dictionary
samples: List of output samples to analyze
Returns:
Tuple of (system_prompt, user_prompt)
"""
system_prompt = """You are an expert at analyzing LLM-powered data processing operations.
Your task is to determine if an operation would benefit from being decomposed into multiple
smaller, focused operations (also known as "optimization" or "decomposition").
An operation SHOULD be decomposed when:
1. The prompt asks the LLM to do multiple distinct tasks that could be done separately
2. The task is complex enough that breaking it into sequential steps would improve accuracy
3. The outputs show inconsistency, incompleteness, or quality issues that iterative refinement could fix
4. Long documents need to be processed in chunks rather than all at once
5. The prompt asks for both extraction AND analysis/synthesis in one step
An operation should NOT be decomposed when:
1. It performs a single, focused task well
2. The outputs are consistently high quality and complete
3. The task is simple and atomic (e.g., simple classification, single field extraction)
4. The operation is already well-scoped and produces reliable results
Be conservative - only recommend decomposition if there's clear evidence it would help."""
output_schema = op_config.get("output", {}).get("schema", {})
prompt_template = op_config.get("prompt", "No prompt specified")
# Truncate very long prompts for display
if len(prompt_template) > 3000:
prompt_template = prompt_template[:3000] + "\n... [truncated]"
user_prompt = f"""Analyze this data processing operation and its outputs:
## Operation Configuration
**Name:** {op_config.get('name', 'unknown')}
**Type:** {op_config.get('type', 'unknown')}
**Prompt Template:**
```
{prompt_template}
```
**Output Schema:**
```json
{json.dumps(output_schema, indent=2)}
```
## Sample Outputs ({len(samples)} samples from the operation)
```json
{json.dumps(samples, indent=2, default=str)}
```
## Your Task
Based on the operation configuration and sample outputs, determine:
1. Should this operation be decomposed/optimized?
2. If yes, what specific improvements would help?
Consider the quality, completeness, and consistency of the outputs when making your assessment."""
return system_prompt, user_prompt
def analyze(
self,
op_config: dict[str, Any],
step_name: str,
op_name: str,
) -> tuple[str, list[dict[str, Any]], int, float]:
"""
Analyze whether an operation should be optimized.
This is the main entry point. It loads outputs, builds the prompt,
makes a single LLM call, and returns the assessment.
Args:
op_config: The operation configuration dictionary
step_name: Name of the pipeline step
op_name: Name of the operation
Returns:
Tuple of (rationale, output_samples, num_docs_analyzed, cost):
- rationale: Empty string if no optimization needed, explanation if it should be optimized
- output_samples: The samples that were analyzed
- num_docs_analyzed: Number of documents that fit in the LLM prompt
- cost: LLM API cost in USD
Raises:
ValueError: If the operation is not an LLM-powered map operation
"""
# Validate operation type - only LLM-powered map operations
op_type = op_config.get("type", "")
if op_type != "map":
raise ValueError(
f"should_optimize only supports 'map' operations, got '{op_type}'. "
"Only LLM-powered map operations can be analyzed for decomposition."
)
# Load outputs from intermediate file
outputs = self.load_operation_data(step_name, op_name)
if not outputs:
return "No output samples available for analysis.", [], 0, 0.0
# Calculate samples that fit in context
samples = self.calculate_samples_that_fit(op_config, outputs)
if not samples:
return "Could not fit any samples in context window.", outputs[:5], 0, 0.0
# Build prompt
system_prompt, user_prompt = self.build_analysis_prompt(op_config, samples)
# Make LLM call with structured output
response = completion(
model=self.optimizer_model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format={
"type": "json_schema",
"json_schema": {
"name": "optimization_analysis",
"strict": True,
"schema": {
"type": "object",
"properties": {
"should_optimize": {
"type": "boolean",
"description": "True if operation should be decomposed/optimized",
},
"rationale": {
"type": "string",
"description": "Explanation of why the operation should or should not be optimized",
},
"suggested_improvements": {
"type": "array",
"items": {"type": "string"},
"description": "Specific improvements if optimization is recommended (empty if not)",
},
},
"required": [
"should_optimize",
"rationale",
"suggested_improvements",
],
"additionalProperties": False,
},
},
},
**self.litellm_kwargs,
)
# Calculate cost
cost = completion_cost(response)
# Parse response
result = json.loads(response.choices[0].message.content)
num_docs_analyzed = len(samples)
if result["should_optimize"]:
# Build rationale string with improvements
rationale_parts = [result["rationale"]]
if result["suggested_improvements"]:
rationale_parts.append("\n\nSuggested improvements:")
for imp in result["suggested_improvements"]:
rationale_parts.append(f"- {imp}")
return "\n".join(rationale_parts), samples, num_docs_analyzed, cost
else:
# Return empty string to indicate no optimization needed
return "", samples, num_docs_analyzed, cost

View File

@ -30,7 +30,7 @@ class LLMClient:
Initialize the LLMClient.
Args:
model (str, optional): The name of the LLM model to use. Defaults to "gpt-4o".
model (str, optional): The name of the LLM model to use. Defaults to "gpt-5.1".
**litellm_kwargs: Additional keyword arguments for the LLM model.
"""
self.rewrite_agent_model = rewrite_agent_model

View File

@ -204,10 +204,6 @@ def get_openai_response(
response = litellm.completion(
model=model,
messages=messages,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ResponseFormat,
)
assistant_response = response.choices[0].message.content

View File

@ -30,6 +30,8 @@ from .map_reduce_fusion import MapReduceFusionDirective
from .hierarchical_reduce import HierarchicalReduceDirective
from .cascade_filtering import CascadeFilteringDirective
from .arbitrary_rewrite import ArbitraryRewriteDirective
from .map_to_map_resolve_reduce import MapToMapResolveReduceDirective
from .map_resolve_to_map_with_categories import MapResolveToMapWithCategoriesDirective
# Registry of all available directives
ALL_DIRECTIVES = [
@ -53,6 +55,8 @@ ALL_DIRECTIVES = [
HierarchicalReduceDirective(),
CascadeFilteringDirective(),
ArbitraryRewriteDirective(),
MapToMapResolveReduceDirective(),
MapResolveToMapWithCategoriesDirective(),
]
ALL_COST_DIRECTIVES = [
@ -179,8 +183,10 @@ __all__ = [
"HierarchicalReduceDirective",
"CascadeFilteringDirective",
"ArbitraryRewriteDirective",
"MapToMapResolveReduceDirective",
"MapResolveToMapWithCategoriesDirective",
"ALL_DIRECTIVES",
"DIRECTIVE_REGISTRY",
"DIRECTIVE_REGISTRY",
"get_all_directive_strings",
"instantiate_directive"
]

View File

@ -324,10 +324,6 @@ Focus on quality over quantity - a few diverse, informative examples are better
model=self.agent_llm,
messages=self.message_history,
response_format=AgentDecision,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
)
call_cost += response._hidden_params["response_cost"]
@ -415,10 +411,6 @@ Provide your response as a JSON object matching this schema: {response_schema.mo
model=self.agent_llm,
messages=self.message_history,
response_format=response_schema,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
)
call_cost += schema_response._hidden_params["response_cost"]

View File

@ -123,7 +123,8 @@ class Directive(BaseModel, ABC):
try:
# 1. Execute the directive
actual_output, _ = self.instantiate(
# instantiate returns (new_ops_plan, message_history, call_cost)
instantiate_result = self.instantiate(
operators=(
[test_case.input_config]
if isinstance(test_case.input_config, dict)
@ -134,6 +135,11 @@ class Directive(BaseModel, ABC):
input_file_path=temp_file_path,
pipeline_code=fake_pipeline,
)
# Handle both 2-tuple and 3-tuple returns
if isinstance(instantiate_result, tuple):
actual_output = instantiate_result[0]
else:
actual_output = instantiate_result
# 2. Use LLM judge to evaluate
judge_result = self._llm_judge_test(
@ -216,7 +222,6 @@ class Directive(BaseModel, ABC):
model=agent_llm,
messages=messages,
response_format=JudgeResponse,
azure=True,
)
# Parse the JSON response

View File

@ -1,5 +1,4 @@
import json
import os
import re
from copy import deepcopy
from typing import Dict, List, Type
@ -169,19 +168,15 @@ class ChainingDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
# api_key=os.environ["GEMINI_API_KEY"],
azure=True,
response_format=ChainingInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
@ -202,11 +197,12 @@ class ChainingDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -211,18 +210,14 @@ class ChangeModelDirective(Directive):
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
# api_key=os.environ["GEMINI_API_KEY"],
azure=True,
response_format=ChangeModelInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ChangeModelInstantiateSchema(**parsed_res)
@ -240,11 +235,12 @@ class ChangeModelDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -205,17 +204,14 @@ class ChangeModelAccDirective(Directive):
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ChangeModelInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ChangeModelInstantiateSchema(**parsed_res)
@ -233,11 +229,12 @@ class ChangeModelAccDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -282,17 +281,14 @@ class ChangeModelCostDirective(Directive):
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ChangeModelInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ChangeModelInstantiateSchema(**parsed_res)
@ -310,11 +306,12 @@ class ChangeModelCostDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -185,17 +184,14 @@ class ChunkHeaderSummaryDirective(Directive):
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ChunkHeaderSummaryInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ChunkHeaderSummaryInstantiateSchema(**parsed_res)
@ -204,11 +200,12 @@ class ChunkHeaderSummaryDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
@ -251,9 +248,7 @@ class ChunkHeaderSummaryDirective(Directive):
)
if gather_idx - split_idx > 1:
raise ValueError(
"There should not be operators between split and gather"
)
raise ValueError("There should not be operators between split and gather")
# Get the split_key from the split operation
split_key = split_op.get("split_key")

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -261,17 +260,14 @@ def transform(input_doc):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=DeterministicDocCompressionInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = DeterministicDocCompressionInstantiateSchema(**parsed_res)
@ -284,11 +280,12 @@ def transform(input_doc):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -268,18 +267,15 @@ class DocumentChunkingDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
call_cost = 0.0
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=DocumentChunkingInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = DocumentChunkingInstantiateSchema(**parsed_res)
@ -290,11 +286,12 @@ class DocumentChunkingDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
@ -404,16 +401,19 @@ class DocumentChunkingDirective(Directive):
sample_op["stratify_key"] = stratify_keys
if rewrite.sampling_config.samples_per_group:
sample_op["samples_per_group"] = rewrite.sampling_config.samples_per_group
sample_op["samples_per_group"] = (
rewrite.sampling_config.samples_per_group
)
# Add optional fields if provided
if rewrite.sampling_config.random_state is not None:
sample_op["random_state"] = rewrite.sampling_config.random_state
if rewrite.sampling_config.method_kwargs is not None:
try:
sample_op["method_kwargs"] = json.loads(rewrite.sampling_config.method_kwargs)
sample_op["method_kwargs"] = json.loads(
rewrite.sampling_config.method_kwargs
)
except Exception as e:
raise ValueError(f"Invalid method_kwargs: {e}")

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -431,17 +430,14 @@ class DocumentChunkingTopKDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=DocumentChunkingTopKInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = DocumentChunkingTopKInstantiateSchema(**parsed_res)
@ -451,11 +447,12 @@ class DocumentChunkingTopKDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -164,19 +163,14 @@ class DocCompressionDirective(Directive):
},
]
)
error_message = ""
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=DocCompressionInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
@ -187,11 +181,12 @@ class DocCompressionDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Error: {error_message}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -261,17 +260,14 @@ class DocSummarizationDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=DocSummarizationInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = DocSummarizationInstantiateSchema(**parsed_res)
@ -280,11 +276,12 @@ class DocSummarizationDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -146,32 +145,31 @@ class GleaningDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
# api_key=os.environ["GEMINI_API_KEY"],
azure=True,
response_format=GleaningInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = GleaningInstantiateSchema(**parsed_res)
GleaningInstantiateSchema.check_no_jinja_variables(schema.validation_prompt)
GleaningInstantiateSchema.check_no_jinja_variables(
schema.validation_prompt
)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -180,19 +179,14 @@ class HierarchicalReduceDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=HierarchicalReduceInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
@ -214,11 +208,12 @@ class HierarchicalReduceDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
@ -259,7 +254,7 @@ class HierarchicalReduceDirective(Directive):
"name": rewrite.map_config.name,
"type": "map",
"prompt": rewrite.map_config.prompt,
"model": default_model,
"model": default_model,
"litellm_completion_kwargs": {"temperature": 0},
"output": {"schema": {rewrite.map_config.output_keys[0]: "string"}},
}
@ -324,4 +319,4 @@ class HierarchicalReduceDirective(Directive):
new_ops_plan = self.apply(
global_default_model, operators, target_ops[0], rewrite
)
return new_ops_plan, message_history, call_cost
return new_ops_plan, message_history, call_cost

View File

@ -1,5 +1,4 @@
import json
import os
import re
from copy import deepcopy
from typing import Dict, List, Type
@ -260,17 +259,14 @@ class IsolatingSubtasksDirective(Directive):
original_op.get("output", {}).get("schema", {}).keys()
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=IsolatingSubtasksInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = IsolatingSubtasksInstantiateSchema(**parsed_res)
@ -285,11 +281,12 @@ class IsolatingSubtasksDirective(Directive):
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease ensure all original output keys are covered by subtasks and try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
import re
from copy import deepcopy
from typing import Dict, List, Type
@ -7,14 +6,20 @@ from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import MapReduceFusionInstantiateSchema
from docetl.reasoning_optimizer.instantiate_schemas import (
MapReduceFusionInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class MapReduceFusionDirective(Directive):
name: str = Field(default="map_reduce_fusion", description="The name of the directive")
formal_description: str = Field(default="Map -> Reduce => Map (with new prompt) -> Reduce (with new propmt)")
name: str = Field(
default="map_reduce_fusion", description="The name of the directive"
)
formal_description: str = Field(
default="Map -> Reduce => Map (with new prompt) -> Reduce (with new propmt)"
)
nl_description: str = Field(
default="Transform a Map operation followed by a Reduce operation over long documents by updating the Map prompt to first extract or process only the relevant information from each document. Then, modify the Reduce prompt to operate on these processed outputs instead of the full document content."
)
@ -42,7 +47,7 @@ class MapReduceFusionDirective(Directive):
" type: reduce\\n"
" reduce_key: category\\n"
" prompt: |\\n"
" For each category \\\"{{ reduce_key }}\\\", extract all organization names from these documents:\\n"
' For each category \\"{{ reduce_key }}\\", extract all organization names from these documents:\\n'
" {% for input in inputs %}\\n"
" Document {{ loop.index }}: {{ input.content }}\\n"
" {% endfor %}\\n"
@ -79,7 +84,7 @@ class MapReduceFusionDirective(Directive):
"reduce_key": "category",
"prompt": "For each category '{{ reduce_key }}', extract organizations from: {% for input in inputs %}{{ input.content }}{% endfor %}",
"output": {"schema": {"organizations": "list[str]"}},
}
},
],
target_ops=["classify_document", "extract_organizations"],
expected_behavior="Should modify map to classify AND extract organizations per document, then reduce to aggregate pre-extracted organizations",
@ -101,7 +106,7 @@ class MapReduceFusionDirective(Directive):
"reduce_key": "doc_type",
"prompt": "For each document type '{{ reduce_key }}', find all people mentioned: {% for input in inputs %}Document: {{ input.text }}{% endfor %}",
"output": {"schema": {"people": "list[str]"}},
}
},
],
target_ops=["analyze_document", "find_people"],
expected_behavior="Should modify map to analyze type AND extract people per document, then reduce to combine pre-extracted people lists",
@ -159,7 +164,7 @@ class MapReduceFusionDirective(Directive):
reduce_op: Dict,
expected_document_key,
agent_llm: str,
message_history: list = []
message_history: list = [],
):
"""
Use LLM to instantiate this directive by transforming the map and reduce operations.
@ -173,66 +178,85 @@ class MapReduceFusionDirective(Directive):
Returns:
MapReduceFusionInstantiateSchema: The structured output from the LLM.
"""
message_history.extend([
{"role": "system", "content": "You are a helpful AI assistant for optimizing document processing pipelines."},
{"role": "user", "content": self.to_string_for_instantiate([map_op, reduce_op])},
])
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for optimizing document processing pipelines.",
},
{
"role": "user",
"content": self.to_string_for_instantiate([map_op, reduce_op]),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=MapReduceFusionInstantiateSchema
response_format=MapReduceFusionInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
if "new_map_name" not in parsed_res or "new_map_prompt" not in parsed_res or "new_key" not in parsed_res or "new_reduce_prompt" not in parsed_res:
if (
"new_map_name" not in parsed_res
or "new_map_prompt" not in parsed_res
or "new_key" not in parsed_res
or "new_reduce_prompt" not in parsed_res
):
raise ValueError(
"Response from LLM is missing required keys: 'new_map_name', 'new_map_prompt', 'new_key', or 'new_reduce_prompt'"
)
schema = MapReduceFusionInstantiateSchema(
new_map_name=parsed_res["new_map_name"],
new_map_prompt=parsed_res["new_map_prompt"],
new_key=parsed_res["new_key"],
new_reduce_prompt=parsed_res["new_reduce_prompt"]
new_reduce_prompt=parsed_res["new_reduce_prompt"],
)
# Validate the schema
MapReduceFusionInstantiateSchema.validate_reduce_prompt_references_new_key(
schema.new_reduce_prompt, schema.new_key, expected_document_key
)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts.")
def apply(self, global_default_model, ops_list: List[Dict], map_target: str, reduce_target: str, rewrite: MapReduceFusionInstantiateSchema) -> List[Dict]:
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model,
ops_list: List[Dict],
map_target: str,
reduce_target: str,
rewrite: MapReduceFusionInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config.
"""
# Create a copy of the pipeline config
new_ops_list = deepcopy(ops_list)
# Find positions of the target ops
map_pos = None
reduce_pos = None
orig_map_op = None
orig_reduce_op = None
for i, op in enumerate(ops_list):
if op["name"] == map_target:
map_pos = i
@ -240,13 +264,15 @@ class MapReduceFusionDirective(Directive):
elif op["name"] == reduce_target:
reduce_pos = i
orig_reduce_op = op
if map_pos is None or reduce_pos is None:
raise ValueError(f"Could not find target operations: {map_target}, {reduce_target}")
raise ValueError(
f"Could not find target operations: {map_target}, {reduce_target}"
)
# Get default model
default_model = orig_map_op.get("model", global_default_model)
# Create the new map operation with fused functionality
new_map_op = {
"name": rewrite.new_map_name,
@ -257,11 +283,11 @@ class MapReduceFusionDirective(Directive):
"output": {
"schema": {
**orig_map_op.get("output", {}).get("schema", {}),
rewrite.new_key: "list[str]"
rewrite.new_key: "list[str]",
}
}
},
}
# Create the new reduce operation that works with pre-extracted data
new_reduce_op = {
"name": orig_reduce_op["name"],
@ -270,54 +296,58 @@ class MapReduceFusionDirective(Directive):
"prompt": rewrite.new_reduce_prompt,
"model": default_model,
"litellm_completion_kwargs": {"temperature": 0},
"output": orig_reduce_op.get("output", {})
"output": orig_reduce_op.get("output", {}),
}
# Replace the operations
new_ops_list[map_pos] = new_map_op
new_ops_list[reduce_pos] = new_reduce_op
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
**kwargs
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there are exactly two target ops (map and reduce)
if len(target_ops) != 2:
raise ValueError("map_reduce_fusion directive requires exactly two target operators: a map and a reduce operation")
raise ValueError(
"map_reduce_fusion directive requires exactly two target operators: a map and a reduce operation"
)
# Find the map and reduce operations
first_op = None
second_op = None
for op in operators:
if op["name"] == target_ops[0]:
first_op = op
elif op["name"] == target_ops[1]:
second_op = op
if first_op is None or second_op is None:
raise ValueError(f"Could not find target operations: {target_ops}")
if first_op.get("type") == "map" and second_op.get("type") == "reduce":
map_op = first_op
reduce_op = second_op
map_target = target_ops[0]
reduce_target = target_ops[1]
else:
raise ValueError("Target operators must be one map operation followed by one reduce operation!")
# Extract expected document key from the reduce prompt template
raise ValueError(
"Target operators must be one map operation followed by one reduce operation!"
)
# Extract expected document key from the reduce prompt template
prompt_template = reduce_op["prompt"]
# Find all occurrences of {{ input.key }} in the prompt
input_key_pattern = r"\{\{\s*([^\}\s]+)\s*\}\}"
@ -335,7 +365,9 @@ class MapReduceFusionDirective(Directive):
]
if document_key_candidates:
expected_document_key = document_key_candidates[0] # Pick the first candidate
expected_document_key = document_key_candidates[
0
] # Pick the first candidate
elif input_keys:
expected_document_key = input_keys[0] # Fall back to the first input key
else:
@ -344,8 +376,12 @@ class MapReduceFusionDirective(Directive):
print(f"Detected document key: {expected_document_key}")
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(map_op, reduce_op, expected_document_key, agent_llm, message_history)
rewrite, message_history, call_cost = self.llm_instantiate(
map_op, reduce_op, expected_document_key, agent_llm, message_history
)
# Apply the rewrite to the operators
new_ops_plan = self.apply(global_default_model, operators, map_target, reduce_target, rewrite)
return new_ops_plan, message_history, call_cost
new_ops_plan = self.apply(
global_default_model, operators, map_target, reduce_target, rewrite
)
return new_ops_plan, message_history, call_cost

View File

@ -0,0 +1,361 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
MapResolveToMapWithCategoriesInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class MapResolveToMapWithCategoriesDirective(Directive):
name: str = Field(
default="map_resolve_to_map_with_categories",
description="The name of the directive",
)
formal_description: str = Field(default="Map -> Resolve => Map (with categories)")
nl_description: str = Field(
default="Replace a Map -> Resolve pattern with a single Map operation that has predefined categories. The agent analyzes the data and task to propose a set of canonical categories, and the new Map forces outputs into one of these categories (or 'none of the above'). This effectively performs entity resolution deterministically by standardizing outputs upfront, avoiding the need for pairwise comparisons."
)
when_to_use: str = Field(
default="When a Map operation produces outputs that are then resolved/deduplicated, and the set of valid output categories can be enumerated upfront. This is more efficient than Resolve when the category space is small and well-defined (e.g., standardizing company types, product categories, sentiment labels). The target must be a Map operation followed by a Resolve operation."
)
instantiate_schema_type: Type[BaseModel] = (
MapResolveToMapWithCategoriesInstantiateSchema
)
example: str = Field(
default=(
"Original Pipeline:\n"
"- name: extract_sentiment\n"
" type: map\n"
" prompt: |\n"
" What is the sentiment of this review?\n"
" {{ input.review }}\n"
" output:\n"
" schema:\n"
" sentiment: string\n"
"\n"
"- name: normalize_sentiment\n"
" type: resolve\n"
" comparison_prompt: |\n"
" Are these sentiments equivalent?\n"
" Sentiment 1: {{ input1.sentiment }}\n"
" Sentiment 2: {{ input2.sentiment }}\n"
" resolution_prompt: |\n"
" Normalize these sentiments:\n"
" {% for input in inputs %}\n"
" - {{ input.sentiment }}\n"
" {% endfor %}\n"
" output:\n"
" schema:\n"
" sentiment: string\n"
"\n"
"Example InstantiateSchema:\n"
"MapResolveToMapWithCategoriesInstantiateSchema(\n"
" categories=['Positive', 'Negative', 'Neutral', 'Mixed'],\n"
" category_key='sentiment',\n"
" new_prompt='''Analyze the sentiment of this review and classify it into one of the following categories:\n"
"\n"
"Categories:\n"
"- Positive: Clearly positive sentiment, satisfaction, praise\n"
"- Negative: Clearly negative sentiment, complaints, criticism\n"
"- Neutral: No strong sentiment, factual statements\n"
"- Mixed: Contains both positive and negative elements\n"
"- None of the above: If the review doesn't fit any category\n"
"\n"
"Review: {{ input.review }}\n"
"\n"
"Return exactly one of: Positive, Negative, Neutral, Mixed, or None of the above.''',\n"
" include_none_of_above=True,\n"
")"
),
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="sentiment_categorization",
description="Should replace map+resolve with categorized map for sentiment",
input_config=[
{
"name": "extract_sentiment",
"type": "map",
"prompt": "What is the sentiment? {{ input.text }}",
"output": {"schema": {"sentiment": "string"}},
},
{
"name": "normalize_sentiment",
"type": "resolve",
"comparison_prompt": "Same sentiment? {{ input1.sentiment }} vs {{ input2.sentiment }}",
"resolution_prompt": "Normalize: {% for input in inputs %}{{ input.sentiment }}{% endfor %}",
"output": {"schema": {"sentiment": "string"}},
},
],
target_ops=["extract_sentiment", "normalize_sentiment"],
expected_behavior="Should create a single map with predefined sentiment categories",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, MapResolveToMapWithCategoriesDirective)
def __hash__(self):
return hash("MapResolveToMapWithCategoriesDirective")
def to_string_for_instantiate(
self, map_op: Dict, resolve_op: Dict, sample_data: List[Dict] = None
) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
map_op (Dict): The map operation configuration.
resolve_op (Dict): The resolve operation configuration.
sample_data (List[Dict], optional): Sample data to help identify categories.
Returns:
str: The agent prompt for instantiating the directive.
"""
sample_str = ""
if sample_data:
sample_str = f"\n\nSample Input Data (first 5 items):\n{json.dumps(sample_data[:5], indent=2)}"
return (
f"You are an expert at optimizing data processing pipelines by replacing entity resolution with categorical constraints.\n\n"
f"Map Operation:\n"
f"{str(map_op)}\n\n"
f"Resolve Operation:\n"
f"{str(resolve_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to replace the Map -> Resolve pattern with a single Map that uses predefined categories.\n\n"
f"Key Requirements:\n"
f"1. Analyze the map's output field and the resolve operation to understand what values are being normalized:\n"
f" - Look at the comparison_prompt to understand what variations are being matched\n"
f" - Look at the resolution_prompt to understand the canonical form\n\n"
f"2. Propose a set of categories that cover all expected outputs:\n"
f" - Categories should be mutually exclusive\n"
f" - Categories should be exhaustive (cover all realistic cases)\n"
f" - Consider including 'None of the above' for edge cases\n\n"
f"3. Optionally provide descriptions for each category to help the LLM classify correctly\n\n"
f"4. Create a new_prompt that:\n"
f" - Lists all valid categories with their descriptions\n"
f" - Instructs the LLM to output exactly one category\n"
f" - References the input using {{{{ input.key }}}} syntax\n"
f" - Includes any context from the original map prompt\n\n"
f"5. Identify the category_key (the output field that will contain the category)\n\n"
f"Benefits of this approach:\n"
f"- Eliminates O(n^2) pairwise comparisons from Resolve\n"
f"- Produces consistent, standardized outputs\n"
f"- Reduces cost by removing the Resolve operation entirely\n"
f"{sample_str}\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please analyze the operations and propose appropriate categories. Output the MapResolveToMapWithCategoriesInstantiateSchema."
)
def llm_instantiate(
self,
map_op: Dict,
resolve_op: Dict,
agent_llm: str,
message_history: list = [],
sample_data: List[Dict] = None,
):
"""
Use LLM to instantiate this directive.
Args:
map_op (Dict): The map operation configuration.
resolve_op (Dict): The resolve operation configuration.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
sample_data (List[Dict], optional): Sample data to help identify categories.
Returns:
MapResolveToMapWithCategoriesInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for document processing pipelines specializing in categorical classification.",
},
{
"role": "user",
"content": self.to_string_for_instantiate(
map_op, resolve_op, sample_data
),
},
]
)
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=MapResolveToMapWithCategoriesInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = MapResolveToMapWithCategoriesInstantiateSchema(**parsed_res)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(
self,
global_default_model,
ops_list: List[Dict],
map_op_name: str,
resolve_op_name: str,
rewrite: MapResolveToMapWithCategoriesInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config.
"""
new_ops_list = deepcopy(ops_list)
# Find position of the map and resolve ops
map_pos = None
resolve_pos = None
map_op = None
for i, op in enumerate(new_ops_list):
if op["name"] == map_op_name:
map_pos = i
map_op = op
elif op["name"] == resolve_op_name:
resolve_pos = i
if map_pos is None or resolve_pos is None:
raise ValueError(
f"Could not find map '{map_op_name}' and resolve '{resolve_op_name}' operations"
)
# Determine the model to use
default_model = map_op.get("model", global_default_model)
# Build the list of valid values for validation
valid_values = list(rewrite.categories)
if rewrite.include_none_of_above:
valid_values.append("None of the above")
# Modify the map operation with the new prompt and add validation
new_ops_list[map_pos]["prompt"] = rewrite.new_prompt
new_ops_list[map_pos]["model"] = default_model
# Add validation to ensure output is one of the categories
if "validate" not in new_ops_list[map_pos]:
new_ops_list[map_pos]["validate"] = []
# Add validation rule for the category key
validation_rule = f"output['{rewrite.category_key}'] in {valid_values}"
new_ops_list[map_pos]["validate"].append(validation_rule)
# Update the output schema to reflect the category key
if "output" not in new_ops_list[map_pos]:
new_ops_list[map_pos]["output"] = {"schema": {}}
new_ops_list[map_pos]["output"]["schema"][rewrite.category_key] = "string"
# Remove the resolve operation
new_ops_list.pop(resolve_pos)
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
dataset: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there are exactly two target ops (map and resolve)
assert (
len(target_ops) == 2
), "There must be exactly two target ops (map and resolve) to instantiate this directive"
# Find the map and resolve operations
map_op = None
resolve_op = None
for op in operators:
if op["name"] == target_ops[0]:
if op.get("type") == "map":
map_op = op
elif op.get("type") == "resolve":
resolve_op = op
elif op["name"] == target_ops[1]:
if op.get("type") == "map":
map_op = op
elif op.get("type") == "resolve":
resolve_op = op
if map_op is None or resolve_op is None:
raise ValueError(
f"Could not find both a map and resolve operation in target_ops: {target_ops}"
)
# Verify the map comes before resolve
map_idx = next(
i for i, op in enumerate(operators) if op["name"] == map_op["name"]
)
resolve_idx = next(
i for i, op in enumerate(operators) if op["name"] == resolve_op["name"]
)
if map_idx >= resolve_idx:
raise ValueError(
f"Map operation '{map_op['name']}' must come before resolve operation '{resolve_op['name']}'"
)
# Load sample data if available
sample_data = None
if dataset:
try:
with open(dataset, "r") as f:
sample_data = json.load(f)
except Exception:
pass # Ignore if we can't load sample data
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
map_op,
resolve_op,
agent_llm,
message_history,
sample_data,
)
# Apply the rewrite to the operators
new_ops_plan = self.apply(
global_default_model, operators, map_op["name"], resolve_op["name"], rewrite
)
return new_ops_plan, message_history, call_cost

View File

@ -0,0 +1,335 @@
import json
from copy import deepcopy
from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import (
MapToMapResolveReduceInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class MapToMapResolveReduceDirective(Directive):
name: str = Field(
default="map_to_map_resolve_reduce", description="The name of the directive"
)
formal_description: str = Field(default="Map -> Reduce => Map -> Resolve -> Reduce")
nl_description: str = Field(
default="Insert a Resolve operation between Map and Reduce to deduplicate or normalize entities before aggregation. The Resolve operation uses code-powered blocking conditions to efficiently identify which pairs to compare, avoiding O(n^2) comparisons. This is useful when the Map output contains duplicate or near-duplicate entities that should be merged before the Reduce step."
)
when_to_use: str = Field(
default="When a Map operation produces outputs that may contain duplicates, variations, or near-duplicates (e.g., different spellings of names, similar categories), and these should be normalized before the Reduce aggregation step. The target must be a Map operation followed by a Reduce operation."
)
instantiate_schema_type: Type[BaseModel] = MapToMapResolveReduceInstantiateSchema
example: str = Field(
default=(
"Original Pipeline:\n"
"- name: extract_companies\n"
" type: map\n"
" prompt: |\n"
" Extract company names from this news article:\n"
" {{ input.article }}\n"
" output:\n"
" schema:\n"
" company_name: string\n"
"\n"
"- name: aggregate_companies\n"
" type: reduce\n"
" reduce_key: sector\n"
" prompt: |\n"
" List all unique companies in this sector:\n"
" {% for input in inputs %}\n"
" - {{ input.company_name }}\n"
" {% endfor %}\n"
" output:\n"
" schema:\n"
" companies: list[str]\n"
"\n"
"Example InstantiateSchema:\n"
"MapToMapResolveReduceInstantiateSchema(\n"
" resolve_name='normalize_company_names',\n"
" comparison_prompt='''Are these two company names referring to the same company?\n"
"Company 1: {{ input1.company_name }}\n"
"Company 2: {{ input2.company_name }}\n"
"Consider variations like abbreviations (IBM vs International Business Machines), \n"
"different legal suffixes (Inc, Corp, LLC), and common misspellings.''',\n"
" resolution_prompt='''Given these variations of a company name:\n"
"{% for input in inputs %}\n"
"- {{ input.company_name }}\n"
"{% endfor %}\n"
"Return the canonical/official company name.''',\n"
" blocking_conditions=[\n"
" \"input1['company_name'][:3].lower() == input2['company_name'][:3].lower()\",\n"
" \"input1['company_name'].split()[0].lower() == input2['company_name'].split()[0].lower()\",\n"
" ],\n"
" blocking_keys=['company_name'],\n"
" limit_comparisons=1000,\n"
" output_schema={'company_name': 'string'},\n"
")"
),
)
test_cases: List[DirectiveTestCase] = Field(
default_factory=lambda: [
DirectiveTestCase(
name="company_name_normalization",
description="Should insert resolve between map and reduce for company names",
input_config=[
{
"name": "extract_companies",
"type": "map",
"prompt": "Extract company name from: {{ input.text }}",
"output": {"schema": {"company_name": "string"}},
},
{
"name": "aggregate_by_sector",
"type": "reduce",
"reduce_key": "sector",
"prompt": "List companies:\n{% for input in inputs %}\n- {{ input.company_name }}\n{% endfor %}",
"output": {"schema": {"companies": "list[str]"}},
},
],
target_ops=["extract_companies", "aggregate_by_sector"],
expected_behavior="Should create a resolve operation between the map and reduce with appropriate blocking conditions",
should_pass=True,
),
]
)
def __eq__(self, other):
return isinstance(other, MapToMapResolveReduceDirective)
def __hash__(self):
return hash("MapToMapResolveReduceDirective")
def to_string_for_instantiate(self, map_op: Dict, reduce_op: Dict) -> str:
"""
Generate a prompt for an agent to instantiate this directive.
Args:
map_op (Dict): The map operation configuration.
reduce_op (Dict): The reduce operation configuration.
Returns:
str: The agent prompt for instantiating the directive.
"""
return (
f"You are an expert at optimizing data processing pipelines by inserting entity resolution steps.\n\n"
f"Map Operation:\n"
f"{str(map_op)}\n\n"
f"Reduce Operation:\n"
f"{str(reduce_op)}\n\n"
f"Directive: {self.name}\n"
f"Your task is to insert a Resolve operation between the Map and Reduce to deduplicate/normalize entities.\n\n"
f"Key Requirements:\n"
f"1. Create a comparison_prompt that determines if two items from the Map output are duplicates/variations:\n"
f" - Must reference {{ input1.key }} and {{ input2.key }} for comparing fields\n"
f" - Should handle common variations (abbreviations, misspellings, formatting differences)\n\n"
f"2. Create a resolution_prompt that merges matched items into a canonical form:\n"
f" - Must use {{% for input in inputs %}} to iterate over matched items\n"
f" - Should produce the most authoritative/complete version\n\n"
f"3. Create blocking_conditions to avoid O(n^2) comparisons:\n"
f" - These are Python expressions with access to 'input1' and 'input2' dicts\n"
f" - They should filter pairs to only those likely to match\n"
f" - Examples:\n"
f" * \"input1['name'][:3].lower() == input2['name'][:3].lower()\" (first 3 chars match)\n"
f" * \"input1['name'].split()[0].lower() == input2['name'].split()[0].lower()\" (first word matches)\n"
f" * \"abs(len(input1['name']) - len(input2['name'])) < 10\" (similar length)\n"
f" - Multiple conditions are OR'd together\n\n"
f"4. Set limit_comparisons to cap the number of pairs (recommended: 500-2000)\n\n"
f"5. The output_schema should match what the Reduce operation expects from each input\n\n"
f"Example:\n"
f"{self.example}\n\n"
f"Please output the MapToMapResolveReduceInstantiateSchema."
)
def llm_instantiate(
self, map_op: Dict, reduce_op: Dict, agent_llm: str, message_history: list = []
):
"""
Use LLM to instantiate this directive.
Args:
map_op (Dict): The map operation configuration.
reduce_op (Dict): The reduce operation configuration.
agent_llm (str): The LLM model to use.
message_history (List, optional): Conversation history for context.
Returns:
MapToMapResolveReduceInstantiateSchema: The structured output from the LLM.
"""
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for document processing pipelines specializing in entity resolution.",
},
{
"role": "user",
"content": self.to_string_for_instantiate(map_op, reduce_op),
},
]
)
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
response_format=MapToMapResolveReduceInstantiateSchema,
)
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = MapToMapResolveReduceInstantiateSchema(**parsed_res)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
)
def apply(
self,
global_default_model,
ops_list: List[Dict],
map_op_name: str,
reduce_op_name: str,
rewrite: MapToMapResolveReduceInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config.
"""
new_ops_list = deepcopy(ops_list)
# Find position of the map and reduce ops
map_pos = None
reduce_pos = None
map_op = None
for i, op in enumerate(ops_list):
if op["name"] == map_op_name:
map_pos = i
map_op = op
elif op["name"] == reduce_op_name:
reduce_pos = i
if map_pos is None or reduce_pos is None:
raise ValueError(
f"Could not find map '{map_op_name}' and reduce '{reduce_op_name}' operations"
)
# Determine the model to use
default_model = map_op.get("model", global_default_model)
# Find the reduce operation to get the reduce_key
reduce_op = None
for op in ops_list:
if op["name"] == reduce_op_name:
reduce_op = op
break
# Derive output schema from reduce_key - that's what's being grouped/resolved
reduce_key = reduce_op.get("reduce_key", []) if reduce_op else []
if isinstance(reduce_key, str):
reduce_key = [reduce_key]
# Build output schema from reduce_key fields
output_schema = {key: "string" for key in reduce_key}
# Create the resolve operation
resolve_op = {
"name": rewrite.resolve_name,
"type": "resolve",
"comparison_prompt": rewrite.comparison_prompt,
"resolution_prompt": rewrite.resolution_prompt,
"blocking_conditions": rewrite.blocking_conditions,
"blocking_keys": rewrite.blocking_keys,
"limit_comparisons": rewrite.limit_comparisons,
"model": default_model,
"output": {"schema": output_schema},
}
# Insert resolve operation after the map operation
new_ops_list.insert(map_pos + 1, resolve_op)
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there are exactly two target ops (map and reduce)
assert (
len(target_ops) == 2
), "There must be exactly two target ops (map and reduce) to instantiate this directive"
# Find the map and reduce operations
map_op = None
reduce_op = None
for op in operators:
if op["name"] == target_ops[0]:
if op.get("type") == "map":
map_op = op
elif op.get("type") == "reduce":
reduce_op = op
elif op["name"] == target_ops[1]:
if op.get("type") == "map":
map_op = op
elif op.get("type") == "reduce":
reduce_op = op
if map_op is None or reduce_op is None:
raise ValueError(
f"Could not find both a map and reduce operation in target_ops: {target_ops}"
)
# Verify the map comes before reduce
map_idx = next(
i for i, op in enumerate(operators) if op["name"] == map_op["name"]
)
reduce_idx = next(
i for i, op in enumerate(operators) if op["name"] == reduce_op["name"]
)
if map_idx >= reduce_idx:
raise ValueError(
f"Map operation '{map_op['name']}' must come before reduce operation '{reduce_op['name']}'"
)
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(
map_op,
reduce_op,
agent_llm,
message_history,
)
# Apply the rewrite to the operators
new_ops_plan = self.apply(
global_default_model, operators, map_op["name"], reduce_op["name"], rewrite
)
return new_ops_plan, message_history, call_cost

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -172,17 +171,14 @@ class OperatorFusionDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=OperatorFusionInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = OperatorFusionInstantiateSchema(**parsed_res)
@ -191,11 +187,12 @@ class OperatorFusionDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
@ -215,7 +212,6 @@ class OperatorFusionDirective(Directive):
new_ops_list = deepcopy(ops_list)
op1_name, op2_name = target_ops[0], target_ops[1]
# Find the operations
op1_idx = next(i for i, op in enumerate(ops_list) if op["name"] == op1_name)
op2_idx = next(i for i, op in enumerate(ops_list) if op["name"] == op2_name)
@ -329,5 +325,5 @@ def transform(input_doc):
return (
self.apply(global_default_model, operators, target_ops, rewrite),
message_history,
call_cost
call_cost,
)

View File

@ -1,5 +1,4 @@
import json
import os
import re
from copy import deepcopy
from typing import Dict, List, Type
@ -127,7 +126,7 @@ class ReduceChainingDirective(Directive):
original_op: Dict,
expected_document_key: str,
agent_llm: str,
message_history: list = []
message_history: list = [],
):
"""
Use LLM to instantiate this directive by decomposing the reduce operation.
@ -155,19 +154,15 @@ class ReduceChainingDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
call_cost = 0
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ReduceChainingInstantiateSchema
response_format=ReduceChainingInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ReduceChainingInstantiateSchema(**parsed_res)
@ -182,11 +177,12 @@ class ReduceChainingDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
@ -283,7 +279,9 @@ class ReduceChainingDirective(Directive):
]
if document_key_candidates:
expected_document_key = document_key_candidates[0] # Pick the first candidate
expected_document_key = document_key_candidates[
0
] # Pick the first candidate
elif input_keys:
expected_document_key = input_keys[0] # Fall back to the first input key
else:

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -255,17 +254,14 @@ class ReduceGleaningDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=GleaningInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = GleaningInstantiateSchema(**parsed_res)
@ -274,11 +270,12 @@ class ReduceGleaningDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
@ -341,5 +338,6 @@ class ReduceGleaningDirective(Directive):
# Apply the rewrite to the operators
return (
self.apply(global_default_model, operators, target_ops[0], rewrite),
message_history, call_cost
message_history,
call_cost,
)

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -234,18 +233,14 @@ class TakeHeadTailDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
call_cost = 0
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=TakeHeadTailInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = TakeHeadTailInstantiateSchema(**parsed_res)
@ -254,11 +249,12 @@ class TakeHeadTailDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -61,7 +61,7 @@ class MapOpConfig(BaseModel):
...,
description="The keys of the output of the Map operator, to be referenced in the downstream operator's prompt. Can be a single key or a list of keys. Can be new keys or existing keys from the map operator we are rewriting.",
)
@classmethod
def validate_prompt_contains_input_key(cls, value: str) -> str:
"""
@ -445,8 +445,12 @@ class DeterministicDocCompressionInstantiateSchema(BaseModel):
raise ValueError(
"Code must define a function named 'transform' that takes input_doc as parameter"
)
if "return {" not in v and "return dict(" not in v:
raise ValueError("Code must return a dictionary")
if (
"return {" not in v
and "return dict(" not in v
and "output = dict(" not in v
):
raise ValueError(f"Code must return a dictionary. Instead the code is: {v}")
return v
def validate_code_returns_target_keys(self, target_ops_configs: List[Dict]) -> None:
@ -563,7 +567,6 @@ class ChunkHeaderSummaryInstantiateSchema(BaseModel):
return v
class SamplingConfig(BaseModel):
"""Configuration for optional sampling in document chunking."""
@ -634,7 +637,7 @@ class DocumentChunkingInstantiateSchema(BaseModel):
default=None,
description="Optional sampling configuration. If provided, inserts a Sample operation between Gather and Map. Use by default UNLESS task requires processing ALL chunks (like comprehensive extraction of all instances).",
)
def validate_split_key_exists_in_input(self, input_file_path: str) -> None:
"""
Validates that the split_key exists in the input JSON file items.
@ -812,7 +815,6 @@ class DocumentChunkingTopKInstantiateSchema(BaseModel):
description="Configuration for the topk operation to select relevant chunks",
)
def validate_split_key_exists_in_input(self, input_file_path: str) -> None:
"""
Validates that the split_key exists in the input JSON file items.
@ -1211,6 +1213,109 @@ class SearchReplaceEdit(BaseModel):
# )
class MapToMapResolveReduceInstantiateSchema(BaseModel):
"""
Schema for inserting a Resolve operation between Map and Reduce.
Transforms Map -> Reduce => Map -> Resolve -> Reduce pattern.
The Resolve operation deduplicates/normalizes entities before aggregation.
"""
resolve_name: str = Field(..., description="The name of the new Resolve operator")
comparison_prompt: str = Field(
...,
description="Jinja prompt template for comparing two items. Must use {{ input1.key }} and {{ input2.key }} to reference fields from both items being compared.",
)
resolution_prompt: str = Field(
...,
description="Jinja prompt template for resolving matched items into a canonical form. Must use {% for input in inputs %} to iterate over matched items.",
)
blocking_conditions: List[str] = Field(
...,
description="List of Python expressions that determine if two items should be compared. Each expression has access to 'input1' and 'input2' dicts. Example: \"input1['category'].lower() == input2['category'].lower()\"",
)
blocking_keys: List[str] = Field(
...,
description="Keys to use for blocking/grouping items before comparison. Must include at least the reduce_key from the downstream Reduce operation, plus any additional context helpful for resolution.",
)
limit_comparisons: int = Field(
default=15000,
description="Maximum number of pairs to compare. Code-based blocked pairs are prioritized. Defaults to 15000 to avoid O(n^2) comparisons.",
gt=0,
)
@field_validator("comparison_prompt")
@classmethod
def check_comparison_prompt(cls, v: str) -> str:
if "input1" not in v or "input2" not in v:
raise ValueError(
"comparison_prompt must reference both 'input1' and 'input2' variables"
)
return v
@field_validator("resolution_prompt")
@classmethod
def check_resolution_prompt(cls, v: str) -> str:
if "inputs" not in v:
raise ValueError(
"resolution_prompt must reference 'inputs' variable for iterating over matched items"
)
return v
@field_validator("blocking_conditions")
@classmethod
def check_blocking_conditions(cls, v: List[str]) -> List[str]:
if not v:
raise ValueError(
"At least one blocking condition must be provided to avoid O(n^2) comparisons"
)
for condition in v:
if "input1" not in condition or "input2" not in condition:
raise ValueError(
f"Blocking condition must reference both 'input1' and 'input2': {condition}"
)
return v
class MapResolveToMapWithCategoriesInstantiateSchema(BaseModel):
"""
Schema for replacing Map -> Resolve with a single Map that has predefined categories.
The agent proposes categories based on analysis of the data/task, and the new Map
forces outputs into one of these categories (or 'none of the above'), effectively
doing entity resolution deterministically.
"""
categories: List[str] = Field(
...,
description="List of valid category values that the map output should be constrained to. Should include all distinct canonical values discovered from analyzing the data/task.",
)
category_key: str = Field(
...,
description="The key in the output schema that will contain the category value",
)
new_prompt: str = Field(
...,
description="The new prompt for the Map operation that includes the category list and instructs the LLM to output one of the predefined categories. Must reference {{ input.key }} for input fields.",
)
include_none_of_above: bool = Field(
default=True,
description="Whether to include 'None of the above' as a valid category option for items that don't fit any category",
)
@field_validator("categories")
@classmethod
def check_categories(cls, v: List[str]) -> List[str]:
if len(v) < 2:
raise ValueError("At least 2 categories must be provided")
if len(v) != len(set(v)):
raise ValueError("Categories must be unique")
return v
@field_validator("new_prompt")
@classmethod
def check_new_prompt(cls, v: str) -> str:
return MapOpConfig.validate_prompt_contains_input_key(v)
class ArbitraryRewriteInstantiateSchema(BaseModel):
"""
Schema for arbitrary pipeline rewrites using search/replace edits.

View File

@ -699,7 +699,7 @@ class DSLRunner(ConfigWrapper):
"litellm_kwargs", {}
)
kwargs["rewrite_agent_model"] = self.config.get("optimizer_config", {}).get(
"rewrite_agent_model", "gpt-4o"
"rewrite_agent_model", "gpt-5.1"
)
kwargs["judge_agent_model"] = self.config.get("optimizer_config", {}).get(
"judge_agent_model", "gpt-4o-mini"
@ -727,7 +727,7 @@ class DSLRunner(ConfigWrapper):
"litellm_kwargs", {}
)
kwargs["rewrite_agent_model"] = self.config.get("optimizer_config", {}).get(
"rewrite_agent_model", "gpt-4o"
"rewrite_agent_model", "gpt-5.1"
)
kwargs["judge_agent_model"] = self.config.get("optimizer_config", {}).get(
"judge_agent_model", "gpt-4o-mini"

View File

@ -34,9 +34,9 @@ This Equijoin operation matches job candidates to job postings:
The prompt template uses Jinja2 syntax, allowing you to reference input fields directly (e.g., `left.skills`). You can reference the left and right documents using `left` and `right` respectively.
!!! warning "Performance Consideration"
!!! info "Automatic Blocking"
For large datasets, running comparisons with an LLM can be time-consuming. It's recommended to optimize your pipeline using `docetl build pipeline.yaml` to generate efficient blocking rules for the operation.
If you don't specify any blocking configuration (`blocking_threshold`, `blocking_conditions`, or `limit_comparisons`), the Equijoin operation will automatically compute an optimal embedding-based blocking threshold at runtime. It samples pairs from your data, runs LLM comparisons on the sample, and finds a threshold that achieves 95% recall by default. You can adjust this with the `blocking_target_recall` parameter.
## Blocking
@ -95,10 +95,19 @@ A full Equijoin step combining both ideas might look like:
Equijoin shares many parameters with the Resolve operation. For a detailed list of required and optional parameters, please see the [Parameters section in the Resolve operation documentation](resolve.md#required-parameters).
Key differences for Equijoin include:
### Equijoin-Specific Parameters
| Parameter | Description | Default |
| ------------------------- | --------------------------------------------------------------------------------- | ----------------------------- |
| `limits` | Maximum matches for each left/right item: `{"left": n, "right": m}` | No limit |
| `blocking_keys` | Keys for embedding blocking: `{"left": [...], "right": [...]}` | All keys from each dataset |
| `blocking_threshold` | Embedding similarity threshold for considering pairs | Auto-computed if not set |
| `blocking_target_recall` | Target recall when auto-computing blocking threshold (0.0 to 1.0) | 0.95 |
Key differences from Resolve:
- `resolution_prompt` is not used in Equijoin.
- `limits` parameter is specific to Equijoin, allowing you to set maximum matches for each left and right item.
- `blocking_keys` uses a dict with `left` and `right` keys instead of a simple list.
## Incorporating Into a Pipeline

View File

@ -44,9 +44,9 @@ This Resolve operation processes patient names to identify and standardize dupli
Note: The prompt templates use Jinja2 syntax, allowing you to reference input fields directly (e.g., `input1.patient_name`).
!!! warning "Performance Consideration"
!!! info "Automatic Blocking"
You should not run this operation as-is unless your dataset is small! Running O(n^2) comparisons with an LLM can be extremely time-consuming for large datasets. Instead, optimize your pipeline first using `docetl build pipeline.yaml` and run the optimized version, which will generate efficient blocking rules for the operation. Make sure you've set `optimize: true` in your resolve operation config.
If you don't specify any blocking configuration (`blocking_threshold`, `blocking_conditions`, or `limit_comparisons`), the Resolve operation will automatically compute an optimal embedding-based blocking threshold at runtime. It samples pairs from your data, runs LLM comparisons on the sample, and finds a threshold that achieves 95% recall by default. You can adjust this with the `blocking_target_recall` parameter.
## Blocking
@ -132,7 +132,8 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un
| `resolution_model` | The language model to use for reducing matched entries | Falls back to `default_model` |
| `comparison_model` | The language model to use for comparing potential matches | Falls back to `default_model` |
| `blocking_keys` | List of keys to use for initial blocking | All keys in the input data |
| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | None |
| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | Auto-computed if not set |
| `blocking_target_recall` | Target recall when auto-computing blocking threshold (0.0 to 1.0) | 0.95 |
| `blocking_conditions` | List of conditions for initial blocking | [] |
| `input` | Specifies the schema or keys to subselect from each item to pass into the prompts | All keys from input items |
| `embedding_batch_size` | The number of entries to send to the embedding model at a time | 1000 |
@ -140,9 +141,9 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un
| `limit_comparisons` | Maximum number of comparisons to perform | None |
| `timeout` | Timeout for each LLM call in seconds | 120 |
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |
| `sample` | Number of samples to use for the operation | None |
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
| `bypass_cache` | If true, bypass the cache for this operation. | False |
| `sample` | Number of samples to use for the operation | None |
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
| `bypass_cache` | If true, bypass the cache for this operation. | False |
## Best Practices

View File

@ -14,7 +14,7 @@ All fields in `optimizer_config` are required (no defaults):
| `evaluation_file` | `str` | Path to Python file containing `@register_eval` decorated function |
| `metric_key` | `str` | Key in evaluation results dictionary to use as accuracy metric |
| `max_iterations` | `int` | Maximum number of MOARSearch iterations to run |
| `model` | `str` | LLM model to use for directive instantiation during search |
| `rewrite_agent_model` | `str` | LLM model to use for directive instantiation during search |
!!! warning "All Fields Required"
MOAR will error if any required field is missing. There are no defaults.
@ -67,19 +67,19 @@ The optimizer will use the sample dataset, but your final pipeline uses the full
### Available Models
!!! info "LiteLLM Model Names"
Use LiteLLM model names (e.g., `gpt-4o-mini`, `gpt-4o`, `gpt-4.1`). Make sure your API keys are set in your environment.
Use LiteLLM model names (e.g., `gpt-4o-mini`, `gpt-4o`, `gpt-5.1`). Make sure your API keys are set in your environment.
```yaml
available_models: # LiteLLM model names - ensure API keys are set
- gpt-4.1-nano # Cheapest, lower accuracy
- gpt-4.1-mini # Low cost, decent accuracy
- gpt-4.1 # Balanced
- gpt-5.1-nano # Cheapest, lower accuracy
- gpt-5.1-mini # Low cost, decent accuracy
- gpt-5.1 # Balanced
- gpt-4o # Higher cost, better accuracy
```
### Model for Directive Instantiation
The `model` field specifies which LLM to use for generating optimization directives during the search process. This doesn't affect the models tested in `available_models`.
The `rewrite_agent_model` field specifies which LLM to use for generating optimization directives during the search process. This doesn't affect the models tested in `available_models`.
!!! tip "Cost Consideration"
Use a cheaper model (like `gpt-4o-mini`) for directive instantiation to reduce search costs.
@ -104,12 +104,12 @@ optimizer_config:
available_models:
- gpt-4o-mini
- gpt-4o
- gpt-4.1-mini
- gpt-4.1
- gpt-5.1-mini
- gpt-5.1
evaluation_file: evaluate_medications.py
metric_key: medication_extraction_score
max_iterations: 40
model: gpt-4.1
rewrite_agent_model: gpt-5.1
dataset_path: data/sample.json # Optional
exploration_weight: 1.414 # Optional
```

View File

@ -25,15 +25,15 @@ optimizer_config:
dataset_path: workloads/medical/raw_sample.json # Use sample for faster optimization
save_dir: workloads/medical/moar_results
available_models: # LiteLLM model names - ensure API keys are set in your environment
- gpt-4.1-nano
- gpt-4.1-mini
- gpt-4.1
- gpt-5.1-nano
- gpt-5.1-mini
- gpt-5.1
- gpt-4o
- gpt-4o-mini
evaluation_file: workloads/medical/evaluate_medications.py
metric_key: medication_extraction_score
max_iterations: 40
model: gpt-4.1
rewrite_agent_model: gpt-5.1
system_prompt:
dataset_description: a collection of transcripts of doctor visits

View File

@ -109,12 +109,12 @@ optimizer_config:
available_models: # LiteLLM model names - ensure API keys are set in your environment
- gpt-4o-mini
- gpt-4o
- gpt-4.1-mini
- gpt-4.1
- gpt-5.1-mini
- gpt-5.1
evaluation_file: evaluate_medications.py
metric_key: medication_extraction_score # This must match a key in your evaluation function's return dictionary
max_iterations: 40
model: gpt-4.1
rewrite_agent_model: gpt-5.1
dataset_path: data/transcripts_sample.json # Optional: use sample/hold-out dataset
```

View File

@ -19,7 +19,7 @@ High-level summary of the optimization run:
{
"optimizer": "moar",
"input_pipeline": "pipeline.yaml",
"model": "gpt-4.1",
"rewrite_agent_model": "gpt-5.1",
"max_iterations": 40,
"save_dir": "results/moar_optimization",
"dataset": "transcripts",

View File

@ -80,9 +80,9 @@ Test your evaluation function independently and check MOAR logs for errors.
```yaml
available_models:
- gpt-4.1-nano # Cheapest, lower accuracy
- gpt-4.1-mini # Low cost, decent accuracy
- gpt-4.1 # Balanced
- gpt-5.1-nano # Cheapest, lower accuracy
- gpt-5.1-mini # Low cost, decent accuracy
- gpt-5.1 # Balanced
- gpt-4o # Higher cost, better accuracy
```

View File

@ -267,10 +267,6 @@ class AgentCommunicator:
response = completion(
model=self.model,
messages=messages,
azure=True,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
response_format={"type": "json_object"}
)
@ -290,10 +286,6 @@ class AgentCommunicator:
response = completion(
model=self.model,
messages=messages + [{"role": "user", "content": request_msg}],
azure=True,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
response_format={"type": "json_object"}
)

View File

@ -27,6 +27,7 @@ class OptimizeResult(BaseModel):
should_optimize: str | None = None
input_data: list[dict[str, Any]] | None = None
output_data: list[dict[str, Any]] | None = None
num_docs_analyzed: int | None = None
cost: float | None = None
error: str | None = None
created_at: datetime
@ -35,4 +36,25 @@ class OptimizeResult(BaseModel):
class OptimizeRequest(BaseModel):
yaml_config: str
step_name: str
op_name: str
op_name: str
class DecomposeRequest(BaseModel):
yaml_config: str
step_name: str
op_name: str
class DecomposeResult(BaseModel):
task_id: str
status: TaskStatus
decomposed_operations: list[dict[str, Any]] | None = None
winning_directive: str | None = None
candidates_evaluated: int | None = None
original_outputs: list[dict[str, Any]] | None = None
decomposed_outputs: list[dict[str, Any]] | None = None
comparison_rationale: str | None = None
cost: float | None = None
error: str | None = None
created_at: datetime
completed_at: datetime | None = None

View File

@ -2,13 +2,23 @@ from typing import Any
import uuid
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
from docetl.runner import DSLRunner
from docetl.optimizers.fast_should_optimize import FastShouldOptimizeAnalyzer
from docetl.optimizers.fast_decomposer import FastDecomposer
import asyncio
from asyncio import Task
from rich.logging import RichHandler
import logging
import yaml
from datetime import datetime, timedelta
from enum import Enum
from server.app.models import OptimizeResult, TaskStatus, OptimizeRequest, PipelineRequest
from server.app.models import (
OptimizeResult,
TaskStatus,
OptimizeRequest,
PipelineRequest,
DecomposeRequest,
DecomposeResult,
)
# Setup logging
FORMAT = "%(message)s"
@ -18,10 +28,14 @@ logging.basicConfig(
router = APIRouter()
# Task storage
# Task storage for optimize tasks
tasks: dict[str, OptimizeResult] = {}
asyncio_tasks: dict[str, Task] = {}
# Task storage for decompose tasks
decompose_tasks: dict[str, DecomposeResult] = {}
decompose_asyncio_tasks: dict[str, Task] = {}
# Configuration
COMPLETED_TASK_TTL = timedelta(hours=1)
@ -30,50 +44,111 @@ async def cleanup_old_tasks():
while True:
try:
current_time = datetime.now()
task_ids_to_remove = []
# Clean up optimize tasks
task_ids_to_remove = []
for task_id, task in tasks.items():
if (task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED] and
task.completed_at and
task.completed_at and
current_time - task.completed_at > COMPLETED_TASK_TTL):
task_ids_to_remove.append(task_id)
for task_id in task_ids_to_remove:
del tasks[task_id]
# Clean up decompose tasks
decompose_ids_to_remove = []
for task_id, task in decompose_tasks.items():
if (task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED] and
task.completed_at and
current_time - task.completed_at > COMPLETED_TASK_TTL):
decompose_ids_to_remove.append(task_id)
for task_id in decompose_ids_to_remove:
del decompose_tasks[task_id]
await asyncio.sleep(60)
except Exception as e:
logging.error(f"Error in cleanup task: {e}")
await asyncio.sleep(60)
async def run_optimization(task_id: str, yaml_config: str, step_name: str, op_name: str):
"""Execute the optimization task"""
"""Execute the optimization task using fast single-LLM-call analysis."""
try:
tasks[task_id].status = TaskStatus.PROCESSING
# Run the actual optimization in a separate thread to not block
runner = DSLRunner.from_yaml(yaml_config)
should_optimize, input_data, output_data, cost = await asyncio.to_thread(
runner.should_optimize,
# yaml_config is a file path, not YAML content - read and parse the file
if yaml_config.endswith(".yaml") or yaml_config.endswith(".yml"):
with open(yaml_config, "r") as f:
config = yaml.safe_load(f)
else:
# Fallback: try parsing as YAML string
config = yaml.safe_load(yaml_config)
# Validate that we got a dict
if not isinstance(config, dict):
raise ValueError(
f"Invalid yaml_config: expected dict after parsing, got {type(config).__name__}. "
f"Input: {str(yaml_config)[:200]}"
)
# Get intermediate directory from config
intermediate_dir = (
config.get("pipeline", {})
.get("output", {})
.get("intermediate_dir")
)
if not intermediate_dir:
raise ValueError("No intermediate_dir configured in pipeline output")
# Find the operation config
op_config = None
for op in config.get("operations", []):
if op.get("name") == op_name:
op_config = op
break
if not op_config:
raise ValueError(f"Operation '{op_name}' not found in config")
# Get optimizer settings from config
optimizer_model = (
config.get("optimizer_config", {})
.get("rewrite_agent_model", "gpt-5.1")
)
litellm_kwargs = (
config.get("optimizer_config", {})
.get("litellm_kwargs", {})
)
# Create analyzer and run in thread pool
analyzer = FastShouldOptimizeAnalyzer(
intermediate_dir=intermediate_dir,
optimizer_model=optimizer_model,
litellm_kwargs=litellm_kwargs,
)
should_optimize, output_data, num_docs_analyzed, cost = await asyncio.to_thread(
analyzer.analyze,
op_config,
step_name,
op_name,
)
# Update task result
tasks[task_id].status = TaskStatus.COMPLETED
tasks[task_id].should_optimize = should_optimize
tasks[task_id].input_data = input_data
tasks[task_id].input_data = None # We don't have input_data in fast mode
tasks[task_id].output_data = output_data
tasks[task_id].num_docs_analyzed = num_docs_analyzed
tasks[task_id].cost = cost
tasks[task_id].completed_at = datetime.now()
except asyncio.CancelledError:
runner.is_cancelled = True
tasks[task_id].status = TaskStatus.CANCELLED
tasks[task_id].completed_at = datetime.now()
raise
except Exception as e:
import traceback
error_traceback = traceback.format_exc()
@ -81,11 +156,10 @@ async def run_optimization(task_id: str, yaml_config: str, step_name: str, op_na
tasks[task_id].error = f"{str(e)}\n{error_traceback}"
tasks[task_id].completed_at = datetime.now()
raise
finally:
if task_id in asyncio_tasks:
del asyncio_tasks[task_id]
runner.reset_env()
@router.on_event("startup")
async def startup_event():
@ -153,6 +227,140 @@ async def cancel_optimize_task(task_id: str):
return {"message": "Task cancelled successfully"}
# ============================================================================
# Decompose endpoints
# ============================================================================
async def run_decomposition(task_id: str, yaml_config: str, step_name: str, op_name: str):
"""Execute the decomposition task using fast directive-based approach."""
try:
decompose_tasks[task_id].status = TaskStatus.PROCESSING
# yaml_config is a file path
if not (yaml_config.endswith(".yaml") or yaml_config.endswith(".yml")):
raise ValueError("yaml_config must be a path to a YAML file")
# Get optimizer settings from config
with open(yaml_config, "r") as f:
config = yaml.safe_load(f)
optimizer_model = (
config.get("optimizer_config", {})
.get("rewrite_agent_model", "gpt-5.1")
)
litellm_kwargs = (
config.get("optimizer_config", {})
.get("litellm_kwargs", {})
)
# Create decomposer and run in thread pool
decomposer = FastDecomposer(
yaml_config_path=yaml_config,
optimizer_model=optimizer_model,
sample_size=5,
litellm_kwargs=litellm_kwargs,
)
result = await asyncio.to_thread(
decomposer.decompose,
step_name,
op_name,
)
# Update task result
decompose_tasks[task_id].status = TaskStatus.COMPLETED
decompose_tasks[task_id].decomposed_operations = result["decomposed_ops"]
decompose_tasks[task_id].winning_directive = result["winning_directive"]
decompose_tasks[task_id].candidates_evaluated = result["candidates_evaluated"]
decompose_tasks[task_id].original_outputs = result["original_outputs"]
decompose_tasks[task_id].decomposed_outputs = result["decomposed_outputs"]
decompose_tasks[task_id].comparison_rationale = result["comparison_rationale"]
decompose_tasks[task_id].cost = result["cost"]
decompose_tasks[task_id].completed_at = datetime.now()
except asyncio.CancelledError:
decompose_tasks[task_id].status = TaskStatus.CANCELLED
decompose_tasks[task_id].completed_at = datetime.now()
raise
except Exception as e:
import traceback
error_traceback = traceback.format_exc()
decompose_tasks[task_id].status = TaskStatus.FAILED
decompose_tasks[task_id].error = f"{str(e)}\n{error_traceback}"
decompose_tasks[task_id].completed_at = datetime.now()
raise
finally:
if task_id in decompose_asyncio_tasks:
del decompose_asyncio_tasks[task_id]
@router.post("/decompose", status_code=202)
async def submit_decompose_task(request: DecomposeRequest):
"""Submit a new decomposition task"""
task_id = str(uuid.uuid4())
# Create task record
decompose_tasks[task_id] = DecomposeResult(
task_id=task_id,
status=TaskStatus.PENDING,
created_at=datetime.now()
)
# Create and store the asyncio task
task = asyncio.create_task(
run_decomposition(
task_id,
request.yaml_config,
request.step_name,
request.op_name
)
)
decompose_asyncio_tasks[task_id] = task
return {"task_id": task_id}
@router.get("/decompose/{task_id}")
async def get_decompose_status(task_id: str) -> DecomposeResult:
"""Get the current status of a decomposition task"""
if task_id not in decompose_tasks:
raise HTTPException(
status_code=404,
detail="Task not found or has been cleaned up"
)
return decompose_tasks[task_id]
@router.post("/decompose/{task_id}/cancel")
async def cancel_decompose_task(task_id: str):
"""Cancel a running decomposition task"""
if task_id not in decompose_tasks:
raise HTTPException(
status_code=404,
detail="Task not found or has been cleaned up"
)
if task_id not in decompose_asyncio_tasks:
raise HTTPException(
status_code=400,
detail="Task already finished or cannot be cancelled"
)
asyncio_task = decompose_asyncio_tasks[task_id]
asyncio_task.cancel()
try:
await asyncio_task
except asyncio.CancelledError:
pass
return {"message": "Decomposition task cancelled successfully"}
# Keep the original run_pipeline endpoint
@router.post("/run_pipeline")
def run_pipeline(request: PipelineRequest) -> dict[str, Any]:
@ -169,6 +377,12 @@ def run_pipeline(request: PipelineRequest) -> dict[str, Any]:
@router.websocket("/ws/run_pipeline/{client_id}")
async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
"""
WebSocket endpoint for running pipelines with real-time output streaming.
Note: The old 'optimize' flag for full pipeline optimization has been removed.
Use the /decompose endpoint for fast operation decomposition instead.
"""
await websocket.accept()
runner = None
try:
@ -178,19 +392,8 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
if config.get("clear_intermediate", False):
runner.clear_intermediate()
if config.get("optimize", False):
logging.info(f"Optimizing pipeline with model {config.get('optimizer_model', 'gpt-4o')}")
# Set the runner config to the optimizer config
runner.config["optimizer_config"]["rewrite_agent_model"] = config.get("optimizer_model", "gpt-4o")
runner.config["optimizer_config"]["judge_agent_model"] = config.get("optimizer_model", "gpt-4o-mini")
async def run_pipeline():
return await asyncio.to_thread(runner.optimize, return_pipeline=False)
else:
async def run_pipeline():
return await asyncio.to_thread(runner.load_run_save)
async def run_pipeline():
return await asyncio.to_thread(runner.load_run_save)
pipeline_task = asyncio.create_task(run_pipeline())
@ -198,18 +401,6 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
console_output = runner.console.file.getvalue()
await websocket.send_json({"type": "output", "data": console_output})
if config.get("optimize", False):
optimizer_progress = runner.console.get_optimizer_progress()
rationale = runner.console.optimizer_rationale
await websocket.send_json({
"type": "optimizer_progress",
"status": optimizer_progress[0],
"progress": optimizer_progress[1],
"rationale": rationale[1] if rationale is not None else "",
"should_optimize": rationale[0] if rationale is not None else False,
"validator_prompt": rationale[2] if rationale is not None else ""
})
# Check for incoming messages from the user
try:
user_message = await asyncio.wait_for(
@ -219,8 +410,7 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
if user_message == "kill":
runner.console.log("Stopping process...")
runner.is_cancelled = True
await websocket.send_json({
"type": "error",
"message": "Process stopped by user request"
@ -250,41 +440,16 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
# Sleep for a short duration to ensure all output is captured
await asyncio.sleep(3)
# If optimize is true, send back the optimized operations
if config.get("optimize", False):
optimized_config, cost = result
# Send the operations back in order
new_pipeline_steps = optimized_config["pipeline"]["steps"]
new_pipeline_op_name_to_op_map = {op["name"]: op for op in optimized_config["operations"]}
new_ops_in_order = []
for new_step in new_pipeline_steps:
for op in new_step.get("operations", []):
if op not in new_ops_in_order:
new_ops_in_order.append(new_pipeline_op_name_to_op_map[op])
await websocket.send_json(
{
"type": "result",
"data": {
"message": "Pipeline executed successfully",
"cost": cost,
"optimized_ops": new_ops_in_order,
"yaml_config": config["yaml_config"],
},
}
)
else:
await websocket.send_json(
{
"type": "result",
"data": {
"message": "Pipeline executed successfully",
"cost": result,
"yaml_config": config["yaml_config"],
},
}
)
await websocket.send_json(
{
"type": "result",
"data": {
"message": "Pipeline executed successfully",
"cost": result,
"yaml_config": config["yaml_config"],
},
}
)
except WebSocketDisconnect:
if runner is not None:
runner.reset_env()
@ -299,3 +464,159 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
if runner is not None:
runner.reset_env()
await websocket.close()
@router.websocket("/ws/decompose/{client_id}")
async def websocket_decompose(websocket: WebSocket, client_id: str):
"""
WebSocket endpoint for fast operation decomposition with real-time output streaming.
Expects JSON message with:
- yaml_config: Path to the pipeline YAML config file
- step_name: Name of the pipeline step
- op_name: Name of the operation to decompose
"""
await websocket.accept()
decomposer = None
try:
config = await websocket.receive_json()
yaml_config = config["yaml_config"]
step_name = config["step_name"]
op_name = config["op_name"]
# Validate yaml_config is a path
if not (yaml_config.endswith(".yaml") or yaml_config.endswith(".yml")):
raise ValueError("yaml_config must be a path to a YAML file")
# Get optimizer settings from config
with open(yaml_config, "r") as f:
pipeline_config = yaml.safe_load(f)
optimizer_model = (
pipeline_config.get("optimizer_config", {})
.get("rewrite_agent_model", "gpt-5.1")
)
litellm_kwargs = (
pipeline_config.get("optimizer_config", {})
.get("litellm_kwargs", {})
)
# Create a ThreadSafeConsole for streaming output
from docetl.console import ThreadSafeConsole
console = ThreadSafeConsole(
force_terminal=True,
soft_wrap=True,
highlight=False,
log_path=False,
color_system="truecolor",
width=120,
style="bright_white on black",
record=True,
)
# Create decomposer with the console
decomposer = FastDecomposer(
yaml_config_path=yaml_config,
optimizer_model=optimizer_model,
sample_size=5,
litellm_kwargs=litellm_kwargs,
console=console,
)
# Run decomposition in a thread
async def run_decomposition():
return await asyncio.to_thread(
decomposer.decompose,
step_name,
op_name,
)
decompose_task = asyncio.create_task(run_decomposition())
# Stream console output while decomposition runs
accumulated_output = ""
while not decompose_task.done():
# get_output() processes carriage returns and clears the buffer
new_output = console.get_output()
if new_output:
# Handle spinner updates: if new output doesn't start with newline
# and accumulated doesn't end with newline, replace the last line
if accumulated_output and not accumulated_output.endswith('\n') and not new_output.startswith('\n'):
# Find the last newline in accumulated output
last_newline = accumulated_output.rfind('\n')
if last_newline >= 0:
# Replace everything after the last newline
accumulated_output = accumulated_output[:last_newline + 1] + new_output
else:
# No newline in accumulated, replace everything
accumulated_output = new_output
else:
accumulated_output += new_output
await websocket.send_json({"type": "output", "data": accumulated_output})
# Check for kill message
try:
user_message = await asyncio.wait_for(
websocket.receive_json(), timeout=0.1
)
if user_message == "kill":
console.log("[red]Stopping decomposition...[/red]")
decompose_task.cancel()
await websocket.send_json({
"type": "error",
"message": "Decomposition stopped by user request"
})
raise Exception("Process stopped by user request")
except asyncio.TimeoutError:
pass
except asyncio.CancelledError:
await websocket.send_json({
"type": "error",
"message": "Decomposition stopped by user request"
})
raise
await asyncio.sleep(0.5)
# Get final result
result = await decompose_task
# Send any remaining console output
final_output = console.get_output()
if final_output:
accumulated_output += final_output
await websocket.send_json({"type": "output", "data": accumulated_output})
await asyncio.sleep(1)
# Send the result
await websocket.send_json({
"type": "result",
"data": {
"decomposed_operations": result["decomposed_ops"],
"winning_directive": result["winning_directive"],
"candidates_evaluated": result["candidates_evaluated"],
"original_outputs": result["original_outputs"],
"decomposed_outputs": result["decomposed_outputs"],
"comparison_rationale": result["comparison_rationale"],
"cost": result["cost"],
},
})
except WebSocketDisconnect:
print(f"Decompose client {client_id} disconnected")
except Exception as e:
import traceback
error_traceback = traceback.format_exc()
print(f"Decompose error:\n{error_traceback}")
try:
await websocket.send_json({
"type": "error",
"data": str(e),
"traceback": error_traceback
})
except Exception:
pass
finally:
await websocket.close()

View File

@ -6,7 +6,7 @@ Simple apply tests for directive testing - just ensure apply() doesn't crash.
# Simple apply tests - no pytest needed
from docetl.reasoning_optimizer.directives import (
ChainingDirective,
ChainingDirective,
GleaningDirective,
ReduceGleaningDirective,
ReduceChainingDirective,
@ -14,7 +14,6 @@ from docetl.reasoning_optimizer.directives import (
OperatorFusionDirective,
DocSummarizationDirective,
IsolatingSubtasksDirective,
DocCompressionDirective,
DeterministicDocCompressionDirective,
DocumentChunkingDirective,
DocumentChunkingTopKDirective,
@ -22,8 +21,12 @@ from docetl.reasoning_optimizer.directives import (
TakeHeadTailDirective,
ClarifyInstructionsDirective,
SwapWithCodeDirective,
HierarchicalReduceDirective
HierarchicalReduceDirective,
MapToMapResolveReduceDirective,
MapResolveToMapWithCategoriesDirective,
)
# DocCompressionDirective is commented out in __init__.py, import directly
from docetl.reasoning_optimizer.directives.doc_compression import DocCompressionDirective
def test_chaining_apply():
@ -1012,6 +1015,237 @@ def test_cascade_filtering_apply():
assert "high-quality research paper" in result[5]["prompt"]
def test_map_to_map_resolve_reduce_apply():
"""Test that map_to_map_resolve_reduce apply doesn't crash"""
directive = MapToMapResolveReduceDirective()
# Pipeline with map followed by reduce
ops_list = [
{
"name": "extract_companies",
"type": "map",
"prompt": "Extract company name from: {{ input.article }}",
"model": "gpt-4o-mini",
"output": {"schema": {"company_name": "string"}},
},
{
"name": "aggregate_by_sector",
"type": "reduce",
"reduce_key": "sector",
"prompt": "List companies:\n{% for input in inputs %}\n- {{ input.company_name }}\n{% endfor %}",
"model": "gpt-4o-mini",
"output": {"schema": {"companies": "list[str]"}},
},
]
from docetl.reasoning_optimizer.instantiate_schemas import (
MapToMapResolveReduceInstantiateSchema,
)
rewrite = MapToMapResolveReduceInstantiateSchema(
resolve_name="normalize_company_names",
comparison_prompt="""Are these two company names referring to the same company?
Company 1: {{ input1.company_name }}
Company 2: {{ input2.company_name }}
Consider variations like abbreviations (IBM vs International Business Machines).""",
resolution_prompt="""Given these variations of a company name:
{% for input in inputs %}
- {{ input.company_name }}
{% endfor %}
Return the canonical/official company name.""",
blocking_conditions=[
"input1['company_name'][:3].lower() == input2['company_name'][:3].lower()",
"input1['company_name'].split()[0].lower() == input2['company_name'].split()[0].lower()",
],
blocking_keys=["sector", "company_name"], # Must include reduce_key (sector)
limit_comparisons=1000,
)
result = directive.apply(
"gpt-4o-mini", ops_list, "extract_companies", "aggregate_by_sector", rewrite
)
assert isinstance(result, list)
assert len(result) == 3 # map + resolve + reduce
# Check the resolve operation was inserted correctly
assert result[0]["name"] == "extract_companies"
assert result[0]["type"] == "map"
assert result[1]["name"] == "normalize_company_names"
assert result[1]["type"] == "resolve"
assert "comparison_prompt" in result[1]
assert "resolution_prompt" in result[1]
assert "blocking_conditions" in result[1]
assert "blocking_keys" in result[1]
assert result[1]["blocking_keys"] == ["sector", "company_name"]
assert result[1]["limit_comparisons"] == 1000
assert "input1" in result[1]["comparison_prompt"]
assert "input2" in result[1]["comparison_prompt"]
assert "inputs" in result[1]["resolution_prompt"]
# Output schema should be derived from reduce_key
assert result[1]["output"]["schema"] == {"sector": "string"}
assert result[2]["name"] == "aggregate_by_sector"
assert result[2]["type"] == "reduce"
def test_map_to_map_resolve_reduce_apply_with_multiple_reduce_keys():
"""Test map_to_map_resolve_reduce with multiple reduce_keys"""
directive = MapToMapResolveReduceDirective()
ops_list = [
{
"name": "extract_products",
"type": "map",
"prompt": "Extract product info: {{ input.description }}",
"model": "gpt-4o-mini",
"output": {"schema": {"product_name": "string", "category": "string"}},
},
{
"name": "aggregate_products",
"type": "reduce",
"reduce_key": ["brand", "region"], # Multiple reduce keys
"prompt": "List products:\n{% for input in inputs %}{{ input.product_name }}{% endfor %}",
"model": "gpt-4o-mini",
"output": {"schema": {"products": "list[str]"}},
},
]
from docetl.reasoning_optimizer.instantiate_schemas import (
MapToMapResolveReduceInstantiateSchema,
)
rewrite = MapToMapResolveReduceInstantiateSchema(
resolve_name="normalize_products",
comparison_prompt="Same product? {{ input1.product_name }} vs {{ input2.product_name }}",
resolution_prompt="Normalize: {% for input in inputs %}{{ input.product_name }}{% endfor %}",
blocking_conditions=[
"input1['category'] == input2['category']",
],
blocking_keys=["brand", "region", "product_name"], # Must include reduce_keys
limit_comparisons=500,
)
result = directive.apply(
"gpt-4o-mini", ops_list, "extract_products", "aggregate_products", rewrite
)
assert result[1]["type"] == "resolve"
assert "blocking_keys" in result[1]
assert result[1]["blocking_keys"] == ["brand", "region", "product_name"]
# Output schema should include all reduce_keys
assert result[1]["output"]["schema"] == {"brand": "string", "region": "string"}
def test_map_resolve_to_map_with_categories_apply():
"""Test that map_resolve_to_map_with_categories apply doesn't crash"""
directive = MapResolveToMapWithCategoriesDirective()
# Pipeline with map followed by resolve
ops_list = [
{
"name": "extract_sentiment",
"type": "map",
"prompt": "What is the sentiment? {{ input.review }}",
"model": "gpt-4o-mini",
"output": {"schema": {"sentiment": "string"}},
},
{
"name": "normalize_sentiment",
"type": "resolve",
"comparison_prompt": "Same sentiment? {{ input1.sentiment }} vs {{ input2.sentiment }}",
"resolution_prompt": "Normalize: {% for input in inputs %}{{ input.sentiment }}{% endfor %}",
"model": "gpt-4o-mini",
"output": {"schema": {"sentiment": "string"}},
},
]
from docetl.reasoning_optimizer.instantiate_schemas import (
MapResolveToMapWithCategoriesInstantiateSchema,
)
rewrite = MapResolveToMapWithCategoriesInstantiateSchema(
categories=["Positive", "Negative", "Neutral", "Mixed"],
category_key="sentiment",
new_prompt="""Classify the sentiment of this review into one of these categories:
- Positive: Clearly positive sentiment
- Negative: Clearly negative sentiment
- Neutral: No strong sentiment
- Mixed: Both positive and negative
- None of the above
Review: {{ input.review }}
Return exactly one category.""",
include_none_of_above=True,
)
result = directive.apply(
"gpt-4o-mini", ops_list, "extract_sentiment", "normalize_sentiment", rewrite
)
assert isinstance(result, list)
assert len(result) == 1 # map only, resolve removed
# Check the map operation was modified correctly
assert result[0]["name"] == "extract_sentiment"
assert result[0]["type"] == "map"
assert "Positive" in result[0]["prompt"]
assert "Negative" in result[0]["prompt"]
assert "None of the above" in result[0]["prompt"]
# Check validation was added
assert "validate" in result[0]
assert len(result[0]["validate"]) == 1
assert "Positive" in result[0]["validate"][0]
assert "None of the above" in result[0]["validate"][0]
def test_map_resolve_to_map_with_categories_no_none_of_above():
"""Test map_resolve_to_map_with_categories without 'None of the above' option"""
directive = MapResolveToMapWithCategoriesDirective()
ops_list = [
{
"name": "classify_type",
"type": "map",
"prompt": "What type is this? {{ input.text }}",
"model": "gpt-4o-mini",
"output": {"schema": {"item_type": "string"}},
},
{
"name": "normalize_type",
"type": "resolve",
"comparison_prompt": "Same type? {{ input1.item_type }} vs {{ input2.item_type }}",
"resolution_prompt": "Normalize: {% for input in inputs %}{{ input.item_type }}{% endfor %}",
"model": "gpt-4o-mini",
"output": {"schema": {"item_type": "string"}},
},
]
from docetl.reasoning_optimizer.instantiate_schemas import (
MapResolveToMapWithCategoriesInstantiateSchema,
)
rewrite = MapResolveToMapWithCategoriesInstantiateSchema(
categories=["TypeA", "TypeB", "TypeC"],
category_key="item_type",
new_prompt="""Classify into: TypeA, TypeB, or TypeC
Text: {{ input.text }}""",
include_none_of_above=False,
)
result = directive.apply(
"gpt-4o-mini", ops_list, "classify_type", "normalize_type", rewrite
)
assert len(result) == 1
# Check validation does NOT include 'None of the above'
assert "validate" in result[0]
assert "None of the above" not in result[0]["validate"][0]
assert "TypeA" in result[0]["validate"][0]
assert "TypeB" in result[0]["validate"][0]
assert "TypeC" in result[0]["validate"][0]
if __name__ == "__main__":
# Run all tests
test_chaining_apply()
@ -1078,4 +1312,16 @@ if __name__ == "__main__":
test_cascade_filtering_apply()
print("✅ Cascade filtering apply test passed")
test_map_to_map_resolve_reduce_apply()
print("✅ Map to map resolve reduce apply test passed")
test_map_to_map_resolve_reduce_apply_with_multiple_reduce_keys()
print("✅ Map to map resolve reduce with multiple reduce keys apply test passed")
test_map_resolve_to_map_with_categories_apply()
print("✅ Map resolve to map with categories apply test passed")
test_map_resolve_to_map_with_categories_no_none_of_above()
print("✅ Map resolve to map with categories (no none of above) apply test passed")
print("\n🎉 All directive apply tests passed!")

View File

@ -10,14 +10,13 @@ from typing import Dict, List
from datetime import datetime
from docetl.reasoning_optimizer.directives import (
ChainingDirective,
ChainingDirective,
GleaningDirective,
ReduceGleaningDirective,
ReduceChainingDirective,
ChangeModelDirective,
DocSummarizationDirective,
IsolatingSubtasksDirective,
DocCompressionDirective,
DeterministicDocCompressionDirective,
OperatorFusionDirective,
DocumentChunkingDirective,
@ -28,8 +27,12 @@ from docetl.reasoning_optimizer.directives import (
SwapWithCodeDirective,
HierarchicalReduceDirective,
CascadeFilteringDirective,
TestResult
MapToMapResolveReduceDirective,
MapResolveToMapWithCategoriesDirective,
TestResult,
)
# DocCompressionDirective is commented out in __init__.py, import directly
from docetl.reasoning_optimizer.directives.doc_compression import DocCompressionDirective
def run_all_directive_tests(agent_llm: str = "gpt-4.1") -> Dict[str, List[TestResult]]:
"""
@ -68,6 +71,8 @@ def run_all_directive_tests(agent_llm: str = "gpt-4.1") -> Dict[str, List[TestRe
SwapWithCodeDirective(),
HierarchicalReduceDirective(),
CascadeFilteringDirective(),
MapToMapResolveReduceDirective(),
MapResolveToMapWithCategoriesDirective(),
]
all_results = {}
@ -176,6 +181,8 @@ def run_specific_directive_test(directive_name: str, agent_llm: str = "gpt-4o-mi
"clarify_instructions": ClarifyInstructionsDirective(),
"swap_with_code": SwapWithCodeDirective(),
"cascade_filtering": CascadeFilteringDirective(),
"map_to_map_resolve_reduce": MapToMapResolveReduceDirective(),
"map_resolve_to_map_with_categories": MapResolveToMapWithCategoriesDirective(),
}
if directive_name.lower() not in directive_map:

View File

@ -5,4 +5,9 @@ export const API_ROUTES = {
CANCEL: (taskId: string) =>
`/api/shouldOptimize?taskId=${taskId}&cancel=true`,
},
DECOMPOSE: {
SUBMIT: "/api/decompose",
STATUS: (taskId: string) => `/api/decompose?taskId=${taskId}`,
CANCEL: (taskId: string) => `/api/decompose?taskId=${taskId}&cancel=true`,
},
} as const;

View File

@ -0,0 +1,79 @@
import { NextRequest, NextResponse } from "next/server";
const FASTAPI_URL = `${
process.env.NEXT_PUBLIC_BACKEND_HTTPS ? "https" : "http"
}://${process.env.NEXT_PUBLIC_BACKEND_HOST}:${
process.env.NEXT_PUBLIC_BACKEND_PORT
}`;
// Helper to handle errors consistently
const handleError = (error: unknown, status = 500) => {
const message =
error instanceof Error ? error.message : "Internal server error";
return NextResponse.json({ error: message }, { status });
};
// Helper to proxy requests to FastAPI
async function proxyRequest(path: string, init?: RequestInit) {
const response = await fetch(`${FASTAPI_URL}${path}`, {
...init,
headers: {
"Content-Type": "application/json",
...init?.headers,
},
});
if (!response.ok) {
const error = await response.text();
throw new Error(`FastAPI server error: ${error}`);
}
return response.json();
}
export async function POST(request: NextRequest): Promise<NextResponse> {
try {
// Extract task ID from the URL if it exists
const taskId = request.nextUrl.searchParams.get("taskId");
const isCancel = request.nextUrl.searchParams.get("cancel") === "true";
// Handle different POST scenarios
if (taskId) {
if (isCancel) {
// Cancel task
const data = await proxyRequest(`/decompose/${taskId}/cancel`, {
method: "POST",
});
return NextResponse.json(data);
}
// Invalid request with taskId but no cancel
return handleError(new Error("Invalid request"), 400);
}
// Submit new task
const body = await request.json();
const data = await proxyRequest("/decompose", {
method: "POST",
body: JSON.stringify(body),
});
return NextResponse.json(data, { status: 202 });
} catch (error) {
return handleError(error);
}
}
export async function GET(request: NextRequest): Promise<NextResponse> {
try {
// Extract task ID from the URL
const taskId = request.nextUrl.searchParams.get("taskId");
if (!taskId) {
return handleError(new Error("Task ID is required"), 400);
}
const data = await proxyRequest(`/decompose/${taskId}`);
return NextResponse.json(data);
} catch (error) {
return handleError(error);
}
}

View File

@ -33,6 +33,77 @@ export function getNamespaceDir(homeDir: string, namespace: string) {
return path.join(homeDir, ".docetl", namespace);
}
/**
* Safely parse a JSON value - returns the parsed value if it's a string,
* or the original value if it's already an object.
* @param value - The value to parse
* @param fieldName - Name of the field (for error messages)
* @returns The parsed value
*/
function safeParseJSON(
value: unknown,
fieldName: string
): Record<string, unknown> | unknown[] {
if (value === null || value === undefined) {
throw new Error(`Field '${fieldName}' is null or undefined`);
}
// If it's already an object or array, return as-is
if (typeof value === "object") {
return value as Record<string, unknown> | unknown[];
}
// If it's a string, try to parse it
if (typeof value === "string") {
try {
return JSON.parse(value);
} catch (e) {
throw new Error(
`Field '${fieldName}' contains invalid JSON: ${value.slice(0, 100)}${value.length > 100 ? "..." : ""}`
);
}
}
throw new Error(
`Field '${fieldName}' has unexpected type '${typeof value}': ${String(value).slice(0, 100)}`
);
}
/**
* Recursively sanitize an object for YAML serialization.
* Converts any nested [object Object] strings to proper objects.
*/
function sanitizeForYaml(obj: unknown, path: string = ""): unknown {
if (obj === null || obj === undefined) {
return obj;
}
if (typeof obj === "string") {
// Check for [object Object] which indicates a serialization issue
if (obj === "[object Object]") {
console.warn(
`Warning: Found '[object Object]' string at path '${path}'. This indicates a serialization issue.`
);
return {}; // Return empty object as fallback
}
return obj;
}
if (Array.isArray(obj)) {
return obj.map((item, idx) => sanitizeForYaml(item, `${path}[${idx}]`));
}
if (typeof obj === "object") {
const result: Record<string, unknown> = {};
for (const [key, value] of Object.entries(obj)) {
result[key] = sanitizeForYaml(value, path ? `${path}.${key}` : key);
}
return result;
}
return obj;
}
export function generatePipelineConfig(
namespace: string,
default_model: string,
@ -82,9 +153,18 @@ export function generatePipelineConfig(
const litellm_completion_kwargs =
op.otherKwargs?.litellm_completion_kwargs;
if (litellm_completion_kwargs) {
op.otherKwargs.litellm_completion_kwargs = JSON.parse(
litellm_completion_kwargs
);
try {
op.otherKwargs.litellm_completion_kwargs = safeParseJSON(
litellm_completion_kwargs,
`${op.name}.litellm_completion_kwargs`
);
} catch (e) {
console.warn(
`Warning: Could not parse litellm_completion_kwargs for operation '${op.name}':`,
e
);
// Keep the original value if parsing fails
}
}
const newOp: Record<string, unknown> = {
@ -175,10 +255,14 @@ export function generatePipelineConfig(
// If it's a sample operation with custom method, parse the samples as key-value pairs
if (op.type === "sample" && op.otherKwargs?.method === "custom") {
try {
newOp.samples = JSON.parse(op.otherKwargs.samples);
newOp.samples = safeParseJSON(
op.otherKwargs.samples,
`${op.name}.samples`
);
} catch (error) {
console.warn(
"Failed to parse custom samples as JSON, using raw value"
`Failed to parse custom samples for operation '${op.name}':`,
error
);
}
}
@ -217,14 +301,12 @@ export function generatePipelineConfig(
// Fix type errors by asserting the pipeline config type
let pipelineConfig: any = {
from_docwrangler: true,
optimizer_model: optimizerModel,
datasets,
default_model,
...(enable_observability && {
optimizer_config: {
force_decompose: true,
},
}),
optimizer_config: {
rewrite_agent_model: optimizerModel,
...(enable_observability && { force_decompose: true }),
},
operations: updatedOperations,
pipeline: {
steps: [
@ -308,7 +390,9 @@ export function generatePipelineConfig(
const outputOpName = operationsToRun[currentOpIndex].name;
outputPath = path.join(outputBase, "data_processing", outputOpName + ".json");
const yamlString = yaml.dump(pipelineConfig);
// Sanitize the config before serializing to YAML to catch any [object Object] issues
const sanitizedConfig = sanitizeForYaml(pipelineConfig, "pipelineConfig");
const yamlString = yaml.dump(sanitizedConfig);
console.log(yamlString);

View File

@ -104,6 +104,7 @@ export interface OptimizeResult {
should_optimize?: string;
input_data?: Array<Record<string, unknown>>;
output_data?: Array<Record<string, unknown>>;
num_docs_analyzed?: number;
cost?: number;
error?: string;
created_at: string;
@ -114,3 +115,24 @@ export interface APIKey {
name: string;
value: string;
}
export interface DecomposeRequest {
yaml_config: string;
step_name: string;
op_name: string;
}
export interface DecomposeResult {
task_id: string;
status: TaskStatus;
decomposed_operations?: Array<Record<string, unknown>>;
winning_directive?: string;
candidates_evaluated?: number;
original_outputs?: Array<Record<string, unknown>>;
decomposed_outputs?: Array<Record<string, unknown>>;
comparison_rationale?: string;
cost?: number;
error?: string;
created_at: string;
completed_at?: string;
}

View File

@ -24,12 +24,14 @@ interface AnsiRendererProps {
text: string;
readyState: number;
setTerminalOutput: (text: string) => void;
isDecomposing?: boolean;
}
const AnsiRenderer: React.FC<AnsiRendererProps> = ({
text,
readyState,
setTerminalOutput,
isDecomposing = false,
}) => {
const html = convert.toHtml(text);
const scrollRef = useRef<HTMLDivElement>(null);
@ -51,7 +53,9 @@ const AnsiRenderer: React.FC<AnsiRendererProps> = ({
}
};
const isWebSocketClosed = readyState === WebSocket.CLOSED;
// Consider connected if either WebSocket is open OR decomposing is in progress
const isConnected = readyState === WebSocket.OPEN || isDecomposing;
const isWebSocketClosed = readyState === WebSocket.CLOSED && !isDecomposing;
return (
<div
@ -92,9 +96,11 @@ const AnsiRenderer: React.FC<AnsiRendererProps> = ({
)}
</div>
<div className="flex justify-between items-center text-xs text-gray-500">
<div className={isWebSocketClosed ? "text-red-500" : ""}>
<div className={isWebSocketClosed ? "text-red-500" : isDecomposing ? "text-blue-400" : ""}>
Status:{" "}
{readyState === WebSocket.CONNECTING
{isDecomposing
? "Decomposing..."
: readyState === WebSocket.CONNECTING
? "Connecting"
: readyState === WebSocket.OPEN
? "Connected"
@ -106,10 +112,7 @@ const AnsiRenderer: React.FC<AnsiRendererProps> = ({
</div>
<button
onClick={() => setTerminalOutput("")}
className={`hover:text-white transition-colors ${
isWebSocketClosed ? "cursor-not-allowed opacity-50" : ""
}`}
disabled={isWebSocketClosed}
className="hover:text-white transition-colors"
>
Clear
</button>

View File

@ -0,0 +1,271 @@
import React from "react";
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogDescription,
} from "@/components/ui/dialog";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { Button } from "@/components/ui/button";
import { ChevronLeft, ChevronRight, Check, X } from "lucide-react";
import { DecomposeResult } from "@/app/types";
interface DecompositionComparisonDialogProps {
isOpen: boolean;
result: DecomposeResult | null;
operationName?: string;
onOpenChange: (open: boolean) => void;
onApply: () => void;
onCancel: () => void;
}
export const DecompositionComparisonDialog: React.FC<
DecompositionComparisonDialogProps
> = ({ isOpen, result, operationName, onOpenChange, onApply, onCancel }) => {
const [currentPage, setCurrentPage] = React.useState(1);
const originalOutputs = result?.original_outputs || [];
const decomposedOutputs = result?.decomposed_outputs || [];
const maxRows = Math.max(originalOutputs.length, decomposedOutputs.length);
const renderOutputTable = (
data: Array<Record<string, unknown>>,
title: string
) => {
if (!data.length) {
return (
<div className="text-center text-muted-foreground py-8">
No outputs available
</div>
);
}
const columns = Object.keys(data[0]).filter(
(column) => !column.startsWith("_observability")
);
const currentRow = data[currentPage - 1];
if (!currentRow) return null;
return (
<div className="space-y-2">
<h4 className="font-medium text-sm text-muted-foreground">{title}</h4>
<div className="border rounded-md">
<div className="max-h-[250px] overflow-auto">
<Table className="relative w-full border-collapse">
<TableHeader>
<TableRow className="sticky top-0 bg-background z-10 border-b">
{columns.map((column) => (
<TableHead
key={column}
className="h-8 px-3 text-left align-middle bg-background text-xs"
>
{column}
</TableHead>
))}
</TableRow>
</TableHeader>
<TableBody>
<TableRow>
{columns.map((column) => (
<TableCell key={column} className="p-2 align-top">
<pre className="whitespace-pre-wrap font-mono text-xs text-left max-w-[300px]">
{typeof currentRow[column] === "object"
? JSON.stringify(currentRow[column], null, 2)
: String(currentRow[column] ?? "")}
</pre>
</TableCell>
))}
</TableRow>
</TableBody>
</Table>
</div>
</div>
</div>
);
};
React.useEffect(() => {
if (isOpen) {
setCurrentPage(1);
}
}, [isOpen]);
if (!result) return null;
const isOriginalWinner = result.winning_directive === "original";
return (
<Dialog open={isOpen} onOpenChange={onOpenChange}>
<DialogContent className="max-w-7xl max-h-[90vh] flex flex-col">
<DialogHeader className="flex-shrink-0 border-b pb-4">
<DialogTitle className="text-xl">
Review Decomposition Results
</DialogTitle>
<DialogDescription>
Compare the original operation outputs with the decomposed version
before applying changes.
</DialogDescription>
</DialogHeader>
<div className="flex-1 overflow-y-auto py-4 space-y-6">
{/* Summary */}
<div className="p-4 bg-muted rounded-lg space-y-2">
<div className="flex items-center gap-4 flex-wrap">
{operationName && (
<div className="flex items-center">
<span className="font-medium text-sm mr-2">Operation:</span>
<span className="bg-primary/15 text-primary rounded-md px-2 py-1 text-sm">
{operationName}
</span>
</div>
)}
<div className="flex items-center">
<span className="font-medium text-sm mr-2">
Winning Strategy:
</span>
<span
className={`rounded-md px-2 py-1 text-sm ${
isOriginalWinner
? "bg-yellow-100 text-yellow-800"
: "bg-green-100 text-green-800"
}`}
>
{result.winning_directive}
</span>
</div>
<div className="flex items-center">
<span className="font-medium text-sm mr-2">
Candidates Evaluated:
</span>
<span className="text-sm">{result.candidates_evaluated}</span>
</div>
{result.cost && (
<div className="flex items-center">
<span className="font-medium text-sm mr-2">Cost:</span>
<span className="text-sm">${result.cost.toFixed(4)}</span>
</div>
)}
</div>
</div>
{/* Comparison Rationale */}
{result.comparison_rationale && (
<div className="space-y-2">
<h3 className="font-medium text-base">Why This Was Chosen</h3>
<div className="p-3 bg-blue-50 border border-blue-200 rounded-lg text-sm">
{result.comparison_rationale}
</div>
</div>
)}
{/* Side by side comparison */}
<div className="space-y-4">
<div className="flex items-center justify-between">
<h3 className="font-medium text-base">
Output Comparison (Sample {currentPage} of {maxRows || 1})
</h3>
<div className="flex items-center space-x-1">
<Button
variant="outline"
size="sm"
onClick={() =>
setCurrentPage((prev) => Math.max(prev - 1, 1))
}
disabled={currentPage === 1}
className="px-2 py-1"
>
<ChevronLeft className="h-4 w-4" />
Previous
</Button>
<Button
variant="outline"
size="sm"
onClick={() =>
setCurrentPage((prev) => Math.min(prev + 1, maxRows))
}
disabled={currentPage === maxRows || maxRows === 0}
className="px-2 py-1"
>
Next
<ChevronRight className="h-4 w-4" />
</Button>
</div>
</div>
<div className="grid grid-cols-2 gap-4">
{renderOutputTable(originalOutputs, "Original Output")}
{renderOutputTable(
decomposedOutputs,
`Decomposed Output (${result.winning_directive})`
)}
</div>
</div>
{/* Decomposed Operations Preview */}
{result.decomposed_operations &&
result.decomposed_operations.length > 0 &&
!isOriginalWinner && (
<div className="space-y-2">
<h3 className="font-medium text-base">
New Operations ({result.decomposed_operations.length})
</h3>
<div className="space-y-2">
{result.decomposed_operations.map((op, idx) => (
<div
key={idx}
className="p-3 border rounded-lg bg-muted/50"
>
<div className="flex items-center gap-2 mb-2">
<span className="font-medium text-sm">
{(op.name as string) || `Operation ${idx + 1}`}
</span>
<span className="text-xs bg-primary/10 text-primary px-2 py-0.5 rounded">
{op.type as string}
</span>
</div>
{op.prompt && (
<pre className="text-xs font-mono bg-background p-2 rounded border max-h-[100px] overflow-auto">
{String(op.prompt).slice(0, 500)}
{String(op.prompt).length > 500 ? "..." : ""}
</pre>
)}
</div>
))}
</div>
</div>
)}
</div>
<div className="flex justify-end items-center gap-3 pt-4 border-t mt-4">
<Button variant="outline" onClick={onCancel}>
<X className="h-4 w-4 mr-2" />
Keep Original
</Button>
<Button
onClick={onApply}
disabled={isOriginalWinner}
title={
isOriginalWinner
? "Original was determined to be the best option"
: undefined
}
>
<Check className="h-4 w-4 mr-2" />
Apply Decomposition
</Button>
</div>
</DialogContent>
</Dialog>
);
};
DecompositionComparisonDialog.displayName = "DecompositionComparisonDialog";

View File

@ -39,7 +39,6 @@ import { Skeleton } from "@/components/ui/skeleton";
import { debounce } from "lodash";
import { Guardrails, GleaningConfig } from "./operations/args";
import createOperationComponent from "./operations/components";
import { useWebSocket } from "@/contexts/WebSocketContext";
import { Badge } from "./ui/badge";
import {
Popover,
@ -161,7 +160,7 @@ const OperationHeader: React.FC<OperationHeaderProps> = React.memo(
}`}
/>
</HoverCardTrigger>
<HoverCardContent className="w-72" side="bottom" align="start">
<HoverCardContent className="w-96 max-h-80" side="bottom" align="start">
<div className="flex flex-col space-y-1">
<p className="text-sm font-medium">
{optimizeResult === undefined || optimizeResult === null
@ -170,13 +169,15 @@ const OperationHeader: React.FC<OperationHeaderProps> = React.memo(
? "Decomposition Status"
: "Decomposition Recommended"}
</p>
<p className="text-sm text-muted-foreground">
{optimizeResult === undefined || optimizeResult === null
? "Analyzing operation complexity..."
: optimizeResult === ""
? "No decomposition needed for this operation"
: "Recommended decomposition: " + optimizeResult}
</p>
<div className="max-h-64 overflow-y-auto">
<p className="text-sm text-muted-foreground whitespace-pre-wrap">
{optimizeResult === undefined || optimizeResult === null
? "Analyzing operation complexity..."
: optimizeResult === ""
? "No decomposition needed for this operation"
: optimizeResult}
</p>
</div>
</div>
</HoverCardContent>
</HoverCard>
@ -367,7 +368,7 @@ const OperationHeader: React.FC<OperationHeaderProps> = React.memo(
disabled={disabled}
>
<Zap className="mr-2 h-4 w-4" />
Optimize Operation
Decompose Operation
</Button>
)}
@ -720,7 +721,6 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
output: pipelineOutput,
setOutput,
isLoadingOutputs,
setIsLoadingOutputs,
numOpRun,
setNumOpRun,
currentFile,
@ -730,18 +730,14 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
sampleSize,
setCost,
defaultModel,
optimizerModel,
setTerminalOutput,
namespace,
apiKeys,
systemPrompt,
extraPipelineSettings,
onRequestDecompositionRef,
} = usePipelineContext();
const { toast } = useToast();
const operationRef = useRef(operation);
const { connect, sendMessage, lastMessage, readyState, disconnect } =
useWebSocket();
useEffect(() => {
operationRef.current = operation;
@ -808,75 +804,19 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
const handleOptimizeConfirm = useCallback(async () => {
if (!operation) return;
try {
// Clear the output
setTerminalOutput("");
setIsLoadingOutputs(true);
setShowOptimizeDialog(false);
// Write pipeline config
const response = await fetch("/api/writePipelineConfig", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
default_model: defaultModel,
data: { path: currentFile?.path || "" },
operations,
operation_id: operation.id,
name: pipelineName,
sample_size: sampleSize,
optimize: true,
clear_intermediate: false,
system_prompt: systemPrompt,
namespace: namespace,
apiKeys: apiKeys,
optimizerModel: optimizerModel,
extraPipelineSettings: extraPipelineSettings,
}),
});
if (!response.ok) {
throw new Error(await response.text());
}
const { filePath } = await response.json();
// Ensure WebSocket is connected
await connect();
// Send message to run the pipeline
sendMessage({
yaml_config: filePath,
optimize: true,
});
} catch (error) {
console.error("Error optimizing operation:", error);
// Use the decomposition flow from context if available
if (onRequestDecompositionRef.current) {
onRequestDecompositionRef.current(operation.id, operation.name);
} else {
toast({
title: "Error",
description: error.message,
description: "Decomposition handler not available",
variant: "destructive",
});
// Close the WebSocket connection
disconnect();
} finally {
setShowOptimizeDialog(false);
}
}, [
operation,
defaultModel,
currentFile,
operations,
pipelineName,
sampleSize,
optimizerModel,
connect,
sendMessage,
systemPrompt,
namespace,
apiKeys,
extraPipelineSettings,
]);
}, [operation, onRequestDecompositionRef, toast]);
const onShowOutput = useCallback(async () => {
if (!operation) return;
@ -1224,7 +1164,7 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle>Optimize Operation</AlertDialogTitle>
<AlertDialogTitle>Decompose Operation</AlertDialogTitle>
<AlertDialogDescription>
{!hasOpenAIKey && !isLocalMode ? (
<div className="space-y-2">
@ -1232,8 +1172,8 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
OpenAI API Key Required
</p>
<p>
To use the optimizer, please add your OpenAI API key in Edit{" "}
{">"}
To decompose operations, please add your OpenAI API key in
Edit {">"}
Edit API Keys.
</p>
<button
@ -1245,11 +1185,10 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
</div>
) : (
<p>
This will analyze the operation and replace it with another
pipeline that has higher accuracy (as determined by an
LLM-as-a-judge), if it can be found. Do you want to proceed?
The process may take between 2 and 10 minutes, depending on
how complex your data is.
This will analyze the operation and try different
decomposition strategies (chaining, chunking, gleaning, etc.)
to improve accuracy. The best strategy will be selected via
LLM-as-a-judge comparison. Do you want to proceed?
</p>
)}
</AlertDialogDescription>
@ -1260,7 +1199,7 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
onClick={handleOptimizeConfirm}
disabled={!hasOpenAIKey && !isLocalMode}
>
Proceed
Decompose
</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>

View File

@ -23,6 +23,7 @@ interface OptimizationDialogProps {
operationName?: string;
inputData?: Array<Record<string, unknown>>;
outputData?: Array<Record<string, unknown>>;
numDocsAnalyzed?: number;
onOpenChange: (open: boolean) => void;
onDecompose?: () => void;
}
@ -34,6 +35,7 @@ export const OptimizationDialog: React.FC<OptimizationDialogProps> = ({
inputData,
outputData,
operationName,
numDocsAnalyzed,
onOpenChange,
onDecompose,
}) => {
@ -140,9 +142,9 @@ export const OptimizationDialog: React.FC<OptimizationDialogProps> = ({
<DialogHeader className="flex-shrink-0 border-b pb-4">
<DialogTitle className="text-xl">Operation Too Complex</DialogTitle>
<p className="text-base mt-2">
This operation might be too complex for the LLM to handle
efficiently. We recommend breaking it down into smaller, more
manageable steps.
After analyzing {numDocsAnalyzed ?? "several"} documents, this
operation might be too complex for the LLM to handle efficiently. We
recommend breaking it down into smaller, more manageable steps.
</p>
</DialogHeader>

View File

@ -81,6 +81,7 @@ const useOutputContext = () => {
const {
output,
isLoadingOutputs,
isDecomposing,
terminalOutput,
setTerminalOutput,
optimizerProgress,
@ -91,6 +92,7 @@ const useOutputContext = () => {
return {
output,
isLoadingOutputs,
isDecomposing,
terminalOutput,
setTerminalOutput,
optimizerProgress,
@ -385,7 +387,7 @@ VisualizeContent.displayName = "VisualizeContent";
// Move ConsoleContent outside
export const ConsoleContent = memo(() => {
const { terminalOutput, setTerminalOutput, optimizerProgress } =
const { terminalOutput, setTerminalOutput, optimizerProgress, isDecomposing } =
useOutputContext();
const { readyState } = useWebSocket();
@ -469,6 +471,7 @@ export const ConsoleContent = memo(() => {
text={terminalOutput || ""}
readyState={readyState}
setTerminalOutput={setTerminalOutput}
isDecomposing={isDecomposing}
/>
</div>
</div>
@ -478,7 +481,7 @@ ConsoleContent.displayName = "ConsoleContent";
// Main Output component
export const Output = memo(() => {
const { output, isLoadingOutputs, sampleSize, operations } =
const { output, isLoadingOutputs, isDecomposing, sampleSize, operations } =
useOutputContext();
const operation = useOperation(output?.operationId);
@ -500,10 +503,14 @@ export const Output = memo(() => {
);
}, [operation]);
// Effect for tab changes
// Effect for tab changes - switch to console when loading or decomposing
useEffect(() => {
setActiveTab(isLoadingOutputs ? "console" : "table");
}, [isLoadingOutputs]);
if (isLoadingOutputs || isDecomposing) {
setActiveTab("console");
} else {
setActiveTab("table");
}
}, [isLoadingOutputs, isDecomposing]);
// Memoize columns
const columns = useMemo(() => {

View File

@ -40,9 +40,12 @@ import { Input } from "@/components/ui/input";
import { schemaDictToItemSet } from "./utils";
import { v4 as uuidv4 } from "uuid";
import { useOptimizeCheck } from "@/hooks/useOptimizeCheck";
import { useDecomposeWebSocket } from "@/hooks/useDecomposeWebSocket";
import { canBeOptimized } from "@/lib/utils";
import { Textarea } from "./ui/textarea";
import { OptimizationDialog } from "@/components/OptimizationDialog";
import { DecompositionComparisonDialog } from "@/components/DecompositionComparisonDialog";
import { DecomposeResult } from "@/app/types";
import {
Popover,
PopoverContent,
@ -240,6 +243,9 @@ const PipelineGUI: React.FC = () => {
apiKeys,
extraPipelineSettings,
setExtraPipelineSettings,
isDecomposing,
setIsDecomposing,
onRequestDecompositionRef,
} = usePipelineContext();
const [isSettingsOpen, setIsSettingsOpen] = useState(false);
const { toast } = useToast();
@ -253,12 +259,14 @@ const PipelineGUI: React.FC = () => {
outputData?: Array<Record<string, unknown>>;
operationName?: string;
operationId?: string;
numDocsAnalyzed?: number;
}>({
isOpen: false,
content: "",
prompt: undefined,
operationName: undefined,
operationId: undefined,
numDocsAnalyzed: undefined,
});
const [isEditingName, setIsEditingName] = useState(false);
const [editedPipelineName, setEditedPipelineName] = useState(pipelineName);
@ -277,14 +285,13 @@ const PipelineGUI: React.FC = () => {
setCost((prev) => prev + result.cost);
if (result.should_optimize) {
toast({
title: `Hey! Consider decomposing ${operations[operations.length - 1].name
}`,
const lastOp = operations[operations.length - 1];
const { dismiss } = toast({
title: `Hey! Consider decomposing ${lastOp.name}`,
description: (
<span
className="cursor-pointer text-blue-500 hover:text-blue-700"
onClick={() => {
const lastOp = operations[operations.length - 1];
setOptimizationDialog({
isOpen: true,
content: result.should_optimize,
@ -293,7 +300,9 @@ const PipelineGUI: React.FC = () => {
operationId: lastOp.id,
inputData: result.input_data,
outputData: result.output_data,
numDocsAnalyzed: result.num_docs_analyzed,
});
dismiss();
}}
>
Click here to see why.
@ -312,6 +321,144 @@ const PipelineGUI: React.FC = () => {
},
});
const [decomposeDialog, setDecomposeDialog] = useState<{
isOpen: boolean;
result: DecomposeResult | null;
targetOpId: string | null;
operationName: string | null;
}>({
isOpen: false,
result: null,
targetOpId: null,
operationName: null,
});
// Ref to store the current decomposition target (avoids stale closure issue)
const decompositionTargetRef = useRef<{
operationId: string;
operationName: string;
} | null>(null);
const { startDecomposition, cancel: cancelDecomposition } = useDecomposeWebSocket({
namespace,
onOutput: (output) => {
// Stream output to the terminal
setTerminalOutput(output);
},
onComplete: (result) => {
setIsDecomposing(false);
setCost((prev) => prev + (result.cost || 0));
// Read from ref to avoid stale closure
const target = decompositionTargetRef.current;
console.log("Decomposition complete:", {
result,
target,
decomposed_operations: result.decomposed_operations,
});
// Show the comparison dialog instead of auto-applying
setDecomposeDialog({
isOpen: true,
result,
targetOpId: target?.operationId || null,
operationName: target?.operationName || null,
});
},
onError: (error) => {
setIsDecomposing(false);
toast({
title: "Decomposition Failed",
description: error,
variant: "destructive",
});
},
});
const handleApplyDecomposition = () => {
const { result, targetOpId } = decomposeDialog;
console.log("handleApplyDecomposition called:", { result, targetOpId, decomposeDialog });
if (!result || !targetOpId) {
console.warn("handleApplyDecomposition: missing result or targetOpId", { result, targetOpId });
return;
}
if (
result.decomposed_operations &&
result.decomposed_operations.length > 0 &&
result.winning_directive !== "original"
) {
setOperations((prev) => {
// Find the index of the target operation
const targetIdx = prev.findIndex((op) => op.id === targetOpId);
if (targetIdx === -1) return prev;
// Convert decomposed operations to our Operation format
const newOps: Operation[] = result.decomposed_operations!.map(
(opConfig) => ({
id: uuidv4(),
llmType:
opConfig.type === "map" ||
opConfig.type === "reduce" ||
opConfig.type === "filter" ||
opConfig.type === "parallel_map"
? "LLM"
: "non-LLM",
type: opConfig.type as Operation["type"],
name: opConfig.name as string,
prompt: opConfig.prompt as string | undefined,
output: opConfig.output
? {
schema: schemaDictToItemSet(
(opConfig.output as { schema: Record<string, string> })
.schema
),
}
: undefined,
gleaning: opConfig.gleaning as Operation["gleaning"],
otherKwargs: Object.fromEntries(
Object.entries(opConfig).filter(
([key]) =>
!["name", "type", "prompt", "output", "gleaning"].includes(key)
)
),
visibility: true,
})
);
// Replace the target operation with the new operations
const newOperations = [...prev];
newOperations.splice(targetIdx, 1, ...newOps);
return newOperations;
});
toast({
title: "Decomposition Applied",
description: `Applied ${result.winning_directive} directive. Created ${result.decomposed_operations.length} operation(s).`,
});
}
setDecomposeDialog({
isOpen: false,
result: null,
targetOpId: null,
operationName: null,
});
};
const handleCancelDecomposition = () => {
toast({
title: "Decomposition Cancelled",
description: "Keeping the original operation.",
});
setDecomposeDialog({
isOpen: false,
result: null,
targetOpId: null,
operationName: null,
});
};
const { restoreFromYAML } = useRestorePipeline({
setOperations,
setPipelineName,
@ -328,79 +475,18 @@ const PipelineGUI: React.FC = () => {
if (lastMessage) {
if (lastMessage.type === "output") {
setTerminalOutput(lastMessage.data);
} else if (lastMessage.type === "optimizer_progress") {
setOptimizerProgress({
status: lastMessage.status,
progress: lastMessage.progress,
shouldOptimize: lastMessage.should_optimize,
rationale: lastMessage.rationale,
validatorPrompt: lastMessage.validator_prompt,
});
} else if (lastMessage.type === "result") {
const runCost = lastMessage.data.cost || 0;
setOptimizerProgress(null);
// See if there was an optimized operation
const optimizedOps = lastMessage.data.optimized_ops;
if (optimizedOps) {
const newOperations = optimizedOps.map((optimizedOp) => {
const {
id,
type,
name,
prompt,
output,
validate,
gleaning,
sample,
...otherKwargs
} = optimizedOp;
// Find matching operation in previous operations list
const existingOp = operations.find((op) => op.name === name);
return {
id: id || uuidv4(),
llmType:
type === "map" ||
type === "reduce" ||
type === "resolve" ||
type === "filter" ||
type === "parallel_map" ||
type === "rank" ||
type === "extract"
? "LLM"
: "non-LLM",
type: type,
name: name || "Untitled Operation",
prompt: prompt,
output: output
? {
schema: schemaDictToItemSet(output.schema),
}
: undefined,
validate: validate,
gleaning: gleaning,
sample: sample,
otherKwargs: otherKwargs || {},
...(existingOp?.runIndex && { runIndex: existingOp.runIndex }),
visibility: true,
} as Operation;
});
setOperations(newOperations);
} else {
// No optimized operations, so we need to check if we should optimize the last operation
// Trigger should optimize for the last operation
if (autoOptimizeCheck) {
const lastOp = operations[operations.length - 1];
if (lastOp && canBeOptimized(lastOp.type)) {
submitTask({
yaml_config: lastMessage.data.yaml_config,
step_name: "data_processing", // TODO: Make this a constant
op_name: lastOp.name,
});
}
// Trigger should_optimize check for the last operation if enabled
if (autoOptimizeCheck) {
const lastOp = operations[operations.length - 1];
if (lastOp && canBeOptimized(lastOp.type)) {
submitTask({
yaml_config: lastMessage.data.yaml_config,
step_name: "data_processing",
op_name: lastOp.name,
});
}
}
@ -692,20 +778,35 @@ const PipelineGUI: React.FC = () => {
};
const handleStop = () => {
sendMessage("kill");
// Stop pipeline run if running
if (isLoadingOutputs) {
sendMessage("kill");
if (readyState === WebSocket.CLOSED) {
setIsLoadingOutputs(false);
}
}
if (readyState === WebSocket.CLOSED && isLoadingOutputs) {
setIsLoadingOutputs(false);
// Stop decomposition if running
if (isDecomposing) {
cancelDecomposition();
setIsDecomposing(false);
}
};
const handleOptimizeFromDialog = async () => {
if (!optimizationDialog.operationId) return;
if (!optimizationDialog.operationId || !optimizationDialog.operationName) return;
// Store target in ref so onComplete callback can access it
decompositionTargetRef.current = {
operationId: optimizationDialog.operationId,
operationName: optimizationDialog.operationName,
};
try {
setTerminalOutput("");
setIsLoadingOutputs(true);
setIsDecomposing(true);
setTerminalOutput(""); // Clear terminal output
// Write the pipeline config to get a YAML file path
const response = await fetch("/api/writePipelineConfig", {
method: "POST",
headers: {
@ -732,24 +833,115 @@ const PipelineGUI: React.FC = () => {
const { filePath } = await response.json();
await connect();
// Use the WebSocket decompose endpoint for streaming output
await startDecomposition(
filePath,
"data_processing", // Matches the step name in generatePipelineConfig
optimizationDialog.operationName
);
sendMessage({
yaml_config: filePath,
optimize: true,
toast({
title: "Decomposing Operation",
description: "Analyzing and generating candidate decompositions. Watch the console for progress.",
});
} catch (error) {
console.error("Error optimizing operation:", error);
console.error("Error starting decomposition:", error);
setIsDecomposing(false);
toast({
title: "Error",
description: error.message,
description: error instanceof Error ? error.message : "Failed to start decomposition",
variant: "destructive",
});
disconnect();
setIsLoadingOutputs(false);
}
};
// Handler for decomposition requests from OperationCard dropdown
const handleDecompositionRequest = useCallback(
async (operationId: string, operationName: string) => {
// Store target in ref so onComplete callback can access it
decompositionTargetRef.current = {
operationId,
operationName,
};
try {
setIsDecomposing(true);
setTerminalOutput(""); // Clear terminal output
// Set the dialog state so we know which operation we're decomposing
setDecomposeDialog((prev) => ({
...prev,
targetOpId: operationId,
operationName: operationName,
}));
// Write the pipeline config to get a YAML file path
const response = await fetch("/api/writePipelineConfig", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
default_model: defaultModel,
data: { path: currentFile?.path || "" },
operations,
operation_id: operationId,
name: pipelineName,
sample_size: sampleSize,
optimize: true,
namespace: namespace,
apiKeys: apiKeys,
optimizerModel: optimizerModel,
extraPipelineSettings: extraPipelineSettings,
}),
});
if (!response.ok) {
throw new Error(await response.text());
}
const { filePath } = await response.json();
// Start decomposition via WebSocket
startDecomposition(filePath, "data_processing", operationName);
toast({
title: "Decomposing Operation",
description:
"Analyzing and generating candidate decompositions. Watch the console for progress.",
});
} catch (error) {
console.error("Error starting decomposition:", error);
setIsDecomposing(false);
toast({
title: "Error",
description:
error instanceof Error ? error.message : "Failed to start decomposition",
variant: "destructive",
});
}
},
[
defaultModel,
currentFile,
operations,
pipelineName,
sampleSize,
namespace,
apiKeys,
optimizerModel,
extraPipelineSettings,
startDecomposition,
setIsDecomposing,
setTerminalOutput,
toast,
]
);
// Register the decomposition handler in context so OperationCard can use it
// Using ref assignment instead of state to avoid infinite loops
onRequestDecompositionRef.current = handleDecompositionRequest;
return (
<div className="flex flex-col h-full">
<div
@ -1053,7 +1245,7 @@ const PipelineGUI: React.FC = () => {
variant="destructive"
className="rounded-sm whitespace-nowrap"
onClick={handleStop}
disabled={!isLoadingOutputs}
disabled={!isLoadingOutputs && !isDecomposing}
>
<StopCircle size={16} className="mr-2" />
Stop
@ -1144,11 +1336,22 @@ const PipelineGUI: React.FC = () => {
operationName={optimizationDialog.operationName}
inputData={optimizationDialog.inputData}
outputData={optimizationDialog.outputData}
numDocsAnalyzed={optimizationDialog.numDocsAnalyzed}
onOpenChange={(open) =>
setOptimizationDialog((prev) => ({ ...prev, isOpen: open }))
}
onDecompose={handleOptimizeFromDialog}
/>
<DecompositionComparisonDialog
isOpen={decomposeDialog.isOpen}
result={decomposeDialog.result}
operationName={decomposeDialog.operationName || undefined}
onOpenChange={(open) =>
setDecomposeDialog((prev) => ({ ...prev, isOpen: open }))
}
onApply={handleApplyDecomposition}
onCancel={handleCancelDecomposition}
/>
</div>
);
};

View File

@ -31,6 +31,7 @@ interface PromptInputProps {
onChange: (value: string) => void;
disableValidation?: boolean;
placeholder?: string;
operationType?: "map" | "reduce" | "filter" | "resolve" | "other";
}
export const PromptInput: React.FC<PromptInputProps> = React.memo(
@ -38,31 +39,62 @@ export const PromptInput: React.FC<PromptInputProps> = React.memo(
prompt,
onChange,
disableValidation = false,
placeholder = "Enter prompt (must be a Jinja2 template)",
placeholder = "Enter prompt (Jinja2 template optional)",
operationType = "map",
}) => {
const validateJinjaTemplate = (value: string) => {
const hasJinjaTemplate = (value: string) => {
if (disableValidation) return true;
const hasOpenBrace = value.includes("{{");
const hasCloseBrace = value.includes("}}");
return hasOpenBrace && hasCloseBrace;
};
const isReduce = operationType === "reduce";
const templateVar = isReduce ? "inputs" : "input";
const exampleKey = isReduce ? "inputs[0].content" : "input.content";
return (
<>
<Textarea
placeholder={placeholder}
className={`mb-1 rounded-sm text-sm font-mono ${
!validateJinjaTemplate(prompt) ? "border-red-500" : ""
!hasJinjaTemplate(prompt) ? "border-amber-400" : ""
}`}
rows={Math.max(3, Math.ceil(prompt.split("\n").length))}
value={prompt}
onChange={(e) => onChange(e.target.value)}
/>
{!validateJinjaTemplate(prompt) && (
<div className="text-red-500 text-sm mb-1">
Prompt must contain Jinja2 template syntax {"{"}
{"{"} and {"}"}
{"}"}
{!hasJinjaTemplate(prompt) && (
<div className="text-amber-600 text-sm mb-1 space-y-1">
<div>
<strong>Warning:</strong> No Jinja2 template found.
</div>
<div className="text-xs text-muted-foreground">
{isReduce ? (
<>
{"{"}
{"{"} inputs {"}"}
{"}"} (the list of all documents) will be auto-appended. For
example, if documents have keys {'"'}id{'"'} and {'"'}content
{'"'}, all docs with all keys will be included. To reference
specific keys, use a for loop: {"{"}% for doc in inputs %{"}"}{" "}
{"{"}
{"{"} doc.content {"}"}
{"}"} {"{"}% endfor %{"}"}.
</>
) : (
<>
{"{"}
{"{"} input {"}"}
{"}"} (the full document) will be auto-appended. For example,
if documents have keys {'"'}id{'"'}, {'"'}title{'"'}, and {'"'}
content{'"'}, all three will be included. To use only specific
keys (e.g., just {'"'}content{'"'}), add {"{"}
{"{"} input.content {"}"}
{"}"} to your prompt.
</>
)}
</div>
</div>
)}
</>
@ -77,6 +109,7 @@ PromptInput.propTypes = {
onChange: PropTypes.func.isRequired,
disableValidation: PropTypes.bool,
placeholder: PropTypes.string,
operationType: PropTypes.oneOf(["map", "reduce", "filter", "resolve", "other"]),
};
interface SchemaFormProps {

View File

@ -314,6 +314,7 @@ export const ReduceOperationComponent: React.FC<OperationComponentProps> = ({
<PromptInput
prompt={operation.prompt || ""}
onChange={handlePromptChange}
operationType="reduce"
/>
<OutputSchema
schema={schemaItems}

View File

@ -26,7 +26,7 @@ const ToastViewport = React.forwardRef<
ToastViewport.displayName = ToastPrimitives.Viewport.displayName;
const toastVariants = cva(
"group pointer-events-auto relative flex w-full items-center justify-between space-x-2 overflow-y-auto max-h-[80vh] whitespace-pre-wrap break-words rounded-md border p-4 pr-6 shadow-lg transition-all data-[swipe=cancel]:translate-x-0 data-[swipe=end]:translate-x-[var(--radix-toast-swipe-end-x)] data-[swipe=move]:translate-x-[var(--radix-toast-swipe-move-x)] data-[swipe=move]:transition-none data-[state=open]:animate-in data-[state=closed]:animate-out data-[swipe=end]:animate-out data-[state=closed]:fade-out-80 data-[state=closed]:slide-out-to-right-full data-[state=open]:slide-in-from-top-full data-[state=open]:sm:slide-in-from-bottom-full",
"group pointer-events-auto relative flex w-full items-start justify-between space-x-2 overflow-hidden max-h-[80vh] whitespace-pre-wrap break-words rounded-md border p-4 pr-6 shadow-lg transition-all data-[swipe=cancel]:translate-x-0 data-[swipe=end]:translate-x-[var(--radix-toast-swipe-end-x)] data-[swipe=move]:translate-x-[var(--radix-toast-swipe-move-x)] data-[swipe=move]:transition-none data-[state=open]:animate-in data-[state=closed]:animate-out data-[swipe=end]:animate-out data-[state=closed]:fade-out-80 data-[state=closed]:slide-out-to-right-full data-[state=open]:slide-in-from-top-full data-[state=open]:sm:slide-in-from-bottom-full",
{
variants: {
variant: {

View File

@ -22,29 +22,33 @@ export function Toaster() {
const descId = `toast-desc-${id}`;
return (
<Toast key={id} {...props}>
<div className="grid gap-1">
{title && <ToastTitle>{title}</ToastTitle>}
<div className="grid gap-1 flex-1 overflow-hidden">
<div className="flex items-center gap-2 shrink-0">
{title && <ToastTitle>{title}</ToastTitle>}
{description && (
<Button
variant="ghost"
size="sm"
className="h-5 w-5 p-0 opacity-70 hover:opacity-100"
onClick={() => {
const el = document.getElementById(descId);
const text = el?.textContent ?? "";
if (text) {
navigator.clipboard.writeText(text);
}
}}
title="Copy message"
>
<Copy size={12} />
</Button>
)}
</div>
{description && (
<ToastDescription id={descId}>{description}</ToastDescription>
<div className="overflow-y-auto max-h-[60vh]">
<ToastDescription id={descId}>{description}</ToastDescription>
</div>
)}
</div>
{description && (
<Button
variant="ghost"
size="sm"
className="h-6 px-2 text-xs"
onClick={() => {
const el = document.getElementById(descId);
const text = el?.textContent ?? "";
if (text) {
navigator.clipboard.writeText(text);
}
}}
title="Copy message"
>
<Copy size={12} />
</Button>
)}
{action}
<ToastClose />
</Toast>

View File

@ -29,6 +29,7 @@ interface PipelineState {
validatorPrompt: string;
} | null;
isLoadingOutputs: boolean;
isDecomposing: boolean;
numOpRun: number;
pipelineName: string;
sampleSize: number | null;
@ -59,6 +60,7 @@ interface PipelineContextType extends PipelineState {
} | null>
>;
setIsLoadingOutputs: React.Dispatch<React.SetStateAction<boolean>>;
setIsDecomposing: React.Dispatch<React.SetStateAction<boolean>>;
setNumOpRun: React.Dispatch<React.SetStateAction<number>>;
setPipelineName: React.Dispatch<React.SetStateAction<string>>;
setSampleSize: React.Dispatch<React.SetStateAction<number | null>>;
@ -83,6 +85,10 @@ interface PipelineContextType extends PipelineState {
setExtraPipelineSettings: React.Dispatch<
React.SetStateAction<Record<string, unknown> | null>
>;
// Ref for triggering decomposition from OperationCard (using ref to avoid infinite loops)
onRequestDecompositionRef: React.MutableRefObject<
((operationId: string, operationName: string) => void) | null
>;
}
const PipelineContext = createContext<PipelineContextType | undefined>(
@ -290,6 +296,7 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
localStorageKeys.IS_LOADING_OUTPUTS_KEY,
false
),
isDecomposing: false, // Transient state, not persisted
numOpRun: loadFromLocalStorage(localStorageKeys.NUM_OP_RUN_KEY, 0),
pipelineName: loadFromLocalStorage(
localStorageKeys.PIPELINE_NAME_KEY,
@ -333,6 +340,11 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
const stateRef = useRef(state);
const [isMounted, setIsMounted] = useState(false);
// Ref for triggering decomposition from OperationCard (using ref to avoid infinite loops)
const onRequestDecompositionRef = useRef<
((operationId: string, operationName: string) => void) | null
>(null);
useEffect(() => {
stateRef.current = state;
}, [state]);
@ -423,6 +435,7 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
output: null,
terminalOutput: "",
isLoadingOutputs: false,
isDecomposing: false,
numOpRun: 0,
pipelineName: mockPipelineName,
sampleSize: mockSampleSize,
@ -529,6 +542,10 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
(value) => setStateAndUpdate("isLoadingOutputs", value),
[setStateAndUpdate]
),
setIsDecomposing: useCallback(
(value) => setStateAndUpdate("isDecomposing", value),
[setStateAndUpdate]
),
setNumOpRun: useCallback(
(value) => setStateAndUpdate("numOpRun", value),
[setStateAndUpdate]
@ -589,6 +606,7 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
(value) => setStateAndUpdate("extraPipelineSettings", value),
[setStateAndUpdate]
),
onRequestDecompositionRef,
};
return (

View File

@ -0,0 +1,104 @@
import { useState, useEffect } from "react";
import axios from "axios";
import { DecomposeResult, DecomposeRequest } from "@/app/types";
import { API_ROUTES } from "@/app/api/constants";
interface UseDecomposeProps {
onComplete?: (result: DecomposeResult) => void;
onError?: (error: string) => void;
pollInterval?: number;
}
export function useDecompose({
onComplete,
onError,
pollInterval = 2000,
}: UseDecomposeProps = {}) {
const [taskId, setTaskId] = useState<string | null>(null);
const [result, setResult] = useState<DecomposeResult | null>(null);
const [error, setError] = useState<string | null>(null);
const [isLoading, setIsLoading] = useState(false);
const submitTask = async (request: DecomposeRequest) => {
try {
setIsLoading(true);
setError(null);
setResult(null);
const response = await axios.post<{ task_id: string }>(
API_ROUTES.DECOMPOSE.SUBMIT,
request
);
setTaskId(response.data.task_id);
} catch (err) {
const errorMessage =
err instanceof Error ? err.message : "Failed to submit decompose task";
setError(errorMessage);
onError?.(errorMessage);
setIsLoading(false);
}
};
const cancelTask = async () => {
if (!taskId) return;
try {
await axios.post(API_ROUTES.DECOMPOSE.CANCEL(taskId));
setTaskId(null);
setIsLoading(false);
} catch (err) {
const errorMessage =
err instanceof Error ? err.message : "Failed to cancel task";
setError(errorMessage);
onError?.(errorMessage);
}
};
useEffect(() => {
if (!taskId) return;
const pollTask = async () => {
try {
const response = await axios.get<DecomposeResult>(
API_ROUTES.DECOMPOSE.STATUS(taskId)
);
setResult(response.data);
if (
["completed", "failed", "cancelled"].includes(response.data.status)
) {
setTaskId(null);
setIsLoading(false);
if (response.data.status === "completed") {
onComplete?.(response.data);
} else if (response.data.status === "failed" && response.data.error) {
setError(response.data.error);
onError?.(response.data.error);
}
}
} catch (err) {
const errorMessage =
err instanceof Error ? err.message : "Failed to fetch task status";
setError(errorMessage);
onError?.(errorMessage);
setTaskId(null);
setIsLoading(false);
}
};
const interval = setInterval(pollTask, pollInterval);
return () => clearInterval(interval);
}, [taskId, onComplete, onError, pollInterval]);
return {
submitTask,
cancelTask,
result,
error,
isLoading,
isRunning: !!taskId,
};
}

View File

@ -0,0 +1,145 @@
import { useState, useRef, useCallback } from "react";
import { DecomposeResult } from "@/app/types";
interface DecomposeWebSocketMessage {
type: "output" | "result" | "error";
data?: any;
message?: string;
traceback?: string;
}
interface UseDecomposeWebSocketProps {
namespace: string;
onOutput?: (output: string) => void;
onComplete?: (result: DecomposeResult) => void;
onError?: (error: string) => void;
}
export function useDecomposeWebSocket({
namespace,
onOutput,
onComplete,
onError,
}: UseDecomposeWebSocketProps) {
const [isConnected, setIsConnected] = useState(false);
const [isRunning, setIsRunning] = useState(false);
const ws = useRef<WebSocket | null>(null);
const connect = useCallback((): Promise<void> => {
return new Promise((resolve, reject) => {
if (ws.current?.readyState === WebSocket.OPEN) {
resolve();
return;
}
if (!namespace) {
reject(new Error("Namespace is required for WebSocket connection"));
return;
}
const wsUrl = `${
process.env.NEXT_PUBLIC_BACKEND_HTTPS === "true" ? "wss://" : "ws://"
}${process.env.NEXT_PUBLIC_BACKEND_HOST}:${
process.env.NEXT_PUBLIC_BACKEND_PORT
}/ws/decompose/${namespace}`;
ws.current = new WebSocket(wsUrl);
ws.current.onopen = () => {
setIsConnected(true);
resolve();
};
ws.current.onclose = () => {
setIsConnected(false);
setIsRunning(false);
};
ws.current.onerror = (error: Event) => {
console.error("Decompose WebSocket error:", error);
onError?.("WebSocket connection failed");
reject(new Error("WebSocket connection failed"));
};
ws.current.onmessage = (event) => {
try {
const message: DecomposeWebSocketMessage = JSON.parse(event.data);
if (message.type === "output") {
onOutput?.(message.data);
} else if (message.type === "result") {
setIsRunning(false);
// Convert the result data to DecomposeResult format
const result: DecomposeResult = {
task_id: "", // Not used in WebSocket mode
status: "completed",
decomposed_operations: message.data.decomposed_operations,
winning_directive: message.data.winning_directive,
candidates_evaluated: message.data.candidates_evaluated,
original_outputs: message.data.original_outputs,
decomposed_outputs: message.data.decomposed_outputs,
comparison_rationale: message.data.comparison_rationale,
cost: message.data.cost,
created_at: new Date().toISOString(),
completed_at: new Date().toISOString(),
};
onComplete?.(result);
// Close the connection after receiving result
ws.current?.close();
} else if (message.type === "error") {
setIsRunning(false);
const errorMsg = message.data || message.message || "Unknown error";
onError?.(errorMsg);
ws.current?.close();
}
} catch (error) {
console.error("Error parsing WebSocket message:", error);
onError?.("Failed to parse WebSocket message");
}
};
});
}, [namespace, onOutput, onComplete, onError]);
const startDecomposition = useCallback(
async (yamlConfig: string, stepName: string, opName: string) => {
try {
await connect();
setIsRunning(true);
if (ws.current?.readyState === WebSocket.OPEN) {
ws.current.send(
JSON.stringify({
yaml_config: yamlConfig,
step_name: stepName,
op_name: opName,
})
);
}
} catch (error) {
setIsRunning(false);
throw error;
}
},
[connect]
);
const cancel = useCallback(() => {
if (ws.current?.readyState === WebSocket.OPEN) {
ws.current.send(JSON.stringify("kill"));
}
}, []);
const disconnect = useCallback(() => {
if (ws.current) {
ws.current.close();
}
}, []);
return {
isConnected,
isRunning,
startDecomposition,
cancel,
disconnect,
};
}

View File

@ -6,7 +6,8 @@ export function cn(...inputs: ClassValue[]) {
}
export function canBeOptimized(operationType: string) {
return ["resolve", "map", "reduce", "filter"].includes(operationType);
// Only map operations can be decomposed
return operationType === "map";
}
export const generateId = () => {