mirror of
https://github.com/arkorty/B.Tech-Project-III.git
synced 2026-04-19 12:41:48 +00:00
init
This commit is contained in:
177
thirdeye/backend/providers.py
Normal file
177
thirdeye/backend/providers.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""Multi-provider LLM router with automatic fallback on rate limits.
|
||||
|
||||
Groq pool: up to 3 API keys (GROQ_API_KEY, GROQ_API_KEY_2, GROQ_API_KEY_3) all running
|
||||
llama-3.3-70b-versatile. Calls are round-robined across the pool so the per-key rate
|
||||
limit is shared evenly. When a key is rate-limited the router falls through to the next
|
||||
key in rotation, then to the rest of the fallback chain.
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from openai import AsyncOpenAI
|
||||
from backend.config import (
|
||||
GROQ_API_KEY, GROQ_API_KEY_2, GROQ_API_KEY_3,
|
||||
CEREBRAS_API_KEY, SAMBANOVA_API_KEY,
|
||||
OPENROUTER_API_KEY, GEMINI_API_KEY,
|
||||
OLLAMA_BASE_URL, OLLAMA_ENABLED,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("thirdeye.providers")
|
||||
|
||||
# ── Client registry ──────────────────────────────────────────────────────────
|
||||
_clients: dict[str, AsyncOpenAI] = {}
|
||||
|
||||
def _init_client(name: str, base_url: str, api_key: str | None):
|
||||
if api_key and len(api_key) > 5:
|
||||
_clients[name] = AsyncOpenAI(base_url=base_url, api_key=api_key)
|
||||
|
||||
# Ollama (local) — uses a dummy key; the OpenAI client requires a non-empty value
|
||||
if OLLAMA_ENABLED:
|
||||
_clients["ollama"] = AsyncOpenAI(base_url=OLLAMA_BASE_URL, api_key="ollama")
|
||||
|
||||
# Groq pool: register each key under its own name
|
||||
_init_client("groq", "https://api.groq.com/openai/v1", GROQ_API_KEY)
|
||||
_init_client("groq_2", "https://api.groq.com/openai/v1", GROQ_API_KEY_2)
|
||||
_init_client("groq_3", "https://api.groq.com/openai/v1", GROQ_API_KEY_3)
|
||||
|
||||
_init_client("cerebras", "https://api.cerebras.ai/v1", CEREBRAS_API_KEY)
|
||||
_init_client("sambanova", "https://api.sambanova.ai/v1", SAMBANOVA_API_KEY)
|
||||
_init_client("openrouter", "https://openrouter.ai/api/v1", OPENROUTER_API_KEY)
|
||||
_init_client("google", "https://generativelanguage.googleapis.com/v1beta/openai/", GEMINI_API_KEY)
|
||||
|
||||
# Which provider names belong to the Groq pool
|
||||
_GROQ_POOL = [name for name in ("groq", "groq_2", "groq_3") if name in _clients]
|
||||
_GROQ_MODEL = "llama-3.3-70b-versatile"
|
||||
|
||||
# Round-robin cursor per task_type (incremented after every call attempt on the pool)
|
||||
_rr_cursor: dict[str, int] = defaultdict(int)
|
||||
|
||||
# ── Model registry ───────────────────────────────────────────────────────────
|
||||
# Groq pool entries are expanded dynamically at call time so the cursor drives order.
|
||||
# Use the sentinel string "groq_pool" to indicate "use all available Groq keys".
|
||||
_GROQ_POOL_SENTINEL = "__groq_pool__"
|
||||
|
||||
MODEL_REGISTRY: dict[str, list[tuple[str, str]]] = {
|
||||
"fast_small": [
|
||||
("ollama", "llama3:8b"),
|
||||
("groq", "llama-3.1-8b-instant"),
|
||||
("cerebras", "llama-3.1-8b"),
|
||||
("openrouter", "openai/gpt-oss-20b:free"),
|
||||
],
|
||||
"fast_large": [
|
||||
(_GROQ_POOL_SENTINEL, _GROQ_MODEL), # expands to all 3 Groq keys (round-robin)
|
||||
("openrouter", "arcee-ai/trinity-large-preview:free"),
|
||||
("openrouter", "meta-llama/llama-3.3-70b-instruct:free"),
|
||||
("sambanova", "Meta-Llama-3.3-70B-Instruct"),
|
||||
("cerebras", "llama3.1-8b"),
|
||||
],
|
||||
"reasoning": [
|
||||
("sambanova", "DeepSeek-R1-Distill-Llama-70B"),
|
||||
("openrouter", "nvidia/nemotron-3-super-120b-a12b:free"),
|
||||
("openrouter", "openai/gpt-oss-120b:free"),
|
||||
],
|
||||
"agentic": [
|
||||
("openrouter", "minimax/minimax-m2.5:free"),
|
||||
("openrouter", "nvidia/nemotron-3-super-120b-a12b:free"),
|
||||
(_GROQ_POOL_SENTINEL, _GROQ_MODEL),
|
||||
],
|
||||
"fallback": [
|
||||
("openrouter", "openrouter/free"),
|
||||
("google", "gemini-2.5-flash"),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _expand_candidates(task_type: str) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Return the full candidate list for a task_type with the Groq pool sentinel
|
||||
expanded into ordered (provider_name, model) tuples starting from the
|
||||
current round-robin cursor position.
|
||||
"""
|
||||
raw = MODEL_REGISTRY.get(task_type, []) + MODEL_REGISTRY["fallback"]
|
||||
expanded: list[tuple[str, str]] = []
|
||||
|
||||
for provider, model in raw:
|
||||
if provider != _GROQ_POOL_SENTINEL:
|
||||
expanded.append((provider, model))
|
||||
continue
|
||||
|
||||
if not _GROQ_POOL:
|
||||
continue
|
||||
|
||||
# Rotate: start from cursor, wrap around
|
||||
start = _rr_cursor[task_type] % len(_GROQ_POOL)
|
||||
ordered = _GROQ_POOL[start:] + _GROQ_POOL[:start]
|
||||
for key_name in ordered:
|
||||
expanded.append((key_name, model))
|
||||
|
||||
return expanded
|
||||
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
async def call_llm(
|
||||
task_type: str,
|
||||
messages: list,
|
||||
temperature: float = 0.3,
|
||||
max_tokens: int = 2000,
|
||||
response_format: dict = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Route to the best available provider with automatic fallback.
|
||||
|
||||
Returns:
|
||||
{"content": str, "provider": str, "model": str}
|
||||
"""
|
||||
candidates = _expand_candidates(task_type)
|
||||
errors = []
|
||||
|
||||
for provider_name, model_id in candidates:
|
||||
client = _clients.get(provider_name)
|
||||
if not client:
|
||||
continue
|
||||
|
||||
try:
|
||||
kwargs = {
|
||||
"model": model_id,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"timeout": 45,
|
||||
}
|
||||
if response_format and provider_name not in ("google",):
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
response = await client.chat.completions.create(**kwargs)
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# Advance round-robin cursor on success so next call starts from the
|
||||
# following key, distributing load evenly across the pool.
|
||||
if provider_name in _GROQ_POOL:
|
||||
_rr_cursor[task_type] = (_rr_cursor[task_type] + 1) % len(_GROQ_POOL)
|
||||
|
||||
display_name = provider_name if provider_name not in ("groq_2", "groq_3") else f"groq[key{provider_name[-1]}]"
|
||||
logger.info(f"LLM call success: {display_name}/{model_id} ({task_type})")
|
||||
return {
|
||||
"content": content,
|
||||
"provider": display_name,
|
||||
"model": model_id,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
err = str(e).lower()
|
||||
is_rate_limit = any(k in err for k in ["429", "rate", "quota", "limit", "exceeded", "capacity"])
|
||||
is_timeout = "timeout" in err or "timed out" in err
|
||||
|
||||
if is_rate_limit or is_timeout:
|
||||
logger.warning(f"Provider {provider_name}/{model_id} unavailable: {type(e).__name__}")
|
||||
errors.append(f"{provider_name}: rate limited")
|
||||
# Also advance cursor so the next call won't start on this key
|
||||
if provider_name in _GROQ_POOL:
|
||||
_rr_cursor[task_type] = (_rr_cursor[task_type] + 1) % len(_GROQ_POOL)
|
||||
else:
|
||||
logger.error(f"Provider {provider_name}/{model_id} error: {e}")
|
||||
errors.append(f"{provider_name}: {e}")
|
||||
continue
|
||||
|
||||
raise Exception(f"All LLM providers exhausted. Errors: {errors}")
|
||||
Reference in New Issue
Block a user