File size: 4,328 Bytes
14208c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Thread-safe API Key Rotator for load balancing across multiple keys.

This module provides a round-robin key rotation mechanism to distribute
API requests across multiple keys, helping to avoid rate limits.

Usage:
    from app.shared.integrations.key_rotator import megallm_key_rotator
    
    api_key = megallm_key_rotator.get_next_key()
"""

import logging
import os
import threading
from typing import List

from dotenv import load_dotenv

# Ensure .env is loaded before accessing os.environ
load_dotenv()

logger = logging.getLogger(__name__)


class KeyRotator:
    """Thread-safe round-robin API key rotator.
    
    Distributes API calls across multiple keys to avoid per-key rate limits.
    Each call to get_next_key() returns the next key in rotation.
    
    Attributes:
        _keys: List of API keys to rotate through
        _index: Current position in rotation
        _lock: Thread lock for safe concurrent access
    """
    
    def __init__(self, keys: List[str], name: str = "default"):
        """Initialize the key rotator.
        
        Args:
            keys: List of API keys (must have at least one)
            name: Name for logging identification
            
        Raises:
            ValueError: If keys list is empty
        """
        if not keys:
            raise ValueError("At least one API key is required")
        
        self._keys = keys
        self._name = name
        self._index = 0
        self._lock = threading.Lock()
        self._request_count = 0
        
        logger.info(f"[KeyRotator:{name}] Initialized with {len(keys)} API keys")
    
    def get_next_key(self) -> str:
        """Get next API key in rotation (thread-safe).
        
        Returns:
            The next API key in round-robin order
        """
        with self._lock:
            key = self._keys[self._index]
            key_index = self._index + 1  # 1-based for logging
            self._index = (self._index + 1) % len(self._keys)
            self._request_count += 1
            
            # Log rotation (mask key for security, only show last 8 chars)
            masked_key = f"...{key[-8:]}" if len(key) > 8 else key
            logger.info(
                f"[KeyRotator:{self._name}] Request #{self._request_count} "
                f"using key {key_index}/{len(self._keys)} ({masked_key})"
            )
            
            return key
    
    @property
    def total_keys(self) -> int:
        """Number of keys in rotation."""
        return len(self._keys)
    
    @property
    def request_count(self) -> int:
        """Total number of requests made through this rotator."""
        return self._request_count
    
    def get_stats(self) -> dict:
        """Get rotation statistics for debugging."""
        return {
            "name": self._name,
            "total_keys": len(self._keys),
            "current_index": self._index,
            "total_requests": self._request_count,
        }


def load_megallm_keys() -> List[str]:
    """Load all MEGALLM_API_KEY_* from environment variables.
    
    Looks for keys in format: MEGALLM_API_KEY_1, MEGALLM_API_KEY_2, etc.
    Falls back to single MEGALLM_API_KEY for backward compatibility.
    
    Returns:
        List of API keys found in environment
    """
    keys = []
    i = 1
    
    # Load numbered keys (MEGALLM_API_KEY_1, MEGALLM_API_KEY_2, ...)
    while True:
        key = os.environ.get(f"MEGALLM_API_KEY_{i}")
        if not key:
            break
        keys.append(key)
        i += 1
    
    # Fallback to single key for backward compatibility
    if not keys:
        single_key = os.environ.get("MEGALLM_API_KEY")
        if single_key:
            keys = [single_key]
            logger.warning(
                "[KeyRotator] Using legacy MEGALLM_API_KEY. "
                "Consider migrating to MEGALLM_API_KEY_1, MEGALLM_API_KEY_2, etc."
            )
    
    if keys:
        logger.info(f"[KeyRotator] Loaded {len(keys)} MegaLLM API key(s)")
    else:
        logger.warning("[KeyRotator] No MegaLLM API keys found in environment")
    
    return keys


# Singleton instance for MegaLLM key rotation
_megallm_keys = load_megallm_keys()
megallm_key_rotator: KeyRotator | None = None

if _megallm_keys:
    megallm_key_rotator = KeyRotator(_megallm_keys, name="MegaLLM")