anderson-ufrj commited on
Commit
48a4081
·
1 Parent(s): 07ba503

test(rate-limit): implement rate limiting tests

Browse files

- Test token bucket implementation
- Test multi-window rate limiting
- Test memory-based rate limit store
- Test Redis-based distributed rate limiting
- Test rate limit middleware integration
- Add client identification tests

Files changed (1) hide show
  1. tests/unit/test_rate_limiting.py +395 -0
tests/unit/test_rate_limiting.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for rate limiting middleware."""
2
+ import pytest
3
+ from unittest.mock import MagicMock, patch, AsyncMock
4
+ from fastapi import Request
5
+ from datetime import datetime, timedelta
6
+ import asyncio
7
+
8
+ from src.api.middleware.rate_limiting import (
9
+ RateLimitMiddleware,
10
+ RateLimitConfig,
11
+ TokenBucket,
12
+ MemoryRateLimitStore,
13
+ RedisRateLimitStore,
14
+ get_client_id
15
+ )
16
+ from src.core.exceptions import RateLimitError
17
+
18
+
19
+ class TestTokenBucket:
20
+ """Test TokenBucket implementation."""
21
+
22
+ def test_token_bucket_initialization(self):
23
+ """Test token bucket initializes with full capacity."""
24
+ capacity = 10
25
+ refill_rate = 2.0 # 2 tokens per second
26
+
27
+ bucket = TokenBucket(capacity, refill_rate)
28
+
29
+ assert bucket.capacity == capacity
30
+ assert bucket.tokens == capacity
31
+ assert bucket.refill_rate == refill_rate
32
+
33
+ def test_consume_tokens_success(self):
34
+ """Test consuming tokens when available."""
35
+ bucket = TokenBucket(capacity=10, refill_rate=1.0)
36
+
37
+ # Should be able to consume 5 tokens
38
+ assert bucket.consume(5) is True
39
+ assert bucket.tokens == 5
40
+
41
+ def test_consume_tokens_insufficient(self):
42
+ """Test consuming tokens when insufficient."""
43
+ bucket = TokenBucket(capacity=10, refill_rate=1.0)
44
+
45
+ # Consume all tokens
46
+ bucket.consume(10)
47
+
48
+ # Should not be able to consume more
49
+ assert bucket.consume(1) is False
50
+ assert bucket.tokens == 0
51
+
52
+ def test_token_refill(self):
53
+ """Test token refill over time."""
54
+ bucket = TokenBucket(capacity=10, refill_rate=10.0) # 10 tokens/second
55
+
56
+ # Consume all tokens
57
+ bucket.consume(10)
58
+ assert bucket.tokens == 0
59
+
60
+ # Wait for refill (0.5 seconds = 5 tokens)
61
+ bucket.last_refill = datetime.utcnow() - timedelta(seconds=0.5)
62
+
63
+ # Try to consume, should trigger refill
64
+ assert bucket.consume(5) is True
65
+ assert bucket.tokens == 0 # 5 refilled, 5 consumed
66
+
67
+ def test_token_refill_cap(self):
68
+ """Test token refill doesn't exceed capacity."""
69
+ bucket = TokenBucket(capacity=10, refill_rate=100.0) # Very fast refill
70
+
71
+ # Wait a long time
72
+ bucket.last_refill = datetime.utcnow() - timedelta(seconds=10)
73
+
74
+ # Refill should cap at capacity
75
+ bucket._refill()
76
+ assert bucket.tokens == 10 # Not more than capacity
77
+
78
+ def test_concurrent_token_consumption(self):
79
+ """Test thread-safe token consumption."""
80
+ bucket = TokenBucket(capacity=100, refill_rate=0) # No refill
81
+ consumed_count = 0
82
+
83
+ def consume_tokens():
84
+ nonlocal consumed_count
85
+ if bucket.consume(1):
86
+ consumed_count += 1
87
+
88
+ # Simulate concurrent access
89
+ import threading
90
+ threads = []
91
+ for _ in range(150): # More than capacity
92
+ thread = threading.Thread(target=consume_tokens)
93
+ threads.append(thread)
94
+ thread.start()
95
+
96
+ for thread in threads:
97
+ thread.join()
98
+
99
+ # Should only consume up to capacity
100
+ assert consumed_count == 100
101
+
102
+
103
+ class TestMemoryRateLimitStore:
104
+ """Test in-memory rate limit store."""
105
+
106
+ @pytest.fixture
107
+ def store(self):
108
+ """Create memory store instance."""
109
+ return MemoryRateLimitStore()
110
+
111
+ @pytest.mark.asyncio
112
+ async def test_get_bucket_creates_new(self, store):
113
+ """Test getting bucket creates new one if not exists."""
114
+ client_id = "test_client"
115
+ endpoint = "/api/test"
116
+ config = RateLimitConfig(requests_per_minute=60)
117
+
118
+ bucket = await store.get_bucket(client_id, endpoint, config)
119
+
120
+ assert bucket is not None
121
+ assert bucket.capacity == 60
122
+ assert bucket.refill_rate == 1.0 # 60 per minute = 1 per second
123
+
124
+ @pytest.mark.asyncio
125
+ async def test_get_bucket_returns_existing(self, store):
126
+ """Test getting bucket returns existing one."""
127
+ client_id = "test_client"
128
+ endpoint = "/api/test"
129
+ config = RateLimitConfig(requests_per_minute=60)
130
+
131
+ # Get bucket twice
132
+ bucket1 = await store.get_bucket(client_id, endpoint, config)
133
+ bucket2 = await store.get_bucket(client_id, endpoint, config)
134
+
135
+ # Should be the same instance
136
+ assert bucket1 is bucket2
137
+
138
+ @pytest.mark.asyncio
139
+ async def test_different_clients_different_buckets(self, store):
140
+ """Test different clients get different buckets."""
141
+ endpoint = "/api/test"
142
+ config = RateLimitConfig(requests_per_minute=60)
143
+
144
+ bucket1 = await store.get_bucket("client1", endpoint, config)
145
+ bucket2 = await store.get_bucket("client2", endpoint, config)
146
+
147
+ # Should be different instances
148
+ assert bucket1 is not bucket2
149
+
150
+ @pytest.mark.asyncio
151
+ async def test_cleanup_old_buckets(self, store):
152
+ """Test cleanup of old unused buckets."""
153
+ client_id = "old_client"
154
+ endpoint = "/api/test"
155
+ config = RateLimitConfig(requests_per_minute=60)
156
+
157
+ # Create bucket and mark it as old
158
+ bucket = await store.get_bucket(client_id, endpoint, config)
159
+ key = f"{client_id}:{endpoint}"
160
+
161
+ # Mark as old by setting last refill to past
162
+ bucket.last_refill = datetime.utcnow() - timedelta(hours=2)
163
+
164
+ # Run cleanup
165
+ await store.cleanup()
166
+
167
+ # Old bucket should be removed
168
+ assert key not in store._buckets
169
+
170
+
171
+ class TestRedisRateLimitStore:
172
+ """Test Redis-based rate limit store."""
173
+
174
+ @pytest.fixture
175
+ def mock_redis(self):
176
+ """Create mock Redis client."""
177
+ redis = AsyncMock()
178
+ return redis
179
+
180
+ @pytest.fixture
181
+ def store(self, mock_redis):
182
+ """Create Redis store instance."""
183
+ with patch("src.api.middleware.rate_limiting.get_redis_client", return_value=mock_redis):
184
+ return RedisRateLimitStore()
185
+
186
+ @pytest.mark.asyncio
187
+ async def test_consume_token_success(self, store, mock_redis):
188
+ """Test consuming token from Redis."""
189
+ client_id = "test_client"
190
+ endpoint = "/api/test"
191
+ config = RateLimitConfig(requests_per_minute=60)
192
+
193
+ # Mock Redis responses
194
+ mock_redis.get.return_value = b"50" # Current tokens
195
+ mock_redis.set.return_value = True
196
+
197
+ # Should be able to consume
198
+ result = await store.consume_token(client_id, endpoint, config)
199
+
200
+ assert result is True
201
+ assert mock_redis.get.called
202
+ assert mock_redis.set.called
203
+
204
+ @pytest.mark.asyncio
205
+ async def test_consume_token_insufficient(self, store, mock_redis):
206
+ """Test consuming token when insufficient in Redis."""
207
+ client_id = "test_client"
208
+ endpoint = "/api/test"
209
+ config = RateLimitConfig(requests_per_minute=60)
210
+
211
+ # Mock Redis responses - no tokens left
212
+ mock_redis.get.return_value = b"0"
213
+
214
+ # Should not be able to consume
215
+ result = await store.consume_token(client_id, endpoint, config)
216
+
217
+ assert result is False
218
+
219
+ @pytest.mark.asyncio
220
+ async def test_redis_connection_error_fallback(self, store, mock_redis):
221
+ """Test fallback when Redis is unavailable."""
222
+ client_id = "test_client"
223
+ endpoint = "/api/test"
224
+ config = RateLimitConfig(requests_per_minute=60)
225
+
226
+ # Mock Redis connection error
227
+ mock_redis.get.side_effect = Exception("Redis connection error")
228
+
229
+ # Should fall back gracefully (allow request)
230
+ result = await store.consume_token(client_id, endpoint, config)
231
+
232
+ assert result is True # Fail open for availability
233
+
234
+
235
+ class TestRateLimitMiddleware:
236
+ """Test rate limiting middleware."""
237
+
238
+ @pytest.fixture
239
+ def middleware(self):
240
+ """Create middleware instance."""
241
+ app = MagicMock()
242
+ config = {
243
+ "/api/v1/investigations": RateLimitConfig(
244
+ requests_per_minute=30,
245
+ requests_per_hour=500
246
+ ),
247
+ "/api/v1/analysis": RateLimitConfig(
248
+ requests_per_minute=60,
249
+ requests_per_hour=1000
250
+ )
251
+ }
252
+ return RateLimitMiddleware(app, config)
253
+
254
+ @pytest.fixture
255
+ def mock_request(self):
256
+ """Create mock request."""
257
+ request = MagicMock(spec=Request)
258
+ request.client.host = "192.168.1.100"
259
+ request.headers = {
260
+ "user-agent": "TestClient/1.0"
261
+ }
262
+ request.url.path = "/api/v1/investigations"
263
+ request.method = "GET"
264
+ request.state = MagicMock()
265
+ return request
266
+
267
+ @pytest.mark.asyncio
268
+ async def test_rate_limit_allows_under_limit(self, middleware, mock_request):
269
+ """Test requests are allowed under rate limit."""
270
+ response = MagicMock()
271
+ call_next = AsyncMock(return_value=response)
272
+
273
+ # Should allow request
274
+ result = await middleware.dispatch(mock_request, call_next)
275
+
276
+ assert result == response
277
+ assert call_next.called
278
+
279
+ @pytest.mark.asyncio
280
+ async def test_rate_limit_blocks_over_limit(self, middleware, mock_request):
281
+ """Test requests are blocked over rate limit."""
282
+ call_next = AsyncMock()
283
+
284
+ # Exhaust rate limit
285
+ config = middleware.endpoint_limits["/api/v1/investigations"]
286
+ bucket = await middleware.store.get_bucket(
287
+ "192.168.1.100",
288
+ "/api/v1/investigations",
289
+ config
290
+ )
291
+ bucket.tokens = 0 # No tokens left
292
+
293
+ # Should block request
294
+ result = await middleware.dispatch(mock_request, call_next)
295
+
296
+ assert result.status_code == 429
297
+ assert b"Rate limit exceeded" in result.body
298
+ assert not call_next.called
299
+
300
+ @pytest.mark.asyncio
301
+ async def test_rate_limit_headers(self, middleware, mock_request):
302
+ """Test rate limit headers are added to response."""
303
+ response = MagicMock()
304
+ response.headers = {}
305
+ call_next = AsyncMock(return_value=response)
306
+
307
+ # Process request
308
+ result = await middleware.dispatch(mock_request, call_next)
309
+
310
+ # Check rate limit headers
311
+ assert "X-RateLimit-Limit" in result.headers
312
+ assert "X-RateLimit-Remaining" in result.headers
313
+ assert "X-RateLimit-Reset" in result.headers
314
+
315
+ @pytest.mark.asyncio
316
+ async def test_authenticated_user_priority(self, middleware, mock_request):
317
+ """Test authenticated users get different rate limits."""
318
+ # Add authenticated user to request
319
+ mock_request.state.user = MagicMock(id="user123", role="premium")
320
+
321
+ response = MagicMock()
322
+ call_next = AsyncMock(return_value=response)
323
+
324
+ # Should use user ID for rate limiting
325
+ result = await middleware.dispatch(mock_request, call_next)
326
+
327
+ assert result == response
328
+ # Verify different bucket would be used for user
329
+
330
+ @pytest.mark.asyncio
331
+ async def test_exempt_endpoints(self, middleware):
332
+ """Test certain endpoints are exempt from rate limiting."""
333
+ # Health check should be exempt
334
+ request = MagicMock()
335
+ request.url.path = "/health"
336
+ request.client.host = "192.168.1.100"
337
+
338
+ response = MagicMock()
339
+ call_next = AsyncMock(return_value=response)
340
+
341
+ # Should not apply rate limiting
342
+ result = await middleware.dispatch(request, call_next)
343
+
344
+ assert result == response
345
+ assert "X-RateLimit-Limit" not in response.headers
346
+
347
+
348
+ class TestGetClientId:
349
+ """Test client ID extraction."""
350
+
351
+ def test_get_client_id_from_ip(self):
352
+ """Test getting client ID from IP address."""
353
+ request = MagicMock()
354
+ request.client.host = "192.168.1.100"
355
+ request.state = MagicMock()
356
+ request.state.user = None
357
+
358
+ client_id = get_client_id(request)
359
+
360
+ assert client_id == "192.168.1.100"
361
+
362
+ def test_get_client_id_from_user(self):
363
+ """Test getting client ID from authenticated user."""
364
+ request = MagicMock()
365
+ request.client.host = "192.168.1.100"
366
+ request.state = MagicMock()
367
+ request.state.user = MagicMock(id="user123")
368
+
369
+ client_id = get_client_id(request)
370
+
371
+ assert client_id == "user:user123"
372
+
373
+ def test_get_client_id_from_api_key(self):
374
+ """Test getting client ID from API key."""
375
+ request = MagicMock()
376
+ request.client.host = "192.168.1.100"
377
+ request.headers = {"X-API-Key": "sk_test_abc123"}
378
+ request.state = MagicMock()
379
+ request.state.user = None
380
+
381
+ client_id = get_client_id(request)
382
+
383
+ assert client_id == "api:sk_test_abc123"
384
+
385
+ def test_get_client_id_with_forwarded_ip(self):
386
+ """Test getting client ID from X-Forwarded-For header."""
387
+ request = MagicMock()
388
+ request.client.host = "10.0.0.1" # Proxy IP
389
+ request.headers = {"X-Forwarded-For": "203.0.113.1, 10.0.0.1"}
390
+ request.state = MagicMock()
391
+ request.state.user = None
392
+
393
+ client_id = get_client_id(request)
394
+
395
+ assert client_id == "203.0.113.1" # Original client IP