Compare commits
3 Commits
main
...
dwoptimize
| Author | SHA1 | Date |
|---|---|---|
|
|
4e9f4532ad | |
|
|
ed2f1fd0a1 | |
|
|
c0445abba9 |
|
|
@ -465,8 +465,7 @@ class OpContainer:
|
|||
return cached_data, 0, curr_logs
|
||||
|
||||
# Try to load from checkpoint if available
|
||||
# Skip if this operation has bypass_cache: true
|
||||
if not is_build and not self.config.get("bypass_cache", False):
|
||||
if not is_build:
|
||||
attempted_input_data = self.runner._load_from_checkpoint_if_exists(
|
||||
self.name.split("/")[0], self.name.split("/")[-1]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from rich.prompt import Confirm
|
|||
|
||||
from docetl.operations.base import BaseOperation
|
||||
from docetl.operations.utils import strict_render
|
||||
from docetl.operations.utils.blocking import RuntimeBlockingOptimizer
|
||||
from docetl.operations.utils.progress import RichLoopBar
|
||||
from docetl.utils import (
|
||||
completion_cost,
|
||||
|
|
@ -64,7 +63,6 @@ class EquijoinOperation(BaseOperation):
|
|||
comparison_prompt: str
|
||||
output: dict[str, Any] | None = None
|
||||
blocking_threshold: float | None = None
|
||||
blocking_target_recall: float | None = None
|
||||
blocking_conditions: list[str] | None = None
|
||||
limits: dict[str, int] | None = None
|
||||
comparison_model: str | None = None
|
||||
|
|
@ -252,58 +250,6 @@ class EquijoinOperation(BaseOperation):
|
|||
if self.status:
|
||||
self.status.stop()
|
||||
|
||||
# Track pre-computed embeddings from auto-optimization
|
||||
precomputed_left_embeddings = None
|
||||
precomputed_right_embeddings = None
|
||||
|
||||
# Auto-compute blocking threshold if no blocking configuration is provided
|
||||
if not blocking_threshold and not blocking_conditions and not limit_comparisons:
|
||||
# Get target recall from operation config (default 0.95)
|
||||
target_recall = self.config.get("blocking_target_recall", 0.95)
|
||||
self.console.log(
|
||||
f"[yellow]No blocking configuration. Auto-computing threshold (target recall: {target_recall:.0%})...[/yellow]"
|
||||
)
|
||||
|
||||
# Create comparison function for threshold optimization
|
||||
def compare_fn_for_optimization(left_item, right_item):
|
||||
return self.compare_pair(
|
||||
self.config["comparison_prompt"],
|
||||
self.config.get("comparison_model", self.default_model),
|
||||
left_item,
|
||||
right_item,
|
||||
timeout_seconds=self.config.get("timeout", 120),
|
||||
max_retries_per_timeout=self.config.get(
|
||||
"max_retries_per_timeout", 2
|
||||
),
|
||||
)
|
||||
|
||||
# Run threshold optimization
|
||||
optimizer = RuntimeBlockingOptimizer(
|
||||
runner=self.runner,
|
||||
config=self.config,
|
||||
default_model=self.default_model,
|
||||
max_threads=self.max_threads,
|
||||
console=self.console,
|
||||
target_recall=target_recall,
|
||||
sample_size=min(100, len(left_data) * len(right_data) // 4),
|
||||
)
|
||||
(
|
||||
blocking_threshold,
|
||||
precomputed_left_embeddings,
|
||||
precomputed_right_embeddings,
|
||||
optimization_cost,
|
||||
) = optimizer.optimize_equijoin(
|
||||
left_data,
|
||||
right_data,
|
||||
compare_fn_for_optimization,
|
||||
left_keys=left_keys,
|
||||
right_keys=right_keys,
|
||||
)
|
||||
total_cost += optimization_cost
|
||||
self.console.log(
|
||||
f"[green]Using auto-computed blocking threshold: {blocking_threshold}[/green]"
|
||||
)
|
||||
|
||||
# Initial blocking using multiprocessing
|
||||
num_processes = min(cpu_count(), len(left_data))
|
||||
|
||||
|
|
@ -352,60 +298,45 @@ class EquijoinOperation(BaseOperation):
|
|||
)
|
||||
|
||||
if blocking_threshold is not None:
|
||||
# Use precomputed embeddings if available from auto-optimization
|
||||
if (
|
||||
precomputed_left_embeddings is not None
|
||||
and precomputed_right_embeddings is not None
|
||||
):
|
||||
left_embeddings = precomputed_left_embeddings
|
||||
right_embeddings = precomputed_right_embeddings
|
||||
else:
|
||||
embedding_model = self.config.get("embedding_model", self.default_model)
|
||||
model_input_context_length = model_cost.get(embedding_model, {}).get(
|
||||
"max_input_tokens", 8192
|
||||
)
|
||||
batch_size = 2000
|
||||
embedding_model = self.config.get("embedding_model", self.default_model)
|
||||
model_input_context_length = model_cost.get(embedding_model, {}).get(
|
||||
"max_input_tokens", 8192
|
||||
)
|
||||
|
||||
def get_embeddings(
|
||||
input_data: list[dict[str, Any]], keys: list[str], name: str
|
||||
) -> tuple[list[list[float]], float]:
|
||||
texts = [
|
||||
" ".join(str(item[key]) for key in keys if key in item)[
|
||||
: model_input_context_length * 4
|
||||
]
|
||||
for item in input_data
|
||||
def get_embeddings(
|
||||
input_data: list[dict[str, Any]], keys: list[str], name: str
|
||||
) -> tuple[list[list[float]], float]:
|
||||
texts = [
|
||||
" ".join(str(item[key]) for key in keys if key in item)[
|
||||
: model_input_context_length * 4
|
||||
]
|
||||
embeddings = []
|
||||
embedding_cost = 0
|
||||
num_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
for item in input_data
|
||||
]
|
||||
|
||||
for batch_idx, i in enumerate(range(0, len(texts), batch_size)):
|
||||
batch = texts[i : i + batch_size]
|
||||
if num_batches > 1:
|
||||
self.console.log(
|
||||
f"[dim]Creating {name} embeddings: batch {batch_idx + 1}/{num_batches} "
|
||||
f"({min(i + batch_size, len(texts))}/{len(texts)} items)[/dim]"
|
||||
)
|
||||
response = self.runner.api.gen_embedding(
|
||||
model=embedding_model,
|
||||
input=batch,
|
||||
)
|
||||
embeddings.extend(
|
||||
[data["embedding"] for data in response["data"]]
|
||||
)
|
||||
embedding_cost += completion_cost(response)
|
||||
return embeddings, embedding_cost
|
||||
embeddings = []
|
||||
total_cost = 0
|
||||
batch_size = 2000
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
self.console.log(
|
||||
f"On iteration {i} for creating embeddings for {name} data"
|
||||
)
|
||||
response = self.runner.api.gen_embedding(
|
||||
model=embedding_model,
|
||||
input=batch,
|
||||
)
|
||||
embeddings.extend([data["embedding"] for data in response["data"]])
|
||||
total_cost += completion_cost(response)
|
||||
return embeddings, total_cost
|
||||
|
||||
self.console.log(
|
||||
f"[cyan]Creating embeddings for {len(left_data)} left + {len(right_data)} right items...[/cyan]"
|
||||
)
|
||||
left_embeddings, left_cost = get_embeddings(
|
||||
left_data, left_keys, "left"
|
||||
)
|
||||
right_embeddings, right_cost = get_embeddings(
|
||||
right_data, right_keys, "right"
|
||||
)
|
||||
total_cost += left_cost + right_cost
|
||||
left_embeddings, left_cost = get_embeddings(left_data, left_keys, "left")
|
||||
right_embeddings, right_cost = get_embeddings(
|
||||
right_data, right_keys, "right"
|
||||
)
|
||||
total_cost += left_cost + right_cost
|
||||
self.console.log(
|
||||
f"Created embeddings for datasets. Total embedding creation cost: {total_cost}"
|
||||
)
|
||||
|
||||
# Compute all cosine similarities in one call
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ import jinja2
|
|||
from jinja2 import Template
|
||||
from litellm import model_cost
|
||||
from pydantic import Field, ValidationInfo, field_validator, model_validator
|
||||
from rich.prompt import Confirm
|
||||
|
||||
from docetl.operations.base import BaseOperation
|
||||
from docetl.operations.utils import RichLoopBar, rich_as_completed, strict_render
|
||||
from docetl.operations.utils.blocking import RuntimeBlockingOptimizer
|
||||
from docetl.utils import (
|
||||
completion_cost,
|
||||
extract_jinja_variables,
|
||||
|
|
@ -40,7 +40,6 @@ class ResolveOperation(BaseOperation):
|
|||
comparison_model: str | None = None
|
||||
blocking_keys: list[str] | None = None
|
||||
blocking_threshold: float | None = Field(None, ge=0, le=1)
|
||||
blocking_target_recall: float | None = Field(None, ge=0, le=1)
|
||||
blocking_conditions: list[str] | None = None
|
||||
input: dict[str, Any] | None = None
|
||||
embedding_batch_size: int | None = Field(None, gt=0)
|
||||
|
|
@ -267,75 +266,26 @@ class ResolveOperation(BaseOperation):
|
|||
blocking_keys = self.config.get("blocking_keys", [])
|
||||
blocking_threshold = self.config.get("blocking_threshold")
|
||||
blocking_conditions = self.config.get("blocking_conditions", [])
|
||||
limit_comparisons = self.config.get("limit_comparisons")
|
||||
total_cost = 0
|
||||
if self.status:
|
||||
self.status.stop()
|
||||
|
||||
# Track pre-computed embeddings from auto-optimization
|
||||
precomputed_embeddings = None
|
||||
|
||||
# Auto-compute blocking threshold if no blocking configuration is provided
|
||||
if not blocking_threshold and not blocking_conditions and not limit_comparisons:
|
||||
# Get target recall from operation config (default 0.95)
|
||||
target_recall = self.config.get("blocking_target_recall", 0.95)
|
||||
self.console.log(
|
||||
f"[yellow]No blocking configuration. Auto-computing threshold (target recall: {target_recall:.0%})...[/yellow]"
|
||||
)
|
||||
# Determine blocking keys if not set
|
||||
auto_blocking_keys = blocking_keys if blocking_keys else None
|
||||
if not auto_blocking_keys:
|
||||
prompt_template = self.config.get("comparison_prompt", "")
|
||||
prompt_vars = extract_jinja_variables(prompt_template)
|
||||
prompt_vars = [
|
||||
var
|
||||
for var in prompt_vars
|
||||
if var not in ["input", "input1", "input2"]
|
||||
]
|
||||
auto_blocking_keys = list(
|
||||
set([var.split(".")[-1] for var in prompt_vars])
|
||||
)
|
||||
if not auto_blocking_keys:
|
||||
auto_blocking_keys = list(input_data[0].keys())
|
||||
blocking_keys = auto_blocking_keys
|
||||
|
||||
# Create comparison function for threshold optimization
|
||||
def compare_fn_for_optimization(item1, item2):
|
||||
return self.compare_pair(
|
||||
self.config["comparison_prompt"],
|
||||
self.config.get("comparison_model", self.default_model),
|
||||
item1,
|
||||
item2,
|
||||
blocking_keys=[], # Don't use key-based shortcut during optimization
|
||||
timeout_seconds=self.config.get("timeout", 120),
|
||||
max_retries_per_timeout=self.config.get(
|
||||
"max_retries_per_timeout", 2
|
||||
),
|
||||
)
|
||||
|
||||
# Run threshold optimization
|
||||
optimizer = RuntimeBlockingOptimizer(
|
||||
runner=self.runner,
|
||||
config=self.config,
|
||||
default_model=self.default_model,
|
||||
max_threads=self.max_threads,
|
||||
console=self.console,
|
||||
target_recall=target_recall,
|
||||
sample_size=min(100, len(input_data) * (len(input_data) - 1) // 4),
|
||||
)
|
||||
blocking_threshold, precomputed_embeddings, optimization_cost = (
|
||||
optimizer.optimize_resolve(
|
||||
input_data,
|
||||
compare_fn_for_optimization,
|
||||
blocking_keys=blocking_keys,
|
||||
)
|
||||
)
|
||||
total_cost += optimization_cost
|
||||
if not blocking_threshold and not blocking_conditions:
|
||||
# Prompt the user for confirmation
|
||||
if not Confirm.ask(
|
||||
"[yellow]Warning: No blocking keys or conditions specified. "
|
||||
"This may result in a large number of comparisons. "
|
||||
"We recommend specifying at least one blocking key or condition, or using the optimizer to automatically come up with these. "
|
||||
"Do you want to continue without blocking?[/yellow]",
|
||||
console=self.runner.console,
|
||||
):
|
||||
raise ValueError("Operation cancelled by user.")
|
||||
|
||||
input_schema = self.config.get("input", {}).get("schema", {})
|
||||
if not blocking_keys:
|
||||
# Set them to all keys in the input data
|
||||
blocking_keys = list(input_data[0].keys())
|
||||
limit_comparisons = self.config.get("limit_comparisons")
|
||||
total_cost = 0
|
||||
|
||||
def is_match(item1: dict[str, Any], item2: dict[str, Any]) -> bool:
|
||||
return any(
|
||||
|
|
@ -346,101 +296,120 @@ class ResolveOperation(BaseOperation):
|
|||
# Calculate embeddings if blocking_threshold is set
|
||||
embeddings = None
|
||||
if blocking_threshold is not None:
|
||||
# Use precomputed embeddings if available from auto-optimization
|
||||
if precomputed_embeddings is not None:
|
||||
embeddings = precomputed_embeddings
|
||||
else:
|
||||
self.console.log(
|
||||
f"[cyan]Creating embeddings for {len(input_data)} items...[/cyan]"
|
||||
)
|
||||
|
||||
def get_embeddings_batch(
|
||||
items: list[dict[str, Any]]
|
||||
) -> list[tuple[list[float], float]]:
|
||||
embedding_model = self.config.get(
|
||||
"embedding_model", "text-embedding-3-small"
|
||||
)
|
||||
model_input_context_length = model_cost.get(embedding_model, {}).get(
|
||||
"max_input_tokens", 8192
|
||||
)
|
||||
batch_size = self.config.get("embedding_batch_size", 1000)
|
||||
embeddings = []
|
||||
embedding_cost = 0.0
|
||||
num_batches = (len(input_data) + batch_size - 1) // batch_size
|
||||
|
||||
for batch_idx in range(num_batches):
|
||||
start_idx = batch_idx * batch_size
|
||||
end_idx = min(start_idx + batch_size, len(input_data))
|
||||
batch = input_data[start_idx:end_idx]
|
||||
|
||||
if num_batches > 1:
|
||||
self.console.log(
|
||||
f"[dim]Creating embeddings: batch {batch_idx + 1}/{num_batches} "
|
||||
f"({end_idx}/{len(input_data)} items)[/dim]"
|
||||
)
|
||||
|
||||
texts = [
|
||||
" ".join(
|
||||
str(item[key]) for key in blocking_keys if key in item
|
||||
)[: model_input_context_length * 3]
|
||||
for item in batch
|
||||
texts = [
|
||||
" ".join(str(item[key]) for key in blocking_keys if key in item)[
|
||||
: model_input_context_length * 3
|
||||
]
|
||||
response = self.runner.api.gen_embedding(
|
||||
model=embedding_model, input=texts
|
||||
)
|
||||
embeddings.extend([data["embedding"] for data in response["data"]])
|
||||
embedding_cost += completion_cost(response)
|
||||
for item in items
|
||||
]
|
||||
|
||||
total_cost += embedding_cost
|
||||
response = self.runner.api.gen_embedding(
|
||||
model=embedding_model, input=texts
|
||||
)
|
||||
return [
|
||||
(data["embedding"], completion_cost(response))
|
||||
for data in response["data"]
|
||||
]
|
||||
|
||||
# Build a mapping of blocking key values to indices
|
||||
# This is used later for cluster merging (when two items match, merge all items sharing their key values)
|
||||
value_to_indices: dict[tuple[str, ...], list[int]] = {}
|
||||
for i, item in enumerate(input_data):
|
||||
key = tuple(str(item.get(k, "")) for k in blocking_keys)
|
||||
if key not in value_to_indices:
|
||||
value_to_indices[key] = []
|
||||
value_to_indices[key].append(i)
|
||||
embeddings = []
|
||||
costs = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
|
||||
for i in range(
|
||||
0, len(input_data), self.config.get("embedding_batch_size", 1000)
|
||||
):
|
||||
batch = input_data[
|
||||
i : i + self.config.get("embedding_batch_size", 1000)
|
||||
]
|
||||
batch_results = list(executor.map(get_embeddings_batch, [batch]))
|
||||
|
||||
# Total number of pairs to potentially compare
|
||||
n = len(input_data)
|
||||
total_pairs = n * (n - 1) // 2
|
||||
for result in batch_results:
|
||||
embeddings.extend([r[0] for r in result])
|
||||
costs.extend([r[1] for r in result])
|
||||
|
||||
# Apply code-based blocking conditions (check all pairs)
|
||||
code_blocked_pairs = []
|
||||
if blocking_conditions:
|
||||
for i in range(n):
|
||||
for j in range(i + 1, n):
|
||||
if is_match(input_data[i], input_data[j]):
|
||||
code_blocked_pairs.append((i, j))
|
||||
total_cost += sum(costs)
|
||||
|
||||
# Generate all pairs to compare, ensuring no duplicate comparisons
|
||||
def get_unique_comparison_pairs() -> (
|
||||
tuple[list[tuple[int, int]], dict[tuple[str, ...], list[int]]]
|
||||
):
|
||||
# Create a mapping of values to their indices
|
||||
value_to_indices: dict[tuple[str, ...], list[int]] = {}
|
||||
for i, item in enumerate(input_data):
|
||||
# Create a hashable key from the blocking keys
|
||||
key = tuple(str(item.get(k, "")) for k in blocking_keys)
|
||||
if key not in value_to_indices:
|
||||
value_to_indices[key] = []
|
||||
value_to_indices[key].append(i)
|
||||
|
||||
# Generate pairs for comparison, comparing each unique value combination only once
|
||||
comparison_pairs = []
|
||||
keys = list(value_to_indices.keys())
|
||||
|
||||
# First, handle comparisons between different values
|
||||
for i in range(len(keys)):
|
||||
for j in range(i + 1, len(keys)):
|
||||
# Only need one comparison between different values
|
||||
idx1 = value_to_indices[keys[i]][0]
|
||||
idx2 = value_to_indices[keys[j]][0]
|
||||
if idx1 < idx2: # Maintain ordering to avoid duplicates
|
||||
comparison_pairs.append((idx1, idx2))
|
||||
|
||||
return comparison_pairs, value_to_indices
|
||||
|
||||
comparison_pairs, value_to_indices = get_unique_comparison_pairs()
|
||||
|
||||
# Filter pairs based on blocking conditions
|
||||
def meets_blocking_conditions(pair: tuple[int, int]) -> bool:
|
||||
i, j = pair
|
||||
return (
|
||||
is_match(input_data[i], input_data[j]) if blocking_conditions else False
|
||||
)
|
||||
|
||||
# Start with pairs that meet blocking conditions, or empty list if no conditions
|
||||
code_blocked_pairs = (
|
||||
list(filter(meets_blocking_conditions, comparison_pairs))
|
||||
if blocking_conditions
|
||||
else []
|
||||
)
|
||||
|
||||
# Apply cosine similarity blocking if threshold is specified
|
||||
embedding_blocked_pairs = []
|
||||
if blocking_threshold is not None and embeddings is not None:
|
||||
import numpy as np
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
similarity_matrix = cosine_similarity(embeddings)
|
||||
|
||||
# Add pairs that meet the cosine similarity threshold and aren't already blocked
|
||||
code_blocked_set = set(code_blocked_pairs)
|
||||
|
||||
# Use numpy to efficiently find all pairs above threshold
|
||||
i_indices, j_indices = np.triu_indices(n, k=1)
|
||||
similarities = similarity_matrix[i_indices, j_indices]
|
||||
above_threshold_mask = similarities >= blocking_threshold
|
||||
for i, j in comparison_pairs:
|
||||
if (i, j) not in code_blocked_set:
|
||||
similarity = similarity_matrix[i, j]
|
||||
if similarity >= blocking_threshold:
|
||||
embedding_blocked_pairs.append((i, j))
|
||||
|
||||
# Get pairs above threshold
|
||||
above_threshold_i = i_indices[above_threshold_mask]
|
||||
above_threshold_j = j_indices[above_threshold_mask]
|
||||
self.console.log(
|
||||
f"Cosine similarity blocking: added {len(embedding_blocked_pairs)} pairs "
|
||||
f"(threshold: {blocking_threshold})"
|
||||
)
|
||||
|
||||
# Filter out pairs already in code_blocked_set
|
||||
embedding_blocked_pairs = [
|
||||
(int(i), int(j))
|
||||
for i, j in zip(above_threshold_i, above_threshold_j)
|
||||
if (i, j) not in code_blocked_set
|
||||
]
|
||||
|
||||
# Combine pairs from both blocking methods
|
||||
# Combine pairs with prioritization for sampling
|
||||
all_blocked_pairs = code_blocked_pairs + embedding_blocked_pairs
|
||||
|
||||
# If no blocking was applied, compare all pairs
|
||||
if not blocking_conditions and blocking_threshold is None:
|
||||
all_blocked_pairs = [(i, j) for i in range(n) for j in range(i + 1, n)]
|
||||
# If no pairs are blocked at all, fall back to all comparison pairs
|
||||
if not all_blocked_pairs:
|
||||
all_blocked_pairs = comparison_pairs
|
||||
# Apply limit_comparisons with prioritization
|
||||
if limit_comparisons is not None and len(all_blocked_pairs) > limit_comparisons:
|
||||
# Prioritize code-based pairs, then sample from embedding pairs if needed
|
||||
|
|
@ -507,6 +476,18 @@ class ResolveOperation(BaseOperation):
|
|||
cluster_map[root_idx] = root1
|
||||
clusters[root_idx] = set()
|
||||
|
||||
# Calculate and print statistics
|
||||
total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2
|
||||
comparisons_made = len(blocked_pairs)
|
||||
comparisons_saved = total_possible_comparisons - comparisons_made
|
||||
self.console.log(
|
||||
f"[green]Comparisons saved by deduping and blocking: {comparisons_saved} "
|
||||
f"({(comparisons_saved / total_possible_comparisons) * 100:.2f}%)[/green]"
|
||||
)
|
||||
self.console.log(
|
||||
f"[blue]Number of pairs to compare: {len(blocked_pairs)}[/blue]"
|
||||
)
|
||||
|
||||
# Compute an auto-batch size based on the number of comparisons
|
||||
def auto_batch() -> int:
|
||||
# Maximum batch size limit for 4o-mini model
|
||||
|
|
@ -532,14 +513,7 @@ class ResolveOperation(BaseOperation):
|
|||
|
||||
# Compare pairs and update clusters in real-time
|
||||
batch_size = self.config.get("compare_batch_size", auto_batch())
|
||||
|
||||
# Log blocking summary
|
||||
total_possible_comparisons = len(input_data) * (len(input_data) - 1) // 2
|
||||
self.console.log(
|
||||
f"Comparing {len(blocked_pairs):,} pairs "
|
||||
f"({len(blocked_pairs)/total_possible_comparisons*100:.1f}% of {total_possible_comparisons:,} total, "
|
||||
f"batch size: {batch_size})"
|
||||
)
|
||||
self.console.log(f"Using compare batch size: {batch_size}")
|
||||
pair_costs = 0
|
||||
|
||||
pbar = RichLoopBar(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from .api import APIWrapper
|
||||
from .blocking import RuntimeBlockingOptimizer
|
||||
from .cache import (
|
||||
cache,
|
||||
cache_key,
|
||||
|
|
@ -16,7 +15,6 @@ from .validation import safe_eval, convert_val, convert_dict_schema_to_list_sche
|
|||
|
||||
__all__ = [
|
||||
'APIWrapper',
|
||||
'RuntimeBlockingOptimizer',
|
||||
'cache',
|
||||
'cache_key',
|
||||
'clear_cache',
|
||||
|
|
|
|||
|
|
@ -1,567 +0,0 @@
|
|||
"""
|
||||
Runtime blocking threshold optimization utilities.
|
||||
|
||||
This module provides functionality for automatically computing embedding-based
|
||||
blocking thresholds at runtime when no blocking configuration is provided.
|
||||
"""
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
from litellm import model_cost
|
||||
from rich.console import Console
|
||||
|
||||
from docetl.utils import completion_cost, extract_jinja_variables
|
||||
|
||||
|
||||
class RuntimeBlockingOptimizer:
|
||||
"""
|
||||
Computes optimal embedding-based blocking thresholds at runtime.
|
||||
|
||||
This class samples pairs from the dataset, performs LLM comparisons,
|
||||
and finds the optimal cosine similarity threshold that achieves a
|
||||
target recall rate.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runner,
|
||||
config: dict[str, Any],
|
||||
default_model: str,
|
||||
max_threads: int,
|
||||
console: Console,
|
||||
target_recall: float = 0.95,
|
||||
sample_size: int = 100,
|
||||
sampling_weight: float = 20.0,
|
||||
):
|
||||
"""
|
||||
Initialize the RuntimeBlockingOptimizer.
|
||||
|
||||
Args:
|
||||
runner: The pipeline runner instance.
|
||||
config: Operation configuration.
|
||||
default_model: Default LLM model for comparisons.
|
||||
max_threads: Maximum threads for parallel processing.
|
||||
console: Rich console for logging.
|
||||
target_recall: Target recall rate (default 0.95).
|
||||
sample_size: Number of pairs to sample for threshold estimation.
|
||||
sampling_weight: Weight for exponential sampling towards higher similarities.
|
||||
"""
|
||||
self.runner = runner
|
||||
self.config = config
|
||||
self.default_model = default_model
|
||||
self.max_threads = max_threads
|
||||
self.console = console
|
||||
self.target_recall = target_recall
|
||||
self.sample_size = sample_size
|
||||
self.sampling_weight = sampling_weight
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
input_data: list[dict[str, Any]],
|
||||
keys: list[str],
|
||||
embedding_model: str | None = None,
|
||||
batch_size: int = 1000,
|
||||
) -> tuple[list[list[float]], float]:
|
||||
"""
|
||||
Compute embeddings for the input data.
|
||||
|
||||
Args:
|
||||
input_data: List of input documents.
|
||||
keys: Keys to use for embedding text.
|
||||
embedding_model: Model to use for embeddings.
|
||||
batch_size: Batch size for embedding computation.
|
||||
|
||||
Returns:
|
||||
Tuple of (embeddings list, total cost).
|
||||
"""
|
||||
embedding_model = embedding_model or self.config.get(
|
||||
"embedding_model", "text-embedding-3-small"
|
||||
)
|
||||
model_input_context_length = model_cost.get(embedding_model, {}).get(
|
||||
"max_input_tokens", 8192
|
||||
)
|
||||
texts = [
|
||||
" ".join(str(item[key]) for key in keys if key in item)[
|
||||
: model_input_context_length * 3
|
||||
]
|
||||
for item in input_data
|
||||
]
|
||||
|
||||
self.console.log(f"[cyan]Creating embeddings for {len(texts)} items...[/cyan]")
|
||||
|
||||
embeddings = []
|
||||
total_cost = 0.0
|
||||
num_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
for batch_idx, i in enumerate(range(0, len(texts), batch_size)):
|
||||
batch = texts[i : i + batch_size]
|
||||
if num_batches > 1:
|
||||
self.console.log(
|
||||
f"[dim] Batch {batch_idx + 1}/{num_batches} "
|
||||
f"({len(embeddings) + len(batch)}/{len(texts)} items)[/dim]"
|
||||
)
|
||||
response = self.runner.api.gen_embedding(
|
||||
model=embedding_model,
|
||||
input=batch,
|
||||
)
|
||||
embeddings.extend([data["embedding"] for data in response["data"]])
|
||||
total_cost += completion_cost(response)
|
||||
return embeddings, total_cost
|
||||
|
||||
def calculate_cosine_similarities_self(
|
||||
self, embeddings: list[list[float]]
|
||||
) -> list[tuple[int, int, float]]:
|
||||
"""
|
||||
Calculate pairwise cosine similarities for self-join.
|
||||
|
||||
Args:
|
||||
embeddings: List of embedding vectors.
|
||||
|
||||
Returns:
|
||||
List of (i, j, similarity) tuples for all pairs where i < j.
|
||||
"""
|
||||
embeddings_array = np.array(embeddings)
|
||||
norms = np.linalg.norm(embeddings_array, axis=1)
|
||||
# Avoid division by zero
|
||||
norms = np.where(norms == 0, 1e-10, norms)
|
||||
dot_products = np.dot(embeddings_array, embeddings_array.T)
|
||||
similarities_matrix = dot_products / np.outer(norms, norms)
|
||||
i, j = np.triu_indices(len(embeddings), k=1)
|
||||
similarities = list(
|
||||
zip(i.tolist(), j.tolist(), similarities_matrix[i, j].tolist())
|
||||
)
|
||||
return similarities
|
||||
|
||||
def calculate_cosine_similarities_cross(
|
||||
self,
|
||||
left_embeddings: list[list[float]],
|
||||
right_embeddings: list[list[float]],
|
||||
) -> list[tuple[int, int, float]]:
|
||||
"""
|
||||
Calculate cosine similarities between two sets of embeddings.
|
||||
|
||||
Args:
|
||||
left_embeddings: Embeddings for left dataset.
|
||||
right_embeddings: Embeddings for right dataset.
|
||||
|
||||
Returns:
|
||||
List of (left_idx, right_idx, similarity) tuples.
|
||||
"""
|
||||
left_array = np.array(left_embeddings)
|
||||
right_array = np.array(right_embeddings)
|
||||
dot_product = np.dot(left_array, right_array.T)
|
||||
norm_left = np.linalg.norm(left_array, axis=1)
|
||||
norm_right = np.linalg.norm(right_array, axis=1)
|
||||
# Avoid division by zero
|
||||
norm_left = np.where(norm_left == 0, 1e-10, norm_left)
|
||||
norm_right = np.where(norm_right == 0, 1e-10, norm_right)
|
||||
similarities = dot_product / np.outer(norm_left, norm_right)
|
||||
return [
|
||||
(i, j, float(sim))
|
||||
for i, row in enumerate(similarities)
|
||||
for j, sim in enumerate(row)
|
||||
]
|
||||
|
||||
def sample_pairs(
|
||||
self,
|
||||
similarities: list[tuple[int, int, float]],
|
||||
num_bins: int = 10,
|
||||
stratified_fraction: float = 0.5,
|
||||
) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Sample pairs using a hybrid of stratified and exponential-weighted sampling.
|
||||
|
||||
This ensures coverage across the similarity distribution while still
|
||||
focusing on high-similarity pairs where matches are more likely.
|
||||
|
||||
Args:
|
||||
similarities: List of (i, j, similarity) tuples.
|
||||
num_bins: Number of bins for stratified sampling.
|
||||
stratified_fraction: Fraction of samples to allocate to stratified sampling.
|
||||
|
||||
Returns:
|
||||
List of sampled (i, j) pairs.
|
||||
"""
|
||||
if len(similarities) == 0:
|
||||
return []
|
||||
|
||||
sample_count = min(self.sample_size, len(similarities))
|
||||
stratified_count = int(sample_count * stratified_fraction)
|
||||
exponential_count = sample_count - stratified_count
|
||||
|
||||
sampled_indices = set()
|
||||
sim_values = np.array([sim[2] for sim in similarities])
|
||||
|
||||
# Part 1: Stratified sampling across bins
|
||||
if stratified_count > 0:
|
||||
bin_edges = np.linspace(
|
||||
sim_values.min(), sim_values.max() + 1e-9, num_bins + 1
|
||||
)
|
||||
samples_per_bin = max(1, stratified_count // num_bins)
|
||||
|
||||
for bin_idx in range(num_bins):
|
||||
bin_mask = (sim_values >= bin_edges[bin_idx]) & (
|
||||
sim_values < bin_edges[bin_idx + 1]
|
||||
)
|
||||
bin_indices = np.where(bin_mask)[0]
|
||||
|
||||
if len(bin_indices) > 0:
|
||||
# Within each bin, use exponential weighting
|
||||
bin_sims = sim_values[bin_indices]
|
||||
bin_weights = np.exp(self.sampling_weight * bin_sims)
|
||||
bin_weights /= bin_weights.sum()
|
||||
|
||||
n_to_sample = min(samples_per_bin, len(bin_indices))
|
||||
chosen = np.random.choice(
|
||||
bin_indices,
|
||||
size=n_to_sample,
|
||||
replace=False,
|
||||
p=bin_weights,
|
||||
)
|
||||
sampled_indices.update(chosen.tolist())
|
||||
|
||||
# Part 2: Exponential-weighted sampling for remaining slots
|
||||
if exponential_count > 0:
|
||||
remaining_indices = [
|
||||
i for i in range(len(similarities)) if i not in sampled_indices
|
||||
]
|
||||
if remaining_indices:
|
||||
remaining_sims = sim_values[remaining_indices]
|
||||
weights = np.exp(self.sampling_weight * remaining_sims)
|
||||
weights /= weights.sum()
|
||||
|
||||
n_to_sample = min(exponential_count, len(remaining_indices))
|
||||
chosen = np.random.choice(
|
||||
remaining_indices,
|
||||
size=n_to_sample,
|
||||
replace=False,
|
||||
p=weights,
|
||||
)
|
||||
sampled_indices.update(chosen.tolist())
|
||||
|
||||
sampled_pairs = [
|
||||
(similarities[i][0], similarities[i][1]) for i in sampled_indices
|
||||
]
|
||||
return sampled_pairs
|
||||
|
||||
def _print_similarity_histogram(
|
||||
self,
|
||||
similarities: list[tuple[int, int, float]],
|
||||
comparison_results: list[tuple[int, int, bool]],
|
||||
threshold: float | None = None,
|
||||
):
|
||||
"""
|
||||
Print a histogram of embedding cosine similarity distribution.
|
||||
|
||||
Args:
|
||||
similarities: List of (i, j, similarity) tuples.
|
||||
comparison_results: List of (i, j, is_match) from LLM comparisons.
|
||||
threshold: Optional threshold to highlight in the histogram.
|
||||
"""
|
||||
# Filter out self-similarities (similarity == 1)
|
||||
flat_similarities = [sim[2] for sim in similarities if sim[2] != 1]
|
||||
if not flat_similarities:
|
||||
return
|
||||
|
||||
hist, bin_edges = np.histogram(flat_similarities, bins=20)
|
||||
max_bar_width, max_count = 40, max(hist) if max(hist) > 0 else 1
|
||||
normalized_hist = [int(count / max_count * max_bar_width) for count in hist]
|
||||
|
||||
# Create a dictionary to store true labels
|
||||
true_labels = {(i, j): is_match for i, j, is_match in comparison_results}
|
||||
|
||||
# Count pairs above threshold
|
||||
pairs_above_threshold = (
|
||||
sum(1 for sim in flat_similarities if sim >= threshold) if threshold else 0
|
||||
)
|
||||
total_pairs = len(flat_similarities)
|
||||
|
||||
lines = []
|
||||
for i, count in enumerate(normalized_hist):
|
||||
bar = "█" * count
|
||||
bin_start, bin_end = bin_edges[i], bin_edges[i + 1]
|
||||
label = f"{bin_start:.2f}-{bin_end:.2f}"
|
||||
|
||||
# Count true matches and not matches in this bin
|
||||
true_matches = 0
|
||||
not_matches = 0
|
||||
labeled_count = 0
|
||||
for sim in similarities:
|
||||
if bin_start <= sim[2] < bin_end:
|
||||
if (sim[0], sim[1]) in true_labels:
|
||||
labeled_count += 1
|
||||
if true_labels[(sim[0], sim[1])]:
|
||||
true_matches += 1
|
||||
else:
|
||||
not_matches += 1
|
||||
|
||||
# Calculate percentages of labeled pairs
|
||||
if labeled_count > 0:
|
||||
true_match_percent = (true_matches / labeled_count) * 100
|
||||
label_info = f"[green]{true_match_percent:5.1f}%[/green] match"
|
||||
else:
|
||||
label_info = "[dim]--[/dim]"
|
||||
|
||||
# Highlight the bin containing the threshold
|
||||
if threshold is not None and bin_start <= threshold < bin_end:
|
||||
lines.append(
|
||||
f"[bold yellow]{label}[/bold yellow] {bar:<{max_bar_width}} "
|
||||
f"[dim]n={hist[i]:>5}[/dim] {label_info} [bold yellow]◀ threshold[/bold yellow]"
|
||||
)
|
||||
else:
|
||||
lines.append(
|
||||
f"{label} {bar:<{max_bar_width}} "
|
||||
f"[dim]n={hist[i]:>5}[/dim] {label_info}"
|
||||
)
|
||||
|
||||
from rich.panel import Panel
|
||||
|
||||
histogram_content = "\n".join(lines)
|
||||
title = f"Similarity Distribution ({pairs_above_threshold:,} of {total_pairs:,} pairs ≥ {threshold:.4f})"
|
||||
self.console.log(Panel(histogram_content, title=title, border_style="cyan"))
|
||||
|
||||
def find_optimal_threshold(
|
||||
self,
|
||||
comparisons: list[tuple[int, int, bool]],
|
||||
similarities: list[tuple[int, int, float]],
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Find the optimal similarity threshold that achieves target recall.
|
||||
|
||||
Args:
|
||||
comparisons: List of (i, j, is_match) from LLM comparisons.
|
||||
similarities: List of (i, j, similarity) tuples.
|
||||
|
||||
Returns:
|
||||
Tuple of (optimal_threshold, achieved_recall).
|
||||
"""
|
||||
if not comparisons or not any(comp[2] for comp in comparisons):
|
||||
# No matches found, use a high threshold to be conservative
|
||||
self.console.log(
|
||||
"[yellow]No matches found in sample. Using 99th percentile "
|
||||
"similarity as threshold.[/yellow]"
|
||||
)
|
||||
all_sims = [sim[2] for sim in similarities]
|
||||
threshold = float(np.percentile(all_sims, 99)) if all_sims else 0.9
|
||||
return threshold, 0.0
|
||||
|
||||
true_labels = np.array([comp[2] for comp in comparisons])
|
||||
sim_dict = {(i, j): sim for i, j, sim in similarities}
|
||||
sim_scores = np.array([sim_dict.get((i, j), 0.0) for i, j, _ in comparisons])
|
||||
thresholds = np.linspace(0, 1, 100)
|
||||
recalls = []
|
||||
for threshold in thresholds:
|
||||
predictions = sim_scores >= threshold
|
||||
tp = np.sum(predictions & true_labels)
|
||||
fn = np.sum(~predictions & true_labels)
|
||||
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
||||
recalls.append(recall)
|
||||
|
||||
# Find highest threshold that achieves target recall
|
||||
valid_indices = [i for i, r in enumerate(recalls) if r >= self.target_recall]
|
||||
if not valid_indices:
|
||||
# If no threshold achieves target recall, use the one with highest recall
|
||||
best_idx = int(np.argmax(recalls))
|
||||
optimal_threshold = float(thresholds[best_idx])
|
||||
achieved_recall = float(recalls[best_idx])
|
||||
self.console.log(
|
||||
f"[yellow]Warning: Could not achieve target recall {self.target_recall:.0%}. "
|
||||
f"Using threshold {optimal_threshold:.4f} with recall {achieved_recall:.2%}.[/yellow]"
|
||||
)
|
||||
else:
|
||||
best_idx = max(valid_indices)
|
||||
optimal_threshold = float(thresholds[best_idx])
|
||||
achieved_recall = float(recalls[best_idx])
|
||||
|
||||
return round(optimal_threshold, 4), achieved_recall
|
||||
|
||||
def optimize_resolve(
|
||||
self,
|
||||
input_data: list[dict[str, Any]],
|
||||
compare_fn: Callable[[dict, dict], tuple[bool, float, str]],
|
||||
blocking_keys: list[str] | None = None,
|
||||
) -> tuple[float, list[list[float]], float]:
|
||||
"""
|
||||
Compute optimal blocking threshold for resolve operation.
|
||||
|
||||
Args:
|
||||
input_data: Input dataset.
|
||||
compare_fn: Function to compare two items, returns (is_match, cost, prompt).
|
||||
blocking_keys: Keys to use for blocking. If None, extracted from prompt.
|
||||
|
||||
Returns:
|
||||
Tuple of (threshold, embeddings, total_cost).
|
||||
"""
|
||||
from rich.panel import Panel
|
||||
|
||||
# Determine blocking keys
|
||||
if not blocking_keys:
|
||||
prompt_template = self.config.get("comparison_prompt", "")
|
||||
prompt_vars = extract_jinja_variables(prompt_template)
|
||||
prompt_vars = [
|
||||
var for var in prompt_vars if var not in ["input", "input1", "input2"]
|
||||
]
|
||||
blocking_keys = list(set([var.split(".")[-1] for var in prompt_vars]))
|
||||
if not blocking_keys:
|
||||
blocking_keys = list(input_data[0].keys())
|
||||
|
||||
# Compute embeddings
|
||||
embeddings, embedding_cost = self.compute_embeddings(input_data, blocking_keys)
|
||||
|
||||
# Calculate similarities
|
||||
similarities = self.calculate_cosine_similarities_self(embeddings)
|
||||
|
||||
# Sample pairs
|
||||
sampled_pairs = self.sample_pairs(similarities)
|
||||
if not sampled_pairs:
|
||||
self.console.log(
|
||||
"[yellow]No pairs to sample. Using default threshold 0.8.[/yellow]"
|
||||
)
|
||||
return 0.8, embeddings, embedding_cost
|
||||
|
||||
# Perform comparisons
|
||||
comparisons = []
|
||||
comparison_cost = 0.0
|
||||
matches_found = 0
|
||||
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
|
||||
futures = {
|
||||
executor.submit(compare_fn, input_data[i], input_data[j]): (i, j)
|
||||
for i, j in sampled_pairs
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
i, j = futures[future]
|
||||
try:
|
||||
is_match, cost, _ = future.result()
|
||||
comparisons.append((i, j, is_match))
|
||||
comparison_cost += cost
|
||||
if is_match:
|
||||
matches_found += 1
|
||||
except Exception as e:
|
||||
self.console.log(f"[red]Comparison error: {e}[/red]")
|
||||
comparisons.append((i, j, False))
|
||||
|
||||
# Find optimal threshold
|
||||
threshold, achieved_recall = self.find_optimal_threshold(
|
||||
comparisons, similarities
|
||||
)
|
||||
total_cost = embedding_cost + comparison_cost
|
||||
|
||||
# Print histogram visualization
|
||||
self._print_similarity_histogram(similarities, comparisons, threshold)
|
||||
|
||||
# Print summary
|
||||
n = len(input_data)
|
||||
total_pairs = n * (n - 1) // 2
|
||||
pairs_above = sum(1 for s in similarities if s[2] >= threshold)
|
||||
|
||||
summary = (
|
||||
f"[bold]Blocking keys:[/bold] {blocking_keys}\n"
|
||||
f"[bold]Sampled:[/bold] {len(sampled_pairs)} pairs → {matches_found} matches ({matches_found/len(sampled_pairs)*100:.1f}%)\n"
|
||||
f"[bold]Threshold:[/bold] {threshold:.4f} → {achieved_recall:.1%} recall (target: {self.target_recall:.0%})\n"
|
||||
f"[bold]Pairs to compare:[/bold] {pairs_above:,} of {total_pairs:,} ({pairs_above/total_pairs*100:.1f}%)\n"
|
||||
f"[bold]Optimization cost:[/bold] ${total_cost:.4f}"
|
||||
)
|
||||
self.console.log(
|
||||
Panel(
|
||||
summary, title="Blocking Threshold Optimization", border_style="green"
|
||||
)
|
||||
)
|
||||
|
||||
return threshold, embeddings, total_cost
|
||||
|
||||
def optimize_equijoin(
|
||||
self,
|
||||
left_data: list[dict[str, Any]],
|
||||
right_data: list[dict[str, Any]],
|
||||
compare_fn: Callable[[dict, dict], tuple[bool, float]],
|
||||
left_keys: list[str] | None = None,
|
||||
right_keys: list[str] | None = None,
|
||||
) -> tuple[float, list[list[float]], list[list[float]], float]:
|
||||
"""
|
||||
Compute optimal blocking threshold for equijoin operation.
|
||||
|
||||
Args:
|
||||
left_data: Left dataset.
|
||||
right_data: Right dataset.
|
||||
compare_fn: Function to compare two items, returns (is_match, cost).
|
||||
left_keys: Keys to use for left dataset embeddings.
|
||||
right_keys: Keys to use for right dataset embeddings.
|
||||
|
||||
Returns:
|
||||
Tuple of (threshold, left_embeddings, right_embeddings, total_cost).
|
||||
"""
|
||||
from rich.panel import Panel
|
||||
|
||||
# Determine keys
|
||||
if not left_keys:
|
||||
left_keys = list(left_data[0].keys()) if left_data else []
|
||||
if not right_keys:
|
||||
right_keys = list(right_data[0].keys()) if right_data else []
|
||||
|
||||
# Compute embeddings
|
||||
left_embeddings, left_cost = self.compute_embeddings(left_data, left_keys)
|
||||
right_embeddings, right_cost = self.compute_embeddings(right_data, right_keys)
|
||||
embedding_cost = left_cost + right_cost
|
||||
|
||||
# Calculate cross similarities
|
||||
similarities = self.calculate_cosine_similarities_cross(
|
||||
left_embeddings, right_embeddings
|
||||
)
|
||||
|
||||
# Sample pairs
|
||||
sampled_pairs = self.sample_pairs(similarities)
|
||||
if not sampled_pairs:
|
||||
self.console.log(
|
||||
"[yellow]No pairs to sample. Using default threshold 0.8.[/yellow]"
|
||||
)
|
||||
return 0.8, left_embeddings, right_embeddings, embedding_cost
|
||||
|
||||
# Perform comparisons
|
||||
comparisons = []
|
||||
comparison_cost = 0.0
|
||||
matches_found = 0
|
||||
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
|
||||
futures = {
|
||||
executor.submit(compare_fn, left_data[i], right_data[j]): (i, j)
|
||||
for i, j in sampled_pairs
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
i, j = futures[future]
|
||||
try:
|
||||
is_match, cost = future.result()
|
||||
comparisons.append((i, j, is_match))
|
||||
comparison_cost += cost
|
||||
if is_match:
|
||||
matches_found += 1
|
||||
except Exception as e:
|
||||
self.console.log(f"[red]Comparison error: {e}[/red]")
|
||||
comparisons.append((i, j, False))
|
||||
|
||||
# Find optimal threshold
|
||||
threshold, achieved_recall = self.find_optimal_threshold(
|
||||
comparisons, similarities
|
||||
)
|
||||
total_cost = embedding_cost + comparison_cost
|
||||
|
||||
# Print histogram visualization
|
||||
self._print_similarity_histogram(similarities, comparisons, threshold)
|
||||
|
||||
# Print summary
|
||||
total_pairs = len(left_data) * len(right_data)
|
||||
pairs_above = sum(1 for s in similarities if s[2] >= threshold)
|
||||
|
||||
summary = (
|
||||
f"[bold]Left keys:[/bold] {left_keys} [bold]Right keys:[/bold] {right_keys}\n"
|
||||
f"[bold]Sampled:[/bold] {len(sampled_pairs)} pairs → {matches_found} matches ({matches_found/len(sampled_pairs)*100:.1f}%)\n"
|
||||
f"[bold]Threshold:[/bold] {threshold:.4f} → {achieved_recall:.1%} recall (target: {self.target_recall:.0%})\n"
|
||||
f"[bold]Pairs to compare:[/bold] {pairs_above:,} of {total_pairs:,} ({pairs_above/total_pairs*100:.1f}%)\n"
|
||||
f"[bold]Optimization cost:[/bold] ${total_cost:.4f}"
|
||||
)
|
||||
self.console.log(
|
||||
Panel(
|
||||
summary, title="Blocking Threshold Optimization", border_style="green"
|
||||
)
|
||||
)
|
||||
|
||||
return threshold, left_embeddings, right_embeddings, total_cost
|
||||
|
|
@ -34,9 +34,9 @@ This Equijoin operation matches job candidates to job postings:
|
|||
|
||||
The prompt template uses Jinja2 syntax, allowing you to reference input fields directly (e.g., `left.skills`). You can reference the left and right documents using `left` and `right` respectively.
|
||||
|
||||
!!! info "Automatic Blocking"
|
||||
!!! warning "Performance Consideration"
|
||||
|
||||
If you don't specify any blocking configuration (`blocking_threshold`, `blocking_conditions`, or `limit_comparisons`), the Equijoin operation will automatically compute an optimal embedding-based blocking threshold at runtime. It samples pairs from your data, runs LLM comparisons on the sample, and finds a threshold that achieves 95% recall by default. You can adjust this with the `blocking_target_recall` parameter.
|
||||
For large datasets, running comparisons with an LLM can be time-consuming. It's recommended to optimize your pipeline using `docetl build pipeline.yaml` to generate efficient blocking rules for the operation.
|
||||
|
||||
## Blocking
|
||||
|
||||
|
|
@ -95,19 +95,10 @@ A full Equijoin step combining both ideas might look like:
|
|||
|
||||
Equijoin shares many parameters with the Resolve operation. For a detailed list of required and optional parameters, please see the [Parameters section in the Resolve operation documentation](resolve.md#required-parameters).
|
||||
|
||||
### Equijoin-Specific Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
| ------------------------- | --------------------------------------------------------------------------------- | ----------------------------- |
|
||||
| `limits` | Maximum matches for each left/right item: `{"left": n, "right": m}` | No limit |
|
||||
| `blocking_keys` | Keys for embedding blocking: `{"left": [...], "right": [...]}` | All keys from each dataset |
|
||||
| `blocking_threshold` | Embedding similarity threshold for considering pairs | Auto-computed if not set |
|
||||
| `blocking_target_recall` | Target recall when auto-computing blocking threshold (0.0 to 1.0) | 0.95 |
|
||||
|
||||
Key differences from Resolve:
|
||||
Key differences for Equijoin include:
|
||||
|
||||
- `resolution_prompt` is not used in Equijoin.
|
||||
- `blocking_keys` uses a dict with `left` and `right` keys instead of a simple list.
|
||||
- `limits` parameter is specific to Equijoin, allowing you to set maximum matches for each left and right item.
|
||||
|
||||
## Incorporating Into a Pipeline
|
||||
|
||||
|
|
|
|||
|
|
@ -44,9 +44,9 @@ This Resolve operation processes patient names to identify and standardize dupli
|
|||
|
||||
Note: The prompt templates use Jinja2 syntax, allowing you to reference input fields directly (e.g., `input1.patient_name`).
|
||||
|
||||
!!! info "Automatic Blocking"
|
||||
!!! warning "Performance Consideration"
|
||||
|
||||
If you don't specify any blocking configuration (`blocking_threshold`, `blocking_conditions`, or `limit_comparisons`), the Resolve operation will automatically compute an optimal embedding-based blocking threshold at runtime. It samples pairs from your data, runs LLM comparisons on the sample, and finds a threshold that achieves 95% recall by default. You can adjust this with the `blocking_target_recall` parameter.
|
||||
You should not run this operation as-is unless your dataset is small! Running O(n^2) comparisons with an LLM can be extremely time-consuming for large datasets. Instead, optimize your pipeline first using `docetl build pipeline.yaml` and run the optimized version, which will generate efficient blocking rules for the operation. Make sure you've set `optimize: true` in your resolve operation config.
|
||||
|
||||
## Blocking
|
||||
|
||||
|
|
@ -132,8 +132,7 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un
|
|||
| `resolution_model` | The language model to use for reducing matched entries | Falls back to `default_model` |
|
||||
| `comparison_model` | The language model to use for comparing potential matches | Falls back to `default_model` |
|
||||
| `blocking_keys` | List of keys to use for initial blocking | All keys in the input data |
|
||||
| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | Auto-computed if not set |
|
||||
| `blocking_target_recall` | Target recall when auto-computing blocking threshold (0.0 to 1.0) | 0.95 |
|
||||
| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | None |
|
||||
| `blocking_conditions` | List of conditions for initial blocking | [] |
|
||||
| `input` | Specifies the schema or keys to subselect from each item to pass into the prompts | All keys from input items |
|
||||
| `embedding_batch_size` | The number of entries to send to the embedding model at a time | 1000 |
|
||||
|
|
@ -141,9 +140,9 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un
|
|||
| `limit_comparisons` | Maximum number of comparisons to perform | None |
|
||||
| `timeout` | Timeout for each LLM call in seconds | 120 |
|
||||
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |
|
||||
| `sample` | Number of samples to use for the operation | None |
|
||||
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
|
||||
| `bypass_cache` | If true, bypass the cache for this operation. | False |
|
||||
| `sample` | Number of samples to use for the operation | None |
|
||||
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
|
||||
| `bypass_cache` | If true, bypass the cache for this operation. | False |
|
||||
|
||||
## Best Practices
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue