Files
B.Tech-Project-III/thirdeye/backend/providers.py
2026-04-05 00:43:23 +05:30

178 lines
7.3 KiB
Python

"""Multi-provider LLM router with automatic fallback on rate limits.
Groq pool: up to 3 API keys (GROQ_API_KEY, GROQ_API_KEY_2, GROQ_API_KEY_3) all running
llama-3.3-70b-versatile. Calls are round-robined across the pool so the per-key rate
limit is shared evenly. When a key is rate-limited the router falls through to the next
key in rotation, then to the rest of the fallback chain.
"""
import asyncio
import logging
from collections import defaultdict
from openai import AsyncOpenAI
from backend.config import (
GROQ_API_KEY, GROQ_API_KEY_2, GROQ_API_KEY_3,
CEREBRAS_API_KEY, SAMBANOVA_API_KEY,
OPENROUTER_API_KEY, GEMINI_API_KEY,
OLLAMA_BASE_URL, OLLAMA_ENABLED,
)
logger = logging.getLogger("thirdeye.providers")
# ── Client registry ──────────────────────────────────────────────────────────
_clients: dict[str, AsyncOpenAI] = {}
def _init_client(name: str, base_url: str, api_key: str | None):
if api_key and len(api_key) > 5:
_clients[name] = AsyncOpenAI(base_url=base_url, api_key=api_key)
# Ollama (local) — uses a dummy key; the OpenAI client requires a non-empty value
if OLLAMA_ENABLED:
_clients["ollama"] = AsyncOpenAI(base_url=OLLAMA_BASE_URL, api_key="ollama")
# Groq pool: register each key under its own name
_init_client("groq", "https://api.groq.com/openai/v1", GROQ_API_KEY)
_init_client("groq_2", "https://api.groq.com/openai/v1", GROQ_API_KEY_2)
_init_client("groq_3", "https://api.groq.com/openai/v1", GROQ_API_KEY_3)
_init_client("cerebras", "https://api.cerebras.ai/v1", CEREBRAS_API_KEY)
_init_client("sambanova", "https://api.sambanova.ai/v1", SAMBANOVA_API_KEY)
_init_client("openrouter", "https://openrouter.ai/api/v1", OPENROUTER_API_KEY)
_init_client("google", "https://generativelanguage.googleapis.com/v1beta/openai/", GEMINI_API_KEY)
# Which provider names belong to the Groq pool
_GROQ_POOL = [name for name in ("groq", "groq_2", "groq_3") if name in _clients]
_GROQ_MODEL = "llama-3.3-70b-versatile"
# Round-robin cursor per task_type (incremented after every call attempt on the pool)
_rr_cursor: dict[str, int] = defaultdict(int)
# ── Model registry ───────────────────────────────────────────────────────────
# Groq pool entries are expanded dynamically at call time so the cursor drives order.
# Use the sentinel string "groq_pool" to indicate "use all available Groq keys".
_GROQ_POOL_SENTINEL = "__groq_pool__"
MODEL_REGISTRY: dict[str, list[tuple[str, str]]] = {
"fast_small": [
("ollama", "llama3:8b"),
("groq", "llama-3.1-8b-instant"),
("cerebras", "llama-3.1-8b"),
("openrouter", "openai/gpt-oss-20b:free"),
],
"fast_large": [
(_GROQ_POOL_SENTINEL, _GROQ_MODEL), # expands to all 3 Groq keys (round-robin)
("openrouter", "arcee-ai/trinity-large-preview:free"),
("openrouter", "meta-llama/llama-3.3-70b-instruct:free"),
("sambanova", "Meta-Llama-3.3-70B-Instruct"),
("cerebras", "llama3.1-8b"),
],
"reasoning": [
("sambanova", "DeepSeek-R1-Distill-Llama-70B"),
("openrouter", "nvidia/nemotron-3-super-120b-a12b:free"),
("openrouter", "openai/gpt-oss-120b:free"),
],
"agentic": [
("openrouter", "minimax/minimax-m2.5:free"),
("openrouter", "nvidia/nemotron-3-super-120b-a12b:free"),
(_GROQ_POOL_SENTINEL, _GROQ_MODEL),
],
"fallback": [
("openrouter", "openrouter/free"),
("google", "gemini-2.5-flash"),
],
}
def _expand_candidates(task_type: str) -> list[tuple[str, str]]:
"""
Return the full candidate list for a task_type with the Groq pool sentinel
expanded into ordered (provider_name, model) tuples starting from the
current round-robin cursor position.
"""
raw = MODEL_REGISTRY.get(task_type, []) + MODEL_REGISTRY["fallback"]
expanded: list[tuple[str, str]] = []
for provider, model in raw:
if provider != _GROQ_POOL_SENTINEL:
expanded.append((provider, model))
continue
if not _GROQ_POOL:
continue
# Rotate: start from cursor, wrap around
start = _rr_cursor[task_type] % len(_GROQ_POOL)
ordered = _GROQ_POOL[start:] + _GROQ_POOL[:start]
for key_name in ordered:
expanded.append((key_name, model))
return expanded
# ── Public API ────────────────────────────────────────────────────────────────
async def call_llm(
task_type: str,
messages: list,
temperature: float = 0.3,
max_tokens: int = 2000,
response_format: dict = None,
) -> dict:
"""
Route to the best available provider with automatic fallback.
Returns:
{"content": str, "provider": str, "model": str}
"""
candidates = _expand_candidates(task_type)
errors = []
for provider_name, model_id in candidates:
client = _clients.get(provider_name)
if not client:
continue
try:
kwargs = {
"model": model_id,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"timeout": 45,
}
if response_format and provider_name not in ("google",):
kwargs["response_format"] = response_format
response = await client.chat.completions.create(**kwargs)
content = response.choices[0].message.content
# Advance round-robin cursor on success so next call starts from the
# following key, distributing load evenly across the pool.
if provider_name in _GROQ_POOL:
_rr_cursor[task_type] = (_rr_cursor[task_type] + 1) % len(_GROQ_POOL)
display_name = provider_name if provider_name not in ("groq_2", "groq_3") else f"groq[key{provider_name[-1]}]"
logger.info(f"LLM call success: {display_name}/{model_id} ({task_type})")
return {
"content": content,
"provider": display_name,
"model": model_id,
}
except Exception as e:
err = str(e).lower()
is_rate_limit = any(k in err for k in ["429", "rate", "quota", "limit", "exceeded", "capacity"])
is_timeout = "timeout" in err or "timed out" in err
if is_rate_limit or is_timeout:
logger.warning(f"Provider {provider_name}/{model_id} unavailable: {type(e).__name__}")
errors.append(f"{provider_name}: rate limited")
# Also advance cursor so the next call won't start on this key
if provider_name in _GROQ_POOL:
_rr_cursor[task_type] = (_rr_cursor[task_type] + 1) % len(_GROQ_POOL)
else:
logger.error(f"Provider {provider_name}/{model_id} error: {e}")
errors.append(f"{provider_name}: {e}")
continue
raise Exception(f"All LLM providers exhausted. Errors: {errors}")