SLA-RedM/reference-deepwiki/api/ollama_patch.py

105 lines
4.5 KiB
Python

from typing import Sequence, List
from copy import deepcopy
from tqdm import tqdm
import logging
import adalflow as adal
from adalflow.core.types import Document
from adalflow.core.component import DataComponent
import requests
import os
# Configure logging
from api.logging_config import setup_logging
setup_logging()
logger = logging.getLogger(__name__)
class OllamaModelNotFoundError(Exception):
"""Custom exception for when Ollama model is not found"""
pass
def check_ollama_model_exists(model_name: str, ollama_host: str = None) -> bool:
"""
Check if an Ollama model exists before attempting to use it.
Args:
model_name: Name of the model to check
ollama_host: Ollama host URL, defaults to localhost:11434
Returns:
bool: True if model exists, False otherwise
"""
if ollama_host is None:
ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434")
try:
# Remove /api prefix if present and add it back
if ollama_host.endswith('/api'):
ollama_host = ollama_host[:-4]
response = requests.get(f"{ollama_host}/api/tags", timeout=5)
if response.status_code == 200:
models_data = response.json()
available_models = [model.get('name', '').split(':')[0] for model in models_data.get('models', [])]
model_base_name = model_name.split(':')[0] # Remove tag if present
is_available = model_base_name in available_models
if is_available:
logger.info(f"Ollama model '{model_name}' is available")
else:
logger.warning(f"Ollama model '{model_name}' is not available. Available models: {available_models}")
return is_available
else:
logger.warning(f"Could not check Ollama models, status code: {response.status_code}")
return False
except requests.exceptions.RequestException as e:
logger.warning(f"Could not connect to Ollama to check models: {e}")
return False
except Exception as e:
logger.warning(f"Error checking Ollama model availability: {e}")
return False
class OllamaDocumentProcessor(DataComponent):
"""
Process documents for Ollama embeddings by processing one document at a time.
Adalflow Ollama Client does not support batch embedding, so we need to process each document individually.
"""
def __init__(self, embedder: adal.Embedder) -> None:
super().__init__()
self.embedder = embedder
def __call__(self, documents: Sequence[Document]) -> Sequence[Document]:
output = deepcopy(documents)
logger.info(f"Processing {len(output)} documents individually for Ollama embeddings")
successful_docs = []
expected_embedding_size = None
for i, doc in enumerate(tqdm(output, desc="Processing documents for Ollama embeddings")):
try:
# Get embedding for a single document
result = self.embedder(input=doc.text)
if result.data and len(result.data) > 0:
embedding = result.data[0].embedding
# Validate embedding size consistency
if expected_embedding_size is None:
expected_embedding_size = len(embedding)
logger.info(f"Expected embedding size set to: {expected_embedding_size}")
elif len(embedding) != expected_embedding_size:
file_path = getattr(doc, 'meta_data', {}).get('file_path', f'document_{i}')
logger.warning(f"Document '{file_path}' has inconsistent embedding size {len(embedding)} != {expected_embedding_size}, skipping")
continue
# Assign the embedding to the document
output[i].vector = embedding
successful_docs.append(output[i])
else:
file_path = getattr(doc, 'meta_data', {}).get('file_path', f'document_{i}')
logger.warning(f"Failed to get embedding for document '{file_path}', skipping")
except Exception as e:
file_path = getattr(doc, 'meta_data', {}).get('file_path', f'document_{i}')
logger.error(f"Error processing document '{file_path}': {e}, skipping")
logger.info(f"Successfully processed {len(successful_docs)}/{len(output)} documents with consistent embeddings")
return successful_docs