SLA-RedM/reference-deepwiki/deepwiki-open-main/api/rag.py

446 lines
19 KiB
Python

import logging
import weakref
import re
from dataclasses import dataclass
from typing import Any, List, Tuple, Dict
from uuid import uuid4
import adalflow as adal
from api.tools.embedder import get_embedder
from api.prompts import RAG_SYSTEM_PROMPT as system_prompt, RAG_TEMPLATE
# Create our own implementation of the conversation classes
@dataclass
class UserQuery:
query_str: str
@dataclass
class AssistantResponse:
response_str: str
@dataclass
class DialogTurn:
id: str
user_query: UserQuery
assistant_response: AssistantResponse
class CustomConversation:
"""Custom implementation of Conversation to fix the list assignment index out of range error"""
def __init__(self):
self.dialog_turns = []
def append_dialog_turn(self, dialog_turn):
"""Safely append a dialog turn to the conversation"""
if not hasattr(self, 'dialog_turns'):
self.dialog_turns = []
self.dialog_turns.append(dialog_turn)
# Import other adalflow components
from adalflow.components.retriever.faiss_retriever import FAISSRetriever
from api.config import configs
from api.data_pipeline import DatabaseManager
# Configure logging
logger = logging.getLogger(__name__)
# Maximum token limit for embedding models
MAX_INPUT_TOKENS = 7500 # Safe threshold below 8192 token limit
class Memory(adal.core.component.DataComponent):
"""Simple conversation management with a list of dialog turns."""
def __init__(self):
super().__init__()
# Use our custom implementation instead of the original Conversation class
self.current_conversation = CustomConversation()
def call(self) -> Dict:
"""Return the conversation history as a dictionary."""
all_dialog_turns = {}
try:
# Check if dialog_turns exists and is a list
if hasattr(self.current_conversation, 'dialog_turns'):
if self.current_conversation.dialog_turns:
logger.info(f"Memory content: {len(self.current_conversation.dialog_turns)} turns")
for i, turn in enumerate(self.current_conversation.dialog_turns):
if hasattr(turn, 'id') and turn.id is not None:
all_dialog_turns[turn.id] = turn
logger.info(f"Added turn {i+1} with ID {turn.id} to memory")
else:
logger.warning(f"Skipping invalid turn object in memory: {turn}")
else:
logger.info("Dialog turns list exists but is empty")
else:
logger.info("No dialog_turns attribute in current_conversation")
# Try to initialize it
self.current_conversation.dialog_turns = []
except Exception as e:
logger.error(f"Error accessing dialog turns: {str(e)}")
# Try to recover
try:
self.current_conversation = CustomConversation()
logger.info("Recovered by creating new conversation")
except Exception as e2:
logger.error(f"Failed to recover: {str(e2)}")
logger.info(f"Returning {len(all_dialog_turns)} dialog turns from memory")
return all_dialog_turns
def add_dialog_turn(self, user_query: str, assistant_response: str) -> bool:
"""
Add a dialog turn to the conversation history.
Args:
user_query: The user's query
assistant_response: The assistant's response
Returns:
bool: True if successful, False otherwise
"""
try:
# Create a new dialog turn using our custom implementation
dialog_turn = DialogTurn(
id=str(uuid4()),
user_query=UserQuery(query_str=user_query),
assistant_response=AssistantResponse(response_str=assistant_response),
)
# Make sure the current_conversation has the append_dialog_turn method
if not hasattr(self.current_conversation, 'append_dialog_turn'):
logger.warning("current_conversation does not have append_dialog_turn method, creating new one")
# Initialize a new conversation if needed
self.current_conversation = CustomConversation()
# Ensure dialog_turns exists
if not hasattr(self.current_conversation, 'dialog_turns'):
logger.warning("dialog_turns not found, initializing empty list")
self.current_conversation.dialog_turns = []
# Safely append the dialog turn
self.current_conversation.dialog_turns.append(dialog_turn)
logger.info(f"Successfully added dialog turn, now have {len(self.current_conversation.dialog_turns)} turns")
return True
except Exception as e:
logger.error(f"Error adding dialog turn: {str(e)}")
# Try to recover by creating a new conversation
try:
self.current_conversation = CustomConversation()
dialog_turn = DialogTurn(
id=str(uuid4()),
user_query=UserQuery(query_str=user_query),
assistant_response=AssistantResponse(response_str=assistant_response),
)
self.current_conversation.dialog_turns.append(dialog_turn)
logger.info("Recovered from error by creating new conversation")
return True
except Exception as e2:
logger.error(f"Failed to recover from error: {str(e2)}")
return False
from dataclasses import dataclass, field
@dataclass
class RAGAnswer(adal.DataClass):
rationale: str = field(default="", metadata={"desc": "Chain of thoughts for the answer."})
answer: str = field(default="", metadata={"desc": "Answer to the user query, formatted in markdown for beautiful rendering with react-markdown. DO NOT include ``` triple backticks fences at the beginning or end of your answer."})
__output_fields__ = ["rationale", "answer"]
class RAG(adal.Component):
"""RAG with one repo.
If you want to load a new repos, call prepare_retriever(repo_url_or_path) first."""
def __init__(self, provider="google", model=None, use_s3: bool = False): # noqa: F841 - use_s3 is kept for compatibility
"""
Initialize the RAG component.
Args:
provider: Model provider to use (google, openai, openrouter, ollama)
model: Model name to use with the provider
use_s3: Whether to use S3 for database storage (default: False)
"""
super().__init__()
self.provider = provider
self.model = model
# Import the helper functions
from api.config import get_embedder_config, get_embedder_type
# Determine embedder type based on current configuration
self.embedder_type = get_embedder_type()
self.is_ollama_embedder = (self.embedder_type == 'ollama') # Backward compatibility
# Check if Ollama model exists before proceeding
if self.is_ollama_embedder:
from api.ollama_patch import check_ollama_model_exists
from api.config import get_embedder_config
embedder_config = get_embedder_config()
if embedder_config and embedder_config.get("model_kwargs", {}).get("model"):
model_name = embedder_config["model_kwargs"]["model"]
if not check_ollama_model_exists(model_name):
raise Exception(f"Ollama model '{model_name}' not found. Please run 'ollama pull {model_name}' to install it.")
# Initialize components
self.memory = Memory()
self.embedder = get_embedder(embedder_type=self.embedder_type)
self_weakref = weakref.ref(self)
# Patch: ensure query embedding is always single string for Ollama
def single_string_embedder(query):
# Accepts either a string or a list, always returns embedding for a single string
if isinstance(query, list):
if len(query) != 1:
raise ValueError("Ollama embedder only supports a single string")
query = query[0]
instance = self_weakref()
assert instance is not None, "RAG instance is no longer available, but the query embedder was called."
return instance.embedder(input=query)
# Use single string embedder for Ollama, regular embedder for others
self.query_embedder = single_string_embedder if self.is_ollama_embedder else self.embedder
self.initialize_db_manager()
# Set up the output parser
data_parser = adal.DataClassParser(data_class=RAGAnswer, return_data_class=True)
# Format instructions to ensure proper output structure
format_instructions = data_parser.get_output_format_str() + """
IMPORTANT FORMATTING RULES:
1. DO NOT include your thinking or reasoning process in the output
2. Provide only the final, polished answer
3. DO NOT include ```markdown fences at the beginning or end of your answer
4. DO NOT wrap your response in any kind of fences
5. Start your response directly with the content
6. The content will already be rendered as markdown
7. Do not use backslashes before special characters like [ ] { } in your answer
8. When listing tags or similar items, write them as plain text without escape characters
9. For pipe characters (|) in text, write them directly without escaping them"""
# Get model configuration based on provider and model
from api.config import get_model_config
generator_config = get_model_config(self.provider, self.model)
# Set up the main generator
self.generator = adal.Generator(
template=RAG_TEMPLATE,
prompt_kwargs={
"output_format_str": format_instructions,
"conversation_history": self.memory(),
"system_prompt": system_prompt,
"contexts": None,
},
model_client=generator_config["model_client"](),
model_kwargs=generator_config["model_kwargs"],
output_processors=data_parser,
)
def initialize_db_manager(self):
"""Initialize the database manager with local storage"""
self.db_manager = DatabaseManager()
self.transformed_docs = []
def _validate_and_filter_embeddings(self, documents: List) -> List:
"""
Validate embeddings and filter out documents with invalid or mismatched embedding sizes.
Args:
documents: List of documents with embeddings
Returns:
List of documents with valid embeddings of consistent size
"""
if not documents:
logger.warning("No documents provided for embedding validation")
return []
valid_documents = []
embedding_sizes = {}
# First pass: collect all embedding sizes and count occurrences
for i, doc in enumerate(documents):
if not hasattr(doc, 'vector') or doc.vector is None:
logger.warning(f"Document {i} has no embedding vector, skipping")
continue
try:
if isinstance(doc.vector, list):
embedding_size = len(doc.vector)
elif hasattr(doc.vector, 'shape'):
embedding_size = doc.vector.shape[0] if len(doc.vector.shape) == 1 else doc.vector.shape[-1]
elif hasattr(doc.vector, '__len__'):
embedding_size = len(doc.vector)
else:
logger.warning(f"Document {i} has invalid embedding vector type: {type(doc.vector)}, skipping")
continue
if embedding_size == 0:
logger.warning(f"Document {i} has empty embedding vector, skipping")
continue
embedding_sizes[embedding_size] = embedding_sizes.get(embedding_size, 0) + 1
except Exception as e:
logger.warning(f"Error checking embedding size for document {i}: {str(e)}, skipping")
continue
if not embedding_sizes:
logger.error("No valid embeddings found in any documents")
return []
# Find the most common embedding size (this should be the correct one)
target_size = max(embedding_sizes.keys(), key=lambda k: embedding_sizes[k])
logger.info(f"Target embedding size: {target_size} (found in {embedding_sizes[target_size]} documents)")
# Log all embedding sizes found
for size, count in embedding_sizes.items():
if size != target_size:
logger.warning(f"Found {count} documents with incorrect embedding size {size}, will be filtered out")
# Second pass: filter documents with the target embedding size
for i, doc in enumerate(documents):
if not hasattr(doc, 'vector') or doc.vector is None:
continue
try:
if isinstance(doc.vector, list):
embedding_size = len(doc.vector)
elif hasattr(doc.vector, 'shape'):
embedding_size = doc.vector.shape[0] if len(doc.vector.shape) == 1 else doc.vector.shape[-1]
elif hasattr(doc.vector, '__len__'):
embedding_size = len(doc.vector)
else:
continue
if embedding_size == target_size:
valid_documents.append(doc)
else:
# Log which document is being filtered out
file_path = getattr(doc, 'meta_data', {}).get('file_path', f'document_{i}')
logger.warning(f"Filtering out document '{file_path}' due to embedding size mismatch: {embedding_size} != {target_size}")
except Exception as e:
file_path = getattr(doc, 'meta_data', {}).get('file_path', f'document_{i}')
logger.warning(f"Error validating embedding for document '{file_path}': {str(e)}, skipping")
continue
logger.info(f"Embedding validation complete: {len(valid_documents)}/{len(documents)} documents have valid embeddings")
if len(valid_documents) == 0:
logger.error("No documents with valid embeddings remain after filtering")
elif len(valid_documents) < len(documents):
filtered_count = len(documents) - len(valid_documents)
logger.warning(f"Filtered out {filtered_count} documents due to embedding issues")
return valid_documents
def prepare_retriever(self, repo_url_or_path: str, type: str = "github", access_token: str = None,
excluded_dirs: List[str] = None, excluded_files: List[str] = None,
included_dirs: List[str] = None, included_files: List[str] = None):
"""
Prepare the retriever for a repository.
Will load database from local storage if available.
Args:
repo_url_or_path: URL or local path to the repository
access_token: Optional access token for private repositories
excluded_dirs: Optional list of directories to exclude from processing
excluded_files: Optional list of file patterns to exclude from processing
included_dirs: Optional list of directories to include exclusively
included_files: Optional list of file patterns to include exclusively
"""
self.initialize_db_manager()
self.repo_url_or_path = repo_url_or_path
self.transformed_docs = self.db_manager.prepare_database(
repo_url_or_path,
type,
access_token,
embedder_type=self.embedder_type,
excluded_dirs=excluded_dirs,
excluded_files=excluded_files,
included_dirs=included_dirs,
included_files=included_files
)
logger.info(f"Loaded {len(self.transformed_docs)} documents for retrieval")
# Validate and filter embeddings to ensure consistent sizes
self.transformed_docs = self._validate_and_filter_embeddings(self.transformed_docs)
if not self.transformed_docs:
raise ValueError("No valid documents with embeddings found. Cannot create retriever.")
logger.info(f"Using {len(self.transformed_docs)} documents with valid embeddings for retrieval")
try:
# Use the appropriate embedder for retrieval
retrieve_embedder = self.query_embedder if self.is_ollama_embedder else self.embedder
self.retriever = FAISSRetriever(
**configs["retriever"],
embedder=retrieve_embedder,
documents=self.transformed_docs,
document_map_func=lambda doc: doc.vector,
)
logger.info("FAISS retriever created successfully")
except Exception as e:
logger.error(f"Error creating FAISS retriever: {str(e)}")
# Try to provide more specific error information
if "All embeddings should be of the same size" in str(e):
logger.error("Embedding size validation failed. This suggests there are still inconsistent embedding sizes.")
# Log embedding sizes for debugging
sizes = []
for i, doc in enumerate(self.transformed_docs[:10]): # Check first 10 docs
if hasattr(doc, 'vector') and doc.vector is not None:
try:
if isinstance(doc.vector, list):
size = len(doc.vector)
elif hasattr(doc.vector, 'shape'):
size = doc.vector.shape[0] if len(doc.vector.shape) == 1 else doc.vector.shape[-1]
elif hasattr(doc.vector, '__len__'):
size = len(doc.vector)
else:
size = "unknown"
sizes.append(f"doc_{i}: {size}")
except:
sizes.append(f"doc_{i}: error")
logger.error(f"Sample embedding sizes: {', '.join(sizes)}")
raise
def call(self, query: str, language: str = "en") -> Tuple[List]:
"""
Process a query using RAG.
Args:
query: The user's query
Returns:
Tuple of (RAGAnswer, retrieved_documents)
"""
try:
retrieved_documents = self.retriever(query)
# Fill in the documents
retrieved_documents[0].documents = [
self.transformed_docs[doc_index]
for doc_index in retrieved_documents[0].doc_indices
]
return retrieved_documents
except Exception as e:
logger.error(f"Error in RAG call: {str(e)}")
# Create error response
error_response = RAGAnswer(
rationale="Error occurred while processing the query.",
answer=f"I apologize, but I encountered an error while processing your question. Please try again or rephrase your question."
)
return error_response, []