526 lines
28 KiB
Python
526 lines
28 KiB
Python
"""OpenRouter ModelClient integration."""
|
|
|
|
from typing import Dict, Sequence, Optional, Any, List
|
|
import logging
|
|
import json
|
|
import aiohttp
|
|
import requests
|
|
from requests.exceptions import RequestException, Timeout
|
|
|
|
from adalflow.core.model_client import ModelClient
|
|
from adalflow.core.types import (
|
|
CompletionUsage,
|
|
ModelType,
|
|
GeneratorOutput,
|
|
)
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
class OpenRouterClient(ModelClient):
|
|
__doc__ = r"""A component wrapper for the OpenRouter API client.
|
|
|
|
OpenRouter provides a unified API that gives access to hundreds of AI models through a single endpoint.
|
|
The API is compatible with OpenAI's API format with a few small differences.
|
|
|
|
Visit https://openrouter.ai/docs for more details.
|
|
|
|
Example:
|
|
```python
|
|
from api.openrouter_client import OpenRouterClient
|
|
|
|
client = OpenRouterClient()
|
|
generator = adal.Generator(
|
|
model_client=client,
|
|
model_kwargs={"model": "openai/gpt-4o"}
|
|
)
|
|
```
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
"""Initialize the OpenRouter client."""
|
|
super().__init__(*args, **kwargs)
|
|
self.sync_client = self.init_sync_client()
|
|
self.async_client = None # Initialize async client only when needed
|
|
|
|
def init_sync_client(self):
|
|
"""Initialize the synchronous OpenRouter client."""
|
|
from api.config import OPENROUTER_API_KEY
|
|
api_key = OPENROUTER_API_KEY
|
|
if not api_key:
|
|
log.warning("OPENROUTER_API_KEY not configured")
|
|
|
|
# OpenRouter doesn't have a dedicated client library, so we'll use requests directly
|
|
return {
|
|
"api_key": api_key,
|
|
"base_url": "https://openrouter.ai/api/v1"
|
|
}
|
|
|
|
def init_async_client(self):
|
|
"""Initialize the asynchronous OpenRouter client."""
|
|
from api.config import OPENROUTER_API_KEY
|
|
api_key = OPENROUTER_API_KEY
|
|
if not api_key:
|
|
log.warning("OPENROUTER_API_KEY not configured")
|
|
|
|
# For async, we'll use aiohttp
|
|
return {
|
|
"api_key": api_key,
|
|
"base_url": "https://openrouter.ai/api/v1"
|
|
}
|
|
|
|
def convert_inputs_to_api_kwargs(
|
|
self, input: Any, model_kwargs: Dict = None, model_type: ModelType = None
|
|
) -> Dict:
|
|
"""Convert AdalFlow inputs to OpenRouter API format."""
|
|
model_kwargs = model_kwargs or {}
|
|
|
|
if model_type == ModelType.LLM:
|
|
# Handle LLM generation
|
|
messages = []
|
|
|
|
# Convert input to messages format if it's a string
|
|
if isinstance(input, str):
|
|
messages = [{"role": "user", "content": input}]
|
|
elif isinstance(input, list) and all(isinstance(msg, dict) for msg in input):
|
|
messages = input
|
|
else:
|
|
raise ValueError(f"Unsupported input format for OpenRouter: {type(input)}")
|
|
|
|
# For debugging
|
|
log.info(f"Messages for OpenRouter: {messages}")
|
|
|
|
api_kwargs = {
|
|
"messages": messages,
|
|
**model_kwargs
|
|
}
|
|
|
|
# Ensure model is specified
|
|
if "model" not in api_kwargs:
|
|
api_kwargs["model"] = "openai/gpt-3.5-turbo"
|
|
|
|
return api_kwargs
|
|
|
|
elif model_type == ModelType.EMBEDDING:
|
|
# OpenRouter doesn't support embeddings directly
|
|
# We could potentially use a specific model through OpenRouter for embeddings
|
|
# but for now, we'll raise an error
|
|
raise NotImplementedError("OpenRouter client does not support embeddings yet")
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported model type: {model_type}")
|
|
|
|
async def acall(self, api_kwargs: Dict = None, model_type: ModelType = None) -> Any:
|
|
"""Make an asynchronous call to the OpenRouter API."""
|
|
if not self.async_client:
|
|
self.async_client = self.init_async_client()
|
|
|
|
# Check if API key is set
|
|
if not self.async_client.get("api_key"):
|
|
error_msg = "OPENROUTER_API_KEY not configured. Please set this environment variable to use OpenRouter."
|
|
log.error(error_msg)
|
|
# Instead of raising an exception, return a generator that yields the error message
|
|
# This allows the error to be displayed to the user in the streaming response
|
|
async def error_generator():
|
|
yield error_msg
|
|
return error_generator()
|
|
|
|
api_kwargs = api_kwargs or {}
|
|
|
|
if model_type == ModelType.LLM:
|
|
# Prepare headers
|
|
headers = {
|
|
"Authorization": f"Bearer {self.async_client['api_key']}",
|
|
"Content-Type": "application/json",
|
|
"HTTP-Referer": "https://github.com/AsyncFuncAI/deepwiki-open", # Optional
|
|
"X-Title": "DeepWiki" # Optional
|
|
}
|
|
|
|
# Always use non-streaming mode for OpenRouter
|
|
api_kwargs["stream"] = False
|
|
|
|
# Make the API call
|
|
try:
|
|
log.info(f"Making async OpenRouter API call to {self.async_client['base_url']}/chat/completions")
|
|
log.info(f"Request headers: {headers}")
|
|
log.info(f"Request body: {api_kwargs}")
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
try:
|
|
async with session.post(
|
|
f"{self.async_client['base_url']}/chat/completions",
|
|
headers=headers,
|
|
json=api_kwargs,
|
|
timeout=60
|
|
) as response:
|
|
if response.status != 200:
|
|
error_text = await response.text()
|
|
log.error(f"OpenRouter API error ({response.status}): {error_text}")
|
|
|
|
# Return a generator that yields the error message
|
|
async def error_response_generator():
|
|
yield f"OpenRouter API error ({response.status}): {error_text}"
|
|
return error_response_generator()
|
|
|
|
# Get the full response
|
|
data = await response.json()
|
|
log.info(f"Received response from OpenRouter: {data}")
|
|
|
|
# Create a generator that yields the content
|
|
async def content_generator():
|
|
if "choices" in data and len(data["choices"]) > 0:
|
|
choice = data["choices"][0]
|
|
if "message" in choice and "content" in choice["message"]:
|
|
content = choice["message"]["content"]
|
|
log.info("Successfully retrieved response")
|
|
|
|
# Check if the content is XML and ensure it's properly formatted
|
|
if content.strip().startswith("<") and ">" in content:
|
|
# It's likely XML, let's make sure it's properly formatted
|
|
try:
|
|
# Extract the XML content
|
|
xml_content = content
|
|
|
|
# Check if it's a wiki_structure XML
|
|
if "<wiki_structure>" in xml_content:
|
|
log.info("Found wiki_structure XML, ensuring proper format")
|
|
|
|
# Extract just the wiki_structure XML
|
|
import re
|
|
wiki_match = re.search(r'<wiki_structure>[\s\S]*?<\/wiki_structure>', xml_content)
|
|
if wiki_match:
|
|
# Get the raw XML
|
|
raw_xml = wiki_match.group(0)
|
|
|
|
# Clean the XML by removing any leading/trailing whitespace
|
|
# and ensuring it's properly formatted
|
|
clean_xml = raw_xml.strip()
|
|
|
|
# Try to fix common XML issues
|
|
try:
|
|
# Replace problematic characters in XML
|
|
fixed_xml = clean_xml
|
|
|
|
# Replace & with & if not already part of an entity
|
|
fixed_xml = re.sub(r'&(?!amp;|lt;|gt;|apos;|quot;)', '&', fixed_xml)
|
|
|
|
# Fix other common XML issues
|
|
fixed_xml = fixed_xml.replace('</', '</').replace(' >', '>')
|
|
|
|
# Try to parse the fixed XML
|
|
from xml.dom.minidom import parseString
|
|
dom = parseString(fixed_xml)
|
|
|
|
# Get the pretty-printed XML with proper indentation
|
|
pretty_xml = dom.toprettyxml()
|
|
|
|
# Remove XML declaration
|
|
if pretty_xml.startswith('<?xml'):
|
|
pretty_xml = pretty_xml[pretty_xml.find('?>')+2:].strip()
|
|
|
|
log.info(f"Extracted and validated XML: {pretty_xml[:100]}...")
|
|
yield pretty_xml
|
|
except Exception as xml_parse_error:
|
|
log.warning(f"XML validation failed: {str(xml_parse_error)}, using raw XML")
|
|
|
|
# If XML validation fails, try a more aggressive approach
|
|
try:
|
|
# Use regex to extract just the structure without any problematic characters
|
|
import re
|
|
|
|
# Extract the basic structure
|
|
structure_match = re.search(r'<wiki_structure>(.*?)</wiki_structure>', clean_xml, re.DOTALL)
|
|
if structure_match:
|
|
structure = structure_match.group(1).strip()
|
|
|
|
# Rebuild a clean XML structure
|
|
clean_structure = "<wiki_structure>\n"
|
|
|
|
# Extract title
|
|
title_match = re.search(r'<title>(.*?)</title>', structure, re.DOTALL)
|
|
if title_match:
|
|
title = title_match.group(1).strip()
|
|
clean_structure += f" <title>{title}</title>\n"
|
|
|
|
# Extract description
|
|
desc_match = re.search(r'<description>(.*?)</description>', structure, re.DOTALL)
|
|
if desc_match:
|
|
desc = desc_match.group(1).strip()
|
|
clean_structure += f" <description>{desc}</description>\n"
|
|
|
|
# Add pages section
|
|
clean_structure += " <pages>\n"
|
|
|
|
# Extract pages
|
|
pages = re.findall(r'<page id="(.*?)">(.*?)</page>', structure, re.DOTALL)
|
|
for page_id, page_content in pages:
|
|
clean_structure += f' <page id="{page_id}">\n'
|
|
|
|
# Extract page title
|
|
page_title_match = re.search(r'<title>(.*?)</title>', page_content, re.DOTALL)
|
|
if page_title_match:
|
|
page_title = page_title_match.group(1).strip()
|
|
clean_structure += f" <title>{page_title}</title>\n"
|
|
|
|
# Extract page description
|
|
page_desc_match = re.search(r'<description>(.*?)</description>', page_content, re.DOTALL)
|
|
if page_desc_match:
|
|
page_desc = page_desc_match.group(1).strip()
|
|
clean_structure += f" <description>{page_desc}</description>\n"
|
|
|
|
# Extract importance
|
|
importance_match = re.search(r'<importance>(.*?)</importance>', page_content, re.DOTALL)
|
|
if importance_match:
|
|
importance = importance_match.group(1).strip()
|
|
clean_structure += f" <importance>{importance}</importance>\n"
|
|
|
|
# Extract relevant files
|
|
clean_structure += " <relevant_files>\n"
|
|
file_paths = re.findall(r'<file_path>(.*?)</file_path>', page_content, re.DOTALL)
|
|
for file_path in file_paths:
|
|
clean_structure += f" <file_path>{file_path.strip()}</file_path>\n"
|
|
clean_structure += " </relevant_files>\n"
|
|
|
|
# Extract related pages
|
|
clean_structure += " <related_pages>\n"
|
|
related_pages = re.findall(r'<related>(.*?)</related>', page_content, re.DOTALL)
|
|
for related in related_pages:
|
|
clean_structure += f" <related>{related.strip()}</related>\n"
|
|
clean_structure += " </related_pages>\n"
|
|
|
|
clean_structure += " </page>\n"
|
|
|
|
clean_structure += " </pages>\n</wiki_structure>"
|
|
|
|
log.info("Successfully rebuilt clean XML structure")
|
|
yield clean_structure
|
|
else:
|
|
log.warning("Could not extract wiki structure, using raw XML")
|
|
yield clean_xml
|
|
except Exception as rebuild_error:
|
|
log.warning(f"Failed to rebuild XML: {str(rebuild_error)}, using raw XML")
|
|
yield clean_xml
|
|
else:
|
|
# If we can't extract it, just yield the original content
|
|
log.warning("Could not extract wiki_structure XML, yielding original content")
|
|
yield xml_content
|
|
else:
|
|
# For other XML content, just yield it as is
|
|
yield content
|
|
except Exception as xml_error:
|
|
log.error(f"Error processing XML content: {str(xml_error)}")
|
|
yield content
|
|
else:
|
|
# Not XML, just yield the content
|
|
yield content
|
|
else:
|
|
log.error(f"Unexpected response format: {data}")
|
|
yield "Error: Unexpected response format from OpenRouter API"
|
|
else:
|
|
log.error(f"No choices in response: {data}")
|
|
yield "Error: No response content from OpenRouter API"
|
|
|
|
return content_generator()
|
|
except aiohttp.ClientError as e:
|
|
e_client = e
|
|
log.error(f"Connection error with OpenRouter API: {str(e_client)}")
|
|
|
|
# Return a generator that yields the error message
|
|
async def connection_error_generator():
|
|
yield f"Connection error with OpenRouter API: {str(e_client)}. Please check your internet connection and that the OpenRouter API is accessible."
|
|
return connection_error_generator()
|
|
|
|
except RequestException as e:
|
|
e_req = e
|
|
log.error(f"Error calling OpenRouter API asynchronously: {str(e_req)}")
|
|
|
|
# Return a generator that yields the error message
|
|
async def request_error_generator():
|
|
yield f"Error calling OpenRouter API: {str(e_req)}"
|
|
return request_error_generator()
|
|
|
|
except Exception as e:
|
|
e_unexp = e
|
|
log.error(f"Unexpected error calling OpenRouter API asynchronously: {str(e_unexp)}")
|
|
|
|
# Return a generator that yields the error message
|
|
async def unexpected_error_generator():
|
|
yield f"Unexpected error calling OpenRouter API: {str(e_unexp)}"
|
|
return unexpected_error_generator()
|
|
|
|
else:
|
|
error_msg = f"Unsupported model type: {model_type}"
|
|
log.error(error_msg)
|
|
|
|
# Return a generator that yields the error message
|
|
async def model_type_error_generator():
|
|
yield error_msg
|
|
return model_type_error_generator()
|
|
|
|
def _process_completion_response(self, data: Dict) -> GeneratorOutput:
|
|
"""Process a non-streaming completion response from OpenRouter."""
|
|
try:
|
|
# Extract the completion text from the response
|
|
if not data.get("choices"):
|
|
raise ValueError(f"No choices in OpenRouter response: {data}")
|
|
|
|
choice = data["choices"][0]
|
|
|
|
if "message" in choice:
|
|
content = choice["message"].get("content", "")
|
|
elif "text" in choice:
|
|
content = choice.get("text", "")
|
|
else:
|
|
raise ValueError(f"Unexpected response format from OpenRouter: {choice}")
|
|
|
|
# Extract usage information if available
|
|
usage = None
|
|
if "usage" in data:
|
|
usage = CompletionUsage(
|
|
prompt_tokens=data["usage"].get("prompt_tokens", 0),
|
|
completion_tokens=data["usage"].get("completion_tokens", 0),
|
|
total_tokens=data["usage"].get("total_tokens", 0)
|
|
)
|
|
|
|
# Create and return the GeneratorOutput
|
|
return GeneratorOutput(
|
|
data=content,
|
|
usage=usage,
|
|
raw_response=data
|
|
)
|
|
|
|
except Exception as e_proc:
|
|
log.error(f"Error processing OpenRouter completion response: {str(e_proc)}")
|
|
raise
|
|
|
|
def _process_streaming_response(self, response):
|
|
"""Process a streaming response from OpenRouter."""
|
|
try:
|
|
log.info("Starting to process streaming response from OpenRouter")
|
|
buffer = ""
|
|
|
|
for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
|
|
try:
|
|
# Add chunk to buffer
|
|
buffer += chunk
|
|
|
|
# Process complete lines in the buffer
|
|
while '\n' in buffer:
|
|
line, buffer = buffer.split('\n', 1)
|
|
line = line.strip()
|
|
|
|
if not line:
|
|
continue
|
|
|
|
log.debug(f"Processing line: {line}")
|
|
|
|
# Skip SSE comments (lines starting with :)
|
|
if line.startswith(':'):
|
|
log.debug(f"Skipping SSE comment: {line}")
|
|
continue
|
|
|
|
if line.startswith("data: "):
|
|
data = line[6:] # Remove "data: " prefix
|
|
|
|
# Check for stream end
|
|
if data == "[DONE]":
|
|
log.info("Received [DONE] marker")
|
|
break
|
|
|
|
try:
|
|
data_obj = json.loads(data)
|
|
log.debug(f"Parsed JSON data: {data_obj}")
|
|
|
|
# Extract content from delta
|
|
if "choices" in data_obj and len(data_obj["choices"]) > 0:
|
|
choice = data_obj["choices"][0]
|
|
|
|
if "delta" in choice and "content" in choice["delta"] and choice["delta"]["content"]:
|
|
content = choice["delta"]["content"]
|
|
log.debug(f"Yielding delta content: {content}")
|
|
yield content
|
|
elif "text" in choice:
|
|
log.debug(f"Yielding text content: {choice['text']}")
|
|
yield choice["text"]
|
|
else:
|
|
log.debug(f"No content found in choice: {choice}")
|
|
else:
|
|
log.debug(f"No choices found in data: {data_obj}")
|
|
|
|
except json.JSONDecodeError:
|
|
log.warning(f"Failed to parse SSE data: {data}")
|
|
continue
|
|
except Exception as e_chunk:
|
|
log.error(f"Error processing streaming chunk: {str(e_chunk)}")
|
|
yield f"Error processing response chunk: {str(e_chunk)}"
|
|
except Exception as e_stream:
|
|
log.error(f"Error in streaming response: {str(e_stream)}")
|
|
yield f"Error in streaming response: {str(e_stream)}"
|
|
|
|
async def _process_async_streaming_response(self, response):
|
|
"""Process an asynchronous streaming response from OpenRouter."""
|
|
buffer = ""
|
|
try:
|
|
log.info("Starting to process async streaming response from OpenRouter")
|
|
async for chunk in response.content:
|
|
try:
|
|
# Convert bytes to string and add to buffer
|
|
if isinstance(chunk, bytes):
|
|
chunk_str = chunk.decode('utf-8')
|
|
else:
|
|
chunk_str = str(chunk)
|
|
|
|
buffer += chunk_str
|
|
|
|
# Process complete lines in the buffer
|
|
while '\n' in buffer:
|
|
line, buffer = buffer.split('\n', 1)
|
|
line = line.strip()
|
|
|
|
if not line:
|
|
continue
|
|
|
|
log.debug(f"Processing line: {line}")
|
|
|
|
# Skip SSE comments (lines starting with :)
|
|
if line.startswith(':'):
|
|
log.debug(f"Skipping SSE comment: {line}")
|
|
continue
|
|
|
|
if line.startswith("data: "):
|
|
data = line[6:] # Remove "data: " prefix
|
|
|
|
# Check for stream end
|
|
if data == "[DONE]":
|
|
log.info("Received [DONE] marker")
|
|
break
|
|
|
|
try:
|
|
data_obj = json.loads(data)
|
|
log.debug(f"Parsed JSON data: {data_obj}")
|
|
|
|
# Extract content from delta
|
|
if "choices" in data_obj and len(data_obj["choices"]) > 0:
|
|
choice = data_obj["choices"][0]
|
|
|
|
if "delta" in choice and "content" in choice["delta"] and choice["delta"]["content"]:
|
|
content = choice["delta"]["content"]
|
|
log.debug(f"Yielding delta content: {content}")
|
|
yield content
|
|
elif "text" in choice:
|
|
log.debug(f"Yielding text content: {choice['text']}")
|
|
yield choice["text"]
|
|
else:
|
|
log.debug(f"No content found in choice: {choice}")
|
|
else:
|
|
log.debug(f"No choices found in data: {data_obj}")
|
|
|
|
except json.JSONDecodeError:
|
|
log.warning(f"Failed to parse SSE data: {data}")
|
|
continue
|
|
except Exception as e_chunk:
|
|
log.error(f"Error processing streaming chunk: {str(e_chunk)}")
|
|
yield f"Error processing response chunk: {str(e_chunk)}"
|
|
except Exception as e_stream:
|
|
log.error(f"Error in async streaming response: {str(e_stream)}")
|
|
yield f"Error in streaming response: {str(e_stream)}"
|