388 lines
14 KiB
Python
388 lines
14 KiB
Python
|
import os
|
||
|
import json
|
||
|
import logging
|
||
|
import re
|
||
|
from pathlib import Path
|
||
|
from typing import List, Union, Dict, Any
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
from api.openai_client import OpenAIClient
|
||
|
from api.openrouter_client import OpenRouterClient
|
||
|
from api.bedrock_client import BedrockClient
|
||
|
from api.google_embedder_client import GoogleEmbedderClient
|
||
|
from api.azureai_client import AzureAIClient
|
||
|
from api.dashscope_client import DashscopeClient
|
||
|
from adalflow import GoogleGenAIClient, OllamaClient
|
||
|
|
||
|
# Get API keys from environment variables
|
||
|
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
|
||
|
GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')
|
||
|
OPENROUTER_API_KEY = os.environ.get('OPENROUTER_API_KEY')
|
||
|
AWS_ACCESS_KEY_ID = os.environ.get('AWS_ACCESS_KEY_ID')
|
||
|
AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_SECRET_ACCESS_KEY')
|
||
|
AWS_REGION = os.environ.get('AWS_REGION')
|
||
|
AWS_ROLE_ARN = os.environ.get('AWS_ROLE_ARN')
|
||
|
|
||
|
# Set keys in environment (in case they're needed elsewhere in the code)
|
||
|
if OPENAI_API_KEY:
|
||
|
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
|
||
|
if GOOGLE_API_KEY:
|
||
|
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY
|
||
|
if OPENROUTER_API_KEY:
|
||
|
os.environ["OPENROUTER_API_KEY"] = OPENROUTER_API_KEY
|
||
|
if AWS_ACCESS_KEY_ID:
|
||
|
os.environ["AWS_ACCESS_KEY_ID"] = AWS_ACCESS_KEY_ID
|
||
|
if AWS_SECRET_ACCESS_KEY:
|
||
|
os.environ["AWS_SECRET_ACCESS_KEY"] = AWS_SECRET_ACCESS_KEY
|
||
|
if AWS_REGION:
|
||
|
os.environ["AWS_REGION"] = AWS_REGION
|
||
|
if AWS_ROLE_ARN:
|
||
|
os.environ["AWS_ROLE_ARN"] = AWS_ROLE_ARN
|
||
|
|
||
|
# Wiki authentication settings
|
||
|
raw_auth_mode = os.environ.get('DEEPWIKI_AUTH_MODE', 'False')
|
||
|
WIKI_AUTH_MODE = raw_auth_mode.lower() in ['true', '1', 't']
|
||
|
WIKI_AUTH_CODE = os.environ.get('DEEPWIKI_AUTH_CODE', '')
|
||
|
|
||
|
# Embedder settings
|
||
|
EMBEDDER_TYPE = os.environ.get('DEEPWIKI_EMBEDDER_TYPE', 'openai').lower()
|
||
|
|
||
|
# Get configuration directory from environment variable, or use default if not set
|
||
|
CONFIG_DIR = os.environ.get('DEEPWIKI_CONFIG_DIR', None)
|
||
|
|
||
|
# Client class mapping
|
||
|
CLIENT_CLASSES = {
|
||
|
"GoogleGenAIClient": GoogleGenAIClient,
|
||
|
"GoogleEmbedderClient": GoogleEmbedderClient,
|
||
|
"OpenAIClient": OpenAIClient,
|
||
|
"OpenRouterClient": OpenRouterClient,
|
||
|
"OllamaClient": OllamaClient,
|
||
|
"BedrockClient": BedrockClient,
|
||
|
"AzureAIClient": AzureAIClient,
|
||
|
"DashscopeClient": DashscopeClient
|
||
|
}
|
||
|
|
||
|
def replace_env_placeholders(config: Union[Dict[str, Any], List[Any], str, Any]) -> Union[Dict[str, Any], List[Any], str, Any]:
|
||
|
"""
|
||
|
Recursively replace placeholders like "${ENV_VAR}" in string values
|
||
|
within a nested configuration structure (dicts, lists, strings)
|
||
|
with environment variable values. Logs a warning if a placeholder is not found.
|
||
|
"""
|
||
|
pattern = re.compile(r"\$\{([A-Z0-9_]+)\}")
|
||
|
|
||
|
def replacer(match: re.Match[str]) -> str:
|
||
|
env_var_name = match.group(1)
|
||
|
original_placeholder = match.group(0)
|
||
|
env_var_value = os.environ.get(env_var_name)
|
||
|
if env_var_value is None:
|
||
|
logger.warning(
|
||
|
f"Environment variable placeholder '{original_placeholder}' was not found in the environment. "
|
||
|
f"The placeholder string will be used as is."
|
||
|
)
|
||
|
return original_placeholder
|
||
|
return env_var_value
|
||
|
|
||
|
if isinstance(config, dict):
|
||
|
return {k: replace_env_placeholders(v) for k, v in config.items()}
|
||
|
elif isinstance(config, list):
|
||
|
return [replace_env_placeholders(item) for item in config]
|
||
|
elif isinstance(config, str):
|
||
|
return pattern.sub(replacer, config)
|
||
|
else:
|
||
|
# Handles numbers, booleans, None, etc.
|
||
|
return config
|
||
|
|
||
|
# Load JSON configuration file
|
||
|
def load_json_config(filename):
|
||
|
try:
|
||
|
# If environment variable is set, use the directory specified by it
|
||
|
if CONFIG_DIR:
|
||
|
config_path = Path(CONFIG_DIR) / filename
|
||
|
else:
|
||
|
# Otherwise use default directory
|
||
|
config_path = Path(__file__).parent / "config" / filename
|
||
|
|
||
|
logger.info(f"Loading configuration from {config_path}")
|
||
|
|
||
|
if not config_path.exists():
|
||
|
logger.warning(f"Configuration file {config_path} does not exist")
|
||
|
return {}
|
||
|
|
||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||
|
config = json.load(f)
|
||
|
config = replace_env_placeholders(config)
|
||
|
return config
|
||
|
except Exception as e:
|
||
|
logger.error(f"Error loading configuration file {filename}: {str(e)}")
|
||
|
return {}
|
||
|
|
||
|
# Load generator model configuration
|
||
|
def load_generator_config():
|
||
|
generator_config = load_json_config("generator.json")
|
||
|
|
||
|
# Add client classes to each provider
|
||
|
if "providers" in generator_config:
|
||
|
for provider_id, provider_config in generator_config["providers"].items():
|
||
|
# Try to set client class from client_class
|
||
|
if provider_config.get("client_class") in CLIENT_CLASSES:
|
||
|
provider_config["model_client"] = CLIENT_CLASSES[provider_config["client_class"]]
|
||
|
# Fall back to default mapping based on provider_id
|
||
|
elif provider_id in ["google", "openai", "openrouter", "ollama", "bedrock", "azure", "dashscope"]:
|
||
|
default_map = {
|
||
|
"google": GoogleGenAIClient,
|
||
|
"openai": OpenAIClient,
|
||
|
"openrouter": OpenRouterClient,
|
||
|
"ollama": OllamaClient,
|
||
|
"bedrock": BedrockClient,
|
||
|
"azure": AzureAIClient,
|
||
|
"dashscope": DashscopeClient
|
||
|
}
|
||
|
provider_config["model_client"] = default_map[provider_id]
|
||
|
else:
|
||
|
logger.warning(f"Unknown provider or client class: {provider_id}")
|
||
|
|
||
|
return generator_config
|
||
|
|
||
|
# Load embedder configuration
|
||
|
def load_embedder_config():
|
||
|
embedder_config = load_json_config("embedder.json")
|
||
|
|
||
|
# Process client classes
|
||
|
for key in ["embedder", "embedder_ollama", "embedder_google"]:
|
||
|
if key in embedder_config and "client_class" in embedder_config[key]:
|
||
|
class_name = embedder_config[key]["client_class"]
|
||
|
if class_name in CLIENT_CLASSES:
|
||
|
embedder_config[key]["model_client"] = CLIENT_CLASSES[class_name]
|
||
|
|
||
|
return embedder_config
|
||
|
|
||
|
def get_embedder_config():
|
||
|
"""
|
||
|
Get the current embedder configuration based on DEEPWIKI_EMBEDDER_TYPE.
|
||
|
|
||
|
Returns:
|
||
|
dict: The embedder configuration with model_client resolved
|
||
|
"""
|
||
|
embedder_type = EMBEDDER_TYPE
|
||
|
if embedder_type == 'google' and 'embedder_google' in configs:
|
||
|
return configs.get("embedder_google", {})
|
||
|
elif embedder_type == 'ollama' and 'embedder_ollama' in configs:
|
||
|
return configs.get("embedder_ollama", {})
|
||
|
else:
|
||
|
return configs.get("embedder", {})
|
||
|
|
||
|
def is_ollama_embedder():
|
||
|
"""
|
||
|
Check if the current embedder configuration uses OllamaClient.
|
||
|
|
||
|
Returns:
|
||
|
bool: True if using OllamaClient, False otherwise
|
||
|
"""
|
||
|
embedder_config = get_embedder_config()
|
||
|
if not embedder_config:
|
||
|
return False
|
||
|
|
||
|
# Check if model_client is OllamaClient
|
||
|
model_client = embedder_config.get("model_client")
|
||
|
if model_client:
|
||
|
return model_client.__name__ == "OllamaClient"
|
||
|
|
||
|
# Fallback: check client_class string
|
||
|
client_class = embedder_config.get("client_class", "")
|
||
|
return client_class == "OllamaClient"
|
||
|
|
||
|
def is_google_embedder():
|
||
|
"""
|
||
|
Check if the current embedder configuration uses GoogleEmbedderClient.
|
||
|
|
||
|
Returns:
|
||
|
bool: True if using GoogleEmbedderClient, False otherwise
|
||
|
"""
|
||
|
embedder_config = get_embedder_config()
|
||
|
if not embedder_config:
|
||
|
return False
|
||
|
|
||
|
# Check if model_client is GoogleEmbedderClient
|
||
|
model_client = embedder_config.get("model_client")
|
||
|
if model_client:
|
||
|
return model_client.__name__ == "GoogleEmbedderClient"
|
||
|
|
||
|
# Fallback: check client_class string
|
||
|
client_class = embedder_config.get("client_class", "")
|
||
|
return client_class == "GoogleEmbedderClient"
|
||
|
|
||
|
def get_embedder_type():
|
||
|
"""
|
||
|
Get the current embedder type based on configuration.
|
||
|
|
||
|
Returns:
|
||
|
str: 'ollama', 'google', or 'openai' (default)
|
||
|
"""
|
||
|
if is_ollama_embedder():
|
||
|
return 'ollama'
|
||
|
elif is_google_embedder():
|
||
|
return 'google'
|
||
|
else:
|
||
|
return 'openai'
|
||
|
|
||
|
# Load repository and file filters configuration
|
||
|
def load_repo_config():
|
||
|
return load_json_config("repo.json")
|
||
|
|
||
|
# Load language configuration
|
||
|
def load_lang_config():
|
||
|
default_config = {
|
||
|
"supported_languages": {
|
||
|
"en": "English",
|
||
|
"ja": "Japanese (日本語)",
|
||
|
"zh": "Mandarin Chinese (中文)",
|
||
|
"zh-tw": "Traditional Chinese (繁體中文)",
|
||
|
"es": "Spanish (Español)",
|
||
|
"kr": "Korean (한국어)",
|
||
|
"vi": "Vietnamese (Tiếng Việt)",
|
||
|
"pt-br": "Brazilian Portuguese (Português Brasileiro)",
|
||
|
"fr": "Français (French)",
|
||
|
"ru": "Русский (Russian)"
|
||
|
},
|
||
|
"default": "en"
|
||
|
}
|
||
|
|
||
|
loaded_config = load_json_config("lang.json") # Let load_json_config handle path and loading
|
||
|
|
||
|
if not loaded_config:
|
||
|
return default_config
|
||
|
|
||
|
if "supported_languages" not in loaded_config or "default" not in loaded_config:
|
||
|
logger.warning("Language configuration file 'lang.json' is malformed. Using default language configuration.")
|
||
|
return default_config
|
||
|
|
||
|
return loaded_config
|
||
|
|
||
|
# Default excluded directories and files
|
||
|
DEFAULT_EXCLUDED_DIRS: List[str] = [
|
||
|
# Virtual environments and package managers
|
||
|
"./.venv/", "./venv/", "./env/", "./virtualenv/",
|
||
|
"./node_modules/", "./bower_components/", "./jspm_packages/",
|
||
|
# Version control
|
||
|
"./.git/", "./.svn/", "./.hg/", "./.bzr/",
|
||
|
# Cache and compiled files
|
||
|
"./__pycache__/", "./.pytest_cache/", "./.mypy_cache/", "./.ruff_cache/", "./.coverage/",
|
||
|
# Build and distribution
|
||
|
"./dist/", "./build/", "./out/", "./target/", "./bin/", "./obj/",
|
||
|
# Documentation
|
||
|
"./docs/", "./_docs/", "./site-docs/", "./_site/",
|
||
|
# IDE specific
|
||
|
"./.idea/", "./.vscode/", "./.vs/", "./.eclipse/", "./.settings/",
|
||
|
# Logs and temporary files
|
||
|
"./logs/", "./log/", "./tmp/", "./temp/",
|
||
|
]
|
||
|
|
||
|
DEFAULT_EXCLUDED_FILES: List[str] = [
|
||
|
"yarn.lock", "pnpm-lock.yaml", "npm-shrinkwrap.json", "poetry.lock",
|
||
|
"Pipfile.lock", "requirements.txt.lock", "Cargo.lock", "composer.lock",
|
||
|
".lock", ".DS_Store", "Thumbs.db", "desktop.ini", "*.lnk", ".env",
|
||
|
".env.*", "*.env", "*.cfg", "*.ini", ".flaskenv", ".gitignore",
|
||
|
".gitattributes", ".gitmodules", ".github", ".gitlab-ci.yml",
|
||
|
".prettierrc", ".eslintrc", ".eslintignore", ".stylelintrc",
|
||
|
".editorconfig", ".jshintrc", ".pylintrc", ".flake8", "mypy.ini",
|
||
|
"pyproject.toml", "tsconfig.json", "webpack.config.js", "babel.config.js",
|
||
|
"rollup.config.js", "jest.config.js", "karma.conf.js", "vite.config.js",
|
||
|
"next.config.js", "*.min.js", "*.min.css", "*.bundle.js", "*.bundle.css",
|
||
|
"*.map", "*.gz", "*.zip", "*.tar", "*.tgz", "*.rar", "*.7z", "*.iso",
|
||
|
"*.dmg", "*.img", "*.msix", "*.appx", "*.appxbundle", "*.xap", "*.ipa",
|
||
|
"*.deb", "*.rpm", "*.msi", "*.exe", "*.dll", "*.so", "*.dylib", "*.o",
|
||
|
"*.obj", "*.jar", "*.war", "*.ear", "*.jsm", "*.class", "*.pyc", "*.pyd",
|
||
|
"*.pyo", "__pycache__", "*.a", "*.lib", "*.lo", "*.la", "*.slo", "*.dSYM",
|
||
|
"*.egg", "*.egg-info", "*.dist-info", "*.eggs", "node_modules",
|
||
|
"bower_components", "jspm_packages", "lib-cov", "coverage", "htmlcov",
|
||
|
".nyc_output", ".tox", "dist", "build", "bld", "out", "bin", "target",
|
||
|
"packages/*/dist", "packages/*/build", ".output"
|
||
|
]
|
||
|
|
||
|
# Initialize empty configuration
|
||
|
configs = {}
|
||
|
|
||
|
# Load all configuration files
|
||
|
generator_config = load_generator_config()
|
||
|
embedder_config = load_embedder_config()
|
||
|
repo_config = load_repo_config()
|
||
|
lang_config = load_lang_config()
|
||
|
|
||
|
# Update configuration
|
||
|
if generator_config:
|
||
|
configs["default_provider"] = generator_config.get("default_provider", "google")
|
||
|
configs["providers"] = generator_config.get("providers", {})
|
||
|
|
||
|
# Update embedder configuration
|
||
|
if embedder_config:
|
||
|
for key in ["embedder", "embedder_ollama", "embedder_google", "retriever", "text_splitter"]:
|
||
|
if key in embedder_config:
|
||
|
configs[key] = embedder_config[key]
|
||
|
|
||
|
# Update repository configuration
|
||
|
if repo_config:
|
||
|
for key in ["file_filters", "repository"]:
|
||
|
if key in repo_config:
|
||
|
configs[key] = repo_config[key]
|
||
|
|
||
|
# Update language configuration
|
||
|
if lang_config:
|
||
|
configs["lang_config"] = lang_config
|
||
|
|
||
|
|
||
|
def get_model_config(provider="google", model=None):
|
||
|
"""
|
||
|
Get configuration for the specified provider and model
|
||
|
|
||
|
Parameters:
|
||
|
provider (str): Model provider ('google', 'openai', 'openrouter', 'ollama', 'bedrock')
|
||
|
model (str): Model name, or None to use default model
|
||
|
|
||
|
Returns:
|
||
|
dict: Configuration containing model_client, model and other parameters
|
||
|
"""
|
||
|
# Get provider configuration
|
||
|
if "providers" not in configs:
|
||
|
raise ValueError("Provider configuration not loaded")
|
||
|
|
||
|
provider_config = configs["providers"].get(provider)
|
||
|
if not provider_config:
|
||
|
raise ValueError(f"Configuration for provider '{provider}' not found")
|
||
|
|
||
|
model_client = provider_config.get("model_client")
|
||
|
if not model_client:
|
||
|
raise ValueError(f"Model client not specified for provider '{provider}'")
|
||
|
|
||
|
# If model not provided, use default model for the provider
|
||
|
if not model:
|
||
|
model = provider_config.get("default_model")
|
||
|
if not model:
|
||
|
raise ValueError(f"No default model specified for provider '{provider}'")
|
||
|
|
||
|
# Get model parameters (if present)
|
||
|
model_params = {}
|
||
|
if model in provider_config.get("models", {}):
|
||
|
model_params = provider_config["models"][model]
|
||
|
else:
|
||
|
default_model = provider_config.get("default_model")
|
||
|
model_params = provider_config["models"][default_model]
|
||
|
|
||
|
# Prepare base configuration
|
||
|
result = {
|
||
|
"model_client": model_client,
|
||
|
}
|
||
|
|
||
|
# Provider-specific adjustments
|
||
|
if provider == "ollama":
|
||
|
# Ollama uses a slightly different parameter structure
|
||
|
if "options" in model_params:
|
||
|
result["model_kwargs"] = {"model": model, **model_params["options"]}
|
||
|
else:
|
||
|
result["model_kwargs"] = {"model": model}
|
||
|
else:
|
||
|
# Standard structure for other providers
|
||
|
result["model_kwargs"] = {"model": model, **model_params}
|
||
|
|
||
|
return result
|