Compare commits
10 Commits
main
...
cursor/imp
| Author | SHA1 | Date |
|---|---|---|
|
|
3aa22d6180 | |
|
|
76359016c3 | |
|
|
fcec8fa92d | |
|
|
678693a688 | |
|
|
69a2a72b07 | |
|
|
67c661db30 | |
|
|
0be1a167f0 | |
|
|
1073ded6aa | |
|
|
2d7419a887 | |
|
|
672d1e5595 |
|
|
@ -1,6 +1,7 @@
|
|||
import datetime
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import pyrate_limiter
|
||||
from pyrate_limiter import BucketFullException, LimiterDelayException
|
||||
|
|
@ -70,8 +71,102 @@ class ConfigWrapper(object):
|
|||
self.rate_limiter = pyrate_limiter.Limiter(bucket_factory, max_delay=200)
|
||||
self.is_cancelled = False
|
||||
|
||||
# Store fallback configs
|
||||
self.fallback_models_config = self.config.get("fallback_models", [])
|
||||
self.fallback_embedding_models_config = self.config.get("fallback_embedding_models", [])
|
||||
# Create base routers as instance variables (for fallback models only)
|
||||
self.router = self._create_router(self.fallback_models_config, "completion")
|
||||
self.embedding_router = self._create_router(self.fallback_embedding_models_config, "embedding")
|
||||
# Cache routers per operation model (operation model + fallbacks)
|
||||
self._router_cache: dict[str, Any] = {}
|
||||
|
||||
self.api = APIWrapper(self)
|
||||
|
||||
def _create_router(self, fallback_models: list, router_type: str) -> Any | None:
|
||||
"""
|
||||
Create a LiteLLM Router with fallback models if configured.
|
||||
|
||||
Args:
|
||||
fallback_models: List of fallback model configurations
|
||||
router_type: Type of router ("completion" or "embedding") for logging
|
||||
|
||||
Returns:
|
||||
Router instance if fallback_models are configured, None otherwise.
|
||||
"""
|
||||
if not fallback_models:
|
||||
return None
|
||||
|
||||
try:
|
||||
from litellm import Router
|
||||
except ImportError:
|
||||
self.console.log(
|
||||
f"[yellow]Warning: LiteLLM Router not available. Fallback {router_type} models will be ignored.[/yellow]"
|
||||
)
|
||||
return None
|
||||
|
||||
# Build model list and fallbacks for Router
|
||||
model_list = []
|
||||
fallback_model_names = []
|
||||
|
||||
for fallback_config in fallback_models:
|
||||
if isinstance(fallback_config, dict):
|
||||
model_name = fallback_config.get("model_name")
|
||||
litellm_params = fallback_config.get("litellm_params", {})
|
||||
elif isinstance(fallback_config, str):
|
||||
model_name = fallback_config
|
||||
litellm_params = {}
|
||||
else:
|
||||
self.console.log(
|
||||
f"[yellow]Warning: Invalid fallback_{router_type}_models entry: {fallback_config}. Skipping.[/yellow]"
|
||||
)
|
||||
continue
|
||||
|
||||
if not model_name:
|
||||
self.console.log(
|
||||
f"[yellow]Warning: fallback_{router_type}_models entry missing model_name: {fallback_config}. Skipping.[/yellow]"
|
||||
)
|
||||
continue
|
||||
|
||||
# Ensure model is included in litellm_params (required by LiteLLM Router)
|
||||
litellm_params_with_model = litellm_params.copy()
|
||||
litellm_params_with_model["model"] = model_name
|
||||
|
||||
model_list.append(
|
||||
{
|
||||
"model_name": model_name,
|
||||
"litellm_params": litellm_params_with_model,
|
||||
}
|
||||
)
|
||||
fallback_model_names.append(model_name)
|
||||
|
||||
if not model_list:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Create Router with model_list and fallbacks parameter
|
||||
# fallbacks should be a list of dicts: [{"model1": ["fallback1", "fallback2"]}]
|
||||
router_kwargs = {"model_list": model_list}
|
||||
|
||||
# Build fallbacks list: each model falls back to the remaining models in order
|
||||
if len(fallback_model_names) > 1:
|
||||
fallbacks = []
|
||||
for i, model_name in enumerate(fallback_model_names):
|
||||
# Each model falls back to the models after it in the list
|
||||
if i < len(fallback_model_names) - 1:
|
||||
fallbacks.append({model_name: fallback_model_names[i + 1:]})
|
||||
router_kwargs["fallbacks"] = fallbacks
|
||||
|
||||
router = Router(**router_kwargs)
|
||||
self.console.log(
|
||||
f"[green]Created LiteLLM {router_type} Router with {len(model_list)} fallback model(s) in order: {', '.join(fallback_model_names)}[/green]"
|
||||
)
|
||||
return router
|
||||
except Exception as e:
|
||||
self.console.log(
|
||||
f"[yellow]Warning: Failed to create LiteLLM {router_type} Router: {e}. Fallback models will be ignored.[/yellow]"
|
||||
)
|
||||
return None
|
||||
|
||||
def reset_env(self):
|
||||
os.environ = self._original_env
|
||||
|
||||
|
|
|
|||
|
|
@ -67,6 +67,57 @@ class APIWrapper(object):
|
|||
self.default_embedding_api_base = runner.config.get(
|
||||
"default_embedding_api_base", None
|
||||
)
|
||||
# Use routers as instance variables (for fallback models)
|
||||
self.router = getattr(runner, "router", None)
|
||||
self.embedding_router = getattr(runner, "embedding_router", None)
|
||||
# Store fallback configs and router cache from runner
|
||||
self.fallback_models_config = getattr(runner, "fallback_models_config", [])
|
||||
self.runner_router_cache = getattr(runner, "_router_cache", {})
|
||||
|
||||
def _get_router_with_operation_model(self, operation_model: str) -> Any:
|
||||
"""
|
||||
Get Router completion function with operation's model first, then fallbacks.
|
||||
Uses cached Router from runner if available.
|
||||
"""
|
||||
# Return cached Router if available
|
||||
if operation_model in self.runner_router_cache:
|
||||
return self.runner_router_cache[operation_model].completion
|
||||
|
||||
from litellm import Router
|
||||
|
||||
# Build model list: operation model first, then fallbacks
|
||||
model_list = [{
|
||||
"model_name": operation_model,
|
||||
"litellm_params": {
|
||||
"model": operation_model,
|
||||
**({"api_base": self.default_lm_api_base} if self.default_lm_api_base else {})
|
||||
}
|
||||
}]
|
||||
model_names = [operation_model]
|
||||
|
||||
# Add fallback models, skipping duplicates
|
||||
seen = {operation_model}
|
||||
for cfg in self.fallback_models_config:
|
||||
name = cfg.get("model_name") if isinstance(cfg, dict) else (cfg if isinstance(cfg, str) else None)
|
||||
if not name or name in seen:
|
||||
continue
|
||||
seen.add(name)
|
||||
params = cfg.get("litellm_params", {}).copy() if isinstance(cfg, dict) else {}
|
||||
params["model"] = name
|
||||
if self.default_lm_api_base and "api_base" not in params:
|
||||
params["api_base"] = self.default_lm_api_base
|
||||
model_list.append({"model_name": name, "litellm_params": params})
|
||||
model_names.append(name)
|
||||
|
||||
# Build fallbacks list: operation model falls back to all fallback models
|
||||
router_kwargs = {"model_list": model_list}
|
||||
if len(model_names) > 1:
|
||||
# fallbacks should be a list of dicts: [{"model1": ["fallback1", "fallback2"]}]
|
||||
router_kwargs["fallbacks"] = [{operation_model: model_names[1:]}]
|
||||
|
||||
router = Router(**router_kwargs)
|
||||
self.runner_router_cache[operation_model] = router
|
||||
return router.completion
|
||||
|
||||
@freezeargs
|
||||
def gen_embedding(self, model: str, input: list[str]) -> list[float]:
|
||||
|
|
@ -119,7 +170,9 @@ class APIWrapper(object):
|
|||
if self.default_embedding_api_base:
|
||||
extra_kwargs["api_base"] = self.default_embedding_api_base
|
||||
|
||||
result = embedding(model=model, input=input, **extra_kwargs)
|
||||
# Use embedding router if available (for fallback models)
|
||||
embedding_fn = self.embedding_router.embedding if self.embedding_router else embedding
|
||||
result = embedding_fn(model=model, input=input, **extra_kwargs)
|
||||
# Cache the result
|
||||
c.set(key, result)
|
||||
|
||||
|
|
@ -305,7 +358,14 @@ class APIWrapper(object):
|
|||
"tool_choice",
|
||||
]
|
||||
|
||||
validator_response = completion(
|
||||
# Use router if available (for fallback models), otherwise use direct completion
|
||||
# When using router, ensure gleaning model is tried first, then fallback models
|
||||
if self.router and self.fallback_models_config:
|
||||
completion_fn = self._get_router_with_operation_model(gleaning_model)
|
||||
else:
|
||||
completion_fn = completion
|
||||
|
||||
validator_response = completion_fn(
|
||||
model=gleaning_model,
|
||||
messages=truncate_messages(
|
||||
validator_messages
|
||||
|
|
@ -780,9 +840,17 @@ Your main result must be sent via send_output. The updated_scratchpad is only fo
|
|||
if self.default_lm_api_base:
|
||||
extra_litellm_kwargs["api_base"] = self.default_lm_api_base
|
||||
|
||||
# Use router if available (for fallback models), otherwise use direct completion
|
||||
# When using router, ensure operation's model is tried first, then fallback models
|
||||
if self.router and self.fallback_models_config:
|
||||
# Build model list with operation's model first, then fallback models
|
||||
completion_fn = self._get_router_with_operation_model(model)
|
||||
else:
|
||||
completion_fn = completion
|
||||
|
||||
if use_structured_output:
|
||||
try:
|
||||
response = completion(
|
||||
response = completion_fn(
|
||||
model=model,
|
||||
messages=messages_with_system_prompt,
|
||||
response_format=response_format,
|
||||
|
|
@ -798,7 +866,7 @@ Your main result must be sent via send_output. The updated_scratchpad is only fo
|
|||
raise e
|
||||
elif tools is not None:
|
||||
try:
|
||||
response = completion(
|
||||
response = completion_fn(
|
||||
model=model,
|
||||
messages=messages_with_system_prompt,
|
||||
tools=tools,
|
||||
|
|
@ -815,7 +883,7 @@ Your main result must be sent via send_output. The updated_scratchpad is only fo
|
|||
raise e
|
||||
else:
|
||||
try:
|
||||
response = completion(
|
||||
response = completion_fn(
|
||||
model=model,
|
||||
messages=messages_with_system_prompt,
|
||||
**extra_litellm_kwargs,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
# Example configuration demonstrating LiteLLM fallback models for reliability
|
||||
#
|
||||
# This example shows how to configure fallback models that will be automatically
|
||||
# tried when API errors or content errors occur with the primary model.
|
||||
|
||||
datasets:
|
||||
example_dataset:
|
||||
type: file
|
||||
path: example_data/example.json
|
||||
|
||||
# Default language model for all operations unless overridden
|
||||
default_model: gpt-4o-mini
|
||||
|
||||
# Fallback models for completion/chat operations
|
||||
# Models will be tried in order when API errors or content errors occur
|
||||
fallback_models:
|
||||
# First fallback model
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
temperature: 0.0
|
||||
# Second fallback model
|
||||
- model_name: claude-3-haiku-20240307
|
||||
litellm_params:
|
||||
temperature: 0.0
|
||||
|
||||
# Fallback models for embedding operations
|
||||
# Separate configuration for embedding model fallbacks
|
||||
fallback_embedding_models:
|
||||
- model_name: text-embedding-3-small
|
||||
litellm_params: {}
|
||||
- model_name: text-embedding-ada-002
|
||||
litellm_params: {}
|
||||
|
||||
# Alternative simple format (just model names):
|
||||
# fallback_models:
|
||||
# - gpt-3.5-turbo
|
||||
# - claude-3-haiku-20240307
|
||||
#
|
||||
# fallback_embedding_models:
|
||||
# - text-embedding-3-small
|
||||
# - text-embedding-ada-002
|
||||
|
||||
operations:
|
||||
- name: example_map
|
||||
type: map
|
||||
prompt: "Extract key information from: {{ input.contents }}"
|
||||
output:
|
||||
schema:
|
||||
extracted_info: "str"
|
||||
|
||||
pipeline:
|
||||
steps:
|
||||
- name: process_data
|
||||
input: example_dataset
|
||||
operations:
|
||||
- example_map
|
||||
|
||||
output:
|
||||
type: file
|
||||
path: example_output.json
|
||||
|
|
@ -39,7 +39,7 @@ parsing = [
|
|||
"pydub>=0.25.1",
|
||||
"python-pptx>=1.0.2",
|
||||
"azure-ai-documentintelligence>=1.0.0b4",
|
||||
"paddlepaddle>=2.6.2",
|
||||
"paddlepaddle>=2.6.2,<3.2",
|
||||
"pymupdf>=1.24.10",
|
||||
]
|
||||
server = [
|
||||
|
|
|
|||
Loading…
Reference in New Issue