Compare commits
1 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
b193375270 |
|
|
@ -673,6 +673,68 @@ Record 2: {record_template.replace('input0', 'input2')}"""
|
|||
|
||||
return self._record_operation(results, "reduce", reduce_config, reduce_cost)
|
||||
|
||||
def reduce(
|
||||
self,
|
||||
prompt: str,
|
||||
output: dict[str, Any] | None = None,
|
||||
*,
|
||||
output_schema: dict[str, Any] | None = None,
|
||||
reduce_keys: str | list[str] = ["_all"],
|
||||
**kwargs,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Reduce/aggregate all rows using a language model.
|
||||
|
||||
This is a simplified wrapper around the agg() method for common reduce operations.
|
||||
|
||||
Documentation: https://ucbepic.github.io/docetl/operators/reduce/
|
||||
|
||||
Args:
|
||||
prompt: Jinja template string for the reduction prompt. Use {% for item in inputs %}
|
||||
to iterate over input rows.
|
||||
output: Output configuration with keys:
|
||||
- "schema": Dictionary defining the expected output structure and types
|
||||
- "mode": Optional output mode. Either "tools" (default) or "structured_output"
|
||||
output_schema: DEPRECATED. Use 'output' parameter instead.
|
||||
reduce_keys: Keys to group by for reduction (default: ["_all"] for all rows)
|
||||
**kwargs: Additional configuration options passed to agg():
|
||||
- model: LLM model to use
|
||||
- validate: List of validation expressions
|
||||
- num_retries_on_validate_failure: Number of retries
|
||||
- timeout: Timeout in seconds (default: 120)
|
||||
- max_retries_per_timeout: Max retries per timeout (default: 2)
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Aggregated DataFrame with columns matching output['schema']
|
||||
|
||||
Examples:
|
||||
>>> # Summarize all texts into one summary
|
||||
>>> df.semantic.reduce(
|
||||
... prompt=\"\"\"Summarize the following items:
|
||||
... {% for item in inputs %}
|
||||
... - {{ item.text }}
|
||||
... {% endfor %}\"\"\",
|
||||
... output={"schema": {"summary": "str"}}
|
||||
... )
|
||||
|
||||
>>> # Reduce by group
|
||||
>>> df.semantic.reduce(
|
||||
... prompt=\"\"\"Combine items for {{ reduce_key }}:
|
||||
... {% for item in inputs %}
|
||||
... - {{ item.description }}
|
||||
... {% endfor %}\"\"\",
|
||||
... output={"schema": {"combined": "str"}},
|
||||
... reduce_keys=["category"]
|
||||
... )
|
||||
"""
|
||||
return self.agg(
|
||||
reduce_prompt=prompt,
|
||||
output=output,
|
||||
output_schema=output_schema,
|
||||
reduce_keys=reduce_keys,
|
||||
reduce_kwargs=kwargs,
|
||||
)
|
||||
|
||||
def filter(
|
||||
self,
|
||||
prompt: str,
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ from jinja2 import Template
|
|||
from pydantic import Field, field_validator, model_validator
|
||||
|
||||
from docetl.operations.base import BaseOperation
|
||||
from docetl.utils import has_jinja_syntax, prompt_user_for_non_jinja_confirmation
|
||||
from docetl.operations.clustering_utils import (
|
||||
cluster_documents,
|
||||
get_embeddings_for_clustering,
|
||||
|
|
@ -33,7 +32,11 @@ from docetl.operations.utils import (
|
|||
|
||||
# Import OutputMode enum for structured output checks
|
||||
from docetl.operations.utils.api import OutputMode
|
||||
from docetl.utils import completion_cost
|
||||
from docetl.utils import (
|
||||
completion_cost,
|
||||
has_jinja_syntax,
|
||||
prompt_user_for_non_jinja_confirmation,
|
||||
)
|
||||
|
||||
|
||||
class ReduceOperation(BaseOperation):
|
||||
|
|
@ -776,6 +779,7 @@ class ReduceOperation(BaseOperation):
|
|||
|
||||
end_time = time.time()
|
||||
self._update_fold_time(end_time - start_time)
|
||||
fold_cost = response.total_cost
|
||||
|
||||
if response.validated:
|
||||
structured_mode = (
|
||||
|
|
@ -844,6 +848,7 @@ class ReduceOperation(BaseOperation):
|
|||
|
||||
end_time = time.time()
|
||||
self._update_merge_time(end_time - start_time)
|
||||
merge_cost = response.total_cost
|
||||
|
||||
if response.validated:
|
||||
structured_mode = (
|
||||
|
|
|
|||
Loading…
Reference in New Issue