"""Multi-provider LLM router with automatic fallback on rate limits. Groq pool: up to 3 API keys (GROQ_API_KEY, GROQ_API_KEY_2, GROQ_API_KEY_3) all running llama-3.3-70b-versatile. Calls are round-robined across the pool so the per-key rate limit is shared evenly. When a key is rate-limited the router falls through to the next key in rotation, then to the rest of the fallback chain. """ import asyncio import logging from collections import defaultdict from openai import AsyncOpenAI from backend.config import ( GROQ_API_KEY, GROQ_API_KEY_2, GROQ_API_KEY_3, CEREBRAS_API_KEY, SAMBANOVA_API_KEY, OPENROUTER_API_KEY, GEMINI_API_KEY, OLLAMA_BASE_URL, OLLAMA_ENABLED, ) logger = logging.getLogger("thirdeye.providers") # ── Client registry ────────────────────────────────────────────────────────── _clients: dict[str, AsyncOpenAI] = {} def _init_client(name: str, base_url: str, api_key: str | None): if api_key and len(api_key) > 5: _clients[name] = AsyncOpenAI(base_url=base_url, api_key=api_key) # Ollama (local) — uses a dummy key; the OpenAI client requires a non-empty value if OLLAMA_ENABLED: _clients["ollama"] = AsyncOpenAI(base_url=OLLAMA_BASE_URL, api_key="ollama") # Groq pool: register each key under its own name _init_client("groq", "https://api.groq.com/openai/v1", GROQ_API_KEY) _init_client("groq_2", "https://api.groq.com/openai/v1", GROQ_API_KEY_2) _init_client("groq_3", "https://api.groq.com/openai/v1", GROQ_API_KEY_3) _init_client("cerebras", "https://api.cerebras.ai/v1", CEREBRAS_API_KEY) _init_client("sambanova", "https://api.sambanova.ai/v1", SAMBANOVA_API_KEY) _init_client("openrouter", "https://openrouter.ai/api/v1", OPENROUTER_API_KEY) _init_client("google", "https://generativelanguage.googleapis.com/v1beta/openai/", GEMINI_API_KEY) # Which provider names belong to the Groq pool _GROQ_POOL = [name for name in ("groq", "groq_2", "groq_3") if name in _clients] _GROQ_MODEL = "llama-3.3-70b-versatile" # Round-robin cursor per task_type (incremented after every call attempt on the pool) _rr_cursor: dict[str, int] = defaultdict(int) # ── Model registry ─────────────────────────────────────────────────────────── # Groq pool entries are expanded dynamically at call time so the cursor drives order. # Use the sentinel string "groq_pool" to indicate "use all available Groq keys". _GROQ_POOL_SENTINEL = "__groq_pool__" MODEL_REGISTRY: dict[str, list[tuple[str, str]]] = { "fast_small": [ ("ollama", "llama3:8b"), ("groq", "llama-3.1-8b-instant"), ("cerebras", "llama-3.1-8b"), ("openrouter", "openai/gpt-oss-20b:free"), ], "fast_large": [ (_GROQ_POOL_SENTINEL, _GROQ_MODEL), # expands to all 3 Groq keys (round-robin) ("openrouter", "arcee-ai/trinity-large-preview:free"), ("openrouter", "meta-llama/llama-3.3-70b-instruct:free"), ("sambanova", "Meta-Llama-3.3-70B-Instruct"), ("cerebras", "llama3.1-8b"), ], "reasoning": [ ("sambanova", "DeepSeek-R1-Distill-Llama-70B"), ("openrouter", "nvidia/nemotron-3-super-120b-a12b:free"), ("openrouter", "openai/gpt-oss-120b:free"), ], "agentic": [ ("openrouter", "minimax/minimax-m2.5:free"), ("openrouter", "nvidia/nemotron-3-super-120b-a12b:free"), (_GROQ_POOL_SENTINEL, _GROQ_MODEL), ], "fallback": [ ("openrouter", "openrouter/free"), ("google", "gemini-2.5-flash"), ], } def _expand_candidates(task_type: str) -> list[tuple[str, str]]: """ Return the full candidate list for a task_type with the Groq pool sentinel expanded into ordered (provider_name, model) tuples starting from the current round-robin cursor position. """ raw = MODEL_REGISTRY.get(task_type, []) + MODEL_REGISTRY["fallback"] expanded: list[tuple[str, str]] = [] for provider, model in raw: if provider != _GROQ_POOL_SENTINEL: expanded.append((provider, model)) continue if not _GROQ_POOL: continue # Rotate: start from cursor, wrap around start = _rr_cursor[task_type] % len(_GROQ_POOL) ordered = _GROQ_POOL[start:] + _GROQ_POOL[:start] for key_name in ordered: expanded.append((key_name, model)) return expanded # ── Public API ──────────────────────────────────────────────────────────────── async def call_llm( task_type: str, messages: list, temperature: float = 0.3, max_tokens: int = 2000, response_format: dict = None, ) -> dict: """ Route to the best available provider with automatic fallback. Returns: {"content": str, "provider": str, "model": str} """ candidates = _expand_candidates(task_type) errors = [] for provider_name, model_id in candidates: client = _clients.get(provider_name) if not client: continue try: kwargs = { "model": model_id, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "timeout": 45, } if response_format and provider_name not in ("google",): kwargs["response_format"] = response_format response = await client.chat.completions.create(**kwargs) content = response.choices[0].message.content # Advance round-robin cursor on success so next call starts from the # following key, distributing load evenly across the pool. if provider_name in _GROQ_POOL: _rr_cursor[task_type] = (_rr_cursor[task_type] + 1) % len(_GROQ_POOL) display_name = provider_name if provider_name not in ("groq_2", "groq_3") else f"groq[key{provider_name[-1]}]" logger.info(f"LLM call success: {display_name}/{model_id} ({task_type})") return { "content": content, "provider": display_name, "model": model_id, } except Exception as e: err = str(e).lower() is_rate_limit = any(k in err for k in ["429", "rate", "quota", "limit", "exceeded", "capacity"]) is_timeout = "timeout" in err or "timed out" in err if is_rate_limit or is_timeout: logger.warning(f"Provider {provider_name}/{model_id} unavailable: {type(e).__name__}") errors.append(f"{provider_name}: rate limited") # Also advance cursor so the next call won't start on this key if provider_name in _GROQ_POOL: _rr_cursor[task_type] = (_rr_cursor[task_type] + 1) % len(_GROQ_POOL) else: logger.error(f"Provider {provider_name}/{model_id} error: {e}") errors.append(f"{provider_name}: {e}") continue raise Exception(f"All LLM providers exhausted. Errors: {errors}")