|
|
"""Distributed rate limiting service using Redis""" |
|
|
|
|
|
import time |
|
|
from src.core import json_utils |
|
|
from typing import Dict, Optional, Tuple |
|
|
from datetime import datetime, timedelta |
|
|
import redis.asyncio as redis |
|
|
from redis.exceptions import RedisError |
|
|
|
|
|
from src.core.config import settings |
|
|
from src.core import get_logger |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
class DistributedRateLimiter: |
|
|
"""Distributed rate limiter using Redis with sliding window algorithm""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
redis_url: Optional[str] = None, |
|
|
requests_per_minute: int = 60, |
|
|
requests_per_hour: int = 600, |
|
|
requests_per_day: int = 10000, |
|
|
burst_size: int = 10 |
|
|
): |
|
|
self.redis_url = redis_url or settings.REDIS_URL |
|
|
self.requests_per_minute = requests_per_minute |
|
|
self.requests_per_hour = requests_per_hour |
|
|
self.requests_per_day = requests_per_day |
|
|
self.burst_size = burst_size |
|
|
self._redis_client = None |
|
|
|
|
|
async def get_redis(self) -> redis.Redis: |
|
|
"""Get Redis client with lazy initialization""" |
|
|
if self._redis_client is None: |
|
|
self._redis_client = redis.from_url( |
|
|
self.redis_url, |
|
|
encoding="utf-8", |
|
|
decode_responses=True, |
|
|
socket_keepalive=True, |
|
|
health_check_interval=30 |
|
|
) |
|
|
return self._redis_client |
|
|
|
|
|
async def is_allowed( |
|
|
self, |
|
|
identifier: str, |
|
|
endpoint: Optional[str] = None |
|
|
) -> Tuple[bool, Dict[str, any]]: |
|
|
""" |
|
|
Check if request is allowed using sliding window algorithm |
|
|
|
|
|
Args: |
|
|
identifier: Unique identifier (IP address or user ID) |
|
|
endpoint: Optional endpoint for per-endpoint limits |
|
|
|
|
|
Returns: |
|
|
Tuple of (allowed, rate_info) |
|
|
""" |
|
|
try: |
|
|
redis_client = await self.get_redis() |
|
|
current_time = time.time() |
|
|
|
|
|
|
|
|
key_prefix = f"rate_limit:{identifier}" |
|
|
if endpoint: |
|
|
key_prefix += f":{endpoint}" |
|
|
|
|
|
|
|
|
minute_key = f"{key_prefix}:minute" |
|
|
hour_key = f"{key_prefix}:hour" |
|
|
day_key = f"{key_prefix}:day" |
|
|
burst_key = f"{key_prefix}:burst" |
|
|
|
|
|
|
|
|
async with redis_client.pipeline() as pipe: |
|
|
|
|
|
pipe.zcount(minute_key, current_time - 60, current_time) |
|
|
pipe.zcount(hour_key, current_time - 3600, current_time) |
|
|
pipe.zcount(day_key, current_time - 86400, current_time) |
|
|
pipe.get(burst_key) |
|
|
|
|
|
results = await pipe.execute() |
|
|
|
|
|
minute_count = results[0] |
|
|
hour_count = results[1] |
|
|
day_count = results[2] |
|
|
burst_tokens = int(results[3] or self.burst_size) |
|
|
|
|
|
|
|
|
if minute_count >= self.requests_per_minute: |
|
|
return False, { |
|
|
"reason": "minute_limit_exceeded", |
|
|
"limit": self.requests_per_minute, |
|
|
"count": minute_count, |
|
|
"reset_in": 60 - (current_time % 60) |
|
|
} |
|
|
|
|
|
if hour_count >= self.requests_per_hour: |
|
|
return False, { |
|
|
"reason": "hour_limit_exceeded", |
|
|
"limit": self.requests_per_hour, |
|
|
"count": hour_count, |
|
|
"reset_in": 3600 - (current_time % 3600) |
|
|
} |
|
|
|
|
|
if day_count >= self.requests_per_day: |
|
|
return False, { |
|
|
"reason": "day_limit_exceeded", |
|
|
"limit": self.requests_per_day, |
|
|
"count": day_count, |
|
|
"reset_in": 86400 - (current_time % 86400) |
|
|
} |
|
|
|
|
|
|
|
|
if burst_tokens <= 0: |
|
|
return False, { |
|
|
"reason": "burst_limit_exceeded", |
|
|
"limit": self.burst_size, |
|
|
"count": 0, |
|
|
"reset_in": 1 |
|
|
} |
|
|
|
|
|
|
|
|
request_id = f"{current_time}:{identifier}" |
|
|
|
|
|
async with redis_client.pipeline() as pipe: |
|
|
|
|
|
pipe.zadd(minute_key, {request_id: current_time}) |
|
|
pipe.zadd(hour_key, {request_id: current_time}) |
|
|
pipe.zadd(day_key, {request_id: current_time}) |
|
|
|
|
|
|
|
|
pipe.decr(burst_key) |
|
|
|
|
|
|
|
|
pipe.expire(minute_key, 120) |
|
|
pipe.expire(hour_key, 7200) |
|
|
pipe.expire(day_key, 172800) |
|
|
pipe.expire(burst_key, 60) |
|
|
|
|
|
|
|
|
pipe.zremrangebyscore(minute_key, 0, current_time - 60) |
|
|
pipe.zremrangebyscore(hour_key, 0, current_time - 3600) |
|
|
pipe.zremrangebyscore(day_key, 0, current_time - 86400) |
|
|
|
|
|
await pipe.execute() |
|
|
|
|
|
|
|
|
await self._replenish_burst_tokens(burst_key, burst_tokens) |
|
|
|
|
|
return True, { |
|
|
"allowed": True, |
|
|
"minute_count": minute_count + 1, |
|
|
"hour_count": hour_count + 1, |
|
|
"day_count": day_count + 1, |
|
|
"burst_remaining": burst_tokens - 1, |
|
|
"limits": { |
|
|
"per_minute": self.requests_per_minute, |
|
|
"per_hour": self.requests_per_hour, |
|
|
"per_day": self.requests_per_day, |
|
|
"burst": self.burst_size |
|
|
} |
|
|
} |
|
|
|
|
|
except RedisError as e: |
|
|
logger.error(f"Redis error in rate limiting: {e}") |
|
|
|
|
|
return True, {"error": "rate_limit_unavailable"} |
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected error in rate limiting: {e}") |
|
|
return True, {"error": "rate_limit_error"} |
|
|
|
|
|
async def _replenish_burst_tokens(self, burst_key: str, current_tokens: int): |
|
|
"""Replenish burst tokens over time""" |
|
|
if current_tokens < self.burst_size: |
|
|
try: |
|
|
redis_client = await self.get_redis() |
|
|
|
|
|
await redis_client.set( |
|
|
burst_key, |
|
|
min(current_tokens + 1, self.burst_size), |
|
|
ex=60, |
|
|
nx=True |
|
|
) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
async def reset_limits(self, identifier: str, endpoint: Optional[str] = None): |
|
|
"""Reset rate limits for an identifier""" |
|
|
try: |
|
|
redis_client = await self.get_redis() |
|
|
key_prefix = f"rate_limit:{identifier}" |
|
|
if endpoint: |
|
|
key_prefix += f":{endpoint}" |
|
|
|
|
|
keys = [ |
|
|
f"{key_prefix}:minute", |
|
|
f"{key_prefix}:hour", |
|
|
f"{key_prefix}:day", |
|
|
f"{key_prefix}:burst" |
|
|
] |
|
|
|
|
|
if keys: |
|
|
await redis_client.delete(*keys) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error resetting limits: {e}") |
|
|
|
|
|
async def get_limit_status(self, identifier: str, endpoint: Optional[str] = None) -> Dict[str, any]: |
|
|
"""Get current limit status for an identifier""" |
|
|
try: |
|
|
redis_client = await self.get_redis() |
|
|
current_time = time.time() |
|
|
|
|
|
key_prefix = f"rate_limit:{identifier}" |
|
|
if endpoint: |
|
|
key_prefix += f":{endpoint}" |
|
|
|
|
|
minute_key = f"{key_prefix}:minute" |
|
|
hour_key = f"{key_prefix}:hour" |
|
|
day_key = f"{key_prefix}:day" |
|
|
burst_key = f"{key_prefix}:burst" |
|
|
|
|
|
async with redis_client.pipeline() as pipe: |
|
|
pipe.zcount(minute_key, current_time - 60, current_time) |
|
|
pipe.zcount(hour_key, current_time - 3600, current_time) |
|
|
pipe.zcount(day_key, current_time - 86400, current_time) |
|
|
pipe.get(burst_key) |
|
|
|
|
|
results = await pipe.execute() |
|
|
|
|
|
return { |
|
|
"minute": { |
|
|
"used": results[0], |
|
|
"limit": self.requests_per_minute, |
|
|
"remaining": self.requests_per_minute - results[0] |
|
|
}, |
|
|
"hour": { |
|
|
"used": results[1], |
|
|
"limit": self.requests_per_hour, |
|
|
"remaining": self.requests_per_hour - results[1] |
|
|
}, |
|
|
"day": { |
|
|
"used": results[2], |
|
|
"limit": self.requests_per_day, |
|
|
"remaining": self.requests_per_day - results[2] |
|
|
}, |
|
|
"burst": { |
|
|
"remaining": int(results[3] or self.burst_size), |
|
|
"limit": self.burst_size |
|
|
} |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error getting limit status: {e}") |
|
|
return {} |
|
|
|
|
|
async def cleanup_old_entries(self): |
|
|
"""Clean up old entries from all rate limit keys""" |
|
|
try: |
|
|
redis_client = await self.get_redis() |
|
|
current_time = time.time() |
|
|
|
|
|
|
|
|
cursor = 0 |
|
|
while True: |
|
|
cursor, keys = await redis_client.scan( |
|
|
cursor, |
|
|
match="rate_limit:*", |
|
|
count=100 |
|
|
) |
|
|
|
|
|
for key in keys: |
|
|
if key.endswith(":minute"): |
|
|
await redis_client.zremrangebyscore(key, 0, current_time - 60) |
|
|
elif key.endswith(":hour"): |
|
|
await redis_client.zremrangebyscore(key, 0, current_time - 3600) |
|
|
elif key.endswith(":day"): |
|
|
await redis_client.zremrangebyscore(key, 0, current_time - 86400) |
|
|
|
|
|
if cursor == 0: |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error cleaning up old entries: {e}") |
|
|
|
|
|
async def close(self): |
|
|
"""Close Redis connection""" |
|
|
if self._redis_client: |
|
|
await self._redis_client.close() |
|
|
|
|
|
|
|
|
|
|
|
_rate_limiter = None |
|
|
|
|
|
def get_rate_limiter() -> DistributedRateLimiter: |
|
|
"""Get singleton rate limiter instance""" |
|
|
global _rate_limiter |
|
|
if _rate_limiter is None: |
|
|
_rate_limiter = DistributedRateLimiter( |
|
|
requests_per_minute=settings.RATE_LIMIT_PER_MINUTE, |
|
|
requests_per_hour=settings.RATE_LIMIT_PER_HOUR, |
|
|
requests_per_day=settings.RATE_LIMIT_PER_DAY |
|
|
) |
|
|
return _rate_limiter |