File size: 4,328 Bytes
14208c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
"""Thread-safe API Key Rotator for load balancing across multiple keys.
This module provides a round-robin key rotation mechanism to distribute
API requests across multiple keys, helping to avoid rate limits.
Usage:
from app.shared.integrations.key_rotator import megallm_key_rotator
api_key = megallm_key_rotator.get_next_key()
"""
import logging
import os
import threading
from typing import List
from dotenv import load_dotenv
# Ensure .env is loaded before accessing os.environ
load_dotenv()
logger = logging.getLogger(__name__)
class KeyRotator:
"""Thread-safe round-robin API key rotator.
Distributes API calls across multiple keys to avoid per-key rate limits.
Each call to get_next_key() returns the next key in rotation.
Attributes:
_keys: List of API keys to rotate through
_index: Current position in rotation
_lock: Thread lock for safe concurrent access
"""
def __init__(self, keys: List[str], name: str = "default"):
"""Initialize the key rotator.
Args:
keys: List of API keys (must have at least one)
name: Name for logging identification
Raises:
ValueError: If keys list is empty
"""
if not keys:
raise ValueError("At least one API key is required")
self._keys = keys
self._name = name
self._index = 0
self._lock = threading.Lock()
self._request_count = 0
logger.info(f"[KeyRotator:{name}] Initialized with {len(keys)} API keys")
def get_next_key(self) -> str:
"""Get next API key in rotation (thread-safe).
Returns:
The next API key in round-robin order
"""
with self._lock:
key = self._keys[self._index]
key_index = self._index + 1 # 1-based for logging
self._index = (self._index + 1) % len(self._keys)
self._request_count += 1
# Log rotation (mask key for security, only show last 8 chars)
masked_key = f"...{key[-8:]}" if len(key) > 8 else key
logger.info(
f"[KeyRotator:{self._name}] Request #{self._request_count} "
f"using key {key_index}/{len(self._keys)} ({masked_key})"
)
return key
@property
def total_keys(self) -> int:
"""Number of keys in rotation."""
return len(self._keys)
@property
def request_count(self) -> int:
"""Total number of requests made through this rotator."""
return self._request_count
def get_stats(self) -> dict:
"""Get rotation statistics for debugging."""
return {
"name": self._name,
"total_keys": len(self._keys),
"current_index": self._index,
"total_requests": self._request_count,
}
def load_megallm_keys() -> List[str]:
"""Load all MEGALLM_API_KEY_* from environment variables.
Looks for keys in format: MEGALLM_API_KEY_1, MEGALLM_API_KEY_2, etc.
Falls back to single MEGALLM_API_KEY for backward compatibility.
Returns:
List of API keys found in environment
"""
keys = []
i = 1
# Load numbered keys (MEGALLM_API_KEY_1, MEGALLM_API_KEY_2, ...)
while True:
key = os.environ.get(f"MEGALLM_API_KEY_{i}")
if not key:
break
keys.append(key)
i += 1
# Fallback to single key for backward compatibility
if not keys:
single_key = os.environ.get("MEGALLM_API_KEY")
if single_key:
keys = [single_key]
logger.warning(
"[KeyRotator] Using legacy MEGALLM_API_KEY. "
"Consider migrating to MEGALLM_API_KEY_1, MEGALLM_API_KEY_2, etc."
)
if keys:
logger.info(f"[KeyRotator] Loaded {len(keys)} MegaLLM API key(s)")
else:
logger.warning("[KeyRotator] No MegaLLM API keys found in environment")
return keys
# Singleton instance for MegaLLM key rotation
_megallm_keys = load_megallm_keys()
megallm_key_rotator: KeyRotator | None = None
if _megallm_keys:
megallm_key_rotator = KeyRotator(_megallm_keys, name="MegaLLM")
|