"""API Router with /chat endpoint for Swagger testing.""" from enum import Enum from pydantic import BaseModel, Field from fastapi import APIRouter, Depends, UploadFile, File, Query, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from app.agent.mmca_agent import MMCAAgent from app.agent.react_agent import ReActAgent from app.shared.db.session import get_db from app.core.config import settings from app.mcp.tools import mcp_tools from app.shared.chat_history import chat_history router = APIRouter() class LLMProvider(str, Enum): """Available LLM providers.""" GOOGLE = "Google" MEGALLM = "MegaLLM" class ChatRequest(BaseModel): """Chat request model.""" message: str = Field( ..., description="User message in natural language", examples=["Tìm quán cafe gần bãi biển Mỹ Khê"], ) user_id: str = Field( default="anonymous", description="User ID for session management", examples=["user_123", "anonymous"], ) session_id: str | None = Field( None, description="Session ID (optional, uses 'default' if not provided)", examples=["session_abc", "default"], ) image_url: str | None = Field( None, description="Optional image URL for visual similarity search", examples=["https://example.com/cafe.jpg"], ) provider: LLMProvider = Field( default=LLMProvider.MEGALLM, description="LLM provider to use: Google or MegaLLM", ) model: str | None = Field( None, description=f"Model name. Defaults: Google={settings.default_gemini_model}, MegaLLM={settings.default_megallm_model}", examples=["gemini-2.0-flash", "deepseek-r1-distill-llama-70b"], ) react_mode: bool = Field( default=False, description="Enable ReAct multi-step reasoning mode", ) max_steps: int = Field( default=5, description="Maximum reasoning steps for ReAct mode", ge=1, le=10, ) class WorkflowStepResponse(BaseModel): """Workflow step info.""" step: str = Field(..., description="Step name") tool: str | None = Field(None, description="Tool used") purpose: str = Field(default="", description="Purpose of this step") results: int = Field(default=0, description="Number of results") class WorkflowResponse(BaseModel): """Workflow trace for debugging.""" query: str = Field(..., description="Original query") intent_detected: str = Field(..., description="Detected intent") tools_used: list[str] = Field(default_factory=list, description="Tools used") steps: list[WorkflowStepResponse] = Field(default_factory=list, description="Workflow steps") total_duration_ms: float = Field(..., description="Total processing time") class PlaceItem(BaseModel): """Place item for FE rendering.""" place_id: str name: str category: str | None = None lat: float | None = None lng: float | None = None rating: float | None = None distance_km: float | None = None address: str | None = None image_url: str | None = None class ChatResponse(BaseModel): """Chat response model.""" response: str = Field(..., description="Agent's response") status: str = Field(default="success", description="Response status") provider: str = Field(..., description="LLM provider used") model: str = Field(..., description="Model used") user_id: str = Field(..., description="User ID") session_id: str = Field(..., description="Session ID used") places: list[PlaceItem] = Field(default_factory=list, description="LLM-selected places for FE rendering") tools_used: list[str] = Field(default_factory=list, description="MCP tools used") duration_ms: float = Field(default=0, description="Total processing time in ms") class NearbyRequest(BaseModel): """Nearby places request model.""" lat: float = Field(..., description="Latitude", examples=[16.0626442]) lng: float = Field(..., description="Longitude", examples=[108.2462143]) max_distance_km: float = Field( default=5.0, description="Maximum distance in kilometers", examples=[5.0, 18.72], ) category: str | None = Field( None, description="Category filter (cafe, restaurant, attraction, etc.)", examples=["cafe", "restaurant"], ) limit: int = Field(default=10, description="Maximum results", examples=[10, 20]) class PlaceResponse(BaseModel): """Place response model.""" place_id: str name: str category: str | None = None lat: float | None = None lng: float | None = None distance_km: float | None = None rating: float | None = None description: str | None = None class NearbyResponse(BaseModel): """Nearby places response model.""" places: list[PlaceResponse] count: int query: dict class ClearHistoryRequest(BaseModel): """Clear history request model.""" user_id: str = Field(..., description="User ID to clear history for") session_id: str | None = Field( None, description="Session ID to clear (clears all if not provided)", ) class HistoryResponse(BaseModel): """Chat history response model.""" user_id: str sessions: list[str] current_session: str | None message_count: int class MessageItem(BaseModel): """Single chat message.""" role: str content: str timestamp: str class MessagesResponse(BaseModel): """Chat messages response.""" user_id: str session_id: str messages: list[MessageItem] count: int @router.post( "/nearby", response_model=NearbyResponse, summary="Find nearby places (Neo4j)", description=""" Find places near a given location using Neo4j spatial query. This endpoint directly tests the `find_nearby_places` MCP tool. ## Test Cases - Case 1: lat=16.0626442, lng=108.2462143, max_distance_km=18.72 - Case 2: lat=16.0623184, lng=108.2306049, max_distance_km=17.94 """, ) async def find_nearby(request: NearbyRequest) -> NearbyResponse: """ Find nearby places using Neo4j graph database. Directly calls the find_nearby_places MCP tool. """ places = await mcp_tools.find_nearby_places( lat=request.lat, lng=request.lng, max_distance_km=request.max_distance_km, category=request.category, limit=request.limit, ) return NearbyResponse( places=[ PlaceResponse( place_id=p.place_id, name=p.name, category=p.category, lat=p.lat, lng=p.lng, distance_km=p.distance_km, rating=p.rating, description=p.description, ) for p in places ], count=len(places), query={ "lat": request.lat, "lng": request.lng, "max_distance_km": request.max_distance_km, "category": request.category, }, ) async def enrich_places_from_ids(place_ids: list[str], db: AsyncSession) -> list[PlaceItem]: """ Enrich LLM-selected place_ids with full details from DB. Args: place_ids: List of place_ids selected by LLM in synthesis db: Database session Returns: List of PlaceItem with full details """ if not place_ids: return [] # Fetch full details from DB from sqlalchemy import text result = await db.execute( text(""" SELECT place_id, name, category, address, rating, ST_X(coordinates::geometry) as lng, ST_Y(coordinates::geometry) as lat FROM places_metadata WHERE place_id = ANY(:place_ids) """), {"place_ids": place_ids} ) rows = result.fetchall() # Build PlaceItem list preserving LLM order places_dict = {row.place_id: row for row in rows} places = [] for pid in place_ids: if pid in places_dict: row = places_dict[pid] places.append(PlaceItem( place_id=row.place_id, name=row.name, category=row.category, lat=row.lat, lng=row.lng, rating=float(row.rating) if row.rating else None, address=row.address, )) return places @router.post( "/chat", response_model=ChatResponse, summary="Chat with LocalMate Agent", description=""" Chat with the Multi-Modal Contextual Agent (MMCA). ## Session Management - Each user can have up to **3 sessions** stored - Provide `user_id` and optional `session_id` to maintain conversation history - History is automatically injected into the agent prompt ## LLM Providers - **Google**: Gemini models (gemini-2.0-flash, etc.) - **MegaLLM**: DeepSeek models (deepseek-r1-distill-llama-70b, etc.) ## Examples - "Tìm quán cafe gần bãi biển Mỹ Khê" - "Nhà hàng hải sản nào gần Cầu Rồng?" """, ) async def chat( request: ChatRequest, db: AsyncSession = Depends(get_db), ) -> ChatResponse: """ Chat endpoint with per-user history support. Send a natural language message, select provider and model. The agent will analyze your intent, query relevant data sources, and return a synthesized response with conversation context. """ # Determine model to use if request.model: model = request.model elif request.provider == LLMProvider.GOOGLE: model = settings.default_gemini_model else: model = settings.default_megallm_model # Get session ID session_id = request.session_id or "default" # Get conversation history for context history = chat_history.get_history( user_id=request.user_id, session_id=session_id, max_messages=6, # Last 3 exchanges (6 messages) ) # Add user message to history chat_history.add_message( user_id=request.user_id, role="user", content=request.message, session_id=session_id, ) # Choose agent based on react_mode if request.react_mode: # ReAct multi-step agent agent = ReActAgent( provider=request.provider.value, model=model, max_steps=request.max_steps, ) response_text, agent_state = await agent.run( query=request.message, db=db, image_url=request.image_url, history=history, ) # Convert state to workflow workflow = agent.to_workflow(agent_state) workflow_data = workflow.to_dict() # Add assistant response to history chat_history.add_message( user_id=request.user_id, role="assistant", content=response_text, session_id=session_id, ) # Enrich LLM-selected place_ids with DB data places = await enrich_places_from_ids(agent_state.selected_place_ids, db) return ChatResponse( response=response_text, status="success", provider=request.provider.value, model=model, user_id=request.user_id, session_id=session_id, places=places, tools_used=list(agent_state.context.keys()), duration_ms=agent_state.total_duration_ms, ) else: # Single-step agent (original behavior) agent = MMCAAgent(provider=request.provider.value, model=model) # Pass history to agent result = await agent.chat( message=request.message, db=db, image_url=request.image_url, history=history, ) # Add assistant response to history chat_history.add_message( user_id=request.user_id, role="assistant", content=result.response, session_id=session_id, ) # Use LLM-selected places (same pattern as ReAct mode) places = [] if result.selected_place_ids: places = await enrich_places_from_ids(result.selected_place_ids, db) # Add distance info if available from tool results distance_map = {} for tool_call in result.tool_results: if tool_call.result: for item in tool_call.result: if isinstance(item, dict) and 'place_id' in item and 'distance_km' in item: distance_map[item['place_id']] = item['distance_km'] for place in places: if place.place_id in distance_map: place.distance_km = distance_map[place.place_id] return ChatResponse( response=result.response, status="success", provider=request.provider.value, model=model, user_id=request.user_id, session_id=session_id, places=places, tools_used=result.tools_used, duration_ms=result.total_duration_ms, ) @router.post( "/chat/clear", summary="Clear chat history", description="Clears the conversation history for a specific user/session.", ) async def clear_chat(request: ClearHistoryRequest): """Clear conversation history for a user.""" if request.session_id: chat_history.clear_session(request.user_id, request.session_id) message = f"Session '{request.session_id}' cleared for user '{request.user_id}'" else: chat_history.clear_all_sessions(request.user_id) message = f"All sessions cleared for user '{request.user_id}'" return {"status": "success", "message": message} @router.get( "/chat/history/{user_id}", response_model=HistoryResponse, summary="Get chat history info", description="Get information about user's chat sessions.", ) async def get_history_info(user_id: str) -> HistoryResponse: """Get chat history information for a user.""" sessions = chat_history.get_session_ids(user_id) messages = chat_history.get_messages(user_id) return HistoryResponse( user_id=user_id, sessions=sessions, current_session="default" if "default" in sessions else (sessions[0] if sessions else None), message_count=len(messages), ) @router.get( "/chat/messages/{user_id}", response_model=MessagesResponse, summary="Get chat messages", description="Get actual chat messages from a specific session.", ) async def get_chat_messages( user_id: str, session_id: str = "default", ) -> MessagesResponse: """Get chat messages for a session.""" messages = chat_history.get_messages(user_id, session_id) return MessagesResponse( user_id=user_id, session_id=session_id, messages=[ MessageItem( role=m.role, content=m.content, timestamp=m.timestamp.isoformat(), ) for m in messages ], count=len(messages), ) class ImageSearchResult(BaseModel): """Image search result model.""" place_id: str name: str category: str | None = None rating: float | None = None similarity: float matched_images: int = 1 image_url: str | None = None class ImageSearchResponse(BaseModel): """Image search response model.""" results: list[ImageSearchResult] total: int @router.post( "/search/image", response_model=ImageSearchResponse, summary="Search places by image", description=""" Upload an image to find visually similar places. Uses image embeddings stored in Supabase pgvector. """, ) async def search_by_image( image: UploadFile = File(..., description="Image file to search"), limit: int = Query(10, ge=1, le=50, description="Maximum results"), db: AsyncSession = Depends(get_db), ) -> ImageSearchResponse: """ Search places by uploading an image. Uses visual embeddings to find similar places. """ try: # Read image bytes image_bytes = await image.read() if len(image_bytes) > 10 * 1024 * 1024: # 10MB limit raise HTTPException(status_code=400, detail="Image too large (max 10MB)") # Search using visual tool results = await mcp_tools.search_by_image_bytes( db=db, image_bytes=image_bytes, limit=limit, ) return ImageSearchResponse( results=[ ImageSearchResult( place_id=r.place_id, name=r.name, category=r.category, rating=r.rating, similarity=r.similarity, matched_images=r.matched_images, image_url=r.image_url, ) for r in results ], total=len(results), ) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Image search error: {str(e)}")