mirror of
https://github.com/arkorty/B.Tech-Project-III.git
synced 2026-04-19 20:51:49 +00:00
280 lines
9.9 KiB
Python
280 lines
9.9 KiB
Python
"""ChromaDB setup and operations."""
|
|
import json
|
|
import uuid
|
|
import chromadb
|
|
import logging
|
|
from datetime import datetime
|
|
from backend.config import CHROMA_DB_PATH
|
|
from backend.db.embeddings import embed_texts, embed_query
|
|
|
|
logger = logging.getLogger("thirdeye.chroma")
|
|
|
|
# Initialize persistent client
|
|
_chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
|
|
|
|
|
|
def get_collection(group_id: str) -> chromadb.Collection:
|
|
"""Get or create a collection for a specific group."""
|
|
safe_name = f"ll_{group_id.replace('-', '_')}"
|
|
# ChromaDB collection names: 3-63 chars, alphanumeric + underscores
|
|
safe_name = safe_name[:63]
|
|
return _chroma_client.get_or_create_collection(name=safe_name)
|
|
|
|
|
|
def set_group_name(group_id: str, name: str):
|
|
"""Persist the human-readable Telegram group name in the collection metadata."""
|
|
if not name or name == group_id:
|
|
return
|
|
try:
|
|
col = get_collection(group_id)
|
|
existing = dict(col.metadata or {})
|
|
if existing.get("group_name") != name:
|
|
existing["group_name"] = name
|
|
col.modify(metadata=existing)
|
|
except Exception as e:
|
|
logger.warning(f"set_group_name failed for {group_id}: {e}")
|
|
|
|
|
|
def get_group_names() -> dict[str, str]:
|
|
"""Return a mapping of group_id -> human-readable name (falls back to group_id)."""
|
|
result = {}
|
|
for col in _chroma_client.list_collections():
|
|
if not col.name.startswith("ll_"):
|
|
continue
|
|
group_id = col.name.replace("ll_", "").replace("_", "-")
|
|
name = (col.metadata or {}).get("group_name", group_id)
|
|
result[group_id] = name
|
|
return result
|
|
|
|
|
|
def store_signals(group_id: str, signals: list[dict]):
|
|
"""Store extracted signals in ChromaDB with embeddings."""
|
|
if not signals:
|
|
return
|
|
|
|
collection = get_collection(group_id)
|
|
documents = []
|
|
metadatas = []
|
|
ids = []
|
|
|
|
for signal in signals:
|
|
doc_text = f"{signal['type']}: {signal['summary']}"
|
|
if signal.get('raw_quote'):
|
|
doc_text += f" | Quote: {signal['raw_quote']}"
|
|
|
|
documents.append(doc_text)
|
|
metadatas.append({
|
|
"type": signal.get("type", "unknown"),
|
|
"severity": signal.get("severity", "low"),
|
|
"status": signal.get("status", "unknown"),
|
|
"sentiment": signal.get("sentiment", "neutral"),
|
|
"urgency": signal.get("urgency", "none"),
|
|
"entities": json.dumps(signal.get("entities", [])),
|
|
"keywords": json.dumps(signal.get("keywords", [])),
|
|
"raw_quote": signal.get("raw_quote", ""),
|
|
"summary": signal.get("summary", ""),
|
|
"timestamp": signal.get("timestamp", datetime.utcnow().isoformat()),
|
|
"group_id": group_id,
|
|
"lens": signal.get("lens", "unknown"),
|
|
"meeting_id": signal.get("meeting_id", ""),
|
|
# Voice attribution — preserved so /voicelog and /ask can cite the source
|
|
"source": signal.get("source", ""),
|
|
"speaker": signal.get("speaker", ""),
|
|
"voice_file_id": signal.get("voice_file_id", ""),
|
|
"voice_duration": int(signal.get("voice_duration", 0) or 0),
|
|
"voice_language": signal.get("voice_language", ""),
|
|
# Jira tracking fields (populated for jira_raised signals)
|
|
"jira_key": signal.get("jira_key", ""),
|
|
"jira_url": signal.get("jira_url", ""),
|
|
"jira_summary": signal.get("jira_summary", ""),
|
|
"jira_priority": signal.get("jira_priority", ""),
|
|
"original_signal_id": signal.get("original_signal_id", ""),
|
|
})
|
|
ids.append(signal.get("id", str(uuid.uuid4())))
|
|
|
|
# Generate embeddings
|
|
embeddings = embed_texts(documents)
|
|
|
|
collection.add(
|
|
documents=documents,
|
|
metadatas=metadatas,
|
|
embeddings=embeddings,
|
|
ids=ids,
|
|
)
|
|
logger.info(f"Stored {len(signals)} signals for group {group_id}")
|
|
|
|
|
|
def query_signals(group_id: str, query: str, n_results: int = 10, signal_type: str = None) -> list[dict]:
|
|
"""Query the knowledge base with natural language."""
|
|
collection = get_collection(group_id)
|
|
|
|
query_embedding = embed_query(query)
|
|
|
|
where_filter = None
|
|
if signal_type:
|
|
where_filter = {"type": signal_type}
|
|
|
|
try:
|
|
results = collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=min(n_results, collection.count() or 1),
|
|
where=where_filter,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Query failed: {e}")
|
|
return []
|
|
|
|
# Format results
|
|
output = []
|
|
if results and results["documents"]:
|
|
for i, doc in enumerate(results["documents"][0]):
|
|
meta = results["metadatas"][0][i] if results["metadatas"] else {}
|
|
distance = results["distances"][0][i] if results["distances"] else None
|
|
sig_id = results["ids"][0][i] if results.get("ids") else ""
|
|
output.append({
|
|
"id": sig_id,
|
|
"document": doc,
|
|
"metadata": meta,
|
|
"relevance_score": 1 - (distance or 0), # Convert distance to similarity
|
|
})
|
|
|
|
return output
|
|
|
|
|
|
def get_all_signals(group_id: str, signal_type: str = None) -> list[dict]:
|
|
"""Get all signals for a group (for pattern detection)."""
|
|
collection = get_collection(group_id)
|
|
count = collection.count()
|
|
if count == 0:
|
|
return []
|
|
|
|
where_filter = {"type": signal_type} if signal_type else None
|
|
|
|
try:
|
|
results = collection.get(where=where_filter, limit=count)
|
|
except Exception:
|
|
results = collection.get(limit=count)
|
|
|
|
output = []
|
|
if results and results["documents"]:
|
|
for i, doc in enumerate(results["documents"]):
|
|
meta = results["metadatas"][i] if results["metadatas"] else {}
|
|
output.append({"document": doc, "metadata": meta, "id": results["ids"][i]})
|
|
|
|
return output
|
|
|
|
|
|
def get_group_ids() -> list[str]:
|
|
"""Get all group IDs that have collections."""
|
|
collections = _chroma_client.list_collections()
|
|
return [c.name.replace("ll_", "").replace("_", "-") for c in collections if c.name.startswith("ll_")]
|
|
|
|
|
|
def query_signals_global(query: str, n_results: int = 5, exclude_group_id: str = None) -> list[dict]:
|
|
"""
|
|
Search across ALL group collections for a query.
|
|
Used as a cross-group fallback when local search returns weak results.
|
|
Each result is annotated with its source group_id.
|
|
"""
|
|
collections = _chroma_client.list_collections()
|
|
query_embedding = embed_query(query)
|
|
all_results = []
|
|
|
|
for col_meta in collections:
|
|
if not col_meta.name.startswith("ll_"):
|
|
continue
|
|
|
|
# Derive group_id from collection name
|
|
raw = col_meta.name[len("ll_"):]
|
|
group_id = raw.replace("_", "-")
|
|
|
|
if exclude_group_id and group_id == exclude_group_id:
|
|
continue
|
|
|
|
try:
|
|
col = _chroma_client.get_collection(col_meta.name)
|
|
count = col.count()
|
|
if count == 0:
|
|
continue
|
|
|
|
results = col.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=min(n_results, count),
|
|
)
|
|
|
|
if results and results["documents"]:
|
|
for i, doc in enumerate(results["documents"][0]):
|
|
meta = results["metadatas"][0][i] if results["metadatas"] else {}
|
|
distance = results["distances"][0][i] if results["distances"] else None
|
|
all_results.append({
|
|
"document": doc,
|
|
"metadata": meta,
|
|
"relevance_score": 1 - (distance or 0),
|
|
"source_group_id": group_id,
|
|
})
|
|
except Exception as e:
|
|
logger.warning(f"Global query failed for collection {col_meta.name}: {e}")
|
|
continue
|
|
|
|
# Sort by relevance and return top n_results
|
|
all_results.sort(key=lambda x: x["relevance_score"], reverse=True)
|
|
return all_results[:n_results]
|
|
|
|
def mark_signal_as_raised(
|
|
group_id: str,
|
|
signal_id: str,
|
|
jira_key: str,
|
|
jira_url: str = "",
|
|
jira_summary: str = "",
|
|
jira_priority: str = "",
|
|
):
|
|
"""
|
|
Tag a signal with its Jira ticket key so we never raise it twice.
|
|
Adds a new signal of type 'jira_raised' linked to the original signal_id.
|
|
"""
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
tracking_signal = {
|
|
"id": str(uuid.uuid4()),
|
|
"type": "jira_raised",
|
|
"summary": jira_summary or f"Jira ticket {jira_key} raised for signal {signal_id}",
|
|
"raw_quote": signal_id, # original signal_id — used by get_raised_signal_ids
|
|
"severity": "low",
|
|
"status": "raised",
|
|
"sentiment": "neutral",
|
|
"urgency": "none",
|
|
"entities": [jira_key],
|
|
"keywords": ["jira", jira_key, "raised"],
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"group_id": group_id,
|
|
"lens": "jira",
|
|
# Jira tracking fields
|
|
"jira_key": jira_key,
|
|
"jira_url": jira_url,
|
|
"jira_summary": jira_summary,
|
|
"jira_priority": jira_priority,
|
|
"original_signal_id": signal_id,
|
|
}
|
|
store_signals(group_id, [tracking_signal])
|
|
|
|
|
|
def get_raised_signal_ids(group_id: str) -> set[str]:
|
|
"""
|
|
Return the set of signal IDs that have already had Jira tickets raised.
|
|
Used to prevent duplicates.
|
|
"""
|
|
collection = get_collection(group_id)
|
|
try:
|
|
results = collection.get(where={"type": "jira_raised"})
|
|
# raw_quote stores the original signal_id
|
|
raised_ids = set()
|
|
if results and results.get("metadatas"):
|
|
for meta in results["metadatas"]:
|
|
original_id = meta.get("raw_quote") # signal_id stored in raw_quote field
|
|
if original_id:
|
|
raised_ids.add(original_id)
|
|
return raised_ids
|
|
except Exception:
|
|
return set()
|