Compare commits

...

3 Commits

59 changed files with 3148 additions and 596 deletions

View File

@ -1,11 +1,15 @@
__version__ = "0.2.6"
import warnings
import litellm
from docetl.runner import DSLRunner
from docetl.optimizer import Optimizer
from docetl.apis.pd_accessors import SemanticAccessor
# Drop unsupported params for models like gpt-5 that don't support temperature=0
litellm.drop_params = True
# TODO: Remove after https://github.com/BerriAI/litellm/issues/7560 is fixed
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._config")

View File

@ -1,4 +1,5 @@
import os
import re
import threading
import time
from io import StringIO
@ -12,6 +13,63 @@ from docetl.utils import StageType, get_stage_description
install(show_locals=False)
# ANSI escape sequences to strip (cursor movement, line clearing, cursor visibility)
ANSI_CURSOR_PATTERNS = re.compile(
r"\x1b\[\?25[lh]" # Hide/show cursor
r"|\x1b\[\d*[ABCD]" # Cursor up/down/forward/back
r"|\x1b\[\d*[JK]" # Clear screen/line (but we handle \x1b[2K specially)
r"|\x1b\[s|\x1b\[u" # Save/restore cursor position
)
def process_carriage_returns(text: str) -> str:
"""
Process terminal control sequences for buffer-captured output.
Rich's spinner uses ANSI sequences for:
- `\r` - carriage return (move to start of line)
- `\x1b[2K` - clear entire line
- `\x1b[?25l/h` - hide/show cursor
- `\x1b[1A` - move cursor up
When captured to a buffer (not a real terminal), we simulate this by:
1. Handling `\r` as "replace from start of line"
2. Stripping cursor movement/visibility sequences
3. Keeping only meaningful content
"""
# First, strip cursor visibility and movement sequences
text = ANSI_CURSOR_PATTERNS.sub("", text)
# Process line by line
lines = text.split("\n")
processed_lines = []
for line in lines:
# Handle carriage returns - keep only content after the last \r
if "\r" in line:
segments = line.split("\r")
# Keep the last non-empty segment (after stripping ANSI clear codes)
for segment in reversed(segments):
# Strip the clear line code if present
cleaned = segment.replace("\x1b[2K", "").strip()
if cleaned:
# Keep the segment but preserve any color codes
processed_lines.append(segment.replace("\x1b[2K", ""))
break
else:
# All segments were empty, skip this line entirely
pass
else:
# No carriage return, keep the line as-is
if line.strip() or line == "": # Keep empty lines between content
processed_lines.append(line)
# Remove trailing empty lines
while processed_lines and not processed_lines[-1].strip():
processed_lines.pop()
return "\n".join(processed_lines)
class ThreadSafeConsole(Console):
def __init__(self, *args, **kwargs):
@ -24,11 +82,18 @@ class ThreadSafeConsole(Console):
self.optimizer_rationale = None
def get_output(self):
# return self.export_text(styles=True)
"""
Get the output from the buffer, processing carriage returns.
Rich's spinner uses carriage returns to overwrite lines in place.
We process these to only keep the latest content, preventing
duplicate spinner frames from flooding the output.
"""
value = self.buffer.getvalue()
self.buffer.truncate(0)
self.buffer.seek(0)
return value
# Process carriage returns to handle spinner overwrites
return process_carriage_returns(value)
def status(
self,
@ -36,10 +101,15 @@ class ThreadSafeConsole(Console):
*,
spinner: str = "dots",
spinner_style: "StyleType" = "status.spinner",
speed: float = 0.1, # Much slower speed
refresh_per_second: float = 0.5, # Much slower refresh rate (every 2 seconds)
speed: float = 1.0,
refresh_per_second: float = 4,
) -> "Status":
"""
Return a Rich Status with animation.
The carriage returns from the spinner animation are processed
in get_output() to prevent duplicate lines.
"""
status_renderable = Status(
status,
console=self,

View File

@ -849,13 +849,9 @@ class MOARSearch:
response = litellm.completion(
model=self.model,
messages=messages,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ExpandResponseFormat,
)
call_cost = response._hidden_params["response_cost"]
call_cost = response._hidden_params.get("response_cost", 0)
with self.tree_lock:
self.console.log(
f"[green]💰 Adding LLM call cost:[/green] ${call_cost:.4f} (total before: ${self.total_search_cost:.4f})"

View File

@ -153,10 +153,11 @@ def run_moar_optimization(
"optimizer_config must contain 'max_iterations' (number of MOARSearch iterations) for MOAR optimizer"
)
model = optimizer_config.get("model")
# Use rewrite_agent_model for consistency with rest of codebase, fallback to model for backwards compatibility
model = optimizer_config.get("rewrite_agent_model") or optimizer_config.get("model")
if not model:
raise ValueError(
"optimizer_config must contain 'model' (LLM model name for directive instantiation) for MOAR optimizer"
"optimizer_config must contain 'rewrite_agent_model' (LLM model name for directive instantiation) for MOAR optimizer"
)
# Optional parameters

View File

@ -55,7 +55,7 @@ class Optimizer:
def __init__(
self,
runner: "DSLRunner",
rewrite_agent_model: str = "gpt-4o",
rewrite_agent_model: str = "gpt-5.1",
judge_agent_model: str = "gpt-4o-mini",
litellm_kwargs: dict[str, Any] = {},
resume: bool = False,
@ -67,7 +67,7 @@ class Optimizer:
Args:
yaml_file (str): Path to the YAML configuration file.
model (str): The name of the language model to use. Defaults to "gpt-4o".
model (str): The name of the language model to use. Defaults to "gpt-5.1".
resume (bool): Whether to resume optimization from a previous run. Defaults to False.
timeout (int): Timeout in seconds for operations. Defaults to 60.

View File

@ -0,0 +1,926 @@
"""
Fast decomposition analyzer for map operations.
This module provides a fast way to decompose map operations using directives,
running candidates on samples, and selecting the best via pairwise LLM comparison.
"""
import json
import os
import tempfile
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy
from typing import Any
import litellm
from litellm import completion, model_cost
from rich.console import Console
from docetl.console import get_console
from docetl.reasoning_optimizer.directives import (
ChainingDirective,
ClarifyInstructionsDirective,
DeterministicDocCompressionDirective,
DocumentChunkingDirective,
GleaningDirective,
IsolatingSubtasksDirective,
)
from docetl.runner import DSLRunner
from docetl.utils import completion_cost, count_tokens
# Drop unsupported params for models like gpt-5 that don't support temperature=0
litellm.drop_params = True
# Base directives always applicable to map operations
BASE_MAP_DIRECTIVES = [
ChainingDirective(),
IsolatingSubtasksDirective(),
GleaningDirective(),
ClarifyInstructionsDirective(),
]
# Directives that depend on document size
CHUNKING_DIRECTIVE = DocumentChunkingDirective()
COMPRESSION_DIRECTIVE = DeterministicDocCompressionDirective()
# Threshold for enabling DeterministicDocCompression (in characters)
DOC_COMPRESSION_CHAR_THRESHOLD = 1000
# Threshold for enabling DocumentChunking (as fraction of context window)
DOC_CHUNKING_CONTEXT_THRESHOLD = 0.10 # 10%
class FastDecomposer:
"""
Fast decomposition of map operations using directives and pairwise comparison.
Instead of the full optimizer flow, this:
1. Tries multiple directives to generate candidate decompositions
2. Runs each candidate on sample documents
3. Uses LLM judge for pairwise comparison
4. Returns the winning decomposition
"""
def __init__(
self,
yaml_config_path: str,
optimizer_model: str = "gpt-5.1",
sample_size: int = 5,
litellm_kwargs: dict[str, Any] | None = None,
console: Console | None = None,
):
"""
Initialize the decomposer.
Args:
yaml_config_path: Path to the pipeline YAML config file
optimizer_model: LLM model to use for directive instantiation and judging
sample_size: Number of sample documents to run candidates on
litellm_kwargs: Additional kwargs to pass to litellm.completion
console: Rich console for output (uses default if not provided)
"""
self.yaml_config_path = yaml_config_path
self.optimizer_model = optimizer_model
self.sample_size = sample_size
self.litellm_kwargs = litellm_kwargs or {}
if "temperature" not in self.litellm_kwargs:
self.litellm_kwargs["temperature"] = 0.0
self.total_cost = 0.0
self.console = console or get_console()
# Load the config
import yaml
with open(yaml_config_path, "r") as f:
self.config = yaml.safe_load(f)
self.operators = self.config.get("operations", [])
self.default_model = self.config.get("default_model", "gpt-4o-mini")
self.intermediate_dir = (
self.config.get("pipeline", {}).get("output", {}).get("intermediate_dir")
)
def _log(self, message: str) -> None:
"""Log a message to the console."""
self.console.log(message)
def get_model_context_limit(self, model: str) -> int:
"""
Get the context window limit for a model.
Args:
model: The model name (e.g., 'gpt-4o', 'azure/gpt-4')
Returns:
Maximum number of input tokens the model can handle
"""
model_info = model_cost.get(model, {})
# Try without provider prefix if not found
if not model_info:
model_name = model.split("/")[-1]
model_info = model_cost.get(model_name, {})
return model_info.get("max_input_tokens", 128000) # Default to 128k
def get_avg_doc_size(
self, sample_data: list[dict[str, Any]], op_config: dict[str, Any]
) -> tuple[float, float]:
"""
Calculate the average document size in characters and tokens.
Extracts the document content from sample data based on the operation's
prompt template (looks for {{ input.field_name }} patterns).
Args:
sample_data: List of sample documents
op_config: The operation configuration
Returns:
Tuple of (avg_chars, avg_tokens) for the document content
"""
import re
if not sample_data:
return 0.0, 0.0
prompt = op_config.get("prompt", "")
model = op_config.get("model", self.default_model)
# Extract field names from prompt template ({{ input.field_name }})
field_pattern = r"\{\{\s*input\.(\w+)\s*\}\}"
fields = re.findall(field_pattern, prompt)
if not fields:
# Fallback: use all string values from the first document
if sample_data:
fields = [
k
for k, v in sample_data[0].items()
if isinstance(v, str) and len(v) > 100
]
total_chars = 0
total_tokens = 0
for doc in sample_data:
doc_content = ""
for field in fields:
if field in doc:
value = doc[field]
if isinstance(value, str):
doc_content += value
else:
doc_content += str(value)
total_chars += len(doc_content)
if doc_content:
total_tokens += count_tokens(doc_content, model)
n = len(sample_data)
return total_chars / n if n > 0 else 0.0, total_tokens / n if n > 0 else 0.0
def get_applicable_directives(
self,
sample_data: list[dict[str, Any]],
op_config: dict[str, Any],
) -> list:
"""
Get the list of directives applicable based on data characteristics.
- DocumentChunkingDirective: only if avg doc size > 10% of context window
- DeterministicDocCompressionDirective: only if avg doc size > 1000 chars
Args:
sample_data: List of sample documents
op_config: The operation configuration
Returns:
List of applicable directive instances
"""
directives = [] # Build list with priority ordering
model = op_config.get("model", self.default_model)
context_limit = self.get_model_context_limit(model)
avg_chars, avg_tokens = self.get_avg_doc_size(sample_data, op_config)
self._log(
f"Document analysis: avg_chars={avg_chars:.0f}, avg_tokens={avg_tokens:.0f}, "
f"context_limit={context_limit}"
)
# Add DeterministicDocCompression FIRST if doc size > 1000 chars (high priority)
if avg_chars > DOC_COMPRESSION_CHAR_THRESHOLD:
self._log(
f"[cyan]Enabling DeterministicDocCompression (priority)[/cyan] "
f"(avg {avg_chars:.0f} chars > {DOC_COMPRESSION_CHAR_THRESHOLD})"
)
directives.append(COMPRESSION_DIRECTIVE)
# Add base directives
directives.extend(BASE_MAP_DIRECTIVES)
# Add DocumentChunking if avg tokens > 10% of context window
token_threshold = context_limit * DOC_CHUNKING_CONTEXT_THRESHOLD
if avg_tokens > token_threshold:
self._log(
f"[cyan]Enabling DocumentChunking[/cyan] "
f"(avg {avg_tokens:.0f} tokens > {token_threshold:.0f} = 10% of {context_limit})"
)
directives.append(CHUNKING_DIRECTIVE)
else:
self._log(
f"[dim]Skipping DocumentChunking[/dim] "
f"(avg {avg_tokens:.0f} tokens <= {token_threshold:.0f} = 10% of {context_limit})"
)
return directives
def load_sample_data(self, step_name: str, op_name: str) -> list[dict[str, Any]]:
"""
Load sample input data for an operation.
For the first operation, loads from the dataset.
For subsequent operations, loads from the previous operation's intermediate output.
Args:
step_name: Name of the pipeline step
op_name: Name of the operation
Returns:
List of sample documents
"""
# Find the operation's position
op_names = [op.get("name") for op in self.operators]
try:
op_idx = op_names.index(op_name)
except ValueError:
raise ValueError(f"Operation '{op_name}' not found in config")
if op_idx == 0:
# First operation - load from dataset
datasets = self.config.get("datasets", {})
# Get the first dataset (or the one used by this step)
for dataset_name, dataset_config in datasets.items():
dataset_path = dataset_config.get("path")
if dataset_path and os.path.exists(dataset_path):
with open(dataset_path, "r") as f:
data = json.load(f)
return data[: self.sample_size]
raise FileNotFoundError("No dataset found in config")
else:
# Load from previous operation's intermediate output
prev_op_name = op_names[op_idx - 1]
output_path = os.path.join(
self.intermediate_dir, step_name, f"{prev_op_name}.json"
)
if not os.path.exists(output_path):
raise FileNotFoundError(
f"No intermediate output found at {output_path}. "
"Run the previous operation first."
)
with open(output_path, "r") as f:
data = json.load(f)
return data[: self.sample_size]
def get_input_file_path(self) -> str:
"""Get the input file path for the pipeline."""
datasets = self.config.get("datasets", {})
for dataset_name, dataset_config in datasets.items():
path = dataset_config.get("path")
if path:
return path
return ""
def generate_candidates(
self,
op_name: str,
sample_data: list[dict[str, Any]],
target_op: dict[str, Any],
) -> list[dict[str, Any]]:
"""
Generate candidate decompositions using directives.
Directives are selected based on data characteristics:
- DocumentChunkingDirective: only if avg doc size > 10% of context window
- DeterministicDocCompressionDirective: only if avg doc size > 1000 chars
Args:
op_name: Name of the operation to decompose
sample_data: Sample data for analyzing document characteristics
target_op: The target operation configuration
Returns:
List of candidate dictionaries with 'name', 'ops', 'cost' keys
"""
candidates = []
# Add original as baseline
self._log("Adding original operation as baseline candidate")
candidates.append(
{
"name": "original",
"ops": deepcopy(self.operators),
"cost": 0.0,
"error": None,
}
)
input_file_path = self.get_input_file_path()
# Get applicable directives based on data characteristics
applicable_directives = self.get_applicable_directives(sample_data, target_op)
self._log(
f"Generating candidates using {len(applicable_directives)} directives..."
)
for i, directive in enumerate(applicable_directives, 1):
with self.console.status(
f"[bold cyan]({i}/{len(applicable_directives)}) Trying directive: {directive.name}...[/bold cyan]",
spinner="dots",
):
try:
new_ops_list, _, cost = directive.instantiate(
operators=deepcopy(self.operators),
target_ops=[op_name],
agent_llm=self.optimizer_model,
message_history=[],
global_default_model=self.default_model,
input_file_path=input_file_path,
)
self.total_cost += cost
self._log(
f" [green]✓[/green] {directive.name} generated {len(new_ops_list)} operations (cost: ${cost:.4f})"
)
candidates.append(
{
"name": directive.name,
"ops": new_ops_list,
"cost": cost,
"error": None,
}
)
except Exception as e:
# Directive not applicable or failed - skip it
self._log(f" [red]✗[/red] {directive.name} failed: {str(e)}")
candidates.append(
{
"name": directive.name,
"ops": None,
"cost": 0.0,
"error": str(e),
}
)
return candidates
def extract_ops_to_run(
self, ops_list: list[dict], original_op_name: str
) -> list[dict]:
"""
Extract the operations that replaced the original operation.
Args:
ops_list: The full transformed operations list
original_op_name: Name of the original operation that was decomposed
Returns:
List of operations that should be run on samples
"""
# Find ops that are new (not in original) or modified
original_names = {op["name"] for op in self.operators}
# Find the position where the original op was
original_idx = None
for i, op in enumerate(self.operators):
if op["name"] == original_op_name:
original_idx = i
break
if original_idx is None:
return ops_list
# Find new ops (those not in original list)
new_ops = []
for op in ops_list:
if op["name"] not in original_names or op["name"] == original_op_name:
new_ops.append(op)
return new_ops if new_ops else [ops_list[original_idx]]
def run_candidate_on_samples(
self,
candidate: dict[str, Any],
sample_data: list[dict[str, Any]],
original_op_name: str,
) -> list[dict[str, Any]]:
"""
Run a candidate's operations on sample data.
Args:
candidate: Candidate dictionary with 'ops' key
sample_data: List of sample documents
original_op_name: Name of the original operation
Returns:
List of output documents
"""
if candidate["ops"] is None:
return []
# Extract ops to run
ops_to_run = self.extract_ops_to_run(candidate["ops"], original_op_name)
if not ops_to_run:
return []
# Create a minimal config for running these ops
# Write sample data to a temp file
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(sample_data, f)
temp_input_path = f.name
# Create a temp output file (required by DSLRunner validation)
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
temp_output_path = f.name
try:
# Create a minimal pipeline config
temp_config = {
"default_model": self.default_model,
"operations": ops_to_run,
"datasets": {
"sample_data": {
"type": "file",
"path": temp_input_path,
}
},
"pipeline": {
"steps": [
{
"name": "decompose_test",
"input": "sample_data",
"operations": [op["name"] for op in ops_to_run],
}
],
"output": {
"type": "file",
"path": temp_output_path,
},
},
}
# Create runner and execute
runner = DSLRunner(temp_config, max_threads=4)
# Run operations sequentially on the data
current_data = sample_data
for op_config in ops_to_run:
current_data = runner._run_operation(op_config, current_data)
self.total_cost += runner.total_cost
return current_data
finally:
# Clean up temp files
for temp_path in [temp_input_path, temp_output_path]:
try:
os.unlink(temp_path)
except Exception:
pass
def pairwise_compare(
self,
candidate_a: dict[str, Any],
candidate_b: dict[str, Any],
original_prompt: str,
output_schema: dict[str, Any],
) -> dict[str, Any]:
"""
Compare two candidates using LLM judge.
Args:
candidate_a: First candidate with 'name', 'outputs' keys
candidate_b: Second candidate with 'name', 'outputs' keys
original_prompt: The original operation's prompt
output_schema: The expected output schema
Returns:
The winning candidate
"""
# If one has no outputs, the other wins
if not candidate_a.get("outputs"):
return candidate_b
if not candidate_b.get("outputs"):
return candidate_a
system_prompt = """You are an expert judge comparing outputs from two data processing pipeline variants.
Your task is to determine which variant produces BETTER outputs based on:
1. **Completeness**: Does the output contain all required information?
2. **Accuracy**: Is the extracted/generated information correct?
3. **Consistency**: Are the outputs consistent across different samples?
4. **Quality**: Is the output well-structured and useful?
Be objective and focus on the actual output quality, not the approach used."""
user_prompt = f"""Compare outputs from two pipeline variants for this task:
## Original Task
**Prompt:**
```
{original_prompt[:2000]}{"..." if len(original_prompt) > 2000 else ""}
```
**Expected Output Schema:**
```json
{json.dumps(output_schema, indent=2)}
```
## Variant A: {candidate_a["name"]}
**Sample Outputs:**
```json
{json.dumps(candidate_a["outputs"][:3], indent=2, default=str)}
```
## Variant B: {candidate_b["name"]}
**Sample Outputs:**
```json
{json.dumps(candidate_b["outputs"][:3], indent=2, default=str)}
```
## Your Task
Which variant produces better outputs? Consider completeness, accuracy, consistency, and quality.
Respond with your analysis and final choice."""
response = completion(
model=self.optimizer_model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format={
"type": "json_schema",
"json_schema": {
"name": "comparison_result",
"strict": True,
"schema": {
"type": "object",
"properties": {
"winner": {
"type": "string",
"enum": ["A", "B"],
"description": "Which variant is better: A or B",
},
"rationale": {
"type": "string",
"description": "Explanation of why this variant is better",
},
"a_strengths": {
"type": "array",
"items": {"type": "string"},
"description": "Strengths of variant A",
},
"b_strengths": {
"type": "array",
"items": {"type": "string"},
"description": "Strengths of variant B",
},
},
"required": [
"winner",
"rationale",
"a_strengths",
"b_strengths",
],
"additionalProperties": False,
},
},
},
**self.litellm_kwargs,
)
cost = completion_cost(response)
self.total_cost += cost
result = json.loads(response.choices[0].message.content)
winner = candidate_a if result["winner"] == "A" else candidate_b
winner["comparison_rationale"] = result["rationale"]
return winner
def decompose(
self,
step_name: str,
op_name: str,
) -> tuple[list[dict[str, Any]], str, int, float]:
"""
Decompose an operation using fast directive-based approach.
This is the main entry point. It:
1. Generates candidate decompositions using directives
2. Runs each on sample data
3. Uses pairwise LLM comparison to select the best
4. Returns the winning decomposition
Args:
step_name: Name of the pipeline step
op_name: Name of the operation to decompose
Returns:
Dict with:
- decomposed_ops: List of operations that replace the original
- winning_directive: Name of the winning directive
- candidates_evaluated: Number of candidates that were compared
- original_outputs: Sample outputs from the original operation
- decomposed_outputs: Sample outputs from the winning decomposition
- comparison_rationale: LLM's explanation of why the winner was chosen
- cost: Total LLM API cost
"""
self.console.rule("[bold blue]Fast Decomposition[/bold blue]")
self._log(f"[bold]Target operation:[/bold] {op_name}")
# Find the target operation config
target_op = None
for op in self.operators:
if op["name"] == op_name:
target_op = op
break
if target_op is None:
raise ValueError(f"Operation '{op_name}' not found in config")
# Verify it's a map operation
if target_op.get("type") != "map":
raise ValueError(
f"Operation '{op_name}' is type '{target_op.get('type')}', "
"but fast decomposition only supports 'map' operations"
)
# Load sample data
self._log(f"Loading sample data from step '{step_name}'...")
sample_data = self.load_sample_data(step_name, op_name)
self._log(f"[green]✓[/green] Loaded {len(sample_data)} sample documents")
# Generate candidates
self.console.rule("[bold]Generating Candidates[/bold]")
candidates = self.generate_candidates(op_name, sample_data, target_op)
# Filter out failed candidates
valid_candidates = [c for c in candidates if c["ops"] is not None]
self._log(f"[bold]{len(valid_candidates)}[/bold] valid candidates generated")
if len(valid_candidates) < 2:
# Only original (or nothing) - return original
self._log(
"[yellow]No alternative decompositions generated. Keeping original.[/yellow]"
)
return {
"decomposed_ops": self.operators,
"winning_directive": "original",
"candidates_evaluated": len(valid_candidates),
"original_outputs": [],
"decomposed_outputs": [],
"comparison_rationale": "No alternative decompositions were generated.",
"cost": self.total_cost,
}
# Run each candidate on samples IN PARALLEL
self.console.rule("[bold]Running Candidates on Samples[/bold]")
self._log(f"Running {len(valid_candidates)} candidates in parallel...")
def run_single_candidate(candidate):
"""Run a single candidate and return results."""
name = candidate["name"]
try:
outputs = self.run_candidate_on_samples(candidate, sample_data, op_name)
return {"name": name, "outputs": outputs, "error": None}
except Exception as e:
error_tb = traceback.format_exc()
return {
"name": name,
"outputs": [],
"error": str(e),
"traceback": error_tb,
}
# Run all candidates in parallel
with self.console.status(
"[bold cyan]Running candidates on samples...[/bold cyan]", spinner="dots"
):
with ThreadPoolExecutor(max_workers=len(valid_candidates)) as executor:
future_to_candidate = {
executor.submit(run_single_candidate, c): c
for c in valid_candidates
}
for future in as_completed(future_to_candidate):
candidate = future_to_candidate[future]
result = future.result()
if result["error"]:
candidate["outputs"] = []
candidate["run_error"] = result["error"]
self._log(
f" [red]✗[/red] {result['name']} failed: {result['error']}"
)
else:
candidate["outputs"] = result["outputs"]
self._log(
f" [green]✓[/green] {result['name']}: {len(result['outputs'])} outputs"
)
# Filter to candidates with outputs
candidates_with_outputs = [c for c in valid_candidates if c.get("outputs")]
self._log(
f"[bold]{len(candidates_with_outputs)}[/bold] candidates produced outputs"
)
if not candidates_with_outputs:
# All failed - return original
self._log("[red]All decomposition candidates failed to execute.[/red]")
# Log the errors for debugging
for c in valid_candidates:
if c.get("run_error"):
self._log(f" [red]{c['name']}:[/red] {c['run_error']}")
return {
"decomposed_ops": self.operators,
"winning_directive": "original",
"candidates_evaluated": 0,
"original_outputs": [],
"decomposed_outputs": [],
"comparison_rationale": "All decomposition candidates failed to execute.",
"cost": self.total_cost,
}
# Find the original candidate's outputs
original_candidate = next(
(c for c in candidates_with_outputs if c["name"] == "original"), None
)
original_outputs = original_candidate["outputs"] if original_candidate else []
# Parallel pairwise comparison: compare all candidates against original
self.console.rule("[bold]Pairwise Comparison[/bold]")
original_prompt = target_op.get("prompt", "")
output_schema = target_op.get("output", {}).get("schema", {})
# If only original exists, it wins by default
if len(candidates_with_outputs) == 1:
winner = candidates_with_outputs[0]
elif original_candidate is None:
# No original - just pick the first candidate
winner = candidates_with_outputs[0]
else:
# Compare all non-original candidates against original IN PARALLEL
challengers = [
c for c in candidates_with_outputs if c["name"] != "original"
]
if not challengers:
winner = original_candidate
else:
self._log(
f"Comparing {len(challengers)} candidates against original in parallel..."
)
def compare_against_original(challenger):
"""Compare a single challenger against original."""
try:
result = self.pairwise_compare(
original_candidate,
challenger,
original_prompt,
output_schema,
)
won = result["name"] == challenger["name"]
return {
"challenger": challenger,
"won": won,
"rationale": result.get("comparison_rationale", ""),
}
except Exception as e:
return {
"challenger": challenger,
"won": False,
"error": str(e),
}
# Run all comparisons in parallel
comparison_results = []
with self.console.status(
f"[bold cyan]Running {len(challengers)} comparisons in parallel...[/bold cyan]",
spinner="dots",
):
with ThreadPoolExecutor(max_workers=len(challengers)) as executor:
future_to_challenger = {
executor.submit(compare_against_original, c): c
for c in challengers
}
for future in as_completed(future_to_challenger):
result = future.result()
comparison_results.append(result)
challenger_name = result["challenger"]["name"]
if result.get("error"):
self._log(
f" [red]✗[/red] {challenger_name} vs original: error - {result['error']}"
)
elif result["won"]:
self._log(
f" [green]✓[/green] {challenger_name} beat original"
)
else:
self._log(
f" [dim]○[/dim] {challenger_name} lost to original"
)
# Find winners (candidates that beat original)
winners = [r for r in comparison_results if r.get("won")]
if not winners:
# Original beats all challengers
winner = original_candidate
self._log("[bold]Original wins against all challengers[/bold]")
elif len(winners) == 1:
# Single winner
winner = winners[0]["challenger"]
winner["comparison_rationale"] = winners[0].get("rationale", "")
else:
# Multiple winners beat original - run tiebreaker comparisons in parallel
self._log(
f"[bold]{len(winners)} candidates beat original - running tiebreaker...[/bold]"
)
# Compare all winners against each other in parallel (round-robin)
winner_candidates = [w["challenger"] for w in winners]
win_counts = {c["name"]: 0 for c in winner_candidates}
# Generate all pairwise matchups
matchups = []
for i, a in enumerate(winner_candidates):
for b in winner_candidates[i + 1 :]:
matchups.append((a, b))
if matchups:
def run_matchup(matchup):
a, b = matchup
try:
result = self.pairwise_compare(
a, b, original_prompt, output_schema
)
return {
"winner": result["name"],
"a": a["name"],
"b": b["name"],
}
except Exception:
return {
"winner": a["name"],
"a": a["name"],
"b": b["name"],
} # Default to first
with self.console.status(
f"[bold cyan]Running {len(matchups)} tiebreaker comparisons...[/bold cyan]",
spinner="dots",
):
with ThreadPoolExecutor(
max_workers=len(matchups)
) as executor:
for result in executor.map(run_matchup, matchups):
win_counts[result["winner"]] += 1
self._log(
f" [dim]{result['a']} vs {result['b']}{result['winner']}[/dim]"
)
# Pick candidate with most wins
best_name = max(win_counts, key=win_counts.get)
winner = next(
c for c in winner_candidates if c["name"] == best_name
)
self._log(
f"[bold]Tiebreaker winner: {best_name} ({win_counts[best_name]} wins)[/bold]"
)
# Extract the decomposed operations
decomposed_ops = self.extract_ops_to_run(winner["ops"], op_name)
# Final summary
self.console.rule("[bold green]Decomposition Complete[/bold green]")
self._log(f"[bold]Winner:[/bold] [green]{winner['name']}[/green]")
self._log(f"[bold]Candidates evaluated:[/bold] {len(candidates_with_outputs)}")
self._log(f"[bold]New operations:[/bold] {len(decomposed_ops)}")
self._log(f"[bold]Total cost:[/bold] ${self.total_cost:.4f}")
return {
"decomposed_ops": decomposed_ops,
"winning_directive": winner["name"],
"candidates_evaluated": len(candidates_with_outputs),
"original_outputs": original_outputs,
"decomposed_outputs": winner.get("outputs", []),
"comparison_rationale": winner.get("comparison_rationale", ""),
"cost": self.total_cost,
}

View File

@ -0,0 +1,337 @@
"""
Fast should_optimize analyzer using a single LLM call.
This module provides a lightweight alternative to the full MapOptimizer/ReduceOptimizer/JoinOptimizer
flow for quickly determining if an operation should be decomposed.
"""
import json
import os
from typing import Any
import litellm
from litellm import completion, model_cost
from docetl.utils import completion_cost, count_tokens
# Drop unsupported params for models like gpt-5 that don't support temperature=0
litellm.drop_params = True
class FastShouldOptimizeAnalyzer:
"""
Analyzes whether an operation should be optimized using a single LLM call.
Instead of running the operation on sample data and using complex evaluation logic,
this reads cached outputs from intermediate files and makes one judgment call.
"""
def __init__(
self,
intermediate_dir: str,
optimizer_model: str = "gpt-5.1",
litellm_kwargs: dict[str, Any] | None = None,
):
"""
Initialize the analyzer.
Args:
intermediate_dir: Path to the directory containing intermediate outputs
optimizer_model: The LLM model to use for analysis (default: gpt-5.1)
litellm_kwargs: Additional kwargs to pass to litellm.completion
"""
self.intermediate_dir = intermediate_dir
self.optimizer_model = optimizer_model
self.litellm_kwargs = litellm_kwargs or {}
if "temperature" not in self.litellm_kwargs:
self.litellm_kwargs["temperature"] = 0.0
def load_operation_data(self, step_name: str, op_name: str) -> list[dict[str, Any]]:
"""
Load data from the intermediate file for an operation.
Args:
step_name: Name of the pipeline step
op_name: Name of the operation
Returns:
List of dictionaries
Raises:
FileNotFoundError: If the intermediate file doesn't exist
"""
output_path = os.path.join(self.intermediate_dir, step_name, f"{op_name}.json")
if not os.path.exists(output_path):
raise FileNotFoundError(
f"No output file found at {output_path}. "
"Run the operation first to generate outputs."
)
with open(output_path, "r") as f:
return json.load(f)
def find_previous_operation(
self, operations: list[dict[str, Any]], op_name: str
) -> str | None:
"""
Find the operation that comes before op_name in the pipeline.
Args:
operations: List of operation configs from the pipeline
op_name: Name of the current operation
Returns:
Name of the previous operation, or None if this is the first operation
"""
op_names = [op.get("name") for op in operations]
try:
idx = op_names.index(op_name)
if idx > 0:
return op_names[idx - 1]
except ValueError:
pass
return None
def get_max_context_tokens(self) -> int:
"""
Get the maximum input tokens for the optimizer model.
Returns:
Maximum number of input tokens the model can handle
"""
model_info = model_cost.get(self.optimizer_model, {})
# Try without provider prefix if not found
if not model_info:
model_name = self.optimizer_model.split("/")[-1]
model_info = model_cost.get(model_name, {})
return model_info.get("max_input_tokens", 128000) # Default to 128k
def calculate_samples_that_fit(
self,
op_config: dict[str, Any],
outputs: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""
Calculate how many output samples fit in the context window.
Reserves space for system prompt, operation config, and response buffer,
then fills remaining space with as many samples as possible.
Args:
op_config: The operation configuration dictionary
outputs: List of all output documents
Returns:
List of samples that fit in the context window
"""
max_tokens = self.get_max_context_tokens()
# Reserve tokens for fixed parts
system_prompt_tokens = 500
op_config_tokens = count_tokens(
json.dumps(op_config, default=str), self.optimizer_model
)
response_buffer = 2000
available_for_samples = (
max_tokens - system_prompt_tokens - op_config_tokens - response_buffer
)
# Collect samples that fit
samples_to_include = []
tokens_used = 0
for output in outputs:
sample_json = json.dumps(output, default=str)
sample_tokens = count_tokens(sample_json, self.optimizer_model)
if tokens_used + sample_tokens <= available_for_samples:
samples_to_include.append(output)
tokens_used += sample_tokens
else:
break
return samples_to_include
def build_analysis_prompt(
self,
op_config: dict[str, Any],
samples: list[dict[str, Any]],
) -> tuple[str, str]:
"""
Build the system prompt and user prompt for the analysis LLM call.
Args:
op_config: The operation configuration dictionary
samples: List of output samples to analyze
Returns:
Tuple of (system_prompt, user_prompt)
"""
system_prompt = """You are an expert at analyzing LLM-powered data processing operations.
Your task is to determine if an operation would benefit from being decomposed into multiple
smaller, focused operations (also known as "optimization" or "decomposition").
An operation SHOULD be decomposed when:
1. The prompt asks the LLM to do multiple distinct tasks that could be done separately
2. The task is complex enough that breaking it into sequential steps would improve accuracy
3. The outputs show inconsistency, incompleteness, or quality issues that iterative refinement could fix
4. Long documents need to be processed in chunks rather than all at once
5. The prompt asks for both extraction AND analysis/synthesis in one step
An operation should NOT be decomposed when:
1. It performs a single, focused task well
2. The outputs are consistently high quality and complete
3. The task is simple and atomic (e.g., simple classification, single field extraction)
4. The operation is already well-scoped and produces reliable results
Be conservative - only recommend decomposition if there's clear evidence it would help."""
output_schema = op_config.get("output", {}).get("schema", {})
prompt_template = op_config.get("prompt", "No prompt specified")
# Truncate very long prompts for display
if len(prompt_template) > 3000:
prompt_template = prompt_template[:3000] + "\n... [truncated]"
user_prompt = f"""Analyze this data processing operation and its outputs:
## Operation Configuration
**Name:** {op_config.get('name', 'unknown')}
**Type:** {op_config.get('type', 'unknown')}
**Prompt Template:**
```
{prompt_template}
```
**Output Schema:**
```json
{json.dumps(output_schema, indent=2)}
```
## Sample Outputs ({len(samples)} samples from the operation)
```json
{json.dumps(samples, indent=2, default=str)}
```
## Your Task
Based on the operation configuration and sample outputs, determine:
1. Should this operation be decomposed/optimized?
2. If yes, what specific improvements would help?
Consider the quality, completeness, and consistency of the outputs when making your assessment."""
return system_prompt, user_prompt
def analyze(
self,
op_config: dict[str, Any],
step_name: str,
op_name: str,
) -> tuple[str, list[dict[str, Any]], int, float]:
"""
Analyze whether an operation should be optimized.
This is the main entry point. It loads outputs, builds the prompt,
makes a single LLM call, and returns the assessment.
Args:
op_config: The operation configuration dictionary
step_name: Name of the pipeline step
op_name: Name of the operation
Returns:
Tuple of (rationale, output_samples, num_docs_analyzed, cost):
- rationale: Empty string if no optimization needed, explanation if it should be optimized
- output_samples: The samples that were analyzed
- num_docs_analyzed: Number of documents that fit in the LLM prompt
- cost: LLM API cost in USD
Raises:
ValueError: If the operation is not an LLM-powered map operation
"""
# Validate operation type - only LLM-powered map operations
op_type = op_config.get("type", "")
if op_type != "map":
raise ValueError(
f"should_optimize only supports 'map' operations, got '{op_type}'. "
"Only LLM-powered map operations can be analyzed for decomposition."
)
# Load outputs from intermediate file
outputs = self.load_operation_data(step_name, op_name)
if not outputs:
return "No output samples available for analysis.", [], 0, 0.0
# Calculate samples that fit in context
samples = self.calculate_samples_that_fit(op_config, outputs)
if not samples:
return "Could not fit any samples in context window.", outputs[:5], 0, 0.0
# Build prompt
system_prompt, user_prompt = self.build_analysis_prompt(op_config, samples)
# Make LLM call with structured output
response = completion(
model=self.optimizer_model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format={
"type": "json_schema",
"json_schema": {
"name": "optimization_analysis",
"strict": True,
"schema": {
"type": "object",
"properties": {
"should_optimize": {
"type": "boolean",
"description": "True if operation should be decomposed/optimized",
},
"rationale": {
"type": "string",
"description": "Explanation of why the operation should or should not be optimized",
},
"suggested_improvements": {
"type": "array",
"items": {"type": "string"},
"description": "Specific improvements if optimization is recommended (empty if not)",
},
},
"required": [
"should_optimize",
"rationale",
"suggested_improvements",
],
"additionalProperties": False,
},
},
},
**self.litellm_kwargs,
)
# Calculate cost
cost = completion_cost(response)
# Parse response
result = json.loads(response.choices[0].message.content)
num_docs_analyzed = len(samples)
if result["should_optimize"]:
# Build rationale string with improvements
rationale_parts = [result["rationale"]]
if result["suggested_improvements"]:
rationale_parts.append("\n\nSuggested improvements:")
for imp in result["suggested_improvements"]:
rationale_parts.append(f"- {imp}")
return "\n".join(rationale_parts), samples, num_docs_analyzed, cost
else:
# Return empty string to indicate no optimization needed
return "", samples, num_docs_analyzed, cost

View File

@ -30,7 +30,7 @@ class LLMClient:
Initialize the LLMClient.
Args:
model (str, optional): The name of the LLM model to use. Defaults to "gpt-4o".
model (str, optional): The name of the LLM model to use. Defaults to "gpt-5.1".
**litellm_kwargs: Additional keyword arguments for the LLM model.
"""
self.rewrite_agent_model = rewrite_agent_model

View File

@ -204,10 +204,6 @@ def get_openai_response(
response = litellm.completion(
model=model,
messages=messages,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ResponseFormat,
)
assistant_response = response.choices[0].message.content

View File

@ -324,10 +324,6 @@ Focus on quality over quantity - a few diverse, informative examples are better
model=self.agent_llm,
messages=self.message_history,
response_format=AgentDecision,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
)
call_cost += response._hidden_params["response_cost"]
@ -415,10 +411,6 @@ Provide your response as a JSON object matching this schema: {response_schema.mo
model=self.agent_llm,
messages=self.message_history,
response_format=response_schema,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
)
call_cost += schema_response._hidden_params["response_cost"]

View File

@ -222,7 +222,6 @@ class Directive(BaseModel, ABC):
model=agent_llm,
messages=messages,
response_format=JudgeResponse,
azure=True,
)
# Parse the JSON response

View File

@ -1,5 +1,4 @@
import json
import os
import re
from copy import deepcopy
from typing import Dict, List, Type
@ -169,19 +168,15 @@ class ChainingDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
# api_key=os.environ["GEMINI_API_KEY"],
azure=True,
response_format=ChainingInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
@ -202,11 +197,12 @@ class ChainingDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -211,18 +210,14 @@ class ChangeModelDirective(Directive):
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
# api_key=os.environ["GEMINI_API_KEY"],
azure=True,
response_format=ChangeModelInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ChangeModelInstantiateSchema(**parsed_res)
@ -240,11 +235,12 @@ class ChangeModelDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -205,17 +204,14 @@ class ChangeModelAccDirective(Directive):
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ChangeModelInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ChangeModelInstantiateSchema(**parsed_res)
@ -233,11 +229,12 @@ class ChangeModelAccDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -282,17 +281,14 @@ class ChangeModelCostDirective(Directive):
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ChangeModelInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ChangeModelInstantiateSchema(**parsed_res)
@ -310,11 +306,12 @@ class ChangeModelCostDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -185,17 +184,14 @@ class ChunkHeaderSummaryDirective(Directive):
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ChunkHeaderSummaryInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ChunkHeaderSummaryInstantiateSchema(**parsed_res)
@ -204,11 +200,12 @@ class ChunkHeaderSummaryDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
@ -251,9 +248,7 @@ class ChunkHeaderSummaryDirective(Directive):
)
if gather_idx - split_idx > 1:
raise ValueError(
"There should not be operators between split and gather"
)
raise ValueError("There should not be operators between split and gather")
# Get the split_key from the split operation
split_key = split_op.get("split_key")

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -261,17 +260,14 @@ def transform(input_doc):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=DeterministicDocCompressionInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = DeterministicDocCompressionInstantiateSchema(**parsed_res)
@ -284,11 +280,12 @@ def transform(input_doc):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -268,18 +267,15 @@ class DocumentChunkingDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
call_cost = 0.0
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=DocumentChunkingInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = DocumentChunkingInstantiateSchema(**parsed_res)
@ -290,11 +286,12 @@ class DocumentChunkingDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
@ -404,16 +401,19 @@ class DocumentChunkingDirective(Directive):
sample_op["stratify_key"] = stratify_keys
if rewrite.sampling_config.samples_per_group:
sample_op["samples_per_group"] = rewrite.sampling_config.samples_per_group
sample_op["samples_per_group"] = (
rewrite.sampling_config.samples_per_group
)
# Add optional fields if provided
if rewrite.sampling_config.random_state is not None:
sample_op["random_state"] = rewrite.sampling_config.random_state
if rewrite.sampling_config.method_kwargs is not None:
try:
sample_op["method_kwargs"] = json.loads(rewrite.sampling_config.method_kwargs)
sample_op["method_kwargs"] = json.loads(
rewrite.sampling_config.method_kwargs
)
except Exception as e:
raise ValueError(f"Invalid method_kwargs: {e}")

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -431,17 +430,14 @@ class DocumentChunkingTopKDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=DocumentChunkingTopKInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = DocumentChunkingTopKInstantiateSchema(**parsed_res)
@ -451,11 +447,12 @@ class DocumentChunkingTopKDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -164,19 +163,14 @@ class DocCompressionDirective(Directive):
},
]
)
error_message = ""
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=DocCompressionInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
@ -187,11 +181,12 @@ class DocCompressionDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Error: {error_message}"
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -261,17 +260,14 @@ class DocSummarizationDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=DocSummarizationInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = DocSummarizationInstantiateSchema(**parsed_res)
@ -280,11 +276,12 @@ class DocSummarizationDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -146,32 +145,31 @@ class GleaningDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
# api_key=os.environ["GEMINI_API_KEY"],
azure=True,
response_format=GleaningInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = GleaningInstantiateSchema(**parsed_res)
GleaningInstantiateSchema.check_no_jinja_variables(schema.validation_prompt)
GleaningInstantiateSchema.check_no_jinja_variables(
schema.validation_prompt
)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -180,19 +179,14 @@ class HierarchicalReduceDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=HierarchicalReduceInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
@ -214,11 +208,12 @@ class HierarchicalReduceDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
@ -259,7 +254,7 @@ class HierarchicalReduceDirective(Directive):
"name": rewrite.map_config.name,
"type": "map",
"prompt": rewrite.map_config.prompt,
"model": default_model,
"model": default_model,
"litellm_completion_kwargs": {"temperature": 0},
"output": {"schema": {rewrite.map_config.output_keys[0]: "string"}},
}
@ -324,4 +319,4 @@ class HierarchicalReduceDirective(Directive):
new_ops_plan = self.apply(
global_default_model, operators, target_ops[0], rewrite
)
return new_ops_plan, message_history, call_cost
return new_ops_plan, message_history, call_cost

View File

@ -1,5 +1,4 @@
import json
import os
import re
from copy import deepcopy
from typing import Dict, List, Type
@ -260,17 +259,14 @@ class IsolatingSubtasksDirective(Directive):
original_op.get("output", {}).get("schema", {}).keys()
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=IsolatingSubtasksInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = IsolatingSubtasksInstantiateSchema(**parsed_res)
@ -285,11 +281,12 @@ class IsolatingSubtasksDirective(Directive):
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease ensure all original output keys are covered by subtasks and try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -1,5 +1,4 @@
import json
import os
import re
from copy import deepcopy
from typing import Dict, List, Type
@ -7,14 +6,20 @@ from typing import Dict, List, Type
from litellm import completion
from pydantic import BaseModel, Field
from docetl.reasoning_optimizer.instantiate_schemas import MapReduceFusionInstantiateSchema
from docetl.reasoning_optimizer.instantiate_schemas import (
MapReduceFusionInstantiateSchema,
)
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
class MapReduceFusionDirective(Directive):
name: str = Field(default="map_reduce_fusion", description="The name of the directive")
formal_description: str = Field(default="Map -> Reduce => Map (with new prompt) -> Reduce (with new propmt)")
name: str = Field(
default="map_reduce_fusion", description="The name of the directive"
)
formal_description: str = Field(
default="Map -> Reduce => Map (with new prompt) -> Reduce (with new propmt)"
)
nl_description: str = Field(
default="Transform a Map operation followed by a Reduce operation over long documents by updating the Map prompt to first extract or process only the relevant information from each document. Then, modify the Reduce prompt to operate on these processed outputs instead of the full document content."
)
@ -42,7 +47,7 @@ class MapReduceFusionDirective(Directive):
" type: reduce\\n"
" reduce_key: category\\n"
" prompt: |\\n"
" For each category \\\"{{ reduce_key }}\\\", extract all organization names from these documents:\\n"
' For each category \\"{{ reduce_key }}\\", extract all organization names from these documents:\\n'
" {% for input in inputs %}\\n"
" Document {{ loop.index }}: {{ input.content }}\\n"
" {% endfor %}\\n"
@ -79,7 +84,7 @@ class MapReduceFusionDirective(Directive):
"reduce_key": "category",
"prompt": "For each category '{{ reduce_key }}', extract organizations from: {% for input in inputs %}{{ input.content }}{% endfor %}",
"output": {"schema": {"organizations": "list[str]"}},
}
},
],
target_ops=["classify_document", "extract_organizations"],
expected_behavior="Should modify map to classify AND extract organizations per document, then reduce to aggregate pre-extracted organizations",
@ -101,7 +106,7 @@ class MapReduceFusionDirective(Directive):
"reduce_key": "doc_type",
"prompt": "For each document type '{{ reduce_key }}', find all people mentioned: {% for input in inputs %}Document: {{ input.text }}{% endfor %}",
"output": {"schema": {"people": "list[str]"}},
}
},
],
target_ops=["analyze_document", "find_people"],
expected_behavior="Should modify map to analyze type AND extract people per document, then reduce to combine pre-extracted people lists",
@ -159,7 +164,7 @@ class MapReduceFusionDirective(Directive):
reduce_op: Dict,
expected_document_key,
agent_llm: str,
message_history: list = []
message_history: list = [],
):
"""
Use LLM to instantiate this directive by transforming the map and reduce operations.
@ -173,66 +178,85 @@ class MapReduceFusionDirective(Directive):
Returns:
MapReduceFusionInstantiateSchema: The structured output from the LLM.
"""
message_history.extend([
{"role": "system", "content": "You are a helpful AI assistant for optimizing document processing pipelines."},
{"role": "user", "content": self.to_string_for_instantiate([map_op, reduce_op])},
])
message_history.extend(
[
{
"role": "system",
"content": "You are a helpful AI assistant for optimizing document processing pipelines.",
},
{
"role": "user",
"content": self.to_string_for_instantiate([map_op, reduce_op]),
},
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=MapReduceFusionInstantiateSchema
response_format=MapReduceFusionInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
if "new_map_name" not in parsed_res or "new_map_prompt" not in parsed_res or "new_key" not in parsed_res or "new_reduce_prompt" not in parsed_res:
if (
"new_map_name" not in parsed_res
or "new_map_prompt" not in parsed_res
or "new_key" not in parsed_res
or "new_reduce_prompt" not in parsed_res
):
raise ValueError(
"Response from LLM is missing required keys: 'new_map_name', 'new_map_prompt', 'new_key', or 'new_reduce_prompt'"
)
schema = MapReduceFusionInstantiateSchema(
new_map_name=parsed_res["new_map_name"],
new_map_prompt=parsed_res["new_map_prompt"],
new_key=parsed_res["new_key"],
new_reduce_prompt=parsed_res["new_reduce_prompt"]
new_reduce_prompt=parsed_res["new_reduce_prompt"],
)
# Validate the schema
MapReduceFusionInstantiateSchema.validate_reduce_prompt_references_new_key(
schema.new_reduce_prompt, schema.new_key, expected_document_key
)
message_history.append(
{"role": "assistant", "content": resp.choices[0].message.content}
)
return schema, message_history, call_cost
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts.")
def apply(self, global_default_model, ops_list: List[Dict], map_target: str, reduce_target: str, rewrite: MapReduceFusionInstantiateSchema) -> List[Dict]:
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
self,
global_default_model,
ops_list: List[Dict],
map_target: str,
reduce_target: str,
rewrite: MapReduceFusionInstantiateSchema,
) -> List[Dict]:
"""
Apply the directive to the pipeline config.
"""
# Create a copy of the pipeline config
new_ops_list = deepcopy(ops_list)
# Find positions of the target ops
map_pos = None
reduce_pos = None
orig_map_op = None
orig_reduce_op = None
for i, op in enumerate(ops_list):
if op["name"] == map_target:
map_pos = i
@ -240,13 +264,15 @@ class MapReduceFusionDirective(Directive):
elif op["name"] == reduce_target:
reduce_pos = i
orig_reduce_op = op
if map_pos is None or reduce_pos is None:
raise ValueError(f"Could not find target operations: {map_target}, {reduce_target}")
raise ValueError(
f"Could not find target operations: {map_target}, {reduce_target}"
)
# Get default model
default_model = orig_map_op.get("model", global_default_model)
# Create the new map operation with fused functionality
new_map_op = {
"name": rewrite.new_map_name,
@ -257,11 +283,11 @@ class MapReduceFusionDirective(Directive):
"output": {
"schema": {
**orig_map_op.get("output", {}).get("schema", {}),
rewrite.new_key: "list[str]"
rewrite.new_key: "list[str]",
}
}
},
}
# Create the new reduce operation that works with pre-extracted data
new_reduce_op = {
"name": orig_reduce_op["name"],
@ -270,54 +296,58 @@ class MapReduceFusionDirective(Directive):
"prompt": rewrite.new_reduce_prompt,
"model": default_model,
"litellm_completion_kwargs": {"temperature": 0},
"output": orig_reduce_op.get("output", {})
"output": orig_reduce_op.get("output", {}),
}
# Replace the operations
new_ops_list[map_pos] = new_map_op
new_ops_list[reduce_pos] = new_reduce_op
return new_ops_list
def instantiate(
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
**kwargs
self,
operators: List[Dict],
target_ops: List[str],
agent_llm: str,
message_history: list = [],
optimize_goal="acc",
global_default_model: str = None,
**kwargs,
):
"""
Instantiate the directive for a list of operators.
"""
# Assert that there are exactly two target ops (map and reduce)
if len(target_ops) != 2:
raise ValueError("map_reduce_fusion directive requires exactly two target operators: a map and a reduce operation")
raise ValueError(
"map_reduce_fusion directive requires exactly two target operators: a map and a reduce operation"
)
# Find the map and reduce operations
first_op = None
second_op = None
for op in operators:
if op["name"] == target_ops[0]:
first_op = op
elif op["name"] == target_ops[1]:
second_op = op
if first_op is None or second_op is None:
raise ValueError(f"Could not find target operations: {target_ops}")
if first_op.get("type") == "map" and second_op.get("type") == "reduce":
map_op = first_op
reduce_op = second_op
map_target = target_ops[0]
reduce_target = target_ops[1]
else:
raise ValueError("Target operators must be one map operation followed by one reduce operation!")
# Extract expected document key from the reduce prompt template
raise ValueError(
"Target operators must be one map operation followed by one reduce operation!"
)
# Extract expected document key from the reduce prompt template
prompt_template = reduce_op["prompt"]
# Find all occurrences of {{ input.key }} in the prompt
input_key_pattern = r"\{\{\s*([^\}\s]+)\s*\}\}"
@ -335,7 +365,9 @@ class MapReduceFusionDirective(Directive):
]
if document_key_candidates:
expected_document_key = document_key_candidates[0] # Pick the first candidate
expected_document_key = document_key_candidates[
0
] # Pick the first candidate
elif input_keys:
expected_document_key = input_keys[0] # Fall back to the first input key
else:
@ -344,8 +376,12 @@ class MapReduceFusionDirective(Directive):
print(f"Detected document key: {expected_document_key}")
# Instantiate the directive
rewrite, message_history, call_cost = self.llm_instantiate(map_op, reduce_op, expected_document_key, agent_llm, message_history)
rewrite, message_history, call_cost = self.llm_instantiate(
map_op, reduce_op, expected_document_key, agent_llm, message_history
)
# Apply the rewrite to the operators
new_ops_plan = self.apply(global_default_model, operators, map_target, reduce_target, rewrite)
return new_ops_plan, message_history, call_cost
new_ops_plan = self.apply(
global_default_model, operators, map_target, reduce_target, rewrite
)
return new_ops_plan, message_history, call_cost

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -203,14 +202,10 @@ class MapResolveToMapWithCategoriesDirective(Directive):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=MapResolveToMapWithCategoriesInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -182,14 +181,10 @@ class MapToMapResolveReduceDirective(Directive):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=MapToMapResolveReduceInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -172,17 +171,14 @@ class OperatorFusionDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=OperatorFusionInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = OperatorFusionInstantiateSchema(**parsed_res)
@ -191,11 +187,12 @@ class OperatorFusionDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
@ -215,7 +212,6 @@ class OperatorFusionDirective(Directive):
new_ops_list = deepcopy(ops_list)
op1_name, op2_name = target_ops[0], target_ops[1]
# Find the operations
op1_idx = next(i for i, op in enumerate(ops_list) if op["name"] == op1_name)
op2_idx = next(i for i, op in enumerate(ops_list) if op["name"] == op2_name)
@ -329,5 +325,5 @@ def transform(input_doc):
return (
self.apply(global_default_model, operators, target_ops, rewrite),
message_history,
call_cost
call_cost,
)

View File

@ -1,5 +1,4 @@
import json
import os
import re
from copy import deepcopy
from typing import Dict, List, Type
@ -127,7 +126,7 @@ class ReduceChainingDirective(Directive):
original_op: Dict,
expected_document_key: str,
agent_llm: str,
message_history: list = []
message_history: list = [],
):
"""
Use LLM to instantiate this directive by decomposing the reduce operation.
@ -155,19 +154,15 @@ class ReduceChainingDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
call_cost = 0
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=ReduceChainingInstantiateSchema
response_format=ReduceChainingInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = ReduceChainingInstantiateSchema(**parsed_res)
@ -182,11 +177,12 @@ class ReduceChainingDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
@ -283,7 +279,9 @@ class ReduceChainingDirective(Directive):
]
if document_key_candidates:
expected_document_key = document_key_candidates[0] # Pick the first candidate
expected_document_key = document_key_candidates[
0
] # Pick the first candidate
elif input_keys:
expected_document_key = input_keys[0] # Fall back to the first input key
else:

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -255,17 +254,14 @@ class ReduceGleaningDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=GleaningInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = GleaningInstantiateSchema(**parsed_res)
@ -274,11 +270,12 @@ class ReduceGleaningDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(
@ -341,5 +338,6 @@ class ReduceGleaningDirective(Directive):
# Apply the rewrite to the operators
return (
self.apply(global_default_model, operators, target_ops[0], rewrite),
message_history, call_cost
message_history,
call_cost,
)

View File

@ -1,5 +1,4 @@
import json
import os
from copy import deepcopy
from typing import Dict, List, Type
@ -234,18 +233,14 @@ class TakeHeadTailDirective(Directive):
]
)
last_error = None
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
call_cost = 0
resp = completion(
model=agent_llm,
messages=message_history,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
azure=True,
response_format=TakeHeadTailInstantiateSchema,
)
call_cost = resp._hidden_params["response_cost"]
call_cost = resp._hidden_params.get("response_cost", 0)
try:
parsed_res = json.loads(resp.choices[0].message.content)
schema = TakeHeadTailInstantiateSchema(**parsed_res)
@ -254,11 +249,12 @@ class TakeHeadTailDirective(Directive):
)
return schema, message_history, call_cost
except Exception as err:
last_error = err
error_message = f"Validation error: {err}\nPlease try again."
message_history.append({"role": "user", "content": error_message})
raise Exception(
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
)
def apply(

View File

@ -445,8 +445,12 @@ class DeterministicDocCompressionInstantiateSchema(BaseModel):
raise ValueError(
"Code must define a function named 'transform' that takes input_doc as parameter"
)
if "return {" not in v and "return dict(" not in v:
raise ValueError("Code must return a dictionary")
if (
"return {" not in v
and "return dict(" not in v
and "output = dict(" not in v
):
raise ValueError(f"Code must return a dictionary. Instead the code is: {v}")
return v
def validate_code_returns_target_keys(self, target_ops_configs: List[Dict]) -> None:

View File

@ -699,7 +699,7 @@ class DSLRunner(ConfigWrapper):
"litellm_kwargs", {}
)
kwargs["rewrite_agent_model"] = self.config.get("optimizer_config", {}).get(
"rewrite_agent_model", "gpt-4o"
"rewrite_agent_model", "gpt-5.1"
)
kwargs["judge_agent_model"] = self.config.get("optimizer_config", {}).get(
"judge_agent_model", "gpt-4o-mini"
@ -727,7 +727,7 @@ class DSLRunner(ConfigWrapper):
"litellm_kwargs", {}
)
kwargs["rewrite_agent_model"] = self.config.get("optimizer_config", {}).get(
"rewrite_agent_model", "gpt-4o"
"rewrite_agent_model", "gpt-5.1"
)
kwargs["judge_agent_model"] = self.config.get("optimizer_config", {}).get(
"judge_agent_model", "gpt-4o-mini"

View File

@ -14,7 +14,7 @@ All fields in `optimizer_config` are required (no defaults):
| `evaluation_file` | `str` | Path to Python file containing `@register_eval` decorated function |
| `metric_key` | `str` | Key in evaluation results dictionary to use as accuracy metric |
| `max_iterations` | `int` | Maximum number of MOARSearch iterations to run |
| `model` | `str` | LLM model to use for directive instantiation during search |
| `rewrite_agent_model` | `str` | LLM model to use for directive instantiation during search |
!!! warning "All Fields Required"
MOAR will error if any required field is missing. There are no defaults.
@ -67,19 +67,19 @@ The optimizer will use the sample dataset, but your final pipeline uses the full
### Available Models
!!! info "LiteLLM Model Names"
Use LiteLLM model names (e.g., `gpt-4o-mini`, `gpt-4o`, `gpt-4.1`). Make sure your API keys are set in your environment.
Use LiteLLM model names (e.g., `gpt-4o-mini`, `gpt-4o`, `gpt-5.1`). Make sure your API keys are set in your environment.
```yaml
available_models: # LiteLLM model names - ensure API keys are set
- gpt-4.1-nano # Cheapest, lower accuracy
- gpt-4.1-mini # Low cost, decent accuracy
- gpt-4.1 # Balanced
- gpt-5.1-nano # Cheapest, lower accuracy
- gpt-5.1-mini # Low cost, decent accuracy
- gpt-5.1 # Balanced
- gpt-4o # Higher cost, better accuracy
```
### Model for Directive Instantiation
The `model` field specifies which LLM to use for generating optimization directives during the search process. This doesn't affect the models tested in `available_models`.
The `rewrite_agent_model` field specifies which LLM to use for generating optimization directives during the search process. This doesn't affect the models tested in `available_models`.
!!! tip "Cost Consideration"
Use a cheaper model (like `gpt-4o-mini`) for directive instantiation to reduce search costs.
@ -104,12 +104,12 @@ optimizer_config:
available_models:
- gpt-4o-mini
- gpt-4o
- gpt-4.1-mini
- gpt-4.1
- gpt-5.1-mini
- gpt-5.1
evaluation_file: evaluate_medications.py
metric_key: medication_extraction_score
max_iterations: 40
model: gpt-4.1
rewrite_agent_model: gpt-5.1
dataset_path: data/sample.json # Optional
exploration_weight: 1.414 # Optional
```

View File

@ -25,15 +25,15 @@ optimizer_config:
dataset_path: workloads/medical/raw_sample.json # Use sample for faster optimization
save_dir: workloads/medical/moar_results
available_models: # LiteLLM model names - ensure API keys are set in your environment
- gpt-4.1-nano
- gpt-4.1-mini
- gpt-4.1
- gpt-5.1-nano
- gpt-5.1-mini
- gpt-5.1
- gpt-4o
- gpt-4o-mini
evaluation_file: workloads/medical/evaluate_medications.py
metric_key: medication_extraction_score
max_iterations: 40
model: gpt-4.1
rewrite_agent_model: gpt-5.1
system_prompt:
dataset_description: a collection of transcripts of doctor visits

View File

@ -109,12 +109,12 @@ optimizer_config:
available_models: # LiteLLM model names - ensure API keys are set in your environment
- gpt-4o-mini
- gpt-4o
- gpt-4.1-mini
- gpt-4.1
- gpt-5.1-mini
- gpt-5.1
evaluation_file: evaluate_medications.py
metric_key: medication_extraction_score # This must match a key in your evaluation function's return dictionary
max_iterations: 40
model: gpt-4.1
rewrite_agent_model: gpt-5.1
dataset_path: data/transcripts_sample.json # Optional: use sample/hold-out dataset
```

View File

@ -19,7 +19,7 @@ High-level summary of the optimization run:
{
"optimizer": "moar",
"input_pipeline": "pipeline.yaml",
"model": "gpt-4.1",
"rewrite_agent_model": "gpt-5.1",
"max_iterations": 40,
"save_dir": "results/moar_optimization",
"dataset": "transcripts",

View File

@ -80,9 +80,9 @@ Test your evaluation function independently and check MOAR logs for errors.
```yaml
available_models:
- gpt-4.1-nano # Cheapest, lower accuracy
- gpt-4.1-mini # Low cost, decent accuracy
- gpt-4.1 # Balanced
- gpt-5.1-nano # Cheapest, lower accuracy
- gpt-5.1-mini # Low cost, decent accuracy
- gpt-5.1 # Balanced
- gpt-4o # Higher cost, better accuracy
```

View File

@ -267,10 +267,6 @@ class AgentCommunicator:
response = completion(
model=self.model,
messages=messages,
azure=True,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
response_format={"type": "json_object"}
)
@ -290,10 +286,6 @@ class AgentCommunicator:
response = completion(
model=self.model,
messages=messages + [{"role": "user", "content": request_msg}],
azure=True,
api_key=os.environ.get("AZURE_API_KEY"),
api_base=os.environ.get("AZURE_API_BASE"),
api_version=os.environ.get("AZURE_API_VERSION"),
response_format={"type": "json_object"}
)

View File

@ -27,6 +27,7 @@ class OptimizeResult(BaseModel):
should_optimize: str | None = None
input_data: list[dict[str, Any]] | None = None
output_data: list[dict[str, Any]] | None = None
num_docs_analyzed: int | None = None
cost: float | None = None
error: str | None = None
created_at: datetime
@ -35,4 +36,25 @@ class OptimizeResult(BaseModel):
class OptimizeRequest(BaseModel):
yaml_config: str
step_name: str
op_name: str
op_name: str
class DecomposeRequest(BaseModel):
yaml_config: str
step_name: str
op_name: str
class DecomposeResult(BaseModel):
task_id: str
status: TaskStatus
decomposed_operations: list[dict[str, Any]] | None = None
winning_directive: str | None = None
candidates_evaluated: int | None = None
original_outputs: list[dict[str, Any]] | None = None
decomposed_outputs: list[dict[str, Any]] | None = None
comparison_rationale: str | None = None
cost: float | None = None
error: str | None = None
created_at: datetime
completed_at: datetime | None = None

View File

@ -2,13 +2,23 @@ from typing import Any
import uuid
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
from docetl.runner import DSLRunner
from docetl.optimizers.fast_should_optimize import FastShouldOptimizeAnalyzer
from docetl.optimizers.fast_decomposer import FastDecomposer
import asyncio
from asyncio import Task
from rich.logging import RichHandler
import logging
import yaml
from datetime import datetime, timedelta
from enum import Enum
from server.app.models import OptimizeResult, TaskStatus, OptimizeRequest, PipelineRequest
from server.app.models import (
OptimizeResult,
TaskStatus,
OptimizeRequest,
PipelineRequest,
DecomposeRequest,
DecomposeResult,
)
# Setup logging
FORMAT = "%(message)s"
@ -18,10 +28,14 @@ logging.basicConfig(
router = APIRouter()
# Task storage
# Task storage for optimize tasks
tasks: dict[str, OptimizeResult] = {}
asyncio_tasks: dict[str, Task] = {}
# Task storage for decompose tasks
decompose_tasks: dict[str, DecomposeResult] = {}
decompose_asyncio_tasks: dict[str, Task] = {}
# Configuration
COMPLETED_TASK_TTL = timedelta(hours=1)
@ -30,50 +44,111 @@ async def cleanup_old_tasks():
while True:
try:
current_time = datetime.now()
task_ids_to_remove = []
# Clean up optimize tasks
task_ids_to_remove = []
for task_id, task in tasks.items():
if (task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED] and
task.completed_at and
task.completed_at and
current_time - task.completed_at > COMPLETED_TASK_TTL):
task_ids_to_remove.append(task_id)
for task_id in task_ids_to_remove:
del tasks[task_id]
# Clean up decompose tasks
decompose_ids_to_remove = []
for task_id, task in decompose_tasks.items():
if (task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED] and
task.completed_at and
current_time - task.completed_at > COMPLETED_TASK_TTL):
decompose_ids_to_remove.append(task_id)
for task_id in decompose_ids_to_remove:
del decompose_tasks[task_id]
await asyncio.sleep(60)
except Exception as e:
logging.error(f"Error in cleanup task: {e}")
await asyncio.sleep(60)
async def run_optimization(task_id: str, yaml_config: str, step_name: str, op_name: str):
"""Execute the optimization task"""
"""Execute the optimization task using fast single-LLM-call analysis."""
try:
tasks[task_id].status = TaskStatus.PROCESSING
# Run the actual optimization in a separate thread to not block
runner = DSLRunner.from_yaml(yaml_config)
should_optimize, input_data, output_data, cost = await asyncio.to_thread(
runner.should_optimize,
# yaml_config is a file path, not YAML content - read and parse the file
if yaml_config.endswith(".yaml") or yaml_config.endswith(".yml"):
with open(yaml_config, "r") as f:
config = yaml.safe_load(f)
else:
# Fallback: try parsing as YAML string
config = yaml.safe_load(yaml_config)
# Validate that we got a dict
if not isinstance(config, dict):
raise ValueError(
f"Invalid yaml_config: expected dict after parsing, got {type(config).__name__}. "
f"Input: {str(yaml_config)[:200]}"
)
# Get intermediate directory from config
intermediate_dir = (
config.get("pipeline", {})
.get("output", {})
.get("intermediate_dir")
)
if not intermediate_dir:
raise ValueError("No intermediate_dir configured in pipeline output")
# Find the operation config
op_config = None
for op in config.get("operations", []):
if op.get("name") == op_name:
op_config = op
break
if not op_config:
raise ValueError(f"Operation '{op_name}' not found in config")
# Get optimizer settings from config
optimizer_model = (
config.get("optimizer_config", {})
.get("rewrite_agent_model", "gpt-5.1")
)
litellm_kwargs = (
config.get("optimizer_config", {})
.get("litellm_kwargs", {})
)
# Create analyzer and run in thread pool
analyzer = FastShouldOptimizeAnalyzer(
intermediate_dir=intermediate_dir,
optimizer_model=optimizer_model,
litellm_kwargs=litellm_kwargs,
)
should_optimize, output_data, num_docs_analyzed, cost = await asyncio.to_thread(
analyzer.analyze,
op_config,
step_name,
op_name,
)
# Update task result
tasks[task_id].status = TaskStatus.COMPLETED
tasks[task_id].should_optimize = should_optimize
tasks[task_id].input_data = input_data
tasks[task_id].input_data = None # We don't have input_data in fast mode
tasks[task_id].output_data = output_data
tasks[task_id].num_docs_analyzed = num_docs_analyzed
tasks[task_id].cost = cost
tasks[task_id].completed_at = datetime.now()
except asyncio.CancelledError:
runner.is_cancelled = True
tasks[task_id].status = TaskStatus.CANCELLED
tasks[task_id].completed_at = datetime.now()
raise
except Exception as e:
import traceback
error_traceback = traceback.format_exc()
@ -81,11 +156,10 @@ async def run_optimization(task_id: str, yaml_config: str, step_name: str, op_na
tasks[task_id].error = f"{str(e)}\n{error_traceback}"
tasks[task_id].completed_at = datetime.now()
raise
finally:
if task_id in asyncio_tasks:
del asyncio_tasks[task_id]
runner.reset_env()
@router.on_event("startup")
async def startup_event():
@ -153,6 +227,140 @@ async def cancel_optimize_task(task_id: str):
return {"message": "Task cancelled successfully"}
# ============================================================================
# Decompose endpoints
# ============================================================================
async def run_decomposition(task_id: str, yaml_config: str, step_name: str, op_name: str):
"""Execute the decomposition task using fast directive-based approach."""
try:
decompose_tasks[task_id].status = TaskStatus.PROCESSING
# yaml_config is a file path
if not (yaml_config.endswith(".yaml") or yaml_config.endswith(".yml")):
raise ValueError("yaml_config must be a path to a YAML file")
# Get optimizer settings from config
with open(yaml_config, "r") as f:
config = yaml.safe_load(f)
optimizer_model = (
config.get("optimizer_config", {})
.get("rewrite_agent_model", "gpt-5.1")
)
litellm_kwargs = (
config.get("optimizer_config", {})
.get("litellm_kwargs", {})
)
# Create decomposer and run in thread pool
decomposer = FastDecomposer(
yaml_config_path=yaml_config,
optimizer_model=optimizer_model,
sample_size=5,
litellm_kwargs=litellm_kwargs,
)
result = await asyncio.to_thread(
decomposer.decompose,
step_name,
op_name,
)
# Update task result
decompose_tasks[task_id].status = TaskStatus.COMPLETED
decompose_tasks[task_id].decomposed_operations = result["decomposed_ops"]
decompose_tasks[task_id].winning_directive = result["winning_directive"]
decompose_tasks[task_id].candidates_evaluated = result["candidates_evaluated"]
decompose_tasks[task_id].original_outputs = result["original_outputs"]
decompose_tasks[task_id].decomposed_outputs = result["decomposed_outputs"]
decompose_tasks[task_id].comparison_rationale = result["comparison_rationale"]
decompose_tasks[task_id].cost = result["cost"]
decompose_tasks[task_id].completed_at = datetime.now()
except asyncio.CancelledError:
decompose_tasks[task_id].status = TaskStatus.CANCELLED
decompose_tasks[task_id].completed_at = datetime.now()
raise
except Exception as e:
import traceback
error_traceback = traceback.format_exc()
decompose_tasks[task_id].status = TaskStatus.FAILED
decompose_tasks[task_id].error = f"{str(e)}\n{error_traceback}"
decompose_tasks[task_id].completed_at = datetime.now()
raise
finally:
if task_id in decompose_asyncio_tasks:
del decompose_asyncio_tasks[task_id]
@router.post("/decompose", status_code=202)
async def submit_decompose_task(request: DecomposeRequest):
"""Submit a new decomposition task"""
task_id = str(uuid.uuid4())
# Create task record
decompose_tasks[task_id] = DecomposeResult(
task_id=task_id,
status=TaskStatus.PENDING,
created_at=datetime.now()
)
# Create and store the asyncio task
task = asyncio.create_task(
run_decomposition(
task_id,
request.yaml_config,
request.step_name,
request.op_name
)
)
decompose_asyncio_tasks[task_id] = task
return {"task_id": task_id}
@router.get("/decompose/{task_id}")
async def get_decompose_status(task_id: str) -> DecomposeResult:
"""Get the current status of a decomposition task"""
if task_id not in decompose_tasks:
raise HTTPException(
status_code=404,
detail="Task not found or has been cleaned up"
)
return decompose_tasks[task_id]
@router.post("/decompose/{task_id}/cancel")
async def cancel_decompose_task(task_id: str):
"""Cancel a running decomposition task"""
if task_id not in decompose_tasks:
raise HTTPException(
status_code=404,
detail="Task not found or has been cleaned up"
)
if task_id not in decompose_asyncio_tasks:
raise HTTPException(
status_code=400,
detail="Task already finished or cannot be cancelled"
)
asyncio_task = decompose_asyncio_tasks[task_id]
asyncio_task.cancel()
try:
await asyncio_task
except asyncio.CancelledError:
pass
return {"message": "Decomposition task cancelled successfully"}
# Keep the original run_pipeline endpoint
@router.post("/run_pipeline")
def run_pipeline(request: PipelineRequest) -> dict[str, Any]:
@ -169,6 +377,12 @@ def run_pipeline(request: PipelineRequest) -> dict[str, Any]:
@router.websocket("/ws/run_pipeline/{client_id}")
async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
"""
WebSocket endpoint for running pipelines with real-time output streaming.
Note: The old 'optimize' flag for full pipeline optimization has been removed.
Use the /decompose endpoint for fast operation decomposition instead.
"""
await websocket.accept()
runner = None
try:
@ -178,19 +392,8 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
if config.get("clear_intermediate", False):
runner.clear_intermediate()
if config.get("optimize", False):
logging.info(f"Optimizing pipeline with model {config.get('optimizer_model', 'gpt-4o')}")
# Set the runner config to the optimizer config
runner.config["optimizer_config"]["rewrite_agent_model"] = config.get("optimizer_model", "gpt-4o")
runner.config["optimizer_config"]["judge_agent_model"] = config.get("optimizer_model", "gpt-4o-mini")
async def run_pipeline():
return await asyncio.to_thread(runner.optimize, return_pipeline=False)
else:
async def run_pipeline():
return await asyncio.to_thread(runner.load_run_save)
async def run_pipeline():
return await asyncio.to_thread(runner.load_run_save)
pipeline_task = asyncio.create_task(run_pipeline())
@ -198,18 +401,6 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
console_output = runner.console.file.getvalue()
await websocket.send_json({"type": "output", "data": console_output})
if config.get("optimize", False):
optimizer_progress = runner.console.get_optimizer_progress()
rationale = runner.console.optimizer_rationale
await websocket.send_json({
"type": "optimizer_progress",
"status": optimizer_progress[0],
"progress": optimizer_progress[1],
"rationale": rationale[1] if rationale is not None else "",
"should_optimize": rationale[0] if rationale is not None else False,
"validator_prompt": rationale[2] if rationale is not None else ""
})
# Check for incoming messages from the user
try:
user_message = await asyncio.wait_for(
@ -219,8 +410,7 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
if user_message == "kill":
runner.console.log("Stopping process...")
runner.is_cancelled = True
await websocket.send_json({
"type": "error",
"message": "Process stopped by user request"
@ -250,41 +440,16 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
# Sleep for a short duration to ensure all output is captured
await asyncio.sleep(3)
# If optimize is true, send back the optimized operations
if config.get("optimize", False):
optimized_config, cost = result
# Send the operations back in order
new_pipeline_steps = optimized_config["pipeline"]["steps"]
new_pipeline_op_name_to_op_map = {op["name"]: op for op in optimized_config["operations"]}
new_ops_in_order = []
for new_step in new_pipeline_steps:
for op in new_step.get("operations", []):
if op not in new_ops_in_order:
new_ops_in_order.append(new_pipeline_op_name_to_op_map[op])
await websocket.send_json(
{
"type": "result",
"data": {
"message": "Pipeline executed successfully",
"cost": cost,
"optimized_ops": new_ops_in_order,
"yaml_config": config["yaml_config"],
},
}
)
else:
await websocket.send_json(
{
"type": "result",
"data": {
"message": "Pipeline executed successfully",
"cost": result,
"yaml_config": config["yaml_config"],
},
}
)
await websocket.send_json(
{
"type": "result",
"data": {
"message": "Pipeline executed successfully",
"cost": result,
"yaml_config": config["yaml_config"],
},
}
)
except WebSocketDisconnect:
if runner is not None:
runner.reset_env()
@ -299,3 +464,159 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
if runner is not None:
runner.reset_env()
await websocket.close()
@router.websocket("/ws/decompose/{client_id}")
async def websocket_decompose(websocket: WebSocket, client_id: str):
"""
WebSocket endpoint for fast operation decomposition with real-time output streaming.
Expects JSON message with:
- yaml_config: Path to the pipeline YAML config file
- step_name: Name of the pipeline step
- op_name: Name of the operation to decompose
"""
await websocket.accept()
decomposer = None
try:
config = await websocket.receive_json()
yaml_config = config["yaml_config"]
step_name = config["step_name"]
op_name = config["op_name"]
# Validate yaml_config is a path
if not (yaml_config.endswith(".yaml") or yaml_config.endswith(".yml")):
raise ValueError("yaml_config must be a path to a YAML file")
# Get optimizer settings from config
with open(yaml_config, "r") as f:
pipeline_config = yaml.safe_load(f)
optimizer_model = (
pipeline_config.get("optimizer_config", {})
.get("rewrite_agent_model", "gpt-5.1")
)
litellm_kwargs = (
pipeline_config.get("optimizer_config", {})
.get("litellm_kwargs", {})
)
# Create a ThreadSafeConsole for streaming output
from docetl.console import ThreadSafeConsole
console = ThreadSafeConsole(
force_terminal=True,
soft_wrap=True,
highlight=False,
log_path=False,
color_system="truecolor",
width=120,
style="bright_white on black",
record=True,
)
# Create decomposer with the console
decomposer = FastDecomposer(
yaml_config_path=yaml_config,
optimizer_model=optimizer_model,
sample_size=5,
litellm_kwargs=litellm_kwargs,
console=console,
)
# Run decomposition in a thread
async def run_decomposition():
return await asyncio.to_thread(
decomposer.decompose,
step_name,
op_name,
)
decompose_task = asyncio.create_task(run_decomposition())
# Stream console output while decomposition runs
accumulated_output = ""
while not decompose_task.done():
# get_output() processes carriage returns and clears the buffer
new_output = console.get_output()
if new_output:
# Handle spinner updates: if new output doesn't start with newline
# and accumulated doesn't end with newline, replace the last line
if accumulated_output and not accumulated_output.endswith('\n') and not new_output.startswith('\n'):
# Find the last newline in accumulated output
last_newline = accumulated_output.rfind('\n')
if last_newline >= 0:
# Replace everything after the last newline
accumulated_output = accumulated_output[:last_newline + 1] + new_output
else:
# No newline in accumulated, replace everything
accumulated_output = new_output
else:
accumulated_output += new_output
await websocket.send_json({"type": "output", "data": accumulated_output})
# Check for kill message
try:
user_message = await asyncio.wait_for(
websocket.receive_json(), timeout=0.1
)
if user_message == "kill":
console.log("[red]Stopping decomposition...[/red]")
decompose_task.cancel()
await websocket.send_json({
"type": "error",
"message": "Decomposition stopped by user request"
})
raise Exception("Process stopped by user request")
except asyncio.TimeoutError:
pass
except asyncio.CancelledError:
await websocket.send_json({
"type": "error",
"message": "Decomposition stopped by user request"
})
raise
await asyncio.sleep(0.5)
# Get final result
result = await decompose_task
# Send any remaining console output
final_output = console.get_output()
if final_output:
accumulated_output += final_output
await websocket.send_json({"type": "output", "data": accumulated_output})
await asyncio.sleep(1)
# Send the result
await websocket.send_json({
"type": "result",
"data": {
"decomposed_operations": result["decomposed_ops"],
"winning_directive": result["winning_directive"],
"candidates_evaluated": result["candidates_evaluated"],
"original_outputs": result["original_outputs"],
"decomposed_outputs": result["decomposed_outputs"],
"comparison_rationale": result["comparison_rationale"],
"cost": result["cost"],
},
})
except WebSocketDisconnect:
print(f"Decompose client {client_id} disconnected")
except Exception as e:
import traceback
error_traceback = traceback.format_exc()
print(f"Decompose error:\n{error_traceback}")
try:
await websocket.send_json({
"type": "error",
"data": str(e),
"traceback": error_traceback
})
except Exception:
pass
finally:
await websocket.close()

View File

@ -5,4 +5,9 @@ export const API_ROUTES = {
CANCEL: (taskId: string) =>
`/api/shouldOptimize?taskId=${taskId}&cancel=true`,
},
DECOMPOSE: {
SUBMIT: "/api/decompose",
STATUS: (taskId: string) => `/api/decompose?taskId=${taskId}`,
CANCEL: (taskId: string) => `/api/decompose?taskId=${taskId}&cancel=true`,
},
} as const;

View File

@ -0,0 +1,79 @@
import { NextRequest, NextResponse } from "next/server";
const FASTAPI_URL = `${
process.env.NEXT_PUBLIC_BACKEND_HTTPS ? "https" : "http"
}://${process.env.NEXT_PUBLIC_BACKEND_HOST}:${
process.env.NEXT_PUBLIC_BACKEND_PORT
}`;
// Helper to handle errors consistently
const handleError = (error: unknown, status = 500) => {
const message =
error instanceof Error ? error.message : "Internal server error";
return NextResponse.json({ error: message }, { status });
};
// Helper to proxy requests to FastAPI
async function proxyRequest(path: string, init?: RequestInit) {
const response = await fetch(`${FASTAPI_URL}${path}`, {
...init,
headers: {
"Content-Type": "application/json",
...init?.headers,
},
});
if (!response.ok) {
const error = await response.text();
throw new Error(`FastAPI server error: ${error}`);
}
return response.json();
}
export async function POST(request: NextRequest): Promise<NextResponse> {
try {
// Extract task ID from the URL if it exists
const taskId = request.nextUrl.searchParams.get("taskId");
const isCancel = request.nextUrl.searchParams.get("cancel") === "true";
// Handle different POST scenarios
if (taskId) {
if (isCancel) {
// Cancel task
const data = await proxyRequest(`/decompose/${taskId}/cancel`, {
method: "POST",
});
return NextResponse.json(data);
}
// Invalid request with taskId but no cancel
return handleError(new Error("Invalid request"), 400);
}
// Submit new task
const body = await request.json();
const data = await proxyRequest("/decompose", {
method: "POST",
body: JSON.stringify(body),
});
return NextResponse.json(data, { status: 202 });
} catch (error) {
return handleError(error);
}
}
export async function GET(request: NextRequest): Promise<NextResponse> {
try {
// Extract task ID from the URL
const taskId = request.nextUrl.searchParams.get("taskId");
if (!taskId) {
return handleError(new Error("Task ID is required"), 400);
}
const data = await proxyRequest(`/decompose/${taskId}`);
return NextResponse.json(data);
} catch (error) {
return handleError(error);
}
}

View File

@ -33,6 +33,77 @@ export function getNamespaceDir(homeDir: string, namespace: string) {
return path.join(homeDir, ".docetl", namespace);
}
/**
* Safely parse a JSON value - returns the parsed value if it's a string,
* or the original value if it's already an object.
* @param value - The value to parse
* @param fieldName - Name of the field (for error messages)
* @returns The parsed value
*/
function safeParseJSON(
value: unknown,
fieldName: string
): Record<string, unknown> | unknown[] {
if (value === null || value === undefined) {
throw new Error(`Field '${fieldName}' is null or undefined`);
}
// If it's already an object or array, return as-is
if (typeof value === "object") {
return value as Record<string, unknown> | unknown[];
}
// If it's a string, try to parse it
if (typeof value === "string") {
try {
return JSON.parse(value);
} catch (e) {
throw new Error(
`Field '${fieldName}' contains invalid JSON: ${value.slice(0, 100)}${value.length > 100 ? "..." : ""}`
);
}
}
throw new Error(
`Field '${fieldName}' has unexpected type '${typeof value}': ${String(value).slice(0, 100)}`
);
}
/**
* Recursively sanitize an object for YAML serialization.
* Converts any nested [object Object] strings to proper objects.
*/
function sanitizeForYaml(obj: unknown, path: string = ""): unknown {
if (obj === null || obj === undefined) {
return obj;
}
if (typeof obj === "string") {
// Check for [object Object] which indicates a serialization issue
if (obj === "[object Object]") {
console.warn(
`Warning: Found '[object Object]' string at path '${path}'. This indicates a serialization issue.`
);
return {}; // Return empty object as fallback
}
return obj;
}
if (Array.isArray(obj)) {
return obj.map((item, idx) => sanitizeForYaml(item, `${path}[${idx}]`));
}
if (typeof obj === "object") {
const result: Record<string, unknown> = {};
for (const [key, value] of Object.entries(obj)) {
result[key] = sanitizeForYaml(value, path ? `${path}.${key}` : key);
}
return result;
}
return obj;
}
export function generatePipelineConfig(
namespace: string,
default_model: string,
@ -82,9 +153,18 @@ export function generatePipelineConfig(
const litellm_completion_kwargs =
op.otherKwargs?.litellm_completion_kwargs;
if (litellm_completion_kwargs) {
op.otherKwargs.litellm_completion_kwargs = JSON.parse(
litellm_completion_kwargs
);
try {
op.otherKwargs.litellm_completion_kwargs = safeParseJSON(
litellm_completion_kwargs,
`${op.name}.litellm_completion_kwargs`
);
} catch (e) {
console.warn(
`Warning: Could not parse litellm_completion_kwargs for operation '${op.name}':`,
e
);
// Keep the original value if parsing fails
}
}
const newOp: Record<string, unknown> = {
@ -175,10 +255,14 @@ export function generatePipelineConfig(
// If it's a sample operation with custom method, parse the samples as key-value pairs
if (op.type === "sample" && op.otherKwargs?.method === "custom") {
try {
newOp.samples = JSON.parse(op.otherKwargs.samples);
newOp.samples = safeParseJSON(
op.otherKwargs.samples,
`${op.name}.samples`
);
} catch (error) {
console.warn(
"Failed to parse custom samples as JSON, using raw value"
`Failed to parse custom samples for operation '${op.name}':`,
error
);
}
}
@ -217,14 +301,12 @@ export function generatePipelineConfig(
// Fix type errors by asserting the pipeline config type
let pipelineConfig: any = {
from_docwrangler: true,
optimizer_model: optimizerModel,
datasets,
default_model,
...(enable_observability && {
optimizer_config: {
force_decompose: true,
},
}),
optimizer_config: {
rewrite_agent_model: optimizerModel,
...(enable_observability && { force_decompose: true }),
},
operations: updatedOperations,
pipeline: {
steps: [
@ -308,7 +390,9 @@ export function generatePipelineConfig(
const outputOpName = operationsToRun[currentOpIndex].name;
outputPath = path.join(outputBase, "data_processing", outputOpName + ".json");
const yamlString = yaml.dump(pipelineConfig);
// Sanitize the config before serializing to YAML to catch any [object Object] issues
const sanitizedConfig = sanitizeForYaml(pipelineConfig, "pipelineConfig");
const yamlString = yaml.dump(sanitizedConfig);
console.log(yamlString);

View File

@ -104,6 +104,7 @@ export interface OptimizeResult {
should_optimize?: string;
input_data?: Array<Record<string, unknown>>;
output_data?: Array<Record<string, unknown>>;
num_docs_analyzed?: number;
cost?: number;
error?: string;
created_at: string;
@ -114,3 +115,24 @@ export interface APIKey {
name: string;
value: string;
}
export interface DecomposeRequest {
yaml_config: string;
step_name: string;
op_name: string;
}
export interface DecomposeResult {
task_id: string;
status: TaskStatus;
decomposed_operations?: Array<Record<string, unknown>>;
winning_directive?: string;
candidates_evaluated?: number;
original_outputs?: Array<Record<string, unknown>>;
decomposed_outputs?: Array<Record<string, unknown>>;
comparison_rationale?: string;
cost?: number;
error?: string;
created_at: string;
completed_at?: string;
}

View File

@ -24,12 +24,14 @@ interface AnsiRendererProps {
text: string;
readyState: number;
setTerminalOutput: (text: string) => void;
isDecomposing?: boolean;
}
const AnsiRenderer: React.FC<AnsiRendererProps> = ({
text,
readyState,
setTerminalOutput,
isDecomposing = false,
}) => {
const html = convert.toHtml(text);
const scrollRef = useRef<HTMLDivElement>(null);
@ -51,7 +53,9 @@ const AnsiRenderer: React.FC<AnsiRendererProps> = ({
}
};
const isWebSocketClosed = readyState === WebSocket.CLOSED;
// Consider connected if either WebSocket is open OR decomposing is in progress
const isConnected = readyState === WebSocket.OPEN || isDecomposing;
const isWebSocketClosed = readyState === WebSocket.CLOSED && !isDecomposing;
return (
<div
@ -92,9 +96,11 @@ const AnsiRenderer: React.FC<AnsiRendererProps> = ({
)}
</div>
<div className="flex justify-between items-center text-xs text-gray-500">
<div className={isWebSocketClosed ? "text-red-500" : ""}>
<div className={isWebSocketClosed ? "text-red-500" : isDecomposing ? "text-blue-400" : ""}>
Status:{" "}
{readyState === WebSocket.CONNECTING
{isDecomposing
? "Decomposing..."
: readyState === WebSocket.CONNECTING
? "Connecting"
: readyState === WebSocket.OPEN
? "Connected"
@ -106,10 +112,7 @@ const AnsiRenderer: React.FC<AnsiRendererProps> = ({
</div>
<button
onClick={() => setTerminalOutput("")}
className={`hover:text-white transition-colors ${
isWebSocketClosed ? "cursor-not-allowed opacity-50" : ""
}`}
disabled={isWebSocketClosed}
className="hover:text-white transition-colors"
>
Clear
</button>

View File

@ -0,0 +1,271 @@
import React from "react";
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogDescription,
} from "@/components/ui/dialog";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { Button } from "@/components/ui/button";
import { ChevronLeft, ChevronRight, Check, X } from "lucide-react";
import { DecomposeResult } from "@/app/types";
interface DecompositionComparisonDialogProps {
isOpen: boolean;
result: DecomposeResult | null;
operationName?: string;
onOpenChange: (open: boolean) => void;
onApply: () => void;
onCancel: () => void;
}
export const DecompositionComparisonDialog: React.FC<
DecompositionComparisonDialogProps
> = ({ isOpen, result, operationName, onOpenChange, onApply, onCancel }) => {
const [currentPage, setCurrentPage] = React.useState(1);
const originalOutputs = result?.original_outputs || [];
const decomposedOutputs = result?.decomposed_outputs || [];
const maxRows = Math.max(originalOutputs.length, decomposedOutputs.length);
const renderOutputTable = (
data: Array<Record<string, unknown>>,
title: string
) => {
if (!data.length) {
return (
<div className="text-center text-muted-foreground py-8">
No outputs available
</div>
);
}
const columns = Object.keys(data[0]).filter(
(column) => !column.startsWith("_observability")
);
const currentRow = data[currentPage - 1];
if (!currentRow) return null;
return (
<div className="space-y-2">
<h4 className="font-medium text-sm text-muted-foreground">{title}</h4>
<div className="border rounded-md">
<div className="max-h-[250px] overflow-auto">
<Table className="relative w-full border-collapse">
<TableHeader>
<TableRow className="sticky top-0 bg-background z-10 border-b">
{columns.map((column) => (
<TableHead
key={column}
className="h-8 px-3 text-left align-middle bg-background text-xs"
>
{column}
</TableHead>
))}
</TableRow>
</TableHeader>
<TableBody>
<TableRow>
{columns.map((column) => (
<TableCell key={column} className="p-2 align-top">
<pre className="whitespace-pre-wrap font-mono text-xs text-left max-w-[300px]">
{typeof currentRow[column] === "object"
? JSON.stringify(currentRow[column], null, 2)
: String(currentRow[column] ?? "")}
</pre>
</TableCell>
))}
</TableRow>
</TableBody>
</Table>
</div>
</div>
</div>
);
};
React.useEffect(() => {
if (isOpen) {
setCurrentPage(1);
}
}, [isOpen]);
if (!result) return null;
const isOriginalWinner = result.winning_directive === "original";
return (
<Dialog open={isOpen} onOpenChange={onOpenChange}>
<DialogContent className="max-w-7xl max-h-[90vh] flex flex-col">
<DialogHeader className="flex-shrink-0 border-b pb-4">
<DialogTitle className="text-xl">
Review Decomposition Results
</DialogTitle>
<DialogDescription>
Compare the original operation outputs with the decomposed version
before applying changes.
</DialogDescription>
</DialogHeader>
<div className="flex-1 overflow-y-auto py-4 space-y-6">
{/* Summary */}
<div className="p-4 bg-muted rounded-lg space-y-2">
<div className="flex items-center gap-4 flex-wrap">
{operationName && (
<div className="flex items-center">
<span className="font-medium text-sm mr-2">Operation:</span>
<span className="bg-primary/15 text-primary rounded-md px-2 py-1 text-sm">
{operationName}
</span>
</div>
)}
<div className="flex items-center">
<span className="font-medium text-sm mr-2">
Winning Strategy:
</span>
<span
className={`rounded-md px-2 py-1 text-sm ${
isOriginalWinner
? "bg-yellow-100 text-yellow-800"
: "bg-green-100 text-green-800"
}`}
>
{result.winning_directive}
</span>
</div>
<div className="flex items-center">
<span className="font-medium text-sm mr-2">
Candidates Evaluated:
</span>
<span className="text-sm">{result.candidates_evaluated}</span>
</div>
{result.cost && (
<div className="flex items-center">
<span className="font-medium text-sm mr-2">Cost:</span>
<span className="text-sm">${result.cost.toFixed(4)}</span>
</div>
)}
</div>
</div>
{/* Comparison Rationale */}
{result.comparison_rationale && (
<div className="space-y-2">
<h3 className="font-medium text-base">Why This Was Chosen</h3>
<div className="p-3 bg-blue-50 border border-blue-200 rounded-lg text-sm">
{result.comparison_rationale}
</div>
</div>
)}
{/* Side by side comparison */}
<div className="space-y-4">
<div className="flex items-center justify-between">
<h3 className="font-medium text-base">
Output Comparison (Sample {currentPage} of {maxRows || 1})
</h3>
<div className="flex items-center space-x-1">
<Button
variant="outline"
size="sm"
onClick={() =>
setCurrentPage((prev) => Math.max(prev - 1, 1))
}
disabled={currentPage === 1}
className="px-2 py-1"
>
<ChevronLeft className="h-4 w-4" />
Previous
</Button>
<Button
variant="outline"
size="sm"
onClick={() =>
setCurrentPage((prev) => Math.min(prev + 1, maxRows))
}
disabled={currentPage === maxRows || maxRows === 0}
className="px-2 py-1"
>
Next
<ChevronRight className="h-4 w-4" />
</Button>
</div>
</div>
<div className="grid grid-cols-2 gap-4">
{renderOutputTable(originalOutputs, "Original Output")}
{renderOutputTable(
decomposedOutputs,
`Decomposed Output (${result.winning_directive})`
)}
</div>
</div>
{/* Decomposed Operations Preview */}
{result.decomposed_operations &&
result.decomposed_operations.length > 0 &&
!isOriginalWinner && (
<div className="space-y-2">
<h3 className="font-medium text-base">
New Operations ({result.decomposed_operations.length})
</h3>
<div className="space-y-2">
{result.decomposed_operations.map((op, idx) => (
<div
key={idx}
className="p-3 border rounded-lg bg-muted/50"
>
<div className="flex items-center gap-2 mb-2">
<span className="font-medium text-sm">
{(op.name as string) || `Operation ${idx + 1}`}
</span>
<span className="text-xs bg-primary/10 text-primary px-2 py-0.5 rounded">
{op.type as string}
</span>
</div>
{op.prompt && (
<pre className="text-xs font-mono bg-background p-2 rounded border max-h-[100px] overflow-auto">
{String(op.prompt).slice(0, 500)}
{String(op.prompt).length > 500 ? "..." : ""}
</pre>
)}
</div>
))}
</div>
</div>
)}
</div>
<div className="flex justify-end items-center gap-3 pt-4 border-t mt-4">
<Button variant="outline" onClick={onCancel}>
<X className="h-4 w-4 mr-2" />
Keep Original
</Button>
<Button
onClick={onApply}
disabled={isOriginalWinner}
title={
isOriginalWinner
? "Original was determined to be the best option"
: undefined
}
>
<Check className="h-4 w-4 mr-2" />
Apply Decomposition
</Button>
</div>
</DialogContent>
</Dialog>
);
};
DecompositionComparisonDialog.displayName = "DecompositionComparisonDialog";

View File

@ -39,7 +39,6 @@ import { Skeleton } from "@/components/ui/skeleton";
import { debounce } from "lodash";
import { Guardrails, GleaningConfig } from "./operations/args";
import createOperationComponent from "./operations/components";
import { useWebSocket } from "@/contexts/WebSocketContext";
import { Badge } from "./ui/badge";
import {
Popover,
@ -161,7 +160,7 @@ const OperationHeader: React.FC<OperationHeaderProps> = React.memo(
}`}
/>
</HoverCardTrigger>
<HoverCardContent className="w-72" side="bottom" align="start">
<HoverCardContent className="w-96 max-h-80" side="bottom" align="start">
<div className="flex flex-col space-y-1">
<p className="text-sm font-medium">
{optimizeResult === undefined || optimizeResult === null
@ -170,13 +169,15 @@ const OperationHeader: React.FC<OperationHeaderProps> = React.memo(
? "Decomposition Status"
: "Decomposition Recommended"}
</p>
<p className="text-sm text-muted-foreground">
{optimizeResult === undefined || optimizeResult === null
? "Analyzing operation complexity..."
: optimizeResult === ""
? "No decomposition needed for this operation"
: "Recommended decomposition: " + optimizeResult}
</p>
<div className="max-h-64 overflow-y-auto">
<p className="text-sm text-muted-foreground whitespace-pre-wrap">
{optimizeResult === undefined || optimizeResult === null
? "Analyzing operation complexity..."
: optimizeResult === ""
? "No decomposition needed for this operation"
: optimizeResult}
</p>
</div>
</div>
</HoverCardContent>
</HoverCard>
@ -367,7 +368,7 @@ const OperationHeader: React.FC<OperationHeaderProps> = React.memo(
disabled={disabled}
>
<Zap className="mr-2 h-4 w-4" />
Optimize Operation
Decompose Operation
</Button>
)}
@ -720,7 +721,6 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
output: pipelineOutput,
setOutput,
isLoadingOutputs,
setIsLoadingOutputs,
numOpRun,
setNumOpRun,
currentFile,
@ -730,18 +730,14 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
sampleSize,
setCost,
defaultModel,
optimizerModel,
setTerminalOutput,
namespace,
apiKeys,
systemPrompt,
extraPipelineSettings,
onRequestDecompositionRef,
} = usePipelineContext();
const { toast } = useToast();
const operationRef = useRef(operation);
const { connect, sendMessage, lastMessage, readyState, disconnect } =
useWebSocket();
useEffect(() => {
operationRef.current = operation;
@ -808,75 +804,19 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
const handleOptimizeConfirm = useCallback(async () => {
if (!operation) return;
try {
// Clear the output
setTerminalOutput("");
setIsLoadingOutputs(true);
setShowOptimizeDialog(false);
// Write pipeline config
const response = await fetch("/api/writePipelineConfig", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
default_model: defaultModel,
data: { path: currentFile?.path || "" },
operations,
operation_id: operation.id,
name: pipelineName,
sample_size: sampleSize,
optimize: true,
clear_intermediate: false,
system_prompt: systemPrompt,
namespace: namespace,
apiKeys: apiKeys,
optimizerModel: optimizerModel,
extraPipelineSettings: extraPipelineSettings,
}),
});
if (!response.ok) {
throw new Error(await response.text());
}
const { filePath } = await response.json();
// Ensure WebSocket is connected
await connect();
// Send message to run the pipeline
sendMessage({
yaml_config: filePath,
optimize: true,
});
} catch (error) {
console.error("Error optimizing operation:", error);
// Use the decomposition flow from context if available
if (onRequestDecompositionRef.current) {
onRequestDecompositionRef.current(operation.id, operation.name);
} else {
toast({
title: "Error",
description: error.message,
description: "Decomposition handler not available",
variant: "destructive",
});
// Close the WebSocket connection
disconnect();
} finally {
setShowOptimizeDialog(false);
}
}, [
operation,
defaultModel,
currentFile,
operations,
pipelineName,
sampleSize,
optimizerModel,
connect,
sendMessage,
systemPrompt,
namespace,
apiKeys,
extraPipelineSettings,
]);
}, [operation, onRequestDecompositionRef, toast]);
const onShowOutput = useCallback(async () => {
if (!operation) return;
@ -1224,7 +1164,7 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle>Optimize Operation</AlertDialogTitle>
<AlertDialogTitle>Decompose Operation</AlertDialogTitle>
<AlertDialogDescription>
{!hasOpenAIKey && !isLocalMode ? (
<div className="space-y-2">
@ -1232,8 +1172,8 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
OpenAI API Key Required
</p>
<p>
To use the optimizer, please add your OpenAI API key in Edit{" "}
{">"}
To decompose operations, please add your OpenAI API key in
Edit {">"}
Edit API Keys.
</p>
<button
@ -1245,11 +1185,10 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
</div>
) : (
<p>
This will analyze the operation and replace it with another
pipeline that has higher accuracy (as determined by an
LLM-as-a-judge), if it can be found. Do you want to proceed?
The process may take between 2 and 10 minutes, depending on
how complex your data is.
This will analyze the operation and try different
decomposition strategies (chaining, chunking, gleaning, etc.)
to improve accuracy. The best strategy will be selected via
LLM-as-a-judge comparison. Do you want to proceed?
</p>
)}
</AlertDialogDescription>
@ -1260,7 +1199,7 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
onClick={handleOptimizeConfirm}
disabled={!hasOpenAIKey && !isLocalMode}
>
Proceed
Decompose
</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>

View File

@ -23,6 +23,7 @@ interface OptimizationDialogProps {
operationName?: string;
inputData?: Array<Record<string, unknown>>;
outputData?: Array<Record<string, unknown>>;
numDocsAnalyzed?: number;
onOpenChange: (open: boolean) => void;
onDecompose?: () => void;
}
@ -34,6 +35,7 @@ export const OptimizationDialog: React.FC<OptimizationDialogProps> = ({
inputData,
outputData,
operationName,
numDocsAnalyzed,
onOpenChange,
onDecompose,
}) => {
@ -140,9 +142,9 @@ export const OptimizationDialog: React.FC<OptimizationDialogProps> = ({
<DialogHeader className="flex-shrink-0 border-b pb-4">
<DialogTitle className="text-xl">Operation Too Complex</DialogTitle>
<p className="text-base mt-2">
This operation might be too complex for the LLM to handle
efficiently. We recommend breaking it down into smaller, more
manageable steps.
After analyzing {numDocsAnalyzed ?? "several"} documents, this
operation might be too complex for the LLM to handle efficiently. We
recommend breaking it down into smaller, more manageable steps.
</p>
</DialogHeader>

View File

@ -81,6 +81,7 @@ const useOutputContext = () => {
const {
output,
isLoadingOutputs,
isDecomposing,
terminalOutput,
setTerminalOutput,
optimizerProgress,
@ -91,6 +92,7 @@ const useOutputContext = () => {
return {
output,
isLoadingOutputs,
isDecomposing,
terminalOutput,
setTerminalOutput,
optimizerProgress,
@ -385,7 +387,7 @@ VisualizeContent.displayName = "VisualizeContent";
// Move ConsoleContent outside
export const ConsoleContent = memo(() => {
const { terminalOutput, setTerminalOutput, optimizerProgress } =
const { terminalOutput, setTerminalOutput, optimizerProgress, isDecomposing } =
useOutputContext();
const { readyState } = useWebSocket();
@ -469,6 +471,7 @@ export const ConsoleContent = memo(() => {
text={terminalOutput || ""}
readyState={readyState}
setTerminalOutput={setTerminalOutput}
isDecomposing={isDecomposing}
/>
</div>
</div>
@ -478,7 +481,7 @@ ConsoleContent.displayName = "ConsoleContent";
// Main Output component
export const Output = memo(() => {
const { output, isLoadingOutputs, sampleSize, operations } =
const { output, isLoadingOutputs, isDecomposing, sampleSize, operations } =
useOutputContext();
const operation = useOperation(output?.operationId);
@ -500,10 +503,14 @@ export const Output = memo(() => {
);
}, [operation]);
// Effect for tab changes
// Effect for tab changes - switch to console when loading or decomposing
useEffect(() => {
setActiveTab(isLoadingOutputs ? "console" : "table");
}, [isLoadingOutputs]);
if (isLoadingOutputs || isDecomposing) {
setActiveTab("console");
} else {
setActiveTab("table");
}
}, [isLoadingOutputs, isDecomposing]);
// Memoize columns
const columns = useMemo(() => {

View File

@ -40,9 +40,12 @@ import { Input } from "@/components/ui/input";
import { schemaDictToItemSet } from "./utils";
import { v4 as uuidv4 } from "uuid";
import { useOptimizeCheck } from "@/hooks/useOptimizeCheck";
import { useDecomposeWebSocket } from "@/hooks/useDecomposeWebSocket";
import { canBeOptimized } from "@/lib/utils";
import { Textarea } from "./ui/textarea";
import { OptimizationDialog } from "@/components/OptimizationDialog";
import { DecompositionComparisonDialog } from "@/components/DecompositionComparisonDialog";
import { DecomposeResult } from "@/app/types";
import {
Popover,
PopoverContent,
@ -240,6 +243,9 @@ const PipelineGUI: React.FC = () => {
apiKeys,
extraPipelineSettings,
setExtraPipelineSettings,
isDecomposing,
setIsDecomposing,
onRequestDecompositionRef,
} = usePipelineContext();
const [isSettingsOpen, setIsSettingsOpen] = useState(false);
const { toast } = useToast();
@ -253,12 +259,14 @@ const PipelineGUI: React.FC = () => {
outputData?: Array<Record<string, unknown>>;
operationName?: string;
operationId?: string;
numDocsAnalyzed?: number;
}>({
isOpen: false,
content: "",
prompt: undefined,
operationName: undefined,
operationId: undefined,
numDocsAnalyzed: undefined,
});
const [isEditingName, setIsEditingName] = useState(false);
const [editedPipelineName, setEditedPipelineName] = useState(pipelineName);
@ -277,14 +285,13 @@ const PipelineGUI: React.FC = () => {
setCost((prev) => prev + result.cost);
if (result.should_optimize) {
toast({
title: `Hey! Consider decomposing ${operations[operations.length - 1].name
}`,
const lastOp = operations[operations.length - 1];
const { dismiss } = toast({
title: `Hey! Consider decomposing ${lastOp.name}`,
description: (
<span
className="cursor-pointer text-blue-500 hover:text-blue-700"
onClick={() => {
const lastOp = operations[operations.length - 1];
setOptimizationDialog({
isOpen: true,
content: result.should_optimize,
@ -293,7 +300,9 @@ const PipelineGUI: React.FC = () => {
operationId: lastOp.id,
inputData: result.input_data,
outputData: result.output_data,
numDocsAnalyzed: result.num_docs_analyzed,
});
dismiss();
}}
>
Click here to see why.
@ -312,6 +321,144 @@ const PipelineGUI: React.FC = () => {
},
});
const [decomposeDialog, setDecomposeDialog] = useState<{
isOpen: boolean;
result: DecomposeResult | null;
targetOpId: string | null;
operationName: string | null;
}>({
isOpen: false,
result: null,
targetOpId: null,
operationName: null,
});
// Ref to store the current decomposition target (avoids stale closure issue)
const decompositionTargetRef = useRef<{
operationId: string;
operationName: string;
} | null>(null);
const { startDecomposition, cancel: cancelDecomposition } = useDecomposeWebSocket({
namespace,
onOutput: (output) => {
// Stream output to the terminal
setTerminalOutput(output);
},
onComplete: (result) => {
setIsDecomposing(false);
setCost((prev) => prev + (result.cost || 0));
// Read from ref to avoid stale closure
const target = decompositionTargetRef.current;
console.log("Decomposition complete:", {
result,
target,
decomposed_operations: result.decomposed_operations,
});
// Show the comparison dialog instead of auto-applying
setDecomposeDialog({
isOpen: true,
result,
targetOpId: target?.operationId || null,
operationName: target?.operationName || null,
});
},
onError: (error) => {
setIsDecomposing(false);
toast({
title: "Decomposition Failed",
description: error,
variant: "destructive",
});
},
});
const handleApplyDecomposition = () => {
const { result, targetOpId } = decomposeDialog;
console.log("handleApplyDecomposition called:", { result, targetOpId, decomposeDialog });
if (!result || !targetOpId) {
console.warn("handleApplyDecomposition: missing result or targetOpId", { result, targetOpId });
return;
}
if (
result.decomposed_operations &&
result.decomposed_operations.length > 0 &&
result.winning_directive !== "original"
) {
setOperations((prev) => {
// Find the index of the target operation
const targetIdx = prev.findIndex((op) => op.id === targetOpId);
if (targetIdx === -1) return prev;
// Convert decomposed operations to our Operation format
const newOps: Operation[] = result.decomposed_operations!.map(
(opConfig) => ({
id: uuidv4(),
llmType:
opConfig.type === "map" ||
opConfig.type === "reduce" ||
opConfig.type === "filter" ||
opConfig.type === "parallel_map"
? "LLM"
: "non-LLM",
type: opConfig.type as Operation["type"],
name: opConfig.name as string,
prompt: opConfig.prompt as string | undefined,
output: opConfig.output
? {
schema: schemaDictToItemSet(
(opConfig.output as { schema: Record<string, string> })
.schema
),
}
: undefined,
gleaning: opConfig.gleaning as Operation["gleaning"],
otherKwargs: Object.fromEntries(
Object.entries(opConfig).filter(
([key]) =>
!["name", "type", "prompt", "output", "gleaning"].includes(key)
)
),
visibility: true,
})
);
// Replace the target operation with the new operations
const newOperations = [...prev];
newOperations.splice(targetIdx, 1, ...newOps);
return newOperations;
});
toast({
title: "Decomposition Applied",
description: `Applied ${result.winning_directive} directive. Created ${result.decomposed_operations.length} operation(s).`,
});
}
setDecomposeDialog({
isOpen: false,
result: null,
targetOpId: null,
operationName: null,
});
};
const handleCancelDecomposition = () => {
toast({
title: "Decomposition Cancelled",
description: "Keeping the original operation.",
});
setDecomposeDialog({
isOpen: false,
result: null,
targetOpId: null,
operationName: null,
});
};
const { restoreFromYAML } = useRestorePipeline({
setOperations,
setPipelineName,
@ -328,79 +475,18 @@ const PipelineGUI: React.FC = () => {
if (lastMessage) {
if (lastMessage.type === "output") {
setTerminalOutput(lastMessage.data);
} else if (lastMessage.type === "optimizer_progress") {
setOptimizerProgress({
status: lastMessage.status,
progress: lastMessage.progress,
shouldOptimize: lastMessage.should_optimize,
rationale: lastMessage.rationale,
validatorPrompt: lastMessage.validator_prompt,
});
} else if (lastMessage.type === "result") {
const runCost = lastMessage.data.cost || 0;
setOptimizerProgress(null);
// See if there was an optimized operation
const optimizedOps = lastMessage.data.optimized_ops;
if (optimizedOps) {
const newOperations = optimizedOps.map((optimizedOp) => {
const {
id,
type,
name,
prompt,
output,
validate,
gleaning,
sample,
...otherKwargs
} = optimizedOp;
// Find matching operation in previous operations list
const existingOp = operations.find((op) => op.name === name);
return {
id: id || uuidv4(),
llmType:
type === "map" ||
type === "reduce" ||
type === "resolve" ||
type === "filter" ||
type === "parallel_map" ||
type === "rank" ||
type === "extract"
? "LLM"
: "non-LLM",
type: type,
name: name || "Untitled Operation",
prompt: prompt,
output: output
? {
schema: schemaDictToItemSet(output.schema),
}
: undefined,
validate: validate,
gleaning: gleaning,
sample: sample,
otherKwargs: otherKwargs || {},
...(existingOp?.runIndex && { runIndex: existingOp.runIndex }),
visibility: true,
} as Operation;
});
setOperations(newOperations);
} else {
// No optimized operations, so we need to check if we should optimize the last operation
// Trigger should optimize for the last operation
if (autoOptimizeCheck) {
const lastOp = operations[operations.length - 1];
if (lastOp && canBeOptimized(lastOp.type)) {
submitTask({
yaml_config: lastMessage.data.yaml_config,
step_name: "data_processing", // TODO: Make this a constant
op_name: lastOp.name,
});
}
// Trigger should_optimize check for the last operation if enabled
if (autoOptimizeCheck) {
const lastOp = operations[operations.length - 1];
if (lastOp && canBeOptimized(lastOp.type)) {
submitTask({
yaml_config: lastMessage.data.yaml_config,
step_name: "data_processing",
op_name: lastOp.name,
});
}
}
@ -692,20 +778,35 @@ const PipelineGUI: React.FC = () => {
};
const handleStop = () => {
sendMessage("kill");
// Stop pipeline run if running
if (isLoadingOutputs) {
sendMessage("kill");
if (readyState === WebSocket.CLOSED) {
setIsLoadingOutputs(false);
}
}
if (readyState === WebSocket.CLOSED && isLoadingOutputs) {
setIsLoadingOutputs(false);
// Stop decomposition if running
if (isDecomposing) {
cancelDecomposition();
setIsDecomposing(false);
}
};
const handleOptimizeFromDialog = async () => {
if (!optimizationDialog.operationId) return;
if (!optimizationDialog.operationId || !optimizationDialog.operationName) return;
// Store target in ref so onComplete callback can access it
decompositionTargetRef.current = {
operationId: optimizationDialog.operationId,
operationName: optimizationDialog.operationName,
};
try {
setTerminalOutput("");
setIsLoadingOutputs(true);
setIsDecomposing(true);
setTerminalOutput(""); // Clear terminal output
// Write the pipeline config to get a YAML file path
const response = await fetch("/api/writePipelineConfig", {
method: "POST",
headers: {
@ -732,24 +833,115 @@ const PipelineGUI: React.FC = () => {
const { filePath } = await response.json();
await connect();
// Use the WebSocket decompose endpoint for streaming output
await startDecomposition(
filePath,
"data_processing", // Matches the step name in generatePipelineConfig
optimizationDialog.operationName
);
sendMessage({
yaml_config: filePath,
optimize: true,
toast({
title: "Decomposing Operation",
description: "Analyzing and generating candidate decompositions. Watch the console for progress.",
});
} catch (error) {
console.error("Error optimizing operation:", error);
console.error("Error starting decomposition:", error);
setIsDecomposing(false);
toast({
title: "Error",
description: error.message,
description: error instanceof Error ? error.message : "Failed to start decomposition",
variant: "destructive",
});
disconnect();
setIsLoadingOutputs(false);
}
};
// Handler for decomposition requests from OperationCard dropdown
const handleDecompositionRequest = useCallback(
async (operationId: string, operationName: string) => {
// Store target in ref so onComplete callback can access it
decompositionTargetRef.current = {
operationId,
operationName,
};
try {
setIsDecomposing(true);
setTerminalOutput(""); // Clear terminal output
// Set the dialog state so we know which operation we're decomposing
setDecomposeDialog((prev) => ({
...prev,
targetOpId: operationId,
operationName: operationName,
}));
// Write the pipeline config to get a YAML file path
const response = await fetch("/api/writePipelineConfig", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
default_model: defaultModel,
data: { path: currentFile?.path || "" },
operations,
operation_id: operationId,
name: pipelineName,
sample_size: sampleSize,
optimize: true,
namespace: namespace,
apiKeys: apiKeys,
optimizerModel: optimizerModel,
extraPipelineSettings: extraPipelineSettings,
}),
});
if (!response.ok) {
throw new Error(await response.text());
}
const { filePath } = await response.json();
// Start decomposition via WebSocket
startDecomposition(filePath, "data_processing", operationName);
toast({
title: "Decomposing Operation",
description:
"Analyzing and generating candidate decompositions. Watch the console for progress.",
});
} catch (error) {
console.error("Error starting decomposition:", error);
setIsDecomposing(false);
toast({
title: "Error",
description:
error instanceof Error ? error.message : "Failed to start decomposition",
variant: "destructive",
});
}
},
[
defaultModel,
currentFile,
operations,
pipelineName,
sampleSize,
namespace,
apiKeys,
optimizerModel,
extraPipelineSettings,
startDecomposition,
setIsDecomposing,
setTerminalOutput,
toast,
]
);
// Register the decomposition handler in context so OperationCard can use it
// Using ref assignment instead of state to avoid infinite loops
onRequestDecompositionRef.current = handleDecompositionRequest;
return (
<div className="flex flex-col h-full">
<div
@ -1053,7 +1245,7 @@ const PipelineGUI: React.FC = () => {
variant="destructive"
className="rounded-sm whitespace-nowrap"
onClick={handleStop}
disabled={!isLoadingOutputs}
disabled={!isLoadingOutputs && !isDecomposing}
>
<StopCircle size={16} className="mr-2" />
Stop
@ -1144,11 +1336,22 @@ const PipelineGUI: React.FC = () => {
operationName={optimizationDialog.operationName}
inputData={optimizationDialog.inputData}
outputData={optimizationDialog.outputData}
numDocsAnalyzed={optimizationDialog.numDocsAnalyzed}
onOpenChange={(open) =>
setOptimizationDialog((prev) => ({ ...prev, isOpen: open }))
}
onDecompose={handleOptimizeFromDialog}
/>
<DecompositionComparisonDialog
isOpen={decomposeDialog.isOpen}
result={decomposeDialog.result}
operationName={decomposeDialog.operationName || undefined}
onOpenChange={(open) =>
setDecomposeDialog((prev) => ({ ...prev, isOpen: open }))
}
onApply={handleApplyDecomposition}
onCancel={handleCancelDecomposition}
/>
</div>
);
};

View File

@ -31,6 +31,7 @@ interface PromptInputProps {
onChange: (value: string) => void;
disableValidation?: boolean;
placeholder?: string;
operationType?: "map" | "reduce" | "filter" | "resolve" | "other";
}
export const PromptInput: React.FC<PromptInputProps> = React.memo(
@ -38,31 +39,62 @@ export const PromptInput: React.FC<PromptInputProps> = React.memo(
prompt,
onChange,
disableValidation = false,
placeholder = "Enter prompt (must be a Jinja2 template)",
placeholder = "Enter prompt (Jinja2 template optional)",
operationType = "map",
}) => {
const validateJinjaTemplate = (value: string) => {
const hasJinjaTemplate = (value: string) => {
if (disableValidation) return true;
const hasOpenBrace = value.includes("{{");
const hasCloseBrace = value.includes("}}");
return hasOpenBrace && hasCloseBrace;
};
const isReduce = operationType === "reduce";
const templateVar = isReduce ? "inputs" : "input";
const exampleKey = isReduce ? "inputs[0].content" : "input.content";
return (
<>
<Textarea
placeholder={placeholder}
className={`mb-1 rounded-sm text-sm font-mono ${
!validateJinjaTemplate(prompt) ? "border-red-500" : ""
!hasJinjaTemplate(prompt) ? "border-amber-400" : ""
}`}
rows={Math.max(3, Math.ceil(prompt.split("\n").length))}
value={prompt}
onChange={(e) => onChange(e.target.value)}
/>
{!validateJinjaTemplate(prompt) && (
<div className="text-red-500 text-sm mb-1">
Prompt must contain Jinja2 template syntax {"{"}
{"{"} and {"}"}
{"}"}
{!hasJinjaTemplate(prompt) && (
<div className="text-amber-600 text-sm mb-1 space-y-1">
<div>
<strong>Warning:</strong> No Jinja2 template found.
</div>
<div className="text-xs text-muted-foreground">
{isReduce ? (
<>
{"{"}
{"{"} inputs {"}"}
{"}"} (the list of all documents) will be auto-appended. For
example, if documents have keys {'"'}id{'"'} and {'"'}content
{'"'}, all docs with all keys will be included. To reference
specific keys, use a for loop: {"{"}% for doc in inputs %{"}"}{" "}
{"{"}
{"{"} doc.content {"}"}
{"}"} {"{"}% endfor %{"}"}.
</>
) : (
<>
{"{"}
{"{"} input {"}"}
{"}"} (the full document) will be auto-appended. For example,
if documents have keys {'"'}id{'"'}, {'"'}title{'"'}, and {'"'}
content{'"'}, all three will be included. To use only specific
keys (e.g., just {'"'}content{'"'}), add {"{"}
{"{"} input.content {"}"}
{"}"} to your prompt.
</>
)}
</div>
</div>
)}
</>
@ -77,6 +109,7 @@ PromptInput.propTypes = {
onChange: PropTypes.func.isRequired,
disableValidation: PropTypes.bool,
placeholder: PropTypes.string,
operationType: PropTypes.oneOf(["map", "reduce", "filter", "resolve", "other"]),
};
interface SchemaFormProps {

View File

@ -314,6 +314,7 @@ export const ReduceOperationComponent: React.FC<OperationComponentProps> = ({
<PromptInput
prompt={operation.prompt || ""}
onChange={handlePromptChange}
operationType="reduce"
/>
<OutputSchema
schema={schemaItems}

View File

@ -26,7 +26,7 @@ const ToastViewport = React.forwardRef<
ToastViewport.displayName = ToastPrimitives.Viewport.displayName;
const toastVariants = cva(
"group pointer-events-auto relative flex w-full items-center justify-between space-x-2 overflow-y-auto max-h-[80vh] whitespace-pre-wrap break-words rounded-md border p-4 pr-6 shadow-lg transition-all data-[swipe=cancel]:translate-x-0 data-[swipe=end]:translate-x-[var(--radix-toast-swipe-end-x)] data-[swipe=move]:translate-x-[var(--radix-toast-swipe-move-x)] data-[swipe=move]:transition-none data-[state=open]:animate-in data-[state=closed]:animate-out data-[swipe=end]:animate-out data-[state=closed]:fade-out-80 data-[state=closed]:slide-out-to-right-full data-[state=open]:slide-in-from-top-full data-[state=open]:sm:slide-in-from-bottom-full",
"group pointer-events-auto relative flex w-full items-start justify-between space-x-2 overflow-hidden max-h-[80vh] whitespace-pre-wrap break-words rounded-md border p-4 pr-6 shadow-lg transition-all data-[swipe=cancel]:translate-x-0 data-[swipe=end]:translate-x-[var(--radix-toast-swipe-end-x)] data-[swipe=move]:translate-x-[var(--radix-toast-swipe-move-x)] data-[swipe=move]:transition-none data-[state=open]:animate-in data-[state=closed]:animate-out data-[swipe=end]:animate-out data-[state=closed]:fade-out-80 data-[state=closed]:slide-out-to-right-full data-[state=open]:slide-in-from-top-full data-[state=open]:sm:slide-in-from-bottom-full",
{
variants: {
variant: {

View File

@ -22,29 +22,33 @@ export function Toaster() {
const descId = `toast-desc-${id}`;
return (
<Toast key={id} {...props}>
<div className="grid gap-1">
{title && <ToastTitle>{title}</ToastTitle>}
<div className="grid gap-1 flex-1 overflow-hidden">
<div className="flex items-center gap-2 shrink-0">
{title && <ToastTitle>{title}</ToastTitle>}
{description && (
<Button
variant="ghost"
size="sm"
className="h-5 w-5 p-0 opacity-70 hover:opacity-100"
onClick={() => {
const el = document.getElementById(descId);
const text = el?.textContent ?? "";
if (text) {
navigator.clipboard.writeText(text);
}
}}
title="Copy message"
>
<Copy size={12} />
</Button>
)}
</div>
{description && (
<ToastDescription id={descId}>{description}</ToastDescription>
<div className="overflow-y-auto max-h-[60vh]">
<ToastDescription id={descId}>{description}</ToastDescription>
</div>
)}
</div>
{description && (
<Button
variant="ghost"
size="sm"
className="h-6 px-2 text-xs"
onClick={() => {
const el = document.getElementById(descId);
const text = el?.textContent ?? "";
if (text) {
navigator.clipboard.writeText(text);
}
}}
title="Copy message"
>
<Copy size={12} />
</Button>
)}
{action}
<ToastClose />
</Toast>

View File

@ -29,6 +29,7 @@ interface PipelineState {
validatorPrompt: string;
} | null;
isLoadingOutputs: boolean;
isDecomposing: boolean;
numOpRun: number;
pipelineName: string;
sampleSize: number | null;
@ -59,6 +60,7 @@ interface PipelineContextType extends PipelineState {
} | null>
>;
setIsLoadingOutputs: React.Dispatch<React.SetStateAction<boolean>>;
setIsDecomposing: React.Dispatch<React.SetStateAction<boolean>>;
setNumOpRun: React.Dispatch<React.SetStateAction<number>>;
setPipelineName: React.Dispatch<React.SetStateAction<string>>;
setSampleSize: React.Dispatch<React.SetStateAction<number | null>>;
@ -83,6 +85,10 @@ interface PipelineContextType extends PipelineState {
setExtraPipelineSettings: React.Dispatch<
React.SetStateAction<Record<string, unknown> | null>
>;
// Ref for triggering decomposition from OperationCard (using ref to avoid infinite loops)
onRequestDecompositionRef: React.MutableRefObject<
((operationId: string, operationName: string) => void) | null
>;
}
const PipelineContext = createContext<PipelineContextType | undefined>(
@ -290,6 +296,7 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
localStorageKeys.IS_LOADING_OUTPUTS_KEY,
false
),
isDecomposing: false, // Transient state, not persisted
numOpRun: loadFromLocalStorage(localStorageKeys.NUM_OP_RUN_KEY, 0),
pipelineName: loadFromLocalStorage(
localStorageKeys.PIPELINE_NAME_KEY,
@ -333,6 +340,11 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
const stateRef = useRef(state);
const [isMounted, setIsMounted] = useState(false);
// Ref for triggering decomposition from OperationCard (using ref to avoid infinite loops)
const onRequestDecompositionRef = useRef<
((operationId: string, operationName: string) => void) | null
>(null);
useEffect(() => {
stateRef.current = state;
}, [state]);
@ -423,6 +435,7 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
output: null,
terminalOutput: "",
isLoadingOutputs: false,
isDecomposing: false,
numOpRun: 0,
pipelineName: mockPipelineName,
sampleSize: mockSampleSize,
@ -529,6 +542,10 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
(value) => setStateAndUpdate("isLoadingOutputs", value),
[setStateAndUpdate]
),
setIsDecomposing: useCallback(
(value) => setStateAndUpdate("isDecomposing", value),
[setStateAndUpdate]
),
setNumOpRun: useCallback(
(value) => setStateAndUpdate("numOpRun", value),
[setStateAndUpdate]
@ -589,6 +606,7 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
(value) => setStateAndUpdate("extraPipelineSettings", value),
[setStateAndUpdate]
),
onRequestDecompositionRef,
};
return (

View File

@ -0,0 +1,104 @@
import { useState, useEffect } from "react";
import axios from "axios";
import { DecomposeResult, DecomposeRequest } from "@/app/types";
import { API_ROUTES } from "@/app/api/constants";
interface UseDecomposeProps {
onComplete?: (result: DecomposeResult) => void;
onError?: (error: string) => void;
pollInterval?: number;
}
export function useDecompose({
onComplete,
onError,
pollInterval = 2000,
}: UseDecomposeProps = {}) {
const [taskId, setTaskId] = useState<string | null>(null);
const [result, setResult] = useState<DecomposeResult | null>(null);
const [error, setError] = useState<string | null>(null);
const [isLoading, setIsLoading] = useState(false);
const submitTask = async (request: DecomposeRequest) => {
try {
setIsLoading(true);
setError(null);
setResult(null);
const response = await axios.post<{ task_id: string }>(
API_ROUTES.DECOMPOSE.SUBMIT,
request
);
setTaskId(response.data.task_id);
} catch (err) {
const errorMessage =
err instanceof Error ? err.message : "Failed to submit decompose task";
setError(errorMessage);
onError?.(errorMessage);
setIsLoading(false);
}
};
const cancelTask = async () => {
if (!taskId) return;
try {
await axios.post(API_ROUTES.DECOMPOSE.CANCEL(taskId));
setTaskId(null);
setIsLoading(false);
} catch (err) {
const errorMessage =
err instanceof Error ? err.message : "Failed to cancel task";
setError(errorMessage);
onError?.(errorMessage);
}
};
useEffect(() => {
if (!taskId) return;
const pollTask = async () => {
try {
const response = await axios.get<DecomposeResult>(
API_ROUTES.DECOMPOSE.STATUS(taskId)
);
setResult(response.data);
if (
["completed", "failed", "cancelled"].includes(response.data.status)
) {
setTaskId(null);
setIsLoading(false);
if (response.data.status === "completed") {
onComplete?.(response.data);
} else if (response.data.status === "failed" && response.data.error) {
setError(response.data.error);
onError?.(response.data.error);
}
}
} catch (err) {
const errorMessage =
err instanceof Error ? err.message : "Failed to fetch task status";
setError(errorMessage);
onError?.(errorMessage);
setTaskId(null);
setIsLoading(false);
}
};
const interval = setInterval(pollTask, pollInterval);
return () => clearInterval(interval);
}, [taskId, onComplete, onError, pollInterval]);
return {
submitTask,
cancelTask,
result,
error,
isLoading,
isRunning: !!taskId,
};
}

View File

@ -0,0 +1,145 @@
import { useState, useRef, useCallback } from "react";
import { DecomposeResult } from "@/app/types";
interface DecomposeWebSocketMessage {
type: "output" | "result" | "error";
data?: any;
message?: string;
traceback?: string;
}
interface UseDecomposeWebSocketProps {
namespace: string;
onOutput?: (output: string) => void;
onComplete?: (result: DecomposeResult) => void;
onError?: (error: string) => void;
}
export function useDecomposeWebSocket({
namespace,
onOutput,
onComplete,
onError,
}: UseDecomposeWebSocketProps) {
const [isConnected, setIsConnected] = useState(false);
const [isRunning, setIsRunning] = useState(false);
const ws = useRef<WebSocket | null>(null);
const connect = useCallback((): Promise<void> => {
return new Promise((resolve, reject) => {
if (ws.current?.readyState === WebSocket.OPEN) {
resolve();
return;
}
if (!namespace) {
reject(new Error("Namespace is required for WebSocket connection"));
return;
}
const wsUrl = `${
process.env.NEXT_PUBLIC_BACKEND_HTTPS === "true" ? "wss://" : "ws://"
}${process.env.NEXT_PUBLIC_BACKEND_HOST}:${
process.env.NEXT_PUBLIC_BACKEND_PORT
}/ws/decompose/${namespace}`;
ws.current = new WebSocket(wsUrl);
ws.current.onopen = () => {
setIsConnected(true);
resolve();
};
ws.current.onclose = () => {
setIsConnected(false);
setIsRunning(false);
};
ws.current.onerror = (error: Event) => {
console.error("Decompose WebSocket error:", error);
onError?.("WebSocket connection failed");
reject(new Error("WebSocket connection failed"));
};
ws.current.onmessage = (event) => {
try {
const message: DecomposeWebSocketMessage = JSON.parse(event.data);
if (message.type === "output") {
onOutput?.(message.data);
} else if (message.type === "result") {
setIsRunning(false);
// Convert the result data to DecomposeResult format
const result: DecomposeResult = {
task_id: "", // Not used in WebSocket mode
status: "completed",
decomposed_operations: message.data.decomposed_operations,
winning_directive: message.data.winning_directive,
candidates_evaluated: message.data.candidates_evaluated,
original_outputs: message.data.original_outputs,
decomposed_outputs: message.data.decomposed_outputs,
comparison_rationale: message.data.comparison_rationale,
cost: message.data.cost,
created_at: new Date().toISOString(),
completed_at: new Date().toISOString(),
};
onComplete?.(result);
// Close the connection after receiving result
ws.current?.close();
} else if (message.type === "error") {
setIsRunning(false);
const errorMsg = message.data || message.message || "Unknown error";
onError?.(errorMsg);
ws.current?.close();
}
} catch (error) {
console.error("Error parsing WebSocket message:", error);
onError?.("Failed to parse WebSocket message");
}
};
});
}, [namespace, onOutput, onComplete, onError]);
const startDecomposition = useCallback(
async (yamlConfig: string, stepName: string, opName: string) => {
try {
await connect();
setIsRunning(true);
if (ws.current?.readyState === WebSocket.OPEN) {
ws.current.send(
JSON.stringify({
yaml_config: yamlConfig,
step_name: stepName,
op_name: opName,
})
);
}
} catch (error) {
setIsRunning(false);
throw error;
}
},
[connect]
);
const cancel = useCallback(() => {
if (ws.current?.readyState === WebSocket.OPEN) {
ws.current.send(JSON.stringify("kill"));
}
}, []);
const disconnect = useCallback(() => {
if (ws.current) {
ws.current.close();
}
}, []);
return {
isConnected,
isRunning,
startDecomposition,
cancel,
disconnect,
};
}

View File

@ -6,7 +6,8 @@ export function cn(...inputs: ClassValue[]) {
}
export function canBeOptimized(operationType: string) {
return ["resolve", "map", "reduce", "filter"].includes(operationType);
// Only map operations can be decomposed
return operationType === "map";
}
export const generateId = () => {