Files
B.Tech-Project-III/thirdeye/backend/db/chroma.py
2026-04-05 00:43:23 +05:30

280 lines
9.9 KiB
Python

"""ChromaDB setup and operations."""
import json
import uuid
import chromadb
import logging
from datetime import datetime
from backend.config import CHROMA_DB_PATH
from backend.db.embeddings import embed_texts, embed_query
logger = logging.getLogger("thirdeye.chroma")
# Initialize persistent client
_chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
def get_collection(group_id: str) -> chromadb.Collection:
"""Get or create a collection for a specific group."""
safe_name = f"ll_{group_id.replace('-', '_')}"
# ChromaDB collection names: 3-63 chars, alphanumeric + underscores
safe_name = safe_name[:63]
return _chroma_client.get_or_create_collection(name=safe_name)
def set_group_name(group_id: str, name: str):
"""Persist the human-readable Telegram group name in the collection metadata."""
if not name or name == group_id:
return
try:
col = get_collection(group_id)
existing = dict(col.metadata or {})
if existing.get("group_name") != name:
existing["group_name"] = name
col.modify(metadata=existing)
except Exception as e:
logger.warning(f"set_group_name failed for {group_id}: {e}")
def get_group_names() -> dict[str, str]:
"""Return a mapping of group_id -> human-readable name (falls back to group_id)."""
result = {}
for col in _chroma_client.list_collections():
if not col.name.startswith("ll_"):
continue
group_id = col.name.replace("ll_", "").replace("_", "-")
name = (col.metadata or {}).get("group_name", group_id)
result[group_id] = name
return result
def store_signals(group_id: str, signals: list[dict]):
"""Store extracted signals in ChromaDB with embeddings."""
if not signals:
return
collection = get_collection(group_id)
documents = []
metadatas = []
ids = []
for signal in signals:
doc_text = f"{signal['type']}: {signal['summary']}"
if signal.get('raw_quote'):
doc_text += f" | Quote: {signal['raw_quote']}"
documents.append(doc_text)
metadatas.append({
"type": signal.get("type", "unknown"),
"severity": signal.get("severity", "low"),
"status": signal.get("status", "unknown"),
"sentiment": signal.get("sentiment", "neutral"),
"urgency": signal.get("urgency", "none"),
"entities": json.dumps(signal.get("entities", [])),
"keywords": json.dumps(signal.get("keywords", [])),
"raw_quote": signal.get("raw_quote", ""),
"summary": signal.get("summary", ""),
"timestamp": signal.get("timestamp", datetime.utcnow().isoformat()),
"group_id": group_id,
"lens": signal.get("lens", "unknown"),
"meeting_id": signal.get("meeting_id", ""),
# Voice attribution — preserved so /voicelog and /ask can cite the source
"source": signal.get("source", ""),
"speaker": signal.get("speaker", ""),
"voice_file_id": signal.get("voice_file_id", ""),
"voice_duration": int(signal.get("voice_duration", 0) or 0),
"voice_language": signal.get("voice_language", ""),
# Jira tracking fields (populated for jira_raised signals)
"jira_key": signal.get("jira_key", ""),
"jira_url": signal.get("jira_url", ""),
"jira_summary": signal.get("jira_summary", ""),
"jira_priority": signal.get("jira_priority", ""),
"original_signal_id": signal.get("original_signal_id", ""),
})
ids.append(signal.get("id", str(uuid.uuid4())))
# Generate embeddings
embeddings = embed_texts(documents)
collection.add(
documents=documents,
metadatas=metadatas,
embeddings=embeddings,
ids=ids,
)
logger.info(f"Stored {len(signals)} signals for group {group_id}")
def query_signals(group_id: str, query: str, n_results: int = 10, signal_type: str = None) -> list[dict]:
"""Query the knowledge base with natural language."""
collection = get_collection(group_id)
query_embedding = embed_query(query)
where_filter = None
if signal_type:
where_filter = {"type": signal_type}
try:
results = collection.query(
query_embeddings=[query_embedding],
n_results=min(n_results, collection.count() or 1),
where=where_filter,
)
except Exception as e:
logger.warning(f"Query failed: {e}")
return []
# Format results
output = []
if results and results["documents"]:
for i, doc in enumerate(results["documents"][0]):
meta = results["metadatas"][0][i] if results["metadatas"] else {}
distance = results["distances"][0][i] if results["distances"] else None
sig_id = results["ids"][0][i] if results.get("ids") else ""
output.append({
"id": sig_id,
"document": doc,
"metadata": meta,
"relevance_score": 1 - (distance or 0), # Convert distance to similarity
})
return output
def get_all_signals(group_id: str, signal_type: str = None) -> list[dict]:
"""Get all signals for a group (for pattern detection)."""
collection = get_collection(group_id)
count = collection.count()
if count == 0:
return []
where_filter = {"type": signal_type} if signal_type else None
try:
results = collection.get(where=where_filter, limit=count)
except Exception:
results = collection.get(limit=count)
output = []
if results and results["documents"]:
for i, doc in enumerate(results["documents"]):
meta = results["metadatas"][i] if results["metadatas"] else {}
output.append({"document": doc, "metadata": meta, "id": results["ids"][i]})
return output
def get_group_ids() -> list[str]:
"""Get all group IDs that have collections."""
collections = _chroma_client.list_collections()
return [c.name.replace("ll_", "").replace("_", "-") for c in collections if c.name.startswith("ll_")]
def query_signals_global(query: str, n_results: int = 5, exclude_group_id: str = None) -> list[dict]:
"""
Search across ALL group collections for a query.
Used as a cross-group fallback when local search returns weak results.
Each result is annotated with its source group_id.
"""
collections = _chroma_client.list_collections()
query_embedding = embed_query(query)
all_results = []
for col_meta in collections:
if not col_meta.name.startswith("ll_"):
continue
# Derive group_id from collection name
raw = col_meta.name[len("ll_"):]
group_id = raw.replace("_", "-")
if exclude_group_id and group_id == exclude_group_id:
continue
try:
col = _chroma_client.get_collection(col_meta.name)
count = col.count()
if count == 0:
continue
results = col.query(
query_embeddings=[query_embedding],
n_results=min(n_results, count),
)
if results and results["documents"]:
for i, doc in enumerate(results["documents"][0]):
meta = results["metadatas"][0][i] if results["metadatas"] else {}
distance = results["distances"][0][i] if results["distances"] else None
all_results.append({
"document": doc,
"metadata": meta,
"relevance_score": 1 - (distance or 0),
"source_group_id": group_id,
})
except Exception as e:
logger.warning(f"Global query failed for collection {col_meta.name}: {e}")
continue
# Sort by relevance and return top n_results
all_results.sort(key=lambda x: x["relevance_score"], reverse=True)
return all_results[:n_results]
def mark_signal_as_raised(
group_id: str,
signal_id: str,
jira_key: str,
jira_url: str = "",
jira_summary: str = "",
jira_priority: str = "",
):
"""
Tag a signal with its Jira ticket key so we never raise it twice.
Adds a new signal of type 'jira_raised' linked to the original signal_id.
"""
import uuid
from datetime import datetime
tracking_signal = {
"id": str(uuid.uuid4()),
"type": "jira_raised",
"summary": jira_summary or f"Jira ticket {jira_key} raised for signal {signal_id}",
"raw_quote": signal_id, # original signal_id — used by get_raised_signal_ids
"severity": "low",
"status": "raised",
"sentiment": "neutral",
"urgency": "none",
"entities": [jira_key],
"keywords": ["jira", jira_key, "raised"],
"timestamp": datetime.utcnow().isoformat(),
"group_id": group_id,
"lens": "jira",
# Jira tracking fields
"jira_key": jira_key,
"jira_url": jira_url,
"jira_summary": jira_summary,
"jira_priority": jira_priority,
"original_signal_id": signal_id,
}
store_signals(group_id, [tracking_signal])
def get_raised_signal_ids(group_id: str) -> set[str]:
"""
Return the set of signal IDs that have already had Jira tickets raised.
Used to prevent duplicates.
"""
collection = get_collection(group_id)
try:
results = collection.get(where={"type": "jira_raised"})
# raw_quote stores the original signal_id
raised_ids = set()
if results and results.get("metadatas"):
for meta in results["metadatas"]:
original_id = meta.get("raw_quote") # signal_id stored in raw_quote field
if original_id:
raised_ids.add(original_id)
return raised_ids
except Exception:
return set()