"""Embedding provider with Cohere primary and local fallback.""" import cohere import logging from backend.config import COHERE_API_KEY logger = logging.getLogger("thirdeye.embeddings") _cohere_client = None _local_model = None def _get_cohere(): global _cohere_client if _cohere_client is None and COHERE_API_KEY: _cohere_client = cohere.Client(COHERE_API_KEY) return _cohere_client def _get_local_model(): global _local_model if _local_model is None: from sentence_transformers import SentenceTransformer _local_model = SentenceTransformer("all-MiniLM-L6-v2") logger.info("Loaded local embedding model: all-MiniLM-L6-v2") return _local_model def embed_texts(texts: list[str]) -> list[list[float]]: """Embed a list of texts. Tries Cohere first, falls back to local model.""" if not texts: return [] # Try Cohere client = _get_cohere() if client: try: response = client.embed( texts=texts, model="embed-english-v3.0", input_type="search_document", ) logger.info(f"Cohere embedded {len(texts)} texts") return [list(e) for e in response.embeddings] except Exception as e: logger.warning(f"Cohere embedding failed: {e}, falling back to local") # Fallback to local model = _get_local_model() embeddings = model.encode(texts).tolist() logger.info(f"Local model embedded {len(texts)} texts") return embeddings def embed_query(text: str) -> list[float]: """Embed a single query text.""" client = _get_cohere() if client: try: response = client.embed( texts=[text], model="embed-english-v3.0", input_type="search_query", ) return list(response.embeddings[0]) except Exception: pass model = _get_local_model() return model.encode([text]).tolist()[0]