SLA-RedM/reference-deepwiki/deepwiki-open-main/api/dashscope_client.py

914 lines
37 KiB
Python

"""Dashscope (Alibaba Cloud) ModelClient integration."""
import os
import pickle
from typing import (
Dict,
Optional,
Any,
Callable,
Generator,
Union,
Literal,
List,
Sequence,
)
import logging
import backoff
from copy import deepcopy
from tqdm import tqdm
# optional import
from adalflow.utils.lazy_import import safe_import, OptionalPackages
openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1])
from openai import OpenAI, AsyncOpenAI, Stream
from openai import (
APITimeoutError,
InternalServerError,
RateLimitError,
UnprocessableEntityError,
BadRequestError,
)
from openai.types import (
Completion,
CreateEmbeddingResponse,
)
from openai.types.chat import ChatCompletionChunk, ChatCompletion
from adalflow.core.model_client import ModelClient
from adalflow.core.types import (
ModelType,
EmbedderOutput,
CompletionUsage,
GeneratorOutput,
Document,
Embedding,
EmbedderOutputType,
EmbedderInputType,
)
from adalflow.core.component import DataComponent
from adalflow.core.embedder import (
BatchEmbedderOutputType,
BatchEmbedderInputType,
)
import adalflow.core.functional as F
from adalflow.components.model_client.utils import parse_embedding_response
from api.logging_config import setup_logging
# # Disable tqdm progress bars
# os.environ["TQDM_DISABLE"] = "1"
setup_logging()
log = logging.getLogger(__name__)
def get_first_message_content(completion: ChatCompletion) -> str:
"""When we only need the content of the first message."""
log.info(f"🔍 get_first_message_content called with: {type(completion)}")
log.debug(f"raw completion: {completion}")
try:
if hasattr(completion, 'choices') and len(completion.choices) > 0:
choice = completion.choices[0]
if hasattr(choice, 'message') and hasattr(choice.message, 'content'):
content = choice.message.content
log.info(f"✅ Successfully extracted content: {type(content)}, length: {len(content) if content else 0}")
return content
else:
log.error("❌ Choice doesn't have message.content")
return str(completion)
else:
log.error("❌ Completion doesn't have choices")
return str(completion)
except Exception as e:
log.error(f"❌ Error in get_first_message_content: {e}")
return str(completion)
def parse_stream_response(completion: ChatCompletionChunk) -> str:
"""Parse the response of the stream API."""
return completion.choices[0].delta.content
def handle_streaming_response(generator: Stream[ChatCompletionChunk]):
"""Handle the streaming response."""
for completion in generator:
log.debug(f"Raw chunk completion: {completion}")
parsed_content = parse_stream_response(completion)
yield parsed_content
class DashscopeClient(ModelClient):
"""A component wrapper for the Dashscope (Alibaba Cloud) API client.
Dashscope provides access to Alibaba Cloud's Qwen and other models through an OpenAI-compatible API.
Args:
api_key (Optional[str], optional): Dashscope API key. Defaults to None.
workspace_id (Optional[str], optional): Dashscope workspace ID. Defaults to None.
base_url (str): The API base URL. Defaults to "https://dashscope.aliyuncs.com/compatible-mode/v1".
env_api_key_name (str): Environment variable name for the API key. Defaults to "DASHSCOPE_API_KEY".
env_workspace_id_name (str): Environment variable name for the workspace ID. Defaults to "DASHSCOPE_WORKSPACE_ID".
References:
- Dashscope API Documentation: https://help.aliyun.com/zh/dashscope/
"""
def __init__(
self,
api_key: Optional[str] = None,
workspace_id: Optional[str] = None,
chat_completion_parser: Callable[[Completion], Any] = None,
input_type: Literal["text", "messages"] = "text",
base_url: Optional[str] = None,
env_base_url_name: str = "DASHSCOPE_BASE_URL",
env_api_key_name: str = "DASHSCOPE_API_KEY",
env_workspace_id_name: str = "DASHSCOPE_WORKSPACE_ID",
):
super().__init__()
self._api_key = api_key
self._workspace_id = workspace_id
self._env_api_key_name = env_api_key_name
self._env_workspace_id_name = env_workspace_id_name
self._env_base_url_name = env_base_url_name
self.base_url = base_url or os.getenv(self._env_base_url_name, "https://dashscope.aliyuncs.com/compatible-mode/v1")
self.sync_client = self.init_sync_client()
self.async_client = None
# Force use of get_first_message_content to ensure string output
self.chat_completion_parser = get_first_message_content
self._input_type = input_type
self._api_kwargs = {}
def _prepare_client_config(self):
"""
Private helper method to prepare client configuration.
Returns:
tuple: (api_key, workspace_id, base_url) for client initialization
Raises:
ValueError: If API key is not provided
"""
api_key = self._api_key or os.getenv(self._env_api_key_name)
workspace_id = self._workspace_id or os.getenv(self._env_workspace_id_name)
if not api_key:
raise ValueError(
f"Environment variable {self._env_api_key_name} must be set"
)
if not workspace_id:
log.warning(f"Environment variable {self._env_workspace_id_name} not set. Some features may not work properly.")
# For Dashscope, we need to include the workspace ID in the base URL if provided
base_url = self.base_url
if workspace_id:
# Add workspace ID to headers or URL as required by Dashscope
base_url = f"{self.base_url.rstrip('/')}"
return api_key, workspace_id, base_url
def init_sync_client(self):
api_key, workspace_id, base_url = self._prepare_client_config()
client = OpenAI(api_key=api_key, base_url=base_url)
# Store workspace_id for later use in requests
if workspace_id:
client._workspace_id = workspace_id
return client
def init_async_client(self):
api_key, workspace_id, base_url = self._prepare_client_config()
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
# Store workspace_id for later use in requests
if workspace_id:
client._workspace_id = workspace_id
return client
def parse_chat_completion(
self,
completion: Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]],
) -> "GeneratorOutput":
"""Parse the completion response to a GeneratorOutput."""
try:
# If the completion is already a GeneratorOutput, return it directly (prevent recursion)
if isinstance(completion, GeneratorOutput):
return completion
# Check if it's a ChatCompletion object (non-streaming response)
if hasattr(completion, 'choices') and hasattr(completion, 'usage'):
# ALWAYS extract the string content directly
try:
# Direct extraction of message content
if (hasattr(completion, 'choices') and
len(completion.choices) > 0 and
hasattr(completion.choices[0], 'message') and
hasattr(completion.choices[0].message, 'content')):
content = completion.choices[0].message.content
if isinstance(content, str):
parsed_data = content
else:
parsed_data = str(content)
else:
# Fallback: convert entire completion to string
parsed_data = str(completion)
except Exception as e:
# Ultimate fallback
parsed_data = str(completion)
return GeneratorOutput(
data=parsed_data,
usage=CompletionUsage(
completion_tokens=completion.usage.completion_tokens,
prompt_tokens=completion.usage.prompt_tokens,
total_tokens=completion.usage.total_tokens,
),
raw_response=str(completion),
)
else:
# Handle streaming response - collect all content parts into a single string
content_parts = []
usage_info = None
for chunk in completion:
if chunk.choices[0].delta.content:
content_parts.append(chunk.choices[0].delta.content)
# Try to get usage info from the last chunk
if hasattr(chunk, 'usage') and chunk.usage:
usage_info = chunk.usage
# Join all content parts into a single string
full_content = ''.join(content_parts)
# Create usage object
usage = None
if usage_info:
usage = CompletionUsage(
completion_tokens=usage_info.completion_tokens,
prompt_tokens=usage_info.prompt_tokens,
total_tokens=usage_info.total_tokens,
)
return GeneratorOutput(
data=full_content,
usage=usage,
raw_response="streaming"
)
except Exception as e:
log.error(f"Error parsing completion: {e}")
raise
def track_completion_usage(
self,
completion: Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]],
) -> CompletionUsage:
"""Track the completion usage."""
if isinstance(completion, ChatCompletion):
return CompletionUsage(
completion_tokens=completion.usage.completion_tokens,
prompt_tokens=completion.usage.prompt_tokens,
total_tokens=completion.usage.total_tokens,
)
else:
# For streaming, we can't track usage accurately
return CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)
def parse_embedding_response(
self, response: CreateEmbeddingResponse
) -> EmbedderOutput:
"""Parse the embedding response to a EmbedderOutput."""
# Add detailed debugging
try:
result = parse_embedding_response(response)
if result.data:
log.info(f"🔍 Number of embeddings: {len(result.data)}")
if len(result.data) > 0:
log.info(f"🔍 First embedding length: {len(result.data[0].embedding) if hasattr(result.data[0], 'embedding') else 'N/A'}")
else:
log.warning(f"🔍 No embedding data found in result")
return result
except Exception as e:
log.error(f"🔍 Error parsing DashScope embedding response: {e}")
log.error(f"🔍 Raw response details: {repr(response)}")
return EmbedderOutput(data=[], error=str(e), raw_response=response)
def convert_inputs_to_api_kwargs(
self,
input: Optional[Any] = None,
model_kwargs: Dict = {},
model_type: ModelType = ModelType.UNDEFINED,
) -> Dict:
"""Convert inputs to API kwargs."""
final_model_kwargs = model_kwargs.copy()
if model_type == ModelType.LLM:
messages = []
if isinstance(input, str):
messages = [{"role": "user", "content": input}]
elif isinstance(input, list):
messages = input
else:
raise ValueError(f"Unsupported input type: {type(input)}")
api_kwargs = {
"messages": messages,
**final_model_kwargs
}
# Add workspace ID to headers if available
workspace_id = getattr(self.sync_client, '_workspace_id', None) or getattr(self.async_client, '_workspace_id', None)
if workspace_id:
# Dashscope may require workspace ID in headers
if 'extra_headers' not in api_kwargs:
api_kwargs['extra_headers'] = {}
api_kwargs['extra_headers']['X-DashScope-WorkSpace'] = workspace_id
return api_kwargs
elif model_type == ModelType.EMBEDDER:
# Convert Documents to text strings for embedding
processed_input = input
if isinstance(input, list):
# Extract text from Document objects
processed_input = []
for item in input:
if hasattr(item, 'text'):
# It's a Document object, extract text
processed_input.append(item.text)
elif isinstance(item, str):
# It's already a string
processed_input.append(item)
else:
# Try to convert to string
processed_input.append(str(item))
elif hasattr(input, 'text'):
# Single Document object
processed_input = input.text
elif isinstance(input, str):
# Single string
processed_input = input
else:
# Convert to string as fallback
processed_input = str(input)
api_kwargs = {
"input": processed_input,
**final_model_kwargs
}
# Add workspace ID to headers if available
workspace_id = getattr(self.sync_client, '_workspace_id', None) or getattr(self.async_client, '_workspace_id', None)
if workspace_id:
if 'extra_headers' not in api_kwargs:
api_kwargs['extra_headers'] = {}
api_kwargs['extra_headers']['X-DashScope-WorkSpace'] = workspace_id
return api_kwargs
else:
raise ValueError(f"model_type {model_type} is not supported")
@backoff.on_exception(
backoff.expo,
(
APITimeoutError,
InternalServerError,
RateLimitError,
UnprocessableEntityError,
BadRequestError,
),
max_time=5,
)
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
"""Call the Dashscope API."""
if model_type == ModelType.LLM:
if not api_kwargs.get("stream", False):
# For non-streaming, enable_thinking must be false.
# Pass it via extra_body to avoid TypeError from openai client validation.
extra_body = api_kwargs.get("extra_body", {})
extra_body["enable_thinking"] = False
api_kwargs["extra_body"] = extra_body
completion = self.sync_client.chat.completions.create(**api_kwargs)
if api_kwargs.get("stream", False):
return handle_streaming_response(completion)
else:
return self.parse_chat_completion(completion)
elif model_type == ModelType.EMBEDDER:
# Extract input texts from api_kwargs
texts = api_kwargs.get("input", [])
if not texts:
log.warning("😭 No input texts provided")
return EmbedderOutput(data=[], error="No input texts provided", raw_response=None)
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
# Filter out empty or None texts - following HuggingFace client pattern
valid_texts = []
valid_indices = []
for i, text in enumerate(texts):
if text and isinstance(text, str) and text.strip():
valid_texts.append(text)
valid_indices.append(i)
else:
log.warning(f"🔍 Skipping empty or invalid text at index {i}: type={type(text)}, length={len(text) if hasattr(text, '__len__') else 'N/A'}, repr={repr(text)[:100]}")
if not valid_texts:
log.error("😭 No valid texts found after filtering")
return EmbedderOutput(data=[], error="No valid texts found after filtering", raw_response=None)
if len(valid_texts) != len(texts):
filtered_count = len(texts) - len(valid_texts)
log.warning(f"🔍 Filtered out {filtered_count} empty/invalid texts out of {len(texts)} total texts")
# Create modified api_kwargs with only valid texts
filtered_api_kwargs = api_kwargs.copy()
filtered_api_kwargs["input"] = valid_texts
log.info(f"🔍 DashScope embedding API call with {len(valid_texts)} valid texts out of {len(texts)} total")
try:
response = self.sync_client.embeddings.create(**filtered_api_kwargs)
log.info(f"🔍 DashScope API call successful, response type: {type(response)}")
result = self.parse_embedding_response(response)
# If we filtered texts, we need to create embeddings for the original indices
if len(valid_texts) != len(texts):
log.info(f"🔍 Creating embeddings for {len(texts)} original positions")
# Get the correct embedding dimension from the first valid embedding
embedding_dim = None # Must be determined from a successful response
if result.data and len(result.data) > 0 and hasattr(result.data[0], 'embedding'):
embedding_dim = len(result.data[0].embedding)
log.info(f"🔍 Using embedding dimension: {embedding_dim}")
final_data = []
valid_idx = 0
for i in range(len(texts)):
if i in valid_indices:
# Use the embedding from valid texts
final_data.append(result.data[valid_idx])
valid_idx += 1
else:
# Create zero embedding for filtered texts with correct dimension
log.warning(f"🔍 Creating zero embedding for filtered text at index {i}")
final_data.append(Embedding(
embedding=[0.0] * embedding_dim, # Use correct embedding dimension
index=i
))
result = EmbedderOutput(
data=final_data,
error=None,
raw_response=result.raw_response
)
return result
except Exception as e:
log.error(f"🔍 DashScope API call failed: {e}")
return EmbedderOutput(data=[], error=str(e), raw_response=None)
else:
raise ValueError(f"model_type {model_type} is not supported")
@backoff.on_exception(
backoff.expo,
(
APITimeoutError,
InternalServerError,
RateLimitError,
UnprocessableEntityError,
BadRequestError,
),
max_time=5,
)
async def acall(
self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED
):
"""Async call to the Dashscope API."""
if not self.async_client:
self.async_client = self.init_async_client()
if model_type == ModelType.LLM:
if not api_kwargs.get("stream", False):
# For non-streaming, enable_thinking must be false.
extra_body = api_kwargs.get("extra_body", {})
extra_body["enable_thinking"] = False
api_kwargs["extra_body"] = extra_body
completion = await self.async_client.chat.completions.create(**api_kwargs)
if api_kwargs.get("stream", False):
return handle_streaming_response(completion)
else:
return self.parse_chat_completion(completion)
elif model_type == ModelType.EMBEDDER:
# Extract input texts from api_kwargs
texts = api_kwargs.get("input", [])
if not texts:
log.warning("😭 No input texts provided")
return EmbedderOutput(data=[], error="No input texts provided", raw_response=None)
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
# Filter out empty or None texts - following HuggingFace client pattern
valid_texts = []
valid_indices = []
for i, text in enumerate(texts):
if text and isinstance(text, str) and text.strip():
valid_texts.append(text)
valid_indices.append(i)
else:
log.warning(f"🔍 Skipping empty or invalid text at index {i}: type={type(text)}, length={len(text) if hasattr(text, '__len__') else 'N/A'}, repr={repr(text)[:100]}")
if not valid_texts:
log.error("😭 No valid texts found after filtering")
return EmbedderOutput(data=[], error="No valid texts found after filtering", raw_response=None)
if len(valid_texts) != len(texts):
filtered_count = len(texts) - len(valid_texts)
log.warning(f"🔍 Filtered out {filtered_count} empty/invalid texts out of {len(texts)} total texts")
# Create modified api_kwargs with only valid texts
filtered_api_kwargs = api_kwargs.copy()
filtered_api_kwargs["input"] = valid_texts
log.info(f"🔍 DashScope async embedding API call with {len(valid_texts)} valid texts out of {len(texts)} total")
try:
response = await self.async_client.embeddings.create(**filtered_api_kwargs)
log.info(f"🔍 DashScope async API call successful, response type: {type(response)}")
result = self.parse_embedding_response(response)
# If we filtered texts, we need to create embeddings for the original indices
if len(valid_texts) != len(texts):
log.info(f"🔍 Creating embeddings for {len(texts)} original positions")
# Get the correct embedding dimension from the first valid embedding
embedding_dim = 256 # Default fallback based on config
if result.data and len(result.data) > 0 and hasattr(result.data[0], 'embedding'):
embedding_dim = len(result.data[0].embedding)
log.info(f"🔍 Using embedding dimension: {embedding_dim}")
final_data = []
valid_idx = 0
for i in range(len(texts)):
if i in valid_indices:
# Use the embedding from valid texts
final_data.append(result.data[valid_idx])
valid_idx += 1
else:
# Create zero embedding for filtered texts with correct dimension
log.warning(f"🔍 Creating zero embedding for filtered text at index {i}")
final_data.append(Embedding(
embedding=[0.0] * embedding_dim, # Use correct embedding dimension
index=i
))
result = EmbedderOutput(
data=final_data,
error=None,
raw_response=result.raw_response
)
return result
except Exception as e:
log.error(f"🔍 DashScope async API call failed: {e}")
return EmbedderOutput(data=[], error=str(e), raw_response=None)
else:
raise ValueError(f"model_type {model_type} is not supported")
@classmethod
def from_dict(cls, data: Dict[str, Any]):
"""Create an instance from a dictionary."""
return cls(**data)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"api_key": self._api_key,
"workspace_id": self._workspace_id,
"base_url": self.base_url,
"input_type": self._input_type,
}
def __getstate__(self):
"""
Customize serialization to exclude non-picklable client objects.
This method is called by pickle when saving the object's state.
"""
state = self.__dict__.copy()
# Remove the unpicklable client instances
if 'sync_client' in state:
del state['sync_client']
if 'async_client' in state:
del state['async_client']
return state
def __setstate__(self, state):
"""
Customize deserialization to re-create the client objects.
This method is called by pickle when loading the object's state.
"""
self.__dict__.update(state)
# Re-initialize the clients after unpickling
self.sync_client = self.init_sync_client()
self.async_client = None # It will be lazily initialized when acall is used
class DashScopeEmbedder(DataComponent):
r"""
A user-facing component that orchestrates an embedder model via the DashScope model client and output processors.
Args:
model_client (ModelClient): The DashScope model client to use for the embedder.
model_kwargs (Dict[str, Any], optional): The model kwargs to pass to the model client. Defaults to {}.
output_processors (Optional[Component], optional): The output processors after model call. Defaults to None.
"""
model_type: ModelType = ModelType.EMBEDDER
model_client: ModelClient
output_processors: Optional[DataComponent]
def __init__(
self,
*,
model_client: ModelClient,
model_kwargs: Dict[str, Any] = {},
output_processors: Optional[DataComponent] = None,
) -> None:
super().__init__(model_kwargs=model_kwargs)
if not isinstance(model_kwargs, Dict):
raise TypeError(
f"{type(self).__name__} requires a dictionary for model_kwargs, not a string"
)
self.model_kwargs = model_kwargs.copy()
if not isinstance(model_client, ModelClient):
raise TypeError(
f"{type(self).__name__} requires a ModelClient instance for model_client."
)
self.model_client = model_client
self.output_processors = output_processors
def call(
self,
input: EmbedderInputType,
model_kwargs: Optional[Dict] = {},
) -> EmbedderOutputType:
log.debug(f"Calling {self.__class__.__name__} with input: {input}")
api_kwargs = self.model_client.convert_inputs_to_api_kwargs(
input=input,
model_kwargs=self._compose_model_kwargs(**model_kwargs),
model_type=self.model_type,
)
try:
output = self.model_client.call(
api_kwargs=api_kwargs, model_type=self.model_type
)
except Exception as e:
log.error(f"🤡 Error calling the DashScope model: {e}")
output = EmbedderOutput(error=str(e))
return output
async def acall(
self,
input: EmbedderInputType,
model_kwargs: Optional[Dict] = {},
) -> EmbedderOutputType:
log.debug(f"Calling {self.__class__.__name__} with input: {input}")
api_kwargs = self.model_client.convert_inputs_to_api_kwargs(
input=input,
model_kwargs=self._compose_model_kwargs(**model_kwargs),
model_type=self.model_type,
)
output: EmbedderOutputType = None
try:
response = await self.model_client.acall(
api_kwargs=api_kwargs, model_type=self.model_type
)
output = self.model_client.parse_embedding_response(response)
except Exception as e:
log.error(f"Error calling the DashScope model: {e}")
output = EmbedderOutput(error=str(e))
output.input = [input] if isinstance(input, str) else input
log.debug(f"Output from {self.__class__.__name__}: {output}")
return output
def _compose_model_kwargs(self, **model_kwargs) -> Dict[str, object]:
return F.compose_model_kwargs(self.model_kwargs, model_kwargs)
# Batch Embedding Components for DashScope
class DashScopeBatchEmbedder(DataComponent):
"""Batch embedder specifically designed for DashScope API"""
def __init__(self, embedder, batch_size: int = 100, embedding_cache_file_name: str = "default") -> None:
super().__init__(batch_size=batch_size)
self.embedder = embedder
self.batch_size = batch_size
if self.batch_size > 25:
log.warning(f"DashScope batch embedder initialization, batch size: {self.batch_size}, note that DashScope batch embedding size cannot exceed 25, automatically set to 25")
self.batch_size = 25
self.cache_path = f'./embedding_cache/{embedding_cache_file_name}_{self.embedder.__class__.__name__}_dashscope_embeddings.pkl'
def call(
self, input: BatchEmbedderInputType, model_kwargs: Optional[Dict] = {}, force_recreate: bool = False
) -> BatchEmbedderOutputType:
"""
Batch call to DashScope embedder
Args:
input: List of input texts
model_kwargs: Model parameters
force_recreate: Whether to force recreation
Returns:
Batch embedding output
"""
# Check cache first
if not force_recreate and os.path.exists(self.cache_path):
try:
with open(self.cache_path, 'rb') as f:
embeddings = pickle.load(f)
log.info(f"Loaded cached DashScope embeddings from: {self.cache_path}")
return embeddings
except Exception as e:
log.warning(f"Failed to load cache file {self.cache_path}: {e}, proceeding with fresh embedding")
if isinstance(input, str):
input = [input]
n = len(input)
embeddings: List[EmbedderOutput] = []
log.info(f"Starting DashScope batch embedding processing, total {n} texts, batch size: {self.batch_size}")
for i in tqdm(
range(0, n, self.batch_size),
desc="DashScope batch embedding",
disable=False,
):
batch_input = input[i : min(i + self.batch_size, n)]
try:
# Use correct calling method: directly call embedder instance
batch_output = self.embedder(
input=batch_input, model_kwargs=model_kwargs
)
embeddings.append(batch_output)
# Validate batch output
if batch_output.error:
log.error(f"Batch {i//self.batch_size + 1} embedding failed: {batch_output.error}")
elif batch_output.data:
log.debug(f"Batch {i//self.batch_size + 1} successfully generated {len(batch_output.data)} embedding vectors")
else:
log.warning(f"Batch {i//self.batch_size + 1} returned no embedding data")
except Exception as e:
log.error(f"Batch {i//self.batch_size + 1} processing exception: {e}")
# Create error embedding output
error_output = EmbedderOutput(
data=[],
error=str(e),
raw_response=None
)
embeddings.append(error_output)
log.info(f"DashScope batch embedding completed, processed {len(embeddings)} batches")
# Save to cache
try:
if not os.path.exists('./embedding_cache'):
os.makedirs('./embedding_cache')
with open(self.cache_path, 'wb') as f:
pickle.dump(embeddings, f)
log.info(f"Saved DashScope embeddings cache to: {self.cache_path}")
except Exception as e:
log.warning(f"Failed to save cache to {self.cache_path}: {e}")
return embeddings
def __call__(self, input: BatchEmbedderInputType, model_kwargs: Optional[Dict] = {}, force_recreate: bool = False) -> BatchEmbedderOutputType:
"""
Call operator interface, delegates to call method
"""
return self.call(input=input, model_kwargs=model_kwargs, force_recreate=force_recreate)
class DashScopeToEmbeddings(DataComponent):
"""Component that converts document sequences to embedding vector sequences, specifically optimized for DashScope API"""
def __init__(self, embedder, batch_size: int = 100, force_recreate_db: bool = False, embedding_cache_file_name: str = "default") -> None:
super().__init__(batch_size=batch_size)
self.embedder = embedder
self.batch_size = batch_size
self.batch_embedder = DashScopeBatchEmbedder(embedder=embedder, batch_size=batch_size, embedding_cache_file_name=embedding_cache_file_name)
self.force_recreate_db = force_recreate_db
def __call__(self, input: List[Document]) -> List[Document]:
"""
Process list of documents, generating embedding vectors for each document
Args:
input: List of input documents
Returns:
List of documents containing embedding vectors
"""
output = deepcopy(input)
# Convert to text list
embedder_input: List[str] = [chunk.text for chunk in output]
log.info(f"Starting to process embeddings for {len(embedder_input)} documents")
# Batch process embeddings
outputs: List[EmbedderOutput] = self.batch_embedder(
input=embedder_input,
force_recreate=self.force_recreate_db
)
# Validate output
total_embeddings = 0
error_batches = 0
for batch_output in outputs:
if batch_output.error:
error_batches += 1
log.error(f"Found error batch: {batch_output.error}")
elif batch_output.data:
total_embeddings += len(batch_output.data)
log.info(f"Embedding statistics: total {total_embeddings} valid embeddings, {error_batches} error batches")
# Assign embedding vectors back to documents
doc_idx = 0
for batch_idx, batch_output in tqdm(
enumerate(outputs),
desc="Assigning embedding vectors to documents",
disable=False
):
if batch_output.error:
# Create empty vectors for documents in error batches
batch_size_actual = min(self.batch_size, len(output) - doc_idx)
log.warning(f"Creating empty vectors for {batch_size_actual} documents in batch {batch_idx}")
for i in range(batch_size_actual):
if doc_idx < len(output):
output[doc_idx].vector = []
doc_idx += 1
else:
# Assign normal embedding vectors
for embedding in batch_output.data:
if doc_idx < len(output):
if hasattr(embedding, 'embedding'):
output[doc_idx].vector = embedding.embedding
else:
log.warning(f"Invalid embedding format for document {doc_idx}")
output[doc_idx].vector = []
doc_idx += 1
# Validate results
valid_count = 0
empty_count = 0
for doc in output:
if hasattr(doc, 'vector') and doc.vector and len(doc.vector) > 0:
valid_count += 1
else:
empty_count += 1
log.info(f"Embedding results: {valid_count} valid vectors, {empty_count} empty vectors")
if valid_count == 0:
log.error("❌ All documents have empty embedding vectors!")
elif empty_count > 0:
log.warning(f"⚠️ Found {empty_count} empty embedding vectors")
else:
log.info("✅ All documents successfully generated embedding vectors")
return output
def _extra_repr(self) -> str:
return f"batch_size={self.batch_size}"