|
|
""" |
|
|
Request correlation ID management for distributed tracing. |
|
|
|
|
|
This module provides correlation ID generation, propagation, |
|
|
and context management across service boundaries. |
|
|
""" |
|
|
|
|
|
import uuid |
|
|
import asyncio |
|
|
from typing import Optional, Dict, Any, Callable, List |
|
|
from contextvars import ContextVar |
|
|
from functools import wraps |
|
|
import time |
|
|
|
|
|
from fastapi import Request, Response |
|
|
from starlette.middleware.base import BaseHTTPMiddleware |
|
|
|
|
|
from src.core import get_logger |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
correlation_id_ctx: ContextVar[Optional[str]] = ContextVar('correlation_id', default=None) |
|
|
request_id_ctx: ContextVar[Optional[str]] = ContextVar('request_id', default=None) |
|
|
user_id_ctx: ContextVar[Optional[str]] = ContextVar('user_id', default=None) |
|
|
session_id_ctx: ContextVar[Optional[str]] = ContextVar('session_id', default=None) |
|
|
span_id_ctx: ContextVar[Optional[str]] = ContextVar('span_id', default=None) |
|
|
|
|
|
|
|
|
CORRELATION_ID_HEADER = "X-Correlation-ID" |
|
|
REQUEST_ID_HEADER = "X-Request-ID" |
|
|
USER_ID_HEADER = "X-User-ID" |
|
|
SESSION_ID_HEADER = "X-Session-ID" |
|
|
SPAN_ID_HEADER = "X-Span-ID" |
|
|
|
|
|
|
|
|
class CorrelationContext: |
|
|
""" |
|
|
Utility class for managing correlation context. |
|
|
|
|
|
Provides methods to get, set, and propagate correlation IDs |
|
|
across async boundaries and service calls. |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def get_correlation_id() -> str: |
|
|
""" |
|
|
Get current correlation ID, generating one if needed. |
|
|
|
|
|
Returns: |
|
|
Correlation ID string |
|
|
""" |
|
|
correlation_id = correlation_id_ctx.get() |
|
|
if not correlation_id: |
|
|
correlation_id = str(uuid.uuid4()) |
|
|
correlation_id_ctx.set(correlation_id) |
|
|
return correlation_id |
|
|
|
|
|
@staticmethod |
|
|
def set_correlation_id(correlation_id: str): |
|
|
""" |
|
|
Set correlation ID in context. |
|
|
|
|
|
Args: |
|
|
correlation_id: Correlation ID to set |
|
|
""" |
|
|
correlation_id_ctx.set(correlation_id) |
|
|
|
|
|
@staticmethod |
|
|
def get_request_id() -> str: |
|
|
""" |
|
|
Get current request ID, generating one if needed. |
|
|
|
|
|
Returns: |
|
|
Request ID string |
|
|
""" |
|
|
request_id = request_id_ctx.get() |
|
|
if not request_id: |
|
|
request_id = str(uuid.uuid4()) |
|
|
request_id_ctx.set(request_id) |
|
|
return request_id |
|
|
|
|
|
@staticmethod |
|
|
def set_request_id(request_id: str): |
|
|
""" |
|
|
Set request ID in context. |
|
|
|
|
|
Args: |
|
|
request_id: Request ID to set |
|
|
""" |
|
|
request_id_ctx.set(request_id) |
|
|
|
|
|
@staticmethod |
|
|
def get_user_id() -> Optional[str]: |
|
|
"""Get current user ID from context.""" |
|
|
return user_id_ctx.get() |
|
|
|
|
|
@staticmethod |
|
|
def set_user_id(user_id: str): |
|
|
""" |
|
|
Set user ID in context. |
|
|
|
|
|
Args: |
|
|
user_id: User ID to set |
|
|
""" |
|
|
user_id_ctx.set(user_id) |
|
|
|
|
|
@staticmethod |
|
|
def get_session_id() -> Optional[str]: |
|
|
"""Get current session ID from context.""" |
|
|
return session_id_ctx.get() |
|
|
|
|
|
@staticmethod |
|
|
def set_session_id(session_id: str): |
|
|
""" |
|
|
Set session ID in context. |
|
|
|
|
|
Args: |
|
|
session_id: Session ID to set |
|
|
""" |
|
|
session_id_ctx.set(session_id) |
|
|
|
|
|
@staticmethod |
|
|
def get_span_id() -> Optional[str]: |
|
|
"""Get current span ID from context.""" |
|
|
return span_id_ctx.get() |
|
|
|
|
|
@staticmethod |
|
|
def set_span_id(span_id: str): |
|
|
""" |
|
|
Set span ID in context. |
|
|
|
|
|
Args: |
|
|
span_id: Span ID to set |
|
|
""" |
|
|
span_id_ctx.set(span_id) |
|
|
|
|
|
@staticmethod |
|
|
def get_all_ids() -> Dict[str, Optional[str]]: |
|
|
""" |
|
|
Get all correlation IDs from context. |
|
|
|
|
|
Returns: |
|
|
Dictionary with all correlation IDs |
|
|
""" |
|
|
return { |
|
|
"correlation_id": correlation_id_ctx.get(), |
|
|
"request_id": request_id_ctx.get(), |
|
|
"user_id": user_id_ctx.get(), |
|
|
"session_id": session_id_ctx.get(), |
|
|
"span_id": span_id_ctx.get() |
|
|
} |
|
|
|
|
|
@staticmethod |
|
|
def clear_context(): |
|
|
"""Clear all correlation context.""" |
|
|
correlation_id_ctx.set(None) |
|
|
request_id_ctx.set(None) |
|
|
user_id_ctx.set(None) |
|
|
session_id_ctx.set(None) |
|
|
span_id_ctx.set(None) |
|
|
|
|
|
@staticmethod |
|
|
def copy_context() -> Dict[str, Optional[str]]: |
|
|
""" |
|
|
Copy current context for propagation. |
|
|
|
|
|
Returns: |
|
|
Dictionary with current context values |
|
|
""" |
|
|
return CorrelationContext.get_all_ids() |
|
|
|
|
|
@staticmethod |
|
|
def restore_context(context: Dict[str, Optional[str]]): |
|
|
""" |
|
|
Restore context from dictionary. |
|
|
|
|
|
Args: |
|
|
context: Context dictionary to restore |
|
|
""" |
|
|
if context.get("correlation_id"): |
|
|
correlation_id_ctx.set(context["correlation_id"]) |
|
|
if context.get("request_id"): |
|
|
request_id_ctx.set(context["request_id"]) |
|
|
if context.get("user_id"): |
|
|
user_id_ctx.set(context["user_id"]) |
|
|
if context.get("session_id"): |
|
|
session_id_ctx.set(context["session_id"]) |
|
|
if context.get("span_id"): |
|
|
span_id_ctx.set(context["span_id"]) |
|
|
|
|
|
|
|
|
class CorrelationMiddleware(BaseHTTPMiddleware): |
|
|
""" |
|
|
Middleware for correlation ID management in FastAPI. |
|
|
|
|
|
Automatically extracts correlation IDs from headers, |
|
|
generates new ones if missing, and adds them to responses. |
|
|
""" |
|
|
|
|
|
def __init__(self, app, generate_request_id: bool = True): |
|
|
""" |
|
|
Initialize correlation middleware. |
|
|
|
|
|
Args: |
|
|
app: FastAPI application |
|
|
generate_request_id: Whether to generate request IDs |
|
|
""" |
|
|
super().__init__(app) |
|
|
self.generate_request_id = generate_request_id |
|
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response: |
|
|
""" |
|
|
Process request with correlation ID management. |
|
|
|
|
|
Args: |
|
|
request: Incoming request |
|
|
call_next: Next middleware in chain |
|
|
|
|
|
Returns: |
|
|
Response with correlation headers |
|
|
""" |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
correlation_id = ( |
|
|
request.headers.get(CORRELATION_ID_HEADER) or |
|
|
str(uuid.uuid4()) |
|
|
) |
|
|
CorrelationContext.set_correlation_id(correlation_id) |
|
|
|
|
|
|
|
|
if self.generate_request_id: |
|
|
request_id = ( |
|
|
request.headers.get(REQUEST_ID_HEADER) or |
|
|
str(uuid.uuid4()) |
|
|
) |
|
|
CorrelationContext.set_request_id(request_id) |
|
|
|
|
|
|
|
|
user_id = request.headers.get(USER_ID_HEADER) |
|
|
if user_id: |
|
|
CorrelationContext.set_user_id(user_id) |
|
|
|
|
|
session_id = request.headers.get(SESSION_ID_HEADER) |
|
|
if session_id: |
|
|
CorrelationContext.set_session_id(session_id) |
|
|
|
|
|
span_id = request.headers.get(SPAN_ID_HEADER) |
|
|
if span_id: |
|
|
CorrelationContext.set_span_id(span_id) |
|
|
|
|
|
|
|
|
logger.info( |
|
|
"Request started", |
|
|
extra={ |
|
|
"correlation_id": correlation_id, |
|
|
"request_id": CorrelationContext.get_request_id(), |
|
|
"method": request.method, |
|
|
"url": str(request.url), |
|
|
"user_agent": request.headers.get("user-agent"), |
|
|
"client_ip": request.client.host if request.client else None |
|
|
} |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
response = await call_next(request) |
|
|
|
|
|
|
|
|
response.headers[CORRELATION_ID_HEADER] = correlation_id |
|
|
|
|
|
if self.generate_request_id: |
|
|
response.headers[REQUEST_ID_HEADER] = CorrelationContext.get_request_id() |
|
|
|
|
|
|
|
|
duration = time.time() - start_time |
|
|
logger.info( |
|
|
"Request completed", |
|
|
extra={ |
|
|
"correlation_id": correlation_id, |
|
|
"request_id": CorrelationContext.get_request_id(), |
|
|
"status_code": response.status_code, |
|
|
"duration_ms": duration * 1000, |
|
|
"response_size": response.headers.get("content-length") |
|
|
} |
|
|
) |
|
|
|
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
duration = time.time() - start_time |
|
|
logger.error( |
|
|
"Request failed", |
|
|
extra={ |
|
|
"correlation_id": correlation_id, |
|
|
"request_id": CorrelationContext.get_request_id(), |
|
|
"error": str(e), |
|
|
"error_type": type(e).__name__, |
|
|
"duration_ms": duration * 1000 |
|
|
}, |
|
|
exc_info=True |
|
|
) |
|
|
raise |
|
|
|
|
|
finally: |
|
|
|
|
|
CorrelationContext.clear_context() |
|
|
|
|
|
|
|
|
def propagate_correlation( |
|
|
headers: Optional[Dict[str, str]] = None |
|
|
) -> Dict[str, str]: |
|
|
""" |
|
|
Generate headers for correlation propagation. |
|
|
|
|
|
Args: |
|
|
headers: Existing headers to extend |
|
|
|
|
|
Returns: |
|
|
Headers with correlation information |
|
|
""" |
|
|
propagation_headers = headers.copy() if headers else {} |
|
|
|
|
|
correlation_id = CorrelationContext.get_correlation_id() |
|
|
if correlation_id: |
|
|
propagation_headers[CORRELATION_ID_HEADER] = correlation_id |
|
|
|
|
|
request_id = CorrelationContext.get_request_id() |
|
|
if request_id: |
|
|
propagation_headers[REQUEST_ID_HEADER] = request_id |
|
|
|
|
|
user_id = CorrelationContext.get_user_id() |
|
|
if user_id: |
|
|
propagation_headers[USER_ID_HEADER] = user_id |
|
|
|
|
|
session_id = CorrelationContext.get_session_id() |
|
|
if session_id: |
|
|
propagation_headers[SESSION_ID_HEADER] = session_id |
|
|
|
|
|
span_id = CorrelationContext.get_span_id() |
|
|
if span_id: |
|
|
propagation_headers[SPAN_ID_HEADER] = span_id |
|
|
|
|
|
return propagation_headers |
|
|
|
|
|
|
|
|
def with_correlation(func: Callable) -> Callable: |
|
|
""" |
|
|
Decorator to preserve correlation context in async functions. |
|
|
|
|
|
Args: |
|
|
func: Function to wrap |
|
|
|
|
|
Returns: |
|
|
Wrapped function with correlation context |
|
|
""" |
|
|
@wraps(func) |
|
|
async def async_wrapper(*args, **kwargs): |
|
|
|
|
|
context = CorrelationContext.copy_context() |
|
|
|
|
|
try: |
|
|
if asyncio.iscoroutinefunction(func): |
|
|
result = await func(*args, **kwargs) |
|
|
else: |
|
|
result = func(*args, **kwargs) |
|
|
return result |
|
|
finally: |
|
|
|
|
|
if not CorrelationContext.get_correlation_id() and context.get("correlation_id"): |
|
|
CorrelationContext.restore_context(context) |
|
|
|
|
|
@wraps(func) |
|
|
def sync_wrapper(*args, **kwargs): |
|
|
|
|
|
return func(*args, **kwargs) |
|
|
|
|
|
if asyncio.iscoroutinefunction(func): |
|
|
return async_wrapper |
|
|
else: |
|
|
return sync_wrapper |
|
|
|
|
|
|
|
|
class CorrelationLogger: |
|
|
""" |
|
|
Logger wrapper that automatically includes correlation IDs. |
|
|
""" |
|
|
|
|
|
def __init__(self, logger_instance): |
|
|
""" |
|
|
Initialize correlation logger. |
|
|
|
|
|
Args: |
|
|
logger_instance: Logger instance to wrap |
|
|
""" |
|
|
self.logger = logger_instance |
|
|
|
|
|
def _add_correlation_extra(self, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: |
|
|
"""Add correlation IDs to log extra data.""" |
|
|
correlation_extra = extra.copy() if extra else {} |
|
|
|
|
|
|
|
|
correlation_id = CorrelationContext.get_correlation_id() |
|
|
if correlation_id: |
|
|
correlation_extra["correlation_id"] = correlation_id |
|
|
|
|
|
request_id = CorrelationContext.get_request_id() |
|
|
if request_id: |
|
|
correlation_extra["request_id"] = request_id |
|
|
|
|
|
user_id = CorrelationContext.get_user_id() |
|
|
if user_id: |
|
|
correlation_extra["user_id"] = user_id |
|
|
|
|
|
session_id = CorrelationContext.get_session_id() |
|
|
if session_id: |
|
|
correlation_extra["session_id"] = session_id |
|
|
|
|
|
return correlation_extra |
|
|
|
|
|
def debug(self, msg: str, *args, extra: Optional[Dict[str, Any]] = None, **kwargs): |
|
|
"""Log debug message with correlation IDs.""" |
|
|
self.logger.debug(msg, *args, extra=self._add_correlation_extra(extra), **kwargs) |
|
|
|
|
|
def info(self, msg: str, *args, extra: Optional[Dict[str, Any]] = None, **kwargs): |
|
|
"""Log info message with correlation IDs.""" |
|
|
self.logger.info(msg, *args, extra=self._add_correlation_extra(extra), **kwargs) |
|
|
|
|
|
def warning(self, msg: str, *args, extra: Optional[Dict[str, Any]] = None, **kwargs): |
|
|
"""Log warning message with correlation IDs.""" |
|
|
self.logger.warning(msg, *args, extra=self._add_correlation_extra(extra), **kwargs) |
|
|
|
|
|
def error(self, msg: str, *args, extra: Optional[Dict[str, Any]] = None, **kwargs): |
|
|
"""Log error message with correlation IDs.""" |
|
|
self.logger.error(msg, *args, extra=self._add_correlation_extra(extra), **kwargs) |
|
|
|
|
|
def critical(self, msg: str, *args, extra: Optional[Dict[str, Any]] = None, **kwargs): |
|
|
"""Log critical message with correlation IDs.""" |
|
|
self.logger.critical(msg, *args, extra=self._add_correlation_extra(extra), **kwargs) |
|
|
|
|
|
|
|
|
def get_correlation_logger(name: str) -> CorrelationLogger: |
|
|
""" |
|
|
Get a correlation-aware logger. |
|
|
|
|
|
Args: |
|
|
name: Logger name |
|
|
|
|
|
Returns: |
|
|
CorrelationLogger instance |
|
|
""" |
|
|
from src.core import get_logger |
|
|
base_logger = get_logger(name) |
|
|
return CorrelationLogger(base_logger) |
|
|
|
|
|
|
|
|
class RequestTracker: |
|
|
""" |
|
|
Track request lifecycle and performance metrics. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize request tracker.""" |
|
|
self.active_requests: Dict[str, Dict[str, Any]] = {} |
|
|
self.request_stats = { |
|
|
"total_requests": 0, |
|
|
"active_requests": 0, |
|
|
"avg_duration_ms": 0.0, |
|
|
"error_rate": 0.0 |
|
|
} |
|
|
|
|
|
def start_request( |
|
|
self, |
|
|
request_id: str, |
|
|
method: str, |
|
|
path: str, |
|
|
user_id: Optional[str] = None |
|
|
): |
|
|
""" |
|
|
Start tracking a request. |
|
|
|
|
|
Args: |
|
|
request_id: Request ID |
|
|
method: HTTP method |
|
|
path: Request path |
|
|
user_id: Optional user ID |
|
|
""" |
|
|
self.active_requests[request_id] = { |
|
|
"start_time": time.time(), |
|
|
"method": method, |
|
|
"path": path, |
|
|
"user_id": user_id, |
|
|
"correlation_id": CorrelationContext.get_correlation_id() |
|
|
} |
|
|
|
|
|
self.request_stats["active_requests"] = len(self.active_requests) |
|
|
self.request_stats["total_requests"] += 1 |
|
|
|
|
|
def end_request( |
|
|
self, |
|
|
request_id: str, |
|
|
status_code: int, |
|
|
error: Optional[str] = None |
|
|
) -> Optional[float]: |
|
|
""" |
|
|
End tracking a request. |
|
|
|
|
|
Args: |
|
|
request_id: Request ID |
|
|
status_code: HTTP status code |
|
|
error: Optional error message |
|
|
|
|
|
Returns: |
|
|
Request duration in seconds, or None if not found |
|
|
""" |
|
|
if request_id not in self.active_requests: |
|
|
return None |
|
|
|
|
|
request_info = self.active_requests.pop(request_id) |
|
|
duration = time.time() - request_info["start_time"] |
|
|
|
|
|
|
|
|
self.request_stats["active_requests"] = len(self.active_requests) |
|
|
|
|
|
|
|
|
current_avg = self.request_stats["avg_duration_ms"] |
|
|
new_avg = (current_avg + (duration * 1000)) / 2 |
|
|
self.request_stats["avg_duration_ms"] = new_avg |
|
|
|
|
|
|
|
|
if status_code >= 400 or error: |
|
|
total = self.request_stats["total_requests"] |
|
|
current_errors = self.request_stats["error_rate"] * (total - 1) |
|
|
new_error_rate = (current_errors + 1) / total |
|
|
self.request_stats["error_rate"] = new_error_rate |
|
|
|
|
|
return duration |
|
|
|
|
|
def get_active_requests(self) -> List[Dict[str, Any]]: |
|
|
"""Get list of currently active requests.""" |
|
|
current_time = time.time() |
|
|
return [ |
|
|
{ |
|
|
"request_id": req_id, |
|
|
"duration_ms": (current_time - info["start_time"]) * 1000, |
|
|
"method": info["method"], |
|
|
"path": info["path"], |
|
|
"user_id": info["user_id"], |
|
|
"correlation_id": info["correlation_id"] |
|
|
} |
|
|
for req_id, info in self.active_requests.items() |
|
|
] |
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
|
"""Get request tracking statistics.""" |
|
|
return self.request_stats.copy() |
|
|
|
|
|
|
|
|
|
|
|
request_tracker = RequestTracker() |