Compare commits
1 Commits
main
...
nonmonoton
| Author | SHA1 | Date |
|---|---|---|
|
|
d1016d3532 |
|
|
@ -0,0 +1,93 @@
|
|||
"""
|
||||
Debug script to see actual extractions vs ground truth for medical complaints
|
||||
"""
|
||||
|
||||
import json
|
||||
from docetl.runner import DSLRunner
|
||||
from docetl.operations.map import MapOperation
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich import box
|
||||
|
||||
# Load first 5 documents
|
||||
data = json.load(open('docs/assets/medical_transcripts.json'))[:5]
|
||||
|
||||
# Create runner
|
||||
runner = DSLRunner({
|
||||
'default_model': 'gpt-4o-mini',
|
||||
'operations': [],
|
||||
'pipeline': {'steps': [], 'output': {'path': '/tmp/test.json'}}
|
||||
}, max_threads=64)
|
||||
|
||||
# High precision extraction
|
||||
config = {
|
||||
'name': 'test_extract',
|
||||
'type': 'map',
|
||||
'prompt': '''Extract ONLY the primary chief complaint from this medical transcript.
|
||||
Be very strict - only extract the main reason the patient is visiting.
|
||||
Do not include any additional context or secondary complaints.
|
||||
|
||||
Transcript: {{ input.src }}
|
||||
|
||||
Return only the chief complaint as a brief string (1-2 sentences max).''',
|
||||
'output': {'schema': {'extracted_complaints': 'string'}},
|
||||
'model': 'gpt-4o-mini'
|
||||
}
|
||||
|
||||
console = Console()
|
||||
console.print("[bold]Extracting complaints...[/bold]")
|
||||
op = MapOperation(runner=runner, config=config, default_model='gpt-4o-mini', max_threads=64, console=console)
|
||||
results, cost = op.execute(data)
|
||||
|
||||
# Create comparison table
|
||||
table = Table(
|
||||
title="Extraction vs Ground Truth Comparison",
|
||||
box=box.ROUNDED,
|
||||
show_header=True,
|
||||
header_style="bold cyan"
|
||||
)
|
||||
|
||||
table.add_column("Doc", style="bold", width=3)
|
||||
table.add_column("Ground Truth", style="green", width=30)
|
||||
table.add_column("Extracted", style="yellow", width=30)
|
||||
table.add_column("Match?", justify="center", width=8)
|
||||
|
||||
for i, (result, original) in enumerate(zip(results, data)):
|
||||
# Extract ground truth chief complaint
|
||||
gt_text = original.get('tgt', '')
|
||||
if 'CHIEF COMPLAINT' in gt_text:
|
||||
gt_complaint = gt_text.split('CHIEF COMPLAINT')[1].split('\n\n')[0:2]
|
||||
gt_complaint = '\n'.join(gt_complaint).strip()
|
||||
else:
|
||||
gt_complaint = "N/A"
|
||||
|
||||
extracted = result.get('extracted_complaints', '')
|
||||
|
||||
# Check if they match (case insensitive, stripped)
|
||||
match = "✓" if gt_complaint.lower().strip() == extracted.lower().strip() else "✗"
|
||||
|
||||
table.add_row(
|
||||
str(i+1),
|
||||
gt_complaint[:50] + "..." if len(gt_complaint) > 50 else gt_complaint,
|
||||
extracted[:50] + "..." if len(extracted) > 50 else extracted,
|
||||
match
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
# Show full details for mismatches
|
||||
console.print("\n[bold]Detailed Mismatches:[/bold]")
|
||||
for i, (result, original) in enumerate(zip(results, data)):
|
||||
gt_text = original.get('tgt', '')
|
||||
if 'CHIEF COMPLAINT' in gt_text:
|
||||
gt_complaint = gt_text.split('CHIEF COMPLAINT')[1].split('\n\n')[0:2]
|
||||
gt_complaint = '\n'.join(gt_complaint).strip()
|
||||
else:
|
||||
gt_complaint = "N/A"
|
||||
|
||||
extracted = result.get('extracted_complaints', '')
|
||||
|
||||
if gt_complaint.lower().strip() != extracted.lower().strip():
|
||||
console.print(f"\n[bold red]Document {i+1}:[/bold red]")
|
||||
console.print(f"[green]Ground Truth:[/green] {gt_complaint}")
|
||||
console.print(f"[yellow]Extracted:[/yellow] {extracted}")
|
||||
|
|
@ -0,0 +1,625 @@
|
|||
"""
|
||||
Medical Complaints Map-Reduce Precision Experiment
|
||||
|
||||
This experiment demonstrates how lower map precision (extracting more context)
|
||||
can lead to higher reduce quality in medical complaint summarization.
|
||||
|
||||
The experiment:
|
||||
1. Extracts primary complaints from medical transcripts with varying levels of precision
|
||||
- High precision: strict extraction of only chief complaints
|
||||
- Low precision: extract complaints with surrounding context
|
||||
2. Varies the fraction of extracted complaints passed to the reduce operation
|
||||
3. Reduces/summarizes all complaints into a comprehensive summary
|
||||
4. Tracks:
|
||||
- Map precision: How accurately complaints are extracted
|
||||
- Reduce quality: Quality of the final summary using ROUGE and BLEU metrics
|
||||
|
||||
The hypothesis is that lower map precision (more context) leads to better
|
||||
reduce quality because the LLM has more information to synthesize.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
from typing import Any, List, Dict
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich import box
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Import DocETL components
|
||||
from docetl.runner import DSLRunner
|
||||
from docetl.operations.map import MapOperation
|
||||
from docetl.operations.reduce import ReduceOperation
|
||||
|
||||
# For metrics
|
||||
try:
|
||||
from rouge_score import rouge_scorer
|
||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||
import nltk
|
||||
# Download required NLTK data
|
||||
try:
|
||||
nltk.data.find('tokenizers/punkt')
|
||||
except LookupError:
|
||||
nltk.download('punkt', quiet=True)
|
||||
except ImportError:
|
||||
print("Warning: Install rouge-score and nltk for metrics: pip install rouge-score nltk")
|
||||
rouge_scorer = None
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Constants
|
||||
MEDICAL_DATA_PATH = "docs/assets/medical_transcripts.json"
|
||||
GROUND_TRUTH_MODEL = "gpt-4o" # High quality model for ground truth
|
||||
EXPERIMENT_MODEL = "gpt-4o-mini" # Faster model for experiments
|
||||
MAX_WORKERS = 64
|
||||
|
||||
# Extraction prompts with different precision levels
|
||||
EXTRACTION_PROMPTS = {
|
||||
"high_precision": """
|
||||
Extract ONLY the primary chief complaint from this medical transcript.
|
||||
Be very strict - only extract the main reason the patient is visiting.
|
||||
Do not include any additional context or secondary complaints.
|
||||
|
||||
Transcript: {{ input.src }}
|
||||
|
||||
Return only the chief complaint as a brief string (1-2 sentences max).
|
||||
""",
|
||||
"medium_precision": """
|
||||
Extract the primary complaints from this medical transcript.
|
||||
Include the chief complaint and any closely related symptoms or concerns.
|
||||
|
||||
Transcript: {{ input.src }}
|
||||
|
||||
Return the complaints with minimal context.
|
||||
""",
|
||||
"low_precision": """
|
||||
Extract all complaints and concerns mentioned in this medical transcript.
|
||||
Include the chief complaint, related symptoms, and relevant context from the patient's history.
|
||||
Include any information that helps understand the patient's condition.
|
||||
|
||||
Transcript: {{ input.src }}
|
||||
|
||||
Return comprehensive information about the patient's complaints and concerns.
|
||||
"""
|
||||
}
|
||||
|
||||
# Reduce prompt for summarization
|
||||
REDUCE_PROMPT = """
|
||||
You are a medical professional tasked with creating a comprehensive chief complaint summary.
|
||||
|
||||
Here are the extracted complaints from multiple medical transcripts:
|
||||
{% for input in inputs %}
|
||||
{{ loop.index }}. {{ input.extracted_complaints }}
|
||||
{% endfor %}
|
||||
|
||||
Create a comprehensive and well-structured summary that:
|
||||
1. Identifies the most common chief complaints
|
||||
2. Notes any patterns in symptoms or presentations
|
||||
3. Highlights important medical context
|
||||
|
||||
Provide a clear, professional medical summary.
|
||||
"""
|
||||
|
||||
|
||||
def load_medical_data(filepath: str, limit: int = None) -> List[Dict[str, Any]]:
|
||||
"""Load medical transcript data"""
|
||||
with open(filepath, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
if limit:
|
||||
data = data[:limit]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def extract_complaints_with_precision(
|
||||
documents: List[Dict[str, Any]],
|
||||
precision_level: str,
|
||||
runner: DSLRunner,
|
||||
max_workers: int = MAX_WORKERS
|
||||
) -> tuple[List[Dict[str, Any]], float]:
|
||||
"""
|
||||
Extract complaints from medical transcripts with specified precision level.
|
||||
|
||||
Args:
|
||||
documents: List of medical transcript documents
|
||||
precision_level: One of 'high_precision', 'medium_precision', 'low_precision'
|
||||
runner: DSLRunner instance
|
||||
max_workers: Maximum worker threads
|
||||
|
||||
Returns:
|
||||
Tuple of (processed documents, cost)
|
||||
"""
|
||||
from docetl.operations.map import MapOperation
|
||||
|
||||
prompt = EXTRACTION_PROMPTS[precision_level]
|
||||
|
||||
config = {
|
||||
"name": f"extract_complaints_{precision_level}",
|
||||
"type": "map",
|
||||
"prompt": prompt,
|
||||
"output": {
|
||||
"schema": {
|
||||
"extracted_complaints": "string"
|
||||
}
|
||||
},
|
||||
"model": EXPERIMENT_MODEL,
|
||||
}
|
||||
|
||||
op = MapOperation(
|
||||
runner=runner,
|
||||
config=config,
|
||||
default_model=EXPERIMENT_MODEL,
|
||||
max_threads=max_workers,
|
||||
console=Console()
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
results, cost = op.execute(documents)
|
||||
runtime = time.time() - start_time
|
||||
|
||||
return results, cost, runtime
|
||||
|
||||
|
||||
def sample_complaints(
|
||||
documents: List[Dict[str, Any]],
|
||||
fraction: float
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Sample a fraction of the extracted complaints.
|
||||
|
||||
Args:
|
||||
documents: Documents with extracted complaints
|
||||
fraction: Fraction of documents to keep (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
Sampled documents
|
||||
"""
|
||||
if fraction >= 1.0:
|
||||
return documents
|
||||
|
||||
sample_size = max(1, int(len(documents) * fraction))
|
||||
return random.sample(documents, sample_size)
|
||||
|
||||
|
||||
def reduce_complaints(
|
||||
documents: List[Dict[str, Any]],
|
||||
runner: DSLRunner,
|
||||
max_workers: int = MAX_WORKERS
|
||||
) -> tuple[List[Dict[str, Any]], float]:
|
||||
"""
|
||||
Reduce/summarize all complaints into a comprehensive summary.
|
||||
|
||||
Args:
|
||||
documents: Documents with extracted complaints
|
||||
runner: DSLRunner instance
|
||||
max_workers: Maximum worker threads
|
||||
|
||||
Returns:
|
||||
Tuple of (reduced documents, cost)
|
||||
"""
|
||||
from docetl.operations.reduce import ReduceOperation
|
||||
|
||||
config = {
|
||||
"name": "summarize_complaints",
|
||||
"type": "reduce",
|
||||
"reduce_key": "_all", # Reduce all documents together
|
||||
"prompt": REDUCE_PROMPT,
|
||||
"output": {
|
||||
"schema": {
|
||||
"comprehensive_summary": "string"
|
||||
}
|
||||
},
|
||||
"model": EXPERIMENT_MODEL,
|
||||
}
|
||||
|
||||
op = ReduceOperation(
|
||||
runner=runner,
|
||||
config=config,
|
||||
default_model=EXPERIMENT_MODEL,
|
||||
max_threads=max_workers,
|
||||
console=Console()
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
results, cost = op.execute(documents)
|
||||
runtime = time.time() - start_time
|
||||
|
||||
return results, cost, runtime
|
||||
|
||||
|
||||
def calculate_map_precision(
|
||||
extracted_documents: List[Dict[str, Any]],
|
||||
ground_truth_documents: List[Dict[str, Any]]
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Calculate precision metrics for the map operation.
|
||||
|
||||
Compares extracted complaints against ground truth chief complaints
|
||||
using ROUGE scores.
|
||||
|
||||
Args:
|
||||
extracted_documents: Documents with extracted complaints
|
||||
ground_truth_documents: Original documents with ground truth
|
||||
|
||||
Returns:
|
||||
Dictionary with precision metrics
|
||||
"""
|
||||
# Calculate average length regardless of rouge_scorer availability
|
||||
avg_length = sum(len(doc.get('extracted_complaints', '')) for doc in extracted_documents) / len(extracted_documents) if extracted_documents else 0.0
|
||||
|
||||
if rouge_scorer is None:
|
||||
return {
|
||||
"rouge1": 0.0,
|
||||
"rouge2": 0.0,
|
||||
"rougeL": 0.0,
|
||||
"avg_length": avg_length
|
||||
}
|
||||
|
||||
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
|
||||
|
||||
rouge1_scores = []
|
||||
rouge2_scores = []
|
||||
rougeL_scores = []
|
||||
|
||||
for extracted, ground_truth in zip(extracted_documents, ground_truth_documents):
|
||||
# Extract the chief complaint section from ground truth
|
||||
gt_text = ground_truth.get('tgt', '')
|
||||
|
||||
# Try to extract just the chief complaint section
|
||||
if 'CHIEF COMPLAINT' in gt_text:
|
||||
chief_complaint = gt_text.split('CHIEF COMPLAINT')[1].split('\n\n')[0:2]
|
||||
chief_complaint = '\n'.join(chief_complaint).strip()
|
||||
else:
|
||||
# Use first paragraph as fallback
|
||||
chief_complaint = gt_text.split('\n\n')[0].strip()
|
||||
|
||||
extracted_text = extracted.get('extracted_complaints', '')
|
||||
|
||||
if extracted_text and chief_complaint:
|
||||
scores = scorer.score(chief_complaint, extracted_text)
|
||||
rouge1_scores.append(scores['rouge1'].fmeasure)
|
||||
rouge2_scores.append(scores['rouge2'].fmeasure)
|
||||
rougeL_scores.append(scores['rougeL'].fmeasure)
|
||||
|
||||
return {
|
||||
"rouge1": sum(rouge1_scores) / len(rouge1_scores) if rouge1_scores else 0.0,
|
||||
"rouge2": sum(rouge2_scores) / len(rouge2_scores) if rouge2_scores else 0.0,
|
||||
"rougeL": sum(rougeL_scores) / len(rougeL_scores) if rougeL_scores else 0.0,
|
||||
"avg_length": avg_length
|
||||
}
|
||||
|
||||
|
||||
def generate_ground_truth_summary(
|
||||
documents: List[Dict[str, Any]],
|
||||
runner: DSLRunner,
|
||||
max_workers: int = MAX_WORKERS
|
||||
) -> str:
|
||||
"""
|
||||
Generate a high-quality ground truth summary using GPT-4o.
|
||||
|
||||
Args:
|
||||
documents: Documents to summarize
|
||||
runner: DSLRunner instance
|
||||
max_workers: Maximum worker threads
|
||||
|
||||
Returns:
|
||||
Ground truth summary string
|
||||
"""
|
||||
from docetl.operations.reduce import ReduceOperation
|
||||
|
||||
console = Console()
|
||||
console.print("[bold cyan]Generating ground truth summary with GPT-4o...[/bold cyan]")
|
||||
|
||||
config = {
|
||||
"name": "ground_truth_summary",
|
||||
"type": "reduce",
|
||||
"reduce_key": "_all",
|
||||
"prompt": REDUCE_PROMPT,
|
||||
"output": {
|
||||
"schema": {
|
||||
"comprehensive_summary": "string"
|
||||
}
|
||||
},
|
||||
"model": GROUND_TRUTH_MODEL,
|
||||
}
|
||||
|
||||
op = ReduceOperation(
|
||||
runner=runner,
|
||||
config=config,
|
||||
default_model=GROUND_TRUTH_MODEL,
|
||||
max_threads=max_workers,
|
||||
console=console
|
||||
)
|
||||
|
||||
# Extract complaints with low precision (maximum context) for ground truth
|
||||
extract_config = {
|
||||
"name": "ground_truth_extract",
|
||||
"type": "map",
|
||||
"prompt": EXTRACTION_PROMPTS["low_precision"],
|
||||
"output": {
|
||||
"schema": {
|
||||
"extracted_complaints": "string"
|
||||
}
|
||||
},
|
||||
"model": GROUND_TRUTH_MODEL,
|
||||
}
|
||||
|
||||
from docetl.operations.map import MapOperation
|
||||
extract_op = MapOperation(
|
||||
runner=runner,
|
||||
config=extract_config,
|
||||
default_model=GROUND_TRUTH_MODEL,
|
||||
max_threads=max_workers,
|
||||
console=console
|
||||
)
|
||||
|
||||
# Extract with high quality model
|
||||
extracted_docs, _ = extract_op.execute(documents)
|
||||
|
||||
# Generate summary
|
||||
results, _ = op.execute(extracted_docs)
|
||||
|
||||
if results:
|
||||
return results[0].get('comprehensive_summary', '')
|
||||
return ""
|
||||
|
||||
|
||||
def calculate_reduce_quality(
|
||||
reduced_result: Dict[str, Any],
|
||||
ground_truth_summary: str
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Calculate quality metrics for the reduce operation.
|
||||
|
||||
Compares the comprehensive summary against GPT-4o generated ground truth.
|
||||
|
||||
Args:
|
||||
reduced_result: Single result from reduce operation with comprehensive summary
|
||||
ground_truth_summary: Ground truth summary generated by GPT-4o
|
||||
|
||||
Returns:
|
||||
Dictionary with quality metrics
|
||||
"""
|
||||
if rouge_scorer is None:
|
||||
return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0, "bleu": 0.0}
|
||||
|
||||
summary = reduced_result.get('comprehensive_summary', '')
|
||||
|
||||
# Calculate ROUGE scores
|
||||
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
|
||||
rouge_scores = scorer.score(ground_truth_summary, summary)
|
||||
|
||||
# Calculate BLEU score
|
||||
try:
|
||||
reference_tokens = [ground_truth_summary.split()]
|
||||
candidate_tokens = summary.split()
|
||||
smoothing = SmoothingFunction().method1
|
||||
bleu_score = sentence_bleu(reference_tokens, candidate_tokens, smoothing_function=smoothing)
|
||||
except:
|
||||
bleu_score = 0.0
|
||||
|
||||
return {
|
||||
"rouge1": rouge_scores['rouge1'].fmeasure,
|
||||
"rouge2": rouge_scores['rouge2'].fmeasure,
|
||||
"rougeL": rouge_scores['rougeL'].fmeasure,
|
||||
"bleu": bleu_score
|
||||
}
|
||||
|
||||
|
||||
def run_experiment(
|
||||
data_path: str,
|
||||
precision_levels: List[str] = ["high_precision", "medium_precision", "low_precision"],
|
||||
sample_fractions: List[float] = [0.25, 0.5, 0.75, 1.0],
|
||||
data_limit: int = 20, # Limit data for faster experimentation
|
||||
max_workers: int = MAX_WORKERS
|
||||
):
|
||||
"""
|
||||
Run the main experiment comparing precision vs quality tradeoffs.
|
||||
|
||||
Args:
|
||||
data_path: Path to medical transcripts JSON file
|
||||
precision_levels: List of precision levels to test
|
||||
sample_fractions: List of fractions to sample for reduce operation
|
||||
data_limit: Limit number of documents to process
|
||||
max_workers: Maximum worker threads
|
||||
"""
|
||||
console = Console()
|
||||
|
||||
# Load data
|
||||
console.print(f"[bold]Loading medical data from {data_path}[/bold]")
|
||||
documents = load_medical_data(data_path, limit=data_limit)
|
||||
console.print(f"Loaded {len(documents)} medical transcripts")
|
||||
|
||||
# Create runner
|
||||
runner_config = {
|
||||
"default_model": EXPERIMENT_MODEL,
|
||||
"operations": [],
|
||||
"pipeline": {"steps": [], "output": {"path": "/tmp/medical_complaints.json"}},
|
||||
}
|
||||
runner = DSLRunner(runner_config, max_threads=max_workers)
|
||||
|
||||
# Generate ground truth summary with high-quality model
|
||||
console.print(f"\n[bold cyan]Step 1: Generating ground truth summary with {GROUND_TRUTH_MODEL}[/bold cyan]")
|
||||
ground_truth_summary = generate_ground_truth_summary(documents, runner, max_workers)
|
||||
console.print(f"[bold green]✓ Ground truth generated ({len(ground_truth_summary)} chars)[/bold green]")
|
||||
console.print(f"\n[italic]Ground truth preview:[/italic]\n{ground_truth_summary[:300]}...\n")
|
||||
|
||||
# Store results
|
||||
results = {
|
||||
"ground_truth_summary": ground_truth_summary
|
||||
}
|
||||
|
||||
for precision_level in precision_levels:
|
||||
console.print(f"\n[bold green]Testing precision level: {precision_level}[/bold green]")
|
||||
|
||||
# Extract complaints with this precision level
|
||||
console.print(f" Extracting complaints...")
|
||||
extracted_docs, extract_cost, extract_time = extract_complaints_with_precision(
|
||||
documents, precision_level, runner, max_workers
|
||||
)
|
||||
|
||||
# Calculate map precision
|
||||
map_metrics = calculate_map_precision(extracted_docs, documents)
|
||||
console.print(f" Map Precision - ROUGE-L: {map_metrics['rougeL']:.3f}, Avg Length: {map_metrics['avg_length']:.1f}")
|
||||
|
||||
results[precision_level] = {
|
||||
"extract_cost": extract_cost,
|
||||
"extract_time": extract_time,
|
||||
"map_precision": map_metrics,
|
||||
"sample_results": {}
|
||||
}
|
||||
|
||||
# Test different sampling fractions
|
||||
for fraction in sample_fractions:
|
||||
console.print(f"\n[bold blue] Testing sample fraction: {fraction}[/bold blue]")
|
||||
|
||||
# Sample complaints
|
||||
sampled_docs = sample_complaints(extracted_docs, fraction)
|
||||
console.print(f" Sampled {len(sampled_docs)} documents")
|
||||
|
||||
# Reduce/summarize
|
||||
console.print(f" Reducing/summarizing complaints...")
|
||||
reduced_results, reduce_cost, reduce_time = reduce_complaints(
|
||||
sampled_docs, runner, max_workers
|
||||
)
|
||||
|
||||
# Calculate reduce quality (compare against ground truth summary)
|
||||
if reduced_results:
|
||||
reduce_metrics = calculate_reduce_quality(reduced_results[0], ground_truth_summary)
|
||||
console.print(f" Reduce Quality - ROUGE-L: {reduce_metrics['rougeL']:.3f}, BLEU: {reduce_metrics['bleu']:.3f}")
|
||||
|
||||
results[precision_level]["sample_results"][fraction] = {
|
||||
"reduce_cost": reduce_cost,
|
||||
"reduce_time": reduce_time,
|
||||
"reduce_quality": reduce_metrics,
|
||||
"num_samples": len(sampled_docs),
|
||||
"summary": reduced_results[0].get('comprehensive_summary', '')[:200] + '...'
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def format_results_table(results: Dict[str, Any]) -> Table:
|
||||
"""Format experiment results as a Rich table"""
|
||||
table = Table(
|
||||
title="Medical Complaints Precision vs Quality Experiment",
|
||||
box=box.ROUNDED,
|
||||
show_header=True,
|
||||
header_style="bold cyan"
|
||||
)
|
||||
|
||||
# Add columns
|
||||
table.add_column("Precision Level", style="bold")
|
||||
table.add_column("Map ROUGE-L", justify="right")
|
||||
table.add_column("Avg Extract Length", justify="right")
|
||||
table.add_column("Sample Fraction", justify="right")
|
||||
table.add_column("Reduce ROUGE-L", justify="right")
|
||||
table.add_column("Reduce BLEU", justify="right")
|
||||
table.add_column("Total Cost ($)", justify="right")
|
||||
|
||||
# Add rows
|
||||
for precision_level, data in results.items():
|
||||
map_precision = data["map_precision"]
|
||||
|
||||
for i, (fraction, sample_data) in enumerate(data["sample_results"].items()):
|
||||
reduce_quality = sample_data["reduce_quality"]
|
||||
total_cost = data["extract_cost"] + sample_data["reduce_cost"]
|
||||
|
||||
# Only show precision level on first row for this precision level
|
||||
precision_display = precision_level if i == 0 else ""
|
||||
map_rouge_display = f"{map_precision['rougeL']:.3f}" if i == 0 else ""
|
||||
avg_len_display = f"{map_precision['avg_length']:.0f}" if i == 0 else ""
|
||||
|
||||
table.add_row(
|
||||
precision_display,
|
||||
map_rouge_display,
|
||||
avg_len_display,
|
||||
f"{fraction:.2f}",
|
||||
f"{reduce_quality['rougeL']:.3f}",
|
||||
f"{reduce_quality['bleu']:.3f}",
|
||||
f"${total_cost:.4f}"
|
||||
)
|
||||
|
||||
# Add section divider
|
||||
table.add_section()
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def print_analysis(results: Dict[str, Any]):
|
||||
"""Print analysis of the precision-quality tradeoff"""
|
||||
console = Console()
|
||||
|
||||
console.print("\n[bold]Analysis: Map Precision vs Reduce Quality[/bold]")
|
||||
|
||||
# For each sample fraction, compare across precision levels
|
||||
sample_fractions = list(next(iter(results.values()))["sample_results"].keys())
|
||||
|
||||
for fraction in sample_fractions:
|
||||
console.print(f"\n[bold cyan]Sample Fraction: {fraction}[/bold cyan]")
|
||||
|
||||
precision_data = []
|
||||
for precision_level in results.keys():
|
||||
sample_data = results[precision_level]["sample_results"][fraction]
|
||||
map_rouge = results[precision_level]["map_precision"]["rougeL"]
|
||||
reduce_rouge = sample_data["reduce_quality"]["rougeL"]
|
||||
reduce_bleu = sample_data["reduce_quality"]["bleu"]
|
||||
|
||||
precision_data.append({
|
||||
"level": precision_level,
|
||||
"map_rouge": map_rouge,
|
||||
"reduce_rouge": reduce_rouge,
|
||||
"reduce_bleu": reduce_bleu
|
||||
})
|
||||
|
||||
# Sort by map precision (ascending)
|
||||
precision_data.sort(key=lambda x: x["map_rouge"])
|
||||
|
||||
console.print(" Precision vs Quality:")
|
||||
for data in precision_data:
|
||||
console.print(f" {data['level']:20s} | Map ROUGE-L: {data['map_rouge']:.3f} -> Reduce ROUGE-L: {data['reduce_rouge']:.3f}, BLEU: {data['reduce_bleu']:.3f}")
|
||||
|
||||
# Check if lower map precision leads to higher reduce quality
|
||||
if len(precision_data) >= 2:
|
||||
lowest_precision = precision_data[0]
|
||||
highest_precision = precision_data[-1]
|
||||
|
||||
if lowest_precision["reduce_rouge"] > highest_precision["reduce_rouge"]:
|
||||
console.print(f" [bold green]✓ Lower map precision leads to higher reduce quality![/bold green]")
|
||||
console.print(f" Improvement: {(lowest_precision['reduce_rouge'] - highest_precision['reduce_rouge']) / highest_precision['reduce_rouge'] * 100:.1f}%")
|
||||
else:
|
||||
console.print(f" [yellow]✗ Higher map precision leads to higher reduce quality[/yellow]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
console = Console()
|
||||
|
||||
console.print("[bold green]Starting Medical Complaints Precision Experiment[/bold green]\n")
|
||||
|
||||
# Run experiment
|
||||
results = run_experiment(
|
||||
data_path=MEDICAL_DATA_PATH,
|
||||
precision_levels=["high_precision", "medium_precision", "low_precision"],
|
||||
sample_fractions=[0.5, 1.0], # Test with half and all samples
|
||||
data_limit=15, # Limit for faster experimentation
|
||||
max_workers=64
|
||||
)
|
||||
|
||||
# Display results
|
||||
console.print("\n[bold]Results Table:[/bold]")
|
||||
console.print(format_results_table(results))
|
||||
|
||||
# Print analysis
|
||||
print_analysis(results)
|
||||
|
||||
# Print example summaries
|
||||
console.print("\n[bold]Example Summaries (Full Sample):[/bold]")
|
||||
for precision_level in results.keys():
|
||||
if 1.0 in results[precision_level]["sample_results"]:
|
||||
summary = results[precision_level]["sample_results"][1.0]["summary"]
|
||||
console.print(f"\n[bold green]{precision_level}:[/bold green]")
|
||||
console.print(f" {summary}")
|
||||
|
|
@ -0,0 +1,632 @@
|
|||
"""
|
||||
Medical Complaints Map-Reduce Precision Experiment (Using Pandas API)
|
||||
|
||||
This experiment tests how extraction precision affects reduce quality in map-reduce pipelines.
|
||||
|
||||
Ground Truth:
|
||||
- Map operation: Extract chief complaints using expensive model (gpt-4o)
|
||||
- Reduce operation: Summarize using expensive model (gpt-4o)
|
||||
|
||||
Experiment Conditions:
|
||||
- Map operation: Extract with varying precision levels using cheap model (gpt-4o-mini)
|
||||
- Reduce operation: Summarize using expensive model (gpt-4o) - ALWAYS expensive
|
||||
|
||||
Metrics:
|
||||
- Map precision: Compare cheap model extractions to expensive model ground truth
|
||||
- Reduce quality: Compare experiment summaries to ground truth summary using ROUGE
|
||||
|
||||
Uses the DocETL pandas accessor API for cleaner code.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, List, Dict
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich import box
|
||||
from dotenv import load_dotenv
|
||||
import pandas as pd
|
||||
|
||||
# Import to register the semantic accessor
|
||||
import docetl.apis.pd_accessors
|
||||
|
||||
# For metrics
|
||||
try:
|
||||
from rouge_score import rouge_scorer
|
||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||
import nltk
|
||||
# Download required NLTK data
|
||||
try:
|
||||
nltk.data.find('tokenizers/punkt')
|
||||
except LookupError:
|
||||
nltk.download('punkt', quiet=True)
|
||||
except ImportError:
|
||||
rouge_scorer = None
|
||||
|
||||
# For semantic similarity
|
||||
try:
|
||||
from bert_score import score as bert_score
|
||||
BERTSCORE_AVAILABLE = True
|
||||
except ImportError:
|
||||
BERTSCORE_AVAILABLE = False
|
||||
print("Warning: bert-score not available. Install with: pip install bert-score")
|
||||
|
||||
try:
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import numpy as np
|
||||
# We'll use OpenAI embeddings since we're already using their models
|
||||
from openai import OpenAI
|
||||
openai_client = OpenAI()
|
||||
EMBEDDINGS_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDDINGS_AVAILABLE = False
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Constants
|
||||
MEDICAL_DATA_PATH = "docs/assets/medical_transcripts.json"
|
||||
EXPENSIVE_MODEL = "gpt-4o" # Expensive model for ground truth
|
||||
CHEAP_MODEL = "gpt-4o-mini" # Cheap model for experiments
|
||||
|
||||
# Ground truth extraction prompt (simple, no explicit precision mention)
|
||||
GROUND_TRUTH_MAP_PROMPT = """
|
||||
Extract the chief complaints from this medical transcript.
|
||||
|
||||
Transcript: {{ input.src }}
|
||||
|
||||
Return the chief complaints.
|
||||
"""
|
||||
|
||||
# Extraction prompts with different precision levels for experiments
|
||||
EXTRACTION_PROMPTS = {
|
||||
"high_precision": """
|
||||
Extract ONLY the primary chief complaint from this medical transcript.
|
||||
Be very strict - only extract the main reason the patient is visiting.
|
||||
Do not include any additional context or secondary complaints.
|
||||
|
||||
Transcript: {{ input.src }}
|
||||
|
||||
Return only the chief complaint as a brief string (1-2 sentences max).
|
||||
""",
|
||||
"medium_precision": """
|
||||
Extract the primary complaints from this medical transcript.
|
||||
Include the chief complaint and any closely related symptoms or concerns.
|
||||
|
||||
Transcript: {{ input.src }}
|
||||
|
||||
Return the complaints with minimal context.
|
||||
""",
|
||||
"low_precision": """
|
||||
Extract all complaints and concerns mentioned in this medical transcript.
|
||||
Include the chief complaint, related symptoms, and relevant context from the patient's history.
|
||||
Include any information that helps understand the patient's condition.
|
||||
|
||||
Transcript: {{ input.src }}
|
||||
|
||||
Return comprehensive information about the patient's complaints and concerns.
|
||||
"""
|
||||
}
|
||||
|
||||
# Reduce prompt for summarization
|
||||
REDUCE_PROMPT = """
|
||||
You are a medical professional tasked with synthesizing a highly detailed and specific chief complaint summary report.
|
||||
|
||||
Below are the extracted complaints from multiple medical transcripts:
|
||||
{% for input in inputs %}
|
||||
{{ loop.index }}. {{ input.extracted_complaints }}
|
||||
{% endfor %}
|
||||
|
||||
Please perform the following tasks in your summary:
|
||||
1. Clearly list at least 10 of the most frequently occurring chief complaints across these cases. For each, provide a brief explanation if possible.
|
||||
2. If fewer than 10 distinct complaints are present, list as many as are found, but explicitly state this.
|
||||
3. Quantify the frequency of each chief complaint (e.g., "Headache – 4 cases, Cough – 3 cases").
|
||||
4. Identify and briefly describe at least 3 significant patterns in the symptoms, combinations of complaints, or clinical presentations across the cases.
|
||||
5. Note and explain any recurring secondary symptoms or contextual factors (such as duration, severity, or relevant patient history) that could impact diagnosis or management.
|
||||
6. Highlight any outlier or unusual complaints that appear only once, if any.
|
||||
7. Organize your report into clear sections with informative headings: "Most Common Complaints", "Symptom Patterns", "Relevant Context & Secondary Findings", and "Outlier Complaints".
|
||||
8. Provide a closing synthesis—a concise, professional paragraph summarizing key findings, suitable for a medical audience.
|
||||
|
||||
Be as clear, precise, and comprehensive as possible. Use bullet points and tables if helpful. Write in a professional, clinical style.
|
||||
"""
|
||||
|
||||
|
||||
def load_medical_data(filepath: str, limit: int = None) -> pd.DataFrame:
|
||||
"""Load medical transcript data into a DataFrame"""
|
||||
with open(filepath, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
if limit:
|
||||
data = data[:limit]
|
||||
|
||||
return pd.DataFrame(data)
|
||||
|
||||
|
||||
def generate_ground_truth(df: pd.DataFrame) -> tuple[pd.DataFrame, str]:
|
||||
"""
|
||||
Generate ground truth using expensive model for both map and reduce.
|
||||
|
||||
Args:
|
||||
df: DataFrame with medical transcripts
|
||||
|
||||
Returns:
|
||||
Tuple of (ground truth map DataFrame, ground truth summary string)
|
||||
"""
|
||||
console = Console()
|
||||
console.print(f"[bold cyan]Generating ground truth: Map with {EXPENSIVE_MODEL}, Reduce with {EXPENSIVE_MODEL}[/bold cyan]")
|
||||
|
||||
# Extract with EXPENSIVE_MODEL using simple extraction prompt
|
||||
df_gt_map = df.semantic.map(
|
||||
prompt=GROUND_TRUTH_MAP_PROMPT,
|
||||
output={"schema": {"extracted_complaints": "string"}},
|
||||
model=EXPENSIVE_MODEL
|
||||
)
|
||||
|
||||
# Aggregate with EXPENSIVE_MODEL
|
||||
result_df = df_gt_map.semantic.agg(
|
||||
reduce_prompt=REDUCE_PROMPT,
|
||||
output={"schema": {"comprehensive_summary": "string"}},
|
||||
reduce_keys=["_all"],
|
||||
reduce_kwargs={"model": EXPENSIVE_MODEL}
|
||||
)
|
||||
|
||||
gt_summary = result_df.iloc[0]['comprehensive_summary']
|
||||
|
||||
return df_gt_map, gt_summary
|
||||
|
||||
|
||||
|
||||
|
||||
def calculate_map_precision(
|
||||
df_experiment: pd.DataFrame,
|
||||
df_ground_truth: pd.DataFrame
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Calculate precision metrics for the map operation by comparing to ground truth.
|
||||
|
||||
Args:
|
||||
df_experiment: DataFrame with extracted complaints from experiment (cheap model)
|
||||
df_ground_truth: DataFrame with ground truth extracted complaints (expensive model)
|
||||
|
||||
Returns:
|
||||
Dictionary with precision metrics
|
||||
"""
|
||||
# Calculate average length
|
||||
avg_length = df_experiment['extracted_complaints'].str.len().mean()
|
||||
|
||||
# Calculate number of tokens (rough estimate using word count)
|
||||
avg_tokens = df_experiment['extracted_complaints'].str.split().str.len().mean()
|
||||
|
||||
# For precision, we want to measure what fraction of the experiment's extraction
|
||||
# is relevant (i.e., matches the ground truth)
|
||||
# Precision = (# relevant extracted) / (# total extracted)
|
||||
# We'll use the overlap with ground truth as a proxy
|
||||
|
||||
precision_scores = []
|
||||
|
||||
for idx, row in df_experiment.iterrows():
|
||||
gt_text = df_ground_truth.loc[idx, 'extracted_complaints']
|
||||
extracted_text = row['extracted_complaints']
|
||||
|
||||
if not extracted_text or not gt_text:
|
||||
continue
|
||||
|
||||
# Split into tokens
|
||||
gt_tokens = set(gt_text.lower().split())
|
||||
extracted_tokens = set(extracted_text.lower().split())
|
||||
|
||||
if len(extracted_tokens) > 0:
|
||||
# Precision: what fraction of extracted tokens are in ground truth
|
||||
precision = len(gt_tokens & extracted_tokens) / len(extracted_tokens)
|
||||
precision_scores.append(precision)
|
||||
|
||||
return {
|
||||
"avg_length": avg_length,
|
||||
"avg_tokens": avg_tokens,
|
||||
"precision": sum(precision_scores) / len(precision_scores) if precision_scores else 0.0
|
||||
}
|
||||
|
||||
|
||||
def get_embedding(text: str, model: str = "text-embedding-3-small") -> list:
|
||||
"""Get embedding for a text using OpenAI API"""
|
||||
text = text.replace("\n", " ")
|
||||
return openai_client.embeddings.create(input=[text], model=model).data[0].embedding
|
||||
|
||||
|
||||
def calculate_semantic_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
Calculate semantic similarity between two texts using embeddings.
|
||||
|
||||
Args:
|
||||
text1: First text
|
||||
text2: Second text
|
||||
|
||||
Returns:
|
||||
Cosine similarity score (0-1)
|
||||
"""
|
||||
if not EMBEDDINGS_AVAILABLE:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
emb1 = get_embedding(text1)
|
||||
emb2 = get_embedding(text2)
|
||||
|
||||
# Calculate cosine similarity
|
||||
similarity = cosine_similarity([emb1], [emb2])[0][0]
|
||||
return float(similarity)
|
||||
except Exception as e:
|
||||
print(f"Error calculating semantic similarity: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def llm_judge_quality(summary: str, ground_truth: str, model: str = "gpt-4o-mini") -> Dict[str, float]:
|
||||
"""
|
||||
Use an LLM to judge the quality of a summary compared to ground truth.
|
||||
|
||||
Args:
|
||||
summary: The summary to evaluate
|
||||
ground_truth: The ground truth summary
|
||||
model: The model to use for judging
|
||||
|
||||
Returns:
|
||||
Dictionary with quality scores
|
||||
"""
|
||||
if not EMBEDDINGS_AVAILABLE:
|
||||
return {"factual_accuracy": 0.0, "completeness": 0.0, "overall_quality": 0.0}
|
||||
|
||||
judge_prompt = f"""You are an expert medical reviewer evaluating the quality of a chief complaint summary.
|
||||
|
||||
GROUND TRUTH SUMMARY (reference):
|
||||
{ground_truth}
|
||||
|
||||
CANDIDATE SUMMARY (to evaluate):
|
||||
{summary}
|
||||
|
||||
Please evaluate the candidate summary on the following criteria, giving a score from 0-10 for each:
|
||||
|
||||
1. FACTUAL ACCURACY: Are the complaints and their frequencies correct compared to the ground truth? Do they match the actual data?
|
||||
2. COMPLETENESS: Does it cover all the important complaints, patterns, and context mentioned in the ground truth?
|
||||
3. SPECIFICITY: Does it provide specific, actionable information (frequencies, patterns, context) rather than vague generalizations?
|
||||
|
||||
Respond ONLY with a JSON object in this exact format (no other text):
|
||||
{{"factual_accuracy": <score>, "completeness": <score>, "specificity": <score>}}"""
|
||||
|
||||
try:
|
||||
response = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": judge_prompt}],
|
||||
temperature=0.0
|
||||
)
|
||||
|
||||
result_text = response.choices[0].message.content.strip()
|
||||
# Parse JSON response
|
||||
import json as json_lib
|
||||
scores = json_lib.loads(result_text)
|
||||
|
||||
# Normalize to 0-1 range
|
||||
return {
|
||||
"factual_accuracy": scores["factual_accuracy"] / 10.0,
|
||||
"completeness": scores["completeness"] / 10.0,
|
||||
"specificity": scores["specificity"] / 10.0,
|
||||
"overall_quality": (scores["factual_accuracy"] + scores["completeness"] + scores["specificity"]) / 30.0
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error in LLM judge: {e}")
|
||||
return {"factual_accuracy": 0.0, "completeness": 0.0, "specificity": 0.0, "overall_quality": 0.0}
|
||||
|
||||
|
||||
def calculate_reduce_quality(
|
||||
summary: str,
|
||||
ground_truth_summary: str
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Calculate quality metrics for the reduce operation.
|
||||
|
||||
Args:
|
||||
summary: Generated summary
|
||||
ground_truth_summary: Ground truth summary from GPT-4o
|
||||
|
||||
Returns:
|
||||
Dictionary with quality metrics
|
||||
"""
|
||||
metrics = {}
|
||||
|
||||
# LLM-as-judge (most discriminating for structured outputs)
|
||||
if EMBEDDINGS_AVAILABLE:
|
||||
judge_scores = llm_judge_quality(summary, ground_truth_summary)
|
||||
metrics.update(judge_scores)
|
||||
else:
|
||||
metrics["factual_accuracy"] = 0.0
|
||||
metrics["completeness"] = 0.0
|
||||
metrics["specificity"] = 0.0
|
||||
metrics["overall_quality"] = 0.0
|
||||
|
||||
# Calculate BERTScore (semantic similarity)
|
||||
if BERTSCORE_AVAILABLE:
|
||||
try:
|
||||
P, R, F1 = bert_score([summary], [ground_truth_summary], lang="en", verbose=False)
|
||||
metrics["bertscore_f1"] = float(F1[0])
|
||||
metrics["bertscore_precision"] = float(P[0])
|
||||
metrics["bertscore_recall"] = float(R[0])
|
||||
except Exception as e:
|
||||
print(f"Error calculating BERTScore: {e}")
|
||||
metrics["bertscore_f1"] = 0.0
|
||||
metrics["bertscore_precision"] = 0.0
|
||||
metrics["bertscore_recall"] = 0.0
|
||||
else:
|
||||
metrics["bertscore_f1"] = 0.0
|
||||
metrics["bertscore_precision"] = 0.0
|
||||
metrics["bertscore_recall"] = 0.0
|
||||
|
||||
# Calculate ROUGE scores (for reference)
|
||||
if rouge_scorer is not None:
|
||||
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
|
||||
rouge_scores = scorer.score(ground_truth_summary, summary)
|
||||
metrics["rouge1"] = rouge_scores['rouge1'].fmeasure
|
||||
metrics["rouge2"] = rouge_scores['rouge2'].fmeasure
|
||||
metrics["rougeL"] = rouge_scores['rougeL'].fmeasure
|
||||
else:
|
||||
metrics["rouge1"] = 0.0
|
||||
metrics["rouge2"] = 0.0
|
||||
metrics["rougeL"] = 0.0
|
||||
|
||||
# Calculate BLEU score (for reference)
|
||||
try:
|
||||
reference_tokens = [ground_truth_summary.split()]
|
||||
candidate_tokens = summary.split()
|
||||
smoothing = SmoothingFunction().method1
|
||||
bleu_score = sentence_bleu(reference_tokens, candidate_tokens, smoothing_function=smoothing)
|
||||
except:
|
||||
bleu_score = 0.0
|
||||
|
||||
metrics["bleu"] = bleu_score
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def run_experiment(
|
||||
data_path: str = MEDICAL_DATA_PATH,
|
||||
precision_levels: List[str] = ["high_precision", "medium_precision", "low_precision"],
|
||||
data_limit: int = None
|
||||
):
|
||||
"""
|
||||
Run the main experiment comparing precision vs quality tradeoffs.
|
||||
|
||||
Args:
|
||||
data_path: Path to medical transcripts JSON file
|
||||
precision_levels: List of precision levels to test
|
||||
data_limit: Limit number of documents to process
|
||||
"""
|
||||
console = Console()
|
||||
|
||||
# Load data
|
||||
console.print(f"[bold]Loading medical data from {data_path}[/bold]")
|
||||
df = load_medical_data(data_path, limit=data_limit)
|
||||
console.print(f"Loaded {len(df)} medical transcripts\n")
|
||||
|
||||
# Generate ground truth with expensive model
|
||||
console.print(f"[bold cyan]Step 1: Generating ground truth with {EXPENSIVE_MODEL}[/bold cyan]")
|
||||
df_gt_map, ground_truth_summary = generate_ground_truth(df.copy())
|
||||
console.print(f"[bold green]✓ Ground truth generated ({len(ground_truth_summary)} chars)[/bold green]")
|
||||
console.print(f"\n[italic]Ground truth summary preview:[/italic]\n{ground_truth_summary[:300]}...\n")
|
||||
|
||||
# Store results
|
||||
results = {"ground_truth_summary": ground_truth_summary}
|
||||
|
||||
console.print(f"\n[bold cyan]Step 2: Running experiments (Map: {CHEAP_MODEL}, Reduce: {EXPENSIVE_MODEL})[/bold cyan]\n")
|
||||
|
||||
for precision_level in precision_levels:
|
||||
console.print(f"[bold green]Testing precision level: {precision_level}[/bold green]")
|
||||
|
||||
# Extract with CHEAP_MODEL
|
||||
df_fresh = df.copy()
|
||||
df_extracted = df_fresh.semantic.map(
|
||||
prompt=EXTRACTION_PROMPTS[precision_level],
|
||||
output={"schema": {"extracted_complaints": "string"}},
|
||||
model=CHEAP_MODEL
|
||||
)
|
||||
|
||||
# Calculate map precision against ground truth
|
||||
map_metrics = calculate_map_precision(df_extracted, df_gt_map)
|
||||
console.print(f" Map Precision: {map_metrics['precision']:.3f}, Avg Length: {map_metrics['avg_length']:.1f}, Avg Tokens: {map_metrics['avg_tokens']:.1f}")
|
||||
|
||||
# Reduce/summarize with EXPENSIVE_MODEL
|
||||
console.print(f" Reducing/summarizing complaints...")
|
||||
df_result = df_extracted.semantic.agg(
|
||||
reduce_prompt=REDUCE_PROMPT,
|
||||
output={"schema": {"comprehensive_summary": "string"}},
|
||||
reduce_keys=["_all"],
|
||||
reduce_kwargs={"model": EXPENSIVE_MODEL}
|
||||
)
|
||||
|
||||
summary = df_result.iloc[0]['comprehensive_summary']
|
||||
|
||||
# Calculate reduce quality against ground truth
|
||||
reduce_metrics = calculate_reduce_quality(summary, ground_truth_summary)
|
||||
|
||||
# Show LLM judge scores (most important)
|
||||
if EMBEDDINGS_AVAILABLE:
|
||||
console.print(f" Reduce Quality - Overall: {reduce_metrics['overall_quality']:.3f} (Acc: {reduce_metrics['factual_accuracy']:.3f}, Comp: {reduce_metrics['completeness']:.3f}, Spec: {reduce_metrics['specificity']:.3f})")
|
||||
|
||||
# Show other metrics
|
||||
if BERTSCORE_AVAILABLE:
|
||||
console.print(f" BERTScore F1: {reduce_metrics['bertscore_f1']:.3f}, ROUGE-L: {reduce_metrics['rougeL']:.3f}\n")
|
||||
else:
|
||||
console.print(f" ROUGE-L: {reduce_metrics['rougeL']:.3f}\n")
|
||||
|
||||
# Get costs
|
||||
extract_cost = df_extracted.semantic.total_cost
|
||||
reduce_cost = df_result.semantic.total_cost
|
||||
|
||||
results[precision_level] = {
|
||||
"map_precision": map_metrics,
|
||||
"reduce_quality": reduce_metrics,
|
||||
"summary": summary,
|
||||
"extract_cost": extract_cost,
|
||||
"reduce_cost": reduce_cost
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def format_results_table(results: Dict[str, Any]) -> Table:
|
||||
"""Format experiment results as a Rich table"""
|
||||
table = Table(
|
||||
title=f"Medical Complaints Precision Experiment\nGT: Map={EXPENSIVE_MODEL}, Reduce={EXPENSIVE_MODEL} | Exp: Map={CHEAP_MODEL}, Reduce={EXPENSIVE_MODEL}",
|
||||
box=box.ROUNDED,
|
||||
show_header=True,
|
||||
header_style="bold cyan"
|
||||
)
|
||||
|
||||
# Add columns
|
||||
table.add_column("Precision Level", style="bold")
|
||||
table.add_column("Map Precision", justify="right")
|
||||
table.add_column("Avg Tokens", justify="right")
|
||||
|
||||
# LLM judge columns (primary metrics)
|
||||
if EMBEDDINGS_AVAILABLE:
|
||||
table.add_column("Overall Quality", justify="right")
|
||||
table.add_column("Factual Acc", justify="right")
|
||||
table.add_column("Completeness", justify="right")
|
||||
table.add_column("Specificity", justify="right")
|
||||
|
||||
# Secondary metrics
|
||||
if BERTSCORE_AVAILABLE:
|
||||
table.add_column("BERTScore F1", justify="right")
|
||||
table.add_column("ROUGE-L", justify="right")
|
||||
|
||||
# Add rows
|
||||
precision_levels = [k for k in results.keys() if k != "ground_truth_summary"]
|
||||
|
||||
for precision_level in precision_levels:
|
||||
data = results[precision_level]
|
||||
map_precision = data["map_precision"]
|
||||
reduce_quality = data["reduce_quality"]
|
||||
|
||||
row = [
|
||||
precision_level,
|
||||
f"{map_precision['precision']:.3f}",
|
||||
f"{map_precision['avg_tokens']:.0f}",
|
||||
]
|
||||
|
||||
# Add LLM judge scores
|
||||
if EMBEDDINGS_AVAILABLE:
|
||||
row.extend([
|
||||
f"{reduce_quality['overall_quality']:.3f}",
|
||||
f"{reduce_quality['factual_accuracy']:.3f}",
|
||||
f"{reduce_quality['completeness']:.3f}",
|
||||
f"{reduce_quality['specificity']:.3f}"
|
||||
])
|
||||
|
||||
# Add secondary metrics
|
||||
if BERTSCORE_AVAILABLE:
|
||||
row.append(f"{reduce_quality['bertscore_f1']:.3f}")
|
||||
row.append(f"{reduce_quality['rougeL']:.3f}")
|
||||
|
||||
table.add_row(*row)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def print_analysis(results: Dict[str, Any]):
|
||||
"""Print analysis of the precision-quality tradeoff"""
|
||||
console = Console()
|
||||
|
||||
console.print("\n[bold]Analysis: Map Precision vs Reduce Quality[/bold]\n")
|
||||
|
||||
# Compare across precision levels
|
||||
precision_levels = [k for k in results.keys() if k != "ground_truth_summary"]
|
||||
|
||||
precision_data = []
|
||||
for precision_level in precision_levels:
|
||||
data = results[precision_level]
|
||||
map_precision = data["map_precision"]["precision"]
|
||||
reduce_quality = data["reduce_quality"]
|
||||
|
||||
# Use LLM judge as primary metric if available
|
||||
if EMBEDDINGS_AVAILABLE:
|
||||
primary_metric = reduce_quality["overall_quality"]
|
||||
primary_metric_name = "Overall Quality"
|
||||
factual_acc = reduce_quality["factual_accuracy"]
|
||||
completeness = reduce_quality["completeness"]
|
||||
specificity = reduce_quality["specificity"]
|
||||
elif BERTSCORE_AVAILABLE:
|
||||
primary_metric = reduce_quality["bertscore_f1"]
|
||||
primary_metric_name = "BERTScore F1"
|
||||
factual_acc = completeness = specificity = 0.0
|
||||
else:
|
||||
primary_metric = reduce_quality["rougeL"]
|
||||
primary_metric_name = "ROUGE-L"
|
||||
factual_acc = completeness = specificity = 0.0
|
||||
|
||||
precision_data.append({
|
||||
"level": precision_level,
|
||||
"map_precision": map_precision,
|
||||
"primary_metric": primary_metric,
|
||||
"primary_metric_name": primary_metric_name,
|
||||
"factual_acc": factual_acc,
|
||||
"completeness": completeness,
|
||||
"specificity": specificity,
|
||||
"reduce_rouge": reduce_quality["rougeL"]
|
||||
})
|
||||
|
||||
# Sort by map precision (descending - higher precision = more restrictive)
|
||||
precision_data.sort(key=lambda x: x["map_precision"], reverse=True)
|
||||
|
||||
console.print("Precision vs Quality:")
|
||||
for data in precision_data:
|
||||
if EMBEDDINGS_AVAILABLE:
|
||||
console.print(f" {data['level']:20s} | Map Prec: {data['map_precision']:.3f} -> Quality: {data['primary_metric']:.3f} (Acc: {data['factual_acc']:.3f}, Comp: {data['completeness']:.3f}, Spec: {data['specificity']:.3f})")
|
||||
else:
|
||||
console.print(f" {data['level']:20s} | Map Precision: {data['map_precision']:.3f} -> {data['primary_metric_name']}: {data['primary_metric']:.3f}")
|
||||
|
||||
# Check if lower map precision leads to higher reduce quality
|
||||
if len(precision_data) >= 2:
|
||||
highest_map_precision = precision_data[0] # Most restrictive extraction
|
||||
lowest_map_precision = precision_data[-1] # Least restrictive extraction
|
||||
|
||||
if lowest_map_precision["primary_metric"] > highest_map_precision["primary_metric"]:
|
||||
console.print(f"\n[bold green]✓ Lower map precision (more context) leads to higher reduce quality![/bold green]")
|
||||
improvement = (lowest_map_precision['primary_metric'] - highest_map_precision['primary_metric']) / highest_map_precision['primary_metric'] * 100
|
||||
console.print(f" {precision_data[0]['primary_metric_name']} Improvement: {improvement:.1f}%")
|
||||
|
||||
# Show breakdown if using LLM judge
|
||||
if EMBEDDINGS_AVAILABLE:
|
||||
console.print(f" Breakdown:")
|
||||
console.print(f" Factual Accuracy: {highest_map_precision['factual_acc']:.3f} -> {lowest_map_precision['factual_acc']:.3f}")
|
||||
console.print(f" Completeness: {highest_map_precision['completeness']:.3f} -> {lowest_map_precision['completeness']:.3f}")
|
||||
console.print(f" Specificity: {highest_map_precision['specificity']:.3f} -> {lowest_map_precision['specificity']:.3f}")
|
||||
else:
|
||||
console.print(f"\n[yellow]✗ Higher map precision (less context) leads to higher reduce quality[/yellow]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
console = Console()
|
||||
|
||||
console.print(f"[bold green]Medical Complaints Precision Experiment[/bold green]")
|
||||
console.print(f"[cyan]Ground Truth: Map={EXPENSIVE_MODEL}, Reduce={EXPENSIVE_MODEL}[/cyan]")
|
||||
console.print(f"[cyan]Experiments: Map={CHEAP_MODEL} (varying precision), Reduce={EXPENSIVE_MODEL}[/cyan]")
|
||||
console.print(f"[italic]Testing how cheap model's extraction precision affects expensive model's reduce quality[/italic]\n")
|
||||
|
||||
# Run experiment
|
||||
results = run_experiment(
|
||||
data_path=MEDICAL_DATA_PATH,
|
||||
precision_levels=["high_precision", "medium_precision", "low_precision"],
|
||||
data_limit=None # Use full dataset
|
||||
)
|
||||
|
||||
# Display results
|
||||
console.print("\n[bold]Results Table:[/bold]")
|
||||
console.print(format_results_table(results))
|
||||
|
||||
# Print analysis
|
||||
print_analysis(results)
|
||||
|
||||
# Print example summaries
|
||||
console.print("\n[bold]Example Summaries:[/bold]")
|
||||
console.print(f"\n[bold cyan]Ground Truth (Map={EXPENSIVE_MODEL}, Reduce={EXPENSIVE_MODEL}):[/bold cyan]")
|
||||
console.print(f" {results['ground_truth_summary'][:300]}...")
|
||||
|
||||
for precision_level in ["high_precision", "medium_precision", "low_precision"]:
|
||||
summary = results[precision_level]["summary"]
|
||||
console.print(f"\n[bold green]{precision_level} (Map={CHEAP_MODEL}, Reduce={EXPENSIVE_MODEL}):[/bold green]")
|
||||
console.print(f" {summary[:300]}...")
|
||||
|
|
@ -50,6 +50,10 @@ server = [
|
|||
"azure-ai-documentintelligence>=1.0.0b4",
|
||||
"httpx>=0.27.2",
|
||||
]
|
||||
experiments = [
|
||||
"nltk>=3.9.2",
|
||||
"rouge-score>=0.1.2",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
docetl = "docetl.cli:app"
|
||||
|
|
|
|||
45
uv.lock
45
uv.lock
|
|
@ -13,6 +13,15 @@ resolution-markers = [
|
|||
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
version = "2.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/10/2a/c93173ffa1b39c1d0395b7e842bbdc62e556ca9d8d3b5572926f3e4ca752/absl_py-2.3.1.tar.gz", hash = "sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9", size = 116588, upload-time = "2025-07-03T09:31:44.05Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8f/aa/ba0014cc4659328dc818a28827be78e6d97312ab0cb98105a770924dc11e/absl_py-2.3.1-py3-none-any.whl", hash = "sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d", size = 135811, upload-time = "2025-07-03T09:31:42.253Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "accelerate"
|
||||
version = "1.10.0"
|
||||
|
|
@ -490,6 +499,10 @@ dependencies = [
|
|||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
experiments = [
|
||||
{ name = "nltk" },
|
||||
{ name = "rouge-score" },
|
||||
]
|
||||
parsing = [
|
||||
{ name = "azure-ai-documentintelligence" },
|
||||
{ name = "openpyxl" },
|
||||
|
|
@ -543,6 +556,7 @@ requires-dist = [
|
|||
{ name = "jsonschema", specifier = ">=4.23.0" },
|
||||
{ name = "litellm", specifier = ">=1.75.4" },
|
||||
{ name = "lzstring", specifier = ">=1.0.4" },
|
||||
{ name = "nltk", marker = "extra == 'experiments'", specifier = ">=3.9.2" },
|
||||
{ name = "openpyxl", marker = "extra == 'parsing'", specifier = ">=3.1.5" },
|
||||
{ name = "paddlepaddle", marker = "extra == 'parsing'", specifier = ">=2.6.2" },
|
||||
{ name = "pandas", specifier = ">=2.3.0" },
|
||||
|
|
@ -556,13 +570,14 @@ requires-dist = [
|
|||
{ name = "rank-bm25", specifier = ">=0.2.2" },
|
||||
{ name = "rapidfuzz", specifier = ">=3.10.0" },
|
||||
{ name = "rich", specifier = ">=13.7.1" },
|
||||
{ name = "rouge-score", marker = "extra == 'experiments'", specifier = ">=0.1.2" },
|
||||
{ name = "scikit-learn", specifier = ">=1.5.2" },
|
||||
{ name = "tqdm", specifier = ">=4.66.4" },
|
||||
{ name = "typer", specifier = ">=0.16.0" },
|
||||
{ name = "uvicorn", marker = "extra == 'server'", specifier = ">=0.31.0" },
|
||||
{ name = "websockets", specifier = ">=13.1" },
|
||||
]
|
||||
provides-extras = ["parsing", "server"]
|
||||
provides-extras = ["parsing", "server", "experiments"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
|
|
@ -1891,6 +1906,21 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/b9/58/612a17593c2d117f96c7f6b7f1e6570246bddc4b1e808519403a1417f217/ninja-1.11.1.4-py3-none-win_arm64.whl", hash = "sha256:5713cf50c5be50084a8693308a63ecf9e55c3132a78a41ab1363a28b6caaaee1", size = 271441, upload-time = "2025-03-22T06:46:42.147Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nltk"
|
||||
version = "3.9.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
{ name = "joblib" },
|
||||
{ name = "regex" },
|
||||
{ name = "tqdm" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f9/76/3a5e4312c19a028770f86fd7c058cf9f4ec4321c6cf7526bab998a5b683c/nltk-3.9.2.tar.gz", hash = "sha256:0f409e9b069ca4177c1903c3e843eef90c7e92992fa4931ae607da6de49e1419", size = 2887629, upload-time = "2025-10-01T07:19:23.764Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/60/90/81ac364ef94209c100e12579629dc92bf7a709a84af32f8c551b02c07e94/nltk-3.9.2-py3-none-any.whl", hash = "sha256:1e209d2b3009110635ed9709a67a1a3e33a10f799490fa71cf4bec218c11c88a", size = 1513404, upload-time = "2025-10-01T07:19:21.648Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nodeenv"
|
||||
version = "1.9.1"
|
||||
|
|
@ -3349,6 +3379,19 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/e3/30/3c4d035596d3cf444529e0b2953ad0466f6049528a879d27534700580395/rich-14.1.0-py3-none-any.whl", hash = "sha256:536f5f1785986d6dbdea3c75205c473f970777b4a0d6c6dd1b696aa05a3fa04f", size = 243368, upload-time = "2025-07-25T07:32:56.73Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rouge-score"
|
||||
version = "0.1.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "absl-py" },
|
||||
{ name = "nltk" },
|
||||
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||
{ name = "numpy", version = "2.3.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||
{ name = "six" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz", hash = "sha256:c7d4da2683e68c9abf0135ef915d63a46643666f848e558a1b9f7ead17ff0f04", size = 17400, upload-time = "2022-07-22T22:46:22.909Z" }
|
||||
|
||||
[[package]]
|
||||
name = "rpds-py"
|
||||
version = "0.27.0"
|
||||
|
|
|
|||
Loading…
Reference in New Issue