Compare commits

..

2 Commits

Author SHA1 Message Date
Shreya Shankar bcac6872f5
Embedding blocking threshold optimization (#473)
* feat: Add runtime blocking threshold optimization

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Checkpoint before follow-up message

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Refactor: Simplify target_recall retrieval in Equijoin and Resolve

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* Refactor: Improve blocking documentation and add auto-blocking

Co-authored-by: ss.shankar505 <ss.shankar505@gmail.com>

* allow resolve and equijoin to figure out blocking thresholds on the fly.

---------

Co-authored-by: Cursor Agent <cursoragent@cursor.com>
2025-12-29 19:47:21 -06:00
Shreya Shankar 57a284bcb1
Fast Decomposition for Map Operations in DocWrangler (#472)
* refactor: docwrangler to use a faster decomposition flow, only if the last operation in a pipeline is a map operation.

* refactor: docwrangler to use a faster decomposition flow, only if the last operation in a pipeline is a map operation.

* refactor: update MOAR documentation
2025-12-29 18:22:02 -06:00
7 changed files with 838 additions and 163 deletions

View File

@ -465,7 +465,8 @@ class OpContainer:
return cached_data, 0, curr_logs
# Try to load from checkpoint if available
if not is_build:
# Skip if this operation has bypass_cache: true
if not is_build and not self.config.get("bypass_cache", False):
attempted_input_data = self.runner._load_from_checkpoint_if_exists(
self.name.split("/")[0], self.name.split("/")[-1]
)

View File

@ -17,6 +17,7 @@ from rich.prompt import Confirm
from docetl.operations.base import BaseOperation
from docetl.operations.utils import strict_render
from docetl.operations.utils.blocking import RuntimeBlockingOptimizer
from docetl.operations.utils.progress import RichLoopBar
from docetl.utils import (
completion_cost,
@ -63,6 +64,7 @@ class EquijoinOperation(BaseOperation):
comparison_prompt: str
output: dict[str, Any] | None = None
blocking_threshold: float | None = None
blocking_target_recall: float | None = None
blocking_conditions: list[str] | None = None
limits: dict[str, int] | None = None
comparison_model: str | None = None
@ -250,6 +252,58 @@ class EquijoinOperation(BaseOperation):
if self.status:
self.status.stop()
# Track pre-computed embeddings from auto-optimization
precomputed_left_embeddings = None
precomputed_right_embeddings = None
# Auto-compute blocking threshold if no blocking configuration is provided
if not blocking_threshold and not blocking_conditions and not limit_comparisons:
# Get target recall from operation config (default 0.95)
target_recall = self.config.get("blocking_target_recall", 0.95)
self.console.log(
f"[yellow]No blocking configuration. Auto-computing threshold (target recall: {target_recall:.0%})...[/yellow]"
)
# Create comparison function for threshold optimization
def compare_fn_for_optimization(left_item, right_item):
return self.compare_pair(
self.config["comparison_prompt"],
self.config.get("comparison_model", self.default_model),
left_item,
right_item,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
)
# Run threshold optimization
optimizer = RuntimeBlockingOptimizer(
runner=self.runner,
config=self.config,
default_model=self.default_model,
max_threads=self.max_threads,
console=self.console,
target_recall=target_recall,
sample_size=min(100, len(left_data) * len(right_data) // 4),
)
(
blocking_threshold,
precomputed_left_embeddings,
precomputed_right_embeddings,
optimization_cost,
) = optimizer.optimize_equijoin(
left_data,
right_data,
compare_fn_for_optimization,
left_keys=left_keys,
right_keys=right_keys,
)
total_cost += optimization_cost
self.console.log(
f"[green]Using auto-computed blocking threshold: {blocking_threshold}[/green]"
)
# Initial blocking using multiprocessing
num_processes = min(cpu_count(), len(left_data))
@ -298,45 +352,60 @@ class EquijoinOperation(BaseOperation):
)
if blocking_threshold is not None:
embedding_model = self.config.get("embedding_model", self.default_model)
model_input_context_length = model_cost.get(embedding_model, {}).get(
"max_input_tokens", 8192
)
def get_embeddings(
input_data: list[dict[str, Any]], keys: list[str], name: str
) -> tuple[list[list[float]], float]:
texts = [
" ".join(str(item[key]) for key in keys if key in item)[
: model_input_context_length * 4
]
for item in input_data
]
embeddings = []
total_cost = 0
# Use precomputed embeddings if available from auto-optimization
if (
precomputed_left_embeddings is not None
and precomputed_right_embeddings is not None
):
left_embeddings = precomputed_left_embeddings
right_embeddings = precomputed_right_embeddings
else:
embedding_model = self.config.get("embedding_model", self.default_model)
model_input_context_length = model_cost.get(embedding_model, {}).get(
"max_input_tokens", 8192
)
batch_size = 2000
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
self.console.log(
f"On iteration {i} for creating embeddings for {name} data"
)
response = self.runner.api.gen_embedding(
model=embedding_model,
input=batch,
)
embeddings.extend([data["embedding"] for data in response["data"]])
total_cost += completion_cost(response)
return embeddings, total_cost
left_embeddings, left_cost = get_embeddings(left_data, left_keys, "left")
right_embeddings, right_cost = get_embeddings(
right_data, right_keys, "right"
)
total_cost += left_cost + right_cost
self.console.log(
f"Created embeddings for datasets. Total embedding creation cost: {total_cost}"
)
def get_embeddings(
input_data: list[dict[str, Any]], keys: list[str], name: str
) -> tuple[list[list[float]], float]:
texts = [
" ".join(str(item[key]) for key in keys if key in item)[
: model_input_context_length * 4
]
for item in input_data
]
embeddings = []
embedding_cost = 0
num_batches = (len(texts) + batch_size - 1) // batch_size
for batch_idx, i in enumerate(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
if num_batches > 1:
self.console.log(
f"[dim]Creating {name} embeddings: batch {batch_idx + 1}/{num_batches} "
f"({min(i + batch_size, len(texts))}/{len(texts)} items)[/dim]"
)
response = self.runner.api.gen_embedding(
model=embedding_model,
input=batch,
)
embeddings.extend(
[data["embedding"] for data in response["data"]]
)
embedding_cost += completion_cost(response)
return embeddings, embedding_cost
self.console.log(
f"[cyan]Creating embeddings for {len(left_data)} left + {len(right_data)} right items...[/cyan]"
)
left_embeddings, left_cost = get_embeddings(
left_data, left_keys, "left"
)
right_embeddings, right_cost = get_embeddings(
right_data, right_keys, "right"
)
total_cost += left_cost + right_cost
# Compute all cosine similarities in one call
from sklearn.metrics.pairwise import cosine_similarity

View File

@ -10,10 +10,10 @@ import jinja2
from jinja2 import Template
from litellm import model_cost
from pydantic import Field, ValidationInfo, field_validator, model_validator
from rich.prompt import Confirm
from docetl.operations.base import BaseOperation
from docetl.operations.utils import RichLoopBar, rich_as_completed, strict_render
from docetl.operations.utils.blocking import RuntimeBlockingOptimizer
from docetl.utils import (
completion_cost,
extract_jinja_variables,
@ -40,6 +40,7 @@ class ResolveOperation(BaseOperation):
comparison_model: str | None = None
blocking_keys: list[str] | None = None
blocking_threshold: float | None = Field(None, ge=0, le=1)
blocking_target_recall: float | None = Field(None, ge=0, le=1)
blocking_conditions: list[str] | None = None
input: dict[str, Any] | None = None
embedding_batch_size: int | None = Field(None, gt=0)
@ -266,26 +267,75 @@ class ResolveOperation(BaseOperation):
blocking_keys = self.config.get("blocking_keys", [])
blocking_threshold = self.config.get("blocking_threshold")
blocking_conditions = self.config.get("blocking_conditions", [])
limit_comparisons = self.config.get("limit_comparisons")
total_cost = 0
if self.status:
self.status.stop()
if not blocking_threshold and not blocking_conditions:
# Prompt the user for confirmation
if not Confirm.ask(
"[yellow]Warning: No blocking keys or conditions specified. "
"This may result in a large number of comparisons. "
"We recommend specifying at least one blocking key or condition, or using the optimizer to automatically come up with these. "
"Do you want to continue without blocking?[/yellow]",
console=self.runner.console,
):
raise ValueError("Operation cancelled by user.")
# Track pre-computed embeddings from auto-optimization
precomputed_embeddings = None
# Auto-compute blocking threshold if no blocking configuration is provided
if not blocking_threshold and not blocking_conditions and not limit_comparisons:
# Get target recall from operation config (default 0.95)
target_recall = self.config.get("blocking_target_recall", 0.95)
self.console.log(
f"[yellow]No blocking configuration. Auto-computing threshold (target recall: {target_recall:.0%})...[/yellow]"
)
# Determine blocking keys if not set
auto_blocking_keys = blocking_keys if blocking_keys else None
if not auto_blocking_keys:
prompt_template = self.config.get("comparison_prompt", "")
prompt_vars = extract_jinja_variables(prompt_template)
prompt_vars = [
var
for var in prompt_vars
if var not in ["input", "input1", "input2"]
]
auto_blocking_keys = list(
set([var.split(".")[-1] for var in prompt_vars])
)
if not auto_blocking_keys:
auto_blocking_keys = list(input_data[0].keys())
blocking_keys = auto_blocking_keys
# Create comparison function for threshold optimization
def compare_fn_for_optimization(item1, item2):
return self.compare_pair(
self.config["comparison_prompt"],
self.config.get("comparison_model", self.default_model),
item1,
item2,
blocking_keys=[], # Don't use key-based shortcut during optimization
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
)
# Run threshold optimization
optimizer = RuntimeBlockingOptimizer(
runner=self.runner,
config=self.config,
default_model=self.default_model,
max_threads=self.max_threads,
console=self.console,
target_recall=target_recall,
sample_size=min(100, len(input_data) * (len(input_data) - 1) // 4),
)
blocking_threshold, precomputed_embeddings, optimization_cost = (
optimizer.optimize_resolve(
input_data,
compare_fn_for_optimization,
blocking_keys=blocking_keys,
)
)
total_cost += optimization_cost
input_schema = self.config.get("input", {}).get("schema", {})
if not blocking_keys:
# Set them to all keys in the input data
blocking_keys = list(input_data[0].keys())
limit_comparisons = self.config.get("limit_comparisons")
total_cost = 0
def is_match(item1: dict[str, Any], item2: dict[str, Any]) -> bool:
return any(
@ -296,120 +346,101 @@ class ResolveOperation(BaseOperation):
# Calculate embeddings if blocking_threshold is set
embeddings = None
if blocking_threshold is not None:
def get_embeddings_batch(
items: list[dict[str, Any]]
) -> list[tuple[list[float], float]]:
# Use precomputed embeddings if available from auto-optimization
if precomputed_embeddings is not None:
embeddings = precomputed_embeddings
else:
self.console.log(
f"[cyan]Creating embeddings for {len(input_data)} items...[/cyan]"
)
embedding_model = self.config.get(
"embedding_model", "text-embedding-3-small"
)
model_input_context_length = model_cost.get(embedding_model, {}).get(
"max_input_tokens", 8192
)
batch_size = self.config.get("embedding_batch_size", 1000)
embeddings = []
embedding_cost = 0.0
num_batches = (len(input_data) + batch_size - 1) // batch_size
texts = [
" ".join(str(item[key]) for key in blocking_keys if key in item)[
: model_input_context_length * 3
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(input_data))
batch = input_data[start_idx:end_idx]
if num_batches > 1:
self.console.log(
f"[dim]Creating embeddings: batch {batch_idx + 1}/{num_batches} "
f"({end_idx}/{len(input_data)} items)[/dim]"
)
texts = [
" ".join(
str(item[key]) for key in blocking_keys if key in item
)[: model_input_context_length * 3]
for item in batch
]
for item in items
]
response = self.runner.api.gen_embedding(
model=embedding_model, input=texts
)
embeddings.extend([data["embedding"] for data in response["data"]])
embedding_cost += completion_cost(response)
response = self.runner.api.gen_embedding(
model=embedding_model, input=texts
)
return [
(data["embedding"], completion_cost(response))
for data in response["data"]
]
total_cost += embedding_cost
embeddings = []
costs = []
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
for i in range(
0, len(input_data), self.config.get("embedding_batch_size", 1000)
):
batch = input_data[
i : i + self.config.get("embedding_batch_size", 1000)
]
batch_results = list(executor.map(get_embeddings_batch, [batch]))
# Build a mapping of blocking key values to indices
# This is used later for cluster merging (when two items match, merge all items sharing their key values)
value_to_indices: dict[tuple[str, ...], list[int]] = {}
for i, item in enumerate(input_data):
key = tuple(str(item.get(k, "")) for k in blocking_keys)
if key not in value_to_indices:
value_to_indices[key] = []
value_to_indices[key].append(i)
for result in batch_results:
embeddings.extend([r[0] for r in result])
costs.extend([r[1] for r in result])
# Total number of pairs to potentially compare
n = len(input_data)
total_pairs = n * (n - 1) // 2
total_cost += sum(costs)
# Generate all pairs to compare, ensuring no duplicate comparisons
def get_unique_comparison_pairs() -> (
tuple[list[tuple[int, int]], dict[tuple[str, ...], list[int]]]
):
# Create a mapping of values to their indices
value_to_indices: dict[tuple[str, ...], list[int]] = {}
for i, item in enumerate(input_data):
# Create a hashable key from the blocking keys
key = tuple(str(item.get(k, "")) for k in blocking_keys)
if key not in value_to_indices:
value_to_indices[key] = []
value_to_indices[key].append(i)
# Generate pairs for comparison, comparing each unique value combination only once
comparison_pairs = []
keys = list(value_to_indices.keys())
# First, handle comparisons between different values
for i in range(len(keys)):
for j in range(i + 1, len(keys)):
# Only need one comparison between different values
idx1 = value_to_indices[keys[i]][0]
idx2 = value_to_indices[keys[j]][0]
if idx1 < idx2: # Maintain ordering to avoid duplicates
comparison_pairs.append((idx1, idx2))
return comparison_pairs, value_to_indices
comparison_pairs, value_to_indices = get_unique_comparison_pairs()
# Filter pairs based on blocking conditions
def meets_blocking_conditions(pair: tuple[int, int]) -> bool:
i, j = pair
return (
is_match(input_data[i], input_data[j]) if blocking_conditions else False
)
# Start with pairs that meet blocking conditions, or empty list if no conditions
code_blocked_pairs = (
list(filter(meets_blocking_conditions, comparison_pairs))
if blocking_conditions
else []
)
# Apply code-based blocking conditions (check all pairs)
code_blocked_pairs = []
if blocking_conditions:
for i in range(n):
for j in range(i + 1, n):
if is_match(input_data[i], input_data[j]):
code_blocked_pairs.append((i, j))
# Apply cosine similarity blocking if threshold is specified
embedding_blocked_pairs = []
if blocking_threshold is not None and embeddings is not None:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
similarity_matrix = cosine_similarity(embeddings)
# Add pairs that meet the cosine similarity threshold and aren't already blocked
code_blocked_set = set(code_blocked_pairs)
for i, j in comparison_pairs:
if (i, j) not in code_blocked_set:
similarity = similarity_matrix[i, j]
if similarity >= blocking_threshold:
embedding_blocked_pairs.append((i, j))
# Use numpy to efficiently find all pairs above threshold
i_indices, j_indices = np.triu_indices(n, k=1)
similarities = similarity_matrix[i_indices, j_indices]
above_threshold_mask = similarities >= blocking_threshold
self.console.log(
f"Cosine similarity blocking: added {len(embedding_blocked_pairs)} pairs "
f"(threshold: {blocking_threshold})"
)
# Get pairs above threshold
above_threshold_i = i_indices[above_threshold_mask]
above_threshold_j = j_indices[above_threshold_mask]
# Combine pairs with prioritization for sampling
# Filter out pairs already in code_blocked_set
embedding_blocked_pairs = [
(int(i), int(j))
for i, j in zip(above_threshold_i, above_threshold_j)
if (i, j) not in code_blocked_set
]
# Combine pairs from both blocking methods
all_blocked_pairs = code_blocked_pairs + embedding_blocked_pairs
# If no pairs are blocked at all, fall back to all comparison pairs
if not all_blocked_pairs:
all_blocked_pairs = comparison_pairs
# If no blocking was applied, compare all pairs
if not blocking_conditions and blocking_threshold is None:
all_blocked_pairs = [(i, j) for i in range(n) for j in range(i + 1, n)]
# Apply limit_comparisons with prioritization
if limit_comparisons is not None and len(all_blocked_pairs) > limit_comparisons:
# Prioritize code-based pairs, then sample from embedding pairs if needed
@ -476,18 +507,6 @@ class ResolveOperation(BaseOperation):
cluster_map[root_idx] = root1
clusters[root_idx] = set()
# Calculate and print statistics
total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2
comparisons_made = len(blocked_pairs)
comparisons_saved = total_possible_comparisons - comparisons_made
self.console.log(
f"[green]Comparisons saved by deduping and blocking: {comparisons_saved} "
f"({(comparisons_saved / total_possible_comparisons) * 100:.2f}%)[/green]"
)
self.console.log(
f"[blue]Number of pairs to compare: {len(blocked_pairs)}[/blue]"
)
# Compute an auto-batch size based on the number of comparisons
def auto_batch() -> int:
# Maximum batch size limit for 4o-mini model
@ -513,7 +532,14 @@ class ResolveOperation(BaseOperation):
# Compare pairs and update clusters in real-time
batch_size = self.config.get("compare_batch_size", auto_batch())
self.console.log(f"Using compare batch size: {batch_size}")
# Log blocking summary
total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2
self.console.log(
f"Comparing {len(blocked_pairs):,} pairs "
f"({len(blocked_pairs)/total_possible_comparisons*100:.1f}% of {total_possible_comparisons:,} total, "
f"batch size: {batch_size})"
)
pair_costs = 0
pbar = RichLoopBar(

View File

@ -1,4 +1,5 @@
from .api import APIWrapper
from .blocking import RuntimeBlockingOptimizer
from .cache import (
cache,
cache_key,
@ -15,6 +16,7 @@ from .validation import safe_eval, convert_val, convert_dict_schema_to_list_sche
__all__ = [
'APIWrapper',
'RuntimeBlockingOptimizer',
'cache',
'cache_key',
'clear_cache',

View File

@ -0,0 +1,567 @@
"""
Runtime blocking threshold optimization utilities.
This module provides functionality for automatically computing embedding-based
blocking thresholds at runtime when no blocking configuration is provided.
"""
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable
import numpy as np
from litellm import model_cost
from rich.console import Console
from docetl.utils import completion_cost, extract_jinja_variables
class RuntimeBlockingOptimizer:
"""
Computes optimal embedding-based blocking thresholds at runtime.
This class samples pairs from the dataset, performs LLM comparisons,
and finds the optimal cosine similarity threshold that achieves a
target recall rate.
"""
def __init__(
self,
runner,
config: dict[str, Any],
default_model: str,
max_threads: int,
console: Console,
target_recall: float = 0.95,
sample_size: int = 100,
sampling_weight: float = 20.0,
):
"""
Initialize the RuntimeBlockingOptimizer.
Args:
runner: The pipeline runner instance.
config: Operation configuration.
default_model: Default LLM model for comparisons.
max_threads: Maximum threads for parallel processing.
console: Rich console for logging.
target_recall: Target recall rate (default 0.95).
sample_size: Number of pairs to sample for threshold estimation.
sampling_weight: Weight for exponential sampling towards higher similarities.
"""
self.runner = runner
self.config = config
self.default_model = default_model
self.max_threads = max_threads
self.console = console
self.target_recall = target_recall
self.sample_size = sample_size
self.sampling_weight = sampling_weight
def compute_embeddings(
self,
input_data: list[dict[str, Any]],
keys: list[str],
embedding_model: str | None = None,
batch_size: int = 1000,
) -> tuple[list[list[float]], float]:
"""
Compute embeddings for the input data.
Args:
input_data: List of input documents.
keys: Keys to use for embedding text.
embedding_model: Model to use for embeddings.
batch_size: Batch size for embedding computation.
Returns:
Tuple of (embeddings list, total cost).
"""
embedding_model = embedding_model or self.config.get(
"embedding_model", "text-embedding-3-small"
)
model_input_context_length = model_cost.get(embedding_model, {}).get(
"max_input_tokens", 8192
)
texts = [
" ".join(str(item[key]) for key in keys if key in item)[
: model_input_context_length * 3
]
for item in input_data
]
self.console.log(f"[cyan]Creating embeddings for {len(texts)} items...[/cyan]")
embeddings = []
total_cost = 0.0
num_batches = (len(texts) + batch_size - 1) // batch_size
for batch_idx, i in enumerate(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
if num_batches > 1:
self.console.log(
f"[dim] Batch {batch_idx + 1}/{num_batches} "
f"({len(embeddings) + len(batch)}/{len(texts)} items)[/dim]"
)
response = self.runner.api.gen_embedding(
model=embedding_model,
input=batch,
)
embeddings.extend([data["embedding"] for data in response["data"]])
total_cost += completion_cost(response)
return embeddings, total_cost
def calculate_cosine_similarities_self(
self, embeddings: list[list[float]]
) -> list[tuple[int, int, float]]:
"""
Calculate pairwise cosine similarities for self-join.
Args:
embeddings: List of embedding vectors.
Returns:
List of (i, j, similarity) tuples for all pairs where i < j.
"""
embeddings_array = np.array(embeddings)
norms = np.linalg.norm(embeddings_array, axis=1)
# Avoid division by zero
norms = np.where(norms == 0, 1e-10, norms)
dot_products = np.dot(embeddings_array, embeddings_array.T)
similarities_matrix = dot_products / np.outer(norms, norms)
i, j = np.triu_indices(len(embeddings), k=1)
similarities = list(
zip(i.tolist(), j.tolist(), similarities_matrix[i, j].tolist())
)
return similarities
def calculate_cosine_similarities_cross(
self,
left_embeddings: list[list[float]],
right_embeddings: list[list[float]],
) -> list[tuple[int, int, float]]:
"""
Calculate cosine similarities between two sets of embeddings.
Args:
left_embeddings: Embeddings for left dataset.
right_embeddings: Embeddings for right dataset.
Returns:
List of (left_idx, right_idx, similarity) tuples.
"""
left_array = np.array(left_embeddings)
right_array = np.array(right_embeddings)
dot_product = np.dot(left_array, right_array.T)
norm_left = np.linalg.norm(left_array, axis=1)
norm_right = np.linalg.norm(right_array, axis=1)
# Avoid division by zero
norm_left = np.where(norm_left == 0, 1e-10, norm_left)
norm_right = np.where(norm_right == 0, 1e-10, norm_right)
similarities = dot_product / np.outer(norm_left, norm_right)
return [
(i, j, float(sim))
for i, row in enumerate(similarities)
for j, sim in enumerate(row)
]
def sample_pairs(
self,
similarities: list[tuple[int, int, float]],
num_bins: int = 10,
stratified_fraction: float = 0.5,
) -> list[tuple[int, int]]:
"""
Sample pairs using a hybrid of stratified and exponential-weighted sampling.
This ensures coverage across the similarity distribution while still
focusing on high-similarity pairs where matches are more likely.
Args:
similarities: List of (i, j, similarity) tuples.
num_bins: Number of bins for stratified sampling.
stratified_fraction: Fraction of samples to allocate to stratified sampling.
Returns:
List of sampled (i, j) pairs.
"""
if len(similarities) == 0:
return []
sample_count = min(self.sample_size, len(similarities))
stratified_count = int(sample_count * stratified_fraction)
exponential_count = sample_count - stratified_count
sampled_indices = set()
sim_values = np.array([sim[2] for sim in similarities])
# Part 1: Stratified sampling across bins
if stratified_count > 0:
bin_edges = np.linspace(
sim_values.min(), sim_values.max() + 1e-9, num_bins + 1
)
samples_per_bin = max(1, stratified_count // num_bins)
for bin_idx in range(num_bins):
bin_mask = (sim_values >= bin_edges[bin_idx]) & (
sim_values < bin_edges[bin_idx + 1]
)
bin_indices = np.where(bin_mask)[0]
if len(bin_indices) > 0:
# Within each bin, use exponential weighting
bin_sims = sim_values[bin_indices]
bin_weights = np.exp(self.sampling_weight * bin_sims)
bin_weights /= bin_weights.sum()
n_to_sample = min(samples_per_bin, len(bin_indices))
chosen = np.random.choice(
bin_indices,
size=n_to_sample,
replace=False,
p=bin_weights,
)
sampled_indices.update(chosen.tolist())
# Part 2: Exponential-weighted sampling for remaining slots
if exponential_count > 0:
remaining_indices = [
i for i in range(len(similarities)) if i not in sampled_indices
]
if remaining_indices:
remaining_sims = sim_values[remaining_indices]
weights = np.exp(self.sampling_weight * remaining_sims)
weights /= weights.sum()
n_to_sample = min(exponential_count, len(remaining_indices))
chosen = np.random.choice(
remaining_indices,
size=n_to_sample,
replace=False,
p=weights,
)
sampled_indices.update(chosen.tolist())
sampled_pairs = [
(similarities[i][0], similarities[i][1]) for i in sampled_indices
]
return sampled_pairs
def _print_similarity_histogram(
self,
similarities: list[tuple[int, int, float]],
comparison_results: list[tuple[int, int, bool]],
threshold: float | None = None,
):
"""
Print a histogram of embedding cosine similarity distribution.
Args:
similarities: List of (i, j, similarity) tuples.
comparison_results: List of (i, j, is_match) from LLM comparisons.
threshold: Optional threshold to highlight in the histogram.
"""
# Filter out self-similarities (similarity == 1)
flat_similarities = [sim[2] for sim in similarities if sim[2] != 1]
if not flat_similarities:
return
hist, bin_edges = np.histogram(flat_similarities, bins=20)
max_bar_width, max_count = 40, max(hist) if max(hist) > 0 else 1
normalized_hist = [int(count / max_count * max_bar_width) for count in hist]
# Create a dictionary to store true labels
true_labels = {(i, j): is_match for i, j, is_match in comparison_results}
# Count pairs above threshold
pairs_above_threshold = (
sum(1 for sim in flat_similarities if sim >= threshold) if threshold else 0
)
total_pairs = len(flat_similarities)
lines = []
for i, count in enumerate(normalized_hist):
bar = "" * count
bin_start, bin_end = bin_edges[i], bin_edges[i + 1]
label = f"{bin_start:.2f}-{bin_end:.2f}"
# Count true matches and not matches in this bin
true_matches = 0
not_matches = 0
labeled_count = 0
for sim in similarities:
if bin_start <= sim[2] < bin_end:
if (sim[0], sim[1]) in true_labels:
labeled_count += 1
if true_labels[(sim[0], sim[1])]:
true_matches += 1
else:
not_matches += 1
# Calculate percentages of labeled pairs
if labeled_count > 0:
true_match_percent = (true_matches / labeled_count) * 100
label_info = f"[green]{true_match_percent:5.1f}%[/green] match"
else:
label_info = "[dim]--[/dim]"
# Highlight the bin containing the threshold
if threshold is not None and bin_start <= threshold < bin_end:
lines.append(
f"[bold yellow]{label}[/bold yellow] {bar:<{max_bar_width}} "
f"[dim]n={hist[i]:>5}[/dim] {label_info} [bold yellow]◀ threshold[/bold yellow]"
)
else:
lines.append(
f"{label} {bar:<{max_bar_width}} "
f"[dim]n={hist[i]:>5}[/dim] {label_info}"
)
from rich.panel import Panel
histogram_content = "\n".join(lines)
title = f"Similarity Distribution ({pairs_above_threshold:,} of {total_pairs:,} pairs ≥ {threshold:.4f})"
self.console.log(Panel(histogram_content, title=title, border_style="cyan"))
def find_optimal_threshold(
self,
comparisons: list[tuple[int, int, bool]],
similarities: list[tuple[int, int, float]],
) -> tuple[float, float]:
"""
Find the optimal similarity threshold that achieves target recall.
Args:
comparisons: List of (i, j, is_match) from LLM comparisons.
similarities: List of (i, j, similarity) tuples.
Returns:
Tuple of (optimal_threshold, achieved_recall).
"""
if not comparisons or not any(comp[2] for comp in comparisons):
# No matches found, use a high threshold to be conservative
self.console.log(
"[yellow]No matches found in sample. Using 99th percentile "
"similarity as threshold.[/yellow]"
)
all_sims = [sim[2] for sim in similarities]
threshold = float(np.percentile(all_sims, 99)) if all_sims else 0.9
return threshold, 0.0
true_labels = np.array([comp[2] for comp in comparisons])
sim_dict = {(i, j): sim for i, j, sim in similarities}
sim_scores = np.array([sim_dict.get((i, j), 0.0) for i, j, _ in comparisons])
thresholds = np.linspace(0, 1, 100)
recalls = []
for threshold in thresholds:
predictions = sim_scores >= threshold
tp = np.sum(predictions & true_labels)
fn = np.sum(~predictions & true_labels)
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
recalls.append(recall)
# Find highest threshold that achieves target recall
valid_indices = [i for i, r in enumerate(recalls) if r >= self.target_recall]
if not valid_indices:
# If no threshold achieves target recall, use the one with highest recall
best_idx = int(np.argmax(recalls))
optimal_threshold = float(thresholds[best_idx])
achieved_recall = float(recalls[best_idx])
self.console.log(
f"[yellow]Warning: Could not achieve target recall {self.target_recall:.0%}. "
f"Using threshold {optimal_threshold:.4f} with recall {achieved_recall:.2%}.[/yellow]"
)
else:
best_idx = max(valid_indices)
optimal_threshold = float(thresholds[best_idx])
achieved_recall = float(recalls[best_idx])
return round(optimal_threshold, 4), achieved_recall
def optimize_resolve(
self,
input_data: list[dict[str, Any]],
compare_fn: Callable[[dict, dict], tuple[bool, float, str]],
blocking_keys: list[str] | None = None,
) -> tuple[float, list[list[float]], float]:
"""
Compute optimal blocking threshold for resolve operation.
Args:
input_data: Input dataset.
compare_fn: Function to compare two items, returns (is_match, cost, prompt).
blocking_keys: Keys to use for blocking. If None, extracted from prompt.
Returns:
Tuple of (threshold, embeddings, total_cost).
"""
from rich.panel import Panel
# Determine blocking keys
if not blocking_keys:
prompt_template = self.config.get("comparison_prompt", "")
prompt_vars = extract_jinja_variables(prompt_template)
prompt_vars = [
var for var in prompt_vars if var not in ["input", "input1", "input2"]
]
blocking_keys = list(set([var.split(".")[-1] for var in prompt_vars]))
if not blocking_keys:
blocking_keys = list(input_data[0].keys())
# Compute embeddings
embeddings, embedding_cost = self.compute_embeddings(input_data, blocking_keys)
# Calculate similarities
similarities = self.calculate_cosine_similarities_self(embeddings)
# Sample pairs
sampled_pairs = self.sample_pairs(similarities)
if not sampled_pairs:
self.console.log(
"[yellow]No pairs to sample. Using default threshold 0.8.[/yellow]"
)
return 0.8, embeddings, embedding_cost
# Perform comparisons
comparisons = []
comparison_cost = 0.0
matches_found = 0
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
futures = {
executor.submit(compare_fn, input_data[i], input_data[j]): (i, j)
for i, j in sampled_pairs
}
for future in as_completed(futures):
i, j = futures[future]
try:
is_match, cost, _ = future.result()
comparisons.append((i, j, is_match))
comparison_cost += cost
if is_match:
matches_found += 1
except Exception as e:
self.console.log(f"[red]Comparison error: {e}[/red]")
comparisons.append((i, j, False))
# Find optimal threshold
threshold, achieved_recall = self.find_optimal_threshold(
comparisons, similarities
)
total_cost = embedding_cost + comparison_cost
# Print histogram visualization
self._print_similarity_histogram(similarities, comparisons, threshold)
# Print summary
n = len(input_data)
total_pairs = n * (n - 1) // 2
pairs_above = sum(1 for s in similarities if s[2] >= threshold)
summary = (
f"[bold]Blocking keys:[/bold] {blocking_keys}\n"
f"[bold]Sampled:[/bold] {len(sampled_pairs)} pairs → {matches_found} matches ({matches_found/len(sampled_pairs)*100:.1f}%)\n"
f"[bold]Threshold:[/bold] {threshold:.4f}{achieved_recall:.1%} recall (target: {self.target_recall:.0%})\n"
f"[bold]Pairs to compare:[/bold] {pairs_above:,} of {total_pairs:,} ({pairs_above/total_pairs*100:.1f}%)\n"
f"[bold]Optimization cost:[/bold] ${total_cost:.4f}"
)
self.console.log(
Panel(
summary, title="Blocking Threshold Optimization", border_style="green"
)
)
return threshold, embeddings, total_cost
def optimize_equijoin(
self,
left_data: list[dict[str, Any]],
right_data: list[dict[str, Any]],
compare_fn: Callable[[dict, dict], tuple[bool, float]],
left_keys: list[str] | None = None,
right_keys: list[str] | None = None,
) -> tuple[float, list[list[float]], list[list[float]], float]:
"""
Compute optimal blocking threshold for equijoin operation.
Args:
left_data: Left dataset.
right_data: Right dataset.
compare_fn: Function to compare two items, returns (is_match, cost).
left_keys: Keys to use for left dataset embeddings.
right_keys: Keys to use for right dataset embeddings.
Returns:
Tuple of (threshold, left_embeddings, right_embeddings, total_cost).
"""
from rich.panel import Panel
# Determine keys
if not left_keys:
left_keys = list(left_data[0].keys()) if left_data else []
if not right_keys:
right_keys = list(right_data[0].keys()) if right_data else []
# Compute embeddings
left_embeddings, left_cost = self.compute_embeddings(left_data, left_keys)
right_embeddings, right_cost = self.compute_embeddings(right_data, right_keys)
embedding_cost = left_cost + right_cost
# Calculate cross similarities
similarities = self.calculate_cosine_similarities_cross(
left_embeddings, right_embeddings
)
# Sample pairs
sampled_pairs = self.sample_pairs(similarities)
if not sampled_pairs:
self.console.log(
"[yellow]No pairs to sample. Using default threshold 0.8.[/yellow]"
)
return 0.8, left_embeddings, right_embeddings, embedding_cost
# Perform comparisons
comparisons = []
comparison_cost = 0.0
matches_found = 0
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
futures = {
executor.submit(compare_fn, left_data[i], right_data[j]): (i, j)
for i, j in sampled_pairs
}
for future in as_completed(futures):
i, j = futures[future]
try:
is_match, cost = future.result()
comparisons.append((i, j, is_match))
comparison_cost += cost
if is_match:
matches_found += 1
except Exception as e:
self.console.log(f"[red]Comparison error: {e}[/red]")
comparisons.append((i, j, False))
# Find optimal threshold
threshold, achieved_recall = self.find_optimal_threshold(
comparisons, similarities
)
total_cost = embedding_cost + comparison_cost
# Print histogram visualization
self._print_similarity_histogram(similarities, comparisons, threshold)
# Print summary
total_pairs = len(left_data) * len(right_data)
pairs_above = sum(1 for s in similarities if s[2] >= threshold)
summary = (
f"[bold]Left keys:[/bold] {left_keys} [bold]Right keys:[/bold] {right_keys}\n"
f"[bold]Sampled:[/bold] {len(sampled_pairs)} pairs → {matches_found} matches ({matches_found/len(sampled_pairs)*100:.1f}%)\n"
f"[bold]Threshold:[/bold] {threshold:.4f}{achieved_recall:.1%} recall (target: {self.target_recall:.0%})\n"
f"[bold]Pairs to compare:[/bold] {pairs_above:,} of {total_pairs:,} ({pairs_above/total_pairs*100:.1f}%)\n"
f"[bold]Optimization cost:[/bold] ${total_cost:.4f}"
)
self.console.log(
Panel(
summary, title="Blocking Threshold Optimization", border_style="green"
)
)
return threshold, left_embeddings, right_embeddings, total_cost

View File

@ -34,9 +34,9 @@ This Equijoin operation matches job candidates to job postings:
The prompt template uses Jinja2 syntax, allowing you to reference input fields directly (e.g., `left.skills`). You can reference the left and right documents using `left` and `right` respectively.
!!! warning "Performance Consideration"
!!! info "Automatic Blocking"
For large datasets, running comparisons with an LLM can be time-consuming. It's recommended to optimize your pipeline using `docetl build pipeline.yaml` to generate efficient blocking rules for the operation.
If you don't specify any blocking configuration (`blocking_threshold`, `blocking_conditions`, or `limit_comparisons`), the Equijoin operation will automatically compute an optimal embedding-based blocking threshold at runtime. It samples pairs from your data, runs LLM comparisons on the sample, and finds a threshold that achieves 95% recall by default. You can adjust this with the `blocking_target_recall` parameter.
## Blocking
@ -95,10 +95,19 @@ A full Equijoin step combining both ideas might look like:
Equijoin shares many parameters with the Resolve operation. For a detailed list of required and optional parameters, please see the [Parameters section in the Resolve operation documentation](resolve.md#required-parameters).
Key differences for Equijoin include:
### Equijoin-Specific Parameters
| Parameter | Description | Default |
| ------------------------- | --------------------------------------------------------------------------------- | ----------------------------- |
| `limits` | Maximum matches for each left/right item: `{"left": n, "right": m}` | No limit |
| `blocking_keys` | Keys for embedding blocking: `{"left": [...], "right": [...]}` | All keys from each dataset |
| `blocking_threshold` | Embedding similarity threshold for considering pairs | Auto-computed if not set |
| `blocking_target_recall` | Target recall when auto-computing blocking threshold (0.0 to 1.0) | 0.95 |
Key differences from Resolve:
- `resolution_prompt` is not used in Equijoin.
- `limits` parameter is specific to Equijoin, allowing you to set maximum matches for each left and right item.
- `blocking_keys` uses a dict with `left` and `right` keys instead of a simple list.
## Incorporating Into a Pipeline

View File

@ -44,9 +44,9 @@ This Resolve operation processes patient names to identify and standardize dupli
Note: The prompt templates use Jinja2 syntax, allowing you to reference input fields directly (e.g., `input1.patient_name`).
!!! warning "Performance Consideration"
!!! info "Automatic Blocking"
You should not run this operation as-is unless your dataset is small! Running O(n^2) comparisons with an LLM can be extremely time-consuming for large datasets. Instead, optimize your pipeline first using `docetl build pipeline.yaml` and run the optimized version, which will generate efficient blocking rules for the operation. Make sure you've set `optimize: true` in your resolve operation config.
If you don't specify any blocking configuration (`blocking_threshold`, `blocking_conditions`, or `limit_comparisons`), the Resolve operation will automatically compute an optimal embedding-based blocking threshold at runtime. It samples pairs from your data, runs LLM comparisons on the sample, and finds a threshold that achieves 95% recall by default. You can adjust this with the `blocking_target_recall` parameter.
## Blocking
@ -132,7 +132,8 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un
| `resolution_model` | The language model to use for reducing matched entries | Falls back to `default_model` |
| `comparison_model` | The language model to use for comparing potential matches | Falls back to `default_model` |
| `blocking_keys` | List of keys to use for initial blocking | All keys in the input data |
| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | None |
| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | Auto-computed if not set |
| `blocking_target_recall` | Target recall when auto-computing blocking threshold (0.0 to 1.0) | 0.95 |
| `blocking_conditions` | List of conditions for initial blocking | [] |
| `input` | Specifies the schema or keys to subselect from each item to pass into the prompts | All keys from input items |
| `embedding_batch_size` | The number of entries to send to the embedding model at a time | 1000 |
@ -140,9 +141,9 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un
| `limit_comparisons` | Maximum number of comparisons to perform | None |
| `timeout` | Timeout for each LLM call in seconds | 120 |
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |
| `sample` | Number of samples to use for the operation | None |
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
| `bypass_cache` | If true, bypass the cache for this operation. | False |
| `sample` | Number of samples to use for the operation | None |
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
| `bypass_cache` | If true, bypass the cache for this operation. | False |
## Best Practices