""" Request correlation ID management for distributed tracing. This module provides correlation ID generation, propagation, and context management across service boundaries. """ import uuid import asyncio from typing import Optional, Dict, Any, Callable, List from contextvars import ContextVar from functools import wraps import time from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from src.core import get_logger logger = get_logger(__name__) # Context variables for correlation tracking correlation_id_ctx: ContextVar[Optional[str]] = ContextVar('correlation_id', default=None) request_id_ctx: ContextVar[Optional[str]] = ContextVar('request_id', default=None) user_id_ctx: ContextVar[Optional[str]] = ContextVar('user_id', default=None) session_id_ctx: ContextVar[Optional[str]] = ContextVar('session_id', default=None) span_id_ctx: ContextVar[Optional[str]] = ContextVar('span_id', default=None) # Headers for correlation propagation CORRELATION_ID_HEADER = "X-Correlation-ID" REQUEST_ID_HEADER = "X-Request-ID" USER_ID_HEADER = "X-User-ID" SESSION_ID_HEADER = "X-Session-ID" SPAN_ID_HEADER = "X-Span-ID" class CorrelationContext: """ Utility class for managing correlation context. Provides methods to get, set, and propagate correlation IDs across async boundaries and service calls. """ @staticmethod def get_correlation_id() -> str: """ Get current correlation ID, generating one if needed. Returns: Correlation ID string """ correlation_id = correlation_id_ctx.get() if not correlation_id: correlation_id = str(uuid.uuid4()) correlation_id_ctx.set(correlation_id) return correlation_id @staticmethod def set_correlation_id(correlation_id: str): """ Set correlation ID in context. Args: correlation_id: Correlation ID to set """ correlation_id_ctx.set(correlation_id) @staticmethod def get_request_id() -> str: """ Get current request ID, generating one if needed. Returns: Request ID string """ request_id = request_id_ctx.get() if not request_id: request_id = str(uuid.uuid4()) request_id_ctx.set(request_id) return request_id @staticmethod def set_request_id(request_id: str): """ Set request ID in context. Args: request_id: Request ID to set """ request_id_ctx.set(request_id) @staticmethod def get_user_id() -> Optional[str]: """Get current user ID from context.""" return user_id_ctx.get() @staticmethod def set_user_id(user_id: str): """ Set user ID in context. Args: user_id: User ID to set """ user_id_ctx.set(user_id) @staticmethod def get_session_id() -> Optional[str]: """Get current session ID from context.""" return session_id_ctx.get() @staticmethod def set_session_id(session_id: str): """ Set session ID in context. Args: session_id: Session ID to set """ session_id_ctx.set(session_id) @staticmethod def get_span_id() -> Optional[str]: """Get current span ID from context.""" return span_id_ctx.get() @staticmethod def set_span_id(span_id: str): """ Set span ID in context. Args: span_id: Span ID to set """ span_id_ctx.set(span_id) @staticmethod def get_all_ids() -> Dict[str, Optional[str]]: """ Get all correlation IDs from context. Returns: Dictionary with all correlation IDs """ return { "correlation_id": correlation_id_ctx.get(), "request_id": request_id_ctx.get(), "user_id": user_id_ctx.get(), "session_id": session_id_ctx.get(), "span_id": span_id_ctx.get() } @staticmethod def clear_context(): """Clear all correlation context.""" correlation_id_ctx.set(None) request_id_ctx.set(None) user_id_ctx.set(None) session_id_ctx.set(None) span_id_ctx.set(None) @staticmethod def copy_context() -> Dict[str, Optional[str]]: """ Copy current context for propagation. Returns: Dictionary with current context values """ return CorrelationContext.get_all_ids() @staticmethod def restore_context(context: Dict[str, Optional[str]]): """ Restore context from dictionary. Args: context: Context dictionary to restore """ if context.get("correlation_id"): correlation_id_ctx.set(context["correlation_id"]) if context.get("request_id"): request_id_ctx.set(context["request_id"]) if context.get("user_id"): user_id_ctx.set(context["user_id"]) if context.get("session_id"): session_id_ctx.set(context["session_id"]) if context.get("span_id"): span_id_ctx.set(context["span_id"]) class CorrelationMiddleware(BaseHTTPMiddleware): """ Middleware for correlation ID management in FastAPI. Automatically extracts correlation IDs from headers, generates new ones if missing, and adds them to responses. """ def __init__(self, app, generate_request_id: bool = True): """ Initialize correlation middleware. Args: app: FastAPI application generate_request_id: Whether to generate request IDs """ super().__init__(app) self.generate_request_id = generate_request_id async def dispatch(self, request: Request, call_next: Callable) -> Response: """ Process request with correlation ID management. Args: request: Incoming request call_next: Next middleware in chain Returns: Response with correlation headers """ start_time = time.time() # Extract or generate correlation ID correlation_id = ( request.headers.get(CORRELATION_ID_HEADER) or str(uuid.uuid4()) ) CorrelationContext.set_correlation_id(correlation_id) # Extract or generate request ID if self.generate_request_id: request_id = ( request.headers.get(REQUEST_ID_HEADER) or str(uuid.uuid4()) ) CorrelationContext.set_request_id(request_id) # Extract user context if available user_id = request.headers.get(USER_ID_HEADER) if user_id: CorrelationContext.set_user_id(user_id) session_id = request.headers.get(SESSION_ID_HEADER) if session_id: CorrelationContext.set_session_id(session_id) span_id = request.headers.get(SPAN_ID_HEADER) if span_id: CorrelationContext.set_span_id(span_id) # Log request start logger.info( "Request started", extra={ "correlation_id": correlation_id, "request_id": CorrelationContext.get_request_id(), "method": request.method, "url": str(request.url), "user_agent": request.headers.get("user-agent"), "client_ip": request.client.host if request.client else None } ) try: # Process request response = await call_next(request) # Add correlation headers to response response.headers[CORRELATION_ID_HEADER] = correlation_id if self.generate_request_id: response.headers[REQUEST_ID_HEADER] = CorrelationContext.get_request_id() # Log successful response duration = time.time() - start_time logger.info( "Request completed", extra={ "correlation_id": correlation_id, "request_id": CorrelationContext.get_request_id(), "status_code": response.status_code, "duration_ms": duration * 1000, "response_size": response.headers.get("content-length") } ) return response except Exception as e: # Log error duration = time.time() - start_time logger.error( "Request failed", extra={ "correlation_id": correlation_id, "request_id": CorrelationContext.get_request_id(), "error": str(e), "error_type": type(e).__name__, "duration_ms": duration * 1000 }, exc_info=True ) raise finally: # Clear context after request CorrelationContext.clear_context() def propagate_correlation( headers: Optional[Dict[str, str]] = None ) -> Dict[str, str]: """ Generate headers for correlation propagation. Args: headers: Existing headers to extend Returns: Headers with correlation information """ propagation_headers = headers.copy() if headers else {} correlation_id = CorrelationContext.get_correlation_id() if correlation_id: propagation_headers[CORRELATION_ID_HEADER] = correlation_id request_id = CorrelationContext.get_request_id() if request_id: propagation_headers[REQUEST_ID_HEADER] = request_id user_id = CorrelationContext.get_user_id() if user_id: propagation_headers[USER_ID_HEADER] = user_id session_id = CorrelationContext.get_session_id() if session_id: propagation_headers[SESSION_ID_HEADER] = session_id span_id = CorrelationContext.get_span_id() if span_id: propagation_headers[SPAN_ID_HEADER] = span_id return propagation_headers def with_correlation(func: Callable) -> Callable: """ Decorator to preserve correlation context in async functions. Args: func: Function to wrap Returns: Wrapped function with correlation context """ @wraps(func) async def async_wrapper(*args, **kwargs): # Capture current context context = CorrelationContext.copy_context() try: if asyncio.iscoroutinefunction(func): result = await func(*args, **kwargs) else: result = func(*args, **kwargs) return result finally: # Restore context if it was cleared if not CorrelationContext.get_correlation_id() and context.get("correlation_id"): CorrelationContext.restore_context(context) @wraps(func) def sync_wrapper(*args, **kwargs): # For sync functions, just call directly return func(*args, **kwargs) if asyncio.iscoroutinefunction(func): return async_wrapper else: return sync_wrapper class CorrelationLogger: """ Logger wrapper that automatically includes correlation IDs. """ def __init__(self, logger_instance): """ Initialize correlation logger. Args: logger_instance: Logger instance to wrap """ self.logger = logger_instance def _add_correlation_extra(self, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """Add correlation IDs to log extra data.""" correlation_extra = extra.copy() if extra else {} # Add correlation IDs correlation_id = CorrelationContext.get_correlation_id() if correlation_id: correlation_extra["correlation_id"] = correlation_id request_id = CorrelationContext.get_request_id() if request_id: correlation_extra["request_id"] = request_id user_id = CorrelationContext.get_user_id() if user_id: correlation_extra["user_id"] = user_id session_id = CorrelationContext.get_session_id() if session_id: correlation_extra["session_id"] = session_id return correlation_extra def debug(self, msg: str, *args, extra: Optional[Dict[str, Any]] = None, **kwargs): """Log debug message with correlation IDs.""" self.logger.debug(msg, *args, extra=self._add_correlation_extra(extra), **kwargs) def info(self, msg: str, *args, extra: Optional[Dict[str, Any]] = None, **kwargs): """Log info message with correlation IDs.""" self.logger.info(msg, *args, extra=self._add_correlation_extra(extra), **kwargs) def warning(self, msg: str, *args, extra: Optional[Dict[str, Any]] = None, **kwargs): """Log warning message with correlation IDs.""" self.logger.warning(msg, *args, extra=self._add_correlation_extra(extra), **kwargs) def error(self, msg: str, *args, extra: Optional[Dict[str, Any]] = None, **kwargs): """Log error message with correlation IDs.""" self.logger.error(msg, *args, extra=self._add_correlation_extra(extra), **kwargs) def critical(self, msg: str, *args, extra: Optional[Dict[str, Any]] = None, **kwargs): """Log critical message with correlation IDs.""" self.logger.critical(msg, *args, extra=self._add_correlation_extra(extra), **kwargs) def get_correlation_logger(name: str) -> CorrelationLogger: """ Get a correlation-aware logger. Args: name: Logger name Returns: CorrelationLogger instance """ from src.core import get_logger base_logger = get_logger(name) return CorrelationLogger(base_logger) class RequestTracker: """ Track request lifecycle and performance metrics. """ def __init__(self): """Initialize request tracker.""" self.active_requests: Dict[str, Dict[str, Any]] = {} self.request_stats = { "total_requests": 0, "active_requests": 0, "avg_duration_ms": 0.0, "error_rate": 0.0 } def start_request( self, request_id: str, method: str, path: str, user_id: Optional[str] = None ): """ Start tracking a request. Args: request_id: Request ID method: HTTP method path: Request path user_id: Optional user ID """ self.active_requests[request_id] = { "start_time": time.time(), "method": method, "path": path, "user_id": user_id, "correlation_id": CorrelationContext.get_correlation_id() } self.request_stats["active_requests"] = len(self.active_requests) self.request_stats["total_requests"] += 1 def end_request( self, request_id: str, status_code: int, error: Optional[str] = None ) -> Optional[float]: """ End tracking a request. Args: request_id: Request ID status_code: HTTP status code error: Optional error message Returns: Request duration in seconds, or None if not found """ if request_id not in self.active_requests: return None request_info = self.active_requests.pop(request_id) duration = time.time() - request_info["start_time"] # Update stats self.request_stats["active_requests"] = len(self.active_requests) # Update average duration (simple moving average) current_avg = self.request_stats["avg_duration_ms"] new_avg = (current_avg + (duration * 1000)) / 2 self.request_stats["avg_duration_ms"] = new_avg # Update error rate if status_code >= 400 or error: total = self.request_stats["total_requests"] current_errors = self.request_stats["error_rate"] * (total - 1) new_error_rate = (current_errors + 1) / total self.request_stats["error_rate"] = new_error_rate return duration def get_active_requests(self) -> List[Dict[str, Any]]: """Get list of currently active requests.""" current_time = time.time() return [ { "request_id": req_id, "duration_ms": (current_time - info["start_time"]) * 1000, "method": info["method"], "path": info["path"], "user_id": info["user_id"], "correlation_id": info["correlation_id"] } for req_id, info in self.active_requests.items() ] def get_stats(self) -> Dict[str, Any]: """Get request tracking statistics.""" return self.request_stats.copy() # Global request tracker request_tracker = RequestTracker()