|
|
""" |
|
|
Supabase integration service for direct database access. |
|
|
|
|
|
This service provides a bridge between the backend and Supabase PostgreSQL, |
|
|
allowing investigations to be stored centrally for frontend consumption. |
|
|
""" |
|
|
|
|
|
import os |
|
|
from typing import Optional, List, Dict, Any |
|
|
from datetime import datetime |
|
|
from contextlib import asynccontextmanager |
|
|
|
|
|
from asyncpg import Pool, create_pool, Connection |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
from src.core import get_logger, settings |
|
|
from src.core.exceptions import CidadaoAIError |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
class SupabaseConfig(BaseModel): |
|
|
"""Supabase connection configuration.""" |
|
|
|
|
|
url: str = Field(..., description="Supabase PostgreSQL connection URL") |
|
|
anon_key: Optional[str] = Field(None, description="Supabase anon key (for Row Level Security)") |
|
|
service_role_key: Optional[str] = Field(None, description="Supabase service role key (bypasses RLS)") |
|
|
min_connections: int = Field(default=5, description="Minimum pool connections") |
|
|
max_connections: int = Field(default=20, description="Maximum pool connections") |
|
|
|
|
|
@classmethod |
|
|
def from_env(cls) -> "SupabaseConfig": |
|
|
"""Load configuration from environment variables.""" |
|
|
supabase_url = os.getenv("SUPABASE_DB_URL") or os.getenv("DATABASE_URL") |
|
|
|
|
|
if not supabase_url: |
|
|
raise ValueError( |
|
|
"SUPABASE_DB_URL or DATABASE_URL environment variable required. " |
|
|
"Get it from: Supabase Dashboard > Settings > Database > Connection string (URI)" |
|
|
) |
|
|
|
|
|
return cls( |
|
|
url=supabase_url, |
|
|
anon_key=os.getenv("SUPABASE_ANON_KEY"), |
|
|
service_role_key=os.getenv("SUPABASE_SERVICE_ROLE_KEY"), |
|
|
min_connections=int(os.getenv("SUPABASE_MIN_CONNECTIONS", "5")), |
|
|
max_connections=int(os.getenv("SUPABASE_MAX_CONNECTIONS", "20")), |
|
|
) |
|
|
|
|
|
|
|
|
class SupabaseService: |
|
|
""" |
|
|
Service for interacting with Supabase PostgreSQL. |
|
|
|
|
|
Provides connection pooling and CRUD operations for investigations. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Optional[SupabaseConfig] = None): |
|
|
""" |
|
|
Initialize Supabase service. |
|
|
|
|
|
Args: |
|
|
config: Supabase configuration (loads from env if None) |
|
|
""" |
|
|
self.config = config or SupabaseConfig.from_env() |
|
|
self._pool: Optional[Pool] = None |
|
|
self._initialized = False |
|
|
|
|
|
async def initialize(self) -> None: |
|
|
"""Initialize connection pool.""" |
|
|
if self._initialized: |
|
|
logger.warning("Supabase service already initialized") |
|
|
return |
|
|
|
|
|
try: |
|
|
logger.info("Initializing Supabase connection pool") |
|
|
|
|
|
self._pool = await create_pool( |
|
|
dsn=self.config.url, |
|
|
min_size=self.config.min_connections, |
|
|
max_size=self.config.max_connections, |
|
|
command_timeout=30, |
|
|
server_settings={ |
|
|
'application_name': 'cidadao-ai-backend', |
|
|
'timezone': 'UTC', |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
async with self._pool.acquire() as conn: |
|
|
version = await conn.fetchval("SELECT version()") |
|
|
logger.info(f"Connected to Supabase PostgreSQL: {version[:50]}...") |
|
|
|
|
|
self._initialized = True |
|
|
logger.info("Supabase service initialized successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize Supabase service: {e}", exc_info=True) |
|
|
raise CidadaoAIError(f"Supabase initialization failed: {e}") |
|
|
|
|
|
async def close(self) -> None: |
|
|
"""Close connection pool.""" |
|
|
if self._pool: |
|
|
await self._pool.close() |
|
|
self._initialized = False |
|
|
logger.info("Supabase connection pool closed") |
|
|
|
|
|
@asynccontextmanager |
|
|
async def get_connection(self): |
|
|
""" |
|
|
Get a database connection from the pool. |
|
|
|
|
|
Yields: |
|
|
Connection instance |
|
|
""" |
|
|
if not self._initialized: |
|
|
await self.initialize() |
|
|
|
|
|
async with self._pool.acquire() as conn: |
|
|
yield conn |
|
|
|
|
|
async def create_investigation( |
|
|
self, |
|
|
user_id: str, |
|
|
query: str, |
|
|
data_source: str, |
|
|
filters: Optional[Dict[str, Any]] = None, |
|
|
anomaly_types: Optional[List[str]] = None, |
|
|
session_id: Optional[str] = None, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Create a new investigation in Supabase. |
|
|
|
|
|
Args: |
|
|
user_id: User ID |
|
|
query: Investigation query |
|
|
data_source: Data source to investigate |
|
|
filters: Query filters |
|
|
anomaly_types: Types of anomalies to detect |
|
|
session_id: Optional session ID |
|
|
|
|
|
Returns: |
|
|
Created investigation as dict |
|
|
""" |
|
|
async with self.get_connection() as conn: |
|
|
import json |
|
|
|
|
|
row = await conn.fetchrow( |
|
|
""" |
|
|
INSERT INTO investigations ( |
|
|
user_id, session_id, query, data_source, |
|
|
status, filters, anomaly_types, progress, |
|
|
created_at, updated_at |
|
|
) |
|
|
VALUES ($1, $2, $3, $4, $5, $6::jsonb, $7::jsonb, $8, $9, $10) |
|
|
RETURNING * |
|
|
""", |
|
|
user_id, |
|
|
session_id, |
|
|
query, |
|
|
data_source, |
|
|
"pending", |
|
|
json.dumps(filters or {}), |
|
|
json.dumps(anomaly_types or []), |
|
|
0.0, |
|
|
datetime.utcnow(), |
|
|
datetime.utcnow(), |
|
|
) |
|
|
|
|
|
logger.info(f"Created investigation {row['id']} in Supabase") |
|
|
return dict(row) |
|
|
|
|
|
async def get_investigation(self, investigation_id: str) -> Optional[Dict[str, Any]]: |
|
|
""" |
|
|
Get investigation by ID. |
|
|
|
|
|
Args: |
|
|
investigation_id: Investigation UUID |
|
|
|
|
|
Returns: |
|
|
Investigation dict or None |
|
|
""" |
|
|
async with self.get_connection() as conn: |
|
|
row = await conn.fetchrow( |
|
|
"SELECT * FROM investigations WHERE id = $1", |
|
|
investigation_id |
|
|
) |
|
|
|
|
|
return dict(row) if row else None |
|
|
|
|
|
async def update_investigation( |
|
|
self, |
|
|
investigation_id: str, |
|
|
**updates |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Update investigation fields. |
|
|
|
|
|
Args: |
|
|
investigation_id: Investigation UUID |
|
|
**updates: Fields to update |
|
|
|
|
|
Returns: |
|
|
Updated investigation dict |
|
|
""" |
|
|
import json |
|
|
|
|
|
|
|
|
jsonb_fields = {'results', 'filters', 'anomaly_types'} |
|
|
|
|
|
|
|
|
set_clauses = [] |
|
|
values = [] |
|
|
param_index = 1 |
|
|
|
|
|
for key, value in updates.items(): |
|
|
if key in jsonb_fields and isinstance(value, (dict, list)): |
|
|
set_clauses.append(f"{key} = ${param_index}::jsonb") |
|
|
values.append(json.dumps(value)) |
|
|
else: |
|
|
set_clauses.append(f"{key} = ${param_index}") |
|
|
values.append(value) |
|
|
param_index += 1 |
|
|
|
|
|
|
|
|
set_clauses.append(f"updated_at = ${param_index}") |
|
|
values.append(datetime.utcnow()) |
|
|
param_index += 1 |
|
|
|
|
|
|
|
|
values.append(investigation_id) |
|
|
|
|
|
query = f""" |
|
|
UPDATE investigations |
|
|
SET {', '.join(set_clauses)} |
|
|
WHERE id = ${param_index} |
|
|
RETURNING * |
|
|
""" |
|
|
|
|
|
async with self.get_connection() as conn: |
|
|
row = await conn.fetchrow(query, *values) |
|
|
|
|
|
if not row: |
|
|
raise ValueError(f"Investigation {investigation_id} not found") |
|
|
|
|
|
logger.debug(f"Updated investigation {investigation_id}") |
|
|
return dict(row) |
|
|
|
|
|
async def update_progress( |
|
|
self, |
|
|
investigation_id: str, |
|
|
progress: float, |
|
|
current_phase: str, |
|
|
records_processed: Optional[int] = None, |
|
|
anomalies_found: Optional[int] = None, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Update investigation progress. |
|
|
|
|
|
Args: |
|
|
investigation_id: Investigation UUID |
|
|
progress: Progress percentage (0.0 to 1.0) |
|
|
current_phase: Current processing phase |
|
|
records_processed: Number of records processed |
|
|
anomalies_found: Number of anomalies detected |
|
|
|
|
|
Returns: |
|
|
Updated investigation dict |
|
|
""" |
|
|
updates = { |
|
|
"progress": progress, |
|
|
"current_phase": current_phase, |
|
|
} |
|
|
|
|
|
if records_processed is not None: |
|
|
updates["total_records_analyzed"] = records_processed |
|
|
|
|
|
if anomalies_found is not None: |
|
|
updates["anomalies_found"] = anomalies_found |
|
|
|
|
|
return await self.update_investigation(investigation_id, **updates) |
|
|
|
|
|
async def complete_investigation( |
|
|
self, |
|
|
investigation_id: str, |
|
|
results: List[Dict[str, Any]], |
|
|
summary: str, |
|
|
confidence_score: float, |
|
|
total_records: int, |
|
|
anomalies_found: int, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Mark investigation as completed with results. |
|
|
|
|
|
Args: |
|
|
investigation_id: Investigation UUID |
|
|
results: List of anomaly results |
|
|
summary: Investigation summary |
|
|
confidence_score: Overall confidence |
|
|
total_records: Total records analyzed |
|
|
anomalies_found: Total anomalies found |
|
|
|
|
|
Returns: |
|
|
Updated investigation dict |
|
|
""" |
|
|
return await self.update_investigation( |
|
|
investigation_id, |
|
|
status="completed", |
|
|
progress=1.0, |
|
|
current_phase="completed", |
|
|
results=results, |
|
|
summary=summary, |
|
|
confidence_score=confidence_score, |
|
|
total_records_analyzed=total_records, |
|
|
anomalies_found=anomalies_found, |
|
|
completed_at=datetime.utcnow(), |
|
|
) |
|
|
|
|
|
async def fail_investigation( |
|
|
self, |
|
|
investigation_id: str, |
|
|
error_message: str, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Mark investigation as failed. |
|
|
|
|
|
Args: |
|
|
investigation_id: Investigation UUID |
|
|
error_message: Error description |
|
|
|
|
|
Returns: |
|
|
Updated investigation dict |
|
|
""" |
|
|
return await self.update_investigation( |
|
|
investigation_id, |
|
|
status="failed", |
|
|
current_phase="failed", |
|
|
error_message=error_message, |
|
|
completed_at=datetime.utcnow(), |
|
|
) |
|
|
|
|
|
async def list_user_investigations( |
|
|
self, |
|
|
user_id: str, |
|
|
limit: int = 20, |
|
|
offset: int = 0, |
|
|
status: Optional[str] = None, |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
List investigations for a user. |
|
|
|
|
|
Args: |
|
|
user_id: User ID |
|
|
limit: Maximum results |
|
|
offset: Pagination offset |
|
|
status: Filter by status |
|
|
|
|
|
Returns: |
|
|
List of investigation dicts |
|
|
""" |
|
|
async with self.get_connection() as conn: |
|
|
query = """ |
|
|
SELECT * FROM investigations |
|
|
WHERE user_id = $1 |
|
|
""" |
|
|
params = [user_id] |
|
|
|
|
|
if status: |
|
|
query += " AND status = $2" |
|
|
params.append(status) |
|
|
|
|
|
query += " ORDER BY created_at DESC LIMIT $" + str(len(params) + 1) |
|
|
params.append(limit) |
|
|
|
|
|
query += " OFFSET $" + str(len(params) + 1) |
|
|
params.append(offset) |
|
|
|
|
|
rows = await conn.fetch(query, *params) |
|
|
|
|
|
return [dict(row) for row in rows] |
|
|
|
|
|
async def delete_investigation( |
|
|
self, |
|
|
investigation_id: str, |
|
|
user_id: str, |
|
|
) -> bool: |
|
|
""" |
|
|
Delete an investigation (soft delete by marking as cancelled). |
|
|
|
|
|
Args: |
|
|
investigation_id: Investigation UUID |
|
|
user_id: User ID (for authorization) |
|
|
|
|
|
Returns: |
|
|
True if deleted, False if not found |
|
|
""" |
|
|
async with self.get_connection() as conn: |
|
|
result = await conn.execute( |
|
|
""" |
|
|
UPDATE investigations |
|
|
SET status = 'cancelled', completed_at = $1 |
|
|
WHERE id = $2 AND user_id = $3 |
|
|
""", |
|
|
datetime.utcnow(), |
|
|
investigation_id, |
|
|
user_id, |
|
|
) |
|
|
|
|
|
|
|
|
rows_affected = int(result.split()[-1]) |
|
|
|
|
|
if rows_affected > 0: |
|
|
logger.info(f"Cancelled investigation {investigation_id}") |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
async def health_check(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Check Supabase connection health. |
|
|
|
|
|
Returns: |
|
|
Health status dict |
|
|
""" |
|
|
try: |
|
|
async with self.get_connection() as conn: |
|
|
|
|
|
result = await conn.fetchval("SELECT 1") |
|
|
pool_size = self._pool.get_size() |
|
|
pool_free = self._pool.get_idle_size() |
|
|
|
|
|
return { |
|
|
"status": "healthy", |
|
|
"connected": True, |
|
|
"pool_size": pool_size, |
|
|
"pool_free": pool_free, |
|
|
"pool_used": pool_size - pool_free, |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"Supabase health check failed: {e}") |
|
|
return { |
|
|
"status": "unhealthy", |
|
|
"connected": False, |
|
|
"error": str(e), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
supabase_service = SupabaseService() |
|
|
|
|
|
|
|
|
async def get_supabase_service() -> SupabaseService: |
|
|
"""Get the global Supabase service instance.""" |
|
|
if not supabase_service._initialized: |
|
|
await supabase_service.initialize() |
|
|
return supabase_service |
|
|
|