mirror of
https://github.com/arkorty/B.Tech-Project-III.git
synced 2026-04-19 20:51:49 +00:00
init
This commit is contained in:
67
thirdeye/backend/db/embeddings.py
Normal file
67
thirdeye/backend/db/embeddings.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Embedding provider with Cohere primary and local fallback."""
|
||||
import cohere
|
||||
import logging
|
||||
from backend.config import COHERE_API_KEY
|
||||
|
||||
logger = logging.getLogger("thirdeye.embeddings")
|
||||
|
||||
_cohere_client = None
|
||||
_local_model = None
|
||||
|
||||
def _get_cohere():
|
||||
global _cohere_client
|
||||
if _cohere_client is None and COHERE_API_KEY:
|
||||
_cohere_client = cohere.Client(COHERE_API_KEY)
|
||||
return _cohere_client
|
||||
|
||||
def _get_local_model():
|
||||
global _local_model
|
||||
if _local_model is None:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
_local_model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
logger.info("Loaded local embedding model: all-MiniLM-L6-v2")
|
||||
return _local_model
|
||||
|
||||
|
||||
def embed_texts(texts: list[str]) -> list[list[float]]:
|
||||
"""Embed a list of texts. Tries Cohere first, falls back to local model."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Try Cohere
|
||||
client = _get_cohere()
|
||||
if client:
|
||||
try:
|
||||
response = client.embed(
|
||||
texts=texts,
|
||||
model="embed-english-v3.0",
|
||||
input_type="search_document",
|
||||
)
|
||||
logger.info(f"Cohere embedded {len(texts)} texts")
|
||||
return [list(e) for e in response.embeddings]
|
||||
except Exception as e:
|
||||
logger.warning(f"Cohere embedding failed: {e}, falling back to local")
|
||||
|
||||
# Fallback to local
|
||||
model = _get_local_model()
|
||||
embeddings = model.encode(texts).tolist()
|
||||
logger.info(f"Local model embedded {len(texts)} texts")
|
||||
return embeddings
|
||||
|
||||
|
||||
def embed_query(text: str) -> list[float]:
|
||||
"""Embed a single query text."""
|
||||
client = _get_cohere()
|
||||
if client:
|
||||
try:
|
||||
response = client.embed(
|
||||
texts=[text],
|
||||
model="embed-english-v3.0",
|
||||
input_type="search_query",
|
||||
)
|
||||
return list(response.embeddings[0])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
model = _get_local_model()
|
||||
return model.encode([text]).tolist()[0]
|
||||
Reference in New Issue
Block a user