mirror of
https://github.com/arkorty/B.Tech-Project-III.git
synced 2026-04-19 20:51:49 +00:00
init
This commit is contained in:
279
thirdeye/backend/db/chroma.py
Normal file
279
thirdeye/backend/db/chroma.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""ChromaDB setup and operations."""
|
||||
import json
|
||||
import uuid
|
||||
import chromadb
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from backend.config import CHROMA_DB_PATH
|
||||
from backend.db.embeddings import embed_texts, embed_query
|
||||
|
||||
logger = logging.getLogger("thirdeye.chroma")
|
||||
|
||||
# Initialize persistent client
|
||||
_chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
|
||||
|
||||
|
||||
def get_collection(group_id: str) -> chromadb.Collection:
|
||||
"""Get or create a collection for a specific group."""
|
||||
safe_name = f"ll_{group_id.replace('-', '_')}"
|
||||
# ChromaDB collection names: 3-63 chars, alphanumeric + underscores
|
||||
safe_name = safe_name[:63]
|
||||
return _chroma_client.get_or_create_collection(name=safe_name)
|
||||
|
||||
|
||||
def set_group_name(group_id: str, name: str):
|
||||
"""Persist the human-readable Telegram group name in the collection metadata."""
|
||||
if not name or name == group_id:
|
||||
return
|
||||
try:
|
||||
col = get_collection(group_id)
|
||||
existing = dict(col.metadata or {})
|
||||
if existing.get("group_name") != name:
|
||||
existing["group_name"] = name
|
||||
col.modify(metadata=existing)
|
||||
except Exception as e:
|
||||
logger.warning(f"set_group_name failed for {group_id}: {e}")
|
||||
|
||||
|
||||
def get_group_names() -> dict[str, str]:
|
||||
"""Return a mapping of group_id -> human-readable name (falls back to group_id)."""
|
||||
result = {}
|
||||
for col in _chroma_client.list_collections():
|
||||
if not col.name.startswith("ll_"):
|
||||
continue
|
||||
group_id = col.name.replace("ll_", "").replace("_", "-")
|
||||
name = (col.metadata or {}).get("group_name", group_id)
|
||||
result[group_id] = name
|
||||
return result
|
||||
|
||||
|
||||
def store_signals(group_id: str, signals: list[dict]):
|
||||
"""Store extracted signals in ChromaDB with embeddings."""
|
||||
if not signals:
|
||||
return
|
||||
|
||||
collection = get_collection(group_id)
|
||||
documents = []
|
||||
metadatas = []
|
||||
ids = []
|
||||
|
||||
for signal in signals:
|
||||
doc_text = f"{signal['type']}: {signal['summary']}"
|
||||
if signal.get('raw_quote'):
|
||||
doc_text += f" | Quote: {signal['raw_quote']}"
|
||||
|
||||
documents.append(doc_text)
|
||||
metadatas.append({
|
||||
"type": signal.get("type", "unknown"),
|
||||
"severity": signal.get("severity", "low"),
|
||||
"status": signal.get("status", "unknown"),
|
||||
"sentiment": signal.get("sentiment", "neutral"),
|
||||
"urgency": signal.get("urgency", "none"),
|
||||
"entities": json.dumps(signal.get("entities", [])),
|
||||
"keywords": json.dumps(signal.get("keywords", [])),
|
||||
"raw_quote": signal.get("raw_quote", ""),
|
||||
"summary": signal.get("summary", ""),
|
||||
"timestamp": signal.get("timestamp", datetime.utcnow().isoformat()),
|
||||
"group_id": group_id,
|
||||
"lens": signal.get("lens", "unknown"),
|
||||
"meeting_id": signal.get("meeting_id", ""),
|
||||
# Voice attribution — preserved so /voicelog and /ask can cite the source
|
||||
"source": signal.get("source", ""),
|
||||
"speaker": signal.get("speaker", ""),
|
||||
"voice_file_id": signal.get("voice_file_id", ""),
|
||||
"voice_duration": int(signal.get("voice_duration", 0) or 0),
|
||||
"voice_language": signal.get("voice_language", ""),
|
||||
# Jira tracking fields (populated for jira_raised signals)
|
||||
"jira_key": signal.get("jira_key", ""),
|
||||
"jira_url": signal.get("jira_url", ""),
|
||||
"jira_summary": signal.get("jira_summary", ""),
|
||||
"jira_priority": signal.get("jira_priority", ""),
|
||||
"original_signal_id": signal.get("original_signal_id", ""),
|
||||
})
|
||||
ids.append(signal.get("id", str(uuid.uuid4())))
|
||||
|
||||
# Generate embeddings
|
||||
embeddings = embed_texts(documents)
|
||||
|
||||
collection.add(
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
)
|
||||
logger.info(f"Stored {len(signals)} signals for group {group_id}")
|
||||
|
||||
|
||||
def query_signals(group_id: str, query: str, n_results: int = 10, signal_type: str = None) -> list[dict]:
|
||||
"""Query the knowledge base with natural language."""
|
||||
collection = get_collection(group_id)
|
||||
|
||||
query_embedding = embed_query(query)
|
||||
|
||||
where_filter = None
|
||||
if signal_type:
|
||||
where_filter = {"type": signal_type}
|
||||
|
||||
try:
|
||||
results = collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=min(n_results, collection.count() or 1),
|
||||
where=where_filter,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Query failed: {e}")
|
||||
return []
|
||||
|
||||
# Format results
|
||||
output = []
|
||||
if results and results["documents"]:
|
||||
for i, doc in enumerate(results["documents"][0]):
|
||||
meta = results["metadatas"][0][i] if results["metadatas"] else {}
|
||||
distance = results["distances"][0][i] if results["distances"] else None
|
||||
sig_id = results["ids"][0][i] if results.get("ids") else ""
|
||||
output.append({
|
||||
"id": sig_id,
|
||||
"document": doc,
|
||||
"metadata": meta,
|
||||
"relevance_score": 1 - (distance or 0), # Convert distance to similarity
|
||||
})
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_all_signals(group_id: str, signal_type: str = None) -> list[dict]:
|
||||
"""Get all signals for a group (for pattern detection)."""
|
||||
collection = get_collection(group_id)
|
||||
count = collection.count()
|
||||
if count == 0:
|
||||
return []
|
||||
|
||||
where_filter = {"type": signal_type} if signal_type else None
|
||||
|
||||
try:
|
||||
results = collection.get(where=where_filter, limit=count)
|
||||
except Exception:
|
||||
results = collection.get(limit=count)
|
||||
|
||||
output = []
|
||||
if results and results["documents"]:
|
||||
for i, doc in enumerate(results["documents"]):
|
||||
meta = results["metadatas"][i] if results["metadatas"] else {}
|
||||
output.append({"document": doc, "metadata": meta, "id": results["ids"][i]})
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_group_ids() -> list[str]:
|
||||
"""Get all group IDs that have collections."""
|
||||
collections = _chroma_client.list_collections()
|
||||
return [c.name.replace("ll_", "").replace("_", "-") for c in collections if c.name.startswith("ll_")]
|
||||
|
||||
|
||||
def query_signals_global(query: str, n_results: int = 5, exclude_group_id: str = None) -> list[dict]:
|
||||
"""
|
||||
Search across ALL group collections for a query.
|
||||
Used as a cross-group fallback when local search returns weak results.
|
||||
Each result is annotated with its source group_id.
|
||||
"""
|
||||
collections = _chroma_client.list_collections()
|
||||
query_embedding = embed_query(query)
|
||||
all_results = []
|
||||
|
||||
for col_meta in collections:
|
||||
if not col_meta.name.startswith("ll_"):
|
||||
continue
|
||||
|
||||
# Derive group_id from collection name
|
||||
raw = col_meta.name[len("ll_"):]
|
||||
group_id = raw.replace("_", "-")
|
||||
|
||||
if exclude_group_id and group_id == exclude_group_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
col = _chroma_client.get_collection(col_meta.name)
|
||||
count = col.count()
|
||||
if count == 0:
|
||||
continue
|
||||
|
||||
results = col.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=min(n_results, count),
|
||||
)
|
||||
|
||||
if results and results["documents"]:
|
||||
for i, doc in enumerate(results["documents"][0]):
|
||||
meta = results["metadatas"][0][i] if results["metadatas"] else {}
|
||||
distance = results["distances"][0][i] if results["distances"] else None
|
||||
all_results.append({
|
||||
"document": doc,
|
||||
"metadata": meta,
|
||||
"relevance_score": 1 - (distance or 0),
|
||||
"source_group_id": group_id,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Global query failed for collection {col_meta.name}: {e}")
|
||||
continue
|
||||
|
||||
# Sort by relevance and return top n_results
|
||||
all_results.sort(key=lambda x: x["relevance_score"], reverse=True)
|
||||
return all_results[:n_results]
|
||||
|
||||
def mark_signal_as_raised(
|
||||
group_id: str,
|
||||
signal_id: str,
|
||||
jira_key: str,
|
||||
jira_url: str = "",
|
||||
jira_summary: str = "",
|
||||
jira_priority: str = "",
|
||||
):
|
||||
"""
|
||||
Tag a signal with its Jira ticket key so we never raise it twice.
|
||||
Adds a new signal of type 'jira_raised' linked to the original signal_id.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
tracking_signal = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"type": "jira_raised",
|
||||
"summary": jira_summary or f"Jira ticket {jira_key} raised for signal {signal_id}",
|
||||
"raw_quote": signal_id, # original signal_id — used by get_raised_signal_ids
|
||||
"severity": "low",
|
||||
"status": "raised",
|
||||
"sentiment": "neutral",
|
||||
"urgency": "none",
|
||||
"entities": [jira_key],
|
||||
"keywords": ["jira", jira_key, "raised"],
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"group_id": group_id,
|
||||
"lens": "jira",
|
||||
# Jira tracking fields
|
||||
"jira_key": jira_key,
|
||||
"jira_url": jira_url,
|
||||
"jira_summary": jira_summary,
|
||||
"jira_priority": jira_priority,
|
||||
"original_signal_id": signal_id,
|
||||
}
|
||||
store_signals(group_id, [tracking_signal])
|
||||
|
||||
|
||||
def get_raised_signal_ids(group_id: str) -> set[str]:
|
||||
"""
|
||||
Return the set of signal IDs that have already had Jira tickets raised.
|
||||
Used to prevent duplicates.
|
||||
"""
|
||||
collection = get_collection(group_id)
|
||||
try:
|
||||
results = collection.get(where={"type": "jira_raised"})
|
||||
# raw_quote stores the original signal_id
|
||||
raised_ids = set()
|
||||
if results and results.get("metadatas"):
|
||||
for meta in results["metadatas"]:
|
||||
original_id = meta.get("raw_quote") # signal_id stored in raw_quote field
|
||||
if original_id:
|
||||
raised_ids.add(original_id)
|
||||
return raised_ids
|
||||
except Exception:
|
||||
return set()
|
||||
Reference in New Issue
Block a user