|
|
"""MegaLLM client using OpenAI-compatible API with retry logic and key rotation.""" |
|
|
|
|
|
import asyncio |
|
|
import logging |
|
|
|
|
|
import httpx |
|
|
|
|
|
from app.core.config import settings |
|
|
from app.shared.integrations.key_rotator import megallm_key_rotator |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
REQUEST_TIMEOUT = httpx.Timeout( |
|
|
connect=30.0, |
|
|
read=300.0, |
|
|
write=30.0, |
|
|
pool=30.0, |
|
|
) |
|
|
|
|
|
|
|
|
MAX_429_RETRIES = 3 |
|
|
INITIAL_BACKOFF_SECONDS = 2 |
|
|
|
|
|
|
|
|
class MegaLLMClient: |
|
|
"""Client for MegaLLM (OpenAI-compatible API) operations with key rotation.""" |
|
|
|
|
|
def __init__(self, model: str | None = None): |
|
|
"""Initialize with optional model override.""" |
|
|
self.model = model or settings.default_megallm_model |
|
|
self.base_url = settings.megallm_base_url |
|
|
|
|
|
def _get_api_key(self) -> str: |
|
|
"""Get API key using rotation or fallback to settings.""" |
|
|
if megallm_key_rotator: |
|
|
return megallm_key_rotator.get_next_key() |
|
|
|
|
|
if settings.megallm_api_key: |
|
|
return settings.megallm_api_key |
|
|
raise ValueError("No MegaLLM API keys configured") |
|
|
|
|
|
async def generate( |
|
|
self, |
|
|
prompt: str, |
|
|
temperature: float = 0.7, |
|
|
system_instruction: str | None = None, |
|
|
max_retries: int = 2, |
|
|
) -> str: |
|
|
""" |
|
|
Generate text using MegaLLM (OpenAI-compatible API). |
|
|
|
|
|
Args: |
|
|
prompt: Text prompt |
|
|
temperature: Sampling temperature |
|
|
system_instruction: Optional system prompt |
|
|
max_retries: Number of retries on timeout |
|
|
|
|
|
Returns: |
|
|
Generated text |
|
|
""" |
|
|
messages = [] |
|
|
if system_instruction: |
|
|
messages.append({"role": "system", "content": system_instruction}) |
|
|
messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
last_error = None |
|
|
rate_limit_retries = 0 |
|
|
|
|
|
|
|
|
total_attempts = max_retries + 1 + MAX_429_RETRIES |
|
|
|
|
|
for attempt in range(total_attempts): |
|
|
|
|
|
api_key = self._get_api_key() |
|
|
|
|
|
try: |
|
|
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client: |
|
|
response = await client.post( |
|
|
f"{self.base_url}/chat/completions", |
|
|
headers={ |
|
|
"Authorization": f"Bearer {api_key}", |
|
|
"Content-Type": "application/json", |
|
|
}, |
|
|
json={ |
|
|
"model": self.model, |
|
|
"messages": messages, |
|
|
"temperature": temperature, |
|
|
}, |
|
|
) |
|
|
|
|
|
|
|
|
if response.status_code == 429: |
|
|
rate_limit_retries += 1 |
|
|
if rate_limit_retries <= MAX_429_RETRIES: |
|
|
wait_time = INITIAL_BACKOFF_SECONDS * (2 ** (rate_limit_retries - 1)) |
|
|
logger.warning( |
|
|
f"[MegaLLM] 429 Rate Limit hit, retry {rate_limit_retries}/{MAX_429_RETRIES} " |
|
|
f"after {wait_time}s" |
|
|
) |
|
|
await asyncio.sleep(wait_time) |
|
|
continue |
|
|
else: |
|
|
|
|
|
logger.error(f"[MegaLLM] Max 429 retries exceeded ({MAX_429_RETRIES})") |
|
|
response.raise_for_status() |
|
|
|
|
|
response.raise_for_status() |
|
|
data = response.json() |
|
|
return data["choices"][0]["message"]["content"] |
|
|
|
|
|
except httpx.ReadTimeout as e: |
|
|
last_error = e |
|
|
if attempt < max_retries: |
|
|
logger.warning(f"[MegaLLM] Timeout, retry {attempt + 1}/{max_retries}") |
|
|
continue |
|
|
|
|
|
raise |
|
|
except httpx.HTTPStatusError: |
|
|
|
|
|
|
|
|
raise |
|
|
except Exception as e: |
|
|
last_error = e |
|
|
raise |
|
|
|
|
|
|
|
|
if last_error: |
|
|
raise last_error |
|
|
raise RuntimeError("Max retries exceeded") |
|
|
|
|
|
async def chat( |
|
|
self, |
|
|
messages: list[dict], |
|
|
temperature: float = 0.7, |
|
|
system_instruction: str | None = None, |
|
|
) -> str: |
|
|
""" |
|
|
Generate chat completion using MegaLLM. |
|
|
|
|
|
Args: |
|
|
messages: List of message dicts with 'role' and 'content' |
|
|
temperature: Sampling temperature |
|
|
system_instruction: Optional system prompt |
|
|
|
|
|
Returns: |
|
|
Generated text response |
|
|
""" |
|
|
|
|
|
api_key = self._get_api_key() |
|
|
|
|
|
chat_messages = [] |
|
|
if system_instruction: |
|
|
chat_messages.append({"role": "system", "content": system_instruction}) |
|
|
|
|
|
|
|
|
for msg in messages: |
|
|
role = msg.get("role", "user") |
|
|
content = msg.get("content") or msg.get("parts", [""])[0] |
|
|
chat_messages.append({"role": role, "content": content}) |
|
|
|
|
|
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client: |
|
|
response = await client.post( |
|
|
f"{self.base_url}/chat/completions", |
|
|
headers={ |
|
|
"Authorization": f"Bearer {api_key}", |
|
|
"Content-Type": "application/json", |
|
|
}, |
|
|
json={ |
|
|
"model": self.model, |
|
|
"messages": chat_messages, |
|
|
"temperature": temperature, |
|
|
}, |
|
|
) |
|
|
response.raise_for_status() |
|
|
data = response.json() |
|
|
return data["choices"][0]["message"]["content"] |
|
|
|
|
|
|
|
|
def get_megallm_client(model: str | None = None) -> MegaLLMClient: |
|
|
"""Factory function to create MegaLLM client with specified model.""" |
|
|
return MegaLLMClient(model=model) |
|
|
|