mirror of
https://github.com/arkorty/B.Tech-Project-III.git
synced 2026-04-19 20:51:49 +00:00
68 lines
2.0 KiB
Python
68 lines
2.0 KiB
Python
"""Embedding provider with Cohere primary and local fallback."""
|
|
import cohere
|
|
import logging
|
|
from backend.config import COHERE_API_KEY
|
|
|
|
logger = logging.getLogger("thirdeye.embeddings")
|
|
|
|
_cohere_client = None
|
|
_local_model = None
|
|
|
|
def _get_cohere():
|
|
global _cohere_client
|
|
if _cohere_client is None and COHERE_API_KEY:
|
|
_cohere_client = cohere.Client(COHERE_API_KEY)
|
|
return _cohere_client
|
|
|
|
def _get_local_model():
|
|
global _local_model
|
|
if _local_model is None:
|
|
from sentence_transformers import SentenceTransformer
|
|
_local_model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
logger.info("Loaded local embedding model: all-MiniLM-L6-v2")
|
|
return _local_model
|
|
|
|
|
|
def embed_texts(texts: list[str]) -> list[list[float]]:
|
|
"""Embed a list of texts. Tries Cohere first, falls back to local model."""
|
|
if not texts:
|
|
return []
|
|
|
|
# Try Cohere
|
|
client = _get_cohere()
|
|
if client:
|
|
try:
|
|
response = client.embed(
|
|
texts=texts,
|
|
model="embed-english-v3.0",
|
|
input_type="search_document",
|
|
)
|
|
logger.info(f"Cohere embedded {len(texts)} texts")
|
|
return [list(e) for e in response.embeddings]
|
|
except Exception as e:
|
|
logger.warning(f"Cohere embedding failed: {e}, falling back to local")
|
|
|
|
# Fallback to local
|
|
model = _get_local_model()
|
|
embeddings = model.encode(texts).tolist()
|
|
logger.info(f"Local model embedded {len(texts)} texts")
|
|
return embeddings
|
|
|
|
|
|
def embed_query(text: str) -> list[float]:
|
|
"""Embed a single query text."""
|
|
client = _get_cohere()
|
|
if client:
|
|
try:
|
|
response = client.embed(
|
|
texts=[text],
|
|
model="embed-english-v3.0",
|
|
input_type="search_query",
|
|
)
|
|
return list(response.embeddings[0])
|
|
except Exception:
|
|
pass
|
|
|
|
model = _get_local_model()
|
|
return model.encode([text]).tolist()[0]
|