Files
B.Tech-Project-III/thirdeye/backend/db/embeddings.py
2026-04-05 00:43:23 +05:30

68 lines
2.0 KiB
Python

"""Embedding provider with Cohere primary and local fallback."""
import cohere
import logging
from backend.config import COHERE_API_KEY
logger = logging.getLogger("thirdeye.embeddings")
_cohere_client = None
_local_model = None
def _get_cohere():
global _cohere_client
if _cohere_client is None and COHERE_API_KEY:
_cohere_client = cohere.Client(COHERE_API_KEY)
return _cohere_client
def _get_local_model():
global _local_model
if _local_model is None:
from sentence_transformers import SentenceTransformer
_local_model = SentenceTransformer("all-MiniLM-L6-v2")
logger.info("Loaded local embedding model: all-MiniLM-L6-v2")
return _local_model
def embed_texts(texts: list[str]) -> list[list[float]]:
"""Embed a list of texts. Tries Cohere first, falls back to local model."""
if not texts:
return []
# Try Cohere
client = _get_cohere()
if client:
try:
response = client.embed(
texts=texts,
model="embed-english-v3.0",
input_type="search_document",
)
logger.info(f"Cohere embedded {len(texts)} texts")
return [list(e) for e in response.embeddings]
except Exception as e:
logger.warning(f"Cohere embedding failed: {e}, falling back to local")
# Fallback to local
model = _get_local_model()
embeddings = model.encode(texts).tolist()
logger.info(f"Local model embedded {len(texts)} texts")
return embeddings
def embed_query(text: str) -> list[float]:
"""Embed a single query text."""
client = _get_cohere()
if client:
try:
response = client.embed(
texts=[text],
model="embed-english-v3.0",
input_type="search_query",
)
return list(response.embeddings[0])
except Exception:
pass
model = _get_local_model()
return model.encode([text]).tolist()[0]