"""ChromaDB setup and operations.""" import json import uuid import chromadb import logging from datetime import datetime from backend.config import CHROMA_DB_PATH from backend.db.embeddings import embed_texts, embed_query logger = logging.getLogger("thirdeye.chroma") # Initialize persistent client _chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH) def get_collection(group_id: str) -> chromadb.Collection: """Get or create a collection for a specific group.""" safe_name = f"ll_{group_id.replace('-', '_')}" # ChromaDB collection names: 3-63 chars, alphanumeric + underscores safe_name = safe_name[:63] return _chroma_client.get_or_create_collection(name=safe_name) def set_group_name(group_id: str, name: str): """Persist the human-readable Telegram group name in the collection metadata.""" if not name or name == group_id: return try: col = get_collection(group_id) existing = dict(col.metadata or {}) if existing.get("group_name") != name: existing["group_name"] = name col.modify(metadata=existing) except Exception as e: logger.warning(f"set_group_name failed for {group_id}: {e}") def get_group_names() -> dict[str, str]: """Return a mapping of group_id -> human-readable name (falls back to group_id).""" result = {} for col in _chroma_client.list_collections(): if not col.name.startswith("ll_"): continue group_id = col.name.replace("ll_", "").replace("_", "-") name = (col.metadata or {}).get("group_name", group_id) result[group_id] = name return result def store_signals(group_id: str, signals: list[dict]): """Store extracted signals in ChromaDB with embeddings.""" if not signals: return collection = get_collection(group_id) documents = [] metadatas = [] ids = [] for signal in signals: doc_text = f"{signal['type']}: {signal['summary']}" if signal.get('raw_quote'): doc_text += f" | Quote: {signal['raw_quote']}" documents.append(doc_text) metadatas.append({ "type": signal.get("type", "unknown"), "severity": signal.get("severity", "low"), "status": signal.get("status", "unknown"), "sentiment": signal.get("sentiment", "neutral"), "urgency": signal.get("urgency", "none"), "entities": json.dumps(signal.get("entities", [])), "keywords": json.dumps(signal.get("keywords", [])), "raw_quote": signal.get("raw_quote", ""), "summary": signal.get("summary", ""), "timestamp": signal.get("timestamp", datetime.utcnow().isoformat()), "group_id": group_id, "lens": signal.get("lens", "unknown"), "meeting_id": signal.get("meeting_id", ""), # Voice attribution — preserved so /voicelog and /ask can cite the source "source": signal.get("source", ""), "speaker": signal.get("speaker", ""), "voice_file_id": signal.get("voice_file_id", ""), "voice_duration": int(signal.get("voice_duration", 0) or 0), "voice_language": signal.get("voice_language", ""), # Jira tracking fields (populated for jira_raised signals) "jira_key": signal.get("jira_key", ""), "jira_url": signal.get("jira_url", ""), "jira_summary": signal.get("jira_summary", ""), "jira_priority": signal.get("jira_priority", ""), "original_signal_id": signal.get("original_signal_id", ""), }) ids.append(signal.get("id", str(uuid.uuid4()))) # Generate embeddings embeddings = embed_texts(documents) collection.add( documents=documents, metadatas=metadatas, embeddings=embeddings, ids=ids, ) logger.info(f"Stored {len(signals)} signals for group {group_id}") def query_signals(group_id: str, query: str, n_results: int = 10, signal_type: str = None) -> list[dict]: """Query the knowledge base with natural language.""" collection = get_collection(group_id) query_embedding = embed_query(query) where_filter = None if signal_type: where_filter = {"type": signal_type} try: results = collection.query( query_embeddings=[query_embedding], n_results=min(n_results, collection.count() or 1), where=where_filter, ) except Exception as e: logger.warning(f"Query failed: {e}") return [] # Format results output = [] if results and results["documents"]: for i, doc in enumerate(results["documents"][0]): meta = results["metadatas"][0][i] if results["metadatas"] else {} distance = results["distances"][0][i] if results["distances"] else None sig_id = results["ids"][0][i] if results.get("ids") else "" output.append({ "id": sig_id, "document": doc, "metadata": meta, "relevance_score": 1 - (distance or 0), # Convert distance to similarity }) return output def get_all_signals(group_id: str, signal_type: str = None) -> list[dict]: """Get all signals for a group (for pattern detection).""" collection = get_collection(group_id) count = collection.count() if count == 0: return [] where_filter = {"type": signal_type} if signal_type else None try: results = collection.get(where=where_filter, limit=count) except Exception: results = collection.get(limit=count) output = [] if results and results["documents"]: for i, doc in enumerate(results["documents"]): meta = results["metadatas"][i] if results["metadatas"] else {} output.append({"document": doc, "metadata": meta, "id": results["ids"][i]}) return output def get_group_ids() -> list[str]: """Get all group IDs that have collections.""" collections = _chroma_client.list_collections() return [c.name.replace("ll_", "").replace("_", "-") for c in collections if c.name.startswith("ll_")] def query_signals_global(query: str, n_results: int = 5, exclude_group_id: str = None) -> list[dict]: """ Search across ALL group collections for a query. Used as a cross-group fallback when local search returns weak results. Each result is annotated with its source group_id. """ collections = _chroma_client.list_collections() query_embedding = embed_query(query) all_results = [] for col_meta in collections: if not col_meta.name.startswith("ll_"): continue # Derive group_id from collection name raw = col_meta.name[len("ll_"):] group_id = raw.replace("_", "-") if exclude_group_id and group_id == exclude_group_id: continue try: col = _chroma_client.get_collection(col_meta.name) count = col.count() if count == 0: continue results = col.query( query_embeddings=[query_embedding], n_results=min(n_results, count), ) if results and results["documents"]: for i, doc in enumerate(results["documents"][0]): meta = results["metadatas"][0][i] if results["metadatas"] else {} distance = results["distances"][0][i] if results["distances"] else None all_results.append({ "document": doc, "metadata": meta, "relevance_score": 1 - (distance or 0), "source_group_id": group_id, }) except Exception as e: logger.warning(f"Global query failed for collection {col_meta.name}: {e}") continue # Sort by relevance and return top n_results all_results.sort(key=lambda x: x["relevance_score"], reverse=True) return all_results[:n_results] def mark_signal_as_raised( group_id: str, signal_id: str, jira_key: str, jira_url: str = "", jira_summary: str = "", jira_priority: str = "", ): """ Tag a signal with its Jira ticket key so we never raise it twice. Adds a new signal of type 'jira_raised' linked to the original signal_id. """ import uuid from datetime import datetime tracking_signal = { "id": str(uuid.uuid4()), "type": "jira_raised", "summary": jira_summary or f"Jira ticket {jira_key} raised for signal {signal_id}", "raw_quote": signal_id, # original signal_id — used by get_raised_signal_ids "severity": "low", "status": "raised", "sentiment": "neutral", "urgency": "none", "entities": [jira_key], "keywords": ["jira", jira_key, "raised"], "timestamp": datetime.utcnow().isoformat(), "group_id": group_id, "lens": "jira", # Jira tracking fields "jira_key": jira_key, "jira_url": jira_url, "jira_summary": jira_summary, "jira_priority": jira_priority, "original_signal_id": signal_id, } store_signals(group_id, [tracking_signal]) def get_raised_signal_ids(group_id: str) -> set[str]: """ Return the set of signal IDs that have already had Jira tickets raised. Used to prevent duplicates. """ collection = get_collection(group_id) try: results = collection.get(where={"type": "jira_raised"}) # raw_quote stores the original signal_id raised_ids = set() if results and results.get("metadatas"): for meta in results["metadatas"]: original_id = meta.get("raw_quote") # signal_id stored in raw_quote field if original_id: raised_ids.add(original_id) return raised_ids except Exception: return set()