105 lines
4.5 KiB
Python
105 lines
4.5 KiB
Python
from typing import Sequence, List
|
|
from copy import deepcopy
|
|
from tqdm import tqdm
|
|
import logging
|
|
import adalflow as adal
|
|
from adalflow.core.types import Document
|
|
from adalflow.core.component import DataComponent
|
|
import requests
|
|
import os
|
|
|
|
# Configure logging
|
|
from api.logging_config import setup_logging
|
|
|
|
setup_logging()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class OllamaModelNotFoundError(Exception):
|
|
"""Custom exception for when Ollama model is not found"""
|
|
pass
|
|
|
|
def check_ollama_model_exists(model_name: str, ollama_host: str = None) -> bool:
|
|
"""
|
|
Check if an Ollama model exists before attempting to use it.
|
|
|
|
Args:
|
|
model_name: Name of the model to check
|
|
ollama_host: Ollama host URL, defaults to localhost:11434
|
|
|
|
Returns:
|
|
bool: True if model exists, False otherwise
|
|
"""
|
|
if ollama_host is None:
|
|
ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434")
|
|
|
|
try:
|
|
# Remove /api prefix if present and add it back
|
|
if ollama_host.endswith('/api'):
|
|
ollama_host = ollama_host[:-4]
|
|
|
|
response = requests.get(f"{ollama_host}/api/tags", timeout=5)
|
|
if response.status_code == 200:
|
|
models_data = response.json()
|
|
available_models = [model.get('name', '').split(':')[0] for model in models_data.get('models', [])]
|
|
model_base_name = model_name.split(':')[0] # Remove tag if present
|
|
|
|
is_available = model_base_name in available_models
|
|
if is_available:
|
|
logger.info(f"Ollama model '{model_name}' is available")
|
|
else:
|
|
logger.warning(f"Ollama model '{model_name}' is not available. Available models: {available_models}")
|
|
return is_available
|
|
else:
|
|
logger.warning(f"Could not check Ollama models, status code: {response.status_code}")
|
|
return False
|
|
except requests.exceptions.RequestException as e:
|
|
logger.warning(f"Could not connect to Ollama to check models: {e}")
|
|
return False
|
|
except Exception as e:
|
|
logger.warning(f"Error checking Ollama model availability: {e}")
|
|
return False
|
|
|
|
class OllamaDocumentProcessor(DataComponent):
|
|
"""
|
|
Process documents for Ollama embeddings by processing one document at a time.
|
|
Adalflow Ollama Client does not support batch embedding, so we need to process each document individually.
|
|
"""
|
|
def __init__(self, embedder: adal.Embedder) -> None:
|
|
super().__init__()
|
|
self.embedder = embedder
|
|
|
|
def __call__(self, documents: Sequence[Document]) -> Sequence[Document]:
|
|
output = deepcopy(documents)
|
|
logger.info(f"Processing {len(output)} documents individually for Ollama embeddings")
|
|
|
|
successful_docs = []
|
|
expected_embedding_size = None
|
|
|
|
for i, doc in enumerate(tqdm(output, desc="Processing documents for Ollama embeddings")):
|
|
try:
|
|
# Get embedding for a single document
|
|
result = self.embedder(input=doc.text)
|
|
if result.data and len(result.data) > 0:
|
|
embedding = result.data[0].embedding
|
|
|
|
# Validate embedding size consistency
|
|
if expected_embedding_size is None:
|
|
expected_embedding_size = len(embedding)
|
|
logger.info(f"Expected embedding size set to: {expected_embedding_size}")
|
|
elif len(embedding) != expected_embedding_size:
|
|
file_path = getattr(doc, 'meta_data', {}).get('file_path', f'document_{i}')
|
|
logger.warning(f"Document '{file_path}' has inconsistent embedding size {len(embedding)} != {expected_embedding_size}, skipping")
|
|
continue
|
|
|
|
# Assign the embedding to the document
|
|
output[i].vector = embedding
|
|
successful_docs.append(output[i])
|
|
else:
|
|
file_path = getattr(doc, 'meta_data', {}).get('file_path', f'document_{i}')
|
|
logger.warning(f"Failed to get embedding for document '{file_path}', skipping")
|
|
except Exception as e:
|
|
file_path = getattr(doc, 'meta_data', {}).get('file_path', f'document_{i}')
|
|
logger.error(f"Error processing document '{file_path}': {e}, skipping")
|
|
|
|
logger.info(f"Successfully processed {len(successful_docs)}/{len(output)} documents with consistent embeddings")
|
|
return successful_docs |