Compare commits
3 Commits
cursor/emb
...
main
| Author | SHA1 | Date |
|---|---|---|
|
|
bcac6872f5 | |
|
|
57a284bcb1 | |
|
|
cfbb64470a |
|
|
@ -1,11 +1,15 @@
|
|||
__version__ = "0.2.6"
|
||||
|
||||
import warnings
|
||||
import litellm
|
||||
|
||||
from docetl.runner import DSLRunner
|
||||
from docetl.optimizer import Optimizer
|
||||
from docetl.apis.pd_accessors import SemanticAccessor
|
||||
|
||||
# Drop unsupported params for models like gpt-5 that don't support temperature=0
|
||||
litellm.drop_params = True
|
||||
|
||||
# TODO: Remove after https://github.com/BerriAI/litellm/issues/7560 is fixed
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._config")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from io import StringIO
|
||||
|
|
@ -12,6 +13,63 @@ from docetl.utils import StageType, get_stage_description
|
|||
|
||||
install(show_locals=False)
|
||||
|
||||
# ANSI escape sequences to strip (cursor movement, line clearing, cursor visibility)
|
||||
ANSI_CURSOR_PATTERNS = re.compile(
|
||||
r"\x1b\[\?25[lh]" # Hide/show cursor
|
||||
r"|\x1b\[\d*[ABCD]" # Cursor up/down/forward/back
|
||||
r"|\x1b\[\d*[JK]" # Clear screen/line (but we handle \x1b[2K specially)
|
||||
r"|\x1b\[s|\x1b\[u" # Save/restore cursor position
|
||||
)
|
||||
|
||||
|
||||
def process_carriage_returns(text: str) -> str:
|
||||
"""
|
||||
Process terminal control sequences for buffer-captured output.
|
||||
|
||||
Rich's spinner uses ANSI sequences for:
|
||||
- `\r` - carriage return (move to start of line)
|
||||
- `\x1b[2K` - clear entire line
|
||||
- `\x1b[?25l/h` - hide/show cursor
|
||||
- `\x1b[1A` - move cursor up
|
||||
|
||||
When captured to a buffer (not a real terminal), we simulate this by:
|
||||
1. Handling `\r` as "replace from start of line"
|
||||
2. Stripping cursor movement/visibility sequences
|
||||
3. Keeping only meaningful content
|
||||
"""
|
||||
# First, strip cursor visibility and movement sequences
|
||||
text = ANSI_CURSOR_PATTERNS.sub("", text)
|
||||
|
||||
# Process line by line
|
||||
lines = text.split("\n")
|
||||
processed_lines = []
|
||||
|
||||
for line in lines:
|
||||
# Handle carriage returns - keep only content after the last \r
|
||||
if "\r" in line:
|
||||
segments = line.split("\r")
|
||||
# Keep the last non-empty segment (after stripping ANSI clear codes)
|
||||
for segment in reversed(segments):
|
||||
# Strip the clear line code if present
|
||||
cleaned = segment.replace("\x1b[2K", "").strip()
|
||||
if cleaned:
|
||||
# Keep the segment but preserve any color codes
|
||||
processed_lines.append(segment.replace("\x1b[2K", ""))
|
||||
break
|
||||
else:
|
||||
# All segments were empty, skip this line entirely
|
||||
pass
|
||||
else:
|
||||
# No carriage return, keep the line as-is
|
||||
if line.strip() or line == "": # Keep empty lines between content
|
||||
processed_lines.append(line)
|
||||
|
||||
# Remove trailing empty lines
|
||||
while processed_lines and not processed_lines[-1].strip():
|
||||
processed_lines.pop()
|
||||
|
||||
return "\n".join(processed_lines)
|
||||
|
||||
|
||||
class ThreadSafeConsole(Console):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
@ -24,11 +82,18 @@ class ThreadSafeConsole(Console):
|
|||
self.optimizer_rationale = None
|
||||
|
||||
def get_output(self):
|
||||
# return self.export_text(styles=True)
|
||||
"""
|
||||
Get the output from the buffer, processing carriage returns.
|
||||
|
||||
Rich's spinner uses carriage returns to overwrite lines in place.
|
||||
We process these to only keep the latest content, preventing
|
||||
duplicate spinner frames from flooding the output.
|
||||
"""
|
||||
value = self.buffer.getvalue()
|
||||
self.buffer.truncate(0)
|
||||
self.buffer.seek(0)
|
||||
return value
|
||||
# Process carriage returns to handle spinner overwrites
|
||||
return process_carriage_returns(value)
|
||||
|
||||
def status(
|
||||
self,
|
||||
|
|
@ -36,10 +101,15 @@ class ThreadSafeConsole(Console):
|
|||
*,
|
||||
spinner: str = "dots",
|
||||
spinner_style: "StyleType" = "status.spinner",
|
||||
speed: float = 0.1, # Much slower speed
|
||||
refresh_per_second: float = 0.5, # Much slower refresh rate (every 2 seconds)
|
||||
speed: float = 1.0,
|
||||
refresh_per_second: float = 4,
|
||||
) -> "Status":
|
||||
"""
|
||||
Return a Rich Status with animation.
|
||||
|
||||
The carriage returns from the spinner animation are processed
|
||||
in get_output() to prevent duplicate lines.
|
||||
"""
|
||||
status_renderable = Status(
|
||||
status,
|
||||
console=self,
|
||||
|
|
|
|||
|
|
@ -465,7 +465,8 @@ class OpContainer:
|
|||
return cached_data, 0, curr_logs
|
||||
|
||||
# Try to load from checkpoint if available
|
||||
if not is_build:
|
||||
# Skip if this operation has bypass_cache: true
|
||||
if not is_build and not self.config.get("bypass_cache", False):
|
||||
attempted_input_data = self.runner._load_from_checkpoint_if_exists(
|
||||
self.name.split("/")[0], self.name.split("/")[-1]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -849,13 +849,9 @@ class MOARSearch:
|
|||
response = litellm.completion(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=ExpandResponseFormat,
|
||||
)
|
||||
call_cost = response._hidden_params["response_cost"]
|
||||
call_cost = response._hidden_params.get("response_cost", 0)
|
||||
with self.tree_lock:
|
||||
self.console.log(
|
||||
f"[green]💰 Adding LLM call cost:[/green] ${call_cost:.4f} (total before: ${self.total_search_cost:.4f})"
|
||||
|
|
|
|||
|
|
@ -153,10 +153,11 @@ def run_moar_optimization(
|
|||
"optimizer_config must contain 'max_iterations' (number of MOARSearch iterations) for MOAR optimizer"
|
||||
)
|
||||
|
||||
model = optimizer_config.get("model")
|
||||
# Use rewrite_agent_model for consistency with rest of codebase, fallback to model for backwards compatibility
|
||||
model = optimizer_config.get("rewrite_agent_model") or optimizer_config.get("model")
|
||||
if not model:
|
||||
raise ValueError(
|
||||
"optimizer_config must contain 'model' (LLM model name for directive instantiation) for MOAR optimizer"
|
||||
"optimizer_config must contain 'rewrite_agent_model' (LLM model name for directive instantiation) for MOAR optimizer"
|
||||
)
|
||||
|
||||
# Optional parameters
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from rich.prompt import Confirm
|
|||
|
||||
from docetl.operations.base import BaseOperation
|
||||
from docetl.operations.utils import strict_render
|
||||
from docetl.operations.utils.blocking import RuntimeBlockingOptimizer
|
||||
from docetl.operations.utils.progress import RichLoopBar
|
||||
from docetl.utils import (
|
||||
completion_cost,
|
||||
|
|
@ -63,6 +64,7 @@ class EquijoinOperation(BaseOperation):
|
|||
comparison_prompt: str
|
||||
output: dict[str, Any] | None = None
|
||||
blocking_threshold: float | None = None
|
||||
blocking_target_recall: float | None = None
|
||||
blocking_conditions: list[str] | None = None
|
||||
limits: dict[str, int] | None = None
|
||||
comparison_model: str | None = None
|
||||
|
|
@ -250,6 +252,58 @@ class EquijoinOperation(BaseOperation):
|
|||
if self.status:
|
||||
self.status.stop()
|
||||
|
||||
# Track pre-computed embeddings from auto-optimization
|
||||
precomputed_left_embeddings = None
|
||||
precomputed_right_embeddings = None
|
||||
|
||||
# Auto-compute blocking threshold if no blocking configuration is provided
|
||||
if not blocking_threshold and not blocking_conditions and not limit_comparisons:
|
||||
# Get target recall from operation config (default 0.95)
|
||||
target_recall = self.config.get("blocking_target_recall", 0.95)
|
||||
self.console.log(
|
||||
f"[yellow]No blocking configuration. Auto-computing threshold (target recall: {target_recall:.0%})...[/yellow]"
|
||||
)
|
||||
|
||||
# Create comparison function for threshold optimization
|
||||
def compare_fn_for_optimization(left_item, right_item):
|
||||
return self.compare_pair(
|
||||
self.config["comparison_prompt"],
|
||||
self.config.get("comparison_model", self.default_model),
|
||||
left_item,
|
||||
right_item,
|
||||
timeout_seconds=self.config.get("timeout", 120),
|
||||
max_retries_per_timeout=self.config.get(
|
||||
"max_retries_per_timeout", 2
|
||||
),
|
||||
)
|
||||
|
||||
# Run threshold optimization
|
||||
optimizer = RuntimeBlockingOptimizer(
|
||||
runner=self.runner,
|
||||
config=self.config,
|
||||
default_model=self.default_model,
|
||||
max_threads=self.max_threads,
|
||||
console=self.console,
|
||||
target_recall=target_recall,
|
||||
sample_size=min(100, len(left_data) * len(right_data) // 4),
|
||||
)
|
||||
(
|
||||
blocking_threshold,
|
||||
precomputed_left_embeddings,
|
||||
precomputed_right_embeddings,
|
||||
optimization_cost,
|
||||
) = optimizer.optimize_equijoin(
|
||||
left_data,
|
||||
right_data,
|
||||
compare_fn_for_optimization,
|
||||
left_keys=left_keys,
|
||||
right_keys=right_keys,
|
||||
)
|
||||
total_cost += optimization_cost
|
||||
self.console.log(
|
||||
f"[green]Using auto-computed blocking threshold: {blocking_threshold}[/green]"
|
||||
)
|
||||
|
||||
# Initial blocking using multiprocessing
|
||||
num_processes = min(cpu_count(), len(left_data))
|
||||
|
||||
|
|
@ -298,45 +352,60 @@ class EquijoinOperation(BaseOperation):
|
|||
)
|
||||
|
||||
if blocking_threshold is not None:
|
||||
embedding_model = self.config.get("embedding_model", self.default_model)
|
||||
model_input_context_length = model_cost.get(embedding_model, {}).get(
|
||||
"max_input_tokens", 8192
|
||||
)
|
||||
|
||||
def get_embeddings(
|
||||
input_data: list[dict[str, Any]], keys: list[str], name: str
|
||||
) -> tuple[list[list[float]], float]:
|
||||
texts = [
|
||||
" ".join(str(item[key]) for key in keys if key in item)[
|
||||
: model_input_context_length * 4
|
||||
]
|
||||
for item in input_data
|
||||
]
|
||||
|
||||
embeddings = []
|
||||
total_cost = 0
|
||||
# Use precomputed embeddings if available from auto-optimization
|
||||
if (
|
||||
precomputed_left_embeddings is not None
|
||||
and precomputed_right_embeddings is not None
|
||||
):
|
||||
left_embeddings = precomputed_left_embeddings
|
||||
right_embeddings = precomputed_right_embeddings
|
||||
else:
|
||||
embedding_model = self.config.get("embedding_model", self.default_model)
|
||||
model_input_context_length = model_cost.get(embedding_model, {}).get(
|
||||
"max_input_tokens", 8192
|
||||
)
|
||||
batch_size = 2000
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
self.console.log(
|
||||
f"On iteration {i} for creating embeddings for {name} data"
|
||||
)
|
||||
response = self.runner.api.gen_embedding(
|
||||
model=embedding_model,
|
||||
input=batch,
|
||||
)
|
||||
embeddings.extend([data["embedding"] for data in response["data"]])
|
||||
total_cost += completion_cost(response)
|
||||
return embeddings, total_cost
|
||||
|
||||
left_embeddings, left_cost = get_embeddings(left_data, left_keys, "left")
|
||||
right_embeddings, right_cost = get_embeddings(
|
||||
right_data, right_keys, "right"
|
||||
)
|
||||
total_cost += left_cost + right_cost
|
||||
self.console.log(
|
||||
f"Created embeddings for datasets. Total embedding creation cost: {total_cost}"
|
||||
)
|
||||
def get_embeddings(
|
||||
input_data: list[dict[str, Any]], keys: list[str], name: str
|
||||
) -> tuple[list[list[float]], float]:
|
||||
texts = [
|
||||
" ".join(str(item[key]) for key in keys if key in item)[
|
||||
: model_input_context_length * 4
|
||||
]
|
||||
for item in input_data
|
||||
]
|
||||
embeddings = []
|
||||
embedding_cost = 0
|
||||
num_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
|
||||
for batch_idx, i in enumerate(range(0, len(texts), batch_size)):
|
||||
batch = texts[i : i + batch_size]
|
||||
if num_batches > 1:
|
||||
self.console.log(
|
||||
f"[dim]Creating {name} embeddings: batch {batch_idx + 1}/{num_batches} "
|
||||
f"({min(i + batch_size, len(texts))}/{len(texts)} items)[/dim]"
|
||||
)
|
||||
response = self.runner.api.gen_embedding(
|
||||
model=embedding_model,
|
||||
input=batch,
|
||||
)
|
||||
embeddings.extend(
|
||||
[data["embedding"] for data in response["data"]]
|
||||
)
|
||||
embedding_cost += completion_cost(response)
|
||||
return embeddings, embedding_cost
|
||||
|
||||
self.console.log(
|
||||
f"[cyan]Creating embeddings for {len(left_data)} left + {len(right_data)} right items...[/cyan]"
|
||||
)
|
||||
left_embeddings, left_cost = get_embeddings(
|
||||
left_data, left_keys, "left"
|
||||
)
|
||||
right_embeddings, right_cost = get_embeddings(
|
||||
right_data, right_keys, "right"
|
||||
)
|
||||
total_cost += left_cost + right_cost
|
||||
|
||||
# Compute all cosine similarities in one call
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ import jinja2
|
|||
from jinja2 import Template
|
||||
from litellm import model_cost
|
||||
from pydantic import Field, ValidationInfo, field_validator, model_validator
|
||||
from rich.prompt import Confirm
|
||||
|
||||
from docetl.operations.base import BaseOperation
|
||||
from docetl.operations.utils import RichLoopBar, rich_as_completed, strict_render
|
||||
from docetl.operations.utils.blocking import RuntimeBlockingOptimizer
|
||||
from docetl.utils import (
|
||||
completion_cost,
|
||||
extract_jinja_variables,
|
||||
|
|
@ -40,6 +40,7 @@ class ResolveOperation(BaseOperation):
|
|||
comparison_model: str | None = None
|
||||
blocking_keys: list[str] | None = None
|
||||
blocking_threshold: float | None = Field(None, ge=0, le=1)
|
||||
blocking_target_recall: float | None = Field(None, ge=0, le=1)
|
||||
blocking_conditions: list[str] | None = None
|
||||
input: dict[str, Any] | None = None
|
||||
embedding_batch_size: int | None = Field(None, gt=0)
|
||||
|
|
@ -266,26 +267,75 @@ class ResolveOperation(BaseOperation):
|
|||
blocking_keys = self.config.get("blocking_keys", [])
|
||||
blocking_threshold = self.config.get("blocking_threshold")
|
||||
blocking_conditions = self.config.get("blocking_conditions", [])
|
||||
limit_comparisons = self.config.get("limit_comparisons")
|
||||
total_cost = 0
|
||||
if self.status:
|
||||
self.status.stop()
|
||||
|
||||
if not blocking_threshold and not blocking_conditions:
|
||||
# Prompt the user for confirmation
|
||||
if not Confirm.ask(
|
||||
"[yellow]Warning: No blocking keys or conditions specified. "
|
||||
"This may result in a large number of comparisons. "
|
||||
"We recommend specifying at least one blocking key or condition, or using the optimizer to automatically come up with these. "
|
||||
"Do you want to continue without blocking?[/yellow]",
|
||||
console=self.runner.console,
|
||||
):
|
||||
raise ValueError("Operation cancelled by user.")
|
||||
# Track pre-computed embeddings from auto-optimization
|
||||
precomputed_embeddings = None
|
||||
|
||||
# Auto-compute blocking threshold if no blocking configuration is provided
|
||||
if not blocking_threshold and not blocking_conditions and not limit_comparisons:
|
||||
# Get target recall from operation config (default 0.95)
|
||||
target_recall = self.config.get("blocking_target_recall", 0.95)
|
||||
self.console.log(
|
||||
f"[yellow]No blocking configuration. Auto-computing threshold (target recall: {target_recall:.0%})...[/yellow]"
|
||||
)
|
||||
# Determine blocking keys if not set
|
||||
auto_blocking_keys = blocking_keys if blocking_keys else None
|
||||
if not auto_blocking_keys:
|
||||
prompt_template = self.config.get("comparison_prompt", "")
|
||||
prompt_vars = extract_jinja_variables(prompt_template)
|
||||
prompt_vars = [
|
||||
var
|
||||
for var in prompt_vars
|
||||
if var not in ["input", "input1", "input2"]
|
||||
]
|
||||
auto_blocking_keys = list(
|
||||
set([var.split(".")[-1] for var in prompt_vars])
|
||||
)
|
||||
if not auto_blocking_keys:
|
||||
auto_blocking_keys = list(input_data[0].keys())
|
||||
blocking_keys = auto_blocking_keys
|
||||
|
||||
# Create comparison function for threshold optimization
|
||||
def compare_fn_for_optimization(item1, item2):
|
||||
return self.compare_pair(
|
||||
self.config["comparison_prompt"],
|
||||
self.config.get("comparison_model", self.default_model),
|
||||
item1,
|
||||
item2,
|
||||
blocking_keys=[], # Don't use key-based shortcut during optimization
|
||||
timeout_seconds=self.config.get("timeout", 120),
|
||||
max_retries_per_timeout=self.config.get(
|
||||
"max_retries_per_timeout", 2
|
||||
),
|
||||
)
|
||||
|
||||
# Run threshold optimization
|
||||
optimizer = RuntimeBlockingOptimizer(
|
||||
runner=self.runner,
|
||||
config=self.config,
|
||||
default_model=self.default_model,
|
||||
max_threads=self.max_threads,
|
||||
console=self.console,
|
||||
target_recall=target_recall,
|
||||
sample_size=min(100, len(input_data) * (len(input_data) - 1) // 4),
|
||||
)
|
||||
blocking_threshold, precomputed_embeddings, optimization_cost = (
|
||||
optimizer.optimize_resolve(
|
||||
input_data,
|
||||
compare_fn_for_optimization,
|
||||
blocking_keys=blocking_keys,
|
||||
)
|
||||
)
|
||||
total_cost += optimization_cost
|
||||
|
||||
input_schema = self.config.get("input", {}).get("schema", {})
|
||||
if not blocking_keys:
|
||||
# Set them to all keys in the input data
|
||||
blocking_keys = list(input_data[0].keys())
|
||||
limit_comparisons = self.config.get("limit_comparisons")
|
||||
total_cost = 0
|
||||
|
||||
def is_match(item1: dict[str, Any], item2: dict[str, Any]) -> bool:
|
||||
return any(
|
||||
|
|
@ -296,120 +346,101 @@ class ResolveOperation(BaseOperation):
|
|||
# Calculate embeddings if blocking_threshold is set
|
||||
embeddings = None
|
||||
if blocking_threshold is not None:
|
||||
|
||||
def get_embeddings_batch(
|
||||
items: list[dict[str, Any]]
|
||||
) -> list[tuple[list[float], float]]:
|
||||
# Use precomputed embeddings if available from auto-optimization
|
||||
if precomputed_embeddings is not None:
|
||||
embeddings = precomputed_embeddings
|
||||
else:
|
||||
self.console.log(
|
||||
f"[cyan]Creating embeddings for {len(input_data)} items...[/cyan]"
|
||||
)
|
||||
embedding_model = self.config.get(
|
||||
"embedding_model", "text-embedding-3-small"
|
||||
)
|
||||
model_input_context_length = model_cost.get(embedding_model, {}).get(
|
||||
"max_input_tokens", 8192
|
||||
)
|
||||
batch_size = self.config.get("embedding_batch_size", 1000)
|
||||
embeddings = []
|
||||
embedding_cost = 0.0
|
||||
num_batches = (len(input_data) + batch_size - 1) // batch_size
|
||||
|
||||
texts = [
|
||||
" ".join(str(item[key]) for key in blocking_keys if key in item)[
|
||||
: model_input_context_length * 3
|
||||
for batch_idx in range(num_batches):
|
||||
start_idx = batch_idx * batch_size
|
||||
end_idx = min(start_idx + batch_size, len(input_data))
|
||||
batch = input_data[start_idx:end_idx]
|
||||
|
||||
if num_batches > 1:
|
||||
self.console.log(
|
||||
f"[dim]Creating embeddings: batch {batch_idx + 1}/{num_batches} "
|
||||
f"({end_idx}/{len(input_data)} items)[/dim]"
|
||||
)
|
||||
|
||||
texts = [
|
||||
" ".join(
|
||||
str(item[key]) for key in blocking_keys if key in item
|
||||
)[: model_input_context_length * 3]
|
||||
for item in batch
|
||||
]
|
||||
for item in items
|
||||
]
|
||||
response = self.runner.api.gen_embedding(
|
||||
model=embedding_model, input=texts
|
||||
)
|
||||
embeddings.extend([data["embedding"] for data in response["data"]])
|
||||
embedding_cost += completion_cost(response)
|
||||
|
||||
response = self.runner.api.gen_embedding(
|
||||
model=embedding_model, input=texts
|
||||
)
|
||||
return [
|
||||
(data["embedding"], completion_cost(response))
|
||||
for data in response["data"]
|
||||
]
|
||||
total_cost += embedding_cost
|
||||
|
||||
embeddings = []
|
||||
costs = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
|
||||
for i in range(
|
||||
0, len(input_data), self.config.get("embedding_batch_size", 1000)
|
||||
):
|
||||
batch = input_data[
|
||||
i : i + self.config.get("embedding_batch_size", 1000)
|
||||
]
|
||||
batch_results = list(executor.map(get_embeddings_batch, [batch]))
|
||||
# Build a mapping of blocking key values to indices
|
||||
# This is used later for cluster merging (when two items match, merge all items sharing their key values)
|
||||
value_to_indices: dict[tuple[str, ...], list[int]] = {}
|
||||
for i, item in enumerate(input_data):
|
||||
key = tuple(str(item.get(k, "")) for k in blocking_keys)
|
||||
if key not in value_to_indices:
|
||||
value_to_indices[key] = []
|
||||
value_to_indices[key].append(i)
|
||||
|
||||
for result in batch_results:
|
||||
embeddings.extend([r[0] for r in result])
|
||||
costs.extend([r[1] for r in result])
|
||||
# Total number of pairs to potentially compare
|
||||
n = len(input_data)
|
||||
total_pairs = n * (n - 1) // 2
|
||||
|
||||
total_cost += sum(costs)
|
||||
|
||||
# Generate all pairs to compare, ensuring no duplicate comparisons
|
||||
def get_unique_comparison_pairs() -> (
|
||||
tuple[list[tuple[int, int]], dict[tuple[str, ...], list[int]]]
|
||||
):
|
||||
# Create a mapping of values to their indices
|
||||
value_to_indices: dict[tuple[str, ...], list[int]] = {}
|
||||
for i, item in enumerate(input_data):
|
||||
# Create a hashable key from the blocking keys
|
||||
key = tuple(str(item.get(k, "")) for k in blocking_keys)
|
||||
if key not in value_to_indices:
|
||||
value_to_indices[key] = []
|
||||
value_to_indices[key].append(i)
|
||||
|
||||
# Generate pairs for comparison, comparing each unique value combination only once
|
||||
comparison_pairs = []
|
||||
keys = list(value_to_indices.keys())
|
||||
|
||||
# First, handle comparisons between different values
|
||||
for i in range(len(keys)):
|
||||
for j in range(i + 1, len(keys)):
|
||||
# Only need one comparison between different values
|
||||
idx1 = value_to_indices[keys[i]][0]
|
||||
idx2 = value_to_indices[keys[j]][0]
|
||||
if idx1 < idx2: # Maintain ordering to avoid duplicates
|
||||
comparison_pairs.append((idx1, idx2))
|
||||
|
||||
return comparison_pairs, value_to_indices
|
||||
|
||||
comparison_pairs, value_to_indices = get_unique_comparison_pairs()
|
||||
|
||||
# Filter pairs based on blocking conditions
|
||||
def meets_blocking_conditions(pair: tuple[int, int]) -> bool:
|
||||
i, j = pair
|
||||
return (
|
||||
is_match(input_data[i], input_data[j]) if blocking_conditions else False
|
||||
)
|
||||
|
||||
# Start with pairs that meet blocking conditions, or empty list if no conditions
|
||||
code_blocked_pairs = (
|
||||
list(filter(meets_blocking_conditions, comparison_pairs))
|
||||
if blocking_conditions
|
||||
else []
|
||||
)
|
||||
# Apply code-based blocking conditions (check all pairs)
|
||||
code_blocked_pairs = []
|
||||
if blocking_conditions:
|
||||
for i in range(n):
|
||||
for j in range(i + 1, n):
|
||||
if is_match(input_data[i], input_data[j]):
|
||||
code_blocked_pairs.append((i, j))
|
||||
|
||||
# Apply cosine similarity blocking if threshold is specified
|
||||
embedding_blocked_pairs = []
|
||||
if blocking_threshold is not None and embeddings is not None:
|
||||
import numpy as np
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
similarity_matrix = cosine_similarity(embeddings)
|
||||
|
||||
# Add pairs that meet the cosine similarity threshold and aren't already blocked
|
||||
code_blocked_set = set(code_blocked_pairs)
|
||||
|
||||
for i, j in comparison_pairs:
|
||||
if (i, j) not in code_blocked_set:
|
||||
similarity = similarity_matrix[i, j]
|
||||
if similarity >= blocking_threshold:
|
||||
embedding_blocked_pairs.append((i, j))
|
||||
# Use numpy to efficiently find all pairs above threshold
|
||||
i_indices, j_indices = np.triu_indices(n, k=1)
|
||||
similarities = similarity_matrix[i_indices, j_indices]
|
||||
above_threshold_mask = similarities >= blocking_threshold
|
||||
|
||||
self.console.log(
|
||||
f"Cosine similarity blocking: added {len(embedding_blocked_pairs)} pairs "
|
||||
f"(threshold: {blocking_threshold})"
|
||||
)
|
||||
# Get pairs above threshold
|
||||
above_threshold_i = i_indices[above_threshold_mask]
|
||||
above_threshold_j = j_indices[above_threshold_mask]
|
||||
|
||||
# Combine pairs with prioritization for sampling
|
||||
# Filter out pairs already in code_blocked_set
|
||||
embedding_blocked_pairs = [
|
||||
(int(i), int(j))
|
||||
for i, j in zip(above_threshold_i, above_threshold_j)
|
||||
if (i, j) not in code_blocked_set
|
||||
]
|
||||
|
||||
# Combine pairs from both blocking methods
|
||||
all_blocked_pairs = code_blocked_pairs + embedding_blocked_pairs
|
||||
|
||||
# If no pairs are blocked at all, fall back to all comparison pairs
|
||||
if not all_blocked_pairs:
|
||||
all_blocked_pairs = comparison_pairs
|
||||
# If no blocking was applied, compare all pairs
|
||||
if not blocking_conditions and blocking_threshold is None:
|
||||
all_blocked_pairs = [(i, j) for i in range(n) for j in range(i + 1, n)]
|
||||
# Apply limit_comparisons with prioritization
|
||||
if limit_comparisons is not None and len(all_blocked_pairs) > limit_comparisons:
|
||||
# Prioritize code-based pairs, then sample from embedding pairs if needed
|
||||
|
|
@ -476,18 +507,6 @@ class ResolveOperation(BaseOperation):
|
|||
cluster_map[root_idx] = root1
|
||||
clusters[root_idx] = set()
|
||||
|
||||
# Calculate and print statistics
|
||||
total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2
|
||||
comparisons_made = len(blocked_pairs)
|
||||
comparisons_saved = total_possible_comparisons - comparisons_made
|
||||
self.console.log(
|
||||
f"[green]Comparisons saved by deduping and blocking: {comparisons_saved} "
|
||||
f"({(comparisons_saved / total_possible_comparisons) * 100:.2f}%)[/green]"
|
||||
)
|
||||
self.console.log(
|
||||
f"[blue]Number of pairs to compare: {len(blocked_pairs)}[/blue]"
|
||||
)
|
||||
|
||||
# Compute an auto-batch size based on the number of comparisons
|
||||
def auto_batch() -> int:
|
||||
# Maximum batch size limit for 4o-mini model
|
||||
|
|
@ -513,7 +532,14 @@ class ResolveOperation(BaseOperation):
|
|||
|
||||
# Compare pairs and update clusters in real-time
|
||||
batch_size = self.config.get("compare_batch_size", auto_batch())
|
||||
self.console.log(f"Using compare batch size: {batch_size}")
|
||||
|
||||
# Log blocking summary
|
||||
total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2
|
||||
self.console.log(
|
||||
f"Comparing {len(blocked_pairs):,} pairs "
|
||||
f"({len(blocked_pairs)/total_possible_comparisons*100:.1f}% of {total_possible_comparisons:,} total, "
|
||||
f"batch size: {batch_size})"
|
||||
)
|
||||
pair_costs = 0
|
||||
|
||||
pbar = RichLoopBar(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from .api import APIWrapper
|
||||
from .blocking import RuntimeBlockingOptimizer
|
||||
from .cache import (
|
||||
cache,
|
||||
cache_key,
|
||||
|
|
@ -15,6 +16,7 @@ from .validation import safe_eval, convert_val, convert_dict_schema_to_list_sche
|
|||
|
||||
__all__ = [
|
||||
'APIWrapper',
|
||||
'RuntimeBlockingOptimizer',
|
||||
'cache',
|
||||
'cache_key',
|
||||
'clear_cache',
|
||||
|
|
|
|||
|
|
@ -0,0 +1,567 @@
|
|||
"""
|
||||
Runtime blocking threshold optimization utilities.
|
||||
|
||||
This module provides functionality for automatically computing embedding-based
|
||||
blocking thresholds at runtime when no blocking configuration is provided.
|
||||
"""
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
from litellm import model_cost
|
||||
from rich.console import Console
|
||||
|
||||
from docetl.utils import completion_cost, extract_jinja_variables
|
||||
|
||||
|
||||
class RuntimeBlockingOptimizer:
|
||||
"""
|
||||
Computes optimal embedding-based blocking thresholds at runtime.
|
||||
|
||||
This class samples pairs from the dataset, performs LLM comparisons,
|
||||
and finds the optimal cosine similarity threshold that achieves a
|
||||
target recall rate.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runner,
|
||||
config: dict[str, Any],
|
||||
default_model: str,
|
||||
max_threads: int,
|
||||
console: Console,
|
||||
target_recall: float = 0.95,
|
||||
sample_size: int = 100,
|
||||
sampling_weight: float = 20.0,
|
||||
):
|
||||
"""
|
||||
Initialize the RuntimeBlockingOptimizer.
|
||||
|
||||
Args:
|
||||
runner: The pipeline runner instance.
|
||||
config: Operation configuration.
|
||||
default_model: Default LLM model for comparisons.
|
||||
max_threads: Maximum threads for parallel processing.
|
||||
console: Rich console for logging.
|
||||
target_recall: Target recall rate (default 0.95).
|
||||
sample_size: Number of pairs to sample for threshold estimation.
|
||||
sampling_weight: Weight for exponential sampling towards higher similarities.
|
||||
"""
|
||||
self.runner = runner
|
||||
self.config = config
|
||||
self.default_model = default_model
|
||||
self.max_threads = max_threads
|
||||
self.console = console
|
||||
self.target_recall = target_recall
|
||||
self.sample_size = sample_size
|
||||
self.sampling_weight = sampling_weight
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
input_data: list[dict[str, Any]],
|
||||
keys: list[str],
|
||||
embedding_model: str | None = None,
|
||||
batch_size: int = 1000,
|
||||
) -> tuple[list[list[float]], float]:
|
||||
"""
|
||||
Compute embeddings for the input data.
|
||||
|
||||
Args:
|
||||
input_data: List of input documents.
|
||||
keys: Keys to use for embedding text.
|
||||
embedding_model: Model to use for embeddings.
|
||||
batch_size: Batch size for embedding computation.
|
||||
|
||||
Returns:
|
||||
Tuple of (embeddings list, total cost).
|
||||
"""
|
||||
embedding_model = embedding_model or self.config.get(
|
||||
"embedding_model", "text-embedding-3-small"
|
||||
)
|
||||
model_input_context_length = model_cost.get(embedding_model, {}).get(
|
||||
"max_input_tokens", 8192
|
||||
)
|
||||
texts = [
|
||||
" ".join(str(item[key]) for key in keys if key in item)[
|
||||
: model_input_context_length * 3
|
||||
]
|
||||
for item in input_data
|
||||
]
|
||||
|
||||
self.console.log(f"[cyan]Creating embeddings for {len(texts)} items...[/cyan]")
|
||||
|
||||
embeddings = []
|
||||
total_cost = 0.0
|
||||
num_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
for batch_idx, i in enumerate(range(0, len(texts), batch_size)):
|
||||
batch = texts[i : i + batch_size]
|
||||
if num_batches > 1:
|
||||
self.console.log(
|
||||
f"[dim] Batch {batch_idx + 1}/{num_batches} "
|
||||
f"({len(embeddings) + len(batch)}/{len(texts)} items)[/dim]"
|
||||
)
|
||||
response = self.runner.api.gen_embedding(
|
||||
model=embedding_model,
|
||||
input=batch,
|
||||
)
|
||||
embeddings.extend([data["embedding"] for data in response["data"]])
|
||||
total_cost += completion_cost(response)
|
||||
return embeddings, total_cost
|
||||
|
||||
def calculate_cosine_similarities_self(
|
||||
self, embeddings: list[list[float]]
|
||||
) -> list[tuple[int, int, float]]:
|
||||
"""
|
||||
Calculate pairwise cosine similarities for self-join.
|
||||
|
||||
Args:
|
||||
embeddings: List of embedding vectors.
|
||||
|
||||
Returns:
|
||||
List of (i, j, similarity) tuples for all pairs where i < j.
|
||||
"""
|
||||
embeddings_array = np.array(embeddings)
|
||||
norms = np.linalg.norm(embeddings_array, axis=1)
|
||||
# Avoid division by zero
|
||||
norms = np.where(norms == 0, 1e-10, norms)
|
||||
dot_products = np.dot(embeddings_array, embeddings_array.T)
|
||||
similarities_matrix = dot_products / np.outer(norms, norms)
|
||||
i, j = np.triu_indices(len(embeddings), k=1)
|
||||
similarities = list(
|
||||
zip(i.tolist(), j.tolist(), similarities_matrix[i, j].tolist())
|
||||
)
|
||||
return similarities
|
||||
|
||||
def calculate_cosine_similarities_cross(
|
||||
self,
|
||||
left_embeddings: list[list[float]],
|
||||
right_embeddings: list[list[float]],
|
||||
) -> list[tuple[int, int, float]]:
|
||||
"""
|
||||
Calculate cosine similarities between two sets of embeddings.
|
||||
|
||||
Args:
|
||||
left_embeddings: Embeddings for left dataset.
|
||||
right_embeddings: Embeddings for right dataset.
|
||||
|
||||
Returns:
|
||||
List of (left_idx, right_idx, similarity) tuples.
|
||||
"""
|
||||
left_array = np.array(left_embeddings)
|
||||
right_array = np.array(right_embeddings)
|
||||
dot_product = np.dot(left_array, right_array.T)
|
||||
norm_left = np.linalg.norm(left_array, axis=1)
|
||||
norm_right = np.linalg.norm(right_array, axis=1)
|
||||
# Avoid division by zero
|
||||
norm_left = np.where(norm_left == 0, 1e-10, norm_left)
|
||||
norm_right = np.where(norm_right == 0, 1e-10, norm_right)
|
||||
similarities = dot_product / np.outer(norm_left, norm_right)
|
||||
return [
|
||||
(i, j, float(sim))
|
||||
for i, row in enumerate(similarities)
|
||||
for j, sim in enumerate(row)
|
||||
]
|
||||
|
||||
def sample_pairs(
|
||||
self,
|
||||
similarities: list[tuple[int, int, float]],
|
||||
num_bins: int = 10,
|
||||
stratified_fraction: float = 0.5,
|
||||
) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Sample pairs using a hybrid of stratified and exponential-weighted sampling.
|
||||
|
||||
This ensures coverage across the similarity distribution while still
|
||||
focusing on high-similarity pairs where matches are more likely.
|
||||
|
||||
Args:
|
||||
similarities: List of (i, j, similarity) tuples.
|
||||
num_bins: Number of bins for stratified sampling.
|
||||
stratified_fraction: Fraction of samples to allocate to stratified sampling.
|
||||
|
||||
Returns:
|
||||
List of sampled (i, j) pairs.
|
||||
"""
|
||||
if len(similarities) == 0:
|
||||
return []
|
||||
|
||||
sample_count = min(self.sample_size, len(similarities))
|
||||
stratified_count = int(sample_count * stratified_fraction)
|
||||
exponential_count = sample_count - stratified_count
|
||||
|
||||
sampled_indices = set()
|
||||
sim_values = np.array([sim[2] for sim in similarities])
|
||||
|
||||
# Part 1: Stratified sampling across bins
|
||||
if stratified_count > 0:
|
||||
bin_edges = np.linspace(
|
||||
sim_values.min(), sim_values.max() + 1e-9, num_bins + 1
|
||||
)
|
||||
samples_per_bin = max(1, stratified_count // num_bins)
|
||||
|
||||
for bin_idx in range(num_bins):
|
||||
bin_mask = (sim_values >= bin_edges[bin_idx]) & (
|
||||
sim_values < bin_edges[bin_idx + 1]
|
||||
)
|
||||
bin_indices = np.where(bin_mask)[0]
|
||||
|
||||
if len(bin_indices) > 0:
|
||||
# Within each bin, use exponential weighting
|
||||
bin_sims = sim_values[bin_indices]
|
||||
bin_weights = np.exp(self.sampling_weight * bin_sims)
|
||||
bin_weights /= bin_weights.sum()
|
||||
|
||||
n_to_sample = min(samples_per_bin, len(bin_indices))
|
||||
chosen = np.random.choice(
|
||||
bin_indices,
|
||||
size=n_to_sample,
|
||||
replace=False,
|
||||
p=bin_weights,
|
||||
)
|
||||
sampled_indices.update(chosen.tolist())
|
||||
|
||||
# Part 2: Exponential-weighted sampling for remaining slots
|
||||
if exponential_count > 0:
|
||||
remaining_indices = [
|
||||
i for i in range(len(similarities)) if i not in sampled_indices
|
||||
]
|
||||
if remaining_indices:
|
||||
remaining_sims = sim_values[remaining_indices]
|
||||
weights = np.exp(self.sampling_weight * remaining_sims)
|
||||
weights /= weights.sum()
|
||||
|
||||
n_to_sample = min(exponential_count, len(remaining_indices))
|
||||
chosen = np.random.choice(
|
||||
remaining_indices,
|
||||
size=n_to_sample,
|
||||
replace=False,
|
||||
p=weights,
|
||||
)
|
||||
sampled_indices.update(chosen.tolist())
|
||||
|
||||
sampled_pairs = [
|
||||
(similarities[i][0], similarities[i][1]) for i in sampled_indices
|
||||
]
|
||||
return sampled_pairs
|
||||
|
||||
def _print_similarity_histogram(
|
||||
self,
|
||||
similarities: list[tuple[int, int, float]],
|
||||
comparison_results: list[tuple[int, int, bool]],
|
||||
threshold: float | None = None,
|
||||
):
|
||||
"""
|
||||
Print a histogram of embedding cosine similarity distribution.
|
||||
|
||||
Args:
|
||||
similarities: List of (i, j, similarity) tuples.
|
||||
comparison_results: List of (i, j, is_match) from LLM comparisons.
|
||||
threshold: Optional threshold to highlight in the histogram.
|
||||
"""
|
||||
# Filter out self-similarities (similarity == 1)
|
||||
flat_similarities = [sim[2] for sim in similarities if sim[2] != 1]
|
||||
if not flat_similarities:
|
||||
return
|
||||
|
||||
hist, bin_edges = np.histogram(flat_similarities, bins=20)
|
||||
max_bar_width, max_count = 40, max(hist) if max(hist) > 0 else 1
|
||||
normalized_hist = [int(count / max_count * max_bar_width) for count in hist]
|
||||
|
||||
# Create a dictionary to store true labels
|
||||
true_labels = {(i, j): is_match for i, j, is_match in comparison_results}
|
||||
|
||||
# Count pairs above threshold
|
||||
pairs_above_threshold = (
|
||||
sum(1 for sim in flat_similarities if sim >= threshold) if threshold else 0
|
||||
)
|
||||
total_pairs = len(flat_similarities)
|
||||
|
||||
lines = []
|
||||
for i, count in enumerate(normalized_hist):
|
||||
bar = "█" * count
|
||||
bin_start, bin_end = bin_edges[i], bin_edges[i + 1]
|
||||
label = f"{bin_start:.2f}-{bin_end:.2f}"
|
||||
|
||||
# Count true matches and not matches in this bin
|
||||
true_matches = 0
|
||||
not_matches = 0
|
||||
labeled_count = 0
|
||||
for sim in similarities:
|
||||
if bin_start <= sim[2] < bin_end:
|
||||
if (sim[0], sim[1]) in true_labels:
|
||||
labeled_count += 1
|
||||
if true_labels[(sim[0], sim[1])]:
|
||||
true_matches += 1
|
||||
else:
|
||||
not_matches += 1
|
||||
|
||||
# Calculate percentages of labeled pairs
|
||||
if labeled_count > 0:
|
||||
true_match_percent = (true_matches / labeled_count) * 100
|
||||
label_info = f"[green]{true_match_percent:5.1f}%[/green] match"
|
||||
else:
|
||||
label_info = "[dim]--[/dim]"
|
||||
|
||||
# Highlight the bin containing the threshold
|
||||
if threshold is not None and bin_start <= threshold < bin_end:
|
||||
lines.append(
|
||||
f"[bold yellow]{label}[/bold yellow] {bar:<{max_bar_width}} "
|
||||
f"[dim]n={hist[i]:>5}[/dim] {label_info} [bold yellow]◀ threshold[/bold yellow]"
|
||||
)
|
||||
else:
|
||||
lines.append(
|
||||
f"{label} {bar:<{max_bar_width}} "
|
||||
f"[dim]n={hist[i]:>5}[/dim] {label_info}"
|
||||
)
|
||||
|
||||
from rich.panel import Panel
|
||||
|
||||
histogram_content = "\n".join(lines)
|
||||
title = f"Similarity Distribution ({pairs_above_threshold:,} of {total_pairs:,} pairs ≥ {threshold:.4f})"
|
||||
self.console.log(Panel(histogram_content, title=title, border_style="cyan"))
|
||||
|
||||
def find_optimal_threshold(
|
||||
self,
|
||||
comparisons: list[tuple[int, int, bool]],
|
||||
similarities: list[tuple[int, int, float]],
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Find the optimal similarity threshold that achieves target recall.
|
||||
|
||||
Args:
|
||||
comparisons: List of (i, j, is_match) from LLM comparisons.
|
||||
similarities: List of (i, j, similarity) tuples.
|
||||
|
||||
Returns:
|
||||
Tuple of (optimal_threshold, achieved_recall).
|
||||
"""
|
||||
if not comparisons or not any(comp[2] for comp in comparisons):
|
||||
# No matches found, use a high threshold to be conservative
|
||||
self.console.log(
|
||||
"[yellow]No matches found in sample. Using 99th percentile "
|
||||
"similarity as threshold.[/yellow]"
|
||||
)
|
||||
all_sims = [sim[2] for sim in similarities]
|
||||
threshold = float(np.percentile(all_sims, 99)) if all_sims else 0.9
|
||||
return threshold, 0.0
|
||||
|
||||
true_labels = np.array([comp[2] for comp in comparisons])
|
||||
sim_dict = {(i, j): sim for i, j, sim in similarities}
|
||||
sim_scores = np.array([sim_dict.get((i, j), 0.0) for i, j, _ in comparisons])
|
||||
thresholds = np.linspace(0, 1, 100)
|
||||
recalls = []
|
||||
for threshold in thresholds:
|
||||
predictions = sim_scores >= threshold
|
||||
tp = np.sum(predictions & true_labels)
|
||||
fn = np.sum(~predictions & true_labels)
|
||||
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
||||
recalls.append(recall)
|
||||
|
||||
# Find highest threshold that achieves target recall
|
||||
valid_indices = [i for i, r in enumerate(recalls) if r >= self.target_recall]
|
||||
if not valid_indices:
|
||||
# If no threshold achieves target recall, use the one with highest recall
|
||||
best_idx = int(np.argmax(recalls))
|
||||
optimal_threshold = float(thresholds[best_idx])
|
||||
achieved_recall = float(recalls[best_idx])
|
||||
self.console.log(
|
||||
f"[yellow]Warning: Could not achieve target recall {self.target_recall:.0%}. "
|
||||
f"Using threshold {optimal_threshold:.4f} with recall {achieved_recall:.2%}.[/yellow]"
|
||||
)
|
||||
else:
|
||||
best_idx = max(valid_indices)
|
||||
optimal_threshold = float(thresholds[best_idx])
|
||||
achieved_recall = float(recalls[best_idx])
|
||||
|
||||
return round(optimal_threshold, 4), achieved_recall
|
||||
|
||||
def optimize_resolve(
|
||||
self,
|
||||
input_data: list[dict[str, Any]],
|
||||
compare_fn: Callable[[dict, dict], tuple[bool, float, str]],
|
||||
blocking_keys: list[str] | None = None,
|
||||
) -> tuple[float, list[list[float]], float]:
|
||||
"""
|
||||
Compute optimal blocking threshold for resolve operation.
|
||||
|
||||
Args:
|
||||
input_data: Input dataset.
|
||||
compare_fn: Function to compare two items, returns (is_match, cost, prompt).
|
||||
blocking_keys: Keys to use for blocking. If None, extracted from prompt.
|
||||
|
||||
Returns:
|
||||
Tuple of (threshold, embeddings, total_cost).
|
||||
"""
|
||||
from rich.panel import Panel
|
||||
|
||||
# Determine blocking keys
|
||||
if not blocking_keys:
|
||||
prompt_template = self.config.get("comparison_prompt", "")
|
||||
prompt_vars = extract_jinja_variables(prompt_template)
|
||||
prompt_vars = [
|
||||
var for var in prompt_vars if var not in ["input", "input1", "input2"]
|
||||
]
|
||||
blocking_keys = list(set([var.split(".")[-1] for var in prompt_vars]))
|
||||
if not blocking_keys:
|
||||
blocking_keys = list(input_data[0].keys())
|
||||
|
||||
# Compute embeddings
|
||||
embeddings, embedding_cost = self.compute_embeddings(input_data, blocking_keys)
|
||||
|
||||
# Calculate similarities
|
||||
similarities = self.calculate_cosine_similarities_self(embeddings)
|
||||
|
||||
# Sample pairs
|
||||
sampled_pairs = self.sample_pairs(similarities)
|
||||
if not sampled_pairs:
|
||||
self.console.log(
|
||||
"[yellow]No pairs to sample. Using default threshold 0.8.[/yellow]"
|
||||
)
|
||||
return 0.8, embeddings, embedding_cost
|
||||
|
||||
# Perform comparisons
|
||||
comparisons = []
|
||||
comparison_cost = 0.0
|
||||
matches_found = 0
|
||||
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
|
||||
futures = {
|
||||
executor.submit(compare_fn, input_data[i], input_data[j]): (i, j)
|
||||
for i, j in sampled_pairs
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
i, j = futures[future]
|
||||
try:
|
||||
is_match, cost, _ = future.result()
|
||||
comparisons.append((i, j, is_match))
|
||||
comparison_cost += cost
|
||||
if is_match:
|
||||
matches_found += 1
|
||||
except Exception as e:
|
||||
self.console.log(f"[red]Comparison error: {e}[/red]")
|
||||
comparisons.append((i, j, False))
|
||||
|
||||
# Find optimal threshold
|
||||
threshold, achieved_recall = self.find_optimal_threshold(
|
||||
comparisons, similarities
|
||||
)
|
||||
total_cost = embedding_cost + comparison_cost
|
||||
|
||||
# Print histogram visualization
|
||||
self._print_similarity_histogram(similarities, comparisons, threshold)
|
||||
|
||||
# Print summary
|
||||
n = len(input_data)
|
||||
total_pairs = n * (n - 1) // 2
|
||||
pairs_above = sum(1 for s in similarities if s[2] >= threshold)
|
||||
|
||||
summary = (
|
||||
f"[bold]Blocking keys:[/bold] {blocking_keys}\n"
|
||||
f"[bold]Sampled:[/bold] {len(sampled_pairs)} pairs → {matches_found} matches ({matches_found/len(sampled_pairs)*100:.1f}%)\n"
|
||||
f"[bold]Threshold:[/bold] {threshold:.4f} → {achieved_recall:.1%} recall (target: {self.target_recall:.0%})\n"
|
||||
f"[bold]Pairs to compare:[/bold] {pairs_above:,} of {total_pairs:,} ({pairs_above/total_pairs*100:.1f}%)\n"
|
||||
f"[bold]Optimization cost:[/bold] ${total_cost:.4f}"
|
||||
)
|
||||
self.console.log(
|
||||
Panel(
|
||||
summary, title="Blocking Threshold Optimization", border_style="green"
|
||||
)
|
||||
)
|
||||
|
||||
return threshold, embeddings, total_cost
|
||||
|
||||
def optimize_equijoin(
|
||||
self,
|
||||
left_data: list[dict[str, Any]],
|
||||
right_data: list[dict[str, Any]],
|
||||
compare_fn: Callable[[dict, dict], tuple[bool, float]],
|
||||
left_keys: list[str] | None = None,
|
||||
right_keys: list[str] | None = None,
|
||||
) -> tuple[float, list[list[float]], list[list[float]], float]:
|
||||
"""
|
||||
Compute optimal blocking threshold for equijoin operation.
|
||||
|
||||
Args:
|
||||
left_data: Left dataset.
|
||||
right_data: Right dataset.
|
||||
compare_fn: Function to compare two items, returns (is_match, cost).
|
||||
left_keys: Keys to use for left dataset embeddings.
|
||||
right_keys: Keys to use for right dataset embeddings.
|
||||
|
||||
Returns:
|
||||
Tuple of (threshold, left_embeddings, right_embeddings, total_cost).
|
||||
"""
|
||||
from rich.panel import Panel
|
||||
|
||||
# Determine keys
|
||||
if not left_keys:
|
||||
left_keys = list(left_data[0].keys()) if left_data else []
|
||||
if not right_keys:
|
||||
right_keys = list(right_data[0].keys()) if right_data else []
|
||||
|
||||
# Compute embeddings
|
||||
left_embeddings, left_cost = self.compute_embeddings(left_data, left_keys)
|
||||
right_embeddings, right_cost = self.compute_embeddings(right_data, right_keys)
|
||||
embedding_cost = left_cost + right_cost
|
||||
|
||||
# Calculate cross similarities
|
||||
similarities = self.calculate_cosine_similarities_cross(
|
||||
left_embeddings, right_embeddings
|
||||
)
|
||||
|
||||
# Sample pairs
|
||||
sampled_pairs = self.sample_pairs(similarities)
|
||||
if not sampled_pairs:
|
||||
self.console.log(
|
||||
"[yellow]No pairs to sample. Using default threshold 0.8.[/yellow]"
|
||||
)
|
||||
return 0.8, left_embeddings, right_embeddings, embedding_cost
|
||||
|
||||
# Perform comparisons
|
||||
comparisons = []
|
||||
comparison_cost = 0.0
|
||||
matches_found = 0
|
||||
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
|
||||
futures = {
|
||||
executor.submit(compare_fn, left_data[i], right_data[j]): (i, j)
|
||||
for i, j in sampled_pairs
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
i, j = futures[future]
|
||||
try:
|
||||
is_match, cost = future.result()
|
||||
comparisons.append((i, j, is_match))
|
||||
comparison_cost += cost
|
||||
if is_match:
|
||||
matches_found += 1
|
||||
except Exception as e:
|
||||
self.console.log(f"[red]Comparison error: {e}[/red]")
|
||||
comparisons.append((i, j, False))
|
||||
|
||||
# Find optimal threshold
|
||||
threshold, achieved_recall = self.find_optimal_threshold(
|
||||
comparisons, similarities
|
||||
)
|
||||
total_cost = embedding_cost + comparison_cost
|
||||
|
||||
# Print histogram visualization
|
||||
self._print_similarity_histogram(similarities, comparisons, threshold)
|
||||
|
||||
# Print summary
|
||||
total_pairs = len(left_data) * len(right_data)
|
||||
pairs_above = sum(1 for s in similarities if s[2] >= threshold)
|
||||
|
||||
summary = (
|
||||
f"[bold]Left keys:[/bold] {left_keys} [bold]Right keys:[/bold] {right_keys}\n"
|
||||
f"[bold]Sampled:[/bold] {len(sampled_pairs)} pairs → {matches_found} matches ({matches_found/len(sampled_pairs)*100:.1f}%)\n"
|
||||
f"[bold]Threshold:[/bold] {threshold:.4f} → {achieved_recall:.1%} recall (target: {self.target_recall:.0%})\n"
|
||||
f"[bold]Pairs to compare:[/bold] {pairs_above:,} of {total_pairs:,} ({pairs_above/total_pairs*100:.1f}%)\n"
|
||||
f"[bold]Optimization cost:[/bold] ${total_cost:.4f}"
|
||||
)
|
||||
self.console.log(
|
||||
Panel(
|
||||
summary, title="Blocking Threshold Optimization", border_style="green"
|
||||
)
|
||||
)
|
||||
|
||||
return threshold, left_embeddings, right_embeddings, total_cost
|
||||
|
|
@ -55,7 +55,7 @@ class Optimizer:
|
|||
def __init__(
|
||||
self,
|
||||
runner: "DSLRunner",
|
||||
rewrite_agent_model: str = "gpt-4o",
|
||||
rewrite_agent_model: str = "gpt-5.1",
|
||||
judge_agent_model: str = "gpt-4o-mini",
|
||||
litellm_kwargs: dict[str, Any] = {},
|
||||
resume: bool = False,
|
||||
|
|
@ -67,7 +67,7 @@ class Optimizer:
|
|||
|
||||
Args:
|
||||
yaml_file (str): Path to the YAML configuration file.
|
||||
model (str): The name of the language model to use. Defaults to "gpt-4o".
|
||||
model (str): The name of the language model to use. Defaults to "gpt-5.1".
|
||||
resume (bool): Whether to resume optimization from a previous run. Defaults to False.
|
||||
timeout (int): Timeout in seconds for operations. Defaults to 60.
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,926 @@
|
|||
"""
|
||||
Fast decomposition analyzer for map operations.
|
||||
|
||||
This module provides a fast way to decompose map operations using directives,
|
||||
running candidates on samples, and selecting the best via pairwise LLM comparison.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm import completion, model_cost
|
||||
from rich.console import Console
|
||||
|
||||
from docetl.console import get_console
|
||||
from docetl.reasoning_optimizer.directives import (
|
||||
ChainingDirective,
|
||||
ClarifyInstructionsDirective,
|
||||
DeterministicDocCompressionDirective,
|
||||
DocumentChunkingDirective,
|
||||
GleaningDirective,
|
||||
IsolatingSubtasksDirective,
|
||||
)
|
||||
from docetl.runner import DSLRunner
|
||||
from docetl.utils import completion_cost, count_tokens
|
||||
|
||||
# Drop unsupported params for models like gpt-5 that don't support temperature=0
|
||||
litellm.drop_params = True
|
||||
|
||||
# Base directives always applicable to map operations
|
||||
BASE_MAP_DIRECTIVES = [
|
||||
ChainingDirective(),
|
||||
IsolatingSubtasksDirective(),
|
||||
GleaningDirective(),
|
||||
ClarifyInstructionsDirective(),
|
||||
]
|
||||
|
||||
# Directives that depend on document size
|
||||
CHUNKING_DIRECTIVE = DocumentChunkingDirective()
|
||||
COMPRESSION_DIRECTIVE = DeterministicDocCompressionDirective()
|
||||
|
||||
# Threshold for enabling DeterministicDocCompression (in characters)
|
||||
DOC_COMPRESSION_CHAR_THRESHOLD = 1000
|
||||
|
||||
# Threshold for enabling DocumentChunking (as fraction of context window)
|
||||
DOC_CHUNKING_CONTEXT_THRESHOLD = 0.10 # 10%
|
||||
|
||||
|
||||
class FastDecomposer:
|
||||
"""
|
||||
Fast decomposition of map operations using directives and pairwise comparison.
|
||||
|
||||
Instead of the full optimizer flow, this:
|
||||
1. Tries multiple directives to generate candidate decompositions
|
||||
2. Runs each candidate on sample documents
|
||||
3. Uses LLM judge for pairwise comparison
|
||||
4. Returns the winning decomposition
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
yaml_config_path: str,
|
||||
optimizer_model: str = "gpt-5.1",
|
||||
sample_size: int = 5,
|
||||
litellm_kwargs: dict[str, Any] | None = None,
|
||||
console: Console | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the decomposer.
|
||||
|
||||
Args:
|
||||
yaml_config_path: Path to the pipeline YAML config file
|
||||
optimizer_model: LLM model to use for directive instantiation and judging
|
||||
sample_size: Number of sample documents to run candidates on
|
||||
litellm_kwargs: Additional kwargs to pass to litellm.completion
|
||||
console: Rich console for output (uses default if not provided)
|
||||
"""
|
||||
self.yaml_config_path = yaml_config_path
|
||||
self.optimizer_model = optimizer_model
|
||||
self.sample_size = sample_size
|
||||
self.litellm_kwargs = litellm_kwargs or {}
|
||||
if "temperature" not in self.litellm_kwargs:
|
||||
self.litellm_kwargs["temperature"] = 0.0
|
||||
|
||||
self.total_cost = 0.0
|
||||
self.console = console or get_console()
|
||||
|
||||
# Load the config
|
||||
import yaml
|
||||
|
||||
with open(yaml_config_path, "r") as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
|
||||
self.operators = self.config.get("operations", [])
|
||||
self.default_model = self.config.get("default_model", "gpt-4o-mini")
|
||||
self.intermediate_dir = (
|
||||
self.config.get("pipeline", {}).get("output", {}).get("intermediate_dir")
|
||||
)
|
||||
|
||||
def _log(self, message: str) -> None:
|
||||
"""Log a message to the console."""
|
||||
self.console.log(message)
|
||||
|
||||
def get_model_context_limit(self, model: str) -> int:
|
||||
"""
|
||||
Get the context window limit for a model.
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., 'gpt-4o', 'azure/gpt-4')
|
||||
|
||||
Returns:
|
||||
Maximum number of input tokens the model can handle
|
||||
"""
|
||||
model_info = model_cost.get(model, {})
|
||||
# Try without provider prefix if not found
|
||||
if not model_info:
|
||||
model_name = model.split("/")[-1]
|
||||
model_info = model_cost.get(model_name, {})
|
||||
return model_info.get("max_input_tokens", 128000) # Default to 128k
|
||||
|
||||
def get_avg_doc_size(
|
||||
self, sample_data: list[dict[str, Any]], op_config: dict[str, Any]
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate the average document size in characters and tokens.
|
||||
|
||||
Extracts the document content from sample data based on the operation's
|
||||
prompt template (looks for {{ input.field_name }} patterns).
|
||||
|
||||
Args:
|
||||
sample_data: List of sample documents
|
||||
op_config: The operation configuration
|
||||
|
||||
Returns:
|
||||
Tuple of (avg_chars, avg_tokens) for the document content
|
||||
"""
|
||||
import re
|
||||
|
||||
if not sample_data:
|
||||
return 0.0, 0.0
|
||||
|
||||
prompt = op_config.get("prompt", "")
|
||||
model = op_config.get("model", self.default_model)
|
||||
|
||||
# Extract field names from prompt template ({{ input.field_name }})
|
||||
field_pattern = r"\{\{\s*input\.(\w+)\s*\}\}"
|
||||
fields = re.findall(field_pattern, prompt)
|
||||
|
||||
if not fields:
|
||||
# Fallback: use all string values from the first document
|
||||
if sample_data:
|
||||
fields = [
|
||||
k
|
||||
for k, v in sample_data[0].items()
|
||||
if isinstance(v, str) and len(v) > 100
|
||||
]
|
||||
|
||||
total_chars = 0
|
||||
total_tokens = 0
|
||||
|
||||
for doc in sample_data:
|
||||
doc_content = ""
|
||||
for field in fields:
|
||||
if field in doc:
|
||||
value = doc[field]
|
||||
if isinstance(value, str):
|
||||
doc_content += value
|
||||
else:
|
||||
doc_content += str(value)
|
||||
|
||||
total_chars += len(doc_content)
|
||||
if doc_content:
|
||||
total_tokens += count_tokens(doc_content, model)
|
||||
|
||||
n = len(sample_data)
|
||||
return total_chars / n if n > 0 else 0.0, total_tokens / n if n > 0 else 0.0
|
||||
|
||||
def get_applicable_directives(
|
||||
self,
|
||||
sample_data: list[dict[str, Any]],
|
||||
op_config: dict[str, Any],
|
||||
) -> list:
|
||||
"""
|
||||
Get the list of directives applicable based on data characteristics.
|
||||
|
||||
- DocumentChunkingDirective: only if avg doc size > 10% of context window
|
||||
- DeterministicDocCompressionDirective: only if avg doc size > 1000 chars
|
||||
|
||||
Args:
|
||||
sample_data: List of sample documents
|
||||
op_config: The operation configuration
|
||||
|
||||
Returns:
|
||||
List of applicable directive instances
|
||||
"""
|
||||
directives = [] # Build list with priority ordering
|
||||
|
||||
model = op_config.get("model", self.default_model)
|
||||
context_limit = self.get_model_context_limit(model)
|
||||
avg_chars, avg_tokens = self.get_avg_doc_size(sample_data, op_config)
|
||||
|
||||
self._log(
|
||||
f"Document analysis: avg_chars={avg_chars:.0f}, avg_tokens={avg_tokens:.0f}, "
|
||||
f"context_limit={context_limit}"
|
||||
)
|
||||
|
||||
# Add DeterministicDocCompression FIRST if doc size > 1000 chars (high priority)
|
||||
if avg_chars > DOC_COMPRESSION_CHAR_THRESHOLD:
|
||||
self._log(
|
||||
f"[cyan]Enabling DeterministicDocCompression (priority)[/cyan] "
|
||||
f"(avg {avg_chars:.0f} chars > {DOC_COMPRESSION_CHAR_THRESHOLD})"
|
||||
)
|
||||
directives.append(COMPRESSION_DIRECTIVE)
|
||||
|
||||
# Add base directives
|
||||
directives.extend(BASE_MAP_DIRECTIVES)
|
||||
|
||||
# Add DocumentChunking if avg tokens > 10% of context window
|
||||
token_threshold = context_limit * DOC_CHUNKING_CONTEXT_THRESHOLD
|
||||
if avg_tokens > token_threshold:
|
||||
self._log(
|
||||
f"[cyan]Enabling DocumentChunking[/cyan] "
|
||||
f"(avg {avg_tokens:.0f} tokens > {token_threshold:.0f} = 10% of {context_limit})"
|
||||
)
|
||||
directives.append(CHUNKING_DIRECTIVE)
|
||||
else:
|
||||
self._log(
|
||||
f"[dim]Skipping DocumentChunking[/dim] "
|
||||
f"(avg {avg_tokens:.0f} tokens <= {token_threshold:.0f} = 10% of {context_limit})"
|
||||
)
|
||||
|
||||
return directives
|
||||
|
||||
def load_sample_data(self, step_name: str, op_name: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load sample input data for an operation.
|
||||
|
||||
For the first operation, loads from the dataset.
|
||||
For subsequent operations, loads from the previous operation's intermediate output.
|
||||
|
||||
Args:
|
||||
step_name: Name of the pipeline step
|
||||
op_name: Name of the operation
|
||||
|
||||
Returns:
|
||||
List of sample documents
|
||||
"""
|
||||
# Find the operation's position
|
||||
op_names = [op.get("name") for op in self.operators]
|
||||
try:
|
||||
op_idx = op_names.index(op_name)
|
||||
except ValueError:
|
||||
raise ValueError(f"Operation '{op_name}' not found in config")
|
||||
|
||||
if op_idx == 0:
|
||||
# First operation - load from dataset
|
||||
datasets = self.config.get("datasets", {})
|
||||
# Get the first dataset (or the one used by this step)
|
||||
for dataset_name, dataset_config in datasets.items():
|
||||
dataset_path = dataset_config.get("path")
|
||||
if dataset_path and os.path.exists(dataset_path):
|
||||
with open(dataset_path, "r") as f:
|
||||
data = json.load(f)
|
||||
return data[: self.sample_size]
|
||||
raise FileNotFoundError("No dataset found in config")
|
||||
else:
|
||||
# Load from previous operation's intermediate output
|
||||
prev_op_name = op_names[op_idx - 1]
|
||||
output_path = os.path.join(
|
||||
self.intermediate_dir, step_name, f"{prev_op_name}.json"
|
||||
)
|
||||
if not os.path.exists(output_path):
|
||||
raise FileNotFoundError(
|
||||
f"No intermediate output found at {output_path}. "
|
||||
"Run the previous operation first."
|
||||
)
|
||||
with open(output_path, "r") as f:
|
||||
data = json.load(f)
|
||||
return data[: self.sample_size]
|
||||
|
||||
def get_input_file_path(self) -> str:
|
||||
"""Get the input file path for the pipeline."""
|
||||
datasets = self.config.get("datasets", {})
|
||||
for dataset_name, dataset_config in datasets.items():
|
||||
path = dataset_config.get("path")
|
||||
if path:
|
||||
return path
|
||||
return ""
|
||||
|
||||
def generate_candidates(
|
||||
self,
|
||||
op_name: str,
|
||||
sample_data: list[dict[str, Any]],
|
||||
target_op: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Generate candidate decompositions using directives.
|
||||
|
||||
Directives are selected based on data characteristics:
|
||||
- DocumentChunkingDirective: only if avg doc size > 10% of context window
|
||||
- DeterministicDocCompressionDirective: only if avg doc size > 1000 chars
|
||||
|
||||
Args:
|
||||
op_name: Name of the operation to decompose
|
||||
sample_data: Sample data for analyzing document characteristics
|
||||
target_op: The target operation configuration
|
||||
|
||||
Returns:
|
||||
List of candidate dictionaries with 'name', 'ops', 'cost' keys
|
||||
"""
|
||||
candidates = []
|
||||
|
||||
# Add original as baseline
|
||||
self._log("Adding original operation as baseline candidate")
|
||||
candidates.append(
|
||||
{
|
||||
"name": "original",
|
||||
"ops": deepcopy(self.operators),
|
||||
"cost": 0.0,
|
||||
"error": None,
|
||||
}
|
||||
)
|
||||
|
||||
input_file_path = self.get_input_file_path()
|
||||
|
||||
# Get applicable directives based on data characteristics
|
||||
applicable_directives = self.get_applicable_directives(sample_data, target_op)
|
||||
|
||||
self._log(
|
||||
f"Generating candidates using {len(applicable_directives)} directives..."
|
||||
)
|
||||
for i, directive in enumerate(applicable_directives, 1):
|
||||
with self.console.status(
|
||||
f"[bold cyan]({i}/{len(applicable_directives)}) Trying directive: {directive.name}...[/bold cyan]",
|
||||
spinner="dots",
|
||||
):
|
||||
try:
|
||||
new_ops_list, _, cost = directive.instantiate(
|
||||
operators=deepcopy(self.operators),
|
||||
target_ops=[op_name],
|
||||
agent_llm=self.optimizer_model,
|
||||
message_history=[],
|
||||
global_default_model=self.default_model,
|
||||
input_file_path=input_file_path,
|
||||
)
|
||||
self.total_cost += cost
|
||||
self._log(
|
||||
f" [green]✓[/green] {directive.name} generated {len(new_ops_list)} operations (cost: ${cost:.4f})"
|
||||
)
|
||||
candidates.append(
|
||||
{
|
||||
"name": directive.name,
|
||||
"ops": new_ops_list,
|
||||
"cost": cost,
|
||||
"error": None,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
# Directive not applicable or failed - skip it
|
||||
self._log(f" [red]✗[/red] {directive.name} failed: {str(e)}")
|
||||
candidates.append(
|
||||
{
|
||||
"name": directive.name,
|
||||
"ops": None,
|
||||
"cost": 0.0,
|
||||
"error": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
return candidates
|
||||
|
||||
def extract_ops_to_run(
|
||||
self, ops_list: list[dict], original_op_name: str
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Extract the operations that replaced the original operation.
|
||||
|
||||
Args:
|
||||
ops_list: The full transformed operations list
|
||||
original_op_name: Name of the original operation that was decomposed
|
||||
|
||||
Returns:
|
||||
List of operations that should be run on samples
|
||||
"""
|
||||
# Find ops that are new (not in original) or modified
|
||||
original_names = {op["name"] for op in self.operators}
|
||||
|
||||
# Find the position where the original op was
|
||||
original_idx = None
|
||||
for i, op in enumerate(self.operators):
|
||||
if op["name"] == original_op_name:
|
||||
original_idx = i
|
||||
break
|
||||
|
||||
if original_idx is None:
|
||||
return ops_list
|
||||
|
||||
# Find new ops (those not in original list)
|
||||
new_ops = []
|
||||
for op in ops_list:
|
||||
if op["name"] not in original_names or op["name"] == original_op_name:
|
||||
new_ops.append(op)
|
||||
|
||||
return new_ops if new_ops else [ops_list[original_idx]]
|
||||
|
||||
def run_candidate_on_samples(
|
||||
self,
|
||||
candidate: dict[str, Any],
|
||||
sample_data: list[dict[str, Any]],
|
||||
original_op_name: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Run a candidate's operations on sample data.
|
||||
|
||||
Args:
|
||||
candidate: Candidate dictionary with 'ops' key
|
||||
sample_data: List of sample documents
|
||||
original_op_name: Name of the original operation
|
||||
|
||||
Returns:
|
||||
List of output documents
|
||||
"""
|
||||
if candidate["ops"] is None:
|
||||
return []
|
||||
|
||||
# Extract ops to run
|
||||
ops_to_run = self.extract_ops_to_run(candidate["ops"], original_op_name)
|
||||
|
||||
if not ops_to_run:
|
||||
return []
|
||||
|
||||
# Create a minimal config for running these ops
|
||||
# Write sample data to a temp file
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(sample_data, f)
|
||||
temp_input_path = f.name
|
||||
|
||||
# Create a temp output file (required by DSLRunner validation)
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
temp_output_path = f.name
|
||||
|
||||
try:
|
||||
# Create a minimal pipeline config
|
||||
temp_config = {
|
||||
"default_model": self.default_model,
|
||||
"operations": ops_to_run,
|
||||
"datasets": {
|
||||
"sample_data": {
|
||||
"type": "file",
|
||||
"path": temp_input_path,
|
||||
}
|
||||
},
|
||||
"pipeline": {
|
||||
"steps": [
|
||||
{
|
||||
"name": "decompose_test",
|
||||
"input": "sample_data",
|
||||
"operations": [op["name"] for op in ops_to_run],
|
||||
}
|
||||
],
|
||||
"output": {
|
||||
"type": "file",
|
||||
"path": temp_output_path,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Create runner and execute
|
||||
runner = DSLRunner(temp_config, max_threads=4)
|
||||
|
||||
# Run operations sequentially on the data
|
||||
current_data = sample_data
|
||||
for op_config in ops_to_run:
|
||||
current_data = runner._run_operation(op_config, current_data)
|
||||
|
||||
self.total_cost += runner.total_cost
|
||||
|
||||
return current_data
|
||||
|
||||
finally:
|
||||
# Clean up temp files
|
||||
for temp_path in [temp_input_path, temp_output_path]:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def pairwise_compare(
|
||||
self,
|
||||
candidate_a: dict[str, Any],
|
||||
candidate_b: dict[str, Any],
|
||||
original_prompt: str,
|
||||
output_schema: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Compare two candidates using LLM judge.
|
||||
|
||||
Args:
|
||||
candidate_a: First candidate with 'name', 'outputs' keys
|
||||
candidate_b: Second candidate with 'name', 'outputs' keys
|
||||
original_prompt: The original operation's prompt
|
||||
output_schema: The expected output schema
|
||||
|
||||
Returns:
|
||||
The winning candidate
|
||||
"""
|
||||
# If one has no outputs, the other wins
|
||||
if not candidate_a.get("outputs"):
|
||||
return candidate_b
|
||||
if not candidate_b.get("outputs"):
|
||||
return candidate_a
|
||||
|
||||
system_prompt = """You are an expert judge comparing outputs from two data processing pipeline variants.
|
||||
|
||||
Your task is to determine which variant produces BETTER outputs based on:
|
||||
1. **Completeness**: Does the output contain all required information?
|
||||
2. **Accuracy**: Is the extracted/generated information correct?
|
||||
3. **Consistency**: Are the outputs consistent across different samples?
|
||||
4. **Quality**: Is the output well-structured and useful?
|
||||
|
||||
Be objective and focus on the actual output quality, not the approach used."""
|
||||
|
||||
user_prompt = f"""Compare outputs from two pipeline variants for this task:
|
||||
|
||||
## Original Task
|
||||
**Prompt:**
|
||||
```
|
||||
{original_prompt[:2000]}{"..." if len(original_prompt) > 2000 else ""}
|
||||
```
|
||||
|
||||
**Expected Output Schema:**
|
||||
```json
|
||||
{json.dumps(output_schema, indent=2)}
|
||||
```
|
||||
|
||||
## Variant A: {candidate_a["name"]}
|
||||
**Sample Outputs:**
|
||||
```json
|
||||
{json.dumps(candidate_a["outputs"][:3], indent=2, default=str)}
|
||||
```
|
||||
|
||||
## Variant B: {candidate_b["name"]}
|
||||
**Sample Outputs:**
|
||||
```json
|
||||
{json.dumps(candidate_b["outputs"][:3], indent=2, default=str)}
|
||||
```
|
||||
|
||||
## Your Task
|
||||
Which variant produces better outputs? Consider completeness, accuracy, consistency, and quality.
|
||||
Respond with your analysis and final choice."""
|
||||
|
||||
response = completion(
|
||||
model=self.optimizer_model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "comparison_result",
|
||||
"strict": True,
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"winner": {
|
||||
"type": "string",
|
||||
"enum": ["A", "B"],
|
||||
"description": "Which variant is better: A or B",
|
||||
},
|
||||
"rationale": {
|
||||
"type": "string",
|
||||
"description": "Explanation of why this variant is better",
|
||||
},
|
||||
"a_strengths": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Strengths of variant A",
|
||||
},
|
||||
"b_strengths": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Strengths of variant B",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"winner",
|
||||
"rationale",
|
||||
"a_strengths",
|
||||
"b_strengths",
|
||||
],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
**self.litellm_kwargs,
|
||||
)
|
||||
|
||||
cost = completion_cost(response)
|
||||
self.total_cost += cost
|
||||
|
||||
result = json.loads(response.choices[0].message.content)
|
||||
|
||||
winner = candidate_a if result["winner"] == "A" else candidate_b
|
||||
winner["comparison_rationale"] = result["rationale"]
|
||||
|
||||
return winner
|
||||
|
||||
def decompose(
|
||||
self,
|
||||
step_name: str,
|
||||
op_name: str,
|
||||
) -> tuple[list[dict[str, Any]], str, int, float]:
|
||||
"""
|
||||
Decompose an operation using fast directive-based approach.
|
||||
|
||||
This is the main entry point. It:
|
||||
1. Generates candidate decompositions using directives
|
||||
2. Runs each on sample data
|
||||
3. Uses pairwise LLM comparison to select the best
|
||||
4. Returns the winning decomposition
|
||||
|
||||
Args:
|
||||
step_name: Name of the pipeline step
|
||||
op_name: Name of the operation to decompose
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- decomposed_ops: List of operations that replace the original
|
||||
- winning_directive: Name of the winning directive
|
||||
- candidates_evaluated: Number of candidates that were compared
|
||||
- original_outputs: Sample outputs from the original operation
|
||||
- decomposed_outputs: Sample outputs from the winning decomposition
|
||||
- comparison_rationale: LLM's explanation of why the winner was chosen
|
||||
- cost: Total LLM API cost
|
||||
"""
|
||||
self.console.rule("[bold blue]Fast Decomposition[/bold blue]")
|
||||
self._log(f"[bold]Target operation:[/bold] {op_name}")
|
||||
|
||||
# Find the target operation config
|
||||
target_op = None
|
||||
for op in self.operators:
|
||||
if op["name"] == op_name:
|
||||
target_op = op
|
||||
break
|
||||
|
||||
if target_op is None:
|
||||
raise ValueError(f"Operation '{op_name}' not found in config")
|
||||
|
||||
# Verify it's a map operation
|
||||
if target_op.get("type") != "map":
|
||||
raise ValueError(
|
||||
f"Operation '{op_name}' is type '{target_op.get('type')}', "
|
||||
"but fast decomposition only supports 'map' operations"
|
||||
)
|
||||
|
||||
# Load sample data
|
||||
self._log(f"Loading sample data from step '{step_name}'...")
|
||||
sample_data = self.load_sample_data(step_name, op_name)
|
||||
self._log(f"[green]✓[/green] Loaded {len(sample_data)} sample documents")
|
||||
|
||||
# Generate candidates
|
||||
self.console.rule("[bold]Generating Candidates[/bold]")
|
||||
candidates = self.generate_candidates(op_name, sample_data, target_op)
|
||||
|
||||
# Filter out failed candidates
|
||||
valid_candidates = [c for c in candidates if c["ops"] is not None]
|
||||
self._log(f"[bold]{len(valid_candidates)}[/bold] valid candidates generated")
|
||||
|
||||
if len(valid_candidates) < 2:
|
||||
# Only original (or nothing) - return original
|
||||
self._log(
|
||||
"[yellow]No alternative decompositions generated. Keeping original.[/yellow]"
|
||||
)
|
||||
return {
|
||||
"decomposed_ops": self.operators,
|
||||
"winning_directive": "original",
|
||||
"candidates_evaluated": len(valid_candidates),
|
||||
"original_outputs": [],
|
||||
"decomposed_outputs": [],
|
||||
"comparison_rationale": "No alternative decompositions were generated.",
|
||||
"cost": self.total_cost,
|
||||
}
|
||||
|
||||
# Run each candidate on samples IN PARALLEL
|
||||
self.console.rule("[bold]Running Candidates on Samples[/bold]")
|
||||
self._log(f"Running {len(valid_candidates)} candidates in parallel...")
|
||||
|
||||
def run_single_candidate(candidate):
|
||||
"""Run a single candidate and return results."""
|
||||
name = candidate["name"]
|
||||
try:
|
||||
outputs = self.run_candidate_on_samples(candidate, sample_data, op_name)
|
||||
return {"name": name, "outputs": outputs, "error": None}
|
||||
except Exception as e:
|
||||
error_tb = traceback.format_exc()
|
||||
return {
|
||||
"name": name,
|
||||
"outputs": [],
|
||||
"error": str(e),
|
||||
"traceback": error_tb,
|
||||
}
|
||||
|
||||
# Run all candidates in parallel
|
||||
with self.console.status(
|
||||
"[bold cyan]Running candidates on samples...[/bold cyan]", spinner="dots"
|
||||
):
|
||||
with ThreadPoolExecutor(max_workers=len(valid_candidates)) as executor:
|
||||
future_to_candidate = {
|
||||
executor.submit(run_single_candidate, c): c
|
||||
for c in valid_candidates
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_candidate):
|
||||
candidate = future_to_candidate[future]
|
||||
result = future.result()
|
||||
|
||||
if result["error"]:
|
||||
candidate["outputs"] = []
|
||||
candidate["run_error"] = result["error"]
|
||||
self._log(
|
||||
f" [red]✗[/red] {result['name']} failed: {result['error']}"
|
||||
)
|
||||
else:
|
||||
candidate["outputs"] = result["outputs"]
|
||||
self._log(
|
||||
f" [green]✓[/green] {result['name']}: {len(result['outputs'])} outputs"
|
||||
)
|
||||
|
||||
# Filter to candidates with outputs
|
||||
candidates_with_outputs = [c for c in valid_candidates if c.get("outputs")]
|
||||
self._log(
|
||||
f"[bold]{len(candidates_with_outputs)}[/bold] candidates produced outputs"
|
||||
)
|
||||
|
||||
if not candidates_with_outputs:
|
||||
# All failed - return original
|
||||
self._log("[red]All decomposition candidates failed to execute.[/red]")
|
||||
# Log the errors for debugging
|
||||
for c in valid_candidates:
|
||||
if c.get("run_error"):
|
||||
self._log(f" [red]{c['name']}:[/red] {c['run_error']}")
|
||||
return {
|
||||
"decomposed_ops": self.operators,
|
||||
"winning_directive": "original",
|
||||
"candidates_evaluated": 0,
|
||||
"original_outputs": [],
|
||||
"decomposed_outputs": [],
|
||||
"comparison_rationale": "All decomposition candidates failed to execute.",
|
||||
"cost": self.total_cost,
|
||||
}
|
||||
|
||||
# Find the original candidate's outputs
|
||||
original_candidate = next(
|
||||
(c for c in candidates_with_outputs if c["name"] == "original"), None
|
||||
)
|
||||
original_outputs = original_candidate["outputs"] if original_candidate else []
|
||||
|
||||
# Parallel pairwise comparison: compare all candidates against original
|
||||
self.console.rule("[bold]Pairwise Comparison[/bold]")
|
||||
original_prompt = target_op.get("prompt", "")
|
||||
output_schema = target_op.get("output", {}).get("schema", {})
|
||||
|
||||
# If only original exists, it wins by default
|
||||
if len(candidates_with_outputs) == 1:
|
||||
winner = candidates_with_outputs[0]
|
||||
elif original_candidate is None:
|
||||
# No original - just pick the first candidate
|
||||
winner = candidates_with_outputs[0]
|
||||
else:
|
||||
# Compare all non-original candidates against original IN PARALLEL
|
||||
challengers = [
|
||||
c for c in candidates_with_outputs if c["name"] != "original"
|
||||
]
|
||||
|
||||
if not challengers:
|
||||
winner = original_candidate
|
||||
else:
|
||||
self._log(
|
||||
f"Comparing {len(challengers)} candidates against original in parallel..."
|
||||
)
|
||||
|
||||
def compare_against_original(challenger):
|
||||
"""Compare a single challenger against original."""
|
||||
try:
|
||||
result = self.pairwise_compare(
|
||||
original_candidate,
|
||||
challenger,
|
||||
original_prompt,
|
||||
output_schema,
|
||||
)
|
||||
won = result["name"] == challenger["name"]
|
||||
return {
|
||||
"challenger": challenger,
|
||||
"won": won,
|
||||
"rationale": result.get("comparison_rationale", ""),
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"challenger": challenger,
|
||||
"won": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
# Run all comparisons in parallel
|
||||
comparison_results = []
|
||||
with self.console.status(
|
||||
f"[bold cyan]Running {len(challengers)} comparisons in parallel...[/bold cyan]",
|
||||
spinner="dots",
|
||||
):
|
||||
with ThreadPoolExecutor(max_workers=len(challengers)) as executor:
|
||||
future_to_challenger = {
|
||||
executor.submit(compare_against_original, c): c
|
||||
for c in challengers
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_challenger):
|
||||
result = future.result()
|
||||
comparison_results.append(result)
|
||||
challenger_name = result["challenger"]["name"]
|
||||
if result.get("error"):
|
||||
self._log(
|
||||
f" [red]✗[/red] {challenger_name} vs original: error - {result['error']}"
|
||||
)
|
||||
elif result["won"]:
|
||||
self._log(
|
||||
f" [green]✓[/green] {challenger_name} beat original"
|
||||
)
|
||||
else:
|
||||
self._log(
|
||||
f" [dim]○[/dim] {challenger_name} lost to original"
|
||||
)
|
||||
|
||||
# Find winners (candidates that beat original)
|
||||
winners = [r for r in comparison_results if r.get("won")]
|
||||
|
||||
if not winners:
|
||||
# Original beats all challengers
|
||||
winner = original_candidate
|
||||
self._log("[bold]Original wins against all challengers[/bold]")
|
||||
elif len(winners) == 1:
|
||||
# Single winner
|
||||
winner = winners[0]["challenger"]
|
||||
winner["comparison_rationale"] = winners[0].get("rationale", "")
|
||||
else:
|
||||
# Multiple winners beat original - run tiebreaker comparisons in parallel
|
||||
self._log(
|
||||
f"[bold]{len(winners)} candidates beat original - running tiebreaker...[/bold]"
|
||||
)
|
||||
|
||||
# Compare all winners against each other in parallel (round-robin)
|
||||
winner_candidates = [w["challenger"] for w in winners]
|
||||
win_counts = {c["name"]: 0 for c in winner_candidates}
|
||||
|
||||
# Generate all pairwise matchups
|
||||
matchups = []
|
||||
for i, a in enumerate(winner_candidates):
|
||||
for b in winner_candidates[i + 1 :]:
|
||||
matchups.append((a, b))
|
||||
|
||||
if matchups:
|
||||
|
||||
def run_matchup(matchup):
|
||||
a, b = matchup
|
||||
try:
|
||||
result = self.pairwise_compare(
|
||||
a, b, original_prompt, output_schema
|
||||
)
|
||||
return {
|
||||
"winner": result["name"],
|
||||
"a": a["name"],
|
||||
"b": b["name"],
|
||||
}
|
||||
except Exception:
|
||||
return {
|
||||
"winner": a["name"],
|
||||
"a": a["name"],
|
||||
"b": b["name"],
|
||||
} # Default to first
|
||||
|
||||
with self.console.status(
|
||||
f"[bold cyan]Running {len(matchups)} tiebreaker comparisons...[/bold cyan]",
|
||||
spinner="dots",
|
||||
):
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=len(matchups)
|
||||
) as executor:
|
||||
for result in executor.map(run_matchup, matchups):
|
||||
win_counts[result["winner"]] += 1
|
||||
self._log(
|
||||
f" [dim]{result['a']} vs {result['b']} → {result['winner']}[/dim]"
|
||||
)
|
||||
|
||||
# Pick candidate with most wins
|
||||
best_name = max(win_counts, key=win_counts.get)
|
||||
winner = next(
|
||||
c for c in winner_candidates if c["name"] == best_name
|
||||
)
|
||||
self._log(
|
||||
f"[bold]Tiebreaker winner: {best_name} ({win_counts[best_name]} wins)[/bold]"
|
||||
)
|
||||
|
||||
# Extract the decomposed operations
|
||||
decomposed_ops = self.extract_ops_to_run(winner["ops"], op_name)
|
||||
|
||||
# Final summary
|
||||
self.console.rule("[bold green]Decomposition Complete[/bold green]")
|
||||
self._log(f"[bold]Winner:[/bold] [green]{winner['name']}[/green]")
|
||||
self._log(f"[bold]Candidates evaluated:[/bold] {len(candidates_with_outputs)}")
|
||||
self._log(f"[bold]New operations:[/bold] {len(decomposed_ops)}")
|
||||
self._log(f"[bold]Total cost:[/bold] ${self.total_cost:.4f}")
|
||||
|
||||
return {
|
||||
"decomposed_ops": decomposed_ops,
|
||||
"winning_directive": winner["name"],
|
||||
"candidates_evaluated": len(candidates_with_outputs),
|
||||
"original_outputs": original_outputs,
|
||||
"decomposed_outputs": winner.get("outputs", []),
|
||||
"comparison_rationale": winner.get("comparison_rationale", ""),
|
||||
"cost": self.total_cost,
|
||||
}
|
||||
|
|
@ -0,0 +1,337 @@
|
|||
"""
|
||||
Fast should_optimize analyzer using a single LLM call.
|
||||
|
||||
This module provides a lightweight alternative to the full MapOptimizer/ReduceOptimizer/JoinOptimizer
|
||||
flow for quickly determining if an operation should be decomposed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm import completion, model_cost
|
||||
|
||||
from docetl.utils import completion_cost, count_tokens
|
||||
|
||||
# Drop unsupported params for models like gpt-5 that don't support temperature=0
|
||||
litellm.drop_params = True
|
||||
|
||||
|
||||
class FastShouldOptimizeAnalyzer:
|
||||
"""
|
||||
Analyzes whether an operation should be optimized using a single LLM call.
|
||||
|
||||
Instead of running the operation on sample data and using complex evaluation logic,
|
||||
this reads cached outputs from intermediate files and makes one judgment call.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intermediate_dir: str,
|
||||
optimizer_model: str = "gpt-5.1",
|
||||
litellm_kwargs: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the analyzer.
|
||||
|
||||
Args:
|
||||
intermediate_dir: Path to the directory containing intermediate outputs
|
||||
optimizer_model: The LLM model to use for analysis (default: gpt-5.1)
|
||||
litellm_kwargs: Additional kwargs to pass to litellm.completion
|
||||
"""
|
||||
self.intermediate_dir = intermediate_dir
|
||||
self.optimizer_model = optimizer_model
|
||||
self.litellm_kwargs = litellm_kwargs or {}
|
||||
if "temperature" not in self.litellm_kwargs:
|
||||
self.litellm_kwargs["temperature"] = 0.0
|
||||
|
||||
def load_operation_data(self, step_name: str, op_name: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load data from the intermediate file for an operation.
|
||||
|
||||
Args:
|
||||
step_name: Name of the pipeline step
|
||||
op_name: Name of the operation
|
||||
|
||||
Returns:
|
||||
List of dictionaries
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the intermediate file doesn't exist
|
||||
"""
|
||||
output_path = os.path.join(self.intermediate_dir, step_name, f"{op_name}.json")
|
||||
if not os.path.exists(output_path):
|
||||
raise FileNotFoundError(
|
||||
f"No output file found at {output_path}. "
|
||||
"Run the operation first to generate outputs."
|
||||
)
|
||||
with open(output_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
def find_previous_operation(
|
||||
self, operations: list[dict[str, Any]], op_name: str
|
||||
) -> str | None:
|
||||
"""
|
||||
Find the operation that comes before op_name in the pipeline.
|
||||
|
||||
Args:
|
||||
operations: List of operation configs from the pipeline
|
||||
op_name: Name of the current operation
|
||||
|
||||
Returns:
|
||||
Name of the previous operation, or None if this is the first operation
|
||||
"""
|
||||
op_names = [op.get("name") for op in operations]
|
||||
try:
|
||||
idx = op_names.index(op_name)
|
||||
if idx > 0:
|
||||
return op_names[idx - 1]
|
||||
except ValueError:
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_max_context_tokens(self) -> int:
|
||||
"""
|
||||
Get the maximum input tokens for the optimizer model.
|
||||
|
||||
Returns:
|
||||
Maximum number of input tokens the model can handle
|
||||
"""
|
||||
model_info = model_cost.get(self.optimizer_model, {})
|
||||
# Try without provider prefix if not found
|
||||
if not model_info:
|
||||
model_name = self.optimizer_model.split("/")[-1]
|
||||
model_info = model_cost.get(model_name, {})
|
||||
return model_info.get("max_input_tokens", 128000) # Default to 128k
|
||||
|
||||
def calculate_samples_that_fit(
|
||||
self,
|
||||
op_config: dict[str, Any],
|
||||
outputs: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Calculate how many output samples fit in the context window.
|
||||
|
||||
Reserves space for system prompt, operation config, and response buffer,
|
||||
then fills remaining space with as many samples as possible.
|
||||
|
||||
Args:
|
||||
op_config: The operation configuration dictionary
|
||||
outputs: List of all output documents
|
||||
|
||||
Returns:
|
||||
List of samples that fit in the context window
|
||||
"""
|
||||
max_tokens = self.get_max_context_tokens()
|
||||
|
||||
# Reserve tokens for fixed parts
|
||||
system_prompt_tokens = 500
|
||||
op_config_tokens = count_tokens(
|
||||
json.dumps(op_config, default=str), self.optimizer_model
|
||||
)
|
||||
response_buffer = 2000
|
||||
|
||||
available_for_samples = (
|
||||
max_tokens - system_prompt_tokens - op_config_tokens - response_buffer
|
||||
)
|
||||
|
||||
# Collect samples that fit
|
||||
samples_to_include = []
|
||||
tokens_used = 0
|
||||
|
||||
for output in outputs:
|
||||
sample_json = json.dumps(output, default=str)
|
||||
sample_tokens = count_tokens(sample_json, self.optimizer_model)
|
||||
if tokens_used + sample_tokens <= available_for_samples:
|
||||
samples_to_include.append(output)
|
||||
tokens_used += sample_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
return samples_to_include
|
||||
|
||||
def build_analysis_prompt(
|
||||
self,
|
||||
op_config: dict[str, Any],
|
||||
samples: list[dict[str, Any]],
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Build the system prompt and user prompt for the analysis LLM call.
|
||||
|
||||
Args:
|
||||
op_config: The operation configuration dictionary
|
||||
samples: List of output samples to analyze
|
||||
|
||||
Returns:
|
||||
Tuple of (system_prompt, user_prompt)
|
||||
"""
|
||||
system_prompt = """You are an expert at analyzing LLM-powered data processing operations.
|
||||
Your task is to determine if an operation would benefit from being decomposed into multiple
|
||||
smaller, focused operations (also known as "optimization" or "decomposition").
|
||||
|
||||
An operation SHOULD be decomposed when:
|
||||
1. The prompt asks the LLM to do multiple distinct tasks that could be done separately
|
||||
2. The task is complex enough that breaking it into sequential steps would improve accuracy
|
||||
3. The outputs show inconsistency, incompleteness, or quality issues that iterative refinement could fix
|
||||
4. Long documents need to be processed in chunks rather than all at once
|
||||
5. The prompt asks for both extraction AND analysis/synthesis in one step
|
||||
|
||||
An operation should NOT be decomposed when:
|
||||
1. It performs a single, focused task well
|
||||
2. The outputs are consistently high quality and complete
|
||||
3. The task is simple and atomic (e.g., simple classification, single field extraction)
|
||||
4. The operation is already well-scoped and produces reliable results
|
||||
|
||||
Be conservative - only recommend decomposition if there's clear evidence it would help."""
|
||||
|
||||
output_schema = op_config.get("output", {}).get("schema", {})
|
||||
prompt_template = op_config.get("prompt", "No prompt specified")
|
||||
|
||||
# Truncate very long prompts for display
|
||||
if len(prompt_template) > 3000:
|
||||
prompt_template = prompt_template[:3000] + "\n... [truncated]"
|
||||
|
||||
user_prompt = f"""Analyze this data processing operation and its outputs:
|
||||
|
||||
## Operation Configuration
|
||||
|
||||
**Name:** {op_config.get('name', 'unknown')}
|
||||
**Type:** {op_config.get('type', 'unknown')}
|
||||
|
||||
**Prompt Template:**
|
||||
```
|
||||
{prompt_template}
|
||||
```
|
||||
|
||||
**Output Schema:**
|
||||
```json
|
||||
{json.dumps(output_schema, indent=2)}
|
||||
```
|
||||
|
||||
## Sample Outputs ({len(samples)} samples from the operation)
|
||||
|
||||
```json
|
||||
{json.dumps(samples, indent=2, default=str)}
|
||||
```
|
||||
|
||||
## Your Task
|
||||
|
||||
Based on the operation configuration and sample outputs, determine:
|
||||
1. Should this operation be decomposed/optimized?
|
||||
2. If yes, what specific improvements would help?
|
||||
|
||||
Consider the quality, completeness, and consistency of the outputs when making your assessment."""
|
||||
|
||||
return system_prompt, user_prompt
|
||||
|
||||
def analyze(
|
||||
self,
|
||||
op_config: dict[str, Any],
|
||||
step_name: str,
|
||||
op_name: str,
|
||||
) -> tuple[str, list[dict[str, Any]], int, float]:
|
||||
"""
|
||||
Analyze whether an operation should be optimized.
|
||||
|
||||
This is the main entry point. It loads outputs, builds the prompt,
|
||||
makes a single LLM call, and returns the assessment.
|
||||
|
||||
Args:
|
||||
op_config: The operation configuration dictionary
|
||||
step_name: Name of the pipeline step
|
||||
op_name: Name of the operation
|
||||
|
||||
Returns:
|
||||
Tuple of (rationale, output_samples, num_docs_analyzed, cost):
|
||||
- rationale: Empty string if no optimization needed, explanation if it should be optimized
|
||||
- output_samples: The samples that were analyzed
|
||||
- num_docs_analyzed: Number of documents that fit in the LLM prompt
|
||||
- cost: LLM API cost in USD
|
||||
|
||||
Raises:
|
||||
ValueError: If the operation is not an LLM-powered map operation
|
||||
"""
|
||||
# Validate operation type - only LLM-powered map operations
|
||||
op_type = op_config.get("type", "")
|
||||
if op_type != "map":
|
||||
raise ValueError(
|
||||
f"should_optimize only supports 'map' operations, got '{op_type}'. "
|
||||
"Only LLM-powered map operations can be analyzed for decomposition."
|
||||
)
|
||||
|
||||
# Load outputs from intermediate file
|
||||
outputs = self.load_operation_data(step_name, op_name)
|
||||
|
||||
if not outputs:
|
||||
return "No output samples available for analysis.", [], 0, 0.0
|
||||
|
||||
# Calculate samples that fit in context
|
||||
samples = self.calculate_samples_that_fit(op_config, outputs)
|
||||
|
||||
if not samples:
|
||||
return "Could not fit any samples in context window.", outputs[:5], 0, 0.0
|
||||
|
||||
# Build prompt
|
||||
system_prompt, user_prompt = self.build_analysis_prompt(op_config, samples)
|
||||
|
||||
# Make LLM call with structured output
|
||||
response = completion(
|
||||
model=self.optimizer_model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "optimization_analysis",
|
||||
"strict": True,
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"should_optimize": {
|
||||
"type": "boolean",
|
||||
"description": "True if operation should be decomposed/optimized",
|
||||
},
|
||||
"rationale": {
|
||||
"type": "string",
|
||||
"description": "Explanation of why the operation should or should not be optimized",
|
||||
},
|
||||
"suggested_improvements": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Specific improvements if optimization is recommended (empty if not)",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"should_optimize",
|
||||
"rationale",
|
||||
"suggested_improvements",
|
||||
],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
**self.litellm_kwargs,
|
||||
)
|
||||
|
||||
# Calculate cost
|
||||
cost = completion_cost(response)
|
||||
|
||||
# Parse response
|
||||
result = json.loads(response.choices[0].message.content)
|
||||
|
||||
num_docs_analyzed = len(samples)
|
||||
|
||||
if result["should_optimize"]:
|
||||
# Build rationale string with improvements
|
||||
rationale_parts = [result["rationale"]]
|
||||
if result["suggested_improvements"]:
|
||||
rationale_parts.append("\n\nSuggested improvements:")
|
||||
for imp in result["suggested_improvements"]:
|
||||
rationale_parts.append(f"- {imp}")
|
||||
return "\n".join(rationale_parts), samples, num_docs_analyzed, cost
|
||||
else:
|
||||
# Return empty string to indicate no optimization needed
|
||||
return "", samples, num_docs_analyzed, cost
|
||||
|
|
@ -30,7 +30,7 @@ class LLMClient:
|
|||
Initialize the LLMClient.
|
||||
|
||||
Args:
|
||||
model (str, optional): The name of the LLM model to use. Defaults to "gpt-4o".
|
||||
model (str, optional): The name of the LLM model to use. Defaults to "gpt-5.1".
|
||||
**litellm_kwargs: Additional keyword arguments for the LLM model.
|
||||
"""
|
||||
self.rewrite_agent_model = rewrite_agent_model
|
||||
|
|
|
|||
|
|
@ -204,10 +204,6 @@ def get_openai_response(
|
|||
response = litellm.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=ResponseFormat,
|
||||
)
|
||||
assistant_response = response.choices[0].message.content
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ from .map_reduce_fusion import MapReduceFusionDirective
|
|||
from .hierarchical_reduce import HierarchicalReduceDirective
|
||||
from .cascade_filtering import CascadeFilteringDirective
|
||||
from .arbitrary_rewrite import ArbitraryRewriteDirective
|
||||
from .map_to_map_resolve_reduce import MapToMapResolveReduceDirective
|
||||
from .map_resolve_to_map_with_categories import MapResolveToMapWithCategoriesDirective
|
||||
|
||||
# Registry of all available directives
|
||||
ALL_DIRECTIVES = [
|
||||
|
|
@ -53,6 +55,8 @@ ALL_DIRECTIVES = [
|
|||
HierarchicalReduceDirective(),
|
||||
CascadeFilteringDirective(),
|
||||
ArbitraryRewriteDirective(),
|
||||
MapToMapResolveReduceDirective(),
|
||||
MapResolveToMapWithCategoriesDirective(),
|
||||
]
|
||||
|
||||
ALL_COST_DIRECTIVES = [
|
||||
|
|
@ -179,8 +183,10 @@ __all__ = [
|
|||
"HierarchicalReduceDirective",
|
||||
"CascadeFilteringDirective",
|
||||
"ArbitraryRewriteDirective",
|
||||
"MapToMapResolveReduceDirective",
|
||||
"MapResolveToMapWithCategoriesDirective",
|
||||
"ALL_DIRECTIVES",
|
||||
"DIRECTIVE_REGISTRY",
|
||||
"DIRECTIVE_REGISTRY",
|
||||
"get_all_directive_strings",
|
||||
"instantiate_directive"
|
||||
]
|
||||
|
|
@ -324,10 +324,6 @@ Focus on quality over quantity - a few diverse, informative examples are better
|
|||
model=self.agent_llm,
|
||||
messages=self.message_history,
|
||||
response_format=AgentDecision,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
)
|
||||
call_cost += response._hidden_params["response_cost"]
|
||||
|
||||
|
|
@ -415,10 +411,6 @@ Provide your response as a JSON object matching this schema: {response_schema.mo
|
|||
model=self.agent_llm,
|
||||
messages=self.message_history,
|
||||
response_format=response_schema,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
)
|
||||
call_cost += schema_response._hidden_params["response_cost"]
|
||||
|
||||
|
|
|
|||
|
|
@ -123,7 +123,8 @@ class Directive(BaseModel, ABC):
|
|||
|
||||
try:
|
||||
# 1. Execute the directive
|
||||
actual_output, _ = self.instantiate(
|
||||
# instantiate returns (new_ops_plan, message_history, call_cost)
|
||||
instantiate_result = self.instantiate(
|
||||
operators=(
|
||||
[test_case.input_config]
|
||||
if isinstance(test_case.input_config, dict)
|
||||
|
|
@ -134,6 +135,11 @@ class Directive(BaseModel, ABC):
|
|||
input_file_path=temp_file_path,
|
||||
pipeline_code=fake_pipeline,
|
||||
)
|
||||
# Handle both 2-tuple and 3-tuple returns
|
||||
if isinstance(instantiate_result, tuple):
|
||||
actual_output = instantiate_result[0]
|
||||
else:
|
||||
actual_output = instantiate_result
|
||||
|
||||
# 2. Use LLM judge to evaluate
|
||||
judge_result = self._llm_judge_test(
|
||||
|
|
@ -216,7 +222,6 @@ class Directive(BaseModel, ABC):
|
|||
model=agent_llm,
|
||||
messages=messages,
|
||||
response_format=JudgeResponse,
|
||||
azure=True,
|
||||
)
|
||||
|
||||
# Parse the JSON response
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
|
@ -169,19 +168,15 @@ class ChainingDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
# api_key=os.environ["GEMINI_API_KEY"],
|
||||
azure=True,
|
||||
response_format=ChainingInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
|
|
@ -202,11 +197,12 @@ class ChainingDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -211,18 +210,14 @@ class ChangeModelDirective(Directive):
|
|||
},
|
||||
]
|
||||
)
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
# api_key=os.environ["GEMINI_API_KEY"],
|
||||
azure=True,
|
||||
response_format=ChangeModelInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = ChangeModelInstantiateSchema(**parsed_res)
|
||||
|
|
@ -240,11 +235,12 @@ class ChangeModelDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -205,17 +204,14 @@ class ChangeModelAccDirective(Directive):
|
|||
},
|
||||
]
|
||||
)
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=ChangeModelInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = ChangeModelInstantiateSchema(**parsed_res)
|
||||
|
|
@ -233,11 +229,12 @@ class ChangeModelAccDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -282,17 +281,14 @@ class ChangeModelCostDirective(Directive):
|
|||
},
|
||||
]
|
||||
)
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=ChangeModelInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = ChangeModelInstantiateSchema(**parsed_res)
|
||||
|
|
@ -310,11 +306,12 @@ class ChangeModelCostDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -185,17 +184,14 @@ class ChunkHeaderSummaryDirective(Directive):
|
|||
},
|
||||
]
|
||||
)
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=ChunkHeaderSummaryInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = ChunkHeaderSummaryInstantiateSchema(**parsed_res)
|
||||
|
|
@ -204,11 +200,12 @@ class ChunkHeaderSummaryDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
@ -251,9 +248,7 @@ class ChunkHeaderSummaryDirective(Directive):
|
|||
)
|
||||
|
||||
if gather_idx - split_idx > 1:
|
||||
raise ValueError(
|
||||
"There should not be operators between split and gather"
|
||||
)
|
||||
raise ValueError("There should not be operators between split and gather")
|
||||
|
||||
# Get the split_key from the split operation
|
||||
split_key = split_op.get("split_key")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -261,17 +260,14 @@ def transform(input_doc):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=DeterministicDocCompressionInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = DeterministicDocCompressionInstantiateSchema(**parsed_res)
|
||||
|
|
@ -284,11 +280,12 @@ def transform(input_doc):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -268,18 +267,15 @@ class DocumentChunkingDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
call_cost = 0.0
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=DocumentChunkingInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = DocumentChunkingInstantiateSchema(**parsed_res)
|
||||
|
|
@ -290,11 +286,12 @@ class DocumentChunkingDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
@ -404,16 +401,19 @@ class DocumentChunkingDirective(Directive):
|
|||
sample_op["stratify_key"] = stratify_keys
|
||||
|
||||
if rewrite.sampling_config.samples_per_group:
|
||||
sample_op["samples_per_group"] = rewrite.sampling_config.samples_per_group
|
||||
|
||||
|
||||
sample_op["samples_per_group"] = (
|
||||
rewrite.sampling_config.samples_per_group
|
||||
)
|
||||
|
||||
# Add optional fields if provided
|
||||
if rewrite.sampling_config.random_state is not None:
|
||||
sample_op["random_state"] = rewrite.sampling_config.random_state
|
||||
|
||||
|
||||
if rewrite.sampling_config.method_kwargs is not None:
|
||||
try:
|
||||
sample_op["method_kwargs"] = json.loads(rewrite.sampling_config.method_kwargs)
|
||||
sample_op["method_kwargs"] = json.loads(
|
||||
rewrite.sampling_config.method_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid method_kwargs: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -431,17 +430,14 @@ class DocumentChunkingTopKDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=DocumentChunkingTopKInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = DocumentChunkingTopKInstantiateSchema(**parsed_res)
|
||||
|
|
@ -451,11 +447,12 @@ class DocumentChunkingTopKDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -164,19 +163,14 @@ class DocCompressionDirective(Directive):
|
|||
},
|
||||
]
|
||||
)
|
||||
error_message = ""
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=DocCompressionInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
|
|
@ -187,11 +181,12 @@ class DocCompressionDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Error: {error_message}"
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -261,17 +260,14 @@ class DocSummarizationDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=DocSummarizationInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = DocSummarizationInstantiateSchema(**parsed_res)
|
||||
|
|
@ -280,11 +276,12 @@ class DocSummarizationDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -146,32 +145,31 @@ class GleaningDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
# api_key=os.environ["GEMINI_API_KEY"],
|
||||
azure=True,
|
||||
response_format=GleaningInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = GleaningInstantiateSchema(**parsed_res)
|
||||
GleaningInstantiateSchema.check_no_jinja_variables(schema.validation_prompt)
|
||||
GleaningInstantiateSchema.check_no_jinja_variables(
|
||||
schema.validation_prompt
|
||||
)
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -180,19 +179,14 @@ class HierarchicalReduceDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=HierarchicalReduceInstantiateSchema,
|
||||
)
|
||||
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
|
|
@ -214,11 +208,12 @@ class HierarchicalReduceDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
@ -259,7 +254,7 @@ class HierarchicalReduceDirective(Directive):
|
|||
"name": rewrite.map_config.name,
|
||||
"type": "map",
|
||||
"prompt": rewrite.map_config.prompt,
|
||||
"model": default_model,
|
||||
"model": default_model,
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
"output": {"schema": {rewrite.map_config.output_keys[0]: "string"}},
|
||||
}
|
||||
|
|
@ -324,4 +319,4 @@ class HierarchicalReduceDirective(Directive):
|
|||
new_ops_plan = self.apply(
|
||||
global_default_model, operators, target_ops[0], rewrite
|
||||
)
|
||||
return new_ops_plan, message_history, call_cost
|
||||
return new_ops_plan, message_history, call_cost
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
|
@ -260,17 +259,14 @@ class IsolatingSubtasksDirective(Directive):
|
|||
original_op.get("output", {}).get("schema", {}).keys()
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=IsolatingSubtasksInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = IsolatingSubtasksInstantiateSchema(**parsed_res)
|
||||
|
|
@ -285,11 +281,12 @@ class IsolatingSubtasksDirective(Directive):
|
|||
return schema, message_history, call_cost
|
||||
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease ensure all original output keys are covered by subtasks and try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
|
@ -7,14 +6,20 @@ from typing import Dict, List, Type
|
|||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import MapReduceFusionInstantiateSchema
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
MapReduceFusionInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class MapReduceFusionDirective(Directive):
|
||||
name: str = Field(default="map_reduce_fusion", description="The name of the directive")
|
||||
formal_description: str = Field(default="Map -> Reduce => Map (with new prompt) -> Reduce (with new propmt)")
|
||||
name: str = Field(
|
||||
default="map_reduce_fusion", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(
|
||||
default="Map -> Reduce => Map (with new prompt) -> Reduce (with new propmt)"
|
||||
)
|
||||
nl_description: str = Field(
|
||||
default="Transform a Map operation followed by a Reduce operation over long documents by updating the Map prompt to first extract or process only the relevant information from each document. Then, modify the Reduce prompt to operate on these processed outputs instead of the full document content."
|
||||
)
|
||||
|
|
@ -42,7 +47,7 @@ class MapReduceFusionDirective(Directive):
|
|||
" type: reduce\\n"
|
||||
" reduce_key: category\\n"
|
||||
" prompt: |\\n"
|
||||
" For each category \\\"{{ reduce_key }}\\\", extract all organization names from these documents:\\n"
|
||||
' For each category \\"{{ reduce_key }}\\", extract all organization names from these documents:\\n'
|
||||
" {% for input in inputs %}\\n"
|
||||
" Document {{ loop.index }}: {{ input.content }}\\n"
|
||||
" {% endfor %}\\n"
|
||||
|
|
@ -79,7 +84,7 @@ class MapReduceFusionDirective(Directive):
|
|||
"reduce_key": "category",
|
||||
"prompt": "For each category '{{ reduce_key }}', extract organizations from: {% for input in inputs %}{{ input.content }}{% endfor %}",
|
||||
"output": {"schema": {"organizations": "list[str]"}},
|
||||
}
|
||||
},
|
||||
],
|
||||
target_ops=["classify_document", "extract_organizations"],
|
||||
expected_behavior="Should modify map to classify AND extract organizations per document, then reduce to aggregate pre-extracted organizations",
|
||||
|
|
@ -101,7 +106,7 @@ class MapReduceFusionDirective(Directive):
|
|||
"reduce_key": "doc_type",
|
||||
"prompt": "For each document type '{{ reduce_key }}', find all people mentioned: {% for input in inputs %}Document: {{ input.text }}{% endfor %}",
|
||||
"output": {"schema": {"people": "list[str]"}},
|
||||
}
|
||||
},
|
||||
],
|
||||
target_ops=["analyze_document", "find_people"],
|
||||
expected_behavior="Should modify map to analyze type AND extract people per document, then reduce to combine pre-extracted people lists",
|
||||
|
|
@ -159,7 +164,7 @@ class MapReduceFusionDirective(Directive):
|
|||
reduce_op: Dict,
|
||||
expected_document_key,
|
||||
agent_llm: str,
|
||||
message_history: list = []
|
||||
message_history: list = [],
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive by transforming the map and reduce operations.
|
||||
|
|
@ -173,66 +178,85 @@ class MapReduceFusionDirective(Directive):
|
|||
Returns:
|
||||
MapReduceFusionInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend([
|
||||
{"role": "system", "content": "You are a helpful AI assistant for optimizing document processing pipelines."},
|
||||
{"role": "user", "content": self.to_string_for_instantiate([map_op, reduce_op])},
|
||||
])
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant for optimizing document processing pipelines.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate([map_op, reduce_op]),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=MapReduceFusionInstantiateSchema
|
||||
response_format=MapReduceFusionInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
if "new_map_name" not in parsed_res or "new_map_prompt" not in parsed_res or "new_key" not in parsed_res or "new_reduce_prompt" not in parsed_res:
|
||||
if (
|
||||
"new_map_name" not in parsed_res
|
||||
or "new_map_prompt" not in parsed_res
|
||||
or "new_key" not in parsed_res
|
||||
or "new_reduce_prompt" not in parsed_res
|
||||
):
|
||||
raise ValueError(
|
||||
"Response from LLM is missing required keys: 'new_map_name', 'new_map_prompt', 'new_key', or 'new_reduce_prompt'"
|
||||
)
|
||||
|
||||
|
||||
schema = MapReduceFusionInstantiateSchema(
|
||||
new_map_name=parsed_res["new_map_name"],
|
||||
new_map_prompt=parsed_res["new_map_prompt"],
|
||||
new_key=parsed_res["new_key"],
|
||||
new_reduce_prompt=parsed_res["new_reduce_prompt"]
|
||||
new_reduce_prompt=parsed_res["new_reduce_prompt"],
|
||||
)
|
||||
|
||||
# Validate the schema
|
||||
MapReduceFusionInstantiateSchema.validate_reduce_prompt_references_new_key(
|
||||
schema.new_reduce_prompt, schema.new_key, expected_document_key
|
||||
)
|
||||
|
||||
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts.")
|
||||
|
||||
def apply(self, global_default_model, ops_list: List[Dict], map_target: str, reduce_target: str, rewrite: MapReduceFusionInstantiateSchema) -> List[Dict]:
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model,
|
||||
ops_list: List[Dict],
|
||||
map_target: str,
|
||||
reduce_target: str,
|
||||
rewrite: MapReduceFusionInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config.
|
||||
"""
|
||||
# Create a copy of the pipeline config
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
|
||||
# Find positions of the target ops
|
||||
map_pos = None
|
||||
reduce_pos = None
|
||||
orig_map_op = None
|
||||
orig_reduce_op = None
|
||||
|
||||
|
||||
for i, op in enumerate(ops_list):
|
||||
if op["name"] == map_target:
|
||||
map_pos = i
|
||||
|
|
@ -240,13 +264,15 @@ class MapReduceFusionDirective(Directive):
|
|||
elif op["name"] == reduce_target:
|
||||
reduce_pos = i
|
||||
orig_reduce_op = op
|
||||
|
||||
|
||||
if map_pos is None or reduce_pos is None:
|
||||
raise ValueError(f"Could not find target operations: {map_target}, {reduce_target}")
|
||||
|
||||
raise ValueError(
|
||||
f"Could not find target operations: {map_target}, {reduce_target}"
|
||||
)
|
||||
|
||||
# Get default model
|
||||
default_model = orig_map_op.get("model", global_default_model)
|
||||
|
||||
|
||||
# Create the new map operation with fused functionality
|
||||
new_map_op = {
|
||||
"name": rewrite.new_map_name,
|
||||
|
|
@ -257,11 +283,11 @@ class MapReduceFusionDirective(Directive):
|
|||
"output": {
|
||||
"schema": {
|
||||
**orig_map_op.get("output", {}).get("schema", {}),
|
||||
rewrite.new_key: "list[str]"
|
||||
rewrite.new_key: "list[str]",
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Create the new reduce operation that works with pre-extracted data
|
||||
new_reduce_op = {
|
||||
"name": orig_reduce_op["name"],
|
||||
|
|
@ -270,54 +296,58 @@ class MapReduceFusionDirective(Directive):
|
|||
"prompt": rewrite.new_reduce_prompt,
|
||||
"model": default_model,
|
||||
"litellm_completion_kwargs": {"temperature": 0},
|
||||
"output": orig_reduce_op.get("output", {})
|
||||
"output": orig_reduce_op.get("output", {}),
|
||||
}
|
||||
|
||||
|
||||
# Replace the operations
|
||||
new_ops_list[map_pos] = new_map_op
|
||||
new_ops_list[reduce_pos] = new_reduce_op
|
||||
|
||||
|
||||
return new_ops_list
|
||||
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
optimize_goal="acc",
|
||||
global_default_model: str = None,
|
||||
**kwargs
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
optimize_goal="acc",
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Assert that there are exactly two target ops (map and reduce)
|
||||
if len(target_ops) != 2:
|
||||
raise ValueError("map_reduce_fusion directive requires exactly two target operators: a map and a reduce operation")
|
||||
|
||||
raise ValueError(
|
||||
"map_reduce_fusion directive requires exactly two target operators: a map and a reduce operation"
|
||||
)
|
||||
|
||||
# Find the map and reduce operations
|
||||
first_op = None
|
||||
second_op = None
|
||||
|
||||
|
||||
for op in operators:
|
||||
if op["name"] == target_ops[0]:
|
||||
first_op = op
|
||||
elif op["name"] == target_ops[1]:
|
||||
second_op = op
|
||||
|
||||
|
||||
if first_op is None or second_op is None:
|
||||
raise ValueError(f"Could not find target operations: {target_ops}")
|
||||
|
||||
|
||||
if first_op.get("type") == "map" and second_op.get("type") == "reduce":
|
||||
map_op = first_op
|
||||
reduce_op = second_op
|
||||
map_target = target_ops[0]
|
||||
reduce_target = target_ops[1]
|
||||
else:
|
||||
raise ValueError("Target operators must be one map operation followed by one reduce operation!")
|
||||
|
||||
# Extract expected document key from the reduce prompt template
|
||||
raise ValueError(
|
||||
"Target operators must be one map operation followed by one reduce operation!"
|
||||
)
|
||||
|
||||
# Extract expected document key from the reduce prompt template
|
||||
prompt_template = reduce_op["prompt"]
|
||||
# Find all occurrences of {{ input.key }} in the prompt
|
||||
input_key_pattern = r"\{\{\s*([^\}\s]+)\s*\}\}"
|
||||
|
|
@ -335,7 +365,9 @@ class MapReduceFusionDirective(Directive):
|
|||
]
|
||||
|
||||
if document_key_candidates:
|
||||
expected_document_key = document_key_candidates[0] # Pick the first candidate
|
||||
expected_document_key = document_key_candidates[
|
||||
0
|
||||
] # Pick the first candidate
|
||||
elif input_keys:
|
||||
expected_document_key = input_keys[0] # Fall back to the first input key
|
||||
else:
|
||||
|
|
@ -344,8 +376,12 @@ class MapReduceFusionDirective(Directive):
|
|||
print(f"Detected document key: {expected_document_key}")
|
||||
|
||||
# Instantiate the directive
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(map_op, reduce_op, expected_document_key, agent_llm, message_history)
|
||||
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
map_op, reduce_op, expected_document_key, agent_llm, message_history
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators
|
||||
new_ops_plan = self.apply(global_default_model, operators, map_target, reduce_target, rewrite)
|
||||
return new_ops_plan, message_history, call_cost
|
||||
new_ops_plan = self.apply(
|
||||
global_default_model, operators, map_target, reduce_target, rewrite
|
||||
)
|
||||
return new_ops_plan, message_history, call_cost
|
||||
|
|
|
|||
|
|
@ -0,0 +1,361 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
MapResolveToMapWithCategoriesInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class MapResolveToMapWithCategoriesDirective(Directive):
|
||||
name: str = Field(
|
||||
default="map_resolve_to_map_with_categories",
|
||||
description="The name of the directive",
|
||||
)
|
||||
formal_description: str = Field(default="Map -> Resolve => Map (with categories)")
|
||||
nl_description: str = Field(
|
||||
default="Replace a Map -> Resolve pattern with a single Map operation that has predefined categories. The agent analyzes the data and task to propose a set of canonical categories, and the new Map forces outputs into one of these categories (or 'none of the above'). This effectively performs entity resolution deterministically by standardizing outputs upfront, avoiding the need for pairwise comparisons."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When a Map operation produces outputs that are then resolved/deduplicated, and the set of valid output categories can be enumerated upfront. This is more efficient than Resolve when the category space is small and well-defined (e.g., standardizing company types, product categories, sentiment labels). The target must be a Map operation followed by a Resolve operation."
|
||||
)
|
||||
instantiate_schema_type: Type[BaseModel] = (
|
||||
MapResolveToMapWithCategoriesInstantiateSchema
|
||||
)
|
||||
|
||||
example: str = Field(
|
||||
default=(
|
||||
"Original Pipeline:\n"
|
||||
"- name: extract_sentiment\n"
|
||||
" type: map\n"
|
||||
" prompt: |\n"
|
||||
" What is the sentiment of this review?\n"
|
||||
" {{ input.review }}\n"
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
" sentiment: string\n"
|
||||
"\n"
|
||||
"- name: normalize_sentiment\n"
|
||||
" type: resolve\n"
|
||||
" comparison_prompt: |\n"
|
||||
" Are these sentiments equivalent?\n"
|
||||
" Sentiment 1: {{ input1.sentiment }}\n"
|
||||
" Sentiment 2: {{ input2.sentiment }}\n"
|
||||
" resolution_prompt: |\n"
|
||||
" Normalize these sentiments:\n"
|
||||
" {% for input in inputs %}\n"
|
||||
" - {{ input.sentiment }}\n"
|
||||
" {% endfor %}\n"
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
" sentiment: string\n"
|
||||
"\n"
|
||||
"Example InstantiateSchema:\n"
|
||||
"MapResolveToMapWithCategoriesInstantiateSchema(\n"
|
||||
" categories=['Positive', 'Negative', 'Neutral', 'Mixed'],\n"
|
||||
" category_key='sentiment',\n"
|
||||
" new_prompt='''Analyze the sentiment of this review and classify it into one of the following categories:\n"
|
||||
"\n"
|
||||
"Categories:\n"
|
||||
"- Positive: Clearly positive sentiment, satisfaction, praise\n"
|
||||
"- Negative: Clearly negative sentiment, complaints, criticism\n"
|
||||
"- Neutral: No strong sentiment, factual statements\n"
|
||||
"- Mixed: Contains both positive and negative elements\n"
|
||||
"- None of the above: If the review doesn't fit any category\n"
|
||||
"\n"
|
||||
"Review: {{ input.review }}\n"
|
||||
"\n"
|
||||
"Return exactly one of: Positive, Negative, Neutral, Mixed, or None of the above.''',\n"
|
||||
" include_none_of_above=True,\n"
|
||||
")"
|
||||
),
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="sentiment_categorization",
|
||||
description="Should replace map+resolve with categorized map for sentiment",
|
||||
input_config=[
|
||||
{
|
||||
"name": "extract_sentiment",
|
||||
"type": "map",
|
||||
"prompt": "What is the sentiment? {{ input.text }}",
|
||||
"output": {"schema": {"sentiment": "string"}},
|
||||
},
|
||||
{
|
||||
"name": "normalize_sentiment",
|
||||
"type": "resolve",
|
||||
"comparison_prompt": "Same sentiment? {{ input1.sentiment }} vs {{ input2.sentiment }}",
|
||||
"resolution_prompt": "Normalize: {% for input in inputs %}{{ input.sentiment }}{% endfor %}",
|
||||
"output": {"schema": {"sentiment": "string"}},
|
||||
},
|
||||
],
|
||||
target_ops=["extract_sentiment", "normalize_sentiment"],
|
||||
expected_behavior="Should create a single map with predefined sentiment categories",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, MapResolveToMapWithCategoriesDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("MapResolveToMapWithCategoriesDirective")
|
||||
|
||||
def to_string_for_instantiate(
|
||||
self, map_op: Dict, resolve_op: Dict, sample_data: List[Dict] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive.
|
||||
|
||||
Args:
|
||||
map_op (Dict): The map operation configuration.
|
||||
resolve_op (Dict): The resolve operation configuration.
|
||||
sample_data (List[Dict], optional): Sample data to help identify categories.
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
sample_str = ""
|
||||
if sample_data:
|
||||
sample_str = f"\n\nSample Input Data (first 5 items):\n{json.dumps(sample_data[:5], indent=2)}"
|
||||
|
||||
return (
|
||||
f"You are an expert at optimizing data processing pipelines by replacing entity resolution with categorical constraints.\n\n"
|
||||
f"Map Operation:\n"
|
||||
f"{str(map_op)}\n\n"
|
||||
f"Resolve Operation:\n"
|
||||
f"{str(resolve_op)}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to replace the Map -> Resolve pattern with a single Map that uses predefined categories.\n\n"
|
||||
f"Key Requirements:\n"
|
||||
f"1. Analyze the map's output field and the resolve operation to understand what values are being normalized:\n"
|
||||
f" - Look at the comparison_prompt to understand what variations are being matched\n"
|
||||
f" - Look at the resolution_prompt to understand the canonical form\n\n"
|
||||
f"2. Propose a set of categories that cover all expected outputs:\n"
|
||||
f" - Categories should be mutually exclusive\n"
|
||||
f" - Categories should be exhaustive (cover all realistic cases)\n"
|
||||
f" - Consider including 'None of the above' for edge cases\n\n"
|
||||
f"3. Optionally provide descriptions for each category to help the LLM classify correctly\n\n"
|
||||
f"4. Create a new_prompt that:\n"
|
||||
f" - Lists all valid categories with their descriptions\n"
|
||||
f" - Instructs the LLM to output exactly one category\n"
|
||||
f" - References the input using {{{{ input.key }}}} syntax\n"
|
||||
f" - Includes any context from the original map prompt\n\n"
|
||||
f"5. Identify the category_key (the output field that will contain the category)\n\n"
|
||||
f"Benefits of this approach:\n"
|
||||
f"- Eliminates O(n^2) pairwise comparisons from Resolve\n"
|
||||
f"- Produces consistent, standardized outputs\n"
|
||||
f"- Reduces cost by removing the Resolve operation entirely\n"
|
||||
f"{sample_str}\n\n"
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please analyze the operations and propose appropriate categories. Output the MapResolveToMapWithCategoriesInstantiateSchema."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self,
|
||||
map_op: Dict,
|
||||
resolve_op: Dict,
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
sample_data: List[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive.
|
||||
|
||||
Args:
|
||||
map_op (Dict): The map operation configuration.
|
||||
resolve_op (Dict): The resolve operation configuration.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
sample_data (List[Dict], optional): Sample data to help identify categories.
|
||||
|
||||
Returns:
|
||||
MapResolveToMapWithCategoriesInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant for document processing pipelines specializing in categorical classification.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(
|
||||
map_op, resolve_op, sample_data
|
||||
),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=MapResolveToMapWithCategoriesInstantiateSchema,
|
||||
)
|
||||
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = MapResolveToMapWithCategoriesInstantiateSchema(**parsed_res)
|
||||
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model,
|
||||
ops_list: List[Dict],
|
||||
map_op_name: str,
|
||||
resolve_op_name: str,
|
||||
rewrite: MapResolveToMapWithCategoriesInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config.
|
||||
"""
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find position of the map and resolve ops
|
||||
map_pos = None
|
||||
resolve_pos = None
|
||||
map_op = None
|
||||
|
||||
for i, op in enumerate(new_ops_list):
|
||||
if op["name"] == map_op_name:
|
||||
map_pos = i
|
||||
map_op = op
|
||||
elif op["name"] == resolve_op_name:
|
||||
resolve_pos = i
|
||||
|
||||
if map_pos is None or resolve_pos is None:
|
||||
raise ValueError(
|
||||
f"Could not find map '{map_op_name}' and resolve '{resolve_op_name}' operations"
|
||||
)
|
||||
|
||||
# Determine the model to use
|
||||
default_model = map_op.get("model", global_default_model)
|
||||
|
||||
# Build the list of valid values for validation
|
||||
valid_values = list(rewrite.categories)
|
||||
if rewrite.include_none_of_above:
|
||||
valid_values.append("None of the above")
|
||||
|
||||
# Modify the map operation with the new prompt and add validation
|
||||
new_ops_list[map_pos]["prompt"] = rewrite.new_prompt
|
||||
new_ops_list[map_pos]["model"] = default_model
|
||||
|
||||
# Add validation to ensure output is one of the categories
|
||||
if "validate" not in new_ops_list[map_pos]:
|
||||
new_ops_list[map_pos]["validate"] = []
|
||||
|
||||
# Add validation rule for the category key
|
||||
validation_rule = f"output['{rewrite.category_key}'] in {valid_values}"
|
||||
new_ops_list[map_pos]["validate"].append(validation_rule)
|
||||
|
||||
# Update the output schema to reflect the category key
|
||||
if "output" not in new_ops_list[map_pos]:
|
||||
new_ops_list[map_pos]["output"] = {"schema": {}}
|
||||
new_ops_list[map_pos]["output"]["schema"][rewrite.category_key] = "string"
|
||||
|
||||
# Remove the resolve operation
|
||||
new_ops_list.pop(resolve_pos)
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
optimize_goal="acc",
|
||||
global_default_model: str = None,
|
||||
dataset: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Assert that there are exactly two target ops (map and resolve)
|
||||
assert (
|
||||
len(target_ops) == 2
|
||||
), "There must be exactly two target ops (map and resolve) to instantiate this directive"
|
||||
|
||||
# Find the map and resolve operations
|
||||
map_op = None
|
||||
resolve_op = None
|
||||
|
||||
for op in operators:
|
||||
if op["name"] == target_ops[0]:
|
||||
if op.get("type") == "map":
|
||||
map_op = op
|
||||
elif op.get("type") == "resolve":
|
||||
resolve_op = op
|
||||
elif op["name"] == target_ops[1]:
|
||||
if op.get("type") == "map":
|
||||
map_op = op
|
||||
elif op.get("type") == "resolve":
|
||||
resolve_op = op
|
||||
|
||||
if map_op is None or resolve_op is None:
|
||||
raise ValueError(
|
||||
f"Could not find both a map and resolve operation in target_ops: {target_ops}"
|
||||
)
|
||||
|
||||
# Verify the map comes before resolve
|
||||
map_idx = next(
|
||||
i for i, op in enumerate(operators) if op["name"] == map_op["name"]
|
||||
)
|
||||
resolve_idx = next(
|
||||
i for i, op in enumerate(operators) if op["name"] == resolve_op["name"]
|
||||
)
|
||||
|
||||
if map_idx >= resolve_idx:
|
||||
raise ValueError(
|
||||
f"Map operation '{map_op['name']}' must come before resolve operation '{resolve_op['name']}'"
|
||||
)
|
||||
|
||||
# Load sample data if available
|
||||
sample_data = None
|
||||
if dataset:
|
||||
try:
|
||||
with open(dataset, "r") as f:
|
||||
sample_data = json.load(f)
|
||||
except Exception:
|
||||
pass # Ignore if we can't load sample data
|
||||
|
||||
# Instantiate the directive
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
map_op,
|
||||
resolve_op,
|
||||
agent_llm,
|
||||
message_history,
|
||||
sample_data,
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators
|
||||
new_ops_plan = self.apply(
|
||||
global_default_model, operators, map_op["name"], resolve_op["name"], rewrite
|
||||
)
|
||||
return new_ops_plan, message_history, call_cost
|
||||
|
|
@ -0,0 +1,335 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from litellm import completion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
MapToMapResolveReduceInstantiateSchema,
|
||||
)
|
||||
|
||||
from .base import MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS, Directive, DirectiveTestCase
|
||||
|
||||
|
||||
class MapToMapResolveReduceDirective(Directive):
|
||||
name: str = Field(
|
||||
default="map_to_map_resolve_reduce", description="The name of the directive"
|
||||
)
|
||||
formal_description: str = Field(default="Map -> Reduce => Map -> Resolve -> Reduce")
|
||||
nl_description: str = Field(
|
||||
default="Insert a Resolve operation between Map and Reduce to deduplicate or normalize entities before aggregation. The Resolve operation uses code-powered blocking conditions to efficiently identify which pairs to compare, avoiding O(n^2) comparisons. This is useful when the Map output contains duplicate or near-duplicate entities that should be merged before the Reduce step."
|
||||
)
|
||||
when_to_use: str = Field(
|
||||
default="When a Map operation produces outputs that may contain duplicates, variations, or near-duplicates (e.g., different spellings of names, similar categories), and these should be normalized before the Reduce aggregation step. The target must be a Map operation followed by a Reduce operation."
|
||||
)
|
||||
instantiate_schema_type: Type[BaseModel] = MapToMapResolveReduceInstantiateSchema
|
||||
|
||||
example: str = Field(
|
||||
default=(
|
||||
"Original Pipeline:\n"
|
||||
"- name: extract_companies\n"
|
||||
" type: map\n"
|
||||
" prompt: |\n"
|
||||
" Extract company names from this news article:\n"
|
||||
" {{ input.article }}\n"
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
" company_name: string\n"
|
||||
"\n"
|
||||
"- name: aggregate_companies\n"
|
||||
" type: reduce\n"
|
||||
" reduce_key: sector\n"
|
||||
" prompt: |\n"
|
||||
" List all unique companies in this sector:\n"
|
||||
" {% for input in inputs %}\n"
|
||||
" - {{ input.company_name }}\n"
|
||||
" {% endfor %}\n"
|
||||
" output:\n"
|
||||
" schema:\n"
|
||||
" companies: list[str]\n"
|
||||
"\n"
|
||||
"Example InstantiateSchema:\n"
|
||||
"MapToMapResolveReduceInstantiateSchema(\n"
|
||||
" resolve_name='normalize_company_names',\n"
|
||||
" comparison_prompt='''Are these two company names referring to the same company?\n"
|
||||
"Company 1: {{ input1.company_name }}\n"
|
||||
"Company 2: {{ input2.company_name }}\n"
|
||||
"Consider variations like abbreviations (IBM vs International Business Machines), \n"
|
||||
"different legal suffixes (Inc, Corp, LLC), and common misspellings.''',\n"
|
||||
" resolution_prompt='''Given these variations of a company name:\n"
|
||||
"{% for input in inputs %}\n"
|
||||
"- {{ input.company_name }}\n"
|
||||
"{% endfor %}\n"
|
||||
"Return the canonical/official company name.''',\n"
|
||||
" blocking_conditions=[\n"
|
||||
" \"input1['company_name'][:3].lower() == input2['company_name'][:3].lower()\",\n"
|
||||
" \"input1['company_name'].split()[0].lower() == input2['company_name'].split()[0].lower()\",\n"
|
||||
" ],\n"
|
||||
" blocking_keys=['company_name'],\n"
|
||||
" limit_comparisons=1000,\n"
|
||||
" output_schema={'company_name': 'string'},\n"
|
||||
")"
|
||||
),
|
||||
)
|
||||
|
||||
test_cases: List[DirectiveTestCase] = Field(
|
||||
default_factory=lambda: [
|
||||
DirectiveTestCase(
|
||||
name="company_name_normalization",
|
||||
description="Should insert resolve between map and reduce for company names",
|
||||
input_config=[
|
||||
{
|
||||
"name": "extract_companies",
|
||||
"type": "map",
|
||||
"prompt": "Extract company name from: {{ input.text }}",
|
||||
"output": {"schema": {"company_name": "string"}},
|
||||
},
|
||||
{
|
||||
"name": "aggregate_by_sector",
|
||||
"type": "reduce",
|
||||
"reduce_key": "sector",
|
||||
"prompt": "List companies:\n{% for input in inputs %}\n- {{ input.company_name }}\n{% endfor %}",
|
||||
"output": {"schema": {"companies": "list[str]"}},
|
||||
},
|
||||
],
|
||||
target_ops=["extract_companies", "aggregate_by_sector"],
|
||||
expected_behavior="Should create a resolve operation between the map and reduce with appropriate blocking conditions",
|
||||
should_pass=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, MapToMapResolveReduceDirective)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("MapToMapResolveReduceDirective")
|
||||
|
||||
def to_string_for_instantiate(self, map_op: Dict, reduce_op: Dict) -> str:
|
||||
"""
|
||||
Generate a prompt for an agent to instantiate this directive.
|
||||
|
||||
Args:
|
||||
map_op (Dict): The map operation configuration.
|
||||
reduce_op (Dict): The reduce operation configuration.
|
||||
|
||||
Returns:
|
||||
str: The agent prompt for instantiating the directive.
|
||||
"""
|
||||
return (
|
||||
f"You are an expert at optimizing data processing pipelines by inserting entity resolution steps.\n\n"
|
||||
f"Map Operation:\n"
|
||||
f"{str(map_op)}\n\n"
|
||||
f"Reduce Operation:\n"
|
||||
f"{str(reduce_op)}\n\n"
|
||||
f"Directive: {self.name}\n"
|
||||
f"Your task is to insert a Resolve operation between the Map and Reduce to deduplicate/normalize entities.\n\n"
|
||||
f"Key Requirements:\n"
|
||||
f"1. Create a comparison_prompt that determines if two items from the Map output are duplicates/variations:\n"
|
||||
f" - Must reference {{ input1.key }} and {{ input2.key }} for comparing fields\n"
|
||||
f" - Should handle common variations (abbreviations, misspellings, formatting differences)\n\n"
|
||||
f"2. Create a resolution_prompt that merges matched items into a canonical form:\n"
|
||||
f" - Must use {{% for input in inputs %}} to iterate over matched items\n"
|
||||
f" - Should produce the most authoritative/complete version\n\n"
|
||||
f"3. Create blocking_conditions to avoid O(n^2) comparisons:\n"
|
||||
f" - These are Python expressions with access to 'input1' and 'input2' dicts\n"
|
||||
f" - They should filter pairs to only those likely to match\n"
|
||||
f" - Examples:\n"
|
||||
f" * \"input1['name'][:3].lower() == input2['name'][:3].lower()\" (first 3 chars match)\n"
|
||||
f" * \"input1['name'].split()[0].lower() == input2['name'].split()[0].lower()\" (first word matches)\n"
|
||||
f" * \"abs(len(input1['name']) - len(input2['name'])) < 10\" (similar length)\n"
|
||||
f" - Multiple conditions are OR'd together\n\n"
|
||||
f"4. Set limit_comparisons to cap the number of pairs (recommended: 500-2000)\n\n"
|
||||
f"5. The output_schema should match what the Reduce operation expects from each input\n\n"
|
||||
f"Example:\n"
|
||||
f"{self.example}\n\n"
|
||||
f"Please output the MapToMapResolveReduceInstantiateSchema."
|
||||
)
|
||||
|
||||
def llm_instantiate(
|
||||
self, map_op: Dict, reduce_op: Dict, agent_llm: str, message_history: list = []
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive.
|
||||
|
||||
Args:
|
||||
map_op (Dict): The map operation configuration.
|
||||
reduce_op (Dict): The reduce operation configuration.
|
||||
agent_llm (str): The LLM model to use.
|
||||
message_history (List, optional): Conversation history for context.
|
||||
|
||||
Returns:
|
||||
MapToMapResolveReduceInstantiateSchema: The structured output from the LLM.
|
||||
"""
|
||||
|
||||
message_history.extend(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful AI assistant for document processing pipelines specializing in entity resolution.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.to_string_for_instantiate(map_op, reduce_op),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
response_format=MapToMapResolveReduceInstantiateSchema,
|
||||
)
|
||||
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = MapToMapResolveReduceInstantiateSchema(**parsed_res)
|
||||
|
||||
message_history.append(
|
||||
{"role": "assistant", "content": resp.choices[0].message.content}
|
||||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
global_default_model,
|
||||
ops_list: List[Dict],
|
||||
map_op_name: str,
|
||||
reduce_op_name: str,
|
||||
rewrite: MapToMapResolveReduceInstantiateSchema,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply the directive to the pipeline config.
|
||||
"""
|
||||
new_ops_list = deepcopy(ops_list)
|
||||
|
||||
# Find position of the map and reduce ops
|
||||
map_pos = None
|
||||
reduce_pos = None
|
||||
map_op = None
|
||||
|
||||
for i, op in enumerate(ops_list):
|
||||
if op["name"] == map_op_name:
|
||||
map_pos = i
|
||||
map_op = op
|
||||
elif op["name"] == reduce_op_name:
|
||||
reduce_pos = i
|
||||
|
||||
if map_pos is None or reduce_pos is None:
|
||||
raise ValueError(
|
||||
f"Could not find map '{map_op_name}' and reduce '{reduce_op_name}' operations"
|
||||
)
|
||||
|
||||
# Determine the model to use
|
||||
default_model = map_op.get("model", global_default_model)
|
||||
|
||||
# Find the reduce operation to get the reduce_key
|
||||
reduce_op = None
|
||||
for op in ops_list:
|
||||
if op["name"] == reduce_op_name:
|
||||
reduce_op = op
|
||||
break
|
||||
|
||||
# Derive output schema from reduce_key - that's what's being grouped/resolved
|
||||
reduce_key = reduce_op.get("reduce_key", []) if reduce_op else []
|
||||
if isinstance(reduce_key, str):
|
||||
reduce_key = [reduce_key]
|
||||
|
||||
# Build output schema from reduce_key fields
|
||||
output_schema = {key: "string" for key in reduce_key}
|
||||
|
||||
# Create the resolve operation
|
||||
resolve_op = {
|
||||
"name": rewrite.resolve_name,
|
||||
"type": "resolve",
|
||||
"comparison_prompt": rewrite.comparison_prompt,
|
||||
"resolution_prompt": rewrite.resolution_prompt,
|
||||
"blocking_conditions": rewrite.blocking_conditions,
|
||||
"blocking_keys": rewrite.blocking_keys,
|
||||
"limit_comparisons": rewrite.limit_comparisons,
|
||||
"model": default_model,
|
||||
"output": {"schema": output_schema},
|
||||
}
|
||||
|
||||
# Insert resolve operation after the map operation
|
||||
new_ops_list.insert(map_pos + 1, resolve_op)
|
||||
|
||||
return new_ops_list
|
||||
|
||||
def instantiate(
|
||||
self,
|
||||
operators: List[Dict],
|
||||
target_ops: List[str],
|
||||
agent_llm: str,
|
||||
message_history: list = [],
|
||||
optimize_goal="acc",
|
||||
global_default_model: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiate the directive for a list of operators.
|
||||
"""
|
||||
# Assert that there are exactly two target ops (map and reduce)
|
||||
assert (
|
||||
len(target_ops) == 2
|
||||
), "There must be exactly two target ops (map and reduce) to instantiate this directive"
|
||||
|
||||
# Find the map and reduce operations
|
||||
map_op = None
|
||||
reduce_op = None
|
||||
|
||||
for op in operators:
|
||||
if op["name"] == target_ops[0]:
|
||||
if op.get("type") == "map":
|
||||
map_op = op
|
||||
elif op.get("type") == "reduce":
|
||||
reduce_op = op
|
||||
elif op["name"] == target_ops[1]:
|
||||
if op.get("type") == "map":
|
||||
map_op = op
|
||||
elif op.get("type") == "reduce":
|
||||
reduce_op = op
|
||||
|
||||
if map_op is None or reduce_op is None:
|
||||
raise ValueError(
|
||||
f"Could not find both a map and reduce operation in target_ops: {target_ops}"
|
||||
)
|
||||
|
||||
# Verify the map comes before reduce
|
||||
map_idx = next(
|
||||
i for i, op in enumerate(operators) if op["name"] == map_op["name"]
|
||||
)
|
||||
reduce_idx = next(
|
||||
i for i, op in enumerate(operators) if op["name"] == reduce_op["name"]
|
||||
)
|
||||
|
||||
if map_idx >= reduce_idx:
|
||||
raise ValueError(
|
||||
f"Map operation '{map_op['name']}' must come before reduce operation '{reduce_op['name']}'"
|
||||
)
|
||||
|
||||
# Instantiate the directive
|
||||
rewrite, message_history, call_cost = self.llm_instantiate(
|
||||
map_op,
|
||||
reduce_op,
|
||||
agent_llm,
|
||||
message_history,
|
||||
)
|
||||
|
||||
# Apply the rewrite to the operators
|
||||
new_ops_plan = self.apply(
|
||||
global_default_model, operators, map_op["name"], reduce_op["name"], rewrite
|
||||
)
|
||||
return new_ops_plan, message_history, call_cost
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -172,17 +171,14 @@ class OperatorFusionDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=OperatorFusionInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = OperatorFusionInstantiateSchema(**parsed_res)
|
||||
|
|
@ -191,11 +187,12 @@ class OperatorFusionDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
@ -215,7 +212,6 @@ class OperatorFusionDirective(Directive):
|
|||
new_ops_list = deepcopy(ops_list)
|
||||
op1_name, op2_name = target_ops[0], target_ops[1]
|
||||
|
||||
|
||||
# Find the operations
|
||||
op1_idx = next(i for i, op in enumerate(ops_list) if op["name"] == op1_name)
|
||||
op2_idx = next(i for i, op in enumerate(ops_list) if op["name"] == op2_name)
|
||||
|
|
@ -329,5 +325,5 @@ def transform(input_doc):
|
|||
return (
|
||||
self.apply(global_default_model, operators, target_ops, rewrite),
|
||||
message_history,
|
||||
call_cost
|
||||
call_cost,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
|
@ -127,7 +126,7 @@ class ReduceChainingDirective(Directive):
|
|||
original_op: Dict,
|
||||
expected_document_key: str,
|
||||
agent_llm: str,
|
||||
message_history: list = []
|
||||
message_history: list = [],
|
||||
):
|
||||
"""
|
||||
Use LLM to instantiate this directive by decomposing the reduce operation.
|
||||
|
|
@ -155,19 +154,15 @@ class ReduceChainingDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
|
||||
call_cost = 0
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=ReduceChainingInstantiateSchema
|
||||
response_format=ReduceChainingInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = ReduceChainingInstantiateSchema(**parsed_res)
|
||||
|
|
@ -182,11 +177,12 @@ class ReduceChainingDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
@ -283,7 +279,9 @@ class ReduceChainingDirective(Directive):
|
|||
]
|
||||
|
||||
if document_key_candidates:
|
||||
expected_document_key = document_key_candidates[0] # Pick the first candidate
|
||||
expected_document_key = document_key_candidates[
|
||||
0
|
||||
] # Pick the first candidate
|
||||
elif input_keys:
|
||||
expected_document_key = input_keys[0] # Fall back to the first input key
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -255,17 +254,14 @@ class ReduceGleaningDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=GleaningInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = GleaningInstantiateSchema(**parsed_res)
|
||||
|
|
@ -274,11 +270,12 @@ class ReduceGleaningDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
@ -341,5 +338,6 @@ class ReduceGleaningDirective(Directive):
|
|||
# Apply the rewrite to the operators
|
||||
return (
|
||||
self.apply(global_default_model, operators, target_ops[0], rewrite),
|
||||
message_history, call_cost
|
||||
message_history,
|
||||
call_cost,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Type
|
||||
|
||||
|
|
@ -234,18 +233,14 @@ class TakeHeadTailDirective(Directive):
|
|||
]
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for _ in range(MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS):
|
||||
call_cost = 0
|
||||
resp = completion(
|
||||
model=agent_llm,
|
||||
messages=message_history,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
azure=True,
|
||||
response_format=TakeHeadTailInstantiateSchema,
|
||||
)
|
||||
call_cost = resp._hidden_params["response_cost"]
|
||||
call_cost = resp._hidden_params.get("response_cost", 0)
|
||||
try:
|
||||
parsed_res = json.loads(resp.choices[0].message.content)
|
||||
schema = TakeHeadTailInstantiateSchema(**parsed_res)
|
||||
|
|
@ -254,11 +249,12 @@ class TakeHeadTailDirective(Directive):
|
|||
)
|
||||
return schema, message_history, call_cost
|
||||
except Exception as err:
|
||||
last_error = err
|
||||
error_message = f"Validation error: {err}\nPlease try again."
|
||||
message_history.append({"role": "user", "content": error_message})
|
||||
|
||||
raise Exception(
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts."
|
||||
f"Failed to instantiate directive after {MAX_DIRECTIVE_INSTANTIATION_ATTEMPTS} attempts. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class MapOpConfig(BaseModel):
|
|||
...,
|
||||
description="The keys of the output of the Map operator, to be referenced in the downstream operator's prompt. Can be a single key or a list of keys. Can be new keys or existing keys from the map operator we are rewriting.",
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def validate_prompt_contains_input_key(cls, value: str) -> str:
|
||||
"""
|
||||
|
|
@ -445,8 +445,12 @@ class DeterministicDocCompressionInstantiateSchema(BaseModel):
|
|||
raise ValueError(
|
||||
"Code must define a function named 'transform' that takes input_doc as parameter"
|
||||
)
|
||||
if "return {" not in v and "return dict(" not in v:
|
||||
raise ValueError("Code must return a dictionary")
|
||||
if (
|
||||
"return {" not in v
|
||||
and "return dict(" not in v
|
||||
and "output = dict(" not in v
|
||||
):
|
||||
raise ValueError(f"Code must return a dictionary. Instead the code is: {v}")
|
||||
return v
|
||||
|
||||
def validate_code_returns_target_keys(self, target_ops_configs: List[Dict]) -> None:
|
||||
|
|
@ -563,7 +567,6 @@ class ChunkHeaderSummaryInstantiateSchema(BaseModel):
|
|||
return v
|
||||
|
||||
|
||||
|
||||
class SamplingConfig(BaseModel):
|
||||
"""Configuration for optional sampling in document chunking."""
|
||||
|
||||
|
|
@ -634,7 +637,7 @@ class DocumentChunkingInstantiateSchema(BaseModel):
|
|||
default=None,
|
||||
description="Optional sampling configuration. If provided, inserts a Sample operation between Gather and Map. Use by default UNLESS task requires processing ALL chunks (like comprehensive extraction of all instances).",
|
||||
)
|
||||
|
||||
|
||||
def validate_split_key_exists_in_input(self, input_file_path: str) -> None:
|
||||
"""
|
||||
Validates that the split_key exists in the input JSON file items.
|
||||
|
|
@ -812,7 +815,6 @@ class DocumentChunkingTopKInstantiateSchema(BaseModel):
|
|||
description="Configuration for the topk operation to select relevant chunks",
|
||||
)
|
||||
|
||||
|
||||
def validate_split_key_exists_in_input(self, input_file_path: str) -> None:
|
||||
"""
|
||||
Validates that the split_key exists in the input JSON file items.
|
||||
|
|
@ -1211,6 +1213,109 @@ class SearchReplaceEdit(BaseModel):
|
|||
# )
|
||||
|
||||
|
||||
class MapToMapResolveReduceInstantiateSchema(BaseModel):
|
||||
"""
|
||||
Schema for inserting a Resolve operation between Map and Reduce.
|
||||
Transforms Map -> Reduce => Map -> Resolve -> Reduce pattern.
|
||||
The Resolve operation deduplicates/normalizes entities before aggregation.
|
||||
"""
|
||||
|
||||
resolve_name: str = Field(..., description="The name of the new Resolve operator")
|
||||
comparison_prompt: str = Field(
|
||||
...,
|
||||
description="Jinja prompt template for comparing two items. Must use {{ input1.key }} and {{ input2.key }} to reference fields from both items being compared.",
|
||||
)
|
||||
resolution_prompt: str = Field(
|
||||
...,
|
||||
description="Jinja prompt template for resolving matched items into a canonical form. Must use {% for input in inputs %} to iterate over matched items.",
|
||||
)
|
||||
blocking_conditions: List[str] = Field(
|
||||
...,
|
||||
description="List of Python expressions that determine if two items should be compared. Each expression has access to 'input1' and 'input2' dicts. Example: \"input1['category'].lower() == input2['category'].lower()\"",
|
||||
)
|
||||
blocking_keys: List[str] = Field(
|
||||
...,
|
||||
description="Keys to use for blocking/grouping items before comparison. Must include at least the reduce_key from the downstream Reduce operation, plus any additional context helpful for resolution.",
|
||||
)
|
||||
limit_comparisons: int = Field(
|
||||
default=15000,
|
||||
description="Maximum number of pairs to compare. Code-based blocked pairs are prioritized. Defaults to 15000 to avoid O(n^2) comparisons.",
|
||||
gt=0,
|
||||
)
|
||||
|
||||
@field_validator("comparison_prompt")
|
||||
@classmethod
|
||||
def check_comparison_prompt(cls, v: str) -> str:
|
||||
if "input1" not in v or "input2" not in v:
|
||||
raise ValueError(
|
||||
"comparison_prompt must reference both 'input1' and 'input2' variables"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("resolution_prompt")
|
||||
@classmethod
|
||||
def check_resolution_prompt(cls, v: str) -> str:
|
||||
if "inputs" not in v:
|
||||
raise ValueError(
|
||||
"resolution_prompt must reference 'inputs' variable for iterating over matched items"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("blocking_conditions")
|
||||
@classmethod
|
||||
def check_blocking_conditions(cls, v: List[str]) -> List[str]:
|
||||
if not v:
|
||||
raise ValueError(
|
||||
"At least one blocking condition must be provided to avoid O(n^2) comparisons"
|
||||
)
|
||||
for condition in v:
|
||||
if "input1" not in condition or "input2" not in condition:
|
||||
raise ValueError(
|
||||
f"Blocking condition must reference both 'input1' and 'input2': {condition}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class MapResolveToMapWithCategoriesInstantiateSchema(BaseModel):
|
||||
"""
|
||||
Schema for replacing Map -> Resolve with a single Map that has predefined categories.
|
||||
The agent proposes categories based on analysis of the data/task, and the new Map
|
||||
forces outputs into one of these categories (or 'none of the above'), effectively
|
||||
doing entity resolution deterministically.
|
||||
"""
|
||||
|
||||
categories: List[str] = Field(
|
||||
...,
|
||||
description="List of valid category values that the map output should be constrained to. Should include all distinct canonical values discovered from analyzing the data/task.",
|
||||
)
|
||||
category_key: str = Field(
|
||||
...,
|
||||
description="The key in the output schema that will contain the category value",
|
||||
)
|
||||
new_prompt: str = Field(
|
||||
...,
|
||||
description="The new prompt for the Map operation that includes the category list and instructs the LLM to output one of the predefined categories. Must reference {{ input.key }} for input fields.",
|
||||
)
|
||||
include_none_of_above: bool = Field(
|
||||
default=True,
|
||||
description="Whether to include 'None of the above' as a valid category option for items that don't fit any category",
|
||||
)
|
||||
|
||||
@field_validator("categories")
|
||||
@classmethod
|
||||
def check_categories(cls, v: List[str]) -> List[str]:
|
||||
if len(v) < 2:
|
||||
raise ValueError("At least 2 categories must be provided")
|
||||
if len(v) != len(set(v)):
|
||||
raise ValueError("Categories must be unique")
|
||||
return v
|
||||
|
||||
@field_validator("new_prompt")
|
||||
@classmethod
|
||||
def check_new_prompt(cls, v: str) -> str:
|
||||
return MapOpConfig.validate_prompt_contains_input_key(v)
|
||||
|
||||
|
||||
class ArbitraryRewriteInstantiateSchema(BaseModel):
|
||||
"""
|
||||
Schema for arbitrary pipeline rewrites using search/replace edits.
|
||||
|
|
|
|||
|
|
@ -699,7 +699,7 @@ class DSLRunner(ConfigWrapper):
|
|||
"litellm_kwargs", {}
|
||||
)
|
||||
kwargs["rewrite_agent_model"] = self.config.get("optimizer_config", {}).get(
|
||||
"rewrite_agent_model", "gpt-4o"
|
||||
"rewrite_agent_model", "gpt-5.1"
|
||||
)
|
||||
kwargs["judge_agent_model"] = self.config.get("optimizer_config", {}).get(
|
||||
"judge_agent_model", "gpt-4o-mini"
|
||||
|
|
@ -727,7 +727,7 @@ class DSLRunner(ConfigWrapper):
|
|||
"litellm_kwargs", {}
|
||||
)
|
||||
kwargs["rewrite_agent_model"] = self.config.get("optimizer_config", {}).get(
|
||||
"rewrite_agent_model", "gpt-4o"
|
||||
"rewrite_agent_model", "gpt-5.1"
|
||||
)
|
||||
kwargs["judge_agent_model"] = self.config.get("optimizer_config", {}).get(
|
||||
"judge_agent_model", "gpt-4o-mini"
|
||||
|
|
|
|||
|
|
@ -34,9 +34,9 @@ This Equijoin operation matches job candidates to job postings:
|
|||
|
||||
The prompt template uses Jinja2 syntax, allowing you to reference input fields directly (e.g., `left.skills`). You can reference the left and right documents using `left` and `right` respectively.
|
||||
|
||||
!!! warning "Performance Consideration"
|
||||
!!! info "Automatic Blocking"
|
||||
|
||||
For large datasets, running comparisons with an LLM can be time-consuming. It's recommended to optimize your pipeline using `docetl build pipeline.yaml` to generate efficient blocking rules for the operation.
|
||||
If you don't specify any blocking configuration (`blocking_threshold`, `blocking_conditions`, or `limit_comparisons`), the Equijoin operation will automatically compute an optimal embedding-based blocking threshold at runtime. It samples pairs from your data, runs LLM comparisons on the sample, and finds a threshold that achieves 95% recall by default. You can adjust this with the `blocking_target_recall` parameter.
|
||||
|
||||
## Blocking
|
||||
|
||||
|
|
@ -95,10 +95,19 @@ A full Equijoin step combining both ideas might look like:
|
|||
|
||||
Equijoin shares many parameters with the Resolve operation. For a detailed list of required and optional parameters, please see the [Parameters section in the Resolve operation documentation](resolve.md#required-parameters).
|
||||
|
||||
Key differences for Equijoin include:
|
||||
### Equijoin-Specific Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
| ------------------------- | --------------------------------------------------------------------------------- | ----------------------------- |
|
||||
| `limits` | Maximum matches for each left/right item: `{"left": n, "right": m}` | No limit |
|
||||
| `blocking_keys` | Keys for embedding blocking: `{"left": [...], "right": [...]}` | All keys from each dataset |
|
||||
| `blocking_threshold` | Embedding similarity threshold for considering pairs | Auto-computed if not set |
|
||||
| `blocking_target_recall` | Target recall when auto-computing blocking threshold (0.0 to 1.0) | 0.95 |
|
||||
|
||||
Key differences from Resolve:
|
||||
|
||||
- `resolution_prompt` is not used in Equijoin.
|
||||
- `limits` parameter is specific to Equijoin, allowing you to set maximum matches for each left and right item.
|
||||
- `blocking_keys` uses a dict with `left` and `right` keys instead of a simple list.
|
||||
|
||||
## Incorporating Into a Pipeline
|
||||
|
||||
|
|
|
|||
|
|
@ -44,9 +44,9 @@ This Resolve operation processes patient names to identify and standardize dupli
|
|||
|
||||
Note: The prompt templates use Jinja2 syntax, allowing you to reference input fields directly (e.g., `input1.patient_name`).
|
||||
|
||||
!!! warning "Performance Consideration"
|
||||
!!! info "Automatic Blocking"
|
||||
|
||||
You should not run this operation as-is unless your dataset is small! Running O(n^2) comparisons with an LLM can be extremely time-consuming for large datasets. Instead, optimize your pipeline first using `docetl build pipeline.yaml` and run the optimized version, which will generate efficient blocking rules for the operation. Make sure you've set `optimize: true` in your resolve operation config.
|
||||
If you don't specify any blocking configuration (`blocking_threshold`, `blocking_conditions`, or `limit_comparisons`), the Resolve operation will automatically compute an optimal embedding-based blocking threshold at runtime. It samples pairs from your data, runs LLM comparisons on the sample, and finds a threshold that achieves 95% recall by default. You can adjust this with the `blocking_target_recall` parameter.
|
||||
|
||||
## Blocking
|
||||
|
||||
|
|
@ -132,7 +132,8 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un
|
|||
| `resolution_model` | The language model to use for reducing matched entries | Falls back to `default_model` |
|
||||
| `comparison_model` | The language model to use for comparing potential matches | Falls back to `default_model` |
|
||||
| `blocking_keys` | List of keys to use for initial blocking | All keys in the input data |
|
||||
| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | None |
|
||||
| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | Auto-computed if not set |
|
||||
| `blocking_target_recall` | Target recall when auto-computing blocking threshold (0.0 to 1.0) | 0.95 |
|
||||
| `blocking_conditions` | List of conditions for initial blocking | [] |
|
||||
| `input` | Specifies the schema or keys to subselect from each item to pass into the prompts | All keys from input items |
|
||||
| `embedding_batch_size` | The number of entries to send to the embedding model at a time | 1000 |
|
||||
|
|
@ -140,9 +141,9 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un
|
|||
| `limit_comparisons` | Maximum number of comparisons to perform | None |
|
||||
| `timeout` | Timeout for each LLM call in seconds | 120 |
|
||||
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |
|
||||
| `sample` | Number of samples to use for the operation | None |
|
||||
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
|
||||
| `bypass_cache` | If true, bypass the cache for this operation. | False |
|
||||
| `sample` | Number of samples to use for the operation | None |
|
||||
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
|
||||
| `bypass_cache` | If true, bypass the cache for this operation. | False |
|
||||
|
||||
## Best Practices
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ All fields in `optimizer_config` are required (no defaults):
|
|||
| `evaluation_file` | `str` | Path to Python file containing `@register_eval` decorated function |
|
||||
| `metric_key` | `str` | Key in evaluation results dictionary to use as accuracy metric |
|
||||
| `max_iterations` | `int` | Maximum number of MOARSearch iterations to run |
|
||||
| `model` | `str` | LLM model to use for directive instantiation during search |
|
||||
| `rewrite_agent_model` | `str` | LLM model to use for directive instantiation during search |
|
||||
|
||||
!!! warning "All Fields Required"
|
||||
MOAR will error if any required field is missing. There are no defaults.
|
||||
|
|
@ -67,19 +67,19 @@ The optimizer will use the sample dataset, but your final pipeline uses the full
|
|||
### Available Models
|
||||
|
||||
!!! info "LiteLLM Model Names"
|
||||
Use LiteLLM model names (e.g., `gpt-4o-mini`, `gpt-4o`, `gpt-4.1`). Make sure your API keys are set in your environment.
|
||||
Use LiteLLM model names (e.g., `gpt-4o-mini`, `gpt-4o`, `gpt-5.1`). Make sure your API keys are set in your environment.
|
||||
|
||||
```yaml
|
||||
available_models: # LiteLLM model names - ensure API keys are set
|
||||
- gpt-4.1-nano # Cheapest, lower accuracy
|
||||
- gpt-4.1-mini # Low cost, decent accuracy
|
||||
- gpt-4.1 # Balanced
|
||||
- gpt-5.1-nano # Cheapest, lower accuracy
|
||||
- gpt-5.1-mini # Low cost, decent accuracy
|
||||
- gpt-5.1 # Balanced
|
||||
- gpt-4o # Higher cost, better accuracy
|
||||
```
|
||||
|
||||
### Model for Directive Instantiation
|
||||
|
||||
The `model` field specifies which LLM to use for generating optimization directives during the search process. This doesn't affect the models tested in `available_models`.
|
||||
The `rewrite_agent_model` field specifies which LLM to use for generating optimization directives during the search process. This doesn't affect the models tested in `available_models`.
|
||||
|
||||
!!! tip "Cost Consideration"
|
||||
Use a cheaper model (like `gpt-4o-mini`) for directive instantiation to reduce search costs.
|
||||
|
|
@ -104,12 +104,12 @@ optimizer_config:
|
|||
available_models:
|
||||
- gpt-4o-mini
|
||||
- gpt-4o
|
||||
- gpt-4.1-mini
|
||||
- gpt-4.1
|
||||
- gpt-5.1-mini
|
||||
- gpt-5.1
|
||||
evaluation_file: evaluate_medications.py
|
||||
metric_key: medication_extraction_score
|
||||
max_iterations: 40
|
||||
model: gpt-4.1
|
||||
rewrite_agent_model: gpt-5.1
|
||||
dataset_path: data/sample.json # Optional
|
||||
exploration_weight: 1.414 # Optional
|
||||
```
|
||||
|
|
|
|||
|
|
@ -25,15 +25,15 @@ optimizer_config:
|
|||
dataset_path: workloads/medical/raw_sample.json # Use sample for faster optimization
|
||||
save_dir: workloads/medical/moar_results
|
||||
available_models: # LiteLLM model names - ensure API keys are set in your environment
|
||||
- gpt-4.1-nano
|
||||
- gpt-4.1-mini
|
||||
- gpt-4.1
|
||||
- gpt-5.1-nano
|
||||
- gpt-5.1-mini
|
||||
- gpt-5.1
|
||||
- gpt-4o
|
||||
- gpt-4o-mini
|
||||
evaluation_file: workloads/medical/evaluate_medications.py
|
||||
metric_key: medication_extraction_score
|
||||
max_iterations: 40
|
||||
model: gpt-4.1
|
||||
rewrite_agent_model: gpt-5.1
|
||||
|
||||
system_prompt:
|
||||
dataset_description: a collection of transcripts of doctor visits
|
||||
|
|
|
|||
|
|
@ -109,12 +109,12 @@ optimizer_config:
|
|||
available_models: # LiteLLM model names - ensure API keys are set in your environment
|
||||
- gpt-4o-mini
|
||||
- gpt-4o
|
||||
- gpt-4.1-mini
|
||||
- gpt-4.1
|
||||
- gpt-5.1-mini
|
||||
- gpt-5.1
|
||||
evaluation_file: evaluate_medications.py
|
||||
metric_key: medication_extraction_score # This must match a key in your evaluation function's return dictionary
|
||||
max_iterations: 40
|
||||
model: gpt-4.1
|
||||
rewrite_agent_model: gpt-5.1
|
||||
dataset_path: data/transcripts_sample.json # Optional: use sample/hold-out dataset
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ High-level summary of the optimization run:
|
|||
{
|
||||
"optimizer": "moar",
|
||||
"input_pipeline": "pipeline.yaml",
|
||||
"model": "gpt-4.1",
|
||||
"rewrite_agent_model": "gpt-5.1",
|
||||
"max_iterations": 40,
|
||||
"save_dir": "results/moar_optimization",
|
||||
"dataset": "transcripts",
|
||||
|
|
|
|||
|
|
@ -80,9 +80,9 @@ Test your evaluation function independently and check MOAR logs for errors.
|
|||
|
||||
```yaml
|
||||
available_models:
|
||||
- gpt-4.1-nano # Cheapest, lower accuracy
|
||||
- gpt-4.1-mini # Low cost, decent accuracy
|
||||
- gpt-4.1 # Balanced
|
||||
- gpt-5.1-nano # Cheapest, lower accuracy
|
||||
- gpt-5.1-mini # Low cost, decent accuracy
|
||||
- gpt-5.1 # Balanced
|
||||
- gpt-4o # Higher cost, better accuracy
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -267,10 +267,6 @@ class AgentCommunicator:
|
|||
response = completion(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
azure=True,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
|
||||
|
|
@ -290,10 +286,6 @@ class AgentCommunicator:
|
|||
response = completion(
|
||||
model=self.model,
|
||||
messages=messages + [{"role": "user", "content": request_msg}],
|
||||
azure=True,
|
||||
api_key=os.environ.get("AZURE_API_KEY"),
|
||||
api_base=os.environ.get("AZURE_API_BASE"),
|
||||
api_version=os.environ.get("AZURE_API_VERSION"),
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ class OptimizeResult(BaseModel):
|
|||
should_optimize: str | None = None
|
||||
input_data: list[dict[str, Any]] | None = None
|
||||
output_data: list[dict[str, Any]] | None = None
|
||||
num_docs_analyzed: int | None = None
|
||||
cost: float | None = None
|
||||
error: str | None = None
|
||||
created_at: datetime
|
||||
|
|
@ -35,4 +36,25 @@ class OptimizeResult(BaseModel):
|
|||
class OptimizeRequest(BaseModel):
|
||||
yaml_config: str
|
||||
step_name: str
|
||||
op_name: str
|
||||
op_name: str
|
||||
|
||||
|
||||
class DecomposeRequest(BaseModel):
|
||||
yaml_config: str
|
||||
step_name: str
|
||||
op_name: str
|
||||
|
||||
|
||||
class DecomposeResult(BaseModel):
|
||||
task_id: str
|
||||
status: TaskStatus
|
||||
decomposed_operations: list[dict[str, Any]] | None = None
|
||||
winning_directive: str | None = None
|
||||
candidates_evaluated: int | None = None
|
||||
original_outputs: list[dict[str, Any]] | None = None
|
||||
decomposed_outputs: list[dict[str, Any]] | None = None
|
||||
comparison_rationale: str | None = None
|
||||
cost: float | None = None
|
||||
error: str | None = None
|
||||
created_at: datetime
|
||||
completed_at: datetime | None = None
|
||||
|
|
@ -2,13 +2,23 @@ from typing import Any
|
|||
import uuid
|
||||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from docetl.runner import DSLRunner
|
||||
from docetl.optimizers.fast_should_optimize import FastShouldOptimizeAnalyzer
|
||||
from docetl.optimizers.fast_decomposer import FastDecomposer
|
||||
import asyncio
|
||||
from asyncio import Task
|
||||
from rich.logging import RichHandler
|
||||
import logging
|
||||
import yaml
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from server.app.models import OptimizeResult, TaskStatus, OptimizeRequest, PipelineRequest
|
||||
from server.app.models import (
|
||||
OptimizeResult,
|
||||
TaskStatus,
|
||||
OptimizeRequest,
|
||||
PipelineRequest,
|
||||
DecomposeRequest,
|
||||
DecomposeResult,
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
FORMAT = "%(message)s"
|
||||
|
|
@ -18,10 +28,14 @@ logging.basicConfig(
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
# Task storage
|
||||
# Task storage for optimize tasks
|
||||
tasks: dict[str, OptimizeResult] = {}
|
||||
asyncio_tasks: dict[str, Task] = {}
|
||||
|
||||
# Task storage for decompose tasks
|
||||
decompose_tasks: dict[str, DecomposeResult] = {}
|
||||
decompose_asyncio_tasks: dict[str, Task] = {}
|
||||
|
||||
# Configuration
|
||||
COMPLETED_TASK_TTL = timedelta(hours=1)
|
||||
|
||||
|
|
@ -30,50 +44,111 @@ async def cleanup_old_tasks():
|
|||
while True:
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
task_ids_to_remove = []
|
||||
|
||||
# Clean up optimize tasks
|
||||
task_ids_to_remove = []
|
||||
for task_id, task in tasks.items():
|
||||
if (task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED] and
|
||||
task.completed_at and
|
||||
task.completed_at and
|
||||
current_time - task.completed_at > COMPLETED_TASK_TTL):
|
||||
task_ids_to_remove.append(task_id)
|
||||
|
||||
for task_id in task_ids_to_remove:
|
||||
del tasks[task_id]
|
||||
|
||||
|
||||
# Clean up decompose tasks
|
||||
decompose_ids_to_remove = []
|
||||
for task_id, task in decompose_tasks.items():
|
||||
if (task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED] and
|
||||
task.completed_at and
|
||||
current_time - task.completed_at > COMPLETED_TASK_TTL):
|
||||
decompose_ids_to_remove.append(task_id)
|
||||
for task_id in decompose_ids_to_remove:
|
||||
del decompose_tasks[task_id]
|
||||
|
||||
await asyncio.sleep(60)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in cleanup task: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def run_optimization(task_id: str, yaml_config: str, step_name: str, op_name: str):
|
||||
"""Execute the optimization task"""
|
||||
"""Execute the optimization task using fast single-LLM-call analysis."""
|
||||
try:
|
||||
tasks[task_id].status = TaskStatus.PROCESSING
|
||||
|
||||
# Run the actual optimization in a separate thread to not block
|
||||
runner = DSLRunner.from_yaml(yaml_config)
|
||||
should_optimize, input_data, output_data, cost = await asyncio.to_thread(
|
||||
runner.should_optimize,
|
||||
|
||||
# yaml_config is a file path, not YAML content - read and parse the file
|
||||
if yaml_config.endswith(".yaml") or yaml_config.endswith(".yml"):
|
||||
with open(yaml_config, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
else:
|
||||
# Fallback: try parsing as YAML string
|
||||
config = yaml.safe_load(yaml_config)
|
||||
|
||||
# Validate that we got a dict
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError(
|
||||
f"Invalid yaml_config: expected dict after parsing, got {type(config).__name__}. "
|
||||
f"Input: {str(yaml_config)[:200]}"
|
||||
)
|
||||
|
||||
# Get intermediate directory from config
|
||||
intermediate_dir = (
|
||||
config.get("pipeline", {})
|
||||
.get("output", {})
|
||||
.get("intermediate_dir")
|
||||
)
|
||||
|
||||
if not intermediate_dir:
|
||||
raise ValueError("No intermediate_dir configured in pipeline output")
|
||||
|
||||
# Find the operation config
|
||||
op_config = None
|
||||
for op in config.get("operations", []):
|
||||
if op.get("name") == op_name:
|
||||
op_config = op
|
||||
break
|
||||
|
||||
if not op_config:
|
||||
raise ValueError(f"Operation '{op_name}' not found in config")
|
||||
|
||||
# Get optimizer settings from config
|
||||
optimizer_model = (
|
||||
config.get("optimizer_config", {})
|
||||
.get("rewrite_agent_model", "gpt-5.1")
|
||||
)
|
||||
litellm_kwargs = (
|
||||
config.get("optimizer_config", {})
|
||||
.get("litellm_kwargs", {})
|
||||
)
|
||||
|
||||
# Create analyzer and run in thread pool
|
||||
analyzer = FastShouldOptimizeAnalyzer(
|
||||
intermediate_dir=intermediate_dir,
|
||||
optimizer_model=optimizer_model,
|
||||
litellm_kwargs=litellm_kwargs,
|
||||
)
|
||||
|
||||
should_optimize, output_data, num_docs_analyzed, cost = await asyncio.to_thread(
|
||||
analyzer.analyze,
|
||||
op_config,
|
||||
step_name,
|
||||
op_name,
|
||||
)
|
||||
|
||||
|
||||
# Update task result
|
||||
tasks[task_id].status = TaskStatus.COMPLETED
|
||||
tasks[task_id].should_optimize = should_optimize
|
||||
tasks[task_id].input_data = input_data
|
||||
tasks[task_id].input_data = None # We don't have input_data in fast mode
|
||||
tasks[task_id].output_data = output_data
|
||||
tasks[task_id].num_docs_analyzed = num_docs_analyzed
|
||||
tasks[task_id].cost = cost
|
||||
tasks[task_id].completed_at = datetime.now()
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
runner.is_cancelled = True
|
||||
tasks[task_id].status = TaskStatus.CANCELLED
|
||||
tasks[task_id].completed_at = datetime.now()
|
||||
raise
|
||||
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_traceback = traceback.format_exc()
|
||||
|
|
@ -81,11 +156,10 @@ async def run_optimization(task_id: str, yaml_config: str, step_name: str, op_na
|
|||
tasks[task_id].error = f"{str(e)}\n{error_traceback}"
|
||||
tasks[task_id].completed_at = datetime.now()
|
||||
raise
|
||||
|
||||
|
||||
finally:
|
||||
if task_id in asyncio_tasks:
|
||||
del asyncio_tasks[task_id]
|
||||
runner.reset_env()
|
||||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
|
|
@ -153,6 +227,140 @@ async def cancel_optimize_task(task_id: str):
|
|||
|
||||
return {"message": "Task cancelled successfully"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Decompose endpoints
|
||||
# ============================================================================
|
||||
|
||||
async def run_decomposition(task_id: str, yaml_config: str, step_name: str, op_name: str):
|
||||
"""Execute the decomposition task using fast directive-based approach."""
|
||||
try:
|
||||
decompose_tasks[task_id].status = TaskStatus.PROCESSING
|
||||
|
||||
# yaml_config is a file path
|
||||
if not (yaml_config.endswith(".yaml") or yaml_config.endswith(".yml")):
|
||||
raise ValueError("yaml_config must be a path to a YAML file")
|
||||
|
||||
# Get optimizer settings from config
|
||||
with open(yaml_config, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
optimizer_model = (
|
||||
config.get("optimizer_config", {})
|
||||
.get("rewrite_agent_model", "gpt-5.1")
|
||||
)
|
||||
litellm_kwargs = (
|
||||
config.get("optimizer_config", {})
|
||||
.get("litellm_kwargs", {})
|
||||
)
|
||||
|
||||
# Create decomposer and run in thread pool
|
||||
decomposer = FastDecomposer(
|
||||
yaml_config_path=yaml_config,
|
||||
optimizer_model=optimizer_model,
|
||||
sample_size=5,
|
||||
litellm_kwargs=litellm_kwargs,
|
||||
)
|
||||
|
||||
result = await asyncio.to_thread(
|
||||
decomposer.decompose,
|
||||
step_name,
|
||||
op_name,
|
||||
)
|
||||
|
||||
# Update task result
|
||||
decompose_tasks[task_id].status = TaskStatus.COMPLETED
|
||||
decompose_tasks[task_id].decomposed_operations = result["decomposed_ops"]
|
||||
decompose_tasks[task_id].winning_directive = result["winning_directive"]
|
||||
decompose_tasks[task_id].candidates_evaluated = result["candidates_evaluated"]
|
||||
decompose_tasks[task_id].original_outputs = result["original_outputs"]
|
||||
decompose_tasks[task_id].decomposed_outputs = result["decomposed_outputs"]
|
||||
decompose_tasks[task_id].comparison_rationale = result["comparison_rationale"]
|
||||
decompose_tasks[task_id].cost = result["cost"]
|
||||
decompose_tasks[task_id].completed_at = datetime.now()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
decompose_tasks[task_id].status = TaskStatus.CANCELLED
|
||||
decompose_tasks[task_id].completed_at = datetime.now()
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_traceback = traceback.format_exc()
|
||||
decompose_tasks[task_id].status = TaskStatus.FAILED
|
||||
decompose_tasks[task_id].error = f"{str(e)}\n{error_traceback}"
|
||||
decompose_tasks[task_id].completed_at = datetime.now()
|
||||
raise
|
||||
|
||||
finally:
|
||||
if task_id in decompose_asyncio_tasks:
|
||||
del decompose_asyncio_tasks[task_id]
|
||||
|
||||
|
||||
@router.post("/decompose", status_code=202)
|
||||
async def submit_decompose_task(request: DecomposeRequest):
|
||||
"""Submit a new decomposition task"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# Create task record
|
||||
decompose_tasks[task_id] = DecomposeResult(
|
||||
task_id=task_id,
|
||||
status=TaskStatus.PENDING,
|
||||
created_at=datetime.now()
|
||||
)
|
||||
|
||||
# Create and store the asyncio task
|
||||
task = asyncio.create_task(
|
||||
run_decomposition(
|
||||
task_id,
|
||||
request.yaml_config,
|
||||
request.step_name,
|
||||
request.op_name
|
||||
)
|
||||
)
|
||||
decompose_asyncio_tasks[task_id] = task
|
||||
|
||||
return {"task_id": task_id}
|
||||
|
||||
|
||||
@router.get("/decompose/{task_id}")
|
||||
async def get_decompose_status(task_id: str) -> DecomposeResult:
|
||||
"""Get the current status of a decomposition task"""
|
||||
if task_id not in decompose_tasks:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Task not found or has been cleaned up"
|
||||
)
|
||||
|
||||
return decompose_tasks[task_id]
|
||||
|
||||
|
||||
@router.post("/decompose/{task_id}/cancel")
|
||||
async def cancel_decompose_task(task_id: str):
|
||||
"""Cancel a running decomposition task"""
|
||||
if task_id not in decompose_tasks:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Task not found or has been cleaned up"
|
||||
)
|
||||
|
||||
if task_id not in decompose_asyncio_tasks:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Task already finished or cannot be cancelled"
|
||||
)
|
||||
|
||||
asyncio_task = decompose_asyncio_tasks[task_id]
|
||||
asyncio_task.cancel()
|
||||
|
||||
try:
|
||||
await asyncio_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
return {"message": "Decomposition task cancelled successfully"}
|
||||
|
||||
|
||||
# Keep the original run_pipeline endpoint
|
||||
@router.post("/run_pipeline")
|
||||
def run_pipeline(request: PipelineRequest) -> dict[str, Any]:
|
||||
|
|
@ -169,6 +377,12 @@ def run_pipeline(request: PipelineRequest) -> dict[str, Any]:
|
|||
|
||||
@router.websocket("/ws/run_pipeline/{client_id}")
|
||||
async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
|
||||
"""
|
||||
WebSocket endpoint for running pipelines with real-time output streaming.
|
||||
|
||||
Note: The old 'optimize' flag for full pipeline optimization has been removed.
|
||||
Use the /decompose endpoint for fast operation decomposition instead.
|
||||
"""
|
||||
await websocket.accept()
|
||||
runner = None
|
||||
try:
|
||||
|
|
@ -178,19 +392,8 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
|
|||
if config.get("clear_intermediate", False):
|
||||
runner.clear_intermediate()
|
||||
|
||||
if config.get("optimize", False):
|
||||
logging.info(f"Optimizing pipeline with model {config.get('optimizer_model', 'gpt-4o')}")
|
||||
|
||||
# Set the runner config to the optimizer config
|
||||
runner.config["optimizer_config"]["rewrite_agent_model"] = config.get("optimizer_model", "gpt-4o")
|
||||
runner.config["optimizer_config"]["judge_agent_model"] = config.get("optimizer_model", "gpt-4o-mini")
|
||||
|
||||
async def run_pipeline():
|
||||
return await asyncio.to_thread(runner.optimize, return_pipeline=False)
|
||||
|
||||
else:
|
||||
async def run_pipeline():
|
||||
return await asyncio.to_thread(runner.load_run_save)
|
||||
async def run_pipeline():
|
||||
return await asyncio.to_thread(runner.load_run_save)
|
||||
|
||||
pipeline_task = asyncio.create_task(run_pipeline())
|
||||
|
||||
|
|
@ -198,18 +401,6 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
|
|||
console_output = runner.console.file.getvalue()
|
||||
await websocket.send_json({"type": "output", "data": console_output})
|
||||
|
||||
if config.get("optimize", False):
|
||||
optimizer_progress = runner.console.get_optimizer_progress()
|
||||
rationale = runner.console.optimizer_rationale
|
||||
await websocket.send_json({
|
||||
"type": "optimizer_progress",
|
||||
"status": optimizer_progress[0],
|
||||
"progress": optimizer_progress[1],
|
||||
"rationale": rationale[1] if rationale is not None else "",
|
||||
"should_optimize": rationale[0] if rationale is not None else False,
|
||||
"validator_prompt": rationale[2] if rationale is not None else ""
|
||||
})
|
||||
|
||||
# Check for incoming messages from the user
|
||||
try:
|
||||
user_message = await asyncio.wait_for(
|
||||
|
|
@ -219,8 +410,7 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
|
|||
if user_message == "kill":
|
||||
runner.console.log("Stopping process...")
|
||||
runner.is_cancelled = True
|
||||
|
||||
|
||||
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Process stopped by user request"
|
||||
|
|
@ -250,41 +440,16 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
|
|||
# Sleep for a short duration to ensure all output is captured
|
||||
await asyncio.sleep(3)
|
||||
|
||||
# If optimize is true, send back the optimized operations
|
||||
if config.get("optimize", False):
|
||||
optimized_config, cost = result
|
||||
|
||||
# Send the operations back in order
|
||||
new_pipeline_steps = optimized_config["pipeline"]["steps"]
|
||||
new_pipeline_op_name_to_op_map = {op["name"]: op for op in optimized_config["operations"]}
|
||||
new_ops_in_order = []
|
||||
for new_step in new_pipeline_steps:
|
||||
for op in new_step.get("operations", []):
|
||||
if op not in new_ops_in_order:
|
||||
new_ops_in_order.append(new_pipeline_op_name_to_op_map[op])
|
||||
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "result",
|
||||
"data": {
|
||||
"message": "Pipeline executed successfully",
|
||||
"cost": cost,
|
||||
"optimized_ops": new_ops_in_order,
|
||||
"yaml_config": config["yaml_config"],
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "result",
|
||||
"data": {
|
||||
"message": "Pipeline executed successfully",
|
||||
"cost": result,
|
||||
"yaml_config": config["yaml_config"],
|
||||
},
|
||||
}
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "result",
|
||||
"data": {
|
||||
"message": "Pipeline executed successfully",
|
||||
"cost": result,
|
||||
"yaml_config": config["yaml_config"],
|
||||
},
|
||||
}
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
if runner is not None:
|
||||
runner.reset_env()
|
||||
|
|
@ -299,3 +464,159 @@ async def websocket_run_pipeline(websocket: WebSocket, client_id: str):
|
|||
if runner is not None:
|
||||
runner.reset_env()
|
||||
await websocket.close()
|
||||
|
||||
|
||||
@router.websocket("/ws/decompose/{client_id}")
|
||||
async def websocket_decompose(websocket: WebSocket, client_id: str):
|
||||
"""
|
||||
WebSocket endpoint for fast operation decomposition with real-time output streaming.
|
||||
|
||||
Expects JSON message with:
|
||||
- yaml_config: Path to the pipeline YAML config file
|
||||
- step_name: Name of the pipeline step
|
||||
- op_name: Name of the operation to decompose
|
||||
"""
|
||||
await websocket.accept()
|
||||
decomposer = None
|
||||
|
||||
try:
|
||||
config = await websocket.receive_json()
|
||||
yaml_config = config["yaml_config"]
|
||||
step_name = config["step_name"]
|
||||
op_name = config["op_name"]
|
||||
|
||||
# Validate yaml_config is a path
|
||||
if not (yaml_config.endswith(".yaml") or yaml_config.endswith(".yml")):
|
||||
raise ValueError("yaml_config must be a path to a YAML file")
|
||||
|
||||
# Get optimizer settings from config
|
||||
with open(yaml_config, "r") as f:
|
||||
pipeline_config = yaml.safe_load(f)
|
||||
|
||||
optimizer_model = (
|
||||
pipeline_config.get("optimizer_config", {})
|
||||
.get("rewrite_agent_model", "gpt-5.1")
|
||||
)
|
||||
litellm_kwargs = (
|
||||
pipeline_config.get("optimizer_config", {})
|
||||
.get("litellm_kwargs", {})
|
||||
)
|
||||
|
||||
# Create a ThreadSafeConsole for streaming output
|
||||
from docetl.console import ThreadSafeConsole
|
||||
console = ThreadSafeConsole(
|
||||
force_terminal=True,
|
||||
soft_wrap=True,
|
||||
highlight=False,
|
||||
log_path=False,
|
||||
color_system="truecolor",
|
||||
width=120,
|
||||
style="bright_white on black",
|
||||
record=True,
|
||||
)
|
||||
|
||||
# Create decomposer with the console
|
||||
decomposer = FastDecomposer(
|
||||
yaml_config_path=yaml_config,
|
||||
optimizer_model=optimizer_model,
|
||||
sample_size=5,
|
||||
litellm_kwargs=litellm_kwargs,
|
||||
console=console,
|
||||
)
|
||||
|
||||
# Run decomposition in a thread
|
||||
async def run_decomposition():
|
||||
return await asyncio.to_thread(
|
||||
decomposer.decompose,
|
||||
step_name,
|
||||
op_name,
|
||||
)
|
||||
|
||||
decompose_task = asyncio.create_task(run_decomposition())
|
||||
|
||||
# Stream console output while decomposition runs
|
||||
accumulated_output = ""
|
||||
while not decompose_task.done():
|
||||
# get_output() processes carriage returns and clears the buffer
|
||||
new_output = console.get_output()
|
||||
if new_output:
|
||||
# Handle spinner updates: if new output doesn't start with newline
|
||||
# and accumulated doesn't end with newline, replace the last line
|
||||
if accumulated_output and not accumulated_output.endswith('\n') and not new_output.startswith('\n'):
|
||||
# Find the last newline in accumulated output
|
||||
last_newline = accumulated_output.rfind('\n')
|
||||
if last_newline >= 0:
|
||||
# Replace everything after the last newline
|
||||
accumulated_output = accumulated_output[:last_newline + 1] + new_output
|
||||
else:
|
||||
# No newline in accumulated, replace everything
|
||||
accumulated_output = new_output
|
||||
else:
|
||||
accumulated_output += new_output
|
||||
await websocket.send_json({"type": "output", "data": accumulated_output})
|
||||
|
||||
# Check for kill message
|
||||
try:
|
||||
user_message = await asyncio.wait_for(
|
||||
websocket.receive_json(), timeout=0.1
|
||||
)
|
||||
if user_message == "kill":
|
||||
console.log("[red]Stopping decomposition...[/red]")
|
||||
decompose_task.cancel()
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Decomposition stopped by user request"
|
||||
})
|
||||
raise Exception("Process stopped by user request")
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Decomposition stopped by user request"
|
||||
})
|
||||
raise
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Get final result
|
||||
result = await decompose_task
|
||||
|
||||
# Send any remaining console output
|
||||
final_output = console.get_output()
|
||||
if final_output:
|
||||
accumulated_output += final_output
|
||||
await websocket.send_json({"type": "output", "data": accumulated_output})
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Send the result
|
||||
await websocket.send_json({
|
||||
"type": "result",
|
||||
"data": {
|
||||
"decomposed_operations": result["decomposed_ops"],
|
||||
"winning_directive": result["winning_directive"],
|
||||
"candidates_evaluated": result["candidates_evaluated"],
|
||||
"original_outputs": result["original_outputs"],
|
||||
"decomposed_outputs": result["decomposed_outputs"],
|
||||
"comparison_rationale": result["comparison_rationale"],
|
||||
"cost": result["cost"],
|
||||
},
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
print(f"Decompose client {client_id} disconnected")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_traceback = traceback.format_exc()
|
||||
print(f"Decompose error:\n{error_traceback}")
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"data": str(e),
|
||||
"traceback": error_traceback
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
await websocket.close()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ Simple apply tests for directive testing - just ensure apply() doesn't crash.
|
|||
# Simple apply tests - no pytest needed
|
||||
|
||||
from docetl.reasoning_optimizer.directives import (
|
||||
ChainingDirective,
|
||||
ChainingDirective,
|
||||
GleaningDirective,
|
||||
ReduceGleaningDirective,
|
||||
ReduceChainingDirective,
|
||||
|
|
@ -14,7 +14,6 @@ from docetl.reasoning_optimizer.directives import (
|
|||
OperatorFusionDirective,
|
||||
DocSummarizationDirective,
|
||||
IsolatingSubtasksDirective,
|
||||
DocCompressionDirective,
|
||||
DeterministicDocCompressionDirective,
|
||||
DocumentChunkingDirective,
|
||||
DocumentChunkingTopKDirective,
|
||||
|
|
@ -22,8 +21,12 @@ from docetl.reasoning_optimizer.directives import (
|
|||
TakeHeadTailDirective,
|
||||
ClarifyInstructionsDirective,
|
||||
SwapWithCodeDirective,
|
||||
HierarchicalReduceDirective
|
||||
HierarchicalReduceDirective,
|
||||
MapToMapResolveReduceDirective,
|
||||
MapResolveToMapWithCategoriesDirective,
|
||||
)
|
||||
# DocCompressionDirective is commented out in __init__.py, import directly
|
||||
from docetl.reasoning_optimizer.directives.doc_compression import DocCompressionDirective
|
||||
|
||||
|
||||
def test_chaining_apply():
|
||||
|
|
@ -1012,6 +1015,237 @@ def test_cascade_filtering_apply():
|
|||
assert "high-quality research paper" in result[5]["prompt"]
|
||||
|
||||
|
||||
def test_map_to_map_resolve_reduce_apply():
|
||||
"""Test that map_to_map_resolve_reduce apply doesn't crash"""
|
||||
directive = MapToMapResolveReduceDirective()
|
||||
|
||||
# Pipeline with map followed by reduce
|
||||
ops_list = [
|
||||
{
|
||||
"name": "extract_companies",
|
||||
"type": "map",
|
||||
"prompt": "Extract company name from: {{ input.article }}",
|
||||
"model": "gpt-4o-mini",
|
||||
"output": {"schema": {"company_name": "string"}},
|
||||
},
|
||||
{
|
||||
"name": "aggregate_by_sector",
|
||||
"type": "reduce",
|
||||
"reduce_key": "sector",
|
||||
"prompt": "List companies:\n{% for input in inputs %}\n- {{ input.company_name }}\n{% endfor %}",
|
||||
"model": "gpt-4o-mini",
|
||||
"output": {"schema": {"companies": "list[str]"}},
|
||||
},
|
||||
]
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
MapToMapResolveReduceInstantiateSchema,
|
||||
)
|
||||
|
||||
rewrite = MapToMapResolveReduceInstantiateSchema(
|
||||
resolve_name="normalize_company_names",
|
||||
comparison_prompt="""Are these two company names referring to the same company?
|
||||
Company 1: {{ input1.company_name }}
|
||||
Company 2: {{ input2.company_name }}
|
||||
Consider variations like abbreviations (IBM vs International Business Machines).""",
|
||||
resolution_prompt="""Given these variations of a company name:
|
||||
{% for input in inputs %}
|
||||
- {{ input.company_name }}
|
||||
{% endfor %}
|
||||
Return the canonical/official company name.""",
|
||||
blocking_conditions=[
|
||||
"input1['company_name'][:3].lower() == input2['company_name'][:3].lower()",
|
||||
"input1['company_name'].split()[0].lower() == input2['company_name'].split()[0].lower()",
|
||||
],
|
||||
blocking_keys=["sector", "company_name"], # Must include reduce_key (sector)
|
||||
limit_comparisons=1000,
|
||||
)
|
||||
|
||||
result = directive.apply(
|
||||
"gpt-4o-mini", ops_list, "extract_companies", "aggregate_by_sector", rewrite
|
||||
)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 3 # map + resolve + reduce
|
||||
|
||||
# Check the resolve operation was inserted correctly
|
||||
assert result[0]["name"] == "extract_companies"
|
||||
assert result[0]["type"] == "map"
|
||||
|
||||
assert result[1]["name"] == "normalize_company_names"
|
||||
assert result[1]["type"] == "resolve"
|
||||
assert "comparison_prompt" in result[1]
|
||||
assert "resolution_prompt" in result[1]
|
||||
assert "blocking_conditions" in result[1]
|
||||
assert "blocking_keys" in result[1]
|
||||
assert result[1]["blocking_keys"] == ["sector", "company_name"]
|
||||
assert result[1]["limit_comparisons"] == 1000
|
||||
assert "input1" in result[1]["comparison_prompt"]
|
||||
assert "input2" in result[1]["comparison_prompt"]
|
||||
assert "inputs" in result[1]["resolution_prompt"]
|
||||
# Output schema should be derived from reduce_key
|
||||
assert result[1]["output"]["schema"] == {"sector": "string"}
|
||||
|
||||
assert result[2]["name"] == "aggregate_by_sector"
|
||||
assert result[2]["type"] == "reduce"
|
||||
|
||||
|
||||
def test_map_to_map_resolve_reduce_apply_with_multiple_reduce_keys():
|
||||
"""Test map_to_map_resolve_reduce with multiple reduce_keys"""
|
||||
directive = MapToMapResolveReduceDirective()
|
||||
|
||||
ops_list = [
|
||||
{
|
||||
"name": "extract_products",
|
||||
"type": "map",
|
||||
"prompt": "Extract product info: {{ input.description }}",
|
||||
"model": "gpt-4o-mini",
|
||||
"output": {"schema": {"product_name": "string", "category": "string"}},
|
||||
},
|
||||
{
|
||||
"name": "aggregate_products",
|
||||
"type": "reduce",
|
||||
"reduce_key": ["brand", "region"], # Multiple reduce keys
|
||||
"prompt": "List products:\n{% for input in inputs %}{{ input.product_name }}{% endfor %}",
|
||||
"model": "gpt-4o-mini",
|
||||
"output": {"schema": {"products": "list[str]"}},
|
||||
},
|
||||
]
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
MapToMapResolveReduceInstantiateSchema,
|
||||
)
|
||||
|
||||
rewrite = MapToMapResolveReduceInstantiateSchema(
|
||||
resolve_name="normalize_products",
|
||||
comparison_prompt="Same product? {{ input1.product_name }} vs {{ input2.product_name }}",
|
||||
resolution_prompt="Normalize: {% for input in inputs %}{{ input.product_name }}{% endfor %}",
|
||||
blocking_conditions=[
|
||||
"input1['category'] == input2['category']",
|
||||
],
|
||||
blocking_keys=["brand", "region", "product_name"], # Must include reduce_keys
|
||||
limit_comparisons=500,
|
||||
)
|
||||
|
||||
result = directive.apply(
|
||||
"gpt-4o-mini", ops_list, "extract_products", "aggregate_products", rewrite
|
||||
)
|
||||
assert result[1]["type"] == "resolve"
|
||||
assert "blocking_keys" in result[1]
|
||||
assert result[1]["blocking_keys"] == ["brand", "region", "product_name"]
|
||||
# Output schema should include all reduce_keys
|
||||
assert result[1]["output"]["schema"] == {"brand": "string", "region": "string"}
|
||||
|
||||
|
||||
def test_map_resolve_to_map_with_categories_apply():
|
||||
"""Test that map_resolve_to_map_with_categories apply doesn't crash"""
|
||||
directive = MapResolveToMapWithCategoriesDirective()
|
||||
|
||||
# Pipeline with map followed by resolve
|
||||
ops_list = [
|
||||
{
|
||||
"name": "extract_sentiment",
|
||||
"type": "map",
|
||||
"prompt": "What is the sentiment? {{ input.review }}",
|
||||
"model": "gpt-4o-mini",
|
||||
"output": {"schema": {"sentiment": "string"}},
|
||||
},
|
||||
{
|
||||
"name": "normalize_sentiment",
|
||||
"type": "resolve",
|
||||
"comparison_prompt": "Same sentiment? {{ input1.sentiment }} vs {{ input2.sentiment }}",
|
||||
"resolution_prompt": "Normalize: {% for input in inputs %}{{ input.sentiment }}{% endfor %}",
|
||||
"model": "gpt-4o-mini",
|
||||
"output": {"schema": {"sentiment": "string"}},
|
||||
},
|
||||
]
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
MapResolveToMapWithCategoriesInstantiateSchema,
|
||||
)
|
||||
|
||||
rewrite = MapResolveToMapWithCategoriesInstantiateSchema(
|
||||
categories=["Positive", "Negative", "Neutral", "Mixed"],
|
||||
category_key="sentiment",
|
||||
new_prompt="""Classify the sentiment of this review into one of these categories:
|
||||
- Positive: Clearly positive sentiment
|
||||
- Negative: Clearly negative sentiment
|
||||
- Neutral: No strong sentiment
|
||||
- Mixed: Both positive and negative
|
||||
- None of the above
|
||||
|
||||
Review: {{ input.review }}
|
||||
|
||||
Return exactly one category.""",
|
||||
include_none_of_above=True,
|
||||
)
|
||||
|
||||
result = directive.apply(
|
||||
"gpt-4o-mini", ops_list, "extract_sentiment", "normalize_sentiment", rewrite
|
||||
)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1 # map only, resolve removed
|
||||
|
||||
# Check the map operation was modified correctly
|
||||
assert result[0]["name"] == "extract_sentiment"
|
||||
assert result[0]["type"] == "map"
|
||||
assert "Positive" in result[0]["prompt"]
|
||||
assert "Negative" in result[0]["prompt"]
|
||||
assert "None of the above" in result[0]["prompt"]
|
||||
|
||||
# Check validation was added
|
||||
assert "validate" in result[0]
|
||||
assert len(result[0]["validate"]) == 1
|
||||
assert "Positive" in result[0]["validate"][0]
|
||||
assert "None of the above" in result[0]["validate"][0]
|
||||
|
||||
|
||||
def test_map_resolve_to_map_with_categories_no_none_of_above():
|
||||
"""Test map_resolve_to_map_with_categories without 'None of the above' option"""
|
||||
directive = MapResolveToMapWithCategoriesDirective()
|
||||
|
||||
ops_list = [
|
||||
{
|
||||
"name": "classify_type",
|
||||
"type": "map",
|
||||
"prompt": "What type is this? {{ input.text }}",
|
||||
"model": "gpt-4o-mini",
|
||||
"output": {"schema": {"item_type": "string"}},
|
||||
},
|
||||
{
|
||||
"name": "normalize_type",
|
||||
"type": "resolve",
|
||||
"comparison_prompt": "Same type? {{ input1.item_type }} vs {{ input2.item_type }}",
|
||||
"resolution_prompt": "Normalize: {% for input in inputs %}{{ input.item_type }}{% endfor %}",
|
||||
"model": "gpt-4o-mini",
|
||||
"output": {"schema": {"item_type": "string"}},
|
||||
},
|
||||
]
|
||||
|
||||
from docetl.reasoning_optimizer.instantiate_schemas import (
|
||||
MapResolveToMapWithCategoriesInstantiateSchema,
|
||||
)
|
||||
|
||||
rewrite = MapResolveToMapWithCategoriesInstantiateSchema(
|
||||
categories=["TypeA", "TypeB", "TypeC"],
|
||||
category_key="item_type",
|
||||
new_prompt="""Classify into: TypeA, TypeB, or TypeC
|
||||
Text: {{ input.text }}""",
|
||||
include_none_of_above=False,
|
||||
)
|
||||
|
||||
result = directive.apply(
|
||||
"gpt-4o-mini", ops_list, "classify_type", "normalize_type", rewrite
|
||||
)
|
||||
assert len(result) == 1
|
||||
|
||||
# Check validation does NOT include 'None of the above'
|
||||
assert "validate" in result[0]
|
||||
assert "None of the above" not in result[0]["validate"][0]
|
||||
assert "TypeA" in result[0]["validate"][0]
|
||||
assert "TypeB" in result[0]["validate"][0]
|
||||
assert "TypeC" in result[0]["validate"][0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run all tests
|
||||
test_chaining_apply()
|
||||
|
|
@ -1078,4 +1312,16 @@ if __name__ == "__main__":
|
|||
test_cascade_filtering_apply()
|
||||
print("✅ Cascade filtering apply test passed")
|
||||
|
||||
test_map_to_map_resolve_reduce_apply()
|
||||
print("✅ Map to map resolve reduce apply test passed")
|
||||
|
||||
test_map_to_map_resolve_reduce_apply_with_multiple_reduce_keys()
|
||||
print("✅ Map to map resolve reduce with multiple reduce keys apply test passed")
|
||||
|
||||
test_map_resolve_to_map_with_categories_apply()
|
||||
print("✅ Map resolve to map with categories apply test passed")
|
||||
|
||||
test_map_resolve_to_map_with_categories_no_none_of_above()
|
||||
print("✅ Map resolve to map with categories (no none of above) apply test passed")
|
||||
|
||||
print("\n🎉 All directive apply tests passed!")
|
||||
|
|
|
|||
|
|
@ -10,14 +10,13 @@ from typing import Dict, List
|
|||
from datetime import datetime
|
||||
|
||||
from docetl.reasoning_optimizer.directives import (
|
||||
ChainingDirective,
|
||||
ChainingDirective,
|
||||
GleaningDirective,
|
||||
ReduceGleaningDirective,
|
||||
ReduceChainingDirective,
|
||||
ChangeModelDirective,
|
||||
DocSummarizationDirective,
|
||||
IsolatingSubtasksDirective,
|
||||
DocCompressionDirective,
|
||||
DeterministicDocCompressionDirective,
|
||||
OperatorFusionDirective,
|
||||
DocumentChunkingDirective,
|
||||
|
|
@ -28,8 +27,12 @@ from docetl.reasoning_optimizer.directives import (
|
|||
SwapWithCodeDirective,
|
||||
HierarchicalReduceDirective,
|
||||
CascadeFilteringDirective,
|
||||
TestResult
|
||||
MapToMapResolveReduceDirective,
|
||||
MapResolveToMapWithCategoriesDirective,
|
||||
TestResult,
|
||||
)
|
||||
# DocCompressionDirective is commented out in __init__.py, import directly
|
||||
from docetl.reasoning_optimizer.directives.doc_compression import DocCompressionDirective
|
||||
|
||||
def run_all_directive_tests(agent_llm: str = "gpt-4.1") -> Dict[str, List[TestResult]]:
|
||||
"""
|
||||
|
|
@ -68,6 +71,8 @@ def run_all_directive_tests(agent_llm: str = "gpt-4.1") -> Dict[str, List[TestRe
|
|||
SwapWithCodeDirective(),
|
||||
HierarchicalReduceDirective(),
|
||||
CascadeFilteringDirective(),
|
||||
MapToMapResolveReduceDirective(),
|
||||
MapResolveToMapWithCategoriesDirective(),
|
||||
]
|
||||
|
||||
all_results = {}
|
||||
|
|
@ -176,6 +181,8 @@ def run_specific_directive_test(directive_name: str, agent_llm: str = "gpt-4o-mi
|
|||
"clarify_instructions": ClarifyInstructionsDirective(),
|
||||
"swap_with_code": SwapWithCodeDirective(),
|
||||
"cascade_filtering": CascadeFilteringDirective(),
|
||||
"map_to_map_resolve_reduce": MapToMapResolveReduceDirective(),
|
||||
"map_resolve_to_map_with_categories": MapResolveToMapWithCategoriesDirective(),
|
||||
}
|
||||
|
||||
if directive_name.lower() not in directive_map:
|
||||
|
|
|
|||
|
|
@ -5,4 +5,9 @@ export const API_ROUTES = {
|
|||
CANCEL: (taskId: string) =>
|
||||
`/api/shouldOptimize?taskId=${taskId}&cancel=true`,
|
||||
},
|
||||
DECOMPOSE: {
|
||||
SUBMIT: "/api/decompose",
|
||||
STATUS: (taskId: string) => `/api/decompose?taskId=${taskId}`,
|
||||
CANCEL: (taskId: string) => `/api/decompose?taskId=${taskId}&cancel=true`,
|
||||
},
|
||||
} as const;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,79 @@
|
|||
import { NextRequest, NextResponse } from "next/server";
|
||||
|
||||
const FASTAPI_URL = `${
|
||||
process.env.NEXT_PUBLIC_BACKEND_HTTPS ? "https" : "http"
|
||||
}://${process.env.NEXT_PUBLIC_BACKEND_HOST}:${
|
||||
process.env.NEXT_PUBLIC_BACKEND_PORT
|
||||
}`;
|
||||
|
||||
// Helper to handle errors consistently
|
||||
const handleError = (error: unknown, status = 500) => {
|
||||
const message =
|
||||
error instanceof Error ? error.message : "Internal server error";
|
||||
return NextResponse.json({ error: message }, { status });
|
||||
};
|
||||
|
||||
// Helper to proxy requests to FastAPI
|
||||
async function proxyRequest(path: string, init?: RequestInit) {
|
||||
const response = await fetch(`${FASTAPI_URL}${path}`, {
|
||||
...init,
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...init?.headers,
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`FastAPI server error: ${error}`);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
export async function POST(request: NextRequest): Promise<NextResponse> {
|
||||
try {
|
||||
// Extract task ID from the URL if it exists
|
||||
const taskId = request.nextUrl.searchParams.get("taskId");
|
||||
const isCancel = request.nextUrl.searchParams.get("cancel") === "true";
|
||||
|
||||
// Handle different POST scenarios
|
||||
if (taskId) {
|
||||
if (isCancel) {
|
||||
// Cancel task
|
||||
const data = await proxyRequest(`/decompose/${taskId}/cancel`, {
|
||||
method: "POST",
|
||||
});
|
||||
return NextResponse.json(data);
|
||||
}
|
||||
// Invalid request with taskId but no cancel
|
||||
return handleError(new Error("Invalid request"), 400);
|
||||
}
|
||||
|
||||
// Submit new task
|
||||
const body = await request.json();
|
||||
const data = await proxyRequest("/decompose", {
|
||||
method: "POST",
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
return NextResponse.json(data, { status: 202 });
|
||||
} catch (error) {
|
||||
return handleError(error);
|
||||
}
|
||||
}
|
||||
|
||||
export async function GET(request: NextRequest): Promise<NextResponse> {
|
||||
try {
|
||||
// Extract task ID from the URL
|
||||
const taskId = request.nextUrl.searchParams.get("taskId");
|
||||
|
||||
if (!taskId) {
|
||||
return handleError(new Error("Task ID is required"), 400);
|
||||
}
|
||||
|
||||
const data = await proxyRequest(`/decompose/${taskId}`);
|
||||
return NextResponse.json(data);
|
||||
} catch (error) {
|
||||
return handleError(error);
|
||||
}
|
||||
}
|
||||
|
|
@ -33,6 +33,77 @@ export function getNamespaceDir(homeDir: string, namespace: string) {
|
|||
return path.join(homeDir, ".docetl", namespace);
|
||||
}
|
||||
|
||||
/**
|
||||
* Safely parse a JSON value - returns the parsed value if it's a string,
|
||||
* or the original value if it's already an object.
|
||||
* @param value - The value to parse
|
||||
* @param fieldName - Name of the field (for error messages)
|
||||
* @returns The parsed value
|
||||
*/
|
||||
function safeParseJSON(
|
||||
value: unknown,
|
||||
fieldName: string
|
||||
): Record<string, unknown> | unknown[] {
|
||||
if (value === null || value === undefined) {
|
||||
throw new Error(`Field '${fieldName}' is null or undefined`);
|
||||
}
|
||||
|
||||
// If it's already an object or array, return as-is
|
||||
if (typeof value === "object") {
|
||||
return value as Record<string, unknown> | unknown[];
|
||||
}
|
||||
|
||||
// If it's a string, try to parse it
|
||||
if (typeof value === "string") {
|
||||
try {
|
||||
return JSON.parse(value);
|
||||
} catch (e) {
|
||||
throw new Error(
|
||||
`Field '${fieldName}' contains invalid JSON: ${value.slice(0, 100)}${value.length > 100 ? "..." : ""}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
`Field '${fieldName}' has unexpected type '${typeof value}': ${String(value).slice(0, 100)}`
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively sanitize an object for YAML serialization.
|
||||
* Converts any nested [object Object] strings to proper objects.
|
||||
*/
|
||||
function sanitizeForYaml(obj: unknown, path: string = ""): unknown {
|
||||
if (obj === null || obj === undefined) {
|
||||
return obj;
|
||||
}
|
||||
|
||||
if (typeof obj === "string") {
|
||||
// Check for [object Object] which indicates a serialization issue
|
||||
if (obj === "[object Object]") {
|
||||
console.warn(
|
||||
`Warning: Found '[object Object]' string at path '${path}'. This indicates a serialization issue.`
|
||||
);
|
||||
return {}; // Return empty object as fallback
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
if (Array.isArray(obj)) {
|
||||
return obj.map((item, idx) => sanitizeForYaml(item, `${path}[${idx}]`));
|
||||
}
|
||||
|
||||
if (typeof obj === "object") {
|
||||
const result: Record<string, unknown> = {};
|
||||
for (const [key, value] of Object.entries(obj)) {
|
||||
result[key] = sanitizeForYaml(value, path ? `${path}.${key}` : key);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
return obj;
|
||||
}
|
||||
|
||||
export function generatePipelineConfig(
|
||||
namespace: string,
|
||||
default_model: string,
|
||||
|
|
@ -82,9 +153,18 @@ export function generatePipelineConfig(
|
|||
const litellm_completion_kwargs =
|
||||
op.otherKwargs?.litellm_completion_kwargs;
|
||||
if (litellm_completion_kwargs) {
|
||||
op.otherKwargs.litellm_completion_kwargs = JSON.parse(
|
||||
litellm_completion_kwargs
|
||||
);
|
||||
try {
|
||||
op.otherKwargs.litellm_completion_kwargs = safeParseJSON(
|
||||
litellm_completion_kwargs,
|
||||
`${op.name}.litellm_completion_kwargs`
|
||||
);
|
||||
} catch (e) {
|
||||
console.warn(
|
||||
`Warning: Could not parse litellm_completion_kwargs for operation '${op.name}':`,
|
||||
e
|
||||
);
|
||||
// Keep the original value if parsing fails
|
||||
}
|
||||
}
|
||||
|
||||
const newOp: Record<string, unknown> = {
|
||||
|
|
@ -175,10 +255,14 @@ export function generatePipelineConfig(
|
|||
// If it's a sample operation with custom method, parse the samples as key-value pairs
|
||||
if (op.type === "sample" && op.otherKwargs?.method === "custom") {
|
||||
try {
|
||||
newOp.samples = JSON.parse(op.otherKwargs.samples);
|
||||
newOp.samples = safeParseJSON(
|
||||
op.otherKwargs.samples,
|
||||
`${op.name}.samples`
|
||||
);
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
"Failed to parse custom samples as JSON, using raw value"
|
||||
`Failed to parse custom samples for operation '${op.name}':`,
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -217,14 +301,12 @@ export function generatePipelineConfig(
|
|||
// Fix type errors by asserting the pipeline config type
|
||||
let pipelineConfig: any = {
|
||||
from_docwrangler: true,
|
||||
optimizer_model: optimizerModel,
|
||||
datasets,
|
||||
default_model,
|
||||
...(enable_observability && {
|
||||
optimizer_config: {
|
||||
force_decompose: true,
|
||||
},
|
||||
}),
|
||||
optimizer_config: {
|
||||
rewrite_agent_model: optimizerModel,
|
||||
...(enable_observability && { force_decompose: true }),
|
||||
},
|
||||
operations: updatedOperations,
|
||||
pipeline: {
|
||||
steps: [
|
||||
|
|
@ -308,7 +390,9 @@ export function generatePipelineConfig(
|
|||
const outputOpName = operationsToRun[currentOpIndex].name;
|
||||
outputPath = path.join(outputBase, "data_processing", outputOpName + ".json");
|
||||
|
||||
const yamlString = yaml.dump(pipelineConfig);
|
||||
// Sanitize the config before serializing to YAML to catch any [object Object] issues
|
||||
const sanitizedConfig = sanitizeForYaml(pipelineConfig, "pipelineConfig");
|
||||
const yamlString = yaml.dump(sanitizedConfig);
|
||||
|
||||
console.log(yamlString);
|
||||
|
||||
|
|
|
|||
|
|
@ -104,6 +104,7 @@ export interface OptimizeResult {
|
|||
should_optimize?: string;
|
||||
input_data?: Array<Record<string, unknown>>;
|
||||
output_data?: Array<Record<string, unknown>>;
|
||||
num_docs_analyzed?: number;
|
||||
cost?: number;
|
||||
error?: string;
|
||||
created_at: string;
|
||||
|
|
@ -114,3 +115,24 @@ export interface APIKey {
|
|||
name: string;
|
||||
value: string;
|
||||
}
|
||||
|
||||
export interface DecomposeRequest {
|
||||
yaml_config: string;
|
||||
step_name: string;
|
||||
op_name: string;
|
||||
}
|
||||
|
||||
export interface DecomposeResult {
|
||||
task_id: string;
|
||||
status: TaskStatus;
|
||||
decomposed_operations?: Array<Record<string, unknown>>;
|
||||
winning_directive?: string;
|
||||
candidates_evaluated?: number;
|
||||
original_outputs?: Array<Record<string, unknown>>;
|
||||
decomposed_outputs?: Array<Record<string, unknown>>;
|
||||
comparison_rationale?: string;
|
||||
cost?: number;
|
||||
error?: string;
|
||||
created_at: string;
|
||||
completed_at?: string;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,12 +24,14 @@ interface AnsiRendererProps {
|
|||
text: string;
|
||||
readyState: number;
|
||||
setTerminalOutput: (text: string) => void;
|
||||
isDecomposing?: boolean;
|
||||
}
|
||||
|
||||
const AnsiRenderer: React.FC<AnsiRendererProps> = ({
|
||||
text,
|
||||
readyState,
|
||||
setTerminalOutput,
|
||||
isDecomposing = false,
|
||||
}) => {
|
||||
const html = convert.toHtml(text);
|
||||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
|
|
@ -51,7 +53,9 @@ const AnsiRenderer: React.FC<AnsiRendererProps> = ({
|
|||
}
|
||||
};
|
||||
|
||||
const isWebSocketClosed = readyState === WebSocket.CLOSED;
|
||||
// Consider connected if either WebSocket is open OR decomposing is in progress
|
||||
const isConnected = readyState === WebSocket.OPEN || isDecomposing;
|
||||
const isWebSocketClosed = readyState === WebSocket.CLOSED && !isDecomposing;
|
||||
|
||||
return (
|
||||
<div
|
||||
|
|
@ -92,9 +96,11 @@ const AnsiRenderer: React.FC<AnsiRendererProps> = ({
|
|||
)}
|
||||
</div>
|
||||
<div className="flex justify-between items-center text-xs text-gray-500">
|
||||
<div className={isWebSocketClosed ? "text-red-500" : ""}>
|
||||
<div className={isWebSocketClosed ? "text-red-500" : isDecomposing ? "text-blue-400" : ""}>
|
||||
Status:{" "}
|
||||
{readyState === WebSocket.CONNECTING
|
||||
{isDecomposing
|
||||
? "Decomposing..."
|
||||
: readyState === WebSocket.CONNECTING
|
||||
? "Connecting"
|
||||
: readyState === WebSocket.OPEN
|
||||
? "Connected"
|
||||
|
|
@ -106,10 +112,7 @@ const AnsiRenderer: React.FC<AnsiRendererProps> = ({
|
|||
</div>
|
||||
<button
|
||||
onClick={() => setTerminalOutput("")}
|
||||
className={`hover:text-white transition-colors ${
|
||||
isWebSocketClosed ? "cursor-not-allowed opacity-50" : ""
|
||||
}`}
|
||||
disabled={isWebSocketClosed}
|
||||
className="hover:text-white transition-colors"
|
||||
>
|
||||
Clear
|
||||
</button>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,271 @@
|
|||
import React from "react";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogDescription,
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ChevronLeft, ChevronRight, Check, X } from "lucide-react";
|
||||
import { DecomposeResult } from "@/app/types";
|
||||
|
||||
interface DecompositionComparisonDialogProps {
|
||||
isOpen: boolean;
|
||||
result: DecomposeResult | null;
|
||||
operationName?: string;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
onApply: () => void;
|
||||
onCancel: () => void;
|
||||
}
|
||||
|
||||
export const DecompositionComparisonDialog: React.FC<
|
||||
DecompositionComparisonDialogProps
|
||||
> = ({ isOpen, result, operationName, onOpenChange, onApply, onCancel }) => {
|
||||
const [currentPage, setCurrentPage] = React.useState(1);
|
||||
|
||||
const originalOutputs = result?.original_outputs || [];
|
||||
const decomposedOutputs = result?.decomposed_outputs || [];
|
||||
const maxRows = Math.max(originalOutputs.length, decomposedOutputs.length);
|
||||
|
||||
const renderOutputTable = (
|
||||
data: Array<Record<string, unknown>>,
|
||||
title: string
|
||||
) => {
|
||||
if (!data.length) {
|
||||
return (
|
||||
<div className="text-center text-muted-foreground py-8">
|
||||
No outputs available
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const columns = Object.keys(data[0]).filter(
|
||||
(column) => !column.startsWith("_observability")
|
||||
);
|
||||
|
||||
const currentRow = data[currentPage - 1];
|
||||
if (!currentRow) return null;
|
||||
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
<h4 className="font-medium text-sm text-muted-foreground">{title}</h4>
|
||||
<div className="border rounded-md">
|
||||
<div className="max-h-[250px] overflow-auto">
|
||||
<Table className="relative w-full border-collapse">
|
||||
<TableHeader>
|
||||
<TableRow className="sticky top-0 bg-background z-10 border-b">
|
||||
{columns.map((column) => (
|
||||
<TableHead
|
||||
key={column}
|
||||
className="h-8 px-3 text-left align-middle bg-background text-xs"
|
||||
>
|
||||
{column}
|
||||
</TableHead>
|
||||
))}
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
<TableRow>
|
||||
{columns.map((column) => (
|
||||
<TableCell key={column} className="p-2 align-top">
|
||||
<pre className="whitespace-pre-wrap font-mono text-xs text-left max-w-[300px]">
|
||||
{typeof currentRow[column] === "object"
|
||||
? JSON.stringify(currentRow[column], null, 2)
|
||||
: String(currentRow[column] ?? "")}
|
||||
</pre>
|
||||
</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
React.useEffect(() => {
|
||||
if (isOpen) {
|
||||
setCurrentPage(1);
|
||||
}
|
||||
}, [isOpen]);
|
||||
|
||||
if (!result) return null;
|
||||
|
||||
const isOriginalWinner = result.winning_directive === "original";
|
||||
|
||||
return (
|
||||
<Dialog open={isOpen} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-7xl max-h-[90vh] flex flex-col">
|
||||
<DialogHeader className="flex-shrink-0 border-b pb-4">
|
||||
<DialogTitle className="text-xl">
|
||||
Review Decomposition Results
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
Compare the original operation outputs with the decomposed version
|
||||
before applying changes.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="flex-1 overflow-y-auto py-4 space-y-6">
|
||||
{/* Summary */}
|
||||
<div className="p-4 bg-muted rounded-lg space-y-2">
|
||||
<div className="flex items-center gap-4 flex-wrap">
|
||||
{operationName && (
|
||||
<div className="flex items-center">
|
||||
<span className="font-medium text-sm mr-2">Operation:</span>
|
||||
<span className="bg-primary/15 text-primary rounded-md px-2 py-1 text-sm">
|
||||
{operationName}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
<div className="flex items-center">
|
||||
<span className="font-medium text-sm mr-2">
|
||||
Winning Strategy:
|
||||
</span>
|
||||
<span
|
||||
className={`rounded-md px-2 py-1 text-sm ${
|
||||
isOriginalWinner
|
||||
? "bg-yellow-100 text-yellow-800"
|
||||
: "bg-green-100 text-green-800"
|
||||
}`}
|
||||
>
|
||||
{result.winning_directive}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center">
|
||||
<span className="font-medium text-sm mr-2">
|
||||
Candidates Evaluated:
|
||||
</span>
|
||||
<span className="text-sm">{result.candidates_evaluated}</span>
|
||||
</div>
|
||||
{result.cost && (
|
||||
<div className="flex items-center">
|
||||
<span className="font-medium text-sm mr-2">Cost:</span>
|
||||
<span className="text-sm">${result.cost.toFixed(4)}</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Comparison Rationale */}
|
||||
{result.comparison_rationale && (
|
||||
<div className="space-y-2">
|
||||
<h3 className="font-medium text-base">Why This Was Chosen</h3>
|
||||
<div className="p-3 bg-blue-50 border border-blue-200 rounded-lg text-sm">
|
||||
{result.comparison_rationale}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Side by side comparison */}
|
||||
<div className="space-y-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<h3 className="font-medium text-base">
|
||||
Output Comparison (Sample {currentPage} of {maxRows || 1})
|
||||
</h3>
|
||||
<div className="flex items-center space-x-1">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() =>
|
||||
setCurrentPage((prev) => Math.max(prev - 1, 1))
|
||||
}
|
||||
disabled={currentPage === 1}
|
||||
className="px-2 py-1"
|
||||
>
|
||||
<ChevronLeft className="h-4 w-4" />
|
||||
Previous
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() =>
|
||||
setCurrentPage((prev) => Math.min(prev + 1, maxRows))
|
||||
}
|
||||
disabled={currentPage === maxRows || maxRows === 0}
|
||||
className="px-2 py-1"
|
||||
>
|
||||
Next
|
||||
<ChevronRight className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
{renderOutputTable(originalOutputs, "Original Output")}
|
||||
{renderOutputTable(
|
||||
decomposedOutputs,
|
||||
`Decomposed Output (${result.winning_directive})`
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Decomposed Operations Preview */}
|
||||
{result.decomposed_operations &&
|
||||
result.decomposed_operations.length > 0 &&
|
||||
!isOriginalWinner && (
|
||||
<div className="space-y-2">
|
||||
<h3 className="font-medium text-base">
|
||||
New Operations ({result.decomposed_operations.length})
|
||||
</h3>
|
||||
<div className="space-y-2">
|
||||
{result.decomposed_operations.map((op, idx) => (
|
||||
<div
|
||||
key={idx}
|
||||
className="p-3 border rounded-lg bg-muted/50"
|
||||
>
|
||||
<div className="flex items-center gap-2 mb-2">
|
||||
<span className="font-medium text-sm">
|
||||
{(op.name as string) || `Operation ${idx + 1}`}
|
||||
</span>
|
||||
<span className="text-xs bg-primary/10 text-primary px-2 py-0.5 rounded">
|
||||
{op.type as string}
|
||||
</span>
|
||||
</div>
|
||||
{op.prompt && (
|
||||
<pre className="text-xs font-mono bg-background p-2 rounded border max-h-[100px] overflow-auto">
|
||||
{String(op.prompt).slice(0, 500)}
|
||||
{String(op.prompt).length > 500 ? "..." : ""}
|
||||
</pre>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end items-center gap-3 pt-4 border-t mt-4">
|
||||
<Button variant="outline" onClick={onCancel}>
|
||||
<X className="h-4 w-4 mr-2" />
|
||||
Keep Original
|
||||
</Button>
|
||||
<Button
|
||||
onClick={onApply}
|
||||
disabled={isOriginalWinner}
|
||||
title={
|
||||
isOriginalWinner
|
||||
? "Original was determined to be the best option"
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
<Check className="h-4 w-4 mr-2" />
|
||||
Apply Decomposition
|
||||
</Button>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
DecompositionComparisonDialog.displayName = "DecompositionComparisonDialog";
|
||||
|
|
@ -39,7 +39,6 @@ import { Skeleton } from "@/components/ui/skeleton";
|
|||
import { debounce } from "lodash";
|
||||
import { Guardrails, GleaningConfig } from "./operations/args";
|
||||
import createOperationComponent from "./operations/components";
|
||||
import { useWebSocket } from "@/contexts/WebSocketContext";
|
||||
import { Badge } from "./ui/badge";
|
||||
import {
|
||||
Popover,
|
||||
|
|
@ -161,7 +160,7 @@ const OperationHeader: React.FC<OperationHeaderProps> = React.memo(
|
|||
}`}
|
||||
/>
|
||||
</HoverCardTrigger>
|
||||
<HoverCardContent className="w-72" side="bottom" align="start">
|
||||
<HoverCardContent className="w-96 max-h-80" side="bottom" align="start">
|
||||
<div className="flex flex-col space-y-1">
|
||||
<p className="text-sm font-medium">
|
||||
{optimizeResult === undefined || optimizeResult === null
|
||||
|
|
@ -170,13 +169,15 @@ const OperationHeader: React.FC<OperationHeaderProps> = React.memo(
|
|||
? "Decomposition Status"
|
||||
: "Decomposition Recommended"}
|
||||
</p>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{optimizeResult === undefined || optimizeResult === null
|
||||
? "Analyzing operation complexity..."
|
||||
: optimizeResult === ""
|
||||
? "No decomposition needed for this operation"
|
||||
: "Recommended decomposition: " + optimizeResult}
|
||||
</p>
|
||||
<div className="max-h-64 overflow-y-auto">
|
||||
<p className="text-sm text-muted-foreground whitespace-pre-wrap">
|
||||
{optimizeResult === undefined || optimizeResult === null
|
||||
? "Analyzing operation complexity..."
|
||||
: optimizeResult === ""
|
||||
? "No decomposition needed for this operation"
|
||||
: optimizeResult}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</HoverCardContent>
|
||||
</HoverCard>
|
||||
|
|
@ -367,7 +368,7 @@ const OperationHeader: React.FC<OperationHeaderProps> = React.memo(
|
|||
disabled={disabled}
|
||||
>
|
||||
<Zap className="mr-2 h-4 w-4" />
|
||||
Optimize Operation
|
||||
Decompose Operation
|
||||
</Button>
|
||||
)}
|
||||
|
||||
|
|
@ -720,7 +721,6 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
|
|||
output: pipelineOutput,
|
||||
setOutput,
|
||||
isLoadingOutputs,
|
||||
setIsLoadingOutputs,
|
||||
numOpRun,
|
||||
setNumOpRun,
|
||||
currentFile,
|
||||
|
|
@ -730,18 +730,14 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
|
|||
sampleSize,
|
||||
setCost,
|
||||
defaultModel,
|
||||
optimizerModel,
|
||||
setTerminalOutput,
|
||||
namespace,
|
||||
apiKeys,
|
||||
systemPrompt,
|
||||
extraPipelineSettings,
|
||||
onRequestDecompositionRef,
|
||||
} = usePipelineContext();
|
||||
const { toast } = useToast();
|
||||
|
||||
const operationRef = useRef(operation);
|
||||
const { connect, sendMessage, lastMessage, readyState, disconnect } =
|
||||
useWebSocket();
|
||||
|
||||
useEffect(() => {
|
||||
operationRef.current = operation;
|
||||
|
|
@ -808,75 +804,19 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
|
|||
const handleOptimizeConfirm = useCallback(async () => {
|
||||
if (!operation) return;
|
||||
|
||||
try {
|
||||
// Clear the output
|
||||
setTerminalOutput("");
|
||||
setIsLoadingOutputs(true);
|
||||
setShowOptimizeDialog(false);
|
||||
|
||||
// Write pipeline config
|
||||
const response = await fetch("/api/writePipelineConfig", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
default_model: defaultModel,
|
||||
data: { path: currentFile?.path || "" },
|
||||
operations,
|
||||
operation_id: operation.id,
|
||||
name: pipelineName,
|
||||
sample_size: sampleSize,
|
||||
optimize: true,
|
||||
clear_intermediate: false,
|
||||
system_prompt: systemPrompt,
|
||||
namespace: namespace,
|
||||
apiKeys: apiKeys,
|
||||
optimizerModel: optimizerModel,
|
||||
extraPipelineSettings: extraPipelineSettings,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(await response.text());
|
||||
}
|
||||
|
||||
const { filePath } = await response.json();
|
||||
|
||||
// Ensure WebSocket is connected
|
||||
await connect();
|
||||
|
||||
// Send message to run the pipeline
|
||||
sendMessage({
|
||||
yaml_config: filePath,
|
||||
optimize: true,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error optimizing operation:", error);
|
||||
// Use the decomposition flow from context if available
|
||||
if (onRequestDecompositionRef.current) {
|
||||
onRequestDecompositionRef.current(operation.id, operation.name);
|
||||
} else {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: error.message,
|
||||
description: "Decomposition handler not available",
|
||||
variant: "destructive",
|
||||
});
|
||||
// Close the WebSocket connection
|
||||
disconnect();
|
||||
} finally {
|
||||
setShowOptimizeDialog(false);
|
||||
}
|
||||
}, [
|
||||
operation,
|
||||
defaultModel,
|
||||
currentFile,
|
||||
operations,
|
||||
pipelineName,
|
||||
sampleSize,
|
||||
optimizerModel,
|
||||
connect,
|
||||
sendMessage,
|
||||
systemPrompt,
|
||||
namespace,
|
||||
apiKeys,
|
||||
extraPipelineSettings,
|
||||
]);
|
||||
}, [operation, onRequestDecompositionRef, toast]);
|
||||
|
||||
const onShowOutput = useCallback(async () => {
|
||||
if (!operation) return;
|
||||
|
|
@ -1224,7 +1164,7 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
|
|||
>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle>Optimize Operation</AlertDialogTitle>
|
||||
<AlertDialogTitle>Decompose Operation</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
{!hasOpenAIKey && !isLocalMode ? (
|
||||
<div className="space-y-2">
|
||||
|
|
@ -1232,8 +1172,8 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
|
|||
OpenAI API Key Required
|
||||
</p>
|
||||
<p>
|
||||
To use the optimizer, please add your OpenAI API key in Edit{" "}
|
||||
{">"}
|
||||
To decompose operations, please add your OpenAI API key in
|
||||
Edit {">"}
|
||||
Edit API Keys.
|
||||
</p>
|
||||
<button
|
||||
|
|
@ -1245,11 +1185,10 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
|
|||
</div>
|
||||
) : (
|
||||
<p>
|
||||
This will analyze the operation and replace it with another
|
||||
pipeline that has higher accuracy (as determined by an
|
||||
LLM-as-a-judge), if it can be found. Do you want to proceed?
|
||||
The process may take between 2 and 10 minutes, depending on
|
||||
how complex your data is.
|
||||
This will analyze the operation and try different
|
||||
decomposition strategies (chaining, chunking, gleaning, etc.)
|
||||
to improve accuracy. The best strategy will be selected via
|
||||
LLM-as-a-judge comparison. Do you want to proceed?
|
||||
</p>
|
||||
)}
|
||||
</AlertDialogDescription>
|
||||
|
|
@ -1260,7 +1199,7 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
|
|||
onClick={handleOptimizeConfirm}
|
||||
disabled={!hasOpenAIKey && !isLocalMode}
|
||||
>
|
||||
Proceed
|
||||
Decompose
|
||||
</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ interface OptimizationDialogProps {
|
|||
operationName?: string;
|
||||
inputData?: Array<Record<string, unknown>>;
|
||||
outputData?: Array<Record<string, unknown>>;
|
||||
numDocsAnalyzed?: number;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
onDecompose?: () => void;
|
||||
}
|
||||
|
|
@ -34,6 +35,7 @@ export const OptimizationDialog: React.FC<OptimizationDialogProps> = ({
|
|||
inputData,
|
||||
outputData,
|
||||
operationName,
|
||||
numDocsAnalyzed,
|
||||
onOpenChange,
|
||||
onDecompose,
|
||||
}) => {
|
||||
|
|
@ -140,9 +142,9 @@ export const OptimizationDialog: React.FC<OptimizationDialogProps> = ({
|
|||
<DialogHeader className="flex-shrink-0 border-b pb-4">
|
||||
<DialogTitle className="text-xl">Operation Too Complex</DialogTitle>
|
||||
<p className="text-base mt-2">
|
||||
This operation might be too complex for the LLM to handle
|
||||
efficiently. We recommend breaking it down into smaller, more
|
||||
manageable steps.
|
||||
After analyzing {numDocsAnalyzed ?? "several"} documents, this
|
||||
operation might be too complex for the LLM to handle efficiently. We
|
||||
recommend breaking it down into smaller, more manageable steps.
|
||||
</p>
|
||||
</DialogHeader>
|
||||
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ const useOutputContext = () => {
|
|||
const {
|
||||
output,
|
||||
isLoadingOutputs,
|
||||
isDecomposing,
|
||||
terminalOutput,
|
||||
setTerminalOutput,
|
||||
optimizerProgress,
|
||||
|
|
@ -91,6 +92,7 @@ const useOutputContext = () => {
|
|||
return {
|
||||
output,
|
||||
isLoadingOutputs,
|
||||
isDecomposing,
|
||||
terminalOutput,
|
||||
setTerminalOutput,
|
||||
optimizerProgress,
|
||||
|
|
@ -385,7 +387,7 @@ VisualizeContent.displayName = "VisualizeContent";
|
|||
|
||||
// Move ConsoleContent outside
|
||||
export const ConsoleContent = memo(() => {
|
||||
const { terminalOutput, setTerminalOutput, optimizerProgress } =
|
||||
const { terminalOutput, setTerminalOutput, optimizerProgress, isDecomposing } =
|
||||
useOutputContext();
|
||||
const { readyState } = useWebSocket();
|
||||
|
||||
|
|
@ -469,6 +471,7 @@ export const ConsoleContent = memo(() => {
|
|||
text={terminalOutput || ""}
|
||||
readyState={readyState}
|
||||
setTerminalOutput={setTerminalOutput}
|
||||
isDecomposing={isDecomposing}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -478,7 +481,7 @@ ConsoleContent.displayName = "ConsoleContent";
|
|||
|
||||
// Main Output component
|
||||
export const Output = memo(() => {
|
||||
const { output, isLoadingOutputs, sampleSize, operations } =
|
||||
const { output, isLoadingOutputs, isDecomposing, sampleSize, operations } =
|
||||
useOutputContext();
|
||||
const operation = useOperation(output?.operationId);
|
||||
|
||||
|
|
@ -500,10 +503,14 @@ export const Output = memo(() => {
|
|||
);
|
||||
}, [operation]);
|
||||
|
||||
// Effect for tab changes
|
||||
// Effect for tab changes - switch to console when loading or decomposing
|
||||
useEffect(() => {
|
||||
setActiveTab(isLoadingOutputs ? "console" : "table");
|
||||
}, [isLoadingOutputs]);
|
||||
if (isLoadingOutputs || isDecomposing) {
|
||||
setActiveTab("console");
|
||||
} else {
|
||||
setActiveTab("table");
|
||||
}
|
||||
}, [isLoadingOutputs, isDecomposing]);
|
||||
|
||||
// Memoize columns
|
||||
const columns = useMemo(() => {
|
||||
|
|
|
|||
|
|
@ -40,9 +40,12 @@ import { Input } from "@/components/ui/input";
|
|||
import { schemaDictToItemSet } from "./utils";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { useOptimizeCheck } from "@/hooks/useOptimizeCheck";
|
||||
import { useDecomposeWebSocket } from "@/hooks/useDecomposeWebSocket";
|
||||
import { canBeOptimized } from "@/lib/utils";
|
||||
import { Textarea } from "./ui/textarea";
|
||||
import { OptimizationDialog } from "@/components/OptimizationDialog";
|
||||
import { DecompositionComparisonDialog } from "@/components/DecompositionComparisonDialog";
|
||||
import { DecomposeResult } from "@/app/types";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
|
|
@ -240,6 +243,9 @@ const PipelineGUI: React.FC = () => {
|
|||
apiKeys,
|
||||
extraPipelineSettings,
|
||||
setExtraPipelineSettings,
|
||||
isDecomposing,
|
||||
setIsDecomposing,
|
||||
onRequestDecompositionRef,
|
||||
} = usePipelineContext();
|
||||
const [isSettingsOpen, setIsSettingsOpen] = useState(false);
|
||||
const { toast } = useToast();
|
||||
|
|
@ -253,12 +259,14 @@ const PipelineGUI: React.FC = () => {
|
|||
outputData?: Array<Record<string, unknown>>;
|
||||
operationName?: string;
|
||||
operationId?: string;
|
||||
numDocsAnalyzed?: number;
|
||||
}>({
|
||||
isOpen: false,
|
||||
content: "",
|
||||
prompt: undefined,
|
||||
operationName: undefined,
|
||||
operationId: undefined,
|
||||
numDocsAnalyzed: undefined,
|
||||
});
|
||||
const [isEditingName, setIsEditingName] = useState(false);
|
||||
const [editedPipelineName, setEditedPipelineName] = useState(pipelineName);
|
||||
|
|
@ -277,14 +285,13 @@ const PipelineGUI: React.FC = () => {
|
|||
setCost((prev) => prev + result.cost);
|
||||
|
||||
if (result.should_optimize) {
|
||||
toast({
|
||||
title: `Hey! Consider decomposing ${operations[operations.length - 1].name
|
||||
}`,
|
||||
const lastOp = operations[operations.length - 1];
|
||||
const { dismiss } = toast({
|
||||
title: `Hey! Consider decomposing ${lastOp.name}`,
|
||||
description: (
|
||||
<span
|
||||
className="cursor-pointer text-blue-500 hover:text-blue-700"
|
||||
onClick={() => {
|
||||
const lastOp = operations[operations.length - 1];
|
||||
setOptimizationDialog({
|
||||
isOpen: true,
|
||||
content: result.should_optimize,
|
||||
|
|
@ -293,7 +300,9 @@ const PipelineGUI: React.FC = () => {
|
|||
operationId: lastOp.id,
|
||||
inputData: result.input_data,
|
||||
outputData: result.output_data,
|
||||
numDocsAnalyzed: result.num_docs_analyzed,
|
||||
});
|
||||
dismiss();
|
||||
}}
|
||||
>
|
||||
Click here to see why.
|
||||
|
|
@ -312,6 +321,144 @@ const PipelineGUI: React.FC = () => {
|
|||
},
|
||||
});
|
||||
|
||||
const [decomposeDialog, setDecomposeDialog] = useState<{
|
||||
isOpen: boolean;
|
||||
result: DecomposeResult | null;
|
||||
targetOpId: string | null;
|
||||
operationName: string | null;
|
||||
}>({
|
||||
isOpen: false,
|
||||
result: null,
|
||||
targetOpId: null,
|
||||
operationName: null,
|
||||
});
|
||||
|
||||
// Ref to store the current decomposition target (avoids stale closure issue)
|
||||
const decompositionTargetRef = useRef<{
|
||||
operationId: string;
|
||||
operationName: string;
|
||||
} | null>(null);
|
||||
|
||||
const { startDecomposition, cancel: cancelDecomposition } = useDecomposeWebSocket({
|
||||
namespace,
|
||||
onOutput: (output) => {
|
||||
// Stream output to the terminal
|
||||
setTerminalOutput(output);
|
||||
},
|
||||
onComplete: (result) => {
|
||||
setIsDecomposing(false);
|
||||
setCost((prev) => prev + (result.cost || 0));
|
||||
|
||||
// Read from ref to avoid stale closure
|
||||
const target = decompositionTargetRef.current;
|
||||
console.log("Decomposition complete:", {
|
||||
result,
|
||||
target,
|
||||
decomposed_operations: result.decomposed_operations,
|
||||
});
|
||||
|
||||
// Show the comparison dialog instead of auto-applying
|
||||
setDecomposeDialog({
|
||||
isOpen: true,
|
||||
result,
|
||||
targetOpId: target?.operationId || null,
|
||||
operationName: target?.operationName || null,
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
setIsDecomposing(false);
|
||||
toast({
|
||||
title: "Decomposition Failed",
|
||||
description: error,
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const handleApplyDecomposition = () => {
|
||||
const { result, targetOpId } = decomposeDialog;
|
||||
console.log("handleApplyDecomposition called:", { result, targetOpId, decomposeDialog });
|
||||
if (!result || !targetOpId) {
|
||||
console.warn("handleApplyDecomposition: missing result or targetOpId", { result, targetOpId });
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
result.decomposed_operations &&
|
||||
result.decomposed_operations.length > 0 &&
|
||||
result.winning_directive !== "original"
|
||||
) {
|
||||
setOperations((prev) => {
|
||||
// Find the index of the target operation
|
||||
const targetIdx = prev.findIndex((op) => op.id === targetOpId);
|
||||
if (targetIdx === -1) return prev;
|
||||
|
||||
// Convert decomposed operations to our Operation format
|
||||
const newOps: Operation[] = result.decomposed_operations!.map(
|
||||
(opConfig) => ({
|
||||
id: uuidv4(),
|
||||
llmType:
|
||||
opConfig.type === "map" ||
|
||||
opConfig.type === "reduce" ||
|
||||
opConfig.type === "filter" ||
|
||||
opConfig.type === "parallel_map"
|
||||
? "LLM"
|
||||
: "non-LLM",
|
||||
type: opConfig.type as Operation["type"],
|
||||
name: opConfig.name as string,
|
||||
prompt: opConfig.prompt as string | undefined,
|
||||
output: opConfig.output
|
||||
? {
|
||||
schema: schemaDictToItemSet(
|
||||
(opConfig.output as { schema: Record<string, string> })
|
||||
.schema
|
||||
),
|
||||
}
|
||||
: undefined,
|
||||
gleaning: opConfig.gleaning as Operation["gleaning"],
|
||||
otherKwargs: Object.fromEntries(
|
||||
Object.entries(opConfig).filter(
|
||||
([key]) =>
|
||||
!["name", "type", "prompt", "output", "gleaning"].includes(key)
|
||||
)
|
||||
),
|
||||
visibility: true,
|
||||
})
|
||||
);
|
||||
|
||||
// Replace the target operation with the new operations
|
||||
const newOperations = [...prev];
|
||||
newOperations.splice(targetIdx, 1, ...newOps);
|
||||
return newOperations;
|
||||
});
|
||||
|
||||
toast({
|
||||
title: "Decomposition Applied",
|
||||
description: `Applied ${result.winning_directive} directive. Created ${result.decomposed_operations.length} operation(s).`,
|
||||
});
|
||||
}
|
||||
|
||||
setDecomposeDialog({
|
||||
isOpen: false,
|
||||
result: null,
|
||||
targetOpId: null,
|
||||
operationName: null,
|
||||
});
|
||||
};
|
||||
|
||||
const handleCancelDecomposition = () => {
|
||||
toast({
|
||||
title: "Decomposition Cancelled",
|
||||
description: "Keeping the original operation.",
|
||||
});
|
||||
setDecomposeDialog({
|
||||
isOpen: false,
|
||||
result: null,
|
||||
targetOpId: null,
|
||||
operationName: null,
|
||||
});
|
||||
};
|
||||
|
||||
const { restoreFromYAML } = useRestorePipeline({
|
||||
setOperations,
|
||||
setPipelineName,
|
||||
|
|
@ -328,79 +475,18 @@ const PipelineGUI: React.FC = () => {
|
|||
if (lastMessage) {
|
||||
if (lastMessage.type === "output") {
|
||||
setTerminalOutput(lastMessage.data);
|
||||
} else if (lastMessage.type === "optimizer_progress") {
|
||||
setOptimizerProgress({
|
||||
status: lastMessage.status,
|
||||
progress: lastMessage.progress,
|
||||
shouldOptimize: lastMessage.should_optimize,
|
||||
rationale: lastMessage.rationale,
|
||||
validatorPrompt: lastMessage.validator_prompt,
|
||||
});
|
||||
} else if (lastMessage.type === "result") {
|
||||
const runCost = lastMessage.data.cost || 0;
|
||||
setOptimizerProgress(null);
|
||||
|
||||
// See if there was an optimized operation
|
||||
const optimizedOps = lastMessage.data.optimized_ops;
|
||||
if (optimizedOps) {
|
||||
const newOperations = optimizedOps.map((optimizedOp) => {
|
||||
const {
|
||||
id,
|
||||
type,
|
||||
name,
|
||||
prompt,
|
||||
output,
|
||||
validate,
|
||||
gleaning,
|
||||
sample,
|
||||
...otherKwargs
|
||||
} = optimizedOp;
|
||||
|
||||
// Find matching operation in previous operations list
|
||||
const existingOp = operations.find((op) => op.name === name);
|
||||
|
||||
return {
|
||||
id: id || uuidv4(),
|
||||
llmType:
|
||||
type === "map" ||
|
||||
type === "reduce" ||
|
||||
type === "resolve" ||
|
||||
type === "filter" ||
|
||||
type === "parallel_map" ||
|
||||
type === "rank" ||
|
||||
type === "extract"
|
||||
? "LLM"
|
||||
: "non-LLM",
|
||||
type: type,
|
||||
name: name || "Untitled Operation",
|
||||
prompt: prompt,
|
||||
output: output
|
||||
? {
|
||||
schema: schemaDictToItemSet(output.schema),
|
||||
}
|
||||
: undefined,
|
||||
validate: validate,
|
||||
gleaning: gleaning,
|
||||
sample: sample,
|
||||
otherKwargs: otherKwargs || {},
|
||||
...(existingOp?.runIndex && { runIndex: existingOp.runIndex }),
|
||||
visibility: true,
|
||||
} as Operation;
|
||||
});
|
||||
|
||||
setOperations(newOperations);
|
||||
} else {
|
||||
// No optimized operations, so we need to check if we should optimize the last operation
|
||||
// Trigger should optimize for the last operation
|
||||
if (autoOptimizeCheck) {
|
||||
const lastOp = operations[operations.length - 1];
|
||||
if (lastOp && canBeOptimized(lastOp.type)) {
|
||||
submitTask({
|
||||
yaml_config: lastMessage.data.yaml_config,
|
||||
step_name: "data_processing", // TODO: Make this a constant
|
||||
op_name: lastOp.name,
|
||||
});
|
||||
}
|
||||
// Trigger should_optimize check for the last operation if enabled
|
||||
if (autoOptimizeCheck) {
|
||||
const lastOp = operations[operations.length - 1];
|
||||
if (lastOp && canBeOptimized(lastOp.type)) {
|
||||
submitTask({
|
||||
yaml_config: lastMessage.data.yaml_config,
|
||||
step_name: "data_processing",
|
||||
op_name: lastOp.name,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -692,20 +778,35 @@ const PipelineGUI: React.FC = () => {
|
|||
};
|
||||
|
||||
const handleStop = () => {
|
||||
sendMessage("kill");
|
||||
// Stop pipeline run if running
|
||||
if (isLoadingOutputs) {
|
||||
sendMessage("kill");
|
||||
if (readyState === WebSocket.CLOSED) {
|
||||
setIsLoadingOutputs(false);
|
||||
}
|
||||
}
|
||||
|
||||
if (readyState === WebSocket.CLOSED && isLoadingOutputs) {
|
||||
setIsLoadingOutputs(false);
|
||||
// Stop decomposition if running
|
||||
if (isDecomposing) {
|
||||
cancelDecomposition();
|
||||
setIsDecomposing(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleOptimizeFromDialog = async () => {
|
||||
if (!optimizationDialog.operationId) return;
|
||||
if (!optimizationDialog.operationId || !optimizationDialog.operationName) return;
|
||||
|
||||
// Store target in ref so onComplete callback can access it
|
||||
decompositionTargetRef.current = {
|
||||
operationId: optimizationDialog.operationId,
|
||||
operationName: optimizationDialog.operationName,
|
||||
};
|
||||
|
||||
try {
|
||||
setTerminalOutput("");
|
||||
setIsLoadingOutputs(true);
|
||||
setIsDecomposing(true);
|
||||
setTerminalOutput(""); // Clear terminal output
|
||||
|
||||
// Write the pipeline config to get a YAML file path
|
||||
const response = await fetch("/api/writePipelineConfig", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
|
|
@ -732,24 +833,115 @@ const PipelineGUI: React.FC = () => {
|
|||
|
||||
const { filePath } = await response.json();
|
||||
|
||||
await connect();
|
||||
// Use the WebSocket decompose endpoint for streaming output
|
||||
await startDecomposition(
|
||||
filePath,
|
||||
"data_processing", // Matches the step name in generatePipelineConfig
|
||||
optimizationDialog.operationName
|
||||
);
|
||||
|
||||
sendMessage({
|
||||
yaml_config: filePath,
|
||||
optimize: true,
|
||||
toast({
|
||||
title: "Decomposing Operation",
|
||||
description: "Analyzing and generating candidate decompositions. Watch the console for progress.",
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error optimizing operation:", error);
|
||||
console.error("Error starting decomposition:", error);
|
||||
setIsDecomposing(false);
|
||||
toast({
|
||||
title: "Error",
|
||||
description: error.message,
|
||||
description: error instanceof Error ? error.message : "Failed to start decomposition",
|
||||
variant: "destructive",
|
||||
});
|
||||
disconnect();
|
||||
setIsLoadingOutputs(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Handler for decomposition requests from OperationCard dropdown
|
||||
const handleDecompositionRequest = useCallback(
|
||||
async (operationId: string, operationName: string) => {
|
||||
// Store target in ref so onComplete callback can access it
|
||||
decompositionTargetRef.current = {
|
||||
operationId,
|
||||
operationName,
|
||||
};
|
||||
|
||||
try {
|
||||
setIsDecomposing(true);
|
||||
setTerminalOutput(""); // Clear terminal output
|
||||
|
||||
// Set the dialog state so we know which operation we're decomposing
|
||||
setDecomposeDialog((prev) => ({
|
||||
...prev,
|
||||
targetOpId: operationId,
|
||||
operationName: operationName,
|
||||
}));
|
||||
|
||||
// Write the pipeline config to get a YAML file path
|
||||
const response = await fetch("/api/writePipelineConfig", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
default_model: defaultModel,
|
||||
data: { path: currentFile?.path || "" },
|
||||
operations,
|
||||
operation_id: operationId,
|
||||
name: pipelineName,
|
||||
sample_size: sampleSize,
|
||||
optimize: true,
|
||||
namespace: namespace,
|
||||
apiKeys: apiKeys,
|
||||
optimizerModel: optimizerModel,
|
||||
extraPipelineSettings: extraPipelineSettings,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(await response.text());
|
||||
}
|
||||
|
||||
const { filePath } = await response.json();
|
||||
|
||||
// Start decomposition via WebSocket
|
||||
startDecomposition(filePath, "data_processing", operationName);
|
||||
|
||||
toast({
|
||||
title: "Decomposing Operation",
|
||||
description:
|
||||
"Analyzing and generating candidate decompositions. Watch the console for progress.",
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error starting decomposition:", error);
|
||||
setIsDecomposing(false);
|
||||
toast({
|
||||
title: "Error",
|
||||
description:
|
||||
error instanceof Error ? error.message : "Failed to start decomposition",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
},
|
||||
[
|
||||
defaultModel,
|
||||
currentFile,
|
||||
operations,
|
||||
pipelineName,
|
||||
sampleSize,
|
||||
namespace,
|
||||
apiKeys,
|
||||
optimizerModel,
|
||||
extraPipelineSettings,
|
||||
startDecomposition,
|
||||
setIsDecomposing,
|
||||
setTerminalOutput,
|
||||
toast,
|
||||
]
|
||||
);
|
||||
|
||||
// Register the decomposition handler in context so OperationCard can use it
|
||||
// Using ref assignment instead of state to avoid infinite loops
|
||||
onRequestDecompositionRef.current = handleDecompositionRequest;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full">
|
||||
<div
|
||||
|
|
@ -1053,7 +1245,7 @@ const PipelineGUI: React.FC = () => {
|
|||
variant="destructive"
|
||||
className="rounded-sm whitespace-nowrap"
|
||||
onClick={handleStop}
|
||||
disabled={!isLoadingOutputs}
|
||||
disabled={!isLoadingOutputs && !isDecomposing}
|
||||
>
|
||||
<StopCircle size={16} className="mr-2" />
|
||||
Stop
|
||||
|
|
@ -1144,11 +1336,22 @@ const PipelineGUI: React.FC = () => {
|
|||
operationName={optimizationDialog.operationName}
|
||||
inputData={optimizationDialog.inputData}
|
||||
outputData={optimizationDialog.outputData}
|
||||
numDocsAnalyzed={optimizationDialog.numDocsAnalyzed}
|
||||
onOpenChange={(open) =>
|
||||
setOptimizationDialog((prev) => ({ ...prev, isOpen: open }))
|
||||
}
|
||||
onDecompose={handleOptimizeFromDialog}
|
||||
/>
|
||||
<DecompositionComparisonDialog
|
||||
isOpen={decomposeDialog.isOpen}
|
||||
result={decomposeDialog.result}
|
||||
operationName={decomposeDialog.operationName || undefined}
|
||||
onOpenChange={(open) =>
|
||||
setDecomposeDialog((prev) => ({ ...prev, isOpen: open }))
|
||||
}
|
||||
onApply={handleApplyDecomposition}
|
||||
onCancel={handleCancelDecomposition}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ interface PromptInputProps {
|
|||
onChange: (value: string) => void;
|
||||
disableValidation?: boolean;
|
||||
placeholder?: string;
|
||||
operationType?: "map" | "reduce" | "filter" | "resolve" | "other";
|
||||
}
|
||||
|
||||
export const PromptInput: React.FC<PromptInputProps> = React.memo(
|
||||
|
|
@ -38,31 +39,62 @@ export const PromptInput: React.FC<PromptInputProps> = React.memo(
|
|||
prompt,
|
||||
onChange,
|
||||
disableValidation = false,
|
||||
placeholder = "Enter prompt (must be a Jinja2 template)",
|
||||
placeholder = "Enter prompt (Jinja2 template optional)",
|
||||
operationType = "map",
|
||||
}) => {
|
||||
const validateJinjaTemplate = (value: string) => {
|
||||
const hasJinjaTemplate = (value: string) => {
|
||||
if (disableValidation) return true;
|
||||
const hasOpenBrace = value.includes("{{");
|
||||
const hasCloseBrace = value.includes("}}");
|
||||
return hasOpenBrace && hasCloseBrace;
|
||||
};
|
||||
|
||||
const isReduce = operationType === "reduce";
|
||||
const templateVar = isReduce ? "inputs" : "input";
|
||||
const exampleKey = isReduce ? "inputs[0].content" : "input.content";
|
||||
|
||||
return (
|
||||
<>
|
||||
<Textarea
|
||||
placeholder={placeholder}
|
||||
className={`mb-1 rounded-sm text-sm font-mono ${
|
||||
!validateJinjaTemplate(prompt) ? "border-red-500" : ""
|
||||
!hasJinjaTemplate(prompt) ? "border-amber-400" : ""
|
||||
}`}
|
||||
rows={Math.max(3, Math.ceil(prompt.split("\n").length))}
|
||||
value={prompt}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
/>
|
||||
{!validateJinjaTemplate(prompt) && (
|
||||
<div className="text-red-500 text-sm mb-1">
|
||||
Prompt must contain Jinja2 template syntax {"{"}
|
||||
{"{"} and {"}"}
|
||||
{"}"}
|
||||
{!hasJinjaTemplate(prompt) && (
|
||||
<div className="text-amber-600 text-sm mb-1 space-y-1">
|
||||
<div>
|
||||
<strong>Warning:</strong> No Jinja2 template found.
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{isReduce ? (
|
||||
<>
|
||||
{"{"}
|
||||
{"{"} inputs {"}"}
|
||||
{"}"} (the list of all documents) will be auto-appended. For
|
||||
example, if documents have keys {'"'}id{'"'} and {'"'}content
|
||||
{'"'}, all docs with all keys will be included. To reference
|
||||
specific keys, use a for loop: {"{"}% for doc in inputs %{"}"}{" "}
|
||||
{"{"}
|
||||
{"{"} doc.content {"}"}
|
||||
{"}"} {"{"}% endfor %{"}"}.
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
{"{"}
|
||||
{"{"} input {"}"}
|
||||
{"}"} (the full document) will be auto-appended. For example,
|
||||
if documents have keys {'"'}id{'"'}, {'"'}title{'"'}, and {'"'}
|
||||
content{'"'}, all three will be included. To use only specific
|
||||
keys (e.g., just {'"'}content{'"'}), add {"{"}
|
||||
{"{"} input.content {"}"}
|
||||
{"}"} to your prompt.
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
|
|
@ -77,6 +109,7 @@ PromptInput.propTypes = {
|
|||
onChange: PropTypes.func.isRequired,
|
||||
disableValidation: PropTypes.bool,
|
||||
placeholder: PropTypes.string,
|
||||
operationType: PropTypes.oneOf(["map", "reduce", "filter", "resolve", "other"]),
|
||||
};
|
||||
|
||||
interface SchemaFormProps {
|
||||
|
|
|
|||
|
|
@ -314,6 +314,7 @@ export const ReduceOperationComponent: React.FC<OperationComponentProps> = ({
|
|||
<PromptInput
|
||||
prompt={operation.prompt || ""}
|
||||
onChange={handlePromptChange}
|
||||
operationType="reduce"
|
||||
/>
|
||||
<OutputSchema
|
||||
schema={schemaItems}
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ const ToastViewport = React.forwardRef<
|
|||
ToastViewport.displayName = ToastPrimitives.Viewport.displayName;
|
||||
|
||||
const toastVariants = cva(
|
||||
"group pointer-events-auto relative flex w-full items-center justify-between space-x-2 overflow-y-auto max-h-[80vh] whitespace-pre-wrap break-words rounded-md border p-4 pr-6 shadow-lg transition-all data-[swipe=cancel]:translate-x-0 data-[swipe=end]:translate-x-[var(--radix-toast-swipe-end-x)] data-[swipe=move]:translate-x-[var(--radix-toast-swipe-move-x)] data-[swipe=move]:transition-none data-[state=open]:animate-in data-[state=closed]:animate-out data-[swipe=end]:animate-out data-[state=closed]:fade-out-80 data-[state=closed]:slide-out-to-right-full data-[state=open]:slide-in-from-top-full data-[state=open]:sm:slide-in-from-bottom-full",
|
||||
"group pointer-events-auto relative flex w-full items-start justify-between space-x-2 overflow-hidden max-h-[80vh] whitespace-pre-wrap break-words rounded-md border p-4 pr-6 shadow-lg transition-all data-[swipe=cancel]:translate-x-0 data-[swipe=end]:translate-x-[var(--radix-toast-swipe-end-x)] data-[swipe=move]:translate-x-[var(--radix-toast-swipe-move-x)] data-[swipe=move]:transition-none data-[state=open]:animate-in data-[state=closed]:animate-out data-[swipe=end]:animate-out data-[state=closed]:fade-out-80 data-[state=closed]:slide-out-to-right-full data-[state=open]:slide-in-from-top-full data-[state=open]:sm:slide-in-from-bottom-full",
|
||||
{
|
||||
variants: {
|
||||
variant: {
|
||||
|
|
|
|||
|
|
@ -22,29 +22,33 @@ export function Toaster() {
|
|||
const descId = `toast-desc-${id}`;
|
||||
return (
|
||||
<Toast key={id} {...props}>
|
||||
<div className="grid gap-1">
|
||||
{title && <ToastTitle>{title}</ToastTitle>}
|
||||
<div className="grid gap-1 flex-1 overflow-hidden">
|
||||
<div className="flex items-center gap-2 shrink-0">
|
||||
{title && <ToastTitle>{title}</ToastTitle>}
|
||||
{description && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-5 w-5 p-0 opacity-70 hover:opacity-100"
|
||||
onClick={() => {
|
||||
const el = document.getElementById(descId);
|
||||
const text = el?.textContent ?? "";
|
||||
if (text) {
|
||||
navigator.clipboard.writeText(text);
|
||||
}
|
||||
}}
|
||||
title="Copy message"
|
||||
>
|
||||
<Copy size={12} />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
{description && (
|
||||
<ToastDescription id={descId}>{description}</ToastDescription>
|
||||
<div className="overflow-y-auto max-h-[60vh]">
|
||||
<ToastDescription id={descId}>{description}</ToastDescription>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{description && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-6 px-2 text-xs"
|
||||
onClick={() => {
|
||||
const el = document.getElementById(descId);
|
||||
const text = el?.textContent ?? "";
|
||||
if (text) {
|
||||
navigator.clipboard.writeText(text);
|
||||
}
|
||||
}}
|
||||
title="Copy message"
|
||||
>
|
||||
<Copy size={12} />
|
||||
</Button>
|
||||
)}
|
||||
{action}
|
||||
<ToastClose />
|
||||
</Toast>
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ interface PipelineState {
|
|||
validatorPrompt: string;
|
||||
} | null;
|
||||
isLoadingOutputs: boolean;
|
||||
isDecomposing: boolean;
|
||||
numOpRun: number;
|
||||
pipelineName: string;
|
||||
sampleSize: number | null;
|
||||
|
|
@ -59,6 +60,7 @@ interface PipelineContextType extends PipelineState {
|
|||
} | null>
|
||||
>;
|
||||
setIsLoadingOutputs: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
setIsDecomposing: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
setNumOpRun: React.Dispatch<React.SetStateAction<number>>;
|
||||
setPipelineName: React.Dispatch<React.SetStateAction<string>>;
|
||||
setSampleSize: React.Dispatch<React.SetStateAction<number | null>>;
|
||||
|
|
@ -83,6 +85,10 @@ interface PipelineContextType extends PipelineState {
|
|||
setExtraPipelineSettings: React.Dispatch<
|
||||
React.SetStateAction<Record<string, unknown> | null>
|
||||
>;
|
||||
// Ref for triggering decomposition from OperationCard (using ref to avoid infinite loops)
|
||||
onRequestDecompositionRef: React.MutableRefObject<
|
||||
((operationId: string, operationName: string) => void) | null
|
||||
>;
|
||||
}
|
||||
|
||||
const PipelineContext = createContext<PipelineContextType | undefined>(
|
||||
|
|
@ -290,6 +296,7 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
|
|||
localStorageKeys.IS_LOADING_OUTPUTS_KEY,
|
||||
false
|
||||
),
|
||||
isDecomposing: false, // Transient state, not persisted
|
||||
numOpRun: loadFromLocalStorage(localStorageKeys.NUM_OP_RUN_KEY, 0),
|
||||
pipelineName: loadFromLocalStorage(
|
||||
localStorageKeys.PIPELINE_NAME_KEY,
|
||||
|
|
@ -333,6 +340,11 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
|
|||
const stateRef = useRef(state);
|
||||
const [isMounted, setIsMounted] = useState(false);
|
||||
|
||||
// Ref for triggering decomposition from OperationCard (using ref to avoid infinite loops)
|
||||
const onRequestDecompositionRef = useRef<
|
||||
((operationId: string, operationName: string) => void) | null
|
||||
>(null);
|
||||
|
||||
useEffect(() => {
|
||||
stateRef.current = state;
|
||||
}, [state]);
|
||||
|
|
@ -423,6 +435,7 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
|
|||
output: null,
|
||||
terminalOutput: "",
|
||||
isLoadingOutputs: false,
|
||||
isDecomposing: false,
|
||||
numOpRun: 0,
|
||||
pipelineName: mockPipelineName,
|
||||
sampleSize: mockSampleSize,
|
||||
|
|
@ -529,6 +542,10 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
|
|||
(value) => setStateAndUpdate("isLoadingOutputs", value),
|
||||
[setStateAndUpdate]
|
||||
),
|
||||
setIsDecomposing: useCallback(
|
||||
(value) => setStateAndUpdate("isDecomposing", value),
|
||||
[setStateAndUpdate]
|
||||
),
|
||||
setNumOpRun: useCallback(
|
||||
(value) => setStateAndUpdate("numOpRun", value),
|
||||
[setStateAndUpdate]
|
||||
|
|
@ -589,6 +606,7 @@ export const PipelineProvider: React.FC<{ children: React.ReactNode }> = ({
|
|||
(value) => setStateAndUpdate("extraPipelineSettings", value),
|
||||
[setStateAndUpdate]
|
||||
),
|
||||
onRequestDecompositionRef,
|
||||
};
|
||||
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -0,0 +1,104 @@
|
|||
import { useState, useEffect } from "react";
|
||||
import axios from "axios";
|
||||
import { DecomposeResult, DecomposeRequest } from "@/app/types";
|
||||
import { API_ROUTES } from "@/app/api/constants";
|
||||
|
||||
interface UseDecomposeProps {
|
||||
onComplete?: (result: DecomposeResult) => void;
|
||||
onError?: (error: string) => void;
|
||||
pollInterval?: number;
|
||||
}
|
||||
|
||||
export function useDecompose({
|
||||
onComplete,
|
||||
onError,
|
||||
pollInterval = 2000,
|
||||
}: UseDecomposeProps = {}) {
|
||||
const [taskId, setTaskId] = useState<string | null>(null);
|
||||
const [result, setResult] = useState<DecomposeResult | null>(null);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
const submitTask = async (request: DecomposeRequest) => {
|
||||
try {
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
setResult(null);
|
||||
|
||||
const response = await axios.post<{ task_id: string }>(
|
||||
API_ROUTES.DECOMPOSE.SUBMIT,
|
||||
request
|
||||
);
|
||||
|
||||
setTaskId(response.data.task_id);
|
||||
} catch (err) {
|
||||
const errorMessage =
|
||||
err instanceof Error ? err.message : "Failed to submit decompose task";
|
||||
setError(errorMessage);
|
||||
onError?.(errorMessage);
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const cancelTask = async () => {
|
||||
if (!taskId) return;
|
||||
|
||||
try {
|
||||
await axios.post(API_ROUTES.DECOMPOSE.CANCEL(taskId));
|
||||
setTaskId(null);
|
||||
setIsLoading(false);
|
||||
} catch (err) {
|
||||
const errorMessage =
|
||||
err instanceof Error ? err.message : "Failed to cancel task";
|
||||
setError(errorMessage);
|
||||
onError?.(errorMessage);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (!taskId) return;
|
||||
|
||||
const pollTask = async () => {
|
||||
try {
|
||||
const response = await axios.get<DecomposeResult>(
|
||||
API_ROUTES.DECOMPOSE.STATUS(taskId)
|
||||
);
|
||||
|
||||
setResult(response.data);
|
||||
|
||||
if (
|
||||
["completed", "failed", "cancelled"].includes(response.data.status)
|
||||
) {
|
||||
setTaskId(null);
|
||||
setIsLoading(false);
|
||||
|
||||
if (response.data.status === "completed") {
|
||||
onComplete?.(response.data);
|
||||
} else if (response.data.status === "failed" && response.data.error) {
|
||||
setError(response.data.error);
|
||||
onError?.(response.data.error);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
const errorMessage =
|
||||
err instanceof Error ? err.message : "Failed to fetch task status";
|
||||
setError(errorMessage);
|
||||
onError?.(errorMessage);
|
||||
setTaskId(null);
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const interval = setInterval(pollTask, pollInterval);
|
||||
return () => clearInterval(interval);
|
||||
}, [taskId, onComplete, onError, pollInterval]);
|
||||
|
||||
return {
|
||||
submitTask,
|
||||
cancelTask,
|
||||
result,
|
||||
error,
|
||||
isLoading,
|
||||
isRunning: !!taskId,
|
||||
};
|
||||
}
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
import { useState, useRef, useCallback } from "react";
|
||||
import { DecomposeResult } from "@/app/types";
|
||||
|
||||
interface DecomposeWebSocketMessage {
|
||||
type: "output" | "result" | "error";
|
||||
data?: any;
|
||||
message?: string;
|
||||
traceback?: string;
|
||||
}
|
||||
|
||||
interface UseDecomposeWebSocketProps {
|
||||
namespace: string;
|
||||
onOutput?: (output: string) => void;
|
||||
onComplete?: (result: DecomposeResult) => void;
|
||||
onError?: (error: string) => void;
|
||||
}
|
||||
|
||||
export function useDecomposeWebSocket({
|
||||
namespace,
|
||||
onOutput,
|
||||
onComplete,
|
||||
onError,
|
||||
}: UseDecomposeWebSocketProps) {
|
||||
const [isConnected, setIsConnected] = useState(false);
|
||||
const [isRunning, setIsRunning] = useState(false);
|
||||
const ws = useRef<WebSocket | null>(null);
|
||||
|
||||
const connect = useCallback((): Promise<void> => {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (ws.current?.readyState === WebSocket.OPEN) {
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
|
||||
if (!namespace) {
|
||||
reject(new Error("Namespace is required for WebSocket connection"));
|
||||
return;
|
||||
}
|
||||
|
||||
const wsUrl = `${
|
||||
process.env.NEXT_PUBLIC_BACKEND_HTTPS === "true" ? "wss://" : "ws://"
|
||||
}${process.env.NEXT_PUBLIC_BACKEND_HOST}:${
|
||||
process.env.NEXT_PUBLIC_BACKEND_PORT
|
||||
}/ws/decompose/${namespace}`;
|
||||
|
||||
ws.current = new WebSocket(wsUrl);
|
||||
|
||||
ws.current.onopen = () => {
|
||||
setIsConnected(true);
|
||||
resolve();
|
||||
};
|
||||
|
||||
ws.current.onclose = () => {
|
||||
setIsConnected(false);
|
||||
setIsRunning(false);
|
||||
};
|
||||
|
||||
ws.current.onerror = (error: Event) => {
|
||||
console.error("Decompose WebSocket error:", error);
|
||||
onError?.("WebSocket connection failed");
|
||||
reject(new Error("WebSocket connection failed"));
|
||||
};
|
||||
|
||||
ws.current.onmessage = (event) => {
|
||||
try {
|
||||
const message: DecomposeWebSocketMessage = JSON.parse(event.data);
|
||||
|
||||
if (message.type === "output") {
|
||||
onOutput?.(message.data);
|
||||
} else if (message.type === "result") {
|
||||
setIsRunning(false);
|
||||
// Convert the result data to DecomposeResult format
|
||||
const result: DecomposeResult = {
|
||||
task_id: "", // Not used in WebSocket mode
|
||||
status: "completed",
|
||||
decomposed_operations: message.data.decomposed_operations,
|
||||
winning_directive: message.data.winning_directive,
|
||||
candidates_evaluated: message.data.candidates_evaluated,
|
||||
original_outputs: message.data.original_outputs,
|
||||
decomposed_outputs: message.data.decomposed_outputs,
|
||||
comparison_rationale: message.data.comparison_rationale,
|
||||
cost: message.data.cost,
|
||||
created_at: new Date().toISOString(),
|
||||
completed_at: new Date().toISOString(),
|
||||
};
|
||||
onComplete?.(result);
|
||||
// Close the connection after receiving result
|
||||
ws.current?.close();
|
||||
} else if (message.type === "error") {
|
||||
setIsRunning(false);
|
||||
const errorMsg = message.data || message.message || "Unknown error";
|
||||
onError?.(errorMsg);
|
||||
ws.current?.close();
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error parsing WebSocket message:", error);
|
||||
onError?.("Failed to parse WebSocket message");
|
||||
}
|
||||
};
|
||||
});
|
||||
}, [namespace, onOutput, onComplete, onError]);
|
||||
|
||||
const startDecomposition = useCallback(
|
||||
async (yamlConfig: string, stepName: string, opName: string) => {
|
||||
try {
|
||||
await connect();
|
||||
setIsRunning(true);
|
||||
|
||||
if (ws.current?.readyState === WebSocket.OPEN) {
|
||||
ws.current.send(
|
||||
JSON.stringify({
|
||||
yaml_config: yamlConfig,
|
||||
step_name: stepName,
|
||||
op_name: opName,
|
||||
})
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
setIsRunning(false);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[connect]
|
||||
);
|
||||
|
||||
const cancel = useCallback(() => {
|
||||
if (ws.current?.readyState === WebSocket.OPEN) {
|
||||
ws.current.send(JSON.stringify("kill"));
|
||||
}
|
||||
}, []);
|
||||
|
||||
const disconnect = useCallback(() => {
|
||||
if (ws.current) {
|
||||
ws.current.close();
|
||||
}
|
||||
}, []);
|
||||
|
||||
return {
|
||||
isConnected,
|
||||
isRunning,
|
||||
startDecomposition,
|
||||
cancel,
|
||||
disconnect,
|
||||
};
|
||||
}
|
||||
|
|
@ -6,7 +6,8 @@ export function cn(...inputs: ClassValue[]) {
|
|||
}
|
||||
|
||||
export function canBeOptimized(operationType: string) {
|
||||
return ["resolve", "map", "reduce", "filter"].includes(operationType);
|
||||
// Only map operations can be decomposed
|
||||
return operationType === "map";
|
||||
}
|
||||
|
||||
export const generateId = () => {
|
||||
|
|
|
|||
Loading…
Reference in New Issue