|
|
""" |
|
|
Module: services.ip_whitelist_service |
|
|
Description: IP whitelist management for production environments |
|
|
Author: Anderson H. Silva |
|
|
Date: 2025-01-25 |
|
|
License: Proprietary - All rights reserved |
|
|
""" |
|
|
|
|
|
import ipaddress |
|
|
from typing import List, Optional, Set, Dict, Any |
|
|
from datetime import datetime, timezone |
|
|
import json |
|
|
|
|
|
from src.core import get_logger |
|
|
from src.services.cache_service import cache_service |
|
|
from src.core.config import settings |
|
|
from src.models.base import BaseModel |
|
|
from sqlalchemy import Column, String, Boolean, DateTime, Integer, JSON |
|
|
from sqlalchemy.ext.asyncio import AsyncSession |
|
|
from sqlalchemy import select, delete |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
class IPWhitelist(BaseModel): |
|
|
"""IP whitelist entry model.""" |
|
|
__tablename__ = "ip_whitelists" |
|
|
|
|
|
id = Column(String(64), primary_key=True) |
|
|
ip_address = Column(String(45), nullable=False, unique=True) |
|
|
description = Column(String(255)) |
|
|
environment = Column(String(20), nullable=False, default="production") |
|
|
active = Column(Boolean, default=True) |
|
|
created_by = Column(String(255), nullable=False) |
|
|
created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)) |
|
|
expires_at = Column(DateTime(timezone=True), nullable=True) |
|
|
meta_info = Column(JSON, default=dict) |
|
|
|
|
|
|
|
|
is_cidr = Column(Boolean, default=False) |
|
|
cidr_prefix = Column(Integer, nullable=True) |
|
|
|
|
|
def is_expired(self) -> bool: |
|
|
"""Check if whitelist entry is expired.""" |
|
|
if not self.expires_at: |
|
|
return False |
|
|
return datetime.now(timezone.utc) > self.expires_at |
|
|
|
|
|
def matches(self, ip: str) -> bool: |
|
|
"""Check if IP matches this whitelist entry.""" |
|
|
if not self.active or self.is_expired(): |
|
|
return False |
|
|
|
|
|
try: |
|
|
if self.is_cidr: |
|
|
|
|
|
network = ipaddress.ip_network(f"{self.ip_address}/{self.cidr_prefix}") |
|
|
return ipaddress.ip_address(ip) in network |
|
|
else: |
|
|
|
|
|
return self.ip_address == ip |
|
|
except ValueError: |
|
|
logger.error(f"Invalid IP address format: {ip}") |
|
|
return False |
|
|
|
|
|
|
|
|
class IPWhitelistService: |
|
|
"""Service for managing IP whitelists.""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize IP whitelist service.""" |
|
|
self._cache_key_prefix = "ip_whitelist" |
|
|
self._cache_ttl = 300 |
|
|
self._whitelist_cache: Optional[Set[str]] = None |
|
|
self._cidr_cache: Optional[List[tuple]] = None |
|
|
self._last_cache_update: Optional[datetime] = None |
|
|
|
|
|
async def add_ip( |
|
|
self, |
|
|
session: AsyncSession, |
|
|
ip_address: str, |
|
|
created_by: str, |
|
|
description: Optional[str] = None, |
|
|
environment: str = "production", |
|
|
expires_at: Optional[datetime] = None, |
|
|
is_cidr: bool = False, |
|
|
meta_info: Optional[Dict[str, Any]] = None |
|
|
) -> IPWhitelist: |
|
|
"""Add IP address or CIDR range to whitelist.""" |
|
|
try: |
|
|
|
|
|
if is_cidr or "/" in ip_address: |
|
|
network = ipaddress.ip_network(ip_address, strict=False) |
|
|
ip_str = str(network.network_address) |
|
|
cidr_prefix = network.prefixlen |
|
|
is_cidr = True |
|
|
else: |
|
|
|
|
|
ip_obj = ipaddress.ip_address(ip_address) |
|
|
ip_str = str(ip_obj) |
|
|
cidr_prefix = None |
|
|
is_cidr = False |
|
|
|
|
|
except ValueError as e: |
|
|
logger.error(f"Invalid IP address format: {ip_address}") |
|
|
raise ValueError(f"Invalid IP address format: {ip_address}") from e |
|
|
|
|
|
|
|
|
existing = await session.execute( |
|
|
select(IPWhitelist).where( |
|
|
IPWhitelist.ip_address == ip_str, |
|
|
IPWhitelist.environment == environment |
|
|
) |
|
|
) |
|
|
if existing.scalar_one_or_none(): |
|
|
raise ValueError(f"IP address already whitelisted: {ip_str}") |
|
|
|
|
|
|
|
|
entry = IPWhitelist( |
|
|
id=f"{environment}:{ip_str}", |
|
|
ip_address=ip_str, |
|
|
description=description, |
|
|
environment=environment, |
|
|
created_by=created_by, |
|
|
expires_at=expires_at, |
|
|
is_cidr=is_cidr, |
|
|
cidr_prefix=cidr_prefix, |
|
|
meta_info=meta_info or {} |
|
|
) |
|
|
|
|
|
session.add(entry) |
|
|
await session.commit() |
|
|
|
|
|
|
|
|
await self._invalidate_cache() |
|
|
|
|
|
logger.info( |
|
|
"ip_whitelist_added", |
|
|
ip=ip_str, |
|
|
environment=environment, |
|
|
is_cidr=is_cidr, |
|
|
created_by=created_by |
|
|
) |
|
|
|
|
|
return entry |
|
|
|
|
|
async def remove_ip( |
|
|
self, |
|
|
session: AsyncSession, |
|
|
ip_address: str, |
|
|
environment: str = "production" |
|
|
) -> bool: |
|
|
"""Remove IP from whitelist.""" |
|
|
result = await session.execute( |
|
|
delete(IPWhitelist).where( |
|
|
IPWhitelist.ip_address == ip_address, |
|
|
IPWhitelist.environment == environment |
|
|
) |
|
|
) |
|
|
await session.commit() |
|
|
|
|
|
if result.rowcount > 0: |
|
|
await self._invalidate_cache() |
|
|
logger.info( |
|
|
"ip_whitelist_removed", |
|
|
ip=ip_address, |
|
|
environment=environment |
|
|
) |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
async def check_ip( |
|
|
self, |
|
|
session: AsyncSession, |
|
|
ip_address: str, |
|
|
environment: str = "production" |
|
|
) -> bool: |
|
|
"""Check if IP is whitelisted.""" |
|
|
|
|
|
cache_key = f"{self._cache_key_prefix}:{environment}:check:{ip_address}" |
|
|
cached = await cache_service.get(cache_key) |
|
|
if cached is not None: |
|
|
return cached |
|
|
|
|
|
|
|
|
await self._ensure_cache_loaded(session, environment) |
|
|
|
|
|
|
|
|
if self._whitelist_cache and ip_address in self._whitelist_cache: |
|
|
await cache_service.set(cache_key, True, ttl=self._cache_ttl) |
|
|
return True |
|
|
|
|
|
|
|
|
if self._cidr_cache: |
|
|
for cidr_ip, prefix, expires_at in self._cidr_cache: |
|
|
if expires_at and datetime.now(timezone.utc) > expires_at: |
|
|
continue |
|
|
|
|
|
try: |
|
|
network = ipaddress.ip_network(f"{cidr_ip}/{prefix}") |
|
|
if ipaddress.ip_address(ip_address) in network: |
|
|
await cache_service.set(cache_key, True, ttl=self._cache_ttl) |
|
|
return True |
|
|
except ValueError: |
|
|
continue |
|
|
|
|
|
|
|
|
await cache_service.set(cache_key, False, ttl=self._cache_ttl) |
|
|
return False |
|
|
|
|
|
async def list_ips( |
|
|
self, |
|
|
session: AsyncSession, |
|
|
environment: str = "production", |
|
|
include_expired: bool = False |
|
|
) -> List[IPWhitelist]: |
|
|
"""List all whitelisted IPs.""" |
|
|
query = select(IPWhitelist).where( |
|
|
IPWhitelist.environment == environment |
|
|
) |
|
|
|
|
|
if not include_expired: |
|
|
now = datetime.now(timezone.utc) |
|
|
query = query.where( |
|
|
(IPWhitelist.expires_at.is_(None)) | |
|
|
(IPWhitelist.expires_at > now) |
|
|
) |
|
|
|
|
|
result = await session.execute(query) |
|
|
return list(result.scalars().all()) |
|
|
|
|
|
async def update_ip( |
|
|
self, |
|
|
session: AsyncSession, |
|
|
ip_address: str, |
|
|
environment: str = "production", |
|
|
active: Optional[bool] = None, |
|
|
description: Optional[str] = None, |
|
|
expires_at: Optional[datetime] = None, |
|
|
meta_info: Optional[Dict[str, Any]] = None |
|
|
) -> Optional[IPWhitelist]: |
|
|
"""Update whitelist entry.""" |
|
|
result = await session.execute( |
|
|
select(IPWhitelist).where( |
|
|
IPWhitelist.ip_address == ip_address, |
|
|
IPWhitelist.environment == environment |
|
|
) |
|
|
) |
|
|
entry = result.scalar_one_or_none() |
|
|
|
|
|
if not entry: |
|
|
return None |
|
|
|
|
|
if active is not None: |
|
|
entry.active = active |
|
|
if description is not None: |
|
|
entry.description = description |
|
|
if expires_at is not None: |
|
|
entry.expires_at = expires_at |
|
|
if meta_info is not None: |
|
|
entry.meta_info = meta_info |
|
|
|
|
|
await session.commit() |
|
|
await self._invalidate_cache() |
|
|
|
|
|
logger.info( |
|
|
"ip_whitelist_updated", |
|
|
ip=ip_address, |
|
|
environment=environment, |
|
|
active=entry.active |
|
|
) |
|
|
|
|
|
return entry |
|
|
|
|
|
async def cleanup_expired( |
|
|
self, |
|
|
session: AsyncSession, |
|
|
environment: Optional[str] = None |
|
|
) -> int: |
|
|
"""Remove expired whitelist entries.""" |
|
|
query = delete(IPWhitelist).where( |
|
|
IPWhitelist.expires_at < datetime.now(timezone.utc) |
|
|
) |
|
|
|
|
|
if environment: |
|
|
query = query.where(IPWhitelist.environment == environment) |
|
|
|
|
|
result = await session.execute(query) |
|
|
await session.commit() |
|
|
|
|
|
if result.rowcount > 0: |
|
|
await self._invalidate_cache() |
|
|
logger.info( |
|
|
"ip_whitelist_cleanup", |
|
|
removed=result.rowcount, |
|
|
environment=environment |
|
|
) |
|
|
|
|
|
return result.rowcount |
|
|
|
|
|
async def _ensure_cache_loaded( |
|
|
self, |
|
|
session: AsyncSession, |
|
|
environment: str |
|
|
) -> None: |
|
|
"""Ensure whitelist is loaded in cache.""" |
|
|
|
|
|
if ( |
|
|
self._last_cache_update and |
|
|
(datetime.now(timezone.utc) - self._last_cache_update).total_seconds() < self._cache_ttl |
|
|
): |
|
|
return |
|
|
|
|
|
|
|
|
now = datetime.now(timezone.utc) |
|
|
result = await session.execute( |
|
|
select(IPWhitelist).where( |
|
|
IPWhitelist.environment == environment, |
|
|
IPWhitelist.active == True, |
|
|
(IPWhitelist.expires_at.is_(None)) | (IPWhitelist.expires_at > now) |
|
|
) |
|
|
) |
|
|
|
|
|
entries = result.scalars().all() |
|
|
|
|
|
|
|
|
self._whitelist_cache = set() |
|
|
self._cidr_cache = [] |
|
|
|
|
|
for entry in entries: |
|
|
if entry.is_cidr: |
|
|
self._cidr_cache.append(( |
|
|
entry.ip_address, |
|
|
entry.cidr_prefix, |
|
|
entry.expires_at |
|
|
)) |
|
|
else: |
|
|
self._whitelist_cache.add(entry.ip_address) |
|
|
|
|
|
self._last_cache_update = datetime.now(timezone.utc) |
|
|
|
|
|
logger.debug( |
|
|
"ip_whitelist_cache_loaded", |
|
|
environment=environment, |
|
|
exact_ips=len(self._whitelist_cache), |
|
|
cidr_ranges=len(self._cidr_cache) |
|
|
) |
|
|
|
|
|
async def _invalidate_cache(self) -> None: |
|
|
"""Invalidate the whitelist cache.""" |
|
|
self._whitelist_cache = None |
|
|
self._cidr_cache = None |
|
|
self._last_cache_update = None |
|
|
|
|
|
|
|
|
pattern = f"{self._cache_key_prefix}:*" |
|
|
await cache_service.delete_pattern(pattern) |
|
|
|
|
|
def get_default_whitelist(self) -> List[str]: |
|
|
"""Get default whitelist based on environment.""" |
|
|
defaults = [] |
|
|
|
|
|
|
|
|
defaults.extend([ |
|
|
"127.0.0.1", |
|
|
"::1", |
|
|
"localhost" |
|
|
]) |
|
|
|
|
|
|
|
|
if settings.is_development: |
|
|
defaults.extend([ |
|
|
"10.0.0.0/8", |
|
|
"172.16.0.0/12", |
|
|
"192.168.0.0/16" |
|
|
]) |
|
|
|
|
|
|
|
|
if settings.is_production: |
|
|
|
|
|
defaults.extend([ |
|
|
"76.76.21.0/24", |
|
|
"76.223.0.0/16" |
|
|
]) |
|
|
|
|
|
|
|
|
defaults.extend([ |
|
|
"34.0.0.0/8", |
|
|
"35.0.0.0/8" |
|
|
]) |
|
|
|
|
|
|
|
|
defaults.extend([ |
|
|
"52.0.0.0/8" |
|
|
]) |
|
|
|
|
|
return defaults |
|
|
|
|
|
async def initialize_defaults( |
|
|
self, |
|
|
session: AsyncSession, |
|
|
created_by: str = "system" |
|
|
) -> int: |
|
|
"""Initialize default whitelist entries.""" |
|
|
defaults = self.get_default_whitelist() |
|
|
count = 0 |
|
|
|
|
|
for ip in defaults: |
|
|
try: |
|
|
is_cidr = "/" in ip |
|
|
await self.add_ip( |
|
|
session=session, |
|
|
ip_address=ip, |
|
|
created_by=created_by, |
|
|
description="Default whitelist entry", |
|
|
environment=settings.app_env, |
|
|
is_cidr=is_cidr |
|
|
) |
|
|
count += 1 |
|
|
except ValueError: |
|
|
|
|
|
continue |
|
|
|
|
|
logger.info( |
|
|
"ip_whitelist_defaults_initialized", |
|
|
count=count, |
|
|
environment=settings.app_env |
|
|
) |
|
|
|
|
|
return count |
|
|
|
|
|
|
|
|
|
|
|
ip_whitelist_service = IPWhitelistService() |