Compare commits

...

1 Commits

Author SHA1 Message Date
Shreya Shankar b193375270 add reduce to pandas api 2025-12-19 11:13:02 +05:30
2 changed files with 69 additions and 2 deletions

View File

@ -673,6 +673,68 @@ Record 2: {record_template.replace('input0', 'input2')}"""
return self._record_operation(results, "reduce", reduce_config, reduce_cost)
def reduce(
self,
prompt: str,
output: dict[str, Any] | None = None,
*,
output_schema: dict[str, Any] | None = None,
reduce_keys: str | list[str] = ["_all"],
**kwargs,
) -> pd.DataFrame:
"""
Reduce/aggregate all rows using a language model.
This is a simplified wrapper around the agg() method for common reduce operations.
Documentation: https://ucbepic.github.io/docetl/operators/reduce/
Args:
prompt: Jinja template string for the reduction prompt. Use {% for item in inputs %}
to iterate over input rows.
output: Output configuration with keys:
- "schema": Dictionary defining the expected output structure and types
- "mode": Optional output mode. Either "tools" (default) or "structured_output"
output_schema: DEPRECATED. Use 'output' parameter instead.
reduce_keys: Keys to group by for reduction (default: ["_all"] for all rows)
**kwargs: Additional configuration options passed to agg():
- model: LLM model to use
- validate: List of validation expressions
- num_retries_on_validate_failure: Number of retries
- timeout: Timeout in seconds (default: 120)
- max_retries_per_timeout: Max retries per timeout (default: 2)
Returns:
pd.DataFrame: Aggregated DataFrame with columns matching output['schema']
Examples:
>>> # Summarize all texts into one summary
>>> df.semantic.reduce(
... prompt=\"\"\"Summarize the following items:
... {% for item in inputs %}
... - {{ item.text }}
... {% endfor %}\"\"\",
... output={"schema": {"summary": "str"}}
... )
>>> # Reduce by group
>>> df.semantic.reduce(
... prompt=\"\"\"Combine items for {{ reduce_key }}:
... {% for item in inputs %}
... - {{ item.description }}
... {% endfor %}\"\"\",
... output={"schema": {"combined": "str"}},
... reduce_keys=["category"]
... )
"""
return self.agg(
reduce_prompt=prompt,
output=output,
output_schema=output_schema,
reduce_keys=reduce_keys,
reduce_kwargs=kwargs,
)
def filter(
self,
prompt: str,

View File

@ -20,7 +20,6 @@ from jinja2 import Template
from pydantic import Field, field_validator, model_validator
from docetl.operations.base import BaseOperation
from docetl.utils import has_jinja_syntax, prompt_user_for_non_jinja_confirmation
from docetl.operations.clustering_utils import (
cluster_documents,
get_embeddings_for_clustering,
@ -33,7 +32,11 @@ from docetl.operations.utils import (
# Import OutputMode enum for structured output checks
from docetl.operations.utils.api import OutputMode
from docetl.utils import completion_cost
from docetl.utils import (
completion_cost,
has_jinja_syntax,
prompt_user_for_non_jinja_confirmation,
)
class ReduceOperation(BaseOperation):
@ -776,6 +779,7 @@ class ReduceOperation(BaseOperation):
end_time = time.time()
self._update_fold_time(end_time - start_time)
fold_cost = response.total_cost
if response.validated:
structured_mode = (
@ -844,6 +848,7 @@ class ReduceOperation(BaseOperation):
end_time = time.time()
self._update_merge_time(end_time - start_time)
merge_cost = response.total_cost
if response.validated:
structured_mode = (