docetl/experiments/reasoning/evaluation/cuad.py

272 lines
11 KiB
Python

import json
import re
import pandas as pd
import numpy as np
from Levenshtein import distance
import random
def evaluate_results(method_name, results_file, ground_truth_file, original_json_file):
# Read the DocETL results JSON file
with open(results_file, "r") as f:
docetl_results = json.load(f)
docetl_results = pd.DataFrame(docetl_results)
with open(original_json_file, "r") as f:
original_json_content = json.load(f)
original_json_content_cleaned = {}
for item in original_json_content:
filename = item["name"].split("/")[-1].upper().rstrip(".TXT").replace(".", "").replace(",", "").replace(" ", "").replace("_", "").replace("-", "").replace("'", "").replace(r'[^a-zA-Z0-9]$', '')
original_json_content_cleaned[filename] = item
# Variables to track total clause length for calculating average
all_text_spans = []
# Read the ground truth CSV file
ground_truth_df = pd.read_csv(ground_truth_file)
ground_truth_df.columns = ground_truth_df.columns.map(
lambda x: x.replace(" ", "_").lower()
)
# Sort ground truth dataframe
ground_truth_df["filename"] = ground_truth_df["filename"].apply(
lambda x: x.upper()
.replace(".", "")
.replace(",", "")
.replace(" ", "")
.replace("_", "")
.replace("-", "")
.replace("'", "")
.replace(r'[^a-zA-Z0-9]$', '')
)
# Sort DocETL results
filename_key = "name" if "name" in list(docetl_results.columns) else "filename"
docetl_results["filename"] = docetl_results[filename_key].apply(
lambda x: x.split("/")[-1]
.upper()
.rstrip(".TXT")
.replace(".", "")
.replace(",", "")
.replace(" ", "")
.replace("_", "")
.replace("-", "")
.replace("'", "")
.replace(r'[^a-zA-Z0-9]$', '')
)
# Function to calculate precision and recall
def calculate_metrics(metric, predicted_results, ground_truth):
true_positives = 0
false_positives = 0
false_negatives = 0
true_negatives = 0
# For all predicted results, strip the non-alphanumeric characters
def clean_text(text):
if isinstance(text, list):
text = ' '.join(str(t) for t in text)
elif not isinstance(text, str):
text = str(text)
cleaned = re.sub(r'[^a-zA-Z0-9\s]', '', text)
return '' if len(cleaned) < 1 else cleaned
predicted_results_cleaned = {k: clean_text(v) for k, v in predicted_results.items()}
ground_truth_cleaned = {k: clean_text(v) for k, v in ground_truth.items()}
# print(len(ground_truth_cleaned.keys()))
# print("___________________________________")
# print(len(original_json_content_cleaned.keys()))
# Calculate metrics across all filenames
#print(f"----- {metric} -----")
for filename in original_json_content_cleaned:
if filename not in predicted_results_cleaned:
# print(f"Filename not in predicted results: {filename}")
pred = ""
else:
pred = predicted_results_cleaned[filename]
if filename not in ground_truth_cleaned:
truth = "ground truth not found"
else:
truth = ground_truth_cleaned[filename]
# Calculate Jaccard similarity between truth and pred
def get_jaccard_sim(str1, str2):
words1 = set(str1.lower().split())
words2 = set(str2.lower().split())
intersection = words1.intersection(words2)
union = words1.union(words2)
return len(intersection) / len(union) if union else 0
if truth != "" and pred != "":
#print(f"True positive (correct span)")
if get_jaccard_sim(truth, pred) > 0.15:
true_positives += 1
else:
if method_name == "gemini_gemini_opt":
print(f"False positive (incorrect span) for {method_name}, {pred}")
false_positives += 1
elif truth == "" and pred != "":
if method_name == "gemini_gemini_opt":
print(f"False positive (no span) for {method_name}, {pred}")
false_positives += 1
elif truth != "" and pred == "":
#print(f"False negative (no span) for {method_name}, {truth}")
false_negatives += 1
elif truth == "" and pred == "":
true_negatives += 1
#print(f"True negative")
# Calculate precision and recall
total_predictions = true_positives + false_positives
total_ground_truth = true_positives + false_negatives
if total_predictions == 0:
precision = 0.0
else:
precision = true_positives / total_predictions
if total_ground_truth == 0:
recall = float("nan")
else:
recall = true_positives / total_ground_truth
return precision, recall
# List of metrics to evaluate
metrics = [
"document_name", "parties", "agreement_date", "effective_date",
"expiration_date", "renewal_term", "notice_to_terminate_renewal",
"governing_law", "most_favored_nation", "non_compete", "exclusivity",
"no_solicit_of_customers", "competitive_restriction_exception",
"no_solicit_of_employees", "non_disparagement", "termination_for_convenience",
"right_of_first_refusal", "change_of_control", "anti_assignment",
"revenue_profit_sharing", "price_restriction", "minimum_commitment",
"volume_restriction", "ip_ownership_assignment", "joint_ip_ownership",
"license_grant", "non_transferable_license", "affiliate_ip_license_licensor",
"affiliate_ip_license_licensee", "unlimited_license",
"irrevocable_or_perpetual_license", "source_code_escrow",
"post_termination_services", "audit_rights", "uncapped_liability",
"cap_on_liability", "liquidated_damages", "warranty_duration",
"insurance", "covenant_not_to_sue", "third_party_beneficiary",
]
# Reindex the dataframes and join them
docetl_results = docetl_results.sort_values(by="filename").reset_index()
# print(f"Number of documents in results: {len(docetl_results)}")
# Find closest matching filename for each docetl result
matched_filenames = []
for docetl_filename in docetl_results["filename"]:
closest_match = max(
ground_truth_df["filename"],
key=lambda x: sum(a == b for a, b in zip(x, docetl_filename))
)
matched_filenames.append(closest_match)
ground_truth_df = ground_truth_df[ground_truth_df["filename"].isin(matched_filenames)]
ground_truth_df = ground_truth_df.sort_values(by="filename").reset_index()
docetl_results = docetl_results.sort_values(by="filename").reset_index()
# Drop any of the metric cols from docetl_results if they exist
existing_metrics = [col for col in metrics if col in docetl_results.columns]
if existing_metrics:
docetl_results = docetl_results.drop(columns=existing_metrics)
# Merge the dataframes on the index
merged_df = pd.merge(docetl_results, ground_truth_df, left_index=True, right_index=True, how="inner")
# print(merged_df.head())
# Calculate precision and recall for DocETL results
docetl_metrics = {}
for metric in metrics:
if metric in merged_df.columns:
predicted_clauses = dict(zip(merged_df["filename_x"], merged_df["clauses"]))
predicted_results = {}
for k, clauses in predicted_clauses.items():
if isinstance(clauses, dict):
clauses = clauses.get("clauses", [])
if not clauses:
predicted_results[k] = ""
continue
# Print out problematic c values
for c in clauses:
if not (isinstance(c, dict) and "clause_type" in c and "text_span" in c):
pass
#print(f"Problematic clause entry: {c} (type: {type(c)})")
clauses = [{
"clause_type": c["clause_type"].lower().strip().replace(" ", "_").replace("-", "_"),
"text_span": c["text_span"]
} for c in clauses if isinstance(c, dict) and "clause_type" in c and "text_span" in c]
clause_types = [c["clause_type"] for c in clauses]
if len(clause_types) == 0:
predicted_results[k] = ""
continue
closest_match = min(
clause_types,
key=lambda x: distance(x, metric)
)
# If the closest match doesn't share any words with the metric, set the closest match to ""
metric_words = metric.split('_')
closest_match_words = closest_match.split('_')
if not any(word in closest_match_words for word in metric_words):
predicted_results[k] = ""
continue
match_for_clause_type = [c["text_span"] for c in clauses if c["clause_type"] == closest_match][0]
predicted_results[k] = match_for_clause_type
precision, recall = calculate_metrics(
metric, predicted_results, merged_df[["filename_x", metric]].set_index("filename_x")[metric].to_dict()
)
docetl_metrics[metric] = {"precision": precision, "recall": recall}
# Track all extracted text spans for average length calculation
for key, text_span in predicted_results.items():
if text_span and isinstance(text_span, str) and len(text_span.strip()) > 0:
all_text_spans.append(text_span)
else:
docetl_metrics[metric] = {"precision": np.nan, "recall": np.nan}
# Calculate average clause length
avg_clause_length = 0
if len(all_text_spans) > 0:
avg_clause_length = sum(len(span) for span in all_text_spans) / len(all_text_spans)
# Calculate average metrics
precisions = [v["precision"] for v in docetl_metrics.values()]
recalls = [v["recall"] for v in docetl_metrics.values()]
avg_precision = np.nanmean(precisions)
avg_recall = np.nanmean(recalls)
nan_fraction = (np.isnan(precisions) | np.isnan(recalls)).mean()
# Calculate F1 score for non-nan values
f1_scores = [
(
2 * (p * r) / (p + r)
if not (np.isnan(p) or np.isnan(r)) and (p + r) != 0
else 0 if (p + r) == 0 else np.nan
)
for p, r in zip(precisions, recalls)
]
avg_f1 = np.nanmean(f1_scores)
return {
"avg_precision": avg_precision,
"avg_recall": avg_recall,
"nan_fraction": nan_fraction,
"avg_f1": avg_f1,
"avg_clause_length": avg_clause_length,
"per_metric": docetl_metrics
}