|
|
""" |
|
|
Connection pooling for LLM providers with HTTP/2 support. |
|
|
|
|
|
This module provides efficient connection pooling for LLM API calls, |
|
|
reducing latency and improving throughput. |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
from typing import Dict, Any, Optional, Union |
|
|
from contextlib import asynccontextmanager |
|
|
import time |
|
|
|
|
|
import httpx |
|
|
from httpx import AsyncClient, Limits, Timeout |
|
|
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type |
|
|
|
|
|
from src.core import get_logger, settings |
|
|
from src.core.json_utils import dumps, loads |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
class LLMConnectionPool: |
|
|
""" |
|
|
Connection pool manager for LLM providers. |
|
|
|
|
|
Features: |
|
|
- Persistent HTTP/2 connections |
|
|
- Automatic retry with exponential backoff |
|
|
- Connection health monitoring |
|
|
- Request/response caching |
|
|
- Performance metrics |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
max_connections: int = 20, |
|
|
max_keepalive_connections: int = 10, |
|
|
keepalive_expiry: float = 30.0, |
|
|
timeout: float = 30.0, |
|
|
http2: bool = True |
|
|
): |
|
|
""" |
|
|
Initialize LLM connection pool. |
|
|
|
|
|
Args: |
|
|
max_connections: Maximum number of connections |
|
|
max_keepalive_connections: Maximum idle connections |
|
|
keepalive_expiry: How long to keep idle connections (seconds) |
|
|
timeout: Request timeout (seconds) |
|
|
http2: Enable HTTP/2 support |
|
|
""" |
|
|
self.max_connections = max_connections |
|
|
self.max_keepalive_connections = max_keepalive_connections |
|
|
self.keepalive_expiry = keepalive_expiry |
|
|
self.timeout = timeout |
|
|
self.http2 = http2 |
|
|
|
|
|
|
|
|
self._pools: Dict[str, AsyncClient] = {} |
|
|
self._pool_stats: Dict[str, Dict[str, Any]] = {} |
|
|
|
|
|
|
|
|
self.metrics = { |
|
|
"requests": 0, |
|
|
"successes": 0, |
|
|
"failures": 0, |
|
|
"total_latency": 0.0, |
|
|
"cache_hits": 0 |
|
|
} |
|
|
|
|
|
async def initialize(self): |
|
|
"""Initialize connection pools for configured providers.""" |
|
|
providers = { |
|
|
"groq": { |
|
|
"base_url": "https://api.groq.com/openai/v1", |
|
|
"headers": { |
|
|
"Authorization": f"Bearer {settings.groq_api_key}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
}, |
|
|
"openai": { |
|
|
"base_url": "https://api.openai.com/v1", |
|
|
"headers": { |
|
|
"Authorization": f"Bearer {getattr(settings, 'openai_api_key', '')}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
for provider, config in providers.items(): |
|
|
if provider == "openai" and not getattr(settings, 'openai_api_key', None): |
|
|
continue |
|
|
|
|
|
await self._create_pool(provider, config) |
|
|
|
|
|
async def _create_pool(self, provider: str, config: Dict[str, Any]): |
|
|
"""Create connection pool for a provider.""" |
|
|
try: |
|
|
limits = Limits( |
|
|
max_connections=self.max_connections, |
|
|
max_keepalive_connections=self.max_keepalive_connections, |
|
|
keepalive_expiry=self.keepalive_expiry |
|
|
) |
|
|
|
|
|
timeout = Timeout( |
|
|
connect=5.0, |
|
|
read=self.timeout, |
|
|
write=10.0, |
|
|
pool=5.0 |
|
|
) |
|
|
|
|
|
client = AsyncClient( |
|
|
base_url=config["base_url"], |
|
|
headers=config["headers"], |
|
|
limits=limits, |
|
|
timeout=timeout, |
|
|
http2=self.http2, |
|
|
follow_redirects=True |
|
|
) |
|
|
|
|
|
self._pools[provider] = client |
|
|
self._pool_stats[provider] = { |
|
|
"created_at": time.time(), |
|
|
"requests": 0, |
|
|
"errors": 0 |
|
|
} |
|
|
|
|
|
logger.info(f"Created connection pool for {provider} (HTTP/2: {self.http2})") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to create pool for {provider}: {e}") |
|
|
|
|
|
@asynccontextmanager |
|
|
async def get_client(self, provider: str = "groq") -> AsyncClient: |
|
|
""" |
|
|
Get HTTP client for a provider. |
|
|
|
|
|
Args: |
|
|
provider: LLM provider name |
|
|
|
|
|
Yields: |
|
|
AsyncClient instance |
|
|
""" |
|
|
if provider not in self._pools: |
|
|
raise ValueError(f"Provider {provider} not initialized") |
|
|
|
|
|
client = self._pools[provider] |
|
|
self._pool_stats[provider]["requests"] += 1 |
|
|
|
|
|
try: |
|
|
yield client |
|
|
except Exception as e: |
|
|
self._pool_stats[provider]["errors"] += 1 |
|
|
raise |
|
|
|
|
|
@retry( |
|
|
stop=stop_after_attempt(3), |
|
|
wait=wait_exponential(multiplier=1, min=2, max=10), |
|
|
retry=retry_if_exception_type((httpx.TimeoutException, httpx.NetworkError)) |
|
|
) |
|
|
async def post( |
|
|
self, |
|
|
provider: str, |
|
|
endpoint: str, |
|
|
data: Dict[str, Any], |
|
|
**kwargs |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Make POST request with automatic retry and pooling. |
|
|
|
|
|
Args: |
|
|
provider: LLM provider name |
|
|
endpoint: API endpoint |
|
|
data: Request data |
|
|
**kwargs: Additional httpx parameters |
|
|
|
|
|
Returns: |
|
|
Response data as dict |
|
|
""" |
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
async with self.get_client(provider) as client: |
|
|
|
|
|
json_data = dumps(data) |
|
|
|
|
|
response = await client.post( |
|
|
endpoint, |
|
|
content=json_data, |
|
|
headers={"Content-Type": "application/json"}, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
response.raise_for_status() |
|
|
|
|
|
|
|
|
result = loads(response.content) |
|
|
|
|
|
|
|
|
latency = time.time() - start_time |
|
|
self.metrics["requests"] += 1 |
|
|
self.metrics["successes"] += 1 |
|
|
self.metrics["total_latency"] += latency |
|
|
|
|
|
logger.debug(f"{provider} request to {endpoint} completed in {latency:.3f}s") |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
self.metrics["requests"] += 1 |
|
|
self.metrics["failures"] += 1 |
|
|
logger.error(f"{provider} request failed: {e}") |
|
|
raise |
|
|
|
|
|
async def chat_completion( |
|
|
self, |
|
|
messages: list, |
|
|
model: str = "mixtral-8x7b-32768", |
|
|
provider: str = "groq", |
|
|
temperature: float = 0.7, |
|
|
max_tokens: int = 1000, |
|
|
**kwargs |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Make chat completion request with optimal settings. |
|
|
|
|
|
Args: |
|
|
messages: Chat messages |
|
|
model: Model to use |
|
|
provider: LLM provider |
|
|
temperature: Sampling temperature |
|
|
max_tokens: Maximum response tokens |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
Completion response |
|
|
""" |
|
|
data = { |
|
|
"model": model, |
|
|
"messages": messages, |
|
|
"temperature": temperature, |
|
|
"max_tokens": max_tokens, |
|
|
**kwargs |
|
|
} |
|
|
|
|
|
return await self.post(provider, "/chat/completions", data) |
|
|
|
|
|
async def close(self): |
|
|
"""Close all connection pools.""" |
|
|
for provider, client in self._pools.items(): |
|
|
try: |
|
|
await client.aclose() |
|
|
logger.info(f"Closed connection pool for {provider}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error closing pool for {provider}: {e}") |
|
|
|
|
|
self._pools.clear() |
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
|
"""Get connection pool statistics.""" |
|
|
avg_latency = ( |
|
|
self.metrics["total_latency"] / self.metrics["requests"] |
|
|
if self.metrics["requests"] > 0 else 0 |
|
|
) |
|
|
|
|
|
return { |
|
|
"pools": self._pool_stats, |
|
|
"metrics": { |
|
|
**self.metrics, |
|
|
"avg_latency_ms": int(avg_latency * 1000), |
|
|
"success_rate": ( |
|
|
self.metrics["successes"] / self.metrics["requests"] |
|
|
if self.metrics["requests"] > 0 else 0 |
|
|
) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
llm_pool = LLMConnectionPool() |
|
|
|
|
|
|
|
|
async def get_llm_pool() -> LLMConnectionPool: |
|
|
"""Get or initialize the global LLM connection pool.""" |
|
|
if not llm_pool._pools: |
|
|
await llm_pool.initialize() |
|
|
return llm_pool |