mirror of
https://github.com/arkorty/B.Tech-Project-III.git
synced 2026-04-19 12:41:48 +00:00
58 lines
2.0 KiB
Python
58 lines
2.0 KiB
Python
from google import genai
|
|
from google.genai import types
|
|
import json
|
|
from config import GEMINI_API_KEY
|
|
|
|
_client = genai.Client(api_key=GEMINI_API_KEY)
|
|
_DEFAULT_MODEL = "gemini-3-flash-preview"
|
|
|
|
|
|
class BaseAgent:
|
|
def __init__(self, system_prompt: str, model_name: str = _DEFAULT_MODEL):
|
|
self.system_prompt = system_prompt
|
|
self.model_name = model_name
|
|
|
|
async def call(self, user_prompt: str, context: dict = None) -> dict:
|
|
"""Call Gemini asynchronously and always return a dict."""
|
|
full_prompt = user_prompt
|
|
if context:
|
|
full_prompt = f"CONTEXT:\n{json.dumps(context, indent=2)}\n\nTASK:\n{user_prompt}"
|
|
|
|
try:
|
|
response = await _client.aio.models.generate_content(
|
|
model=self.model_name,
|
|
contents=full_prompt,
|
|
config=types.GenerateContentConfig(
|
|
system_instruction=self.system_prompt,
|
|
response_mime_type="application/json",
|
|
temperature=0.7,
|
|
),
|
|
)
|
|
text = response.text.strip()
|
|
|
|
# Fast path: response is already valid JSON
|
|
try:
|
|
result = json.loads(text)
|
|
if isinstance(result, dict):
|
|
return result
|
|
return {"error": "Gemini returned non-dict JSON", "raw": text[:200]}
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Fallback: extract the first {...} block from the text
|
|
start = text.find("{")
|
|
end = text.rfind("}") + 1
|
|
if start != -1 and end > start:
|
|
try:
|
|
result = json.loads(text[start:end])
|
|
if isinstance(result, dict):
|
|
return result
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
return {"error": "Could not parse JSON from Gemini response", "raw": text[:500]}
|
|
|
|
except Exception as e:
|
|
print(f"[BaseAgent] Gemini call failed: {e}")
|
|
return {"error": str(e)}
|