Compare commits
3 Commits
main
...
dwoptimize
| Author | SHA1 | Date |
|---|---|---|
|
|
4e9f4532ad | |
|
|
ed2f1fd0a1 | |
|
|
c0445abba9 |
|
|
@ -1,11 +1,15 @@
|
|||
__version__ = "0.2.6"
|
||||
|
||||
import warnings
|
||||
import litellm
|
||||
|
||||
from docetl.runner import DSLRunner
|
||||
from docetl.optimizer import Optimizer
|
||||
from docetl.apis.pd_accessors import SemanticAccessor
|
||||
|
||||
# Drop unsupported params for models like gpt-5 that don't support temperature=0
|
||||
litellm.drop_params = True
|
||||
|
||||
# TODO: Remove after https://github.com/BerriAI/litellm/issues/7560 is fixed
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._config")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from io import StringIO
|
||||
|
|
@ -12,6 +13,63 @@ from docetl.utils import StageType, get_stage_description
|
|||
|
||||
install(show_locals=False)
|
||||
|
||||
# ANSI escape sequences to strip (cursor movement, line clearing, cursor visibility)
|
||||
ANSI_CURSOR_PATTERNS = re.compile(
|
||||
r"\x1b\[\?25[lh]" # Hide/show cursor
|
||||
r"|\x1b\[\d*[ABCD]" # Cursor up/down/forward/back
|
||||
r"|\x1b\[\d*[JK]" # Clear screen/line (but we handle \x1b[2K specially)
|
||||
r"|\x1b\[s|\x1b\[u" # Save/restore cursor position
|
||||
)
|
||||
|
||||
|
||||
def process_carriage_returns(text: str) -> str:
|
||||
"""
|
||||
Process terminal control sequences for buffer-captured output.
|
||||
|
||||
Rich's spinner uses ANSI sequences for:
|
||||
- `\r` - carriage return (move to start of line)
|
||||
- `\x1b[2K` - clear entire line
|
||||
- `\x1b[?25l/h` - hide/show cursor
|
||||
- `\x1b[1A` - move cursor up
|
||||
|
||||
When captured to a buffer (not a real terminal), we simulate this by:
|
||||
1. Handling `\r` as "replace from start of line"
|
||||
2. Stripping cursor movement/visibility sequences
|
||||
3. Keeping only meaningful content
|
||||
"""
|
||||
# First, strip cursor visibility and movement sequences
|
||||
text = ANSI_CURSOR_PATTERNS.sub("", text)
|
||||
|
||||
# Process line by line
|
||||
lines = text.split("\n")
|
||||
processed_lines = []
|
||||
|
||||
for line in lines:
|
||||
# Handle carriage returns - keep only content after the last \r
|
||||
if "\r" in line:
|
||||
segments = line.split("\r")
|
||||
# Keep the last non-empty segment (after stripping ANSI clear codes)
|
||||
for segment in reversed(segments):
|
||||
# Strip the clear line code if present
|
||||
cleaned = segment.replace("\x1b[2K", "").strip()
|
||||
if cleaned:
|
||||
# Keep the segment but preserve any color codes
|
||||
processed_lines.append(segment.replace("\x1b[2K", ""))
|
||||
break
|
||||
else:
|
||||
# All segments were empty, skip this line entirely
|
||||
pass
|
||||
else:
|
||||
# No carriage return, keep the line as-is
|
||||
if line.strip() or line == "": # Keep empty lines between content
|
||||
processed_lines.append(line)
|
||||
|
||||
# Remove trailing empty lines
|
||||
while processed_lines and not processed_lines[-1].strip():
|
||||
processed_lines.pop()
|
||||
|
||||
return "\n".join(processed_lines)
|
||||
|
||||
|
||||
class ThreadSafeConsole(Console):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
@ -24,11 +82,18 @@ class ThreadSafeConsole(Console):
|
|||
self.optimizer_rationale = None
|
||||
|
||||
def get_output(self):
|
||||
# return self.export_text(styles=True)
|
||||
"""
|
||||
Get the output from the buffer, processing carriage returns.
|
||||
|
||||
Rich's spinner uses carriage returns to overwrite lines in place.
|
||||
We process these to only keep the latest content, preventing
|
||||
duplicate spinner frames from flooding the output.
|
||||
"""
|
||||
value = self.buffer.getvalue()
|
||||
self.buffer.truncate(0)
|
||||
self.buffer.seek(0)
|
||||
return value
|
||||
# Process carriage returns to handle spinner overwrites
|
||||
return process_carriage_returns(value)
|
||||
|
||||
def status(
|
||||
self,
|
||||
|
|
@ -36,10 +101,15 @@ class ThreadSafeConsole(Console):
|
|||
*,
|
||||
spinner: str = "dots",
|
||||
spinner_style: "StyleType" = "status.spinner",
|
||||
speed: float = 0.1, # Much slower speed
|
||||
refresh_per_second: float = 0.5, # Much slower refresh rate (every 2 seconds)
|
||||
speed: float = 1.0,
|
||||
refresh_per_second: float = 4,
|
||||
) -> "Status":
|
||||
"""
|
||||
Return a Rich Status with animation.
|
||||
|
||||
The carriage returns from the spinner animation are processed
|
||||
in get_output() to prevent duplicate lines.
|
||||
"""
|
||||
status_renderable = Status(
|
||||
status,
|
||||
console=self,
|
||||
|
|
|
|||
|
|
@ -849,13 +849,9 @@ class MOARSearch:
|
|||
response = litellm.completion(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=ExpandResponseFormat,
|
||||
)
|
||||
call_cost = response._hidden_params["response_cost"]
|
||||
call_cost = response._hidden_params.get("response_cost", 0)
|
||||
with self.tree_lock:
|
||||
self.console.log(
|
||||
f"[green]💰 Adding LLM call cost:[/green] ${call_cost:.4f} (total before: ${self.total_search_cost:.4f})"
|
||||
|
|
|
|||
|
|
@ -153,10 +153,11 @@ def run_moar_optimization(
|
|||
"optimizer_config must contain 'max_iterations' (number of MOARSearch iterations) for MOAR optimizer"
|
||||
)
|
||||
|
||||
model = optimizer_config.get("model")
|
||||
# Use rewrite_agent_model for consistency with rest of codebase, fallback to model for backwards compatibility
|
||||
model = optimizer_config.get("rewrite_agent_model") or optimizer_config.get("model")
|
||||
if not model:
|
||||
raise ValueError(
|
||||
"optimizer_config must contain 'model' (LLM model name for directive instantiation) for MOAR optimizer"
|
||||
"optimizer_config must contain 'rewrite_agent_model' (LLM model name for directive instantiation) for MOAR optimizer"
|
||||
)
|
||||
|
||||
# Optional parameters
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class Optimizer:
|
|||
def __init__(
|
||||
self,
|
||||
runner: "DSLRunner",
|
||||
rewrite_agent_model: str = "gpt-4o",
|
||||
rewrite_agent_model: str = "gpt-5.1",
|
||||
judge_agent_model: str = "gpt-4o-mini",
|
||||
litellm_kwargs: dict[str, Any] = {},
|
||||
resume: bool = False,
|
||||
|
|
@ -67,7 +67,7 @@ class Optimizer:
|
|||
|
||||
Args:
|
||||
yaml_file (str): Path to the YAML configuration file.
|
||||
model (str): The name of the language model to use. Defaults to "gpt-4o".
|
||||
model (str): The name of the language model to use. Defaults to "gpt-5.1".
|
||||
resume (bool): Whether to resume optimization from a previous run. Defaults to False.
|
||||
timeout (int): Timeout in seconds for operations. Defaults to 60.
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,926 @@
|
|||
"""
|
||||
Fast decomposition analyzer for map operations.
|
||||
|
||||
This module provides a fast way to decompose map operations using directives,
|
||||
running candidates on samples, and selecting the best via pairwise LLM comparison.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm import completion, model_cost
|
||||
from rich.console import Console
|
||||
|
||||
from docetl.console import get_console
|
||||
from docetl.reasoning_optimizer.directives import (
|
||||
ChainingDirective,
|
||||
ClarifyInstructionsDirective,
|
||||
DeterministicDocCompressionDirective,
|
||||
DocumentChunkingDirective,
|
||||
GleaningDirective,
|
||||
IsolatingSubtasksDirective,
|
||||
)
|
||||
from docetl.runner import DSLRunner
|
||||
from docetl.utils import completion_cost, count_tokens
|
||||
|
||||
# Drop unsupported params for models like gpt-5 that don't support temperature=0
|
||||
litellm.drop_params = True
|
||||
|
||||
# Base directives always applicable to map operations
|
||||
BASE_MAP_DIRECTIVES = [
|
||||
ChainingDirective(),
|
||||
IsolatingSubtasksDirective(),
|
||||
GleaningDirective(),
|
||||
ClarifyInstructionsDirective(),
|
||||
]
|
||||
|
||||
# Directives that depend on document size
|
||||
CHUNKING_DIRECTIVE = DocumentChunkingDirective()
|
||||
COMPRESSION_DIRECTIVE = DeterministicDocCompressionDirective()
|
||||
|
||||
# Threshold for enabling DeterministicDocCompression (in characters)
|
||||
DOC_COMPRESSION_CHAR_THRESHOLD = 1000
|
||||
|
||||
# Threshold for enabling DocumentChunking (as fraction of context window)
|
||||
DOC_CHUNKING_CONTEXT_THRESHOLD = 0.10 # 10%
|
||||
|
||||
|
||||
class FastDecomposer:
|
||||
"""
|
||||
Fast decomposition of map operations using directives and pairwise comparison.
|
||||
|
||||
Instead of the full optimizer flow, this:
|
||||
1. Tries multiple directives to generate candidate decompositions
|
||||
2. Runs each candidate on sample documents
|
||||
3. Uses LLM judge for pairwise comparison
|
||||
4. Returns the winning decomposition
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
yaml_config_path: str,
|
||||
optimizer_model: str = "gpt-5.1",
|
||||
sample_size: int = 5,
|
||||
litellm_kwargs: dict[str, Any] | None = None,
|
||||
console: Console | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the decomposer.
|
||||
|
||||
Args:
|
||||
yaml_config_path: Path to the pipeline YAML config file
|
||||
optimizer_model: LLM model to use for directive instantiation and judging
|
||||
sample_size: Number of sample documents to run candidates on
|
||||
litellm_kwargs: Additional kwargs to pass to litellm.completion
|
||||
console: Rich console for output (uses default if not provided)
|
||||
"""
|
||||
self.yaml_config_path = yaml_config_path
|
||||
self.optimizer_model = optimizer_model
|
||||
self.sample_size = sample_size
|
||||
self.litellm_kwargs = litellm_kwargs or {}
|
||||
if "temperature" not in self.litellm_kwargs:
|
||||
self.litellm_kwargs["temperature"] = 0.0
|
||||
|
||||
self.total_cost = 0.0
|
||||
self.console = console or get_console()
|
||||
|
||||
# Load the config
|
||||
import yaml
|
||||
|
||||
with open(yaml_config_path, "r") as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
|
||||
self.operators = self.config.get("operations", [])
|
||||
self.default_model = self.config.get("default_model", "gpt-4o-mini")
|
||||
self.intermediate_dir = (
|
||||
self.config.get("pipeline", {}).get("output", {}).get("intermediate_dir")
|
||||
)
|
||||
|
||||
def _log(self, message: str) -> None:
|
||||
"""Log a message to the console."""
|
||||
self.console.log(message)
|
||||
|
||||
def get_model_context_limit(self, model: str) -> int:
|
||||
"""
|
||||
Get the context window limit for a model.
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., 'gpt-4o', 'azure/gpt-4')
|
||||
|
||||
Returns:
|
||||
Maximum number of input tokens the model can handle
|
||||
"""
|
||||
model_info = model_cost.get(model, {})
|
||||
# Try without provider prefix if not found
|
||||
if not model_info:
|
||||
model_name = model.split("/")[-1]
|
||||
model_info = model_cost.get(model_name, {})
|
||||
return model_info.get("max_input_tokens", 128000) # Default to 128k
|
||||
|
||||
def get_avg_doc_size(
|
||||
self, sample_data: list[dict[str, Any]], op_config: dict[str, Any]
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate the average document size in characters and tokens.
|
||||
|
||||
Extracts the document content from sample data based on the operation's
|
||||
prompt template (looks for {{ input.field_name }} patterns).
|
||||
|
||||
Args:
|
||||
sample_data: List of sample documents
|
||||
op_config: The operation configuration
|
||||
|
||||
Returns:
|
||||
Tuple of (avg_chars, avg_tokens) for the document content
|
||||
"""
|
||||
import re
|
||||
|
||||
if not sample_data:
|
||||
return 0.0, 0.0
|
||||
|
||||
prompt = op_config.get("prompt", "")
|
||||
model = op_config.get("model", self.default_model)
|
||||
|
||||
# Extract field names from prompt template ({{ input.field_name }})
|
||||
field_pattern = r"\{\{\s*input\.(\w+)\s*\}\}"
|
||||
fields = re.findall(field_pattern, prompt)
|
||||
|
||||
if not fields:
|
||||
# Fallback: use all string values from the first document
|
||||
if sample_data:
|
||||
fields = [
|
||||
k
|
||||
for k, v in sample_data[0].items()
|
||||
if isinstance(v, str) and len(v) > 100
|
||||
]
|
||||
|
||||
total_chars = 0
|
||||
total_tokens = 0
|
||||
|
||||
for doc in sample_data:
|
||||
doc_content = ""
|
||||
for field in fields:
|
||||
if field in doc:
|
||||
value = doc[field]
|
||||
if isinstance(value, str):
|
||||
doc_content += value
|
||||
else:
|
||||
doc_content += str(value)
|
||||
|
||||
total_chars += len(doc_content)
|
||||
if doc_content:
|
||||
total_tokens += count_tokens(doc_content, model)
|
||||
|
||||
n = len(sample_data)
|
||||
return total_chars / n if n > 0 else 0.0, total_tokens / n if n > 0 else 0.0
|
||||
|
||||
def get_applicable_directives(
|
||||
self,
|
||||
sample_data: list[dict[str, Any]],
|
||||
op_config: dict[str, Any],
|
||||
) -> list:
|
||||
"""
|
||||
Get the list of directives applicable based on data characteristics.
|
||||
|
||||
- DocumentChunkingDirective: only if avg doc size > 10% of context window
|
||||
- DeterministicDocCompressionDirective: only if avg doc size > 1000 chars
|
||||
|
||||
Args:
|
||||
sample_data: List of sample documents
|
||||
op_config: The operation configuration
|
||||
|
||||
Returns:
|
||||
List of applicable directive instances
|
||||
"""
|
||||
directives = [] # Build list with priority ordering
|
||||
|
||||
model = op_config.get("model", self.default_model)
|
||||
context_limit = self.get_model_context_limit(model)
|
||||
avg_chars, avg_tokens = self.get_avg_doc_size(sample_data, op_config)
|
||||
|
||||
self._log(
|
||||
f"Document analysis: avg_chars={avg_chars:.0f}, avg_tokens={avg_tokens:.0f}, "
|
||||
f"context_limit={context_limit}"
|
||||
)
|
||||
|
||||
# Add DeterministicDocCompression FIRST if doc size > 1000 chars (high priority)
|
||||
if avg_chars > DOC_COMPRESSION_CHAR_THRESHOLD:
|
||||
self._log(
|
||||
f"[cyan]Enabling DeterministicDocCompression (priority)[/cyan] "
|
||||
f"(avg {avg_chars:.0f} chars > {DOC_COMPRESSION_CHAR_THRESHOLD})"
|
||||
)
|
||||
directives.append(COMPRESSION_DIRECTIVE)
|
||||
|
||||
# Add base directives
|
||||
directives.extend(BASE_MAP_DIRECTIVES)
|
||||
|
||||
# Add DocumentChunking if avg tokens > 10% of context window
|
||||
token_threshold = context_limit * DOC_CHUNKING_CONTEXT_THRESHOLD
|
||||
if avg_tokens > token_threshold:
|
||||
self._log(
|
||||
f"[cyan]Enabling DocumentChunking[/cyan] "
|
||||
f"(avg {avg_tokens:.0f} tokens > {token_threshold:.0f} = 10% of {context_limit})"
|
||||
)
|
||||
directives.append(CHUNKING_DIRECTIVE)
|
||||
else:
|
||||
self._log(
|
||||
f"[dim]Skipping DocumentChunking[/dim] "
|
||||
f"(avg {avg_tokens:.0f} tokens <= {token_threshold:.0f} = 10% of {context_limit})"
|
||||
)
|
||||
|
||||
return directives
|
||||
|
||||
def load_sample_data(self, step_name: str, op_name: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load sample input data for an operation.
|
||||
|
||||
For the first operation, loads from the dataset.
|
||||
For subsequent operations, loads from the previous operation's intermediate output.
|
||||
|
||||
Args:
|
||||
step_name: Name of the pipeline step
|
||||
op_name: Name of the operation
|
||||
|
||||
Returns:
|
||||
List of sample documents
|
||||
"""
|
||||
# Find the operation's position
|
||||
op_names = [op.get("name") for op in self.operators]
|
||||
try:
|
||||
op_idx = op_names.index(op_name)
|
||||
except ValueError:
|
||||
raise ValueError(f"Operation '{op_name}' not found in config")
|
||||
|
||||
if op_idx == 0:
|
||||
# First operation - load from dataset
|
||||
datasets = self.config.get("datasets", {})
|
||||
# Get the first dataset (or the one used by this step)
|
||||
for dataset_name, dataset_config in datasets.items():
|
||||
dataset_path = dataset_config.get("path")
|
||||
if dataset_path and os.path.exists(dataset_path):
|
||||
with open(dataset_path, "r") as f:
|
||||
data = json.load(f)
|
||||
return data[: self.sample_size]
|
||||
raise FileNotFoundError("No dataset found in config")
|
||||
else:
|
||||
# Load from previous operation's intermediate output
|
||||
prev_op_name = op_names[op_idx - 1]
|
||||
output_path = os.path.join(
|
||||
self.intermediate_dir, step_name, f"{prev_op_name}.json"
|
||||
)
|
||||
if not os.path.exists(output_path):
|
||||
raise FileNotFoundError(
|
||||
f"No intermediate output found at {output_path}. "
|
||||
"Run the previous operation first."
|
||||
)
|
||||
with open(output_path, "r") as f:
|
||||
data = json.load(f)
|
||||
return data[: self.sample_size]
|
||||
|
||||
def get_input_file_path(self) -> str:
|
||||
"""Get the input file path for the pipeline."""
|
||||
datasets = self.config.get("datasets", {})
|
||||
for dataset_name, dataset_config in datasets.items():
|
||||
path = dataset_config.get("path")
|
||||
if path:
|
||||
return path
|
||||
return ""
|
||||
|
||||
def generate_candidates(
|
||||
self,
|
||||
op_name: str,
|
||||
sample_data: list[dict[str, Any]],
|
||||
target_op: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Generate candidate decompositions using directives.
|
||||
|
||||
Directives are selected based on data characteristics:
|
||||
- DocumentChunkingDirective: only if avg doc size > 10% of context window
|
||||
- DeterministicDocCompressionDirective: only if avg doc size > 1000 chars
|
||||
|
||||
Args:
|
||||
op_name: Name of the operation to decompose
|
||||
sample_data: Sample data for analyzing document characteristics
|
||||
target_op: The target operation configuration
|
||||
|
||||
Returns:
|
||||
List of candidate dictionaries with 'name', 'ops', 'cost' keys
|
||||
"""
|
||||
candidates = []
|
||||
|
||||
# Add original as baseline
|
||||
self._log("Adding original operation as baseline candidate")
|
||||
candidates.append(
|
||||
{
|
||||
"name": "original",
|
||||
"ops": deepcopy(self.operators),
|
||||
"cost": 0.0,
|
||||
"error": None,
|
||||
}
|
||||
)
|
||||
|
||||
input_file_path = self.get_input_file_path()
|
||||
|
||||
# Get applicable directives based on data characteristics
|
||||
applicable_directives = self.get_applicable_directives(sample_data, target_op)
|
||||
|
||||
self._log(
|
||||
f"Generating candidates using {len(applicable_directives)} directives..."
|
||||
)
|
||||
for i, directive in enumerate(applicable_directives, 1):
|
||||
with self.console.status(
|
||||
f"[bold cyan]({i}/{len(applicable_directives)}) Trying directive: {directive.name}...[/bold cyan]",
|
||||
spinner="dots",
|
||||
):
|
||||
try:
|
||||
new_ops_list, _, cost = directive.instantiate(
|
||||
operators=deepcopy(self.operators),
|
||||
target_ops=[op_name],
|
||||
agent_llm=self.optimizer_model,
|
||||
message_history=[],
|
||||
global_default_model=self.default_model,
|
||||
input_file_path=input_file_path,
|
||||
)
|
||||
self.total_cost += cost
|
||||
self._log(
|
||||
f" [green]✓[/green] {directive.name} generated {len(new_ops_list)} operations (cost: ${cost:.4f})"
|
||||
)
|
||||
candidates.append(
|
||||
{
|
||||
"name": directive.name,
|
||||
"ops": new_ops_list,
|
||||
"cost": cost,
|
||||
"error": None,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
# Directive not applicable or failed - skip it
|
||||
self._log(f" [red]✗[/red] {directive.name} failed: {str(e)}")
|
||||
candidates.append(
|
||||
{
|
||||
"name": directive.name,
|
||||
"ops": None,
|
||||
"cost": 0.0,
|
||||
"error": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
return candidates
|
||||
|
||||
def extract_ops_to_run(
|
||||
self, ops_list: list[dict], original_op_name: str
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Extract the operations that replaced the original operation.
|
||||
|
||||
Args:
|
||||
ops_list: The full transformed operations list
|
||||
original_op_name: Name of the original operation that was decomposed
|
||||
|
||||
Returns:
|
||||
List of operations that should be run on samples
|
||||
"""
|
||||
# Find ops that are new (not in original) or modified
|
||||
original_names = {op["name"] for op in self.operators}
|
||||
|
||||
# Find the position where the original op was
|
||||
original_idx = None
|
||||
for i, op in enumerate(self.operators):
|
||||
if op["name"] == original_op_name:
|
||||
original_idx = i
|
||||
break
|
||||
|
||||
if original_idx is None:
|
||||
return ops_list
|
||||
|
||||
# Find new ops (those not in original list)
|
||||
new_ops = []
|
||||
for op in ops_list:
|
||||
if op["name"] not in original_names or op["name"] == original_op_name:
|
||||
new_ops.append(op)
|
||||
|
||||
return new_ops if new_ops else [ops_list[original_idx]]
|
||||
|
||||
def run_candidate_on_samples(
|
||||
self,
|
||||
candidate: dict[str, Any],
|
||||
sample_data: list[dict[str, Any]],
|
||||
original_op_name: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Run a candidate's operations on sample data.
|
||||
|
||||
Args:
|
||||
candidate: Candidate dictionary with 'ops' key
|
||||
sample_data: List of sample documents
|
||||
original_op_name: Name of the original operation
|
||||
|
||||
Returns:
|
||||
List of output documents
|
||||
"""
|
||||
if candidate["ops"] is None:
|
||||
return []
|
||||
|
||||
# Extract ops to run
|
||||
ops_to_run = self.extract_ops_to_run(candidate["ops"], original_op_name)
|
||||
|
||||
if not ops_to_run:
|
||||
return []
|
||||
|
||||
# Create a minimal config for running these ops
|
||||
# Write sample data to a temp file
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(sample_data, f)
|
||||
temp_input_path = f.name
|
||||
|
||||
# Create a temp output file (required by DSLRunner validation)
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
temp_output_path = f.name
|
||||
|
||||
try:
|
||||
# Create a minimal pipeline config
|
||||
temp_config = {
|
||||
"default_model": self.default_model,
|
||||
"operations": ops_to_run,
|
||||
"datasets": {
|
||||
"sample_data": {
|
||||
"type": "file",
|
||||
"path": temp_input_path,
|
||||
}
|
||||
},
|
||||
"pipeline": {
|
||||
"steps": [
|
||||
{
|
||||
"name": "decompose_test",
|
||||
"input": "sample_data",
|
||||
"operations": [op["name"] for op in ops_to_run],
|
||||
}
|
||||
],
|
||||
"output": {
|
||||
"type": "file",
|
||||
"path": temp_output_path,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Create runner and execute
|
||||
runner = DSLRunner(temp_config, max_threads=4)
|
||||
|
||||
# Run operations sequentially on the data
|
||||
current_data = sample_data
|
||||
for op_config in ops_to_run:
|
||||
current_data = runner._run_operation(op_config, current_data)
|
||||
|
||||
self.total_cost += runner.total_cost
|
||||
|
||||
return current_data
|
||||
|
||||
finally:
|
||||
# Clean up temp files
|
||||
for temp_path in [temp_input_path, temp_output_path]:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def pairwise_compare(
|
||||
self,
|
||||
candidate_a: dict[str, Any],
|
||||
candidate_b: dict[str, Any],
|
||||
original_prompt: str,
|
||||
output_schema: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Compare two candidates using LLM judge.
|
||||
|
||||
Args:
|
||||
candidate_a: First candidate with 'name', 'outputs' keys
|
||||
candidate_b: Second candidate with 'name', 'outputs' keys
|
||||
original_prompt: The original operation's prompt
|
||||
output_schema: The expected output schema
|
||||
|
||||
Returns:
|
||||
The winning candidate
|
||||
"""
|
||||
# If one has no outputs, the other wins
|
||||
if not candidate_a.get("outputs"):
|
||||
return candidate_b
|
||||
if not candidate_b.get("outputs"):
|
||||
return candidate_a
|
||||
|
||||
system_prompt = """You are an expert judge comparing outputs from two data processing pipeline variants.
|
||||
|
||||
Your task is to determine which variant produces BETTER outputs based on:
|
||||
1. **Completeness**: Does the output contain all required information?
|
||||
2. **Accuracy**: Is the extracted/generated information correct?
|
||||
3. **Consistency**: Are the outputs consistent across different samples?
|
||||
4. **Quality**: Is the output well-structured and useful?
|
||||
|
||||
Be objective and focus on the actual output quality, not the approach used."""
|
||||
|
||||
user_prompt = f"""Compare outputs from two pipeline variants for this task:
|
||||
|
||||
## Original Task
|
||||
**Prompt:**
|
||||
```
|
||||
{original_prompt[:2000]}{"..." if len(original_prompt) > 2000 else ""}
|
||||
```
|
||||
|
||||
**Expected Output Schema:**
|
||||
```json
|
||||
{json.dumps(output_schema, indent=2)}
|
||||
```
|
||||
|
||||
## Variant A: {candidate_a["name"]}
|
||||
**Sample Outputs:**
|
||||
```json
|
||||
{json.dumps(candidate_a["outputs"][:3], indent=2, default=str)}
|
||||
```
|
||||
|
||||
## Variant B: {candidate_b["name"]}
|
||||
**Sample Outputs:**
|
||||
```json
|
||||
{json.dumps(candidate_b["outputs"][:3], indent=2, default=str)}
|
||||
```
|
||||
|
||||
## Your Task
|
||||
Which variant produces better outputs? Consider completeness, accuracy, consistency, and quality.
|
||||
Respond with your analysis and final choice."""
|
||||
|
||||
response = completion(
|
||||
model=self.optimizer_model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "comparison_result",
|
||||
"strict": True,
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"winner": {
|
||||
"type": "string",
|
||||
"enum": ["A", "B"],
|
||||
"description": "Which variant is better: A or B",
|
||||
},
|
||||
"rationale": {
|
||||
"type": "string",
|
||||
"description": "Explanation of why this variant is better",
|
||||
},
|
||||
"a_strengths": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Strengths of variant A",
|
||||
},
|
||||
"b_strengths": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Strengths of variant B",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"winner",
|
||||
"rationale",
|
||||
"a_strengths",
|
||||
"b_strengths",
|
||||
],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
**self.litellm_kwargs,
|
||||
)
|
||||
|
||||
cost = completion_cost(response)
|
||||
self.total_cost += cost
|
||||
|
||||
result = json.loads(response.choices[0].message.content)
|
||||
|
||||
winner = candidate_a if result["winner"] == "A" else candidate_b
|
||||
winner["comparison_rationale"] = result["rationale"]
|
||||
|
||||
return winner
|
||||
|
||||
def decompose(
|
||||
self,
|
||||
step_name: str,
|
||||
op_name: str,
|
||||
) -> tuple[list[dict[str, Any]], str, int, float]:
|
||||
"""
|
||||
Decompose an operation using fast directive-based approach.
|
||||
|
||||
This is the main entry point. It:
|
||||
1. Generates candidate decompositions using directives
|
||||
2. Runs each on sample data
|
||||
3. Uses pairwise LLM comparison to select the best
|
||||
4. Returns the winning decomposition
|
||||
|
||||
Args:
|
||||
step_name: Name of the pipeline step
|
||||
op_name: Name of the operation to decompose
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- decomposed_ops: List of operations that replace the original
|
||||
- winning_directive: Name of the winning directive
|
||||
- candidates_evaluated: Number of candidates that were compared
|
||||
- original_outputs: Sample outputs from the original operation
|
||||
- decomposed_outputs: Sample outputs from the winning decomposition
|
||||
- comparison_rationale: LLM's explanation of why the winner was chosen
|
||||
- cost: Total LLM API cost
|
||||
"""
|
||||
self.console.rule("[bold blue]Fast Decomposition[/bold blue]")
|
||||
self._log(f"[bold]Target operation:[/bold] {op_name}")
|
||||
|
||||
# Find the target operation config
|
||||
target_op = None
|
||||
for op in self.operators:
|
||||
if op["name"] == op_name:
|
||||
target_op = op
|
||||
break
|
||||
|
||||
if target_op is None:
|
||||
raise ValueError(f"Operation '{op_name}' not found in config")
|
||||
|
||||
# Verify it's a map operation
|
||||
if target_op.get("type") != "map":
|
||||
raise ValueError(
|
||||
f"Operation '{op_name}' is type '{target_op.get('type')}', "
|
||||
"but fast decomposition only supports 'map' operations"
|
||||
)
|
||||
|
||||
# Load sample data
|
||||
self._log(f"Loading sample data from step '{step_name}'...")
|
||||
sample_data = self.load_sample_data(step_name, op_name)
|
||||
self._log(f"[green]✓[/green] Loaded {len(sample_data)} sample documents")
|
||||
|
||||
# Generate candidates
|
||||
self.console.rule("[bold]Generating Candidates[/bold]")
|
||||
candidates = self.generate_candidates(op_name, sample_data, target_op)
|
||||
|
||||
# Filter out failed candidates
|
||||
valid_candidates = [c for c in candidates if c["ops"] is not None]
|
||||
self._log(f"[bold]{len(valid_candidates)}[/bold] valid candidates generated")
|
||||
|
||||
if len(valid_candidates) < 2:
|
||||
# Only original (or nothing) - return original
|
||||
self._log(
|
||||
"[yellow]No alternative decompositions generated. Keeping original.[/yellow]"
|
||||
)
|
||||
return {
|
||||
"decomposed_ops": self.operators,
|
||||
"winning_directive": "original",
|
||||
"candidates_evaluated": len(valid_candidates),
|
||||
"original_outputs": [],
|
||||
"decomposed_outputs": [],
|
||||
"comparison_rationale": "No alternative decompositions were generated.",
|
||||
"cost": self.total_cost,
|
||||
}
|
||||
|
||||
# Run each candidate on samples IN PARALLEL
|
||||
self.console.rule("[bold]Running Candidates on Samples[/bold]")
|
||||
self._log(f"Running {len(valid_candidates)} candidates in parallel...")
|
||||
|
||||
def run_single_candidate(candidate):
|
||||
"""Run a single candidate and return results."""
|
||||
name = candidate["name"]
|
||||
try:
|
||||
outputs = self.run_candidate_on_samples(candidate, sample_data, op_name)
|
||||
return {"name": name, "outputs": outputs, "error": None}
|
||||
except Exception as e:
|
||||
error_tb = traceback.format_exc()
|
||||
return {
|
||||
"name": name,
|
||||
"outputs": [],
|
||||
"error": str(e),
|
||||
"traceback": error_tb,
|
||||
}
|
||||
|
||||
# Run all candidates in parallel
|
||||
with self.console.status(
|
||||
"[bold cyan]Running candidates on samples...[/bold cyan]", spinner="dots"
|
||||
):
|
||||
with ThreadPoolExecutor(max_workers=len(valid_candidates)) as executor:
|
||||
future_to_candidate = {
|
||||
executor.submit(run_single_candidate, c): c
|
||||
for c in valid_candidates
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_candidate):
|
||||
candidate = future_to_candidate[future]
|
||||
result = future.result()
|
||||
|
||||
if result["error"]:
|
||||
candidate["outputs"] = []
|
||||
candidate["run_error"] = result["error"]
|
||||
self._log(
|
||||
f" [red]✗[/red] {result['name']} failed: {result['error']}"
|
||||
)
|
||||
else:
|
||||
candidate["outputs"] = result["outputs"]
|
||||
self._log(
|
||||
f" [green]✓[/green] {result['name']}: {len(result['outputs'])} outputs"
|
||||
)
|
||||
|
||||
# Filter to candidates with outputs
|
||||
candidates_with_outputs = [c for c in valid_candidates if c.get("outputs")]
|
||||
self._log(
|
||||
f"[bold]{len(candidates_with_outputs)}[/bold] candidates produced outputs"
|
||||
)
|
||||
|
||||
if not candidates_with_outputs:
|
||||
# All failed - return original
|
||||
self._log("[red]All decomposition candidates failed to execute.[/red]")
|
||||
# Log the errors for debugging
|
||||
for c in valid_candidates:
|
||||
if c.get("run_error"):
|
||||
self._log(f" [red]{c['name']}:[/red] {c['run_error']}")
|
||||
return {
|
||||
"decomposed_ops": self.operators,
|
||||
"winning_directive": "original",
|
||||
"candidates_evaluated": 0,
|
||||
"original_outputs": [],
|
||||
"decomposed_outputs": [],
|
||||
"comparison_rationale": "All decomposition candidates failed to execute.",
|
||||
"cost": self.total_cost,
|
||||
}
|
||||
|
||||
# Find the original candidate's outputs
|
||||
original_candidate = next(
|
||||
(c for c in candidates_with_outputs if c["name"] == "original"), None
|
||||
)
|
||||
original_outputs = original_candidate["outputs"] if original_candidate else []
|
||||
|
||||
# Parallel pairwise comparison: compare all candidates against original
|
||||
self.console.rule("[bold]Pairwise Comparison[/bold]")
|
||||
original_prompt = target_op.get("prompt", "")
|
||||
output_schema = target_op.get("output", {}).get("schema", {})
|
||||
|
||||
# If only original exists, it wins by default
|
||||
if len(candidates_with_outputs) == 1:
|
||||
winner = candidates_with_outputs[0]
|
||||
elif original_candidate is None:
|
||||
# No original - just pick the first candidate
|
||||
winner = candidates_with_outputs[0]
|
||||
else:
|
||||
# Compare all non-original candidates against original IN PARALLEL
|
||||
challengers = [
|
||||
c for c in candidates_with_outputs if c["name"] != "original"
|
||||
]
|
||||
|
||||
if not challengers:
|
||||
winner = original_candidate
|
||||
else:
|
||||
self._log(
|
||||
f"Comparing {len(challengers)} candidates against original in parallel..."
|
||||
)
|
||||
|
||||
def compare_against_original(challenger):
|
||||
"""Compare a single challenger against original."""
|
||||
try:
|
||||
result = self.pairwise_compare(
|
||||
original_candidate,
|
||||
challenger,
|
||||
original_prompt,
|
||||
output_schema,
|
||||
)
|
||||
won = result["name"] == challenger["name"]
|
||||
return {
|
||||
"challenger": challenger,
|
||||
"won": won,
|
||||
"rationale": result.get("comparison_rationale", ""),
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"challenger": challenger,
|
||||
"won": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
# Run all comparisons in parallel
|
||||
comparison_results = []
|
||||
with self.console.status(
|
||||
f"[bold cyan]Running {len(challengers)} comparisons in parallel...[/bold cyan]",
|
||||
spinner="dots",
|
||||
):
|
||||
with ThreadPoolExecutor(max_workers=len(challengers)) as executor:
|
||||
future_to_challenger = {
|
||||
executor.submit(compare_against_original, c): c
|
||||
for c in challengers
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_challenger):
|
||||
result = future.result()
|
||||
comparison_results.append(result)
|
||||
challenger_name = result["challenger"]["name"]
|
||||
if result.get("error"):
|
||||
self._log(
|
||||
f" [red]✗[/red] {challenger_name} vs original: error - {result['error']}"
|
||||
)
|
||||
elif result["won"]:
|
||||
self._log(
|
||||
f" [green]✓[/green] {challenger_name} beat original"
|
||||
)
|
||||
else:
|
||||
self._log(
|
||||
f" [dim]○[/dim] {challenger_name} lost to original"
|
||||
)
|
||||
|
||||
# Find winners (candidates that beat original)
|
||||
winners = [r for r in comparison_results if r.get("won")]
|
||||
|
||||
if not winners:
|
||||
# Original beats all challengers
|
||||
winner = original_candidate
|
||||
self._log("[bold]Original wins against all challengers[/bold]")
|
||||
elif len(winners) == 1:
|
||||
# Single winner
|
||||
winner = winners[0]["challenger"]
|
||||
winner["comparison_rationale"] = winners[0].get("rationale", "")
|
||||
else:
|
||||
# Multiple winners beat original - run tiebreaker comparisons in parallel
|
||||
self._log(
|
||||
f"[bold]{len(winners)} candidates beat original - running tiebreaker...[/bold]"
|
||||
)
|
||||
|
||||
# Compare all winners against each other in parallel (round-robin)
|
||||
winner_candidates = [w["challenger"] for w in winners]
|
||||
win_counts = {c["name"]: 0 for c in winner_candidates}
|
||||
|
||||
# Generate all pairwise matchups
|
||||
matchups = []
|
||||
for i, a in enumerate(winner_candidates):
|
||||
for b in winner_candidates[i + 1 :]:
|
||||
matchups.append((a, b))
|
||||
|
||||
if matchups:
|
||||
|
||||
def run_matchup(matchup):
|
||||
a, b = matchup
|
||||
try:
|
||||
result = self.pairwise_compare(
|
||||
a, b, original_prompt, output_schema
|
||||
)
|
||||
return {
|
||||
"winner": result["name"],
|
||||
"a": a["name"],
|
||||
"b": b["name"],
|
||||
}
|
||||
except Exception:
|
||||
return {
|
||||
"winner": a["name"],
|
||||
"a": a["name"],
|
||||
"b": b["name"],
|
||||
} # Default to first
|
||||
|
||||
with self.console.status(
|
||||
f"[bold cyan]Running {len(matchups)} tiebreaker comparisons...[/bold cyan]",
|
||||
spinner="dots",
|
||||
):
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=len(matchups)
|
||||
) as executor:
|
||||
for result in executor.map(run_matchup, matchups):
|
||||
win_counts[result["winner"]] += 1
|
||||
self._log(
|
||||
f" [dim]{result['a']} vs {result['b']} → {result['winner']}[/dim]"
|
||||
)
|
||||
|
||||
# Pick candidate with most wins
|
||||
best_name = max(win_counts, key=win_counts.get)
|
||||
winner = next(
|
||||
c for c in winner_candidates if c["name"] == best_name
|
||||
)
|
||||
self._log(
|
||||
f"[bold]Tiebreaker winner: {best_name} ({win_counts[best_name]} wins)[/bold]"
|
||||
)
|
||||
|
||||
# Extract the decomposed operations
|
||||
decomposed_ops = self.extract_ops_to_run(winner["ops"], op_name)
|
||||
|
||||
# Final summary
|
||||
self.console.rule("[bold green]Decomposition Complete[/bold green]")
|
||||
self._log(f"[bold]Winner:[/bold] [green]{winner['name']}[/green]")
|
||||
self._log(f"[bold]Candidates evaluated:[/bold] {len(candidates_with_outputs)}")
|
||||
self._log(f"[bold]New operations:[/bold] {len(decomposed_ops)}")
|
||||
self._log(f"[bold]Total cost:[/bold] ${self.total_cost:.4f}")
|
||||
|
||||
return {
|
||||
"decomposed_ops": decomposed_ops,
|
||||
"winning_directive": winner["name"],
|
||||
"candidates_evaluated": len(candidates_with_outputs),
|
||||
"original_outputs": original_outputs,
|
||||
"decomposed_outputs": winner.get("outputs", []),
|
||||
"comparison_rationale": winner.get("comparison_rationale", ""),
|
||||
"cost": self.total_cost,
|
||||
}
|
||||
|
|
@ -0,0 +1,337 @@
|
|||
"""
|
||||
Fast should_optimize analyzer using a single LLM call.
|
||||
|
||||
This module provides a lightweight alternative to the full MapOptimizer/ReduceOptimizer/JoinOptimizer
|
||||
flow for quickly determining if an operation should be decomposed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm import completion, model_cost
|
||||
|
||||
from docetl.utils import completion_cost, count_tokens
|
||||
|
||||
# Drop unsupported params for models like gpt-5 that don't support temperature=0
|
||||
litellm.drop_params = True
|
||||
|
||||
|
||||
class FastShouldOptimizeAnalyzer:
|
||||
"""
|
||||
Analyzes whether an operation should be optimized using a single LLM call.
|
||||
|
||||
Instead of running the operation on sample data and using complex evaluation logic,
|
||||
this reads cached outputs from intermediate files and makes one judgment call.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intermediate_dir: str,
|
||||
optimizer_model: str = "gpt-5.1",
|
||||
litellm_kwargs: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the analyzer.
|
||||
|
||||
Args:
|
||||
intermediate_dir: Path to the directory containing intermediate outputs
|
||||
optimizer_model: The LLM model to use for analysis (default: gpt-5.1)
|
||||
litellm_kwargs: Additional kwargs to pass to litellm.completion
|
||||
"""
|
||||
self.intermediate_dir = intermediate_dir
|
||||
self.optimizer_model = optimizer_model
|
||||
self.litellm_kwargs = litellm_kwargs or {}
|
||||
if "temperature" not in self.litellm_kwargs:
|
||||
self.litellm_kwargs["temperature"] = 0.0
|
||||
|
||||
def load_operation_data(self, step_name: str, op_name: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load data from the intermediate file for an operation.
|
||||
|
||||
Args:
|
||||
step_name: Name of the pipeline step
|
||||
op_name: Name of the operation
|
||||
|
||||
Returns:
|
||||
List of dictionaries
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the intermediate file doesn't exist
|
||||
"""
|
||||
output_path = os.path.join(self.intermediate_dir, step_name, f"{op_name}.json")
|
||||
if not os.path.exists(output_path):
|
||||
raise FileNotFoundError(
|
||||
f"No output file found at {output_path}. "
|
||||
"Run the operation first to generate outputs."
|
||||
)
|
||||
with open(output_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
def find_previous_operation(
|
||||
self, operations: list[dict[str, Any]], op_name: str
|
||||
) -> str | None:
|
||||
"""
|
||||
Find the operation that comes before op_name in the pipeline.
|
||||
|
||||
Args:
|
||||
operations: List of operation configs from the pipeline
|
||||
op_name: Name of the current operation
|
||||
|
||||
Returns:
|
||||
Name of the previous operation, or None if this is the first operation
|
||||
"""
|
||||
op_names = [op.get("name") for op in operations]
|
||||
try:
|
||||
idx = op_names.index(op_name)
|
||||
if idx > 0:
|
||||
return op_names[idx - 1]
|
||||
except ValueError:
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_max_context_tokens(self) -> int:
|
||||
"""
|
||||
Get the maximum input tokens for the optimizer model.
|
||||
|
||||
Returns:
|
||||
Maximum number of input tokens the model can handle
|
||||
"""
|
||||
model_info = model_cost.get(self.optimizer_model, {})
|
||||
# Try without provider prefix if not found
|
||||
if not model_info:
|
||||
model_name = self.optimizer_model.split("/")[-1]
|
||||
model_info = model_cost.get(model_name, {})
|
||||
return model_info.get("max_input_tokens", 128000) # Default to 128k
|
||||
|
||||
def calculate_samples_that_fit(
|
||||
self,
|
||||
op_config: dict[str, Any],
|
||||
outputs: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Calculate how many output samples fit in the context window.
|
||||
|
||||
Reserves space for system prompt, operation config, and response buffer,
|
||||
then fills remaining space with as many samples as possible.
|
||||
|
||||
Args:
|
||||
op_config: The operation configuration dictionary
|
||||
outputs: List of all output documents
|
||||
|
||||
Returns:
|
||||
List of samples that fit in the context window
|
||||
"""
|
||||
max_tokens = self.get_max_context_tokens()
|
||||
|
||||
# Reserve tokens for fixed parts
|
||||
system_prompt_tokens = 500
|
||||
op_config_tokens = count_tokens(
|
||||
json.dumps(op_config, default=str), self.optimizer_model
|
||||
)
|
||||
response_buffer = 2000
|
||||
|
||||
available_for_samples = (
|
||||
max_tokens - system_prompt_tokens - op_config_tokens - response_buffer
|
||||
)
|
||||
|
||||
# Collect samples that fit
|
||||
samples_to_include = []
|
||||
tokens_used = 0
|
||||
|
||||
for output in outputs:
|
||||
sample_json = json.dumps(output, default=str)
|
||||
sample_tokens = count_tokens(sample_json, self.optimizer_model)
|
||||
if tokens_used + sample_tokens <= available_for_samples:
|
||||
samples_to_include.append(output)
|
||||
tokens_used += sample_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
return samples_to_include
|
||||
|
||||
def build_analysis_prompt(
|
||||
self,
|
||||
op_config: dict[str, Any],
|
||||
samples: list[dict[str, Any]],
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Build the system prompt and user prompt for the analysis LLM call.
|
||||
|
||||
Args:
|
||||
op_config: The operation configuration dictionary
|
||||
samples: List of output samples to analyze
|
||||
|
||||
Returns:
|
||||
Tuple of (system_prompt, user_prompt)
|
||||
"""
|
||||
system_prompt = """You are an expert at analyzing LLM-powered data processing operations.
|
||||
Your task is to determine if an operation would benefit from being decomposed into multiple
|
||||
smaller, focused operations (also known as "optimization" or "decomposition").
|
||||
|
||||
An operation SHOULD be decomposed when:
|
||||
1. The prompt asks the LLM to do multiple distinct tasks that could be done separately
|
||||
2. The task is complex enough that breaking it into sequential steps would improve accuracy
|
||||
3. The outputs show inconsistency, incompleteness, or quality issues that iterative refinement could fix
|
||||
4. Long documents need to be processed in chunks rather than all at once
|
||||
5. The prompt asks for both extraction AND analysis/synthesis in one step
|
||||
|
||||
An operation should NOT be decomposed when:
|
||||
1. It performs a single, focused task well
|
||||
2. The outputs are consistently high quality and complete
|
||||
3. The task is simple and atomic (e.g., simple classification, single field extraction)
|
||||
4. The operation is already well-scoped and produces reliable results
|
||||
|
||||
Be conservative - only recommend decomposition if there's clear evidence it would help."""
|
||||
|
||||
output_schema = op_config.get("output", {}).get("schema", {})
|
||||
prompt_template = op_config.get("prompt", "No prompt specified")
|
||||
|
||||
# Truncate very long prompts for display
|
||||
if len(prompt_template) > 3000:
|
||||
prompt_template = prompt_template[:3000] + "\n... [truncated]"
|
||||
|
||||
user_prompt = f"""Analyze this data processing operation and its outputs:
|
||||
|
||||
## Operation Configuration
|
||||
|
||||
**Name:** {op_config.get('name', 'unknown')}
|
||||
**Type:** {op_config.get('type', 'unknown')}
|
||||
|
||||
**Prompt Template:**
|
||||
```
|
||||
{prompt_template}
|
||||
```
|
||||
|
||||
**Output Schema:**
|
||||
```json
|
||||
{json.dumps(output_schema, indent=2)}
|
||||
```
|
||||
|
||||
## Sample Outputs ({len(samples)} samples from the operation)
|
||||
|
||||
```json
|
||||
{json.dumps(samples, indent=2, default=str)}
|
||||
```
|
||||
|
||||
## Your Task
|
||||
|
||||
Based on the operation configuration and sample outputs, determine:
|
||||
1. Should this operation be decomposed/optimized?
|
||||
2. If yes, what specific improvements would help?
|
||||
|
||||
Consider the quality, completeness, and consistency of the outputs when making your assessment."""
|
||||
|
||||
return system_prompt, user_prompt
|
||||
|
||||
def analyze(
|
||||
self,
|
||||
op_config: dict[str, Any],
|
||||
step_name: str,
|
||||
op_name: str,
|
||||
) -> tuple[str, list[dict[str, Any]], int, float]:
|
||||
"""
|
||||
Analyze whether an operation should be optimized.
|
||||
|
||||
This is the main entry point. It loads outputs, builds the prompt,
|
||||
makes a single LLM call, and returns the assessment.
|
||||
|
||||
Args:
|
||||
op_config: The operation configuration dictionary
|
||||
step_name: Name of the pipeline step
|
||||
op_name: Name of the operation
|
||||
|
||||
Returns:
|
||||
Tuple of (rationale, output_samples, num_docs_analyzed, cost):
|
||||
- rationale: Empty string if no optimization needed, explanation if it should be optimized
|
||||
- output_samples: The samples that were analyzed
|
||||
- num_docs_analyzed: Number of documents that fit in the LLM prompt
|
||||
- cost: LLM API cost in USD
|
||||
|
||||
Raises:
|
||||
ValueError: If the operation is not an LLM-powered map operation
|
||||
"""
|
||||
# Validate operation type - only LLM-powered map operations
|
||||
op_type = op_config.get("type", "")
|
||||
if op_type != "map":
|
||||
raise ValueError(
|
||||
f"should_optimize only supports 'map' operations, got '{op_type}'. "
|
||||
"Only LLM-powered map operations can be analyzed for decomposition."
|
||||
)
|
||||
|
||||
# Load outputs from intermediate file
|
||||
outputs = self.load_operation_data(step_name, op_name)
|
||||
|
||||
if not outputs:
|
||||
return "No output samples available for analysis.", [], 0, 0.0
|
||||
|
||||
# Calculate samples that fit in context
|
||||
samples = self.calculate_samples_that_fit(op_config, outputs)
|
||||
|
||||
if not samples:
|
||||
return "Could not fit any samples in context window.", outputs[:5], 0, 0.0
|
||||
|
||||
# Build prompt
|
||||
system_prompt, user_prompt = self.build_analysis_prompt(op_config, samples)
|
||||
|
||||
# Make LLM call with structured output
|
||||
response = completion(
|
||||
model=self.optimizer_model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "optimization_analysis",
|
||||
"strict": True,
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"should_optimize": {
|
||||
"type": "boolean",
|
||||
"description": "True if operation should be decomposed/optimized",
|
||||
},
|
||||
"rationale": {
|
||||
"type": "string",
|
||||
"description": "Explanation of why the operation should or should not be optimized",
|
||||
},
|
||||
"suggested_improvements": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Specific improvements if optimization is recommended (empty if not)",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"should_optimize",
|
||||
"rationale",
|
||||
"suggested_improvements",
|
||||
],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
**self.litellm_kwargs,
|
||||
)
|
||||
|
||||
# Calculate cost
|
||||
cost = completion_cost(response)
|
||||
|
||||
# Parse response
|
||||
result = json.loads(response.choices[0].message.content)
|
||||
|
||||
num_docs_analyzed = len(samples)
|
||||
|
||||
if result["should_optimize"]:
|
||||
# Build rationale string with improvements
|
||||
rationale_parts = [result["rationale"]]
|
||||
if result["suggested_improvements"]:
|
||||
rationale_parts.append("\n\nSuggested improvements:")
|
||||
for imp in result["suggested_improvements"]:
|
||||
rationale_parts.append(f"- {imp}")
|
||||
return "\n".join(rationale_parts), samples, num_docs_analyzed, cost
|
||||
else:
|
||||
# Return empty string to indicate no optimization needed
|
||||
return "", samples, num_docs_analyzed, cost
|
||||
|
|
@ -30,7 +30,7 @@ class LLMClient:
|
|||
Initialize the LLMClient.
|
||||
|
||||
Args:
|
||||
model (str, optional): The name of the LLM model to use. Defaults to "gpt-4o".
|
||||
model (str, optional): The name of the LLM model to use. Defaults to "gpt-5.1".
|
||||
**litellm_kwargs: Additional keyword arguments for the LLM model.
|
||||
"""
|
||||
self.rewrite_agent_model = rewrite_agent_model
|
||||
|
|
|
|||
|
|
@ -204,10 +204,6 @@ def get_openai_response(
|
|||
response = litellm.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=ResponseFormat,
|
||||
)
|
||||
assistant_response = response.choices[0].message.content
|
||||
|
|
|
|||
|
|
@ -324,10 +324,6 @@ Focus on quality over quantity - a few diverse, informative examples are better
|
|||
model=self.agent_llm,
|
||||
messages=self.message_history,
|
||||
response_format=AgentDecision,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
)
|
||||
call_cost += response._hidden_params["response_cost"]
|
||||
|
||||
|
|
@ -415,10 +411,6 @@ Provide your response as a JSON object matching this schema: {response_schema.mo
|
|||
model=self.agent_llm,
|
||||
messages=self.message_history,
|
||||
response_format=response_schema,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
)
|
||||
call_cost += schema_response._hidden_params["response_cost"]
|
||||
|
||||
|
|
|
|||
|
|
@ -222,7 +222,6 @@ class Directive(BaseModel, ABC):
|
|||
model=agent_llm,
|
||||
messages=messages,
|
||||
response_format=JudgeResponse,
|
||||
azure=True,
|
||||
)
|
||||
|
||||
# Parse the JSON response
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
|
@ -169,19 +168,15 @@ class ChainingDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
# api_key=os.environ["GEMINI_API_KEY"],
|
||||
azure=True,
|
||||
response_format=ChainingInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
|
|
@ -202,11 +197,12 @@ class ChainingDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -211,18 +210,14 @@ class ChangeModelDirective(Directive):
|
|||
},
|
||||
]
|
||||
)
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
# api_key=os.environ["GEMINI_API_KEY"],
|
||||
azure=True,
|
||||
response_format=ChangeModelInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = ChangeModelInstantiateSchema(**parsed_res)
|
||||
|
|
@ -240,11 +235,12 @@ class ChangeModelDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -205,17 +204,14 @@ class ChangeModelAccDirective(Directive):
|
|||
},
|
||||
]
|
||||
)
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=ChangeModelInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = ChangeModelInstantiateSchema(**parsed_res)
|
||||
|
|
@ -233,11 +229,12 @@ class ChangeModelAccDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -282,17 +281,14 @@ class ChangeModelCostDirective(Directive):
|
|||
},
|
||||
]
|
||||
)
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=ChangeModelInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = ChangeModelInstantiateSchema(**parsed_res)
|
||||
|
|
@ -310,11 +306,12 @@ class ChangeModelCostDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -185,17 +184,14 @@ class ChunkHeaderSummaryDirective(Directive):
|
|||
},
|
||||
]
|
||||
)
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=ChunkHeaderSummaryInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = ChunkHeaderSummaryInstantiateSchema(**parsed_res)
|
||||
|
|
@ -204,11 +200,12 @@ class ChunkHeaderSummaryDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
@ -251,9 +248,7 @@ class ChunkHeaderSummaryDirective(Directive):
|
|||
)
|
||||
|
||||
if gather_idx - split_idx > 1:
|
||||
raise ValueError(
|
||||
"There should not be operators between split and gather"
|
||||
)
|
||||
raise ValueError("There should not be operators between split and gather")
|
||||
|
||||
# Get the split_key from the split operation
|
||||
split_key = split_op.get("split_key")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -261,17 +260,14 @@ def transform(input_doc):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=DeterministicDocCompressionInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = DeterministicDocCompressionInstantiateSchema(**parsed_res)
|
||||
|
|
@ -284,11 +280,12 @@ def transform(input_doc):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -268,18 +267,15 @@ class DocumentChunkingDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
call_cost = 0.0
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=DocumentChunkingInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = DocumentChunkingInstantiateSchema(**parsed_res)
|
||||
|
|
@ -290,11 +286,12 @@ class DocumentChunkingDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
@ -404,16 +401,19 @@ class DocumentChunkingDirective(Directive):
|
|||
sample_op["stratify_key"] = stratify_keys
|
||||
|
||||
if rewrite.sampling_config.samples_per_group:
|
||||
sample_op["samples_per_group"] = rewrite.sampling_config.samples_per_group
|
||||
|
||||
|
||||
sample_op["samples_per_group"] = (
|
||||
rewrite.sampling_config.samples_per_group
|
||||
)
|
||||
|
||||
# Add optional fields if provided
|
||||
if rewrite.sampling_config.random_state is not None:
|
||||
sample_op["random_state"] = rewrite.sampling_config.random_state
|
||||
|
||||
|
||||
if rewrite.sampling_config.method_kwargs is not None:
|
||||
try:
|
||||
sample_op["method_kwargs"] = json.loads(rewrite.sampling_config.method_kwargs)
|
||||
sample_op["method_kwargs"] = json.loads(
|
||||
rewrite.sampling_config.method_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid method_kwargs: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -431,17 +430,14 @@ class DocumentChunkingTopKDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=DocumentChunkingTopKInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = DocumentChunkingTopKInstantiateSchema(**parsed_res)
|
||||
|
|
@ -451,11 +447,12 @@ class DocumentChunkingTopKDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -164,19 +163,14 @@ class DocCompressionDirective(Directive):
|
|||
},
|
||||
]
|
||||
)
|
||||
error_message = ""
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=DocCompressionInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
|
|
@ -187,11 +181,12 @@ class DocCompressionDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Error: {error_message}"
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -261,17 +260,14 @@ class DocSummarizationDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=DocSummarizationInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = DocSummarizationInstantiateSchema(**parsed_res)
|
||||
|
|
@ -280,11 +276,12 @@ class DocSummarizationDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -146,32 +145,31 @@ class GleaningDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
# api_key=os.environ["GEMINI_API_KEY"],
|
||||
azure=True,
|
||||
response_format=GleaningInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = GleaningInstantiateSchema(**parsed_res)
|
||||
GleaningInstantiateSchema.check_no_jinja_variables(schema.validation_prompt)
|
||||
GleaningInstantiateSchema.check_no_jinja_variables(
|
||||
schema.validation_prompt
|
||||
)
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -180,19 +179,14 @@ class HierarchicalReduceDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=HierarchicalReduceInstantiateSchema,
|
||||
)
|
||||
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
|
|
@ -214,11 +208,12 @@ class HierarchicalReduceDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
@ -259,7 +254,7 @@ class HierarchicalReduceDirective(Directive):
|
|||
"name": rewrite.map_config.name,
|
||||
"type": "map",
|
||||
"prompt": rewrite.map_config.prompt,
|
||||
"model": default_model,
|
||||
"model": default_model,
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
"output": {"schema": {rewrite.map_config.output_keys[0]: "string"}},
|
||||
}
|
||||
|
|
@ -324,4 +319,4 @@ class HierarchicalReduceDirective(Directive):
|
|||
new_ops_plan = self.apply(
|
||||
global_default_model, operators, target_ops[0], rewrite
|
||||
)
|
||||
return new_ops_plan, message_history, call_cost
|
||||
return new_ops_plan, message_history, call_cost
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
|
@ -260,17 +259,14 @@ class IsolatingSubtasksDirective(Directive):
|
|||
original_op.get("output", {}).get("schema", {}).keys()
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=IsolatingSubtasksInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = IsolatingSubtasksInstantiateSchema(**parsed_res)
|
||||
|
|
@ -285,11 +281,12 @@ class IsolatingSubtasksDirective(Directive):
|
|||
return schema, message_history, call_cost
|
||||
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease ensure all original output keys are covered by subtasks and try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
|
@ -7,14 +6,20 @@ from typing import Dict, List, Type
|
|||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import MapReduceFusionInstantiateSchema
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
MapReduceFusionInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class MapReduceFusionDirective(Directive):
|
||||
name: str = Field(default="map_reduce_fusion", description="The name of the directive")
|
||||
formal_description: str = Field(default="Map -> Reduce => Map (with new prompt) -> Reduce (with new propmt)")
|
||||
name: str = Field(
|
||||
default="map_reduce_fusion", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(
|
||||
default="Map -> Reduce => Map (with new prompt) -> Reduce (with new propmt)"
|
||||
)
|
||||
nl_description: str = Field(
|
||||
default="Transform a Map operation followed by a Reduce operation over long documents by updating the Map prompt to first extract or process only the relevant information from each document. Then, modify the Reduce prompt to operate on these processed outputs instead of the full document content."
|
||||
)
|
||||
|
|
@ -42,7 +47,7 @@ class MapReduceFusionDirective(Directive):
|
|||
" type: reduce\\n"
|
||||
" reduce_key: category\\n"
|
||||
" prompt: |\\n"
|
||||
" For each category \\\"{{ reduce_key }}\\\", extract all organization names from these documents:\\n"
|
||||
' For each category \\"{{ reduce_key }}\\", extract all organization names from these documents:\\n'
|
||||
" {% for input in inputs %}\\n"
|
||||
" Document {{ loop.index }}: {{ input.content }}\\n"
|
||||
" {% endfor %}\\n"
|
||||
|
|
@ -79,7 +84,7 @@ class MapReduceFusionDirective(Directive):
|
|||
"reduce_key": "category",
|
||||
"prompt": "For each category '{{ reduce_key }}', extract organizations from: {% for input in inputs %}{{ input.content }}{% endfor %}",
|
||||
"output": {"schema": {"organizations": "list[str]"}},
|
||||
}
|
||||
},
|
||||
],
|
||||
target_ops=["classify_document", "extract_organizations"],
|
||||
expected_behavior="Should modify map to classify AND extract organizations per document, then reduce to aggregate pre-extracted organizations",
|
||||
|
|
@ -101,7 +106,7 @@ class MapReduceFusionDirective(Directive):
|
|||
"reduce_key": "doc_type",
|
||||
"prompt": "For each document type '{{ reduce_key }}', find all people mentioned: {% for input in inputs %}Document: {{ input.text }}{% endfor %}",
|
||||
"output": {"schema": {"people": "list[str]"}},
|
||||
}
|
||||
},
|
||||
],
|
||||
target_ops=["analyze_document", "find_people"],
|
||||
expected_behavior="Should modify map to analyze type AND extract people per document, then reduce to combine pre-extracted people lists",
|
||||
|
|
@ -159,7 +164,7 @@ class MapReduceFusionDirective(Directive):
|
|||
reduce_op: Dict,
|
||||
expected_document_key,
|
||||
agent_llm: str,
|
||||
message_history: list = []
|
||||
message_history: list = [],
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive by transforming the map and reduce operations.
|
||||
|
|
@ -173,66 +178,85 @@ class MapReduceFusionDirective(Directive):
|
|||
Returns:
|
||||
MapReduceFusionInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend([
|
||||
{"role": "system", "content": "You are a helpful AI assistant for optimizing document processing pipelines."},
|
||||
{"role": "user", "content": self.to_string_for_instantiate([map_op, reduce_op])},
|
||||
])
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant for optimizing document processing pipelines.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate([map_op, reduce_op]),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=MapReduceFusionInstantiateSchema
|
||||
response_format=MapReduceFusionInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
if "new_map_name" not in parsed_res or "new_map_prompt" not in parsed_res or "new_key" not in parsed_res or "new_reduce_prompt" not in parsed_res:
|
||||
if (
|
||||
"new_map_name" not in parsed_res
|
||||
or "new_map_prompt" not in parsed_res
|
||||
or "new_key" not in parsed_res
|
||||
or "new_reduce_prompt" not in parsed_res
|
||||
):
|
||||
raise ValueError(
|
||||
"Response from LLM is missing required keys: 'new_map_name', 'new_map_prompt', 'new_key', or 'new_reduce_prompt'"
|
||||
)
|
||||
|
||||
|
||||
schema = MapReduceFusionInstantiateSchema(
|
||||
new_map_name=parsed_res["new_map_name"],
|
||||
new_map_prompt=parsed_res["new_map_prompt"],
|
||||
new_key=parsed_res["new_key"],
|
||||
new_reduce_prompt=parsed_res["new_reduce_prompt"]
|
||||
new_reduce_prompt=parsed_res["new_reduce_prompt"],
|
||||
)
|
||||
|
||||
# Validate the schema
|
||||
MapReduceFusionInstantiateSchema.validate_reduce_prompt_references_new_key(
|
||||
schema.new_reduce_prompt, schema.new_key, expected_document_key
|
||||
)
|
||||
|
||||
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts.")
|
||||
|
||||
def apply(self, global_default_model, ops_list: List[Dict], map_target: str, reduce_target: str, rewrite: MapReduceFusionInstantiateSchema) -> List[Dict]:
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model,
|
||||
ops_list: List[Dict],
|
||||
map_target: str,
|
||||
reduce_target: str,
|
||||
rewrite: MapReduceFusionInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config.
|
||||
"""
|
||||
# Create a copy of the pipeline config
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
|
||||
# Find positions of the target ops
|
||||
map_pos = None
|
||||
reduce_pos = None
|
||||
orig_map_op = None
|
||||
orig_reduce_op = None
|
||||
|
||||
|
||||
for i, op in enumerate(ops_list):
|
||||
if op["name"] == map_target:
|
||||
map_pos = i
|
||||
|
|
@ -240,13 +264,15 @@ class MapReduceFusionDirective(Directive):
|
|||
elif op["name"] == reduce_target:
|
||||
reduce_pos = i
|
||||
orig_reduce_op = op
|
||||
|
||||
|
||||
if map_pos is None or reduce_pos is None:
|
||||
raise ValueError(f"Could not find target operations: {map_target}, {reduce_target}")
|
||||
|
||||
raise ValueError(
|
||||
f"Could not find target operations: {map_target}, {reduce_target}"
|
||||
)
|
||||
|
||||
# Get default model
|
||||
default_model = orig_map_op.get("model", global_default_model)
|
||||
|
||||
|
||||
# Create the new map operation with fused functionality
|
||||
new_map_op = {
|
||||
"name": rewrite.new_map_name,
|
||||
|
|
@ -257,11 +283,11 @@ class MapReduceFusionDirective(Directive):
|
|||
"output": {
|
||||
"schema": {
|
||||
**orig_map_op.get("output", {}).get("schema", {}),
|
||||
rewrite.new_key: "list[str]"
|
||||
rewrite.new_key: "list[str]",
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Create the new reduce operation that works with pre-extracted data
|
||||
new_reduce_op = {
|
||||
"name": orig_reduce_op["name"],
|
||||
|
|
@ -270,54 +296,58 @@ class MapReduceFusionDirective(Directive):
|
|||
"prompt": rewrite.new_reduce_prompt,
|
||||
"model": default_model,
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
"output": orig_reduce_op.get("output", {})
|
||||
"output": orig_reduce_op.get("output", {}),
|
||||
}
|
||||
|
||||
|
||||
# Replace the operations
|
||||
new_ops_list[map_pos] = new_map_op
|
||||
new_ops_list[reduce_pos] = new_reduce_op
|
||||
|
||||
|
||||
return new_ops_list
|
||||
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
optimize_goal="acc",
|
||||
global_default_model: str = None,
|
||||
**kwargs
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
optimize_goal="acc",
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Assert that there are exactly two target ops (map and reduce)
|
||||
if len(target_ops) != 2:
|
||||
raise ValueError("map_reduce_fusion directive requires exactly two target operators: a map and a reduce operation")
|
||||
|
||||
raise ValueError(
|
||||
"map_reduce_fusion directive requires exactly two target operators: a map and a reduce operation"
|
||||
)
|
||||
|
||||
# Find the map and reduce operations
|
||||
first_op = None
|
||||
second_op = None
|
||||
|
||||
|
||||
for op in operators:
|
||||
if op["name"] == target_ops[0]:
|
||||
first_op = op
|
||||
elif op["name"] == target_ops[1]:
|
||||
second_op = op
|
||||
|
||||
|
||||
if first_op is None or second_op is None:
|
||||
raise ValueError(f"Could not find target operations: {target_ops}")
|
||||
|
||||
|
||||
if first_op.get("type") == "map" and second_op.get("type") == "reduce":
|
||||
map_op = first_op
|
||||
reduce_op = second_op
|
||||
map_target = target_ops[0]
|
||||
reduce_target = target_ops[1]
|
||||
else:
|
||||
raise ValueError("Target operators must be one map operation followed by one reduce operation!")
|
||||
|
||||
# Extract expected document key from the reduce prompt template
|
||||
raise ValueError(
|
||||
"Target operators must be one map operation followed by one reduce operation!"
|
||||
)
|
||||
|
||||
# Extract expected document key from the reduce prompt template
|
||||
prompt_template = reduce_op["prompt"]
|
||||
# Find all occurrences of {{ input.key }} in the prompt
|
||||
input_key_pattern = r"\{\{\s*([^\}\s]+)\s*\}\}"
|
||||
|
|
@ -335,7 +365,9 @@ class MapReduceFusionDirective(Directive):
|
|||
]
|
||||
|
||||
if document_key_candidates:
|
||||
expected_document_key = document_key_candidates[0] # Pick the first candidate
|
||||
expected_document_key = document_key_candidates[
|
||||
0
|
||||
] # Pick the first candidate
|
||||
elif input_keys:
|
||||
expected_document_key = input_keys[0] # Fall back to the first input key
|
||||
else:
|
||||
|
|
@ -344,8 +376,12 @@ class MapReduceFusionDirective(Directive):
|
|||
print(f"Detected document key: {expected_document_key}")
|
||||
|
||||
# Instantiate the directive
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(map_op, reduce_op, expected_document_key, agent_llm, message_history)
|
||||
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
map_op, reduce_op, expected_document_key, agent_llm, message_history
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators
|
||||
new_ops_plan = self.apply(global_default_model, operators, map_target, reduce_target, rewrite)
|
||||
return new_ops_plan, message_history, call_cost
|
||||
new_ops_plan = self.apply(
|
||||
global_default_model, operators, map_target, reduce_target, rewrite
|
||||
)
|
||||
return new_ops_plan, message_history, call_cost
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -203,14 +202,10 @@ class MapResolveToMapWithCategoriesDirective(Directive):
|
|||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=MapResolveToMapWithCategoriesInstantiateSchema,
|
||||
)
|
||||
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -182,14 +181,10 @@ class MapToMapResolveReduceDirective(Directive):
|
|||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=MapToMapResolveReduceInstantiateSchema,
|
||||
)
|
||||
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -172,17 +171,14 @@ class OperatorFusionDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=OperatorFusionInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = OperatorFusionInstantiateSchema(**parsed_res)
|
||||
|
|
@ -191,11 +187,12 @@ class OperatorFusionDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
@ -215,7 +212,6 @@ class OperatorFusionDirective(Directive):
|
|||
new_ops_list = deepcopy(ops_list)
|
||||
op1_name, op2_name = target_ops[0], target_ops[1]
|
||||
|
||||
|
||||
# Find the operations
|
||||
op1_idx = next(i for i, op in enumerate(ops_list) if op["name"] == op1_name)
|
||||
op2_idx = next(i for i, op in enumerate(ops_list) if op["name"] == op2_name)
|
||||
|
|
@ -329,5 +325,5 @@ def transform(input_doc):
|
|||
return (
|
||||
self.apply(global_default_model, operators, target_ops, rewrite),
|
||||
message_history,
|
||||
call_cost
|
||||
call_cost,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
|
@ -127,7 +126,7 @@ class ReduceChainingDirective(Directive):
|
|||
original_op: Dict,
|
||||
expected_document_key: str,
|
||||
agent_llm: str,
|
||||
message_history: list = []
|
||||
message_history: list = [],
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive by decomposing the reduce operation.
|
||||
|
|
@ -155,19 +154,15 @@ class ReduceChainingDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
|
||||
call_cost = 0
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=ReduceChainingInstantiateSchema
|
||||
response_format=ReduceChainingInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = ReduceChainingInstantiateSchema(**parsed_res)
|
||||
|
|
@ -182,11 +177,12 @@ class ReduceChainingDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
@ -283,7 +279,9 @@ class ReduceChainingDirective(Directive):
|
|||
]
|
||||
|
||||
if document_key_candidates:
|
||||
expected_document_key = document_key_candidates[0] # Pick the first candidate
|
||||
expected_document_key = document_key_candidates[
|
||||
0
|
||||
] # Pick the first candidate
|
||||
elif input_keys:
|
||||
expected_document_key = input_keys[0] # Fall back to the first input key
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -255,17 +254,14 @@ class ReduceGleaningDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=GleaningInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = GleaningInstantiateSchema(**parsed_res)
|
||||
|
|
@ -274,11 +270,12 @@ class ReduceGleaningDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
@ -341,5 +338,6 @@ class ReduceGleaningDirective(Directive):
|
|||
# Apply the rewrite to the operators
|
||||
return (
|
||||
self.apply(global_default_model, operators, target_ops[0], rewrite),
|
||||
message_history, call_cost
|
||||
message_history,
|
||||
call_cost,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -234,18 +233,14 @@ class TakeHeadTailDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
call_cost = 0
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=TakeHeadTailInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = TakeHeadTailInstantiateSchema(**parsed_res)
|
||||
|
|
@ -254,11 +249,12 @@ class TakeHeadTailDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -445,8 +445,12 @@ class DeterministicDocCompressionInstantiateSchema(BaseModel):
|
|||
raise ValueError(
|
||||
"Code must define a function named 'transform' that takes input_doc as parameter"
|
||||
)
|
||||
if "return {" not in v and "return dict(" not in v:
|
||||
raise ValueError("Code must return a dictionary")
|
||||
if (
|
||||
"return {" not in v
|
||||
and "return dict(" not in v
|
||||
and "output = dict(" not in v
|
||||
):
|
||||
raise ValueError(f"Code must return a dictionary. Instead the code is: {v}")
|
||||
return v
|
||||
|
||||
def validate_code_returns_target_keys(self, target_ops_configs: List[Dict]) -> None:
|
||||
|
|
|
|||
|
|
@ -699,7 +699,7 @@ class DSLRunner(ConfigWrapper):
|
|||
"litellm_kwargs", {}
|
||||
)
|
||||
kwargs["rewrite_agent_model"] = self.config.get("optimizer_config", {}).get(
|
||||
"rewrite_agent_model", "gpt-4o"
|
||||
"rewrite_agent_model", "gpt-5.1"
|
||||
)
|
||||
kwargs["judge_agent_model"] = self.config.get("optimizer_config", {}).get(
|
||||
"judge_agent_model", "gpt-4o-mini"
|
||||
|
|
@ -727,7 +727,7 @@ class DSLRunner(ConfigWrapper):
|
|||
"litellm_kwargs", {}
|
||||
)
|
||||
kwargs["rewrite_agent_model"] = self.config.get("optimizer_config", {}).get(
|
||||
"rewrite_agent_model", "gpt-4o"
|
||||
"rewrite_agent_model", "gpt-5.1"
|
||||
)
|
||||
kwargs["judge_agent_model"] = self.config.get("optimizer_config", {}).get(
|
||||
"judge_agent_model", "gpt-4o-mini"
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ All fields in `optimizer_config` are required (no defaults):
|
|||
| `evaluation_file` | `str` | Path to Python file containing `@register_eval` decorated function |
|
||||
| `metric_key` | `str` | Key in evaluation results dictionary to use as accuracy metric |
|
||||
| `max_iterations` | `int` | Maximum number of MOARSearch iterations to run |
|
||||
| `model` | `str` | LLM model to use for directive instantiation during search |
|
||||
| `rewrite_agent_model` | `str` | LLM model to use for directive instantiation during search |
|
||||
|
||||
!!! warning "All Fields Required"
|
||||
MOAR will error if any required field is missing. There are no defaults.
|
||||
|
|
@ -67,19 +67,19 @@ The optimizer will use the sample dataset, but your final pipeline uses the full
|
|||
### Available Models
|
||||
|
||||
!!! info "LiteLLM Model Names"
|
||||
Use LiteLLM model names (e.g., `gpt-4o-mini`, `gpt-4o`, `gpt-4.1`). Make sure your API keys are set in your environment.
|
||||
Use LiteLLM model names (e.g., `gpt-4o-mini`, `gpt-4o`, `gpt-5.1`). Make sure your API keys are set in your environment.
|
||||
|
||||
```yaml
|
||||
available_models: # LiteLLM model names - ensure API keys are set
|
||||
- gpt-4.1-nano # Cheapest, lower accuracy
|
||||
- gpt-4.1-mini # Low cost, decent accuracy
|
||||
- gpt-4.1 # Balanced
|
||||
- gpt-5.1-nano # Cheapest, lower accuracy
|
||||
- gpt-5.1-mini # Low cost, decent accuracy
|
||||
- gpt-5.1 # Balanced
|
||||
- gpt-4o # Higher cost, better accuracy
|
||||
```
|
||||
|
||||
### Model for Directive Instantiation
|
||||
|
||||
The `model` field specifies which LLM to use for generating optimization directives during the search process. This doesn't affect the models tested in `available_models`.
|
||||
The `rewrite_agent_model` field specifies which LLM to use for generating optimization directives during the search process. This doesn't affect the models tested in `available_models`.
|
||||
|
||||
!!! tip "Cost Consideration"
|
||||
Use a cheaper model (like `gpt-4o-mini`) for directive instantiation to reduce search costs.
|
||||
|
|
@ -104,12 +104,12 @@ optimizer_config:
|
|||
available_models:
|
||||
- gpt-4o-mini
|
||||
- gpt-4o
|
||||
- gpt-4.1-mini
|
||||
- gpt-4.1
|
||||
- gpt-5.1-mini
|
||||
- gpt-5.1
|
||||
evaluation_file: evaluate_medications.py
|
||||
metric_key: medication_extraction_score
|
||||
max_iterations: 40
|
||||
model: gpt-4.1
|
||||
rewrite_agent_model: gpt-5.1
|
||||
dataset_path: data/sample.json # Optional
|
||||
exploration_weight: 1.414 # Optional
|
||||
```
|
||||
|
|
|
|||
|
|
@ -25,15 +25,15 @@ optimizer_config:
|
|||
dataset_path: workloads/medical/raw_sample.json # Use sample for faster optimization
|
||||
save_dir: workloads/medical/moar_results
|
||||
available_models: # LiteLLM model names - ensure API keys are set in your environment
|
||||
- gpt-4.1-nano
|
||||
- gpt-4.1-mini
|
||||
- gpt-4.1
|
||||
- gpt-5.1-nano
|
||||
- gpt-5.1-mini
|
||||
- gpt-5.1
|
||||
- gpt-4o
|
||||
- gpt-4o-mini
|
||||
evaluation_file: workloads/medical/evaluate_medications.py
|
||||
metric_key: medication_extraction_score
|
||||
max_iterations: 40
|
||||
model: gpt-4.1
|
||||
rewrite_agent_model: gpt-5.1
|
||||
|
||||
system_prompt:
|
||||
dataset_description: a collection of transcripts of doctor visits
|
||||
|
|
|
|||
|
|
@ -109,12 +109,12 @@ optimizer_config:
|
|||
available_models: # LiteLLM model names - ensure API keys are set in your environment
|
||||
- gpt-4o-mini
|
||||
- gpt-4o
|
||||
- gpt-4.1-mini
|
||||
- gpt-4.1
|
||||
- gpt-5.1-mini
|
||||
- gpt-5.1
|
||||
evaluation_file: evaluate_medications.py
|
||||
metric_key: medication_extraction_score # This must match a key in your evaluation function's return dictionary
|
||||
max_iterations: 40
|
||||
model: gpt-4.1
|
||||
rewrite_agent_model: gpt-5.1
|
||||
dataset_path: data/transcripts_sample.json # Optional: use sample/hold-out dataset
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ High-level summary of the optimization run:
|
|||
{
|
||||
"optimizer": "moar",
|
||||
"input_pipeline": "pipeline.yaml",
|
||||
"model": "gpt-4.1",
|
||||
"rewrite_agent_model": "gpt-5.1",
|
||||
"max_iterations": 40,
|
||||
"save_dir": "results/moar_optimization",
|
||||
"dataset": "transcripts",
|
||||
|
|
|
|||
|
|
@ -80,9 +80,9 @@ Test your evaluation function independently and check MOAR logs for errors.
|
|||
|
||||
```yaml
|
||||
available_models:
|
||||
- gpt-4.1-nano # Cheapest, lower accuracy
|
||||
- gpt-4.1-mini # Low cost, decent accuracy
|
||||
- gpt-4.1 # Balanced
|
||||
- gpt-5.1-nano # Cheapest, lower accuracy
|
||||
- gpt-5.1-mini # Low cost, decent accuracy
|
||||
- gpt-5.1 # Balanced
|
||||
- gpt-4o # Higher cost, better accuracy
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -267,10 +267,6 @@ class AgentCommunicator:
|
|||
response = completion(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
azure=True,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
|
||||
|
|
@ -290,10 +286,6 @@ class AgentCommunicator:
|
|||
response = completion(
|
||||
model=self.model,
|
||||
messages=messages + [{"role": "user", "content": request_msg}],
|
||||
azure=True,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ class OptimizeResult(BaseModel):
|
|||
should_optimize: str | None = None
|
||||
input_data: list[dict[str, Any]] | None = None
|
||||
output_data: list[dict[str, Any]] | None = None
|
||||
num_docs_analyzed: int | None = None
|
||||
cost: float | None = None
|
||||
error: str | None = None
|
||||
created_at: datetime
|
||||
|
|
@ -35,4 +36,25 @@ class OptimizeResult(BaseModel):
|
|||
class OptimizeRequest(BaseModel):
|
||||
yaml_config: str
|
||||
step_name: str
|
||||
op_name: str
|
||||
op_name: str
|
||||
|
||||
|
||||
class DecomposeRequest(BaseModel):
|
||||
yaml_config: str
|
||||
step_name: str
|
||||
op_name: str
|
||||
|
||||
|
||||
class DecomposeResult(BaseModel):
|
||||
task_id: str
|
||||
status: TaskStatus
|
||||
decomposed_operations: list[dict[str, Any]] | None = None
|
||||
winning_directive: str | None = None
|
||||
candidates_evaluated: int | None = None
|
||||
original_outputs: list[dict[str, Any]] | None = None
|
||||
decomposed_outputs: list[dict[str, Any]] | None = None
|
||||
comparison_rationale: str | None = None
|
||||
cost: float | None = None
|
||||
error: str | None = None
|
||||
created_at: datetime
|
||||
completed_at: datetime | None = None
|
||||
|
|
@ -2,13 +2,23 @@ from typing import Any
|
|||
import uuid
|
||||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from docetl.runner import DSLRunner
|
||||
from docetl.optimizers.fast_should_optimize import FastShouldOptimizeAnalyzer
|
||||
from docetl.optimizers.fast_decomposer import FastDecomposer
|
||||
import asyncio
|
||||
from asyncio import Task
|
||||
from rich.logging import RichHandler
|
||||
import logging
|
||||
import yaml
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from server.app.models import OptimizeResult, TaskStatus, OptimizeRequest, PipelineRequest
|
||||
from server.app.models import (
|
||||
OptimizeResult,
|
||||
TaskStatus,
|
||||
OptimizeRequest,
|
||||
PipelineRequest,
|
||||
DecomposeRequest,
|
||||
DecomposeResult,
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
FORMAT = "%(message)s"
|
||||
|
|
@ -18,10 +28,14 @@ logging.basicConfig(
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
# Task storage
|
||||
# Task storage for optimize tasks
|
||||
tasks: dict[str, OptimizeResult] = {}
|
||||
asyncio_tasks: dict[str, Task] = {}
|
||||
|
||||
# Task storage for decompose tasks
|
||||
decompose_tasks: dict[str, DecomposeResult] = {}
|
||||
decompose_asyncio_tasks: dict[str, Task] = {}
|
||||
|
||||
# Configuration
|
||||
COMPLETED_TASK_TTL = timedelta(hours=1)
|
||||
|
||||
|
|
@ -30,50 +44,111 @@ async def cleanup_old_tasks():
|
|||
while True:
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
task_ids_to_remove = []
|
||||
|
||||
# Clean up optimize tasks
|
||||
task_ids_to_remove = []
|
||||
for task_id, task in tasks.items():
|
||||
if (task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED] and
|
||||
task.completed_at and
|
||||
task.completed_at and
|
||||
current_time - task.completed_at > COMPLETED_TASK_TTL):
|
||||
task_ids_to_remove.append(task_id)
|
||||
|
||||
for task_id in task_ids_to_remove:
|
||||
del tasks[task_id]
|
||||
|
||||
|
||||
# Clean up decompose tasks
|
||||
decompose_ids_to_remove = []
|
||||
for task_id, task in decompose_tasks.items():
|
||||
if (task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED] and
|
||||
task.completed_at and
|
||||
current_time - task.completed_at > COMPLETED_TASK_TTL):
|
||||
decompose_ids_to_remove.append(task_id)
|
||||
for task_id in decompose_ids_to_remove:
|
||||
del decompose_tasks[task_id]
|
||||
|
||||
await asyncio.sleep(60)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in cleanup task: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def run_optimization(task_id: str, yaml_config: str, step_name: str, op_name: str):
|
||||
"""Execute the optimization task"""
|
||||
"""Execute the optimization task using fast single-LLM-call analysis."""
|
||||
try:
|
||||
tasks[task_id].status = TaskStatus.PROCESSING
|
||||
|
||||
# Run the actual optimization in a separate thread to not block
|
||||
runner = DSLRunner.from_yaml(yaml_config)
|
||||
should_optimize, input_data, output_data, cost = await asyncio.to_thread(
|
||||
runner.should_optimize,
|
||||
|
||||
# yaml_config is a file path, not YAML content - read and parse the file
|
||||
if yaml_config.endswith(".yaml") or yaml_config.endswith(".yml"):
|
||||
with open(yaml_config, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
else:
|
||||
# Fallback: try parsing as YAML string
|
||||
config = yaml.safe_load(yaml_config)
|
||||
|
||||
# Validate that we got a dict
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError(
|
||||
f"Invalid yaml_config: expected dict after parsing, got {type(config).__name__}. "
|
||||
f"Input: {str(yaml_config)[:200]}"
|
||||
)
|
||||
|
||||
# Get intermediate directory from config
|
||||
intermediate_dir = (
|
||||
config.get("pipeline", {})
|
||||
.get("output", {})
|
||||
.get("intermediate_dir")
|
||||
)
|
||||
|
||||
if not intermediate_dir:
|
||||
raise ValueError("No intermediate_dir configured in pipeline output")
|
||||
|
||||
# Find the operation config
|
||||
op_config = None
|
||||
for op in config.get("operations", []):
|
||||
if op.get("name") == op_name:
|
||||
op_config = op
|
||||
break
|
||||
|
||||
if not op_config:
|
||||
raise ValueError(f"Operation '{op_name}' not found in config")
|
||||
|
||||
# Get optimizer settings from config
|
||||
optimizer_model = (
|
||||
config.get("optimizer_config", {})
|
||||
.get("rewrite_agent_model", "gpt-5.1")
|
||||
)
|
||||
litellm_kwargs = (
|
||||
config.get("optimizer_config", {})
|
||||
.get("litellm_kwargs", {})
|
||||
)
|
||||
|
||||
# Create analyzer and run in thread pool
|
||||
analyzer = FastShouldOptimizeAnalyzer(
|
||||
intermediate_dir=intermediate_dir,
|
||||
optimizer_model=optimizer_model,
|
||||
litellm_kwargs=litellm_kwargs,
|
||||
)
|
||||
|
||||
should_optimize, output_data, num_docs_analyzed, cost = await asyncio.to_thread(
|
||||
analyzer.analyze,
|
||||
op_config,
|
||||
step_name,
|
||||
op_name,
|
||||
)
|
||||
|
||||
|
||||
# Update task result
|
||||
tasks[task_id].status = TaskStatus.COMPLETED
|
||||
tasks[task_id].should_optimize = should_optimize
|
||||
tasks[task_id].input_data = input_data
|
||||
tasks[task_id].input_data = None # We don't have input_data in fast mode
|
||||
tasks[task_id].output_data = output_data
|
||||
tasks[task_id].num_docs_analyzed = num_docs_analyzed
|
||||
tasks[task_id].cost = cost
|
||||
tasks[task_id].completed_at = datetime.now()
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
runner.is_cancelled = True
|
||||
tasks[task_id].status = TaskStatus.CANCELLED
|
||||
tasks[task_id].completed_at = datetime.now()
|
||||
raise
|
||||
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_traceback = traceback.format_exc()
|
||||
|
|
@ -81,11 +156,10 @@ async def run_optimization(task_id: str, yaml_config: str, step_name: str, op_na
|
|||
tasks[task_id].error = f"{str(e)}\n{error_traceback}"
|
||||
tasks[task_id].completed_at = datetime.now()
|
||||
raise
|
||||
|
||||
|
||||
finally:
|
||||
if task_id in asyncio_tasks:
|
||||
del asyncio_tasks[task_id]
|
||||
runner.reset_env()
|
||||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
|
|
@ -153,6 +227,140 @@ async def cancel_optimize_task(task_id: str):
|
|||
|
||||
return {"message": "Task cancelled successfully"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Decompose endpoints
|
||||
# ============================================================================
|
||||
|
||||
async def run_decomposition(task_id: str, yaml_config: str, step_name: str, op_name: str):
|
||||
"""Execute the decomposition task using fast directive-based approach."""
|
||||
try:
|
||||
decompose_tasks[task_id].status = TaskStatus.PROCESSING
|
||||
|
||||
# yaml_config is a file path
|
||||
if not (yaml_config.endswith(".yaml") or yaml_config.endswith(".yml")):
|
||||
raise ValueError("yaml_config must be a path to a YAML file")
|
||||
|
||||
# Get optimizer settings from config
|
||||
with open(yaml_config, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
optimizer_model = (
|
||||
config.get("optimizer_config", {})
|
||||
.get("rewrite_agent_model", "gpt-5.1")
|
||||
)
|
||||
litellm_kwargs = (
|
||||
config.get("optimizer_config", {})
|
||||
.get("litellm_kwargs", {})
|
||||
)
|
||||
|
||||
# Create decomposer and run in thread pool
|
||||
decomposer = FastDecomposer(
|
||||
yaml_config_path=yaml_config,
|
||||
optimizer_model=optimizer_model,
|
||||
sample_size=5,
|
||||
litellm_kwargs=litellm_kwargs,
|
||||
)
|
||||
|
||||
result = await asyncio.to_thread(
|
||||
decomposer.decompose,
|
||||
step_name,
|
||||
op_name,
|
||||
)
|
||||
|
||||
# Update task result
|
||||
decompose_tasks[task_id].status = TaskStatus.COMPLETED
|
||||
decompose_tasks[task_id].decomposed_operations = result["decomposed_ops"]
|
||||
decompose_tasks[task_id].winning_directive = result["winning_directive"]
|
||||
decompose_tasks[task_id].candidates_evaluated = result["candidates_evaluated"]
|
||||
decompose_tasks[task_id].original_outputs = result["original_outputs"]
|
||||
decompose_tasks[task_id].decomposed_outputs = result["decomposed_outputs"]
|
||||
decompose_tasks[task_id].comparison_rationale = result["comparison_rationale"]
|
||||
decompose_tasks[task_id].cost = result["cost"]
|
||||
decompose_tasks[task_id].completed_at = datetime.now()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
decompose_tasks[task_id].status = TaskStatus.CANCELLED
|
||||
decompose_tasks[task_id].completed_at = datetime.now()
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_traceback = traceback.format_exc()
|
||||
decompose_tasks[task_id].status = TaskStatus.FAILED
|
||||
decompose_tasks[task_id].error = f"{str(e)}\n{error_traceback}"
|
||||
decompose_tasks[task_id].completed_at = datetime.now()
|
||||
raise
|
||||
|
||||
finally:
|
||||
if task_id in decompose_asyncio_tasks:
|
||||
del decompose_asyncio_tasks[task_id]
|
||||
|
||||
|
||||
@router.post("/decompose", status_code=202)
|
||||
async def submit_decompose_task(request: DecomposeRequest):
|
||||
"""Submit a new decomposition task"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# Create task record
|
||||
decompose_tasks[task_id] = DecomposeResult(
|
||||
task_id=task_id,
|
||||
status=TaskStatus.PENDING,
|
||||
created_at=datetime.now()
|
||||
)
|
||||
|
||||
# Create and store the asyncio task
|
||||
task = asyncio.create_task(
|
||||
run_decomposition(
|
||||
task_id,
|
||||
request.yaml_config,
|
||||
request.step_name,
|
||||
request.op_name
|
||||
)
|
||||
)
|
||||
decompose_asyncio_tasks[task_id] = task
|
||||
|
||||
return {"task_id": task_id}
|
||||
|
||||
|
||||
@router.get("/decompose/{task_id}")
|
||||
async def get_decompose_status(task_id: str) -> DecomposeResult:
|
||||
"""Get the current status of a decomposition task"""
|
||||
if task_id not in decompose_tasks:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Task not found or has been cleaned up"
|
||||
)
|
||||
|
||||
return decompose_tasks[task_id]
|
||||
|
||||
|
||||
@router.post("/decompose/{task_id}/cancel")
|
||||
async def cancel_decompose_task(task_id: str):
|
||||
"""Cancel a running decomposition task"""
|
||||
if task_id not in decompose_tasks:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Task not found or has been cleaned up"
|
||||
)
|
||||
|
||||
if task_id not in decompose_asyncio_tasks:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Task already finished or cannot be cancelled"
|
||||
)
|
||||
|
||||
asyncio_task = decompose_asyncio_tasks[task_id]
|
||||
asyncio_task.cancel()
|
||||
|
||||
try:
|
||||
await asyncio_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
return {"message": "Decomposition task cancelled successfully"}
|
||||
|
||||
|
||||
# Keep the original run_pipeline endpoint
|
||||
@router.post("/run_pipeline")
|
||||
def run_pipeline(request: PipelineRequest) -> dict[str, Any]:
|
||||
|
|
@ -169,6 +377,12 @@ def run_pipeline(request: PipelineRequest) -> dict[str, Any]:
|
|||
|
||||
@router.websocket("/ws/run_pipeline/{client_id}")
|
||||
async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
|
||||
"""
|
||||
WebSocket endpoint for running pipelines with real-time output streaming.
|
||||
|
||||
Note: The old 'optimize' flag for full pipeline optimization has been removed.
|
||||
Use the /decompose endpoint for fast operation decomposition instead.
|
||||
"""
|
||||
await websocket.accept()
|
||||
runner = None
|
||||
try:
|
||||
|
|
@ -178,19 +392,8 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
|
|||
if config.get("clear_intermediate", False):
|
||||
runner.clear_intermediate()
|
||||
|
||||
if config.get("optimize", False):
|
||||
logging.info(f"Optimizing pipeline with model {config.get('optimizer_model', 'gpt-4o')}")
|
||||
|
||||
# Set the runner config to the optimizer config
|
||||
runner.config["optimizer_config"]["rewrite_agent_model"] = config.get("optimizer_model", "gpt-4o")
|
||||
runner.config["optimizer_config"]["judge_agent_model"] = config.get("optimizer_model", "gpt-4o-mini")
|
||||
|
||||
async def run_pipeline():
|
||||
return await asyncio.to_thread(runner.optimize, return_pipeline=False)
|
||||
|
||||
else:
|
||||
async def run_pipeline():
|
||||
return await asyncio.to_thread(runner.load_run_save)
|
||||
async def run_pipeline():
|
||||
return await asyncio.to_thread(runner.load_run_save)
|
||||
|
||||
pipeline_task = asyncio.create_task(run_pipeline())
|
||||
|
||||
|
|
@ -198,18 +401,6 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
|
|||
console_output = runner.console.file.getvalue()
|
||||
await websocket.send_json({"type": "output", "data": console_output})
|
||||
|
||||
if config.get("optimize", False):
|
||||
optimizer_progress = runner.console.get_optimizer_progress()
|
||||
rationale = runner.console.optimizer_rationale
|
||||
await websocket.send_json({
|
||||
"type": "optimizer_progress",
|
||||
"status": optimizer_progress[0],
|
||||
"progress": optimizer_progress[1],
|
||||
"rationale": rationale[1] if rationale is not None else "",
|
||||
"should_optimize": rationale[0] if rationale is not None else False,
|
||||
"validator_prompt": rationale[2] if rationale is not None else ""
|
||||
})
|
||||
|
||||
# Check for incoming messages from the user
|
||||
try:
|
||||
user_message = await asyncio.wait_for(
|
||||
|
|
@ -219,8 +410,7 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
|
|||
if user_message == "kill":
|
||||
runner.console.log("Stopping process...")
|
||||
runner.is_cancelled = True
|
||||
|
||||
|
||||
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Process stopped by user request"
|
||||
|
|
@ -250,41 +440,16 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
|
|||
# Sleep for a short duration to ensure all output is captured
|
||||
await asyncio.sleep(3)
|
||||
|
||||
# If optimize is true, send back the optimized operations
|
||||
if config.get("optimize", False):
|
||||
optimized_config, cost = result
|
||||
|
||||
# Send the operations back in order
|
||||
new_pipeline_steps = optimized_config["pipeline"]["steps"]
|
||||
new_pipeline_op_name_to_op_map = {op["name"]: op for op in optimized_config["operations"]}
|
||||
new_ops_in_order = []
|
||||
for new_step in new_pipeline_steps:
|
||||
for op in new_step.get("operations", []):
|
||||
if op not in new_ops_in_order:
|
||||
new_ops_in_order.append(new_pipeline_op_name_to_op_map[op])
|
||||
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "result",
|
||||
"data": {
|
||||
"message": "Pipeline executed successfully",
|
||||
"cost": cost,
|
||||
"optimized_ops": new_ops_in_order,
|
||||
"yaml_config": config["yaml_config"],
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "result",
|
||||
"data": {
|
||||
"message": "Pipeline executed successfully",
|
||||
"cost": result,
|
||||
"yaml_config": config["yaml_config"],
|
||||
},
|
||||
}
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "result",
|
||||
"data": {
|
||||
"message": "Pipeline executed successfully",
|
||||
"cost": result,
|
||||
"yaml_config": config["yaml_config"],
|
||||
},
|
||||
}
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
if runner is not None:
|
||||
runner.reset_env()
|
||||
|
|
@ -299,3 +464,159 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
|
|||
if runner is not None:
|
||||
runner.reset_env()
|
||||
await websocket.close()
|
||||
|
||||
|
||||
@router.websocket("/ws/decompose/{client_id}")
|
||||
async def websocket_decompose(websocket: WebSocket, client_id: str):
|
||||
"""
|
||||
WebSocket endpoint for fast operation decomposition with real-time output streaming.
|
||||
|
||||
Expects JSON message with:
|
||||
- yaml_config: Path to the pipeline YAML config file
|
||||
- step_name: Name of the pipeline step
|
||||
- op_name: Name of the operation to decompose
|
||||
"""
|
||||
await websocket.accept()
|
||||
decomposer = None
|
||||
|
||||
try:
|
||||
config = await websocket.receive_json()
|
||||
yaml_config = config["yaml_config"]
|
||||
step_name = config["step_name"]
|
||||
op_name = config["op_name"]
|
||||
|
||||
# Validate yaml_config is a path
|
||||
if not (yaml_config.endswith(".yaml") or yaml_config.endswith(".yml")):
|
||||
raise ValueError("yaml_config must be a path to a YAML file")
|
||||
|
||||
# Get optimizer settings from config
|
||||
with open(yaml_config, "r") as f:
|
||||
pipeline_config = yaml.safe_load(f)
|
||||
|
||||
optimizer_model = (
|
||||
pipeline_config.get("optimizer_config", {})
|
||||
.get("rewrite_agent_model", "gpt-5.1")
|
||||
)
|
||||
litellm_kwargs = (
|
||||
pipeline_config.get("optimizer_config", {})
|
||||
.get("litellm_kwargs", {})
|
||||
)
|
||||
|
||||
# Create a ThreadSafeConsole for streaming output
|
||||
from docetl.console import ThreadSafeConsole
|
||||
console = ThreadSafeConsole(
|
||||
force_terminal=True,
|
||||
soft_wrap=True,
|
||||
highlight=False,
|
||||
log_path=False,
|
||||
color_system="truecolor",
|
||||
width=120,
|
||||
style="bright_white on black",
|
||||
record=True,
|
||||
)
|
||||
|
||||
# Create decomposer with the console
|
||||
decomposer = FastDecomposer(
|
||||
yaml_config_path=yaml_config,
|
||||
optimizer_model=optimizer_model,
|
||||
sample_size=5,
|
||||
litellm_kwargs=litellm_kwargs,
|
||||
console=console,
|
||||
)
|
||||
|
||||
# Run decomposition in a thread
|
||||
async def run_decomposition():
|
||||
return await asyncio.to_thread(
|
||||
decomposer.decompose,
|
||||
step_name,
|
||||
op_name,
|
||||
)
|
||||
|
||||
decompose_task = asyncio.create_task(run_decomposition())
|
||||
|
||||
# Stream console output while decomposition runs
|
||||
accumulated_output = ""
|
||||
while not decompose_task.done():
|
||||
# get_output() processes carriage returns and clears the buffer
|
||||
new_output = console.get_output()
|
||||
if new_output:
|
||||
# Handle spinner updates: if new output doesn't start with newline
|
||||
# and accumulated doesn't end with newline, replace the last line
|
||||
if accumulated_output and not accumulated_output.endswith('\n') and not new_output.startswith('\n'):
|
||||
# Find the last newline in accumulated output
|
||||
last_newline = accumulated_output.rfind('\n')
|
||||
if last_newline >= 0:
|
||||
# Replace everything after the last newline
|
||||
accumulated_output = accumulated_output[:last_newline + 1] + new_output
|
||||
else:
|
||||
# No newline in accumulated, replace everything
|
||||
accumulated_output = new_output
|
||||
else:
|
||||
accumulated_output += new_output
|
||||
await websocket.send_json({"type": "output", "data": accumulated_output})
|
||||
|
||||
# Check for kill message
|
||||
try:
|
||||
user_message = await asyncio.wait_for(
|
||||
websocket.receive_json(), timeout=0.1
|
||||
)
|
||||
if user_message == "kill":
|
||||
console.log("[red]Stopping decomposition...[/red]")
|
||||
decompose_task.cancel()
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Decomposition stopped by user request"
|
||||
})
|
||||
raise Exception("Process stopped by user request")
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Decomposition stopped by user request"
|
||||
})
|
||||
raise
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Get final result
|
||||
result = await decompose_task
|
||||
|
||||
# Send any remaining console output
|
||||
final_output = console.get_output()
|
||||
if final_output:
|
||||
accumulated_output += final_output
|
||||
await websocket.send_json({"type": "output", "data": accumulated_output})
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Send the result
|
||||
await websocket.send_json({
|
||||
"type": "result",
|
||||
"data": {
|
||||
"decomposed_operations": result["decomposed_ops"],
|
||||
"winning_directive": result["winning_directive"],
|
||||
"candidates_evaluated": result["candidates_evaluated"],
|
||||
"original_outputs": result["original_outputs"],
|
||||
"decomposed_outputs": result["decomposed_outputs"],
|
||||
"comparison_rationale": result["comparison_rationale"],
|
||||
"cost": result["cost"],
|
||||
},
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
print(f"Decompose client {client_id} disconnected")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_traceback = traceback.format_exc()
|
||||
print(f"Decompose error:\n{error_traceback}")
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"data": str(e),
|
||||
"traceback": error_traceback
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
await websocket.close()
|
||||
|
|
|
|||
|
|
@ -5,4 +5,9 @@ export const API_ROUTES = {
|
|||
CANCEL: (taskId: string) =>
|
||||
`/api/shouldOptimize?taskId=${taskId}&cancel=true`,
|
||||
},
|
||||
DECOMPOSE: {
|
||||
SUBMIT: "/api/decompose",
|
||||
STATUS: (taskId: string) => `/api/decompose?taskId=${taskId}`,
|
||||
CANCEL: (taskId: string) => `/api/decompose?taskId=${taskId}&cancel=true`,
|
||||
},
|
||||
} as const;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,79 @@
|
|||
import { NextRequest, NextResponse } from "next/server";
|
||||
|
||||
const FASTAPI_URL = `${
|
||||
process.env.NEXT_PUBLIC_BACKEND_HTTPS ? "https" : "http"
|
||||
}://${process.env.NEXT_PUBLIC_BACKEND_HOST}:${
|
||||
process.env.NEXT_PUBLIC_BACKEND_PORT
|
||||
}`;
|
||||
|
||||
// Helper to handle errors consistently
|
||||
const handleError = (error: unknown, status = 500) => {
|
||||
const message =
|
||||
error instanceof Error ? error.message : "Internal server error";
|
||||
return NextResponse.json({ error: message }, { status });
|
||||
};
|
||||
|
||||
// Helper to proxy requests to FastAPI
|
||||
async function proxyRequest(path: string, init?: RequestInit) {
|
||||
const response = await fetch(`${FASTAPI_URL}${path}`, {
|
||||
...init,
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...init?.headers,
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`FastAPI server error: ${error}`);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
export async function POST(request: NextRequest): Promise<NextResponse> {
|
||||
try {
|
||||
// Extract task ID from the URL if it exists
|
||||
const taskId = request.nextUrl.searchParams.get("taskId");
|
||||
const isCancel = request.nextUrl.searchParams.get("cancel") === "true";
|
||||
|
||||
// Handle different POST scenarios
|
||||
if (taskId) {
|
||||
if (isCancel) {
|
||||
// Cancel task
|
||||
const data = await proxyRequest(`/decompose/${taskId}/cancel`, {
|
||||
method: "POST",
|
||||
});
|
||||
return NextResponse.json(data);
|
||||
}
|
||||
// Invalid request with taskId but no cancel
|
||||
return handleError(new Error("Invalid request"), 400);
|
||||
}
|
||||
|
||||
// Submit new task
|
||||
const body = await request.json();
|
||||
const data = await proxyRequest("/decompose", {
|
||||
method: "POST",
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
return NextResponse.json(data, { status: 202 });
|
||||
} catch (error) {
|
||||
return handleError(error);
|
||||
}
|
||||
}
|
||||
|
||||
export async function GET(request: NextRequest): Promise<NextResponse> {
|
||||
try {
|
||||
// Extract task ID from the URL
|
||||
const taskId = request.nextUrl.searchParams.get("taskId");
|
||||
|
||||
if (!taskId) {
|
||||
return handleError(new Error("Task ID is required"), 400);
|
||||
}
|
||||
|
||||
const data = await proxyRequest(`/decompose/${taskId}`);
|
||||
return NextResponse.json(data);
|
||||
} catch (error) {
|
||||
return handleError(error);
|
||||
}
|
||||
}
|
||||
|
|
@ -33,6 +33,77 @@ export function getNamespaceDir(homeDir: string, namespace: string) {
|
|||
return path.join(homeDir, ".docetl", namespace);
|
||||
}
|
||||
|
||||
/**
|
||||
* Safely parse a JSON value - returns the parsed value if it's a string,
|
||||
* or the original value if it's already an object.
|
||||
* @param value - The value to parse
|
||||
* @param fieldName - Name of the field (for error messages)
|
||||
* @returns The parsed value
|
||||
*/
|
||||
function safeParseJSON(
|
||||
value: unknown,
|
||||
fieldName: string
|
||||
): Record<string, unknown> | unknown[] {
|
||||
if (value === null || value === undefined) {
|
||||
throw new Error(`Field '${fieldName}' is null or undefined`);
|
||||
}
|
||||
|
||||
// If it's already an object or array, return as-is
|
||||
if (typeof value === "object") {
|
||||
return value as Record<string, unknown> | unknown[];
|
||||
}
|
||||
|
||||
// If it's a string, try to parse it
|
||||
if (typeof value === "string") {
|
||||
try {
|
||||
return JSON.parse(value);
|
||||
} catch (e) {
|
||||
throw new Error(
|
||||
`Field '${fieldName}' contains invalid JSON: ${value.slice(0, 100)}${value.length > 100 ? "..." : ""}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
`Field '${fieldName}' has unexpected type '${typeof value}': ${String(value).slice(0, 100)}`
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively sanitize an object for YAML serialization.
|
||||
* Converts any nested [object Object] strings to proper objects.
|
||||
*/
|
||||
function sanitizeForYaml(obj: unknown, path: string = ""): unknown {
|
||||
if (obj === null || obj === undefined) {
|
||||
return obj;
|
||||
}
|
||||
|
||||
if (typeof obj === "string") {
|
||||
// Check for [object Object] which indicates a serialization issue
|
||||
if (obj === "[object Object]") {
|
||||
console.warn(
|
||||
`Warning: Found '[object Object]' string at path '${path}'. This indicates a serialization issue.`
|
||||
);
|
||||
return {}; // Return empty object as fallback
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
if (Array.isArray(obj)) {
|
||||
return obj.map((item, idx) => sanitizeForYaml(item, `${path}[${idx}]`));
|
||||
}
|
||||
|
||||
if (typeof obj === "object") {
|
||||
const result: Record<string, unknown> = {};
|
||||
for (const [key, value] of Object.entries(obj)) {
|
||||
result[key] = sanitizeForYaml(value, path ? `${path}.${key}` : key);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
return obj;
|
||||
}
|
||||
|
||||
export function generatePipelineConfig(
|
||||
namespace: string,
|
||||
default_model: string,
|
||||
|
|
@ -82,9 +153,18 @@ export function generatePipelineConfig(
|
|||
const litellm_completion_kwargs =
|
||||
op.otherKwargs?.litellm_completion_kwargs;
|
||||
if (litellm_completion_kwargs) {
|
||||
op.otherKwargs.litellm_completion_kwargs = JSON.parse(
|
||||
litellm_completion_kwargs
|
||||
);
|
||||
try {
|
||||
op.otherKwargs.litellm_completion_kwargs = safeParseJSON(
|
||||
litellm_completion_kwargs,
|
||||
`${op.name}.litellm_completion_kwargs`
|
||||
);
|
||||
} catch (e) {
|
||||
console.warn(
|
||||
`Warning: Could not parse litellm_completion_kwargs for operation '${op.name}':`,
|
||||
e
|
||||
);
|
||||
// Keep the original value if parsing fails
|
||||
}
|
||||
}
|
||||
|
||||
const newOp: Record<string, unknown> = {
|
||||
|
|
@ -175,10 +255,14 @@ export function generatePipelineConfig(
|
|||
// If it's a sample operation with custom method, parse the samples as key-value pairs
|
||||
if (op.type === "sample" && op.otherKwargs?.method === "custom") {
|
||||
try {
|
||||
newOp.samples = JSON.parse(op.otherKwargs.samples);
|
||||
newOp.samples = safeParseJSON(
|
||||
op.otherKwargs.samples,
|
||||
`${op.name}.samples`
|
||||
);
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
"Failed to parse custom samples as JSON, using raw value"
|
||||
`Failed to parse custom samples for operation '${op.name}':`,
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -217,14 +301,12 @@ export function generatePipelineConfig(
|
|||
// Fix type errors by asserting the pipeline config type
|
||||
let pipelineConfig: any = {
|
||||
from_docwrangler: true,
|
||||
optimizer_model: optimizerModel,
|
||||
datasets,
|
||||
default_model,
|
||||
...(enable_observability && {
|
||||
optimizer_config: {
|
||||
force_decompose: true,
|
||||
},
|
||||
}),
|
||||
optimizer_config: {
|
||||
rewrite_agent_model: optimizerModel,
|
||||
...(enable_observability && { force_decompose: true }),
|
||||
},
|
||||
operations: updatedOperations,
|
||||
pipeline: {
|
||||
steps: [
|
||||
|
|
@ -308,7 +390,9 @@ export function generatePipelineConfig(
|
|||
const outputOpName = operationsToRun[currentOpIndex].name;
|
||||
outputPath = path.join(outputBase, "data_processing", outputOpName + ".json");
|
||||
|
||||
const yamlString = yaml.dump(pipelineConfig);
|
||||
// Sanitize the config before serializing to YAML to catch any [object Object] issues
|
||||
const sanitizedConfig = sanitizeForYaml(pipelineConfig, "pipelineConfig");
|
||||
const yamlString = yaml.dump(sanitizedConfig);
|
||||
|
||||
console.log(yamlString);
|
||||
|
||||
|
|
|
|||
|
|
@ -104,6 +104,7 @@ export interface OptimizeResult {
|
|||
should_optimize?: string;
|
||||
input_data?: Array<Record<string, unknown>>;
|
||||
output_data?: Array<Record<string, unknown>>;
|
||||
num_docs_analyzed?: number;
|
||||
cost?: number;
|
||||
error?: string;
|
||||
created_at: string;
|
||||
|
|
@ -114,3 +115,24 @@ export interface APIKey {
|
|||
name: string;
|
||||
value: string;
|
||||
}
|
||||
|
||||
export interface DecomposeRequest {
|
||||
yaml_config: string;
|
||||
step_name: string;
|
||||
op_name: string;
|
||||
}
|
||||
|
||||
export interface DecomposeResult {
|
||||
task_id: string;
|
||||
status: TaskStatus;
|
||||
decomposed_operations?: Array<Record<string, unknown>>;
|
||||
winning_directive?: string;
|
||||
candidates_evaluated?: number;
|
||||
original_outputs?: Array<Record<string, unknown>>;
|
||||
decomposed_outputs?: Array<Record<string, unknown>>;
|
||||
comparison_rationale?: string;
|
||||
cost?: number;
|
||||
error?: string;
|
||||
created_at: string;
|
||||
completed_at?: string;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,12 +24,14 @@ interface AnsiRendererProps {
|
|||
text: string;
|
||||
readyState: number;
|
||||
setTerminalOutput: (text: string) => void;
|
||||
isDecomposing?: boolean;
|
||||
}
|
||||
|
||||
const AnsiRenderer: React.FC<AnsiRendererProps> = ({
|
||||
text,
|
||||
readyState,
|
||||
setTerminalOutput,
|
||||
isDecomposing = false,
|
||||
}) => {
|
||||
const html = convert.toHtml(text);
|
||||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
|
|
@ -51,7 +53,9 @@ const AnsiRenderer: React.FC<AnsiRendererProps> = ({
|
|||
}
|
||||
};
|
||||
|
||||
const isWebSocketClosed = readyState === WebSocket.CLOSED;
|
||||
// Consider connected if either WebSocket is open OR decomposing is in progress
|
||||
const isConnected = readyState === WebSocket.OPEN || isDecomposing;
|
||||
const isWebSocketClosed = readyState === WebSocket.CLOSED && !isDecomposing;
|
||||
|
||||
return (
|
||||
<div
|
||||
|
|
@ -92,9 +96,11 @@ const AnsiRenderer: React.FC<AnsiRendererProps> = ({
|
|||
)}
|
||||
</div>
|
||||
<div className="flex justify-between items-center text-xs text-gray-500">
|
||||
<div className={isWebSocketClosed ? "text-red-500" : ""}>
|
||||
<div className={isWebSocketClosed ? "text-red-500" : isDecomposing ? "text-blue-400" : ""}>
|
||||
Status:{" "}
|
||||
{readyState === WebSocket.CONNECTING
|
||||
{isDecomposing
|
||||
? "Decomposing..."
|
||||
: readyState === WebSocket.CONNECTING
|
||||
? "Connecting"
|
||||
: readyState === WebSocket.OPEN
|
||||
? "Connected"
|
||||
|
|
@ -106,10 +112,7 @@ const AnsiRenderer: React.FC<AnsiRendererProps> = ({
|
|||
</div>
|
||||
<button
|
||||
onClick={() => setTerminalOutput("")}
|
||||
className={`hover:text-white transition-colors ${
|
||||
isWebSocketClosed ? "cursor-not-allowed opacity-50" : ""
|
||||
}`}
|
||||
disabled={isWebSocketClosed}
|
||||
className="hover:text-white transition-colors"
|
||||
>
|
||||
Clear
|
||||
</button>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,271 @@
|
|||
import React from "react";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogDescription,
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ChevronLeft, ChevronRight, Check, X } from "lucide-react";
|
||||
import { DecomposeResult } from "@/app/types";
|
||||
|
||||
interface DecompositionComparisonDialogProps {
|
||||
isOpen: boolean;
|
||||
result: DecomposeResult | null;
|
||||
operationName?: string;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
onApply: () => void;
|
||||
onCancel: () => void;
|
||||
}
|
||||
|
||||
export const DecompositionComparisonDialog: React.FC<
|
||||
DecompositionComparisonDialogProps
|
||||
> = ({ isOpen, result, operationName, onOpenChange, onApply, onCancel }) => {
|
||||
const [currentPage, setCurrentPage] = React.useState(1);
|
||||
|
||||
const originalOutputs = result?.original_outputs || [];
|
||||
const decomposedOutputs = result?.decomposed_outputs || [];
|
||||
const maxRows = Math.max(originalOutputs.length, decomposedOutputs.length);
|
||||
|
||||
const renderOutputTable = (
|
||||
data: Array<Record<string, unknown>>,
|
||||
title: string
|
||||
) => {
|
||||
if (!data.length) {
|
||||
return (
|
||||
<div className="text-center text-muted-foreground py-8">
|
||||
No outputs available
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const columns = Object.keys(data[0]).filter(
|
||||
(column) => !column.startsWith("_observability")
|
||||
);
|
||||
|
||||
const currentRow = data[currentPage - 1];
|
||||
if (!currentRow) return null;
|
||||
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
<h4 className="font-medium text-sm text-muted-foreground">{title}</h4>
|
||||
<div className="border rounded-md">
|
||||
<div className="max-h-[250px] overflow-auto">
|
||||
<Table className="relative w-full border-collapse">
|
||||
<TableHeader>
|
||||
<TableRow className="sticky top-0 bg-background z-10 border-b">
|
||||
{columns.map((column) => (
|
||||
<TableHead
|
||||
key={column}
|
||||
className="h-8 px-3 text-left align-middle bg-background text-xs"
|
||||
>
|
||||
{column}
|
||||
</TableHead>
|
||||
))}
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
<TableRow>
|
||||
{columns.map((column) => (
|
||||
<TableCell key={column} className="p-2 align-top">
|
||||
<pre className="whitespace-pre-wrap font-mono text-xs text-left max-w-[300px]">
|
||||
{typeof currentRow[column] === "object"
|
||||
? JSON.stringify(currentRow[column], null, 2)
|
||||
: String(currentRow[column] ?? "")}
|
||||
</pre>
|
||||
</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
React.useEffect(() => {
|
||||
if (isOpen) {
|
||||
setCurrentPage(1);
|
||||
}
|
||||
}, [isOpen]);
|
||||
|
||||
if (!result) return null;
|
||||
|
||||
const isOriginalWinner = result.winning_directive === "original";
|
||||
|
||||
return (
|
||||
<Dialog open={isOpen} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-7xl max-h-[90vh] flex flex-col">
|
||||
<DialogHeader className="flex-shrink-0 border-b pb-4">
|
||||
<DialogTitle className="text-xl">
|
||||
Review Decomposition Results
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
Compare the original operation outputs with the decomposed version
|
||||
before applying changes.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="flex-1 overflow-y-auto py-4 space-y-6">
|
||||
{/* Summary */}
|
||||
<div className="p-4 bg-muted rounded-lg space-y-2">
|
||||
<div className="flex items-center gap-4 flex-wrap">
|
||||
{operationName && (
|
||||
<div className="flex items-center">
|
||||
<span className="font-medium text-sm mr-2">Operation:</span>
|
||||
<span className="bg-primary/15 text-primary rounded-md px-2 py-1 text-sm">
|
||||
{operationName}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
<div className="flex items-center">
|
||||
<span className="font-medium text-sm mr-2">
|
||||
Winning Strategy:
|
||||
</span>
|
||||
<span
|
||||
className={`rounded-md px-2 py-1 text-sm ${
|
||||
isOriginalWinner
|
||||
? "bg-yellow-100 text-yellow-800"
|
||||
: "bg-green-100 text-green-800"
|
||||
}`}
|
||||
>
|
||||
{result.winning_directive}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center">
|
||||
<span className="font-medium text-sm mr-2">
|
||||
Candidates Evaluated:
|
||||
</span>
|
||||
<span className="text-sm">{result.candidates_evaluated}</span>
|
||||
</div>
|
||||
{result.cost && (
|
||||
<div className="flex items-center">
|
||||
<span className="font-medium text-sm mr-2">Cost:</span>
|
||||
<span className="text-sm">${result.cost.toFixed(4)}</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Comparison Rationale */}
|
||||
{result.comparison_rationale && (
|
||||
<div className="space-y-2">
|
||||
<h3 className="font-medium text-base">Why This Was Chosen</h3>
|
||||
<div className="p-3 bg-blue-50 border border-blue-200 rounded-lg text-sm">
|
||||
{result.comparison_rationale}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Side by side comparison */}
|
||||
<div className="space-y-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<h3 className="font-medium text-base">
|
||||
Output Comparison (Sample {currentPage} of {maxRows || 1})
|
||||
</h3>
|
||||
<div className="flex items-center space-x-1">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() =>
|
||||
setCurrentPage((prev) => Math.max(prev - 1, 1))
|
||||
}
|
||||
disabled={currentPage === 1}
|
||||
className="px-2 py-1"
|
||||
>
|
||||
<ChevronLeft className="h-4 w-4" />
|
||||
Previous
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() =>
|
||||
setCurrentPage((prev) => Math.min(prev + 1, maxRows))
|
||||
}
|
||||
disabled={currentPage === maxRows || maxRows === 0}
|
||||
className="px-2 py-1"
|
||||
>
|
||||
Next
|
||||
<ChevronRight className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
{renderOutputTable(originalOutputs, "Original Output")}
|
||||
{renderOutputTable(
|
||||
decomposedOutputs,
|
||||
`Decomposed Output (${result.winning_directive})`
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Decomposed Operations Preview */}
|
||||
{result.decomposed_operations &&
|
||||
result.decomposed_operations.length > 0 &&
|
||||
!isOriginalWinner && (
|
||||
<div className="space-y-2">
|
||||
<h3 className="font-medium text-base">
|
||||
New Operations ({result.decomposed_operations.length})
|
||||
</h3>
|
||||
<div className="space-y-2">
|
||||
{result.decomposed_operations.map((op, idx) => (
|
||||
<div
|
||||
key={idx}
|
||||
className="p-3 border rounded-lg bg-muted/50"
|
||||
>
|
||||
<div className="flex items-center gap-2 mb-2">
|
||||
<span className="font-medium text-sm">
|
||||
{(op.name as string) || `Operation ${idx + 1}`}
|
||||
</span>
|
||||
<span className="text-xs bg-primary/10 text-primary px-2 py-0.5 rounded">
|
||||
{op.type as string}
|
||||
</span>
|
||||
</div>
|
||||
{op.prompt && (
|
||||
<pre className="text-xs font-mono bg-background p-2 rounded border max-h-[100px] overflow-auto">
|
||||
{String(op.prompt).slice(0, 500)}
|
||||
{String(op.prompt).length > 500 ? "..." : ""}
|
||||
</pre>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end items-center gap-3 pt-4 border-t mt-4">
|
||||
<Button variant="outline" onClick={onCancel}>
|
||||
<X className="h-4 w-4 mr-2" />
|
||||
Keep Original
|
||||
</Button>
|
||||
<Button
|
||||
onClick={onApply}
|
||||
disabled={isOriginalWinner}
|
||||
title={
|
||||
isOriginalWinner
|
||||
? "Original was determined to be the best option"
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
<Check className="h-4 w-4 mr-2" />
|
||||
Apply Decomposition
|
||||
</Button>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
DecompositionComparisonDialog.displayName = "DecompositionComparisonDialog";
|
||||
|
|
@ -39,7 +39,6 @@ import { Skeleton } from "@/components/ui/skeleton";
|
|||
import { debounce } from "lodash";
|
||||
import { Guardrails, GleaningConfig } from "./operations/args";
|
||||
import createOperationComponent from "./operations/components";
|
||||
import { useWebSocket } from "@/contexts/WebSocketContext";
|
||||
import { Badge } from "./ui/badge";
|
||||
import {
|
||||
Popover,
|
||||
|
|
@ -161,7 +160,7 @@ const OperationHeader: React.FC<OperationHeaderProps> = React.memo(
|
|||
}`}
|
||||
/>
|
||||
</HoverCardTrigger>
|
||||
<HoverCardContent className="w-72" side="bottom" align="start">
|
||||
<HoverCardContent className="w-96 max-h-80" side="bottom" align="start">
|
||||
<div className="flex flex-col space-y-1">
|
||||
<p className="text-sm font-medium">
|
||||
{optimizeResult === undefined || optimizeResult === null
|
||||
|
|
@ -170,13 +169,15 @@ const OperationHeader: React.FC<OperationHeaderProps> = React.memo(
|
|||
? "Decomposition Status"
|
||||
: "Decomposition Recommended"}
|
||||
</p>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{optimizeResult === undefined || optimizeResult === null
|
||||
? "Analyzing operation complexity..."
|
||||
: optimizeResult === ""
|
||||
? "No decomposition needed for this operation"
|
||||
: "Recommended decomposition: " + optimizeResult}
|
||||
</p>
|
||||
<div className="max-h-64 overflow-y-auto">
|
||||
<p className="text-sm text-muted-foreground whitespace-pre-wrap">
|
||||
{optimizeResult === undefined || optimizeResult === null
|
||||
? "Analyzing operation complexity..."
|
||||
: optimizeResult === ""
|
||||
? "No decomposition needed for this operation"
|
||||
: optimizeResult}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</HoverCardContent>
|
||||
</HoverCard>
|
||||
|
|
@ -367,7 +368,7 @@ const OperationHeader: React.FC<OperationHeaderProps> = React.memo(
|
|||
disabled={disabled}
|
||||
>
|
||||
<Zap className="mr-2 h-4 w-4" />
|
||||
Optimize Operation
|
||||
Decompose Operation
|
||||
</Button>
|
||||
)}
|
||||
|
||||
|
|
@ -720,7 +721,6 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
|
|||
output: pipelineOutput,
|
||||
setOutput,
|
||||
isLoadingOutputs,
|
||||
setIsLoadingOutputs,
|
||||
numOpRun,
|
||||
setNumOpRun,
|
||||
currentFile,
|
||||
|
|
@ -730,18 +730,14 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
|
|||
sampleSize,
|
||||
setCost,
|
||||
defaultModel,
|
||||
optimizerModel,
|
||||
setTerminalOutput,
|
||||
namespace,
|
||||
apiKeys,
|
||||
systemPrompt,
|
||||
extraPipelineSettings,
|
||||
onRequestDecompositionRef,
|
||||
} = usePipelineContext();
|
||||
const { toast } = useToast();
|
||||
|
||||
const operationRef = useRef(operation);
|
||||
const { connect, sendMessage, lastMessage, readyState, disconnect } =
|
||||
useWebSocket();
|
||||
|
||||
useEffect(() => {
|
||||
operationRef.current = operation;
|
||||
|
|
@ -808,75 +804,19 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
|
|||
const handleOptimizeConfirm = useCallback(async () => {
|
||||
if (!operation) return;
|
||||
|
||||
try {
|
||||
// Clear the output
|
||||
setTerminalOutput("");
|
||||
setIsLoadingOutputs(true);
|
||||
setShowOptimizeDialog(false);
|
||||
|
||||
// Write pipeline config
|
||||
const response = await fetch("/api/writePipelineConfig", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
default_model: defaultModel,
|
||||
data: { path: currentFile?.path || "" },
|
||||
operations,
|
||||
operation_id: operation.id,
|
||||
name: pipelineName,
|
||||
sample_size: sampleSize,
|
||||
optimize: true,
|
||||
clear_intermediate: false,
|
||||
system_prompt: systemPrompt,
|
||||
namespace: namespace,
|
||||
apiKeys: apiKeys,
|
||||
optimizerModel: optimizerModel,
|
||||
extraPipelineSettings: extraPipelineSettings,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(await response.text());
|
||||
}
|
||||
|
||||
const { filePath } = await response.json();
|
||||
|
||||
// Ensure WebSocket is connected
|
||||
await connect();
|
||||
|
||||
// Send message to run the pipeline
|
||||
sendMessage({
|
||||
yaml_config: filePath,
|
||||
optimize: true,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error optimizing operation:", error);
|
||||
// Use the decomposition flow from context if available
|
||||
if (onRequestDecompositionRef.current) {
|
||||
onRequestDecompositionRef.current(operation.id, operation.name);
|
||||
} else {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: error.message,
|
||||
description: "Decomposition handler not available",
|
||||
variant: "destructive",
|
||||
});
|
||||
// Close the WebSocket connection
|
||||
disconnect();
|
||||
} finally {
|
||||
setShowOptimizeDialog(false);
|
||||
}
|
||||
}, [
|
||||
operation,
|
||||
defaultModel,
|
||||
currentFile,
|
||||
operations,
|
||||
pipelineName,
|
||||
sampleSize,
|
||||
optimizerModel,
|
||||
connect,
|
||||
sendMessage,
|
||||
systemPrompt,
|
||||
namespace,
|
||||
apiKeys,
|
||||
extraPipelineSettings,
|
||||
]);
|
||||
}, [operation, onRequestDecompositionRef, toast]);
|
||||
|
||||
const onShowOutput = useCallback(async () => {
|
||||
if (!operation) return;
|
||||
|
|
@ -1224,7 +1164,7 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
|
|||
>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle>Optimize Operation</AlertDialogTitle>
|
||||
<AlertDialogTitle>Decompose Operation</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
{!hasOpenAIKey && !isLocalMode ? (
|
||||
<div className="space-y-2">
|
||||
|
|
@ -1232,8 +1172,8 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
|
|||
OpenAI API Key Required
|
||||
</p>
|
||||
<p>
|
||||
To use the optimizer, please add your OpenAI API key in Edit{" "}
|
||||
{">"}
|
||||
To decompose operations, please add your OpenAI API key in
|
||||
Edit {">"}
|
||||
Edit API Keys.
|
||||
</p>
|
||||
<button
|
||||
|
|
@ -1245,11 +1185,10 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
|
|||
</div>
|
||||
) : (
|
||||
<p>
|
||||
This will analyze the operation and replace it with another
|
||||
pipeline that has higher accuracy (as determined by an
|
||||
LLM-as-a-judge), if it can be found. Do you want to proceed?
|
||||
The process may take between 2 and 10 minutes, depending on
|
||||
how complex your data is.
|
||||
This will analyze the operation and try different
|
||||
decomposition strategies (chaining, chunking, gleaning, etc.)
|
||||
to improve accuracy. The best strategy will be selected via
|
||||
LLM-as-a-judge comparison. Do you want to proceed?
|
||||
</p>
|
||||
)}
|
||||
</AlertDialogDescription>
|
||||
|
|
@ -1260,7 +1199,7 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
|
|||
onClick={handleOptimizeConfirm}
|
||||
disabled={!hasOpenAIKey && !isLocalMode}
|
||||
>
|
||||
Proceed
|
||||
Decompose
|
||||
</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ interface OptimizationDialogProps {
|
|||
operationName?: string;
|
||||
inputData?: Array<Record<string, unknown>>;
|
||||
outputData?: Array<Record<string, unknown>>;
|
||||
numDocsAnalyzed?: number;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
onDecompose?: () => void;
|
||||
}
|
||||
|
|
@ -34,6 +35,7 @@ export const OptimizationDialog: React.FC<OptimizationDialogProps> = ({
|
|||
inputData,
|
||||
outputData,
|
||||
operationName,
|
||||
numDocsAnalyzed,
|
||||
onOpenChange,
|
||||
onDecompose,
|
||||
}) => {
|
||||
|
|
@ -140,9 +142,9 @@ export const OptimizationDialog: React.FC<OptimizationDialogProps> = ({
|
|||
<DialogHeader className="flex-shrink-0 border-b pb-4">
|
||||
<DialogTitle className="text-xl">Operation Too Complex</DialogTitle>
|
||||
<p className="text-base mt-2">
|
||||
This operation might be too complex for the LLM to handle
|
||||
efficiently. We recommend breaking it down into smaller, more
|
||||
manageable steps.
|
||||
After analyzing {numDocsAnalyzed ?? "several"} documents, this
|
||||
operation might be too complex for the LLM to handle efficiently. We
|
||||
recommend breaking it down into smaller, more manageable steps.
|
||||
</p>
|
||||
</DialogHeader>
|
||||
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ const useOutputContext = () => {
|
|||
const {
|
||||
output,
|
||||
isLoadingOutputs,
|
||||
isDecomposing,
|
||||
terminalOutput,
|
||||
setTerminalOutput,
|
||||
optimizerProgress,
|
||||
|
|
@ -91,6 +92,7 @@ const useOutputContext = () => {
|
|||
return {
|
||||
output,
|
||||
isLoadingOutputs,
|
||||
isDecomposing,
|
||||
terminalOutput,
|
||||
setTerminalOutput,
|
||||
optimizerProgress,
|
||||
|
|
@ -385,7 +387,7 @@ VisualizeContent.displayName = "VisualizeContent";
|
|||
|
||||
// Move ConsoleContent outside
|
||||
export const ConsoleContent = memo(() => {
|
||||
const { terminalOutput, setTerminalOutput, optimizerProgress } =
|
||||
const { terminalOutput, setTerminalOutput, optimizerProgress, isDecomposing } =
|
||||
useOutputContext();
|
||||
const { readyState } = useWebSocket();
|
||||
|
||||
|
|
@ -469,6 +471,7 @@ export const ConsoleContent = memo(() => {
|
|||
text={terminalOutput || ""}
|
||||
readyState={readyState}
|
||||
setTerminalOutput={setTerminalOutput}
|
||||
isDecomposing={isDecomposing}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -478,7 +481,7 @@ ConsoleContent.displayName = "ConsoleContent";
|
|||
|
||||
// Main Output component
|
||||
export const Output = memo(() => {
|
||||
const { output, isLoadingOutputs, sampleSize, operations } =
|
||||
const { output, isLoadingOutputs, isDecomposing, sampleSize, operations } =
|
||||
useOutputContext();
|
||||
const operation = useOperation(output?.operationId);
|
||||
|
||||
|
|
@ -500,10 +503,14 @@ export const Output = memo(() => {
|
|||
);
|
||||
}, [operation]);
|
||||
|
||||
// Effect for tab changes
|
||||
// Effect for tab changes - switch to console when loading or decomposing
|
||||
useEffect(() => {
|
||||
setActiveTab(isLoadingOutputs ? "console" : "table");
|
||||
}, [isLoadingOutputs]);
|
||||
if (isLoadingOutputs || isDecomposing) {
|
||||
setActiveTab("console");
|
||||
} else {
|
||||
setActiveTab("table");
|
||||
}
|
||||
}, [isLoadingOutputs, isDecomposing]);
|
||||
|
||||
// Memoize columns
|
||||
const columns = useMemo(() => {
|
||||
|
|
|
|||
|
|
@ -40,9 +40,12 @@ import { Input } from "@/components/ui/input";
|
|||
import { schemaDictToItemSet } from "./utils";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { useOptimizeCheck } from "@/hooks/useOptimizeCheck";
|
||||
import { useDecomposeWebSocket } from "@/hooks/useDecomposeWebSocket";
|
||||
import { canBeOptimized } from "@/lib/utils";
|
||||
import { Textarea } from "./ui/textarea";
|
||||
import { OptimizationDialog } from "@/components/OptimizationDialog";
|
||||
import { DecompositionComparisonDialog } from "@/components/DecompositionComparisonDialog";
|
||||
import { DecomposeResult } from "@/app/types";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
|
|
@ -240,6 +243,9 @@ const PipelineGUI: React.FC = () => {
|
|||
apiKeys,
|
||||
extraPipelineSettings,
|
||||
setExtraPipelineSettings,
|
||||
isDecomposing,
|
||||
setIsDecomposing,
|
||||
onRequestDecompositionRef,
|
||||
} = usePipelineContext();
|
||||
const [isSettingsOpen, setIsSettingsOpen] = useState(false);
|
||||
const { toast } = useToast();
|
||||
|
|
@ -253,12 +259,14 @@ const PipelineGUI: React.FC = () => {
|
|||
outputData?: Array<Record<string, unknown>>;
|
||||
operationName?: string;
|
||||
operationId?: string;
|
||||
numDocsAnalyzed?: number;
|
||||
}>({
|
||||
isOpen: false,
|
||||
content: "",
|
||||
prompt: undefined,
|
||||
operationName: undefined,
|
||||
operationId: undefined,
|
||||
numDocsAnalyzed: undefined,
|
||||
});
|
||||
const [isEditingName, setIsEditingName] = useState(false);
|
||||
const [editedPipelineName, setEditedPipelineName] = useState(pipelineName);
|
||||
|
|
@ -277,14 +285,13 @@ const PipelineGUI: React.FC = () => {
|
|||
setCost((prev) => prev + result.cost);
|
||||
|
||||
if (result.should_optimize) {
|
||||
toast({
|
||||
title: `Hey! Consider decomposing ${operations[operations.length - 1].name
|
||||
}`,
|
||||
const lastOp = operations[operations.length - 1];
|
||||
const { dismiss } = toast({
|
||||
title: `Hey! Consider decomposing ${lastOp.name}`,
|
||||
description: (
|
||||
<span
|
||||
className="cursor-pointer text-blue-500 hover:text-blue-700"
|
||||
onClick={() => {
|
||||
const lastOp = operations[operations.length - 1];
|
||||
setOptimizationDialog({
|
||||
isOpen: true,
|
||||
content: result.should_optimize,
|
||||
|
|
@ -293,7 +300,9 @@ const PipelineGUI: React.FC = () => {
|
|||
operationId: lastOp.id,
|
||||
inputData: result.input_data,
|
||||
outputData: result.output_data,
|
||||
numDocsAnalyzed: result.num_docs_analyzed,
|
||||
});
|
||||
dismiss();
|
||||
}}
|
||||
>
|
||||
Click here to see why.
|
||||
|
|
@ -312,6 +321,144 @@ const PipelineGUI: React.FC = () => {
|
|||
},
|
||||
});
|
||||
|
||||
const [decomposeDialog, setDecomposeDialog] = useState<{
|
||||
isOpen: boolean;
|
||||
result: DecomposeResult | null;
|
||||
targetOpId: string | null;
|
||||
operationName: string | null;
|
||||
}>({
|
||||
isOpen: false,
|
||||
result: null,
|
||||
targetOpId: null,
|
||||
operationName: null,
|
||||
});
|
||||
|
||||
// Ref to store the current decomposition target (avoids stale closure issue)
|
||||
const decompositionTargetRef = useRef<{
|
||||
operationId: string;
|
||||
operationName: string;
|
||||
} | null>(null);
|
||||
|
||||
const { startDecomposition, cancel: cancelDecomposition } = useDecomposeWebSocket({
|
||||
namespace,
|
||||
onOutput: (output) => {
|
||||
// Stream output to the terminal
|
||||
setTerminalOutput(output);
|
||||
},
|
||||
onComplete: (result) => {
|
||||
setIsDecomposing(false);
|
||||
setCost((prev) => prev + (result.cost || 0));
|
||||
|
||||
// Read from ref to avoid stale closure
|
||||
const target = decompositionTargetRef.current;
|
||||
console.log("Decomposition complete:", {
|
||||
result,
|
||||
target,
|
||||
decomposed_operations: result.decomposed_operations,
|
||||
});
|
||||
|
||||
// Show the comparison dialog instead of auto-applying
|
||||
setDecomposeDialog({
|
||||
isOpen: true,
|
||||
result,
|
||||
targetOpId: target?.operationId || null,
|
||||
operationName: target?.operationName || null,
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
setIsDecomposing(false);
|
||||
toast({
|
||||
title: "Decomposition Failed",
|
||||
description: error,
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const handleApplyDecomposition = () => {
|
||||
const { result, targetOpId } = decomposeDialog;
|
||||
console.log("handleApplyDecomposition called:", { result, targetOpId, decomposeDialog });
|
||||
if (!result || !targetOpId) {
|
||||
console.warn("handleApplyDecomposition: missing result or targetOpId", { result, targetOpId });
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
result.decomposed_operations &&
|
||||
result.decomposed_operations.length > 0 &&
|
||||
result.winning_directive !== "original"
|
||||
) {
|
||||
setOperations((prev) => {
|
||||
// Find the index of the target operation
|
||||
const targetIdx = prev.findIndex((op) => op.id === targetOpId);
|
||||
if (targetIdx === -1) return prev;
|
||||
|
||||
// Convert decomposed operations to our Operation format
|
||||
const newOps: Operation[] = result.decomposed_operations!.map(
|
||||
(opConfig) => ({
|
||||
id: uuidv4(),
|
||||
llmType:
|
||||
opConfig.type === "map" ||
|
||||
opConfig.type === "reduce" ||
|
||||
opConfig.type === "filter" ||
|
||||
opConfig.type === "parallel_map"
|
||||
? "LLM"
|
||||
: "non-LLM",
|
||||
type: opConfig.type as Operation["type"],
|
||||
name: opConfig.name as string,
|
||||
prompt: opConfig.prompt as string | undefined,
|
||||
output: opConfig.output
|
||||
? {
|
||||
schema: schemaDictToItemSet(
|
||||
(opConfig.output as { schema: Record<string, string> })
|
||||
.schema
|
||||
),
|
||||
}
|
||||
: undefined,
|
||||
gleaning: opConfig.gleaning as Operation["gleaning"],
|
||||
otherKwargs: Object.fromEntries(
|
||||
Object.entries(opConfig).filter(
|
||||
([key]) =>
|
||||
!["name", "type", "prompt", "output", "gleaning"].includes(key)
|
||||
)
|
||||
),
|
||||
visibility: true,
|
||||
})
|
||||
);
|
||||
|
||||
// Replace the target operation with the new operations
|
||||
const newOperations = [...prev];
|
||||
newOperations.splice(targetIdx, 1, ...newOps);
|
||||
return newOperations;
|
||||
});
|
||||
|
||||
toast({
|
||||
title: "Decomposition Applied",
|
||||
description: `Applied ${result.winning_directive} directive. Created ${result.decomposed_operations.length} operation(s).`,
|
||||
});
|
||||
}
|
||||
|
||||
setDecomposeDialog({
|
||||
isOpen: false,
|
||||
result: null,
|
||||
targetOpId: null,
|
||||
operationName: null,
|
||||
});
|
||||
};
|
||||
|
||||
const handleCancelDecomposition = () => {
|
||||
toast({
|
||||
title: "Decomposition Cancelled",
|
||||
description: "Keeping the original operation.",
|
||||
});
|
||||
setDecomposeDialog({
|
||||
isOpen: false,
|
||||
result: null,
|
||||
targetOpId: null,
|
||||
operationName: null,
|
||||
});
|
||||
};
|
||||
|
||||
const { restoreFromYAML } = useRestorePipeline({
|
||||
setOperations,
|
||||
setPipelineName,
|
||||
|
|
@ -328,79 +475,18 @@ const PipelineGUI: React.FC = () => {
|
|||
if (lastMessage) {
|
||||
if (lastMessage.type === "output") {
|
||||
setTerminalOutput(lastMessage.data);
|
||||
} else if (lastMessage.type === "optimizer_progress") {
|
||||
setOptimizerProgress({
|
||||
status: lastMessage.status,
|
||||
progress: lastMessage.progress,
|
||||
shouldOptimize: lastMessage.should_optimize,
|
||||
rationale: lastMessage.rationale,
|
||||
validatorPrompt: lastMessage.validator_prompt,
|
||||
});
|
||||
} else if (lastMessage.type === "result") {
|
||||
const runCost = lastMessage.data.cost || 0;
|
||||
setOptimizerProgress(null);
|
||||
|
||||
// See if there was an optimized operation
|
||||
const optimizedOps = lastMessage.data.optimized_ops;
|
||||
if (optimizedOps) {
|
||||
const newOperations = optimizedOps.map((optimizedOp) => {
|
||||
const {
|
||||
id,
|
||||
type,
|
||||
name,
|
||||
prompt,
|
||||
output,
|
||||
validate,
|
||||
gleaning,
|
||||
sample,
|
||||
...otherKwargs
|
||||
} = optimizedOp;
|
||||
|
||||
// Find matching operation in previous operations list
|
||||
const existingOp = operations.find((op) => op.name === name);
|
||||
|
||||
return {
|
||||
id: id || uuidv4(),
|
||||
llmType:
|
||||
type === "map" ||
|
||||
type === "reduce" ||
|
||||
type === "resolve" ||
|
||||
type === "filter" ||
|
||||
type === "parallel_map" ||
|
||||
type === "rank" ||
|
||||
type === "extract"
|
||||
? "LLM"
|
||||
: "non-LLM",
|
||||
type: type,
|
||||
name: name || "Untitled Operation",
|
||||
prompt: prompt,
|
||||
output: output
|
||||
? {
|
||||
schema: schemaDictToItemSet(output.schema),
|
||||
}
|
||||
: undefined,
|
||||
validate: validate,
|
||||
gleaning: gleaning,
|
||||
sample: sample,
|
||||
otherKwargs: otherKwargs || {},
|
||||
...(existingOp?.runIndex && { runIndex: existingOp.runIndex }),
|
||||
visibility: true,
|
||||
} as Operation;
|
||||
});
|
||||
|
||||
setOperations(newOperations);
|
||||
} else {
|
||||
// No optimized operations, so we need to check if we should optimize the last operation
|
||||
// Trigger should optimize for the last operation
|
||||
if (autoOptimizeCheck) {
|
||||
const lastOp = operations[operations.length - 1];
|
||||
if (lastOp && canBeOptimized(lastOp.type)) {
|
||||
submitTask({
|
||||
yaml_config: lastMessage.data.yaml_config,
|
||||
step_name: "data_processing", // TODO: Make this a constant
|
||||
op_name: lastOp.name,
|
||||
});
|
||||
}
|
||||
// Trigger should_optimize check for the last operation if enabled
|
||||
if (autoOptimizeCheck) {
|
||||
const lastOp = operations[operations.length - 1];
|
||||
if (lastOp && canBeOptimized(lastOp.type)) {
|
||||
submitTask({
|
||||
yaml_config: lastMessage.data.yaml_config,
|
||||
step_name: "data_processing",
|
||||
op_name: lastOp.name,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -692,20 +778,35 @@ const PipelineGUI: React.FC = () => {
|
|||
};
|
||||
|
||||
const handleStop = () => {
|
||||
sendMessage("kill");
|
||||
// Stop pipeline run if running
|
||||
if (isLoadingOutputs) {
|
||||
sendMessage("kill");
|
||||
if (readyState === WebSocket.CLOSED) {
|
||||
setIsLoadingOutputs(false);
|
||||
}
|
||||
}
|
||||
|
||||
if (readyState === WebSocket.CLOSED && isLoadingOutputs) {
|
||||
setIsLoadingOutputs(false);
|
||||
// Stop decomposition if running
|
||||
if (isDecomposing) {
|
||||
cancelDecomposition();
|
||||
setIsDecomposing(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleOptimizeFromDialog = async () => {
|
||||
if (!optimizationDialog.operationId) return;
|
||||
if (!optimizationDialog.operationId || !optimizationDialog.operationName) return;
|
||||
|
||||
// Store target in ref so onComplete callback can access it
|
||||
decompositionTargetRef.current = {
|
||||
operationId: optimizationDialog.operationId,
|
||||
operationName: optimizationDialog.operationName,
|
||||
};
|
||||
|
||||
try {
|
||||
setTerminalOutput("");
|
||||
setIsLoadingOutputs(true);
|
||||
setIsDecomposing(true);
|
||||
setTerminalOutput(""); // Clear terminal output
|
||||
|
||||
// Write the pipeline config to get a YAML file path
|
||||
const response = await fetch("/api/writePipelineConfig", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
|
|
@ -732,24 +833,115 @@ const PipelineGUI: React.FC = () => {
|
|||
|
||||
const { filePath } = await response.json();
|
||||
|
||||
await connect();
|
||||
// Use the WebSocket decompose endpoint for streaming output
|
||||
await startDecomposition(
|
||||
filePath,
|
||||
"data_processing", // Matches the step name in generatePipelineConfig
|
||||
optimizationDialog.operationName
|
||||
);
|
||||
|
||||
sendMessage({
|
||||
yaml_config: filePath,
|
||||
optimize: true,
|
||||
toast({
|
||||
title: "Decomposing Operation",
|
||||
description: "Analyzing and generating candidate decompositions. Watch the console for progress.",
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error optimizing operation:", error);
|
||||
console.error("Error starting decomposition:", error);
|
||||
setIsDecomposing(false);
|
||||
toast({
|
||||
title: "Error",
|
||||
description: error.message,
|
||||
description: error instanceof Error ? error.message : "Failed to start decomposition",
|
||||
variant: "destructive",
|
||||
});
|
||||
disconnect();
|
||||
setIsLoadingOutputs(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Handler for decomposition requests from OperationCard dropdown
|
||||
const handleDecompositionRequest = useCallback(
|
||||
async (operationId: string, operationName: string) => {
|
||||
// Store target in ref so onComplete callback can access it
|
||||
decompositionTargetRef.current = {
|
||||
operationId,
|
||||
operationName,
|
||||
};
|
||||
|
||||
try {
|
||||
setIsDecomposing(true);
|
||||
setTerminalOutput(""); // Clear terminal output
|
||||
|
||||
// Set the dialog state so we know which operation we're decomposing
|
||||
setDecomposeDialog((prev) => ({
|
||||
...prev,
|
||||
targetOpId: operationId,
|
||||
operationName: operationName,
|
||||
}));
|
||||
|
||||
// Write the pipeline config to get a YAML file path
|
||||
const response = await fetch("/api/writePipelineConfig", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
default_model: defaultModel,
|
||||
data: { path: currentFile?.path || "" },
|
||||
operations,
|
||||
operation_id: operationId,
|
||||
name: pipelineName,
|
||||
sample_size: sampleSize,
|
||||
optimize: true,
|
||||
namespace: namespace,
|
||||
apiKeys: apiKeys,
|
||||
optimizerModel: optimizerModel,
|
||||
extraPipelineSettings: extraPipelineSettings,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(await response.text());
|
||||
}
|
||||
|
||||
const { filePath } = await response.json();
|
||||
|
||||
// Start decomposition via WebSocket
|
||||
startDecomposition(filePath, "data_processing", operationName);
|
||||
|
||||
toast({
|
||||
title: "Decomposing Operation",
|
||||
description:
|
||||
"Analyzing and generating candidate decompositions. Watch the console for progress.",
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error starting decomposition:", error);
|
||||
setIsDecomposing(false);
|
||||
toast({
|
||||
title: "Error",
|
||||
description:
|
||||
error instanceof Error ? error.message : "Failed to start decomposition",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
},
|
||||
[
|
||||
defaultModel,
|
||||
currentFile,
|
||||
operations,
|
||||
pipelineName,
|
||||
sampleSize,
|
||||
namespace,
|
||||
apiKeys,
|
||||
optimizerModel,
|
||||
extraPipelineSettings,
|
||||
startDecomposition,
|
||||
setIsDecomposing,
|
||||
setTerminalOutput,
|
||||
toast,
|
||||
]
|
||||
);
|
||||
|
||||
// Register the decomposition handler in context so OperationCard can use it
|
||||
// Using ref assignment instead of state to avoid infinite loops
|
||||
onRequestDecompositionRef.current = handleDecompositionRequest;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full">
|
||||
<div
|
||||
|
|
@ -1053,7 +1245,7 @@ const PipelineGUI: React.FC = () => {
|
|||
variant="destructive"
|
||||
className="rounded-sm whitespace-nowrap"
|
||||
onClick={handleStop}
|
||||
disabled={!isLoadingOutputs}
|
||||
disabled={!isLoadingOutputs && !isDecomposing}
|
||||
>
|
||||
<StopCircle size={16} className="mr-2" />
|
||||
Stop
|
||||
|
|
@ -1144,11 +1336,22 @@ const PipelineGUI: React.FC = () => {
|
|||
operationName={optimizationDialog.operationName}
|
||||
inputData={optimizationDialog.inputData}
|
||||
outputData={optimizationDialog.outputData}
|
||||
numDocsAnalyzed={optimizationDialog.numDocsAnalyzed}
|
||||
onOpenChange={(open) =>
|
||||
setOptimizationDialog((prev) => ({ ...prev, isOpen: open }))
|
||||
}
|
||||
onDecompose={handleOptimizeFromDialog}
|
||||
/>
|
||||
<DecompositionComparisonDialog
|
||||
isOpen={decomposeDialog.isOpen}
|
||||
result={decomposeDialog.result}
|
||||
operationName={decomposeDialog.operationName || undefined}
|
||||
onOpenChange={(open) =>
|
||||
setDecomposeDialog((prev) => ({ ...prev, isOpen: open }))
|
||||
}
|
||||
onApply={handleApplyDecomposition}
|
||||
onCancel={handleCancelDecomposition}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ interface PromptInputProps {
|
|||
onChange: (value: string) => void;
|
||||
disableValidation?: boolean;
|
||||
placeholder?: string;
|
||||
operationType?: "map" | "reduce" | "filter" | "resolve" | "other";
|
||||
}
|
||||
|
||||
export const PromptInput: React.FC<PromptInputProps> = React.memo(
|
||||
|
|
@ -38,31 +39,62 @@ export const PromptInput: React.FC<PromptInputProps> = React.memo(
|
|||
prompt,
|
||||
onChange,
|
||||
disableValidation = false,
|
||||
placeholder = "Enter prompt (must be a Jinja2 template)",
|
||||
placeholder = "Enter prompt (Jinja2 template optional)",
|
||||
operationType = "map",
|
||||
}) => {
|
||||
const validateJinjaTemplate = (value: string) => {
|
||||
const hasJinjaTemplate = (value: string) => {
|
||||
if (disableValidation) return true;
|
||||
const hasOpenBrace = value.includes("{{");
|
||||
const hasCloseBrace = value.includes("}}");
|
||||
return hasOpenBrace && hasCloseBrace;
|
||||
};
|
||||
|
||||
const isReduce = operationType === "reduce";
|
||||
const templateVar = isReduce ? "inputs" : "input";
|
||||
const exampleKey = isReduce ? "inputs[0].content" : "input.content";
|
||||
|
||||
return (
|
||||
<>
|
||||
<Textarea
|
||||
placeholder={placeholder}
|
||||
className={`mb-1 rounded-sm text-sm font-mono ${
|
||||
!validateJinjaTemplate(prompt) ? "border-red-500" : ""
|
||||
!hasJinjaTemplate(prompt) ? "border-amber-400" : ""
|
||||
}`}
|
||||
rows={Math.max(3, Math.ceil(prompt.split("\n").length))}
|
||||
value={prompt}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
/>
|
||||
{!validateJinjaTemplate(prompt) && (
|
||||
<div className="text-red-500 text-sm mb-1">
|
||||
Prompt must contain Jinja2 template syntax {"{"}
|
||||
{"{"} and {"}"}
|
||||
{"}"}
|
||||
{!hasJinjaTemplate(prompt) && (
|
||||
<div className="text-amber-600 text-sm mb-1 space-y-1">
|
||||
<div>
|
||||
<strong>Warning:</strong> No Jinja2 template found.
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{isReduce ? (
|
||||
<>
|
||||
{"{"}
|
||||
{"{"} inputs {"}"}
|
||||
{"}"} (the list of all documents) will be auto-appended. For
|
||||
example, if documents have keys {'"'}id{'"'} and {'"'}content
|
||||
{'"'}, all docs with all keys will be included. To reference
|
||||
specific keys, use a for loop: {"{"}% for doc in inputs %{"}"}{" "}
|
||||
{"{"}
|
||||
{"{"} doc.content {"}"}
|
||||
{"}"} {"{"}% endfor %{"}"}.
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
{"{"}
|
||||
{"{"} input {"}"}
|
||||
{"}"} (the full document) will be auto-appended. For example,
|
||||
if documents have keys {'"'}id{'"'}, {'"'}title{'"'}, and {'"'}
|
||||
content{'"'}, all three will be included. To use only specific
|
||||
keys (e.g., just {'"'}content{'"'}), add {"{"}
|
||||
{"{"} input.content {"}"}
|
||||
{"}"} to your prompt.
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
|
|
@ -77,6 +109,7 @@ PromptInput.propTypes = {
|
|||
onChange: PropTypes.func.isRequired,
|
||||
disableValidation: PropTypes.bool,
|
||||
placeholder: PropTypes.string,
|
||||
operationType: PropTypes.oneOf(["map", "reduce", "filter", "resolve", "other"]),
|
||||
};
|
||||
|
||||
interface SchemaFormProps {
|
||||
|
|
|
|||
|
|
@ -314,6 +314,7 @@ export const ReduceOperationComponent: React.FC<OperationComponentProps> = ({
|
|||
<PromptInput
|
||||
prompt={operation.prompt || ""}
|
||||
onChange={handlePromptChange}
|
||||
operationType="reduce"
|
||||
/>
|
||||
<OutputSchema
|
||||
schema={schemaItems}
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ const ToastViewport = React.forwardRef<
|
|||
ToastViewport.displayName = ToastPrimitives.Viewport.displayName;
|
||||
|
||||
const toastVariants = cva(
|
||||
"group pointer-events-auto relative flex w-full items-center justify-between space-x-2 overflow-y-auto max-h-[80vh] whitespace-pre-wrap break-words rounded-md border p-4 pr-6 shadow-lg transition-all data-[swipe=cancel]:translate-x-0 data-[swipe=end]:translate-x-[var(--radix-toast-swipe-end-x)] data-[swipe=move]:translate-x-[var(--radix-toast-swipe-move-x)] data-[swipe=move]:transition-none data-[state=open]:animate-in data-[state=closed]:animate-out data-[swipe=end]:animate-out data-[state=closed]:fade-out-80 data-[state=closed]:slide-out-to-right-full data-[state=open]:slide-in-from-top-full data-[state=open]:sm:slide-in-from-bottom-full",
|
||||
"group pointer-events-auto relative flex w-full items-start justify-between space-x-2 overflow-hidden max-h-[80vh] whitespace-pre-wrap break-words rounded-md border p-4 pr-6 shadow-lg transition-all data-[swipe=cancel]:translate-x-0 data-[swipe=end]:translate-x-[var(--radix-toast-swipe-end-x)] data-[swipe=move]:translate-x-[var(--radix-toast-swipe-move-x)] data-[swipe=move]:transition-none data-[state=open]:animate-in data-[state=closed]:animate-out data-[swipe=end]:animate-out data-[state=closed]:fade-out-80 data-[state=closed]:slide-out-to-right-full data-[state=open]:slide-in-from-top-full data-[state=open]:sm:slide-in-from-bottom-full",
|
||||
{
|
||||
variants: {
|
||||
variant: {
|
||||
|
|
|
|||
|
|
@ -22,29 +22,33 @@ export function Toaster() {
|
|||
const descId = `toast-desc-${id}`;
|
||||
return (
|
||||
<Toast key={id} {...props}>
|
||||
<div className="grid gap-1">
|
||||
{title && <ToastTitle>{title}</ToastTitle>}
|
||||
<div className="grid gap-1 flex-1 overflow-hidden">
|
||||
<div className="flex items-center gap-2 shrink-0">
|
||||
{title && <ToastTitle>{title}</ToastTitle>}
|
||||
{description && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-5 w-5 p-0 opacity-70 hover:opacity-100"
|
||||
onClick={() => {
|
||||
const el = document.getElementById(descId);
|
||||
const text = el?.textContent ?? "";
|
||||
if (text) {
|
||||
navigator.clipboard.writeText(text);
|
||||
}
|
||||
}}
|
||||
title="Copy message"
|
||||
>
|
||||
<Copy size={12} />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
{description && (
|
||||
<ToastDescription id={descId}>{description}</ToastDescription>
|
||||
<div className="overflow-y-auto max-h-[60vh]">
|
||||
<ToastDescription id={descId}>{description}</ToastDescription>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{description && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-6 px-2 text-xs"
|
||||
onClick={() => {
|
||||
const el = document.getElementById(descId);
|
||||
const text = el?.textContent ?? "";
|
||||
if (text) {
|
||||
navigator.clipboard.writeText(text);
|
||||
}
|
||||
}}
|
||||
title="Copy message"
|
||||
>
|
||||
<Copy size={12} />
|
||||
</Button>
|
||||
)}
|
||||
{action}
|
||||
<ToastClose />
|
||||
</Toast>
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ interface PipelineState {
|
|||
validatorPrompt: string;
|
||||
} | null;
|
||||
isLoadingOutputs: boolean;
|
||||
isDecomposing: boolean;
|
||||
numOpRun: number;
|
||||
pipelineName: string;
|
||||
sampleSize: number | null;
|
||||
|
|
@ -59,6 +60,7 @@ interface PipelineContextType extends PipelineState {
|
|||
} | null>
|
||||
>;
|
||||
setIsLoadingOutputs: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
setIsDecomposing: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
setNumOpRun: React.Dispatch<React.SetStateAction<number>>;
|
||||
setPipelineName: React.Dispatch<React.SetStateAction<string>>;
|
||||
setSampleSize: React.Dispatch<React.SetStateAction<number | null>>;
|
||||
|
|
@ -83,6 +85,10 @@ interface PipelineContextType extends PipelineState {
|
|||
setExtraPipelineSettings: React.Dispatch<
|
||||
React.SetStateAction<Record<string, unknown> | null>
|
||||
>;
|
||||
// Ref for triggering decomposition from OperationCard (using ref to avoid infinite loops)
|
||||
onRequestDecompositionRef: React.MutableRefObject<
|
||||
((operationId: string, operationName: string) => void) | null
|
||||
>;
|
||||
}
|
||||
|
||||
const PipelineContext = createContext<PipelineContextType | undefined>(
|
||||
|
|
@ -290,6 +296,7 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
|
|||
localStorageKeys.IS_LOADING_OUTPUTS_KEY,
|
||||
false
|
||||
),
|
||||
isDecomposing: false, // Transient state, not persisted
|
||||
numOpRun: loadFromLocalStorage(localStorageKeys.NUM_OP_RUN_KEY, 0),
|
||||
pipelineName: loadFromLocalStorage(
|
||||
localStorageKeys.PIPELINE_NAME_KEY,
|
||||
|
|
@ -333,6 +340,11 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
|
|||
const stateRef = useRef(state);
|
||||
const [isMounted, setIsMounted] = useState(false);
|
||||
|
||||
// Ref for triggering decomposition from OperationCard (using ref to avoid infinite loops)
|
||||
const onRequestDecompositionRef = useRef<
|
||||
((operationId: string, operationName: string) => void) | null
|
||||
>(null);
|
||||
|
||||
useEffect(() => {
|
||||
stateRef.current = state;
|
||||
}, [state]);
|
||||
|
|
@ -423,6 +435,7 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
|
|||
output: null,
|
||||
terminalOutput: "",
|
||||
isLoadingOutputs: false,
|
||||
isDecomposing: false,
|
||||
numOpRun: 0,
|
||||
pipelineName: mockPipelineName,
|
||||
sampleSize: mockSampleSize,
|
||||
|
|
@ -529,6 +542,10 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
|
|||
(value) => setStateAndUpdate("isLoadingOutputs", value),
|
||||
[setStateAndUpdate]
|
||||
),
|
||||
setIsDecomposing: useCallback(
|
||||
(value) => setStateAndUpdate("isDecomposing", value),
|
||||
[setStateAndUpdate]
|
||||
),
|
||||
setNumOpRun: useCallback(
|
||||
(value) => setStateAndUpdate("numOpRun", value),
|
||||
[setStateAndUpdate]
|
||||
|
|
@ -589,6 +606,7 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
|
|||
(value) => setStateAndUpdate("extraPipelineSettings", value),
|
||||
[setStateAndUpdate]
|
||||
),
|
||||
onRequestDecompositionRef,
|
||||
};
|
||||
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -0,0 +1,104 @@
|
|||
import { useState, useEffect } from "react";
|
||||
import axios from "axios";
|
||||
import { DecomposeResult, DecomposeRequest } from "@/app/types";
|
||||
import { API_ROUTES } from "@/app/api/constants";
|
||||
|
||||
interface UseDecomposeProps {
|
||||
onComplete?: (result: DecomposeResult) => void;
|
||||
onError?: (error: string) => void;
|
||||
pollInterval?: number;
|
||||
}
|
||||
|
||||
export function useDecompose({
|
||||
onComplete,
|
||||
onError,
|
||||
pollInterval = 2000,
|
||||
}: UseDecomposeProps = {}) {
|
||||
const [taskId, setTaskId] = useState<string | null>(null);
|
||||
const [result, setResult] = useState<DecomposeResult | null>(null);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
const submitTask = async (request: DecomposeRequest) => {
|
||||
try {
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
setResult(null);
|
||||
|
||||
const response = await axios.post<{ task_id: string }>(
|
||||
API_ROUTES.DECOMPOSE.SUBMIT,
|
||||
request
|
||||
);
|
||||
|
||||
setTaskId(response.data.task_id);
|
||||
} catch (err) {
|
||||
const errorMessage =
|
||||
err instanceof Error ? err.message : "Failed to submit decompose task";
|
||||
setError(errorMessage);
|
||||
onError?.(errorMessage);
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const cancelTask = async () => {
|
||||
if (!taskId) return;
|
||||
|
||||
try {
|
||||
await axios.post(API_ROUTES.DECOMPOSE.CANCEL(taskId));
|
||||
setTaskId(null);
|
||||
setIsLoading(false);
|
||||
} catch (err) {
|
||||
const errorMessage =
|
||||
err instanceof Error ? err.message : "Failed to cancel task";
|
||||
setError(errorMessage);
|
||||
onError?.(errorMessage);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (!taskId) return;
|
||||
|
||||
const pollTask = async () => {
|
||||
try {
|
||||
const response = await axios.get<DecomposeResult>(
|
||||
API_ROUTES.DECOMPOSE.STATUS(taskId)
|
||||
);
|
||||
|
||||
setResult(response.data);
|
||||
|
||||
if (
|
||||
["completed", "failed", "cancelled"].includes(response.data.status)
|
||||
) {
|
||||
setTaskId(null);
|
||||
setIsLoading(false);
|
||||
|
||||
if (response.data.status === "completed") {
|
||||
onComplete?.(response.data);
|
||||
} else if (response.data.status === "failed" && response.data.error) {
|
||||
setError(response.data.error);
|
||||
onError?.(response.data.error);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
const errorMessage =
|
||||
err instanceof Error ? err.message : "Failed to fetch task status";
|
||||
setError(errorMessage);
|
||||
onError?.(errorMessage);
|
||||
setTaskId(null);
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const interval = setInterval(pollTask, pollInterval);
|
||||
return () => clearInterval(interval);
|
||||
}, [taskId, onComplete, onError, pollInterval]);
|
||||
|
||||
return {
|
||||
submitTask,
|
||||
cancelTask,
|
||||
result,
|
||||
error,
|
||||
isLoading,
|
||||
isRunning: !!taskId,
|
||||
};
|
||||
}
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
import { useState, useRef, useCallback } from "react";
|
||||
import { DecomposeResult } from "@/app/types";
|
||||
|
||||
interface DecomposeWebSocketMessage {
|
||||
type: "output" | "result" | "error";
|
||||
data?: any;
|
||||
message?: string;
|
||||
traceback?: string;
|
||||
}
|
||||
|
||||
interface UseDecomposeWebSocketProps {
|
||||
namespace: string;
|
||||
onOutput?: (output: string) => void;
|
||||
onComplete?: (result: DecomposeResult) => void;
|
||||
onError?: (error: string) => void;
|
||||
}
|
||||
|
||||
export function useDecomposeWebSocket({
|
||||
namespace,
|
||||
onOutput,
|
||||
onComplete,
|
||||
onError,
|
||||
}: UseDecomposeWebSocketProps) {
|
||||
const [isConnected, setIsConnected] = useState(false);
|
||||
const [isRunning, setIsRunning] = useState(false);
|
||||
const ws = useRef<WebSocket | null>(null);
|
||||
|
||||
const connect = useCallback((): Promise<void> => {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (ws.current?.readyState === WebSocket.OPEN) {
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
|
||||
if (!namespace) {
|
||||
reject(new Error("Namespace is required for WebSocket connection"));
|
||||
return;
|
||||
}
|
||||
|
||||
const wsUrl = `${
|
||||
process.env.NEXT_PUBLIC_BACKEND_HTTPS === "true" ? "wss://" : "ws://"
|
||||
}${process.env.NEXT_PUBLIC_BACKEND_HOST}:${
|
||||
process.env.NEXT_PUBLIC_BACKEND_PORT
|
||||
}/ws/decompose/${namespace}`;
|
||||
|
||||
ws.current = new WebSocket(wsUrl);
|
||||
|
||||
ws.current.onopen = () => {
|
||||
setIsConnected(true);
|
||||
resolve();
|
||||
};
|
||||
|
||||
ws.current.onclose = () => {
|
||||
setIsConnected(false);
|
||||
setIsRunning(false);
|
||||
};
|
||||
|
||||
ws.current.onerror = (error: Event) => {
|
||||
console.error("Decompose WebSocket error:", error);
|
||||
onError?.("WebSocket connection failed");
|
||||
reject(new Error("WebSocket connection failed"));
|
||||
};
|
||||
|
||||
ws.current.onmessage = (event) => {
|
||||
try {
|
||||
const message: DecomposeWebSocketMessage = JSON.parse(event.data);
|
||||
|
||||
if (message.type === "output") {
|
||||
onOutput?.(message.data);
|
||||
} else if (message.type === "result") {
|
||||
setIsRunning(false);
|
||||
// Convert the result data to DecomposeResult format
|
||||
const result: DecomposeResult = {
|
||||
task_id: "", // Not used in WebSocket mode
|
||||
status: "completed",
|
||||
decomposed_operations: message.data.decomposed_operations,
|
||||
winning_directive: message.data.winning_directive,
|
||||
candidates_evaluated: message.data.candidates_evaluated,
|
||||
original_outputs: message.data.original_outputs,
|
||||
decomposed_outputs: message.data.decomposed_outputs,
|
||||
comparison_rationale: message.data.comparison_rationale,
|
||||
cost: message.data.cost,
|
||||
created_at: new Date().toISOString(),
|
||||
completed_at: new Date().toISOString(),
|
||||
};
|
||||
onComplete?.(result);
|
||||
// Close the connection after receiving result
|
||||
ws.current?.close();
|
||||
} else if (message.type === "error") {
|
||||
setIsRunning(false);
|
||||
const errorMsg = message.data || message.message || "Unknown error";
|
||||
onError?.(errorMsg);
|
||||
ws.current?.close();
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error parsing WebSocket message:", error);
|
||||
onError?.("Failed to parse WebSocket message");
|
||||
}
|
||||
};
|
||||
});
|
||||
}, [namespace, onOutput, onComplete, onError]);
|
||||
|
||||
const startDecomposition = useCallback(
|
||||
async (yamlConfig: string, stepName: string, opName: string) => {
|
||||
try {
|
||||
await connect();
|
||||
setIsRunning(true);
|
||||
|
||||
if (ws.current?.readyState === WebSocket.OPEN) {
|
||||
ws.current.send(
|
||||
JSON.stringify({
|
||||
yaml_config: yamlConfig,
|
||||
step_name: stepName,
|
||||
op_name: opName,
|
||||
})
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
setIsRunning(false);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[connect]
|
||||
);
|
||||
|
||||
const cancel = useCallback(() => {
|
||||
if (ws.current?.readyState === WebSocket.OPEN) {
|
||||
ws.current.send(JSON.stringify("kill"));
|
||||
}
|
||||
}, []);
|
||||
|
||||
const disconnect = useCallback(() => {
|
||||
if (ws.current) {
|
||||
ws.current.close();
|
||||
}
|
||||
}, []);
|
||||
|
||||
return {
|
||||
isConnected,
|
||||
isRunning,
|
||||
startDecomposition,
|
||||
cancel,
|
||||
disconnect,
|
||||
};
|
||||
}
|
||||
|
|
@ -6,7 +6,8 @@ export function cn(...inputs: ClassValue[]) {
|
|||
}
|
||||
|
||||
export function canBeOptimized(operationType: string) {
|
||||
return ["resolve", "map", "reduce", "filter"].includes(operationType);
|
||||
// Only map operations can be decomposed
|
||||
return operationType === "map";
|
||||
}
|
||||
|
||||
export const generateId = () => {
|
||||
|
|
|
|||
Loading…
Reference in New Issue