cidadao.ai-backend / src /services /cache_service.py
anderson-ufrj
fix(cache): add missing CacheTTL enum to resolve HuggingFace import error
b153ebd
"""
Redis cache service for chat responses and investigations.
This service provides:
- Caching of frequent chat responses
- Investigation results caching
- Session state persistence
- Distributed cache for scalability
"""
import hashlib
from src.core import json_utils
from typing import Optional, Any, Dict, List
from datetime import datetime, timedelta
import asyncio
from functools import wraps
import zlib # For compression
from enum import Enum
import redis.asyncio as redis
from redis.asyncio.connection import ConnectionPool
from redis.exceptions import RedisError
from src.core import get_logger, settings
from src.core.exceptions import CacheError
from src.core.json_utils import dumps, loads, dumps_bytes
logger = get_logger(__name__)
class CacheTTL(Enum):
"""Cache Time-To-Live constants."""
SHORT = 300 # 5 minutes
MEDIUM = 1800 # 30 minutes
LONG = 3600 # 1 hour
VERY_LONG = 86400 # 24 hours
class CacheService:
"""Redis-based caching service for Cidadão.AI."""
def __init__(self):
"""Initialize Redis connection pool."""
self.pool: Optional[ConnectionPool] = None
self.redis: Optional[redis.Redis] = None
self._initialized = False
# Cache TTLs (in seconds)
self.TTL_CHAT_RESPONSE = 300 # 5 minutes for chat responses
self.TTL_INVESTIGATION = 3600 # 1 hour for investigation results
self.TTL_SESSION = 86400 # 24 hours for session data
self.TTL_AGENT_CONTEXT = 1800 # 30 minutes for agent context
self.TTL_SEARCH_RESULTS = 600 # 10 minutes for search results
# Stampede protection settings
self.STAMPEDE_DELTA = 10 # seconds before expiry to refresh
self.STAMPEDE_BETA = 1.0 # randomization factor
async def initialize(self):
"""Initialize Redis connection."""
if self._initialized:
return
try:
# Create connection pool
self.pool = ConnectionPool.from_url(
settings.redis_url,
decode_responses=True,
max_connections=50,
socket_keepalive=True,
socket_keepalive_options={
1: 1, # TCP_KEEPIDLE
2: 1, # TCP_KEEPINTVL
3: 3, # TCP_KEEPCNT
}
)
# Create Redis client
self.redis = redis.Redis(connection_pool=self.pool)
# Test connection
await self.redis.ping()
self._initialized = True
logger.info("Redis cache service initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize Redis: {e}")
raise CacheError(f"Redis initialization failed: {str(e)}")
async def close(self):
"""Close Redis connections."""
if self.redis:
await self.redis.close()
if self.pool:
await self.pool.disconnect()
self._initialized = False
def _generate_key(self, prefix: str, *args) -> str:
"""Generate cache key from prefix and arguments."""
# Create a consistent key from arguments
key_parts = [str(arg) for arg in args]
key_data = ":".join(key_parts)
# Hash long keys to avoid Redis key size limits
if len(key_data) > 100:
key_hash = hashlib.md5(key_data.encode()).hexdigest()
return f"cidadao:{prefix}:{key_hash}"
return f"cidadao:{prefix}:{key_data}"
async def get(self, key: str, decompress: bool = False) -> Optional[Any]:
"""Get value from cache with optional decompression."""
if not self._initialized:
await self.initialize()
try:
value = await self.redis.get(key)
if value:
# Decompress if needed
if decompress and isinstance(value, bytes):
try:
value = zlib.decompress(value)
except zlib.error:
pass # Not compressed
# Try to deserialize JSON
try:
return loads(value)
except Exception:
return value
return None
except RedisError as e:
logger.error(f"Redis get error: {e}")
return None
async def set(self, key: str, value: Any, ttl: Optional[int] = None, compress: bool = False) -> bool:
"""Set value in cache with optional TTL and compression."""
if not self._initialized:
await self.initialize()
try:
# Serialize complex objects to JSON
if isinstance(value, (dict, list)):
value = dumps_bytes(value)
elif not isinstance(value, bytes):
value = str(value).encode('utf-8')
# Compress if requested and value is large enough
if compress and len(value) > 1024: # Compress if > 1KB
value = zlib.compress(value, level=6)
if ttl:
await self.redis.setex(key, ttl, value)
else:
await self.redis.set(key, value)
return True
except RedisError as e:
logger.error(f"Redis set error: {e}")
return False
async def delete(self, key: str) -> bool:
"""Delete key from cache."""
if not self._initialized:
await self.initialize()
try:
result = await self.redis.delete(key)
return result > 0
except RedisError as e:
logger.error(f"Redis delete error: {e}")
return False
async def get_with_stampede_protection(
self,
key: str,
ttl: int,
refresh_callback = None,
decompress: bool = False
) -> Optional[Any]:
"""
Get value with cache stampede protection using probabilistic early expiration.
Args:
key: Cache key
ttl: Time to live for the cache
refresh_callback: Async function to refresh cache if needed
decompress: Whether to decompress the value
Returns:
Cached value or None
"""
# Get value with TTL info
pipeline = self.redis.pipeline()
pipeline.get(key)
pipeline.ttl(key)
value, remaining_ttl = await pipeline.execute()
if value is None:
return None
# Decompress and deserialize
if decompress and isinstance(value, bytes):
try:
value = zlib.decompress(value)
except zlib.error:
pass
try:
result = loads(value)
except Exception:
result = value
# Check if we should refresh early to prevent stampede
if refresh_callback and remaining_ttl > 0:
import random
import math
# XFetch algorithm for cache stampede prevention
now = datetime.now().timestamp()
delta = self.STAMPEDE_DELTA * math.log(random.random()) * self.STAMPEDE_BETA
if remaining_ttl < abs(delta):
# Refresh cache asynchronously
asyncio.create_task(self._refresh_cache(key, ttl, refresh_callback))
return result
async def _refresh_cache(self, key: str, ttl: int, refresh_callback):
"""Refresh cache value asynchronously."""
try:
new_value = await refresh_callback()
if new_value is not None:
await self.set(key, new_value, ttl=ttl, compress=len(dumps(new_value)) > 1024)
except Exception as e:
logger.error(f"Error refreshing cache for key {key}: {e}")
# Chat-specific methods
async def cache_chat_response(
self,
message: str,
response: Dict[str, Any],
intent: Optional[str] = None
) -> bool:
"""Cache a chat response for a given message."""
# Generate key from message and optional intent
key = self._generate_key("chat", message.lower().strip(), intent)
# Store response with metadata
cache_data = {
"response": response,
"cached_at": datetime.utcnow().isoformat(),
"hit_count": 0
}
return await self.set(key, cache_data, self.TTL_CHAT_RESPONSE, compress=True)
async def get_cached_chat_response(
self,
message: str,
intent: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Get cached chat response if available."""
key = self._generate_key("chat", message.lower().strip(), intent)
cache_data = await self.get(key, decompress=True)
if cache_data:
# Increment hit count
cache_data["hit_count"] += 1
await self.set(key, cache_data, self.TTL_CHAT_RESPONSE, compress=True)
logger.info(f"Cache hit for chat message: {message[:50]}...")
return cache_data["response"]
return None
# Session management
async def save_session_state(
self,
session_id: str,
state: Dict[str, Any]
) -> bool:
"""Save session state to cache."""
key = self._generate_key("session", session_id)
state["last_updated"] = datetime.utcnow().isoformat()
return await self.set(key, state, self.TTL_SESSION, compress=True)
async def get_session_state(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get session state from cache."""
key = self._generate_key("session", session_id)
return await self.get(key)
async def update_session_field(
self,
session_id: str,
field: str,
value: Any
) -> bool:
"""Update a specific field in session state."""
state = await self.get_session_state(session_id) or {}
state[field] = value
return await self.save_session_state(session_id, state)
# Investigation caching
async def cache_investigation_result(
self,
investigation_id: str,
result: Dict[str, Any]
) -> bool:
"""Cache investigation results."""
key = self._generate_key("investigation", investigation_id)
return await self.set(key, result, self.TTL_INVESTIGATION, compress=True)
async def get_cached_investigation(
self,
investigation_id: str
) -> Optional[Dict[str, Any]]:
"""Get cached investigation results."""
key = self._generate_key("investigation", investigation_id)
return await self.get(key)
# Agent context caching
async def save_agent_context(
self,
agent_id: str,
session_id: str,
context: Dict[str, Any]
) -> bool:
"""Save agent context for a session."""
key = self._generate_key("agent_context", agent_id, session_id)
return await self.set(key, context, self.TTL_AGENT_CONTEXT)
async def get_agent_context(
self,
agent_id: str,
session_id: str
) -> Optional[Dict[str, Any]]:
"""Get agent context for a session."""
key = self._generate_key("agent_context", agent_id, session_id)
return await self.get(key)
# Search results caching
async def cache_search_results(
self,
query: str,
filters: Dict[str, Any],
results: List[Dict[str, Any]]
) -> bool:
"""Cache search/query results."""
# Create deterministic key from query and filters
filter_str = json_utils.dumps(filters, sort_keys=True)
key = self._generate_key("search", query, filter_str)
cache_data = {
"results": results,
"count": len(results),
"cached_at": datetime.utcnow().isoformat()
}
return await self.set(key, cache_data, self.TTL_SEARCH_RESULTS)
async def get_cached_search_results(
self,
query: str,
filters: Dict[str, Any]
) -> Optional[List[Dict[str, Any]]]:
"""Get cached search results."""
filter_str = json_utils.dumps(filters, sort_keys=True)
key = self._generate_key("search", query, filter_str)
cache_data = await self.get(key)
if cache_data:
logger.info(f"Cache hit for search: {query}")
return cache_data["results"]
return None
# Cache statistics
async def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
if not self._initialized:
return {"error": "Cache not initialized"}
try:
info = await self.redis.info("stats")
memory = await self.redis.info("memory")
# Count keys by pattern
chat_keys = len([k async for k in self.redis.scan_iter("cidadao:chat:*")])
session_keys = len([k async for k in self.redis.scan_iter("cidadao:session:*")])
investigation_keys = len([k async for k in self.redis.scan_iter("cidadao:investigation:*")])
return {
"connected": True,
"total_keys": await self.redis.dbsize(),
"keys_by_type": {
"chat": chat_keys,
"session": session_keys,
"investigation": investigation_keys
},
"memory_used": memory.get("used_memory_human", "N/A"),
"hit_rate": f"{info.get('keyspace_hit_ratio', 0):.2%}",
"total_connections": info.get("total_connections_received", 0),
"commands_processed": info.get("total_commands_processed", 0)
}
except Exception as e:
logger.error(f"Error getting cache stats: {e}")
return {"error": str(e)}
# Cache decorator for functions
def cache_result(prefix: str, ttl: int = 300):
"""Decorator to cache function results."""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# Skip self argument if it's a method
cache_args = args[1:] if args and hasattr(args[0], '__class__') else args
# Generate cache key
cache_service = CacheService()
key = cache_service._generate_key(
prefix,
func.__name__,
*cache_args,
**kwargs
)
# Check cache
cached = await cache_service.get(key)
if cached is not None:
logger.debug(f"Cache hit for {func.__name__}")
return cached
# Execute function
result = await func(*args, **kwargs)
# Cache result
await cache_service.set(key, result, ttl)
return result
return wrapper
return decorator
# Global cache service instance
cache_service = CacheService()