mirror of
https://github.com/arkorty/B.Tech-Project-III.git
synced 2026-04-19 12:41:48 +00:00
178 lines
7.3 KiB
Python
178 lines
7.3 KiB
Python
"""Multi-provider LLM router with automatic fallback on rate limits.
|
|
|
|
Groq pool: up to 3 API keys (GROQ_API_KEY, GROQ_API_KEY_2, GROQ_API_KEY_3) all running
|
|
llama-3.3-70b-versatile. Calls are round-robined across the pool so the per-key rate
|
|
limit is shared evenly. When a key is rate-limited the router falls through to the next
|
|
key in rotation, then to the rest of the fallback chain.
|
|
"""
|
|
import asyncio
|
|
import logging
|
|
from collections import defaultdict
|
|
from openai import AsyncOpenAI
|
|
from backend.config import (
|
|
GROQ_API_KEY, GROQ_API_KEY_2, GROQ_API_KEY_3,
|
|
CEREBRAS_API_KEY, SAMBANOVA_API_KEY,
|
|
OPENROUTER_API_KEY, GEMINI_API_KEY,
|
|
OLLAMA_BASE_URL, OLLAMA_ENABLED,
|
|
)
|
|
|
|
logger = logging.getLogger("thirdeye.providers")
|
|
|
|
# ── Client registry ──────────────────────────────────────────────────────────
|
|
_clients: dict[str, AsyncOpenAI] = {}
|
|
|
|
def _init_client(name: str, base_url: str, api_key: str | None):
|
|
if api_key and len(api_key) > 5:
|
|
_clients[name] = AsyncOpenAI(base_url=base_url, api_key=api_key)
|
|
|
|
# Ollama (local) — uses a dummy key; the OpenAI client requires a non-empty value
|
|
if OLLAMA_ENABLED:
|
|
_clients["ollama"] = AsyncOpenAI(base_url=OLLAMA_BASE_URL, api_key="ollama")
|
|
|
|
# Groq pool: register each key under its own name
|
|
_init_client("groq", "https://api.groq.com/openai/v1", GROQ_API_KEY)
|
|
_init_client("groq_2", "https://api.groq.com/openai/v1", GROQ_API_KEY_2)
|
|
_init_client("groq_3", "https://api.groq.com/openai/v1", GROQ_API_KEY_3)
|
|
|
|
_init_client("cerebras", "https://api.cerebras.ai/v1", CEREBRAS_API_KEY)
|
|
_init_client("sambanova", "https://api.sambanova.ai/v1", SAMBANOVA_API_KEY)
|
|
_init_client("openrouter", "https://openrouter.ai/api/v1", OPENROUTER_API_KEY)
|
|
_init_client("google", "https://generativelanguage.googleapis.com/v1beta/openai/", GEMINI_API_KEY)
|
|
|
|
# Which provider names belong to the Groq pool
|
|
_GROQ_POOL = [name for name in ("groq", "groq_2", "groq_3") if name in _clients]
|
|
_GROQ_MODEL = "llama-3.3-70b-versatile"
|
|
|
|
# Round-robin cursor per task_type (incremented after every call attempt on the pool)
|
|
_rr_cursor: dict[str, int] = defaultdict(int)
|
|
|
|
# ── Model registry ───────────────────────────────────────────────────────────
|
|
# Groq pool entries are expanded dynamically at call time so the cursor drives order.
|
|
# Use the sentinel string "groq_pool" to indicate "use all available Groq keys".
|
|
_GROQ_POOL_SENTINEL = "__groq_pool__"
|
|
|
|
MODEL_REGISTRY: dict[str, list[tuple[str, str]]] = {
|
|
"fast_small": [
|
|
("ollama", "llama3:8b"),
|
|
("groq", "llama-3.1-8b-instant"),
|
|
("cerebras", "llama-3.1-8b"),
|
|
("openrouter", "openai/gpt-oss-20b:free"),
|
|
],
|
|
"fast_large": [
|
|
(_GROQ_POOL_SENTINEL, _GROQ_MODEL), # expands to all 3 Groq keys (round-robin)
|
|
("openrouter", "arcee-ai/trinity-large-preview:free"),
|
|
("openrouter", "meta-llama/llama-3.3-70b-instruct:free"),
|
|
("sambanova", "Meta-Llama-3.3-70B-Instruct"),
|
|
("cerebras", "llama3.1-8b"),
|
|
],
|
|
"reasoning": [
|
|
("sambanova", "DeepSeek-R1-Distill-Llama-70B"),
|
|
("openrouter", "nvidia/nemotron-3-super-120b-a12b:free"),
|
|
("openrouter", "openai/gpt-oss-120b:free"),
|
|
],
|
|
"agentic": [
|
|
("openrouter", "minimax/minimax-m2.5:free"),
|
|
("openrouter", "nvidia/nemotron-3-super-120b-a12b:free"),
|
|
(_GROQ_POOL_SENTINEL, _GROQ_MODEL),
|
|
],
|
|
"fallback": [
|
|
("openrouter", "openrouter/free"),
|
|
("google", "gemini-2.5-flash"),
|
|
],
|
|
}
|
|
|
|
|
|
def _expand_candidates(task_type: str) -> list[tuple[str, str]]:
|
|
"""
|
|
Return the full candidate list for a task_type with the Groq pool sentinel
|
|
expanded into ordered (provider_name, model) tuples starting from the
|
|
current round-robin cursor position.
|
|
"""
|
|
raw = MODEL_REGISTRY.get(task_type, []) + MODEL_REGISTRY["fallback"]
|
|
expanded: list[tuple[str, str]] = []
|
|
|
|
for provider, model in raw:
|
|
if provider != _GROQ_POOL_SENTINEL:
|
|
expanded.append((provider, model))
|
|
continue
|
|
|
|
if not _GROQ_POOL:
|
|
continue
|
|
|
|
# Rotate: start from cursor, wrap around
|
|
start = _rr_cursor[task_type] % len(_GROQ_POOL)
|
|
ordered = _GROQ_POOL[start:] + _GROQ_POOL[:start]
|
|
for key_name in ordered:
|
|
expanded.append((key_name, model))
|
|
|
|
return expanded
|
|
|
|
|
|
# ── Public API ────────────────────────────────────────────────────────────────
|
|
|
|
async def call_llm(
|
|
task_type: str,
|
|
messages: list,
|
|
temperature: float = 0.3,
|
|
max_tokens: int = 2000,
|
|
response_format: dict = None,
|
|
) -> dict:
|
|
"""
|
|
Route to the best available provider with automatic fallback.
|
|
|
|
Returns:
|
|
{"content": str, "provider": str, "model": str}
|
|
"""
|
|
candidates = _expand_candidates(task_type)
|
|
errors = []
|
|
|
|
for provider_name, model_id in candidates:
|
|
client = _clients.get(provider_name)
|
|
if not client:
|
|
continue
|
|
|
|
try:
|
|
kwargs = {
|
|
"model": model_id,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
"timeout": 45,
|
|
}
|
|
if response_format and provider_name not in ("google",):
|
|
kwargs["response_format"] = response_format
|
|
|
|
response = await client.chat.completions.create(**kwargs)
|
|
content = response.choices[0].message.content
|
|
|
|
# Advance round-robin cursor on success so next call starts from the
|
|
# following key, distributing load evenly across the pool.
|
|
if provider_name in _GROQ_POOL:
|
|
_rr_cursor[task_type] = (_rr_cursor[task_type] + 1) % len(_GROQ_POOL)
|
|
|
|
display_name = provider_name if provider_name not in ("groq_2", "groq_3") else f"groq[key{provider_name[-1]}]"
|
|
logger.info(f"LLM call success: {display_name}/{model_id} ({task_type})")
|
|
return {
|
|
"content": content,
|
|
"provider": display_name,
|
|
"model": model_id,
|
|
}
|
|
|
|
except Exception as e:
|
|
err = str(e).lower()
|
|
is_rate_limit = any(k in err for k in ["429", "rate", "quota", "limit", "exceeded", "capacity"])
|
|
is_timeout = "timeout" in err or "timed out" in err
|
|
|
|
if is_rate_limit or is_timeout:
|
|
logger.warning(f"Provider {provider_name}/{model_id} unavailable: {type(e).__name__}")
|
|
errors.append(f"{provider_name}: rate limited")
|
|
# Also advance cursor so the next call won't start on this key
|
|
if provider_name in _GROQ_POOL:
|
|
_rr_cursor[task_type] = (_rr_cursor[task_type] + 1) % len(_GROQ_POOL)
|
|
else:
|
|
logger.error(f"Provider {provider_name}/{model_id} error: {e}")
|
|
errors.append(f"{provider_name}: {e}")
|
|
continue
|
|
|
|
raise Exception(f"All LLM providers exhausted. Errors: {errors}")
|